Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gemini API Client | |
| ================= | |
| Client for Google Gemini Image APIs (Flash and Pro models). | |
| Handles API communication and response parsing. | |
| """ | |
| import base64 | |
| import logging | |
| from io import BytesIO | |
| from typing import Optional | |
| from PIL import Image | |
| from google import genai | |
| from google.genai import types | |
| from .models import GenerationRequest, GenerationResult | |
| logger = logging.getLogger(__name__) | |
| class GeminiClient: | |
| """ | |
| Client for Gemini Image APIs. | |
| Supports: | |
| - Gemini 2.5 Flash Image (up to ~3 reference images) | |
| - Gemini 3 Pro Image Preview (up to 14 reference images, 1K/2K/4K) | |
| """ | |
| # Model names (updated January 2026) | |
| # See: https://ai.google.dev/gemini-api/docs/image-generation | |
| MODEL_FLASH = "gemini-2.5-flash-image" # Fast, efficient image generation | |
| MODEL_PRO = "gemini-3-pro-image-preview" # Pro quality, advanced text rendering | |
| # Valid resolutions for Pro model | |
| VALID_RESOLUTIONS = ["1K", "2K", "4K"] | |
| # Aspect ratio to dimensions mapping | |
| ASPECT_RATIOS = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "21:9": (1536, 640), # Cinematic ultra-wide | |
| "3:2": (1248, 832), | |
| "2:3": (832, 1248), | |
| "3:4": (864, 1184), | |
| "4:3": (1344, 1008), | |
| "4:5": (1024, 1280), | |
| "5:4": (1280, 1024), | |
| } | |
| def __init__(self, api_key: str, use_pro_model: bool = False): | |
| """ | |
| Initialize Gemini client. | |
| Args: | |
| api_key: Google Gemini API key | |
| use_pro_model: If True, use Pro model with enhanced capabilities | |
| """ | |
| if not api_key: | |
| raise ValueError("API key is required for Gemini client") | |
| self.api_key = api_key | |
| self.use_pro_model = use_pro_model | |
| self.client = genai.Client(api_key=api_key) | |
| model_name = self.MODEL_PRO if use_pro_model else self.MODEL_FLASH | |
| logger.info(f"GeminiClient initialized with model: {model_name}") | |
| def generate( | |
| self, | |
| request: GenerationRequest, | |
| resolution: str = "1K" | |
| ) -> GenerationResult: | |
| """ | |
| Generate image using Gemini API. | |
| Args: | |
| request: GenerationRequest object | |
| resolution: Resolution for Pro model ("1K", "2K", "4K") | |
| Returns: | |
| GenerationResult object | |
| """ | |
| try: | |
| model_name = self.MODEL_PRO if self.use_pro_model else self.MODEL_FLASH | |
| logger.info(f"Generating with {model_name}: {request.prompt[:100]}...") | |
| # Build contents list | |
| contents = self._build_contents(request) | |
| # Build config | |
| config = self._build_config( | |
| request, | |
| resolution if self.use_pro_model else None | |
| ) | |
| # Call API | |
| response = self.client.models.generate_content( | |
| model=model_name, | |
| contents=contents, | |
| config=config | |
| ) | |
| # Parse response | |
| return self._parse_response(response) | |
| except Exception as e: | |
| logger.error(f"Gemini generation failed: {e}", exc_info=True) | |
| return GenerationResult.error_result(f"Gemini API error: {str(e)}") | |
| def _build_contents(self, request: GenerationRequest) -> list: | |
| """Build contents list for API request.""" | |
| contents = [] | |
| # Add input images if present | |
| if request.has_input_images: | |
| valid_images = [img for img in request.input_images if img is not None] | |
| contents.extend(valid_images) | |
| # Add prompt | |
| contents.append(request.prompt) | |
| return contents | |
| def _build_config( | |
| self, | |
| request: GenerationRequest, | |
| resolution: Optional[str] = None | |
| ) -> types.GenerateContentConfig: | |
| """Build generation config for API request.""" | |
| # Parse aspect ratio | |
| aspect_ratio = request.aspect_ratio | |
| if " " in aspect_ratio: | |
| aspect_ratio = aspect_ratio.split()[0] | |
| # Build image config | |
| image_config_kwargs = {"aspect_ratio": aspect_ratio} | |
| # Add resolution for Pro model | |
| if resolution and self.use_pro_model: | |
| if resolution not in self.VALID_RESOLUTIONS: | |
| logger.warning(f"Invalid resolution '{resolution}', defaulting to '1K'") | |
| resolution = "1K" | |
| image_config_kwargs["output_image_resolution"] = resolution | |
| logger.info(f"Pro model resolution: {resolution}") | |
| config = types.GenerateContentConfig( | |
| temperature=request.temperature, | |
| response_modalities=["image", "text"], | |
| image_config=types.ImageConfig(**image_config_kwargs) | |
| ) | |
| return config | |
| def _parse_response(self, response) -> GenerationResult: | |
| """Parse API response and extract image.""" | |
| if response is None: | |
| return GenerationResult.error_result("No response from API") | |
| if not hasattr(response, 'candidates') or not response.candidates: | |
| return GenerationResult.error_result("No candidates in response") | |
| candidate = response.candidates[0] | |
| # Check finish reason | |
| if hasattr(candidate, 'finish_reason'): | |
| finish_reason = str(candidate.finish_reason) | |
| logger.info(f"Finish reason: {finish_reason}") | |
| if 'SAFETY' in finish_reason or 'PROHIBITED' in finish_reason: | |
| return GenerationResult.error_result( | |
| f"Content blocked by safety filters: {finish_reason}" | |
| ) | |
| # Check for content | |
| if not hasattr(candidate, 'content') or candidate.content is None: | |
| finish_reason = getattr(candidate, 'finish_reason', 'UNKNOWN') | |
| return GenerationResult.error_result( | |
| f"No content in response (finish_reason: {finish_reason})" | |
| ) | |
| # Extract image from parts | |
| if hasattr(candidate.content, 'parts') and candidate.content.parts: | |
| for part in candidate.content.parts: | |
| if hasattr(part, 'inline_data') and part.inline_data: | |
| try: | |
| image_data = part.inline_data.data | |
| # Handle both bytes and base64 string | |
| if isinstance(image_data, str): | |
| image_data = base64.b64decode(image_data) | |
| # Convert to PIL Image | |
| image_buffer = BytesIO(image_data) | |
| image = Image.open(image_buffer) | |
| image.load() | |
| logger.info(f"Image generated: {image.size}, {image.mode}") | |
| return GenerationResult.success_result( | |
| image=image, | |
| message="Generated successfully" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to decode image: {e}") | |
| return GenerationResult.error_result( | |
| f"Image decoding error: {str(e)}" | |
| ) | |
| return GenerationResult.error_result("No image data in response") | |
| def is_healthy(self) -> bool: | |
| """Check if API is accessible.""" | |
| return self.api_key is not None and len(self.api_key) > 0 | |
| def get_dimensions(cls, aspect_ratio: str) -> tuple: | |
| """Get pixel dimensions for aspect ratio.""" | |
| ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio | |
| return cls.ASPECT_RATIOS.get(ratio, (1024, 1024)) | |