VTO / app /services /image_generation_service.py
Akshajzclap's picture
Update app/services/image_generation_service.py
c3c17f1 verified
# app/services/image_generation_service.py
import os
import logging
import mimetypes
import traceback
import tempfile
from typing import List, Tuple, Optional
from google import genai
from google.genai import types
from app.core.config import settings
# Basic logger for this module. The app-level startup should configure logging more globally.
logger = logging.getLogger(__name__)
# Ensure the Gemini API key exists
if not settings.GEMINI_API_KEY:
raise ValueError("GEMINI_API_KEY not found in environment or .env file")
client = genai.Client(api_key=settings.GEMINI_API_KEY)
IMAGE_GEN_MODEL = settings.IMAGE_GEN_MODEL
def _safe_info(msg: str, *args, **kwargs) -> None:
"""Log info but ignore low-level IO errors (BrokenPipeError) that can occur in containers."""
try:
logger.info(msg, *args, **kwargs)
except (BrokenPipeError, OSError):
# stdout/stderr might be closed by the launcher; swallow these errors
pass
except Exception:
# don't allow logging problems to interrupt business logic
pass
def _safe_error(msg: str, *args, exc: Optional[BaseException] = None, **kwargs) -> None:
"""Log error and exception safely."""
try:
if exc is not None:
logger.exception(msg, *args, **kwargs)
else:
logger.error(msg, *args, **kwargs)
except (BrokenPipeError, OSError):
pass
except Exception:
pass
def _try_delete_uploaded(uploaded_obj) -> None:
"""
Attempt to delete a previously uploaded file from GenAI service.
This is best-effort: we handle multiple possible attribute names and suppress errors.
"""
if uploaded_obj is None:
return
# Common fields that might identify the uploaded file
candidates = []
# google genai's uploaded file object may have attributes like: name, uri, id
for attr in ("name", "uri", "id", "file_id"):
val = getattr(uploaded_obj, attr, None)
if val:
candidates.append((attr, val))
# Try a few delete patterns (best-effort)
try:
# If the client provides a delete method, call it. This API surface may vary by SDK version.
if hasattr(client.files, "delete"):
try:
# If uploaded_obj has a 'name' attribute this is common for many SDKs
name = getattr(uploaded_obj, "name", None)
if name:
client.files.delete(name)
_safe_info(f"[GenAI Info] Deleted uploaded file by name: {name}")
return
except Exception:
# continue to other attempts
pass
# If we have a URI with a resource name, try to request deletion via client.files.delete with URI
uri = getattr(uploaded_obj, "uri", None)
if uri:
try:
client.files.delete(uri)
_safe_info(f"[GenAI Info] Deleted uploaded file by uri: {uri}")
return
except Exception:
pass
# Last resort: try deleting by id if present
file_id = getattr(uploaded_obj, "id", None) or getattr(uploaded_obj, "file_id", None)
if file_id:
try:
client.files.delete(file_id)
_safe_info(f"[GenAI Info] Deleted uploaded file by id: {file_id}")
return
except Exception:
pass
except Exception as e:
_safe_error(f"[GenAI Error] Exception while attempting to delete uploaded file {uploaded_obj}: {e}", exc=e)
async def generate_image_from_files_and_prompt(
image_files: List[Tuple[bytes, str]], # list of (bytes, original_filename)
prompt: str,
) -> Optional[bytes]:
"""
Uploads provided image bytes to GenAI, requests an image generation with the prompt,
and returns the generated image bytes (or None on failure).
Args:
image_files: list of tuples (file_bytes, original_filename).
The function expects at least one image; two images is common for 'replace' flows.
prompt: textual prompt to guide generation.
Returns:
bytes of generated image (if present) or None on failure.
"""
temp_file_paths: List[str] = []
uploaded_file_infos: List[object] = []
parts: List[types.Part] = []
try:
# 1) Write incoming bytes to persistent temp files and upload them
for img_bytes, original_filename in image_files:
# Determine suffix and mime type
suffix = os.path.splitext(original_filename or "")[1] or ""
mime_type, _ = mimetypes.guess_type(original_filename or "")
if not mime_type:
mime_type = "application/octet-stream"
# create a temp file and persist its path (we'll cleanup in finally)
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(img_bytes)
tmp_path = tmp.name
temp_file_paths.append(tmp_path)
_safe_info("[GenAI Info] Created temporary file: %s (mime: %s)", tmp_path, mime_type)
# Upload using client.files.upload. Keep the uploaded file info for potential cleanup
try:
# The SDK may accept either a path string or a file-like object.
uploaded = client.files.upload(file=tmp_path)
uploaded_file_infos.append(uploaded)
# The uploaded object may have attributes like 'uri' and 'mime_type'
uri = getattr(uploaded, "uri", None)
uploaded_mime = getattr(uploaded, "mime_type", mime_type)
parts.append(
types.Part.from_uri(
file_uri=uri if uri else tmp_path, # fallback to path if SDK didn't return uri
mime_type=uploaded_mime or mime_type,
)
)
_safe_info("[GenAI Info] Uploaded file %s -> uri=%s", tmp_path, uri)
except Exception as e:
_safe_error(f"[GenAI Error] Upload failed for {tmp_path}: {e}", exc=e)
# continue loop so we still try to clean up temp files; then fail
raise
# 2) Add prompt as part
parts.append(types.Part.from_text(text=prompt))
contents = [types.Content(role="user", parts=parts)]
# 3) Prepare config (safety settings preserved)
generate_content_config = types.GenerateContentConfig(
response_modalities=["IMAGE", "TEXT"],
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
],
)
_safe_info("[GenAI Info] Requesting generation with model=%s", IMAGE_GEN_MODEL)
response = client.models.generate_content(
model=IMAGE_GEN_MODEL,
contents=contents,
config=generate_content_config,
)
# 4) Parse response - look for image bytes in candidate parts
if getattr(response, "candidates", None):
for candidate in response.candidates:
content = getattr(candidate, "content", None)
if not content:
_safe_info("[GenAI Warning] Candidate has no content; skipping.")
continue
parts_list = getattr(content, "parts", []) or []
for part in parts_list:
# inline_data.data is where binary bytes typically live in SDK responses
inline_data = getattr(part, "inline_data", None)
if inline_data is not None:
data_bytes = getattr(inline_data, "data", None)
if data_bytes:
_safe_info("[GenAI Info] Found image bytes in response; returning bytes.")
return data_bytes
# alternate: some SDKs return base64-encoded strings in text fields; handle common fallbacks
text_val = getattr(part, "text", None)
if text_val and isinstance(text_val, str) and text_val.startswith("data:image/"):
# data URL -> decode base64 portion
try:
header, b64 = text_val.split(",", 1)
import base64
data_bytes = base64.b64decode(b64)
_safe_info("[GenAI Info] Extracted image bytes from data URL in text part.")
return data_bytes
except Exception:
_safe_error("[GenAI Error] Failed to decode data URL in text part.", exc=None)
# If part has 'uri' that points to a generated asset, attempt to fetch it (best-effort)
part_uri = getattr(part, "uri", None)
if part_uri:
try:
# Use client.files.get or HTTP fetch as available. We'll try client.files.get if present.
if hasattr(client.files, "get"):
fetched = client.files.get(part_uri)
# fetched could include bytes in different fields; try common ones
data_bytes = getattr(fetched, "data", None) or getattr(fetched, "content", None)
if data_bytes:
_safe_info("[GenAI Info] Fetched generated image from part.uri via client.files.get")
return data_bytes
else:
# Fall back to HTTP GET if the URI is an http(s) link
if str(part_uri).startswith("http"):
import requests as _req
r = _req.get(part_uri, timeout=10)
if r.status_code == 200:
_safe_info("[GenAI Info] Fetched generated image from HTTP URI in part.")
return r.content
except Exception as e:
_safe_error(f"[GenAI Error] Failed to fetch part URI {part_uri}: {e}", exc=e)
# If we reach here no image bytes found
_safe_info("[GenAI Warning] No image bytes found in response candidates; returning None.")
return None
except Exception as e:
# Log full traceback safely for debugging
_safe_error(f"[GenAI Error] Image generation failed: {e}", exc=e)
try:
# print stack trace to logger (safe wrapper)
tb = traceback.format_exc()
_safe_error(f"[GenAI Error] Traceback:\n{tb}")
except Exception:
pass
return None
finally:
# Cleanup temporary files created locally
for path in temp_file_paths:
try:
if path and os.path.exists(path):
os.remove(path)
_safe_info("[GenAI Info] Deleted temporary file: %s", path)
except Exception as e:
_safe_error(f"[GenAI Error] Failed to delete temporary file {path}: {e}", exc=e)
# Attempt to delete uploaded files from GenAI (best-effort)
for uploaded in uploaded_file_infos:
try:
_try_delete_uploaded(uploaded)
except Exception as e:
_safe_error(f"[GenAI Error] Failed to delete uploaded file record {uploaded}: {e}", exc=e)