Akshajzclap commited on
Commit
c3c17f1
·
verified ·
1 Parent(s): e55366b

Update app/services/image_generation_service.py

Browse files
app/services/image_generation_service.py CHANGED
@@ -1,77 +1,168 @@
 
1
  import os
2
- from google import genai
3
- from google.genai import types
4
- from io import BytesIO
5
- from app.schemas.common_schemas import ModelSchema, CameraSchema
6
- from app.core.config import settings
7
  import mimetypes
8
- import traceback
9
  import tempfile
 
 
 
 
10
 
11
- # Initialize client and model using the settings object
 
 
 
 
 
12
  if not settings.GEMINI_API_KEY:
13
  raise ValueError("GEMINI_API_KEY not found in environment or .env file")
14
 
15
  client = genai.Client(api_key=settings.GEMINI_API_KEY)
16
  IMAGE_GEN_MODEL = settings.IMAGE_GEN_MODEL
17
 
18
- def save_binary_file(filename, data):
19
- """Saves binary data to a file."""
20
- with open(filename, "wb") as f:
21
- f.write(data)
22
 
23
- async def generate_image_from_files_and_prompt(
24
- image_files: list[tuple[bytes, str]], # Expect list of (bytes, original_filename)
25
- prompt: str
26
- ) -> bytes | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- Generates an image using GenAI based on exactly two input images and model/camera schemas.
29
- Returns the image data as bytes or None if an error occurs.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
- # if not client.api_key:
32
- # print("[GenAI Error] API Key not configured.")
33
- # return None
34
 
 
 
 
 
35
 
36
- temp_file_paths = []
37
- uploaded_file_infos = []
38
- parts = []
 
 
 
39
 
40
  try:
41
- for i, (img_bytes, original_filename) in enumerate(image_files):
42
- # Guess mime type from original filename, default if not guessable
43
- mime_type, _ = mimetypes.guess_type(original_filename)
 
 
44
  if not mime_type:
45
- mime_type = "application/octet-stream" # Default MIME type
46
-
47
- # Create a named temporary file to get a persistent path
48
- # Suffix helps genai identify file type, though mime_type in upload is better
49
- suffix = os.path.splitext(original_filename)[1]
50
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
51
- tmp_file.write(img_bytes)
52
- temp_file_paths.append(tmp_file.name)
53
-
54
- print(f"[GenAI Info] Uploading temporary file: {temp_file_paths[-1]} with MIME type: {mime_type}")
55
- # Upload the file using its path
56
- temp_file_path = temp_file_paths[-1]
57
- with open(temp_file_paths[-1], "rb") as f:
58
- uploaded_file = client.files.upload(file=temp_file_path)
59
-
60
- uploaded_file_infos.append(uploaded_file)
61
- parts.append(types.Part.from_uri(
62
- file_uri=uploaded_file.uri,
63
- mime_type=uploaded_file.mime_type, # Use mime_type from upload response
64
- ))
65
- print(f"[GenAI Info] File {original_filename} uploaded. URI: {uploaded_file.uri}")
66
-
67
-
68
- prompt = prompt
69
- parts.append(types.Part.from_text(text=prompt))
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  contents = [types.Content(role="user", parts=parts)]
72
 
 
73
  generate_content_config = types.GenerateContentConfig(
74
- response_modalities=["IMAGE", "TEXT"], # Requesting IMAGE modality
75
  safety_settings=[
76
  types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
77
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
@@ -80,48 +171,93 @@ async def generate_image_from_files_and_prompt(
80
  ],
81
  )
82
 
83
- print(f"[GenAI Info] Generating content with model: {IMAGE_GEN_MODEL}")
84
  response = client.models.generate_content(
85
  model=IMAGE_GEN_MODEL,
86
  contents=contents,
87
- config=generate_content_config
88
  )
89
-
90
- # print(f"[GenAI Response]: {response}") # For debugging the whole response
91
 
92
- if response.candidates:
 
93
  for candidate in response.candidates:
94
- if candidate.content is not None and hasattr(candidate.content, "parts"):
95
- for part in candidate.content.parts:
96
- if part.inline_data and part.inline_data.data:
97
- print("[GenAI Info] Image data found in response.")
98
- return part.inline_data.data
99
- elif part.text:
100
- print(f"[GenAI Info or Warning] Text response from model: {part.text}")
101
- else:
102
- print("[GenAI Warning] Candidate content is None or missing 'parts'. Full candidate:", candidate)
103
-
104
- print("[GenAI Warning] No image data found directly in response parts.")
105
- # Fallback: Check if the text part contains an error that explains why no image was generated
106
- if response.candidates and response.candidates[0].content.parts and response.candidates[0].content.parts[0].text:
107
- pass # Already printed above
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return None
110
 
111
- # except types.generation_types.BlockedPromptException as e:
112
- # print(f"[GenAI Error] Prompt was blocked: {e}")
113
- # return None
114
  except Exception as e:
115
- print(f"[GenAI Error] Image generation failed: {e}")
116
- # You might want to log the full traceback here for debugging
117
- traceback.print_exc()
 
 
 
 
 
118
  return None
 
119
  finally:
120
- # Clean up temporary files
121
  for path in temp_file_paths:
122
  try:
123
- os.remove(path)
124
- print(f"[GenAI Info] Deleted temporary file: {path}")
125
- except OSError as e:
126
- print(f"[GenAI Error] Failed to delete temporary file {path}: {e}")
127
-
 
 
 
 
 
 
 
 
1
+ # app/services/image_generation_service.py
2
  import os
3
+ import logging
 
 
 
 
4
  import mimetypes
5
+ import traceback
6
  import tempfile
7
+ from typing import List, Tuple, Optional
8
+
9
+ from google import genai
10
+ from google.genai import types
11
 
12
+ from app.core.config import settings
13
+
14
+ # Basic logger for this module. The app-level startup should configure logging more globally.
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Ensure the Gemini API key exists
18
  if not settings.GEMINI_API_KEY:
19
  raise ValueError("GEMINI_API_KEY not found in environment or .env file")
20
 
21
  client = genai.Client(api_key=settings.GEMINI_API_KEY)
22
  IMAGE_GEN_MODEL = settings.IMAGE_GEN_MODEL
23
 
 
 
 
 
24
 
25
+ def _safe_info(msg: str, *args, **kwargs) -> None:
26
+ """Log info but ignore low-level IO errors (BrokenPipeError) that can occur in containers."""
27
+ try:
28
+ logger.info(msg, *args, **kwargs)
29
+ except (BrokenPipeError, OSError):
30
+ # stdout/stderr might be closed by the launcher; swallow these errors
31
+ pass
32
+ except Exception:
33
+ # don't allow logging problems to interrupt business logic
34
+ pass
35
+
36
+
37
+ def _safe_error(msg: str, *args, exc: Optional[BaseException] = None, **kwargs) -> None:
38
+ """Log error and exception safely."""
39
+ try:
40
+ if exc is not None:
41
+ logger.exception(msg, *args, **kwargs)
42
+ else:
43
+ logger.error(msg, *args, **kwargs)
44
+ except (BrokenPipeError, OSError):
45
+ pass
46
+ except Exception:
47
+ pass
48
+
49
+
50
+ def _try_delete_uploaded(uploaded_obj) -> None:
51
  """
52
+ Attempt to delete a previously uploaded file from GenAI service.
53
+ This is best-effort: we handle multiple possible attribute names and suppress errors.
54
+ """
55
+ if uploaded_obj is None:
56
+ return
57
+ # Common fields that might identify the uploaded file
58
+ candidates = []
59
+ # google genai's uploaded file object may have attributes like: name, uri, id
60
+ for attr in ("name", "uri", "id", "file_id"):
61
+ val = getattr(uploaded_obj, attr, None)
62
+ if val:
63
+ candidates.append((attr, val))
64
+ # Try a few delete patterns (best-effort)
65
+ try:
66
+ # If the client provides a delete method, call it. This API surface may vary by SDK version.
67
+ if hasattr(client.files, "delete"):
68
+ try:
69
+ # If uploaded_obj has a 'name' attribute this is common for many SDKs
70
+ name = getattr(uploaded_obj, "name", None)
71
+ if name:
72
+ client.files.delete(name)
73
+ _safe_info(f"[GenAI Info] Deleted uploaded file by name: {name}")
74
+ return
75
+ except Exception:
76
+ # continue to other attempts
77
+ pass
78
+
79
+ # If we have a URI with a resource name, try to request deletion via client.files.delete with URI
80
+ uri = getattr(uploaded_obj, "uri", None)
81
+ if uri:
82
+ try:
83
+ client.files.delete(uri)
84
+ _safe_info(f"[GenAI Info] Deleted uploaded file by uri: {uri}")
85
+ return
86
+ except Exception:
87
+ pass
88
+
89
+ # Last resort: try deleting by id if present
90
+ file_id = getattr(uploaded_obj, "id", None) or getattr(uploaded_obj, "file_id", None)
91
+ if file_id:
92
+ try:
93
+ client.files.delete(file_id)
94
+ _safe_info(f"[GenAI Info] Deleted uploaded file by id: {file_id}")
95
+ return
96
+ except Exception:
97
+ pass
98
+
99
+ except Exception as e:
100
+ _safe_error(f"[GenAI Error] Exception while attempting to delete uploaded file {uploaded_obj}: {e}", exc=e)
101
+
102
+
103
+ async def generate_image_from_files_and_prompt(
104
+ image_files: List[Tuple[bytes, str]], # list of (bytes, original_filename)
105
+ prompt: str,
106
+ ) -> Optional[bytes]:
107
  """
108
+ Uploads provided image bytes to GenAI, requests an image generation with the prompt,
109
+ and returns the generated image bytes (or None on failure).
 
110
 
111
+ Args:
112
+ image_files: list of tuples (file_bytes, original_filename).
113
+ The function expects at least one image; two images is common for 'replace' flows.
114
+ prompt: textual prompt to guide generation.
115
 
116
+ Returns:
117
+ bytes of generated image (if present) or None on failure.
118
+ """
119
+ temp_file_paths: List[str] = []
120
+ uploaded_file_infos: List[object] = []
121
+ parts: List[types.Part] = []
122
 
123
  try:
124
+ # 1) Write incoming bytes to persistent temp files and upload them
125
+ for img_bytes, original_filename in image_files:
126
+ # Determine suffix and mime type
127
+ suffix = os.path.splitext(original_filename or "")[1] or ""
128
+ mime_type, _ = mimetypes.guess_type(original_filename or "")
129
  if not mime_type:
130
+ mime_type = "application/octet-stream"
131
+
132
+ # create a temp file and persist its path (we'll cleanup in finally)
133
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
134
+ tmp.write(img_bytes)
135
+ tmp_path = tmp.name
136
+ temp_file_paths.append(tmp_path)
137
+ _safe_info("[GenAI Info] Created temporary file: %s (mime: %s)", tmp_path, mime_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # Upload using client.files.upload. Keep the uploaded file info for potential cleanup
140
+ try:
141
+ # The SDK may accept either a path string or a file-like object.
142
+ uploaded = client.files.upload(file=tmp_path)
143
+ uploaded_file_infos.append(uploaded)
144
+ # The uploaded object may have attributes like 'uri' and 'mime_type'
145
+ uri = getattr(uploaded, "uri", None)
146
+ uploaded_mime = getattr(uploaded, "mime_type", mime_type)
147
+ parts.append(
148
+ types.Part.from_uri(
149
+ file_uri=uri if uri else tmp_path, # fallback to path if SDK didn't return uri
150
+ mime_type=uploaded_mime or mime_type,
151
+ )
152
+ )
153
+ _safe_info("[GenAI Info] Uploaded file %s -> uri=%s", tmp_path, uri)
154
+ except Exception as e:
155
+ _safe_error(f"[GenAI Error] Upload failed for {tmp_path}: {e}", exc=e)
156
+ # continue loop so we still try to clean up temp files; then fail
157
+ raise
158
+
159
+ # 2) Add prompt as part
160
+ parts.append(types.Part.from_text(text=prompt))
161
  contents = [types.Content(role="user", parts=parts)]
162
 
163
+ # 3) Prepare config (safety settings preserved)
164
  generate_content_config = types.GenerateContentConfig(
165
+ response_modalities=["IMAGE", "TEXT"],
166
  safety_settings=[
167
  types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
168
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
 
171
  ],
172
  )
173
 
174
+ _safe_info("[GenAI Info] Requesting generation with model=%s", IMAGE_GEN_MODEL)
175
  response = client.models.generate_content(
176
  model=IMAGE_GEN_MODEL,
177
  contents=contents,
178
+ config=generate_content_config,
179
  )
 
 
180
 
181
+ # 4) Parse response - look for image bytes in candidate parts
182
+ if getattr(response, "candidates", None):
183
  for candidate in response.candidates:
184
+ content = getattr(candidate, "content", None)
185
+ if not content:
186
+ _safe_info("[GenAI Warning] Candidate has no content; skipping.")
187
+ continue
188
+ parts_list = getattr(content, "parts", []) or []
189
+ for part in parts_list:
190
+ # inline_data.data is where binary bytes typically live in SDK responses
191
+ inline_data = getattr(part, "inline_data", None)
192
+ if inline_data is not None:
193
+ data_bytes = getattr(inline_data, "data", None)
194
+ if data_bytes:
195
+ _safe_info("[GenAI Info] Found image bytes in response; returning bytes.")
196
+ return data_bytes
197
+ # alternate: some SDKs return base64-encoded strings in text fields; handle common fallbacks
198
+ text_val = getattr(part, "text", None)
199
+ if text_val and isinstance(text_val, str) and text_val.startswith("data:image/"):
200
+ # data URL -> decode base64 portion
201
+ try:
202
+ header, b64 = text_val.split(",", 1)
203
+ import base64
204
+ data_bytes = base64.b64decode(b64)
205
+ _safe_info("[GenAI Info] Extracted image bytes from data URL in text part.")
206
+ return data_bytes
207
+ except Exception:
208
+ _safe_error("[GenAI Error] Failed to decode data URL in text part.", exc=None)
209
 
210
+ # If part has 'uri' that points to a generated asset, attempt to fetch it (best-effort)
211
+ part_uri = getattr(part, "uri", None)
212
+ if part_uri:
213
+ try:
214
+ # Use client.files.get or HTTP fetch as available. We'll try client.files.get if present.
215
+ if hasattr(client.files, "get"):
216
+ fetched = client.files.get(part_uri)
217
+ # fetched could include bytes in different fields; try common ones
218
+ data_bytes = getattr(fetched, "data", None) or getattr(fetched, "content", None)
219
+ if data_bytes:
220
+ _safe_info("[GenAI Info] Fetched generated image from part.uri via client.files.get")
221
+ return data_bytes
222
+ else:
223
+ # Fall back to HTTP GET if the URI is an http(s) link
224
+ if str(part_uri).startswith("http"):
225
+ import requests as _req
226
+ r = _req.get(part_uri, timeout=10)
227
+ if r.status_code == 200:
228
+ _safe_info("[GenAI Info] Fetched generated image from HTTP URI in part.")
229
+ return r.content
230
+ except Exception as e:
231
+ _safe_error(f"[GenAI Error] Failed to fetch part URI {part_uri}: {e}", exc=e)
232
+
233
+ # If we reach here no image bytes found
234
+ _safe_info("[GenAI Warning] No image bytes found in response candidates; returning None.")
235
  return None
236
 
 
 
 
237
  except Exception as e:
238
+ # Log full traceback safely for debugging
239
+ _safe_error(f"[GenAI Error] Image generation failed: {e}", exc=e)
240
+ try:
241
+ # print stack trace to logger (safe wrapper)
242
+ tb = traceback.format_exc()
243
+ _safe_error(f"[GenAI Error] Traceback:\n{tb}")
244
+ except Exception:
245
+ pass
246
  return None
247
+
248
  finally:
249
+ # Cleanup temporary files created locally
250
  for path in temp_file_paths:
251
  try:
252
+ if path and os.path.exists(path):
253
+ os.remove(path)
254
+ _safe_info("[GenAI Info] Deleted temporary file: %s", path)
255
+ except Exception as e:
256
+ _safe_error(f"[GenAI Error] Failed to delete temporary file {path}: {e}", exc=e)
257
+
258
+ # Attempt to delete uploaded files from GenAI (best-effort)
259
+ for uploaded in uploaded_file_infos:
260
+ try:
261
+ _try_delete_uploaded(uploaded)
262
+ except Exception as e:
263
+ _safe_error(f"[GenAI Error] Failed to delete uploaded file record {uploaded}: {e}", exc=e)