AkashKumarave commited on
Commit
84a350d
·
verified ·
1 Parent(s): 6950f14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -24
app.py CHANGED
@@ -81,13 +81,15 @@ def validate_image(image_content: bytes):
81
  img = Image.open(io.BytesIO(image_content))
82
  if img.format not in ["PNG", "JPEG"]:
83
  raise HTTPException(status_code=400, detail="Only PNG or JPEG images are supported")
 
84
  return True, img.format.lower()
85
  except Exception as e:
 
86
  raise HTTPException(status_code=400, detail=f"Image validation error: {str(e)}")
87
 
88
  # ===== API FUNCTIONS =====
89
  def create_multi_image_task(subject_images: List[bytes], prompt: str):
90
- """Create image generation task with Gemini API (exactly two images)"""
91
  headers = {
92
  "Content-Type": "application/json"
93
  }
@@ -103,31 +105,62 @@ def create_multi_image_task(subject_images: List[bytes], prompt: str):
103
  }
104
  })
105
 
 
 
 
106
  payload = {
107
  "contents": [
108
  {
109
  "parts": [
110
- {"text": prompt},
111
  *subject_image_list
112
  ]
113
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  ]
115
  }
116
 
117
- try:
118
- logger.info(f"Sending request to Gemini API with payload: {payload}")
119
- response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
120
- response.raise_for_status()
121
- data = response.json()
122
- logger.info(f"API response: {data}")
123
- if not data.get("candidates") or not data["candidates"][0].get("content"):
124
- raise HTTPException(status_code=500, detail="No valid content returned from API")
125
- return data
126
- except requests.exceptions.RequestException as e:
127
- logger.error(f"API request failed: {str(e)}")
128
- if hasattr(e, 'response') and e.response:
129
- logger.error(f"API response: {e.response.text}")
130
- raise HTTPException(status_code=500, detail=f"API Error: {str(e)}")
 
 
 
 
 
 
 
131
 
132
  # ===== MAIN PROCESSING =====
133
  async def generate_image(subject_images: List[bytes], prompt: str):
@@ -159,21 +192,43 @@ async def generate_image(subject_images: List[bytes], prompt: str):
159
 
160
  parts = candidate["content"]["parts"]
161
  logger.info(f"Response parts: {parts}")
162
- # Find the part with inline_data
163
  image_base64 = None
 
 
164
  for part in parts:
165
  if "inline_data" in part and "data" in part["inline_data"]:
166
  image_base64 = part["inline_data"]["data"]
167
- break
 
 
 
 
168
  elif "text" in part:
169
- logger.info(f"Text part found: {part['text']}")
 
 
 
 
 
 
 
 
170
 
171
- if not image_base64:
172
- logger.error(f"No inline_data found in response parts: {parts}")
173
- raise HTTPException(status_code=500, detail="No inline_data found in API response")
 
 
 
 
 
 
 
 
 
 
174
 
175
- # Decode and save the image
176
- image_data = base64.b64decode(image_base64)
177
  output_dir = Path("/tmp")
178
  output_dir.mkdir(exist_ok=True)
179
  output_path = output_dir / f"gemini_output_{int(time.time())}.png"
 
81
  img = Image.open(io.BytesIO(image_content))
82
  if img.format not in ["PNG", "JPEG"]:
83
  raise HTTPException(status_code=400, detail="Only PNG or JPEG images are supported")
84
+ logger.info(f"Validated image: format={img.format}, size={size_mb:.2f}MB")
85
  return True, img.format.lower()
86
  except Exception as e:
87
+ logger.error(f"Image validation error: {str(e)}")
88
  raise HTTPException(status_code=400, detail=f"Image validation error: {str(e)}")
89
 
90
  # ===== API FUNCTIONS =====
91
  def create_multi_image_task(subject_images: List[bytes], prompt: str):
92
+ """Create image generation task with Gemini API (up to two images)"""
93
  headers = {
94
  "Content-Type": "application/json"
95
  }
 
105
  }
106
  })
107
 
108
+ # Use a more descriptive prompt structure as per documentation
109
+ enhanced_prompt = f"A photorealistic composition combining elements from the provided images: {prompt}. Ensure the scene is cohesive, with soft, natural lighting and a balanced aspect ratio of 16:9."
110
+
111
  payload = {
112
  "contents": [
113
  {
114
  "parts": [
115
+ {"text": enhanced_prompt},
116
  *subject_image_list
117
  ]
118
  }
119
+ ],
120
+ "generationConfig": {
121
+ "response_mime_type": "image/png"
122
+ },
123
+ "safetySettings": [
124
+ {
125
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
126
+ "threshold": "BLOCK_NONE"
127
+ },
128
+ {
129
+ "category": "HARM_CATEGORY_HATE_SPEECH",
130
+ "threshold": "BLOCK_NONE"
131
+ },
132
+ {
133
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
134
+ "threshold": "BLOCK_NONE"
135
+ },
136
+ {
137
+ "category": "HARM_CATEGORY_HARASSMENT",
138
+ "threshold": "BLOCK_NONE"
139
+ }
140
  ]
141
  }
142
 
143
+ max_retries = 1
144
+ for attempt in range(max_retries + 1):
145
+ try:
146
+ logger.info(f"Sending request to Gemini API (attempt {attempt + 1}): {payload}")
147
+ response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
148
+ response.raise_for_status()
149
+ data = response.json()
150
+ logger.info(f"API response: {data}")
151
+ if "safetyRatings" in data:
152
+ logger.info(f"Safety ratings: {data['safetyRatings']}")
153
+ if not data.get("candidates") or not data["candidates"][0].get("content"):
154
+ raise HTTPException(status_code=500, detail="No valid content returned from API")
155
+ return data
156
+ except requests.exceptions.RequestException as e:
157
+ logger.error(f"API request failed: {str(e)}")
158
+ if hasattr(e, 'response') and e.response:
159
+ logger.error(f"API response: {e.response.text}")
160
+ if e.response.status_code in [429, 500] and attempt < max_retries:
161
+ time.sleep(2 ** attempt) # Exponential backoff
162
+ continue
163
+ raise HTTPException(status_code=500, detail=f"API Error: {str(e)}")
164
 
165
  # ===== MAIN PROCESSING =====
166
  async def generate_image(subject_images: List[bytes], prompt: str):
 
192
 
193
  parts = candidate["content"]["parts"]
194
  logger.info(f"Response parts: {parts}")
195
+ # Find the part with inline_data or file_uri
196
  image_base64 = None
197
+ file_uri = None
198
+ text_response = None
199
  for part in parts:
200
  if "inline_data" in part and "data" in part["inline_data"]:
201
  image_base64 = part["inline_data"]["data"]
202
+ if not image_base64:
203
+ logger.warning("Empty inline_data.data received")
204
+ elif "fileUri" in part:
205
+ file_uri = part["fileUri"]
206
+ logger.info(f"File URI found: {file_uri}")
207
  elif "text" in part:
208
+ text_response = part["text"]
209
+ logger.info(f"Text part found: {text_response}")
210
+
211
+ if not image_base64 and not file_uri:
212
+ error_detail = text_response or "No image data (inline_data or fileUri) found in API response"
213
+ if image_base64 == "":
214
+ error_detail = f"Empty inline_data.data returned by API: {text_response or 'No additional details'}"
215
+ logger.error(f"No image data in response parts: {parts}")
216
+ raise HTTPException(status_code=500, detail=f"API error: {error_detail}")
217
 
218
+ if file_uri:
219
+ # Download image from file_uri
220
+ logger.info(f"Downloading image from {file_uri}")
221
+ response = requests.get(file_uri)
222
+ response.raise_for_status()
223
+ image_data = response.content
224
+ else:
225
+ # Decode base64 image
226
+ try:
227
+ image_data = base64.b64decode(image_base64)
228
+ except Exception as e:
229
+ logger.error(f"Failed to decode base64 image: {str(e)}")
230
+ raise HTTPException(status_code=500, detail=f"Failed to decode image data: {str(e)}")
231
 
 
 
232
  output_dir = Path("/tmp")
233
  output_dir.mkdir(exist_ok=True)
234
  output_path = output_dir / f"gemini_output_{int(time.time())}.png"