Spaces:
Sleeping
Sleeping
| """ | |
| Batch LLM Client for cost-effective processing using Gemini Batch API. | |
| This client provides 50% cost savings by using Google's Gemini Batch API | |
| instead of real-time API calls. Ideal for large-scale prompt optimization | |
| where latency is acceptable. | |
| Features: | |
| - 50% cost reduction compared to standard API | |
| - Automatic batching and job management | |
| - Built-in retry and polling logic | |
| - Thread-safe operation | |
| - Comprehensive error handling | |
| Author: GEPA Optimizer Team | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| import tempfile | |
| import io | |
| from pathlib import Path | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from .base_llm import BaseLLMClient | |
| try: | |
| from PIL import Image | |
| PIL_AVAILABLE = True | |
| except ImportError: | |
| PIL_AVAILABLE = False | |
| Image = None | |
| try: | |
| from google import genai | |
| from google.genai import types | |
| GENAI_AVAILABLE = True | |
| except ImportError: | |
| GENAI_AVAILABLE = False | |
| genai = None | |
| types = None | |
| logger = logging.getLogger(__name__) | |
| class BatchLLMClient(BaseLLMClient): | |
| """ | |
| Batch LLM client that uses Gemini Batch API for cost-effective processing. | |
| This client processes multiple requests together in batch jobs, providing: | |
| - 50% cost savings vs standard API | |
| - No rate limit impact | |
| - Automatic job management and polling | |
| Usage: | |
| >>> from gepa_optimizer.llms import BatchLLMClient | |
| >>> | |
| >>> client = BatchLLMClient( | |
| ... provider="google", | |
| ... model_name="gemini-2.5-flash", | |
| ... api_key="your-key", | |
| ... batch_size=20, | |
| ... polling_interval=30 | |
| ... ) | |
| >>> | |
| >>> # Use just like VisionLLMClient - adapter handles the rest! | |
| >>> result = client.generate( | |
| ... system_prompt="You are a helpful assistant", | |
| ... user_prompt="Analyze this image", | |
| ... image_base64="..." | |
| ... ) | |
| Performance Note: | |
| Batch processing adds latency (30s+ polling time) but reduces costs by 50%. | |
| Choose this mode for large-scale optimization where cost > speed. | |
| """ | |
| def __init__( | |
| self, | |
| provider: str, | |
| model_name: str, | |
| api_key: Optional[str] = None, | |
| batch_size: int = 20, | |
| polling_interval: int = 30, | |
| max_polling_time: int = 3600, | |
| temp_dir: str = ".gepa_batch_temp", | |
| **kwargs | |
| ): | |
| """ | |
| Initialize Batch LLM Client. | |
| Args: | |
| provider: Must be "google" or "gemini" | |
| model_name: Gemini model (e.g., "gemini-2.5-flash", "gemini-1.5-flash") | |
| api_key: Google API key (defaults to GEMINI_API_KEY env var) | |
| batch_size: Number of samples to process per batch job (1-100) | |
| polling_interval: Seconds between job status checks (default: 30) | |
| max_polling_time: Maximum seconds to wait for job completion (default: 3600) | |
| temp_dir: Directory for temporary files (default: ".gepa_batch_temp") | |
| **kwargs: Additional parameters | |
| Raises: | |
| ValueError: If provider is not Google/Gemini | |
| ImportError: If google-genai is not installed | |
| """ | |
| super().__init__(provider=provider, model_name=model_name, **kwargs) | |
| # Validate provider | |
| if provider.lower() not in ["google", "gemini"]: | |
| raise ValueError( | |
| f"BatchLLMClient only supports Google/Gemini provider. Got: {provider}" | |
| ) | |
| # Check dependencies | |
| if not GENAI_AVAILABLE: | |
| raise ImportError( | |
| "google-genai not installed. Install with: pip install google-genai" | |
| ) | |
| # Configuration | |
| self.batch_size = batch_size | |
| self.polling_interval = polling_interval | |
| self.max_polling_time = max_polling_time | |
| self.temp_dir = Path(temp_dir) | |
| self.temp_dir.mkdir(exist_ok=True) | |
| # Initialize Gemini client | |
| from ..utils.api_keys import APIKeyManager | |
| self.api_key = api_key or APIKeyManager().get_api_key("google") | |
| if not self.api_key: | |
| raise ValueError( | |
| "Google API key required. Provide via api_key parameter or " | |
| "set GEMINI_API_KEY environment variable." | |
| ) | |
| self.client = genai.Client(api_key=self.api_key) | |
| logger.info( | |
| f"✓ BatchLLMClient initialized: {model_name} " | |
| f"(batch_size={batch_size}, polling={polling_interval}s)" | |
| ) | |
| def generate( | |
| self, | |
| system_prompt: str, | |
| user_prompt: str, | |
| image_base64: Optional[str] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate response using batch API. | |
| Note: This method is primarily for compatibility. For batch optimization, | |
| the adapter will call generate_batch() directly with multiple requests. | |
| Args: | |
| system_prompt: System-level instructions | |
| user_prompt: User's input prompt | |
| image_base64: Optional base64 encoded image | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Dict with 'content' key containing generated text | |
| """ | |
| # Single request - process as a batch of 1 | |
| requests = [{ | |
| 'system_prompt': system_prompt, | |
| 'user_prompt': user_prompt, | |
| 'image_base64': image_base64 | |
| }] | |
| results = self.generate_batch(requests) | |
| return results[0] if results else {"content": "", "error": "No results"} | |
| def generate_batch( | |
| self, | |
| requests: List[Dict[str, Any]], | |
| timeout_override: Optional[int] = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Process multiple requests in a single batch job. | |
| This is the main method called by UniversalGepaAdapter during GEPA optimization. | |
| Args: | |
| requests: List of request dicts with keys: | |
| - system_prompt: System instructions | |
| - user_prompt: User input | |
| - image_base64: Optional base64 image | |
| timeout_override: Override max_polling_time for this batch | |
| Returns: | |
| List of response dicts with 'content' key | |
| Raises: | |
| RuntimeError: If batch job fails | |
| TimeoutError: If polling exceeds timeout | |
| """ | |
| logger.info(f"📦 Processing batch of {len(requests)} requests via Gemini Batch API...") | |
| start_time = time.time() | |
| try: | |
| # Step 1: Upload images if needed | |
| file_uris, mime_types = self._upload_images_for_batch(requests) | |
| # Step 2: Create JSONL file | |
| jsonl_path = self._create_batch_jsonl(requests, file_uris, mime_types) | |
| # Step 3: Submit batch job | |
| batch_job_name = self._submit_batch_job(jsonl_path) | |
| # Step 4: Wait for completion | |
| timeout = timeout_override or self.max_polling_time | |
| self._wait_for_batch_completion(batch_job_name, timeout) | |
| # Step 5: Retrieve results | |
| results = self._retrieve_batch_results(batch_job_name) | |
| # Cleanup | |
| jsonl_path.unlink(missing_ok=True) | |
| elapsed_time = time.time() - start_time | |
| logger.info( | |
| f"✓ Batch processing complete: {len(results)} results in {elapsed_time:.1f}s " | |
| f"(~{elapsed_time/len(results):.1f}s per request)" | |
| ) | |
| return results | |
| except Exception as e: | |
| elapsed_time = time.time() - start_time | |
| logger.error(f"❌ Batch processing failed after {elapsed_time:.1f}s: {e}") | |
| raise | |
| def _upload_images_for_batch(self, requests: List[Dict]) -> Tuple[List[Optional[str]], List[Optional[str]]]: | |
| """ | |
| Upload images to Gemini and return file URIs and MIME types. | |
| Args: | |
| requests: List of request dicts | |
| Returns: | |
| Tuple of (file_uris, mime_types) - both are lists with None for requests without images | |
| """ | |
| file_uris = [] | |
| mime_types = [] | |
| images_to_upload = sum(1 for r in requests if r.get('image_base64')) | |
| if images_to_upload > 0: | |
| logger.info(f" ⬆️ Uploading {images_to_upload} images to Gemini...") | |
| for i, request in enumerate(requests): | |
| image_base64 = request.get('image_base64') | |
| if not image_base64: | |
| file_uris.append(None) | |
| mime_types.append(None) | |
| continue | |
| try: | |
| # Decode image data | |
| import base64 | |
| image_data = base64.b64decode(image_base64) | |
| # Detect image format using Pillow | |
| image_format = None | |
| if PIL_AVAILABLE: | |
| try: | |
| img = Image.open(io.BytesIO(image_data)) | |
| image_format = img.format.lower() if img.format else None | |
| except Exception as e: | |
| logger.warning(f" ⚠️ Could not detect image format: {e}") | |
| # Map format to extension and MIME type | |
| format_map = { | |
| 'jpeg': ('.jpg', 'image/jpeg'), | |
| 'jpg': ('.jpg', 'image/jpeg'), | |
| 'png': ('.png', 'image/png'), | |
| 'gif': ('.gif', 'image/gif'), | |
| 'webp': ('.webp', 'image/webp'), | |
| 'bmp': ('.bmp', 'image/bmp'), | |
| 'tiff': ('.tiff', 'image/tiff'), | |
| 'tif': ('.tiff', 'image/tiff'), | |
| } | |
| # Get extension and MIME type (default to PNG if unknown) | |
| ext, mime_type = format_map.get(image_format, ('.png', 'image/png')) | |
| if image_format and image_format not in format_map: | |
| logger.warning(f" ⚠️ Unknown image format '{image_format}' for image {i}, defaulting to PNG") | |
| elif not image_format: | |
| logger.debug(f" ℹ️ Could not detect format for image {i}, using PNG") | |
| # Save to temp file with correct extension | |
| temp_file = tempfile.NamedTemporaryFile( | |
| delete=False, | |
| suffix=ext, | |
| dir=self.temp_dir | |
| ) | |
| temp_file.write(image_data) | |
| temp_file.close() | |
| # Upload to Gemini with correct MIME type | |
| uploaded_file = self.client.files.upload( | |
| file=temp_file.name, | |
| config=types.UploadFileConfig( | |
| display_name=f"batch_image_{i}_{int(time.time())}{ext}", | |
| mime_type=mime_type | |
| ) | |
| ) | |
| logger.debug(f" ✓ Uploaded image {i} as {mime_type}") | |
| # Wait for file to be active | |
| self._wait_for_file_active(uploaded_file) | |
| file_uris.append(uploaded_file.uri) | |
| mime_types.append(mime_type) | |
| # Cleanup temp file | |
| Path(temp_file.name).unlink() | |
| except Exception as e: | |
| logger.error(f" ✗ Failed to upload image {i}: {e}") | |
| file_uris.append(None) | |
| mime_types.append(None) | |
| if images_to_upload > 0: | |
| successful = sum(1 for uri in file_uris if uri is not None) | |
| logger.info(f" ✓ Uploaded {successful}/{images_to_upload} images successfully") | |
| return file_uris, mime_types | |
| def _create_batch_jsonl( | |
| self, | |
| requests: List[Dict], | |
| file_uris: List[Optional[str]], | |
| mime_types: List[Optional[str]] | |
| ) -> Path: | |
| """ | |
| Create JSONL file for batch job. | |
| Args: | |
| requests: List of request dicts | |
| file_uris: List of uploaded file URIs | |
| mime_types: List of MIME types for uploaded files | |
| Returns: | |
| Path to created JSONL file | |
| """ | |
| timestamp = int(time.time()) | |
| jsonl_path = self.temp_dir / f"batch_{timestamp}.jsonl" | |
| with open(jsonl_path, 'w', encoding='utf-8') as f: | |
| for i, (request, file_uri, mime_type) in enumerate(zip(requests, file_uris, mime_types)): | |
| # Combine system and user prompts | |
| system_prompt = request.get('system_prompt', '') | |
| user_prompt = request.get('user_prompt', '') | |
| full_prompt = f"{system_prompt}\n\n{user_prompt}".strip() | |
| # Build request parts | |
| parts = [{"text": full_prompt}] | |
| if file_uri: | |
| parts.append({ | |
| "file_data": { | |
| "file_uri": file_uri, | |
| "mime_type": mime_type or "image/png" # Use actual MIME type | |
| } | |
| }) | |
| # Gemini Batch API format according to official docs | |
| # Reference: https://ai.google.dev/gemini-api/docs/batch-inference | |
| # NOTE: The "request" wrapper is REQUIRED for Gemini 2.5 batch API | |
| batch_request = { | |
| "custom_id": f"request-{i}", | |
| "request": { | |
| "contents": [{ | |
| "role": "user", | |
| "parts": parts | |
| }] | |
| } | |
| } | |
| f.write(json.dumps(batch_request, ensure_ascii=False) + '\n') | |
| logger.info(f" 📝 Created JSONL file: {jsonl_path.name} ({len(requests)} requests)") | |
| return jsonl_path | |
| def _submit_batch_job(self, jsonl_path: Path) -> str: | |
| """ | |
| Submit batch job to Gemini. | |
| Args: | |
| jsonl_path: Path to JSONL file | |
| Returns: | |
| Batch job name | |
| """ | |
| # Upload JSONL file | |
| # Try multiple methods as the google-genai SDK can be finicky | |
| try: | |
| logger.info(f" 📤 Uploading JSONL file: {jsonl_path.name}") | |
| # Read and validate file content | |
| with open(jsonl_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| line_count = len(content.strip().split('\n')) | |
| logger.debug(f" 📄 JSONL: {len(content)} bytes, {line_count} lines") | |
| # Validate JSONL format | |
| for line_num, line in enumerate(content.strip().split('\n'), 1): | |
| try: | |
| json.loads(line) | |
| except json.JSONDecodeError as e: | |
| logger.error(f" ❌ Invalid JSON at line {line_num}: {e}") | |
| logger.error(f" Content: {line[:100]}...") | |
| raise ValueError(f"Invalid JSONL format at line {line_num}") from e | |
| # Method 1: Try uploading with Path object | |
| logger.info(f" 🔄 Upload method 1: Using Path object...") | |
| try: | |
| jsonl_file = self.client.files.upload( | |
| file=jsonl_path, | |
| config=types.UploadFileConfig( | |
| display_name=f'gepa-batch-{int(time.time())}', | |
| mime_type='application/json' # Try application/json instead of application/jsonl | |
| ) | |
| ) | |
| logger.info(f" ✓ JSONL file uploaded: {jsonl_file.name}") | |
| except Exception as e1: | |
| logger.warning(f" ⚠️ Method 1 failed: {e1}") | |
| logger.info(f" 🔄 Upload method 2: Using string path...") | |
| # Method 2: Fallback to string path | |
| try: | |
| jsonl_file = self.client.files.upload( | |
| file=str(jsonl_path.absolute()), | |
| config=types.UploadFileConfig( | |
| display_name=f'gepa-batch-{int(time.time())}', | |
| mime_type='application/json' | |
| ) | |
| ) | |
| logger.info(f" ✓ JSONL file uploaded (method 2): {jsonl_file.name}") | |
| except Exception as e2: | |
| logger.error(f" ❌ Method 2 also failed: {e2}") | |
| raise e2 | |
| except KeyError as e: | |
| logger.error(f"❌ KeyError during JSONL upload: {e}") | |
| logger.error(f" This suggests the Gemini API response format changed") | |
| logger.error(f" Try updating google-genai: pip install --upgrade google-genai") | |
| raise RuntimeError(f"Gemini Batch API response format error: {e}") from e | |
| except Exception as e: | |
| logger.error(f"❌ Failed to upload JSONL file: {e}") | |
| logger.error(f" File path: {jsonl_path}") | |
| logger.error(f" File exists: {jsonl_path.exists()}") | |
| logger.error(f" File size: {jsonl_path.stat().st_size if jsonl_path.exists() else 'N/A'} bytes") | |
| raise RuntimeError(f"Gemini Batch API file upload failed: {e}") from e | |
| # Wait for JSONL to be active | |
| try: | |
| logger.info(f" ⏳ Waiting for JSONL file to be processed...") | |
| self._wait_for_file_active(jsonl_file) | |
| except Exception as e: | |
| logger.error(f"❌ JSONL file processing failed: {e}") | |
| raise | |
| # Create batch job | |
| try: | |
| logger.info(f" 🚀 Creating batch job...") | |
| batch_job = self.client.batches.create( | |
| model=self.model_name, | |
| src=jsonl_file.name, | |
| config={'display_name': f'gepa-opt-{int(time.time())}'} | |
| ) | |
| logger.info(f" ✓ Batch job submitted: {batch_job.name}") | |
| return batch_job.name | |
| except Exception as e: | |
| logger.error(f"❌ Failed to create batch job: {e}") | |
| raise RuntimeError(f"Batch job creation failed: {e}") from e | |
| def _wait_for_batch_completion(self, job_name: str, timeout: int): | |
| """ | |
| Poll batch job until completion. | |
| Args: | |
| job_name: Batch job name | |
| timeout: Maximum seconds to wait | |
| Raises: | |
| TimeoutError: If polling exceeds timeout | |
| RuntimeError: If batch job fails | |
| """ | |
| logger.info(f" ⏳ Polling for completion (checking every {self.polling_interval}s)...") | |
| start_time = time.time() | |
| poll_count = 0 | |
| while True: | |
| elapsed = time.time() - start_time | |
| if elapsed > timeout: | |
| raise TimeoutError( | |
| f"Batch job timeout after {elapsed:.0f}s " | |
| f"(max: {timeout}s)" | |
| ) | |
| try: | |
| batch_job = self.client.batches.get(name=job_name) | |
| state = batch_job.state.name | |
| # Success states | |
| if state in ['JOB_STATE_SUCCEEDED', 'SUCCEEDED']: | |
| logger.info(f" ✓ Batch job completed in {elapsed:.0f}s") | |
| return | |
| # Failure states | |
| if state in ['JOB_STATE_FAILED', 'FAILED']: | |
| raise RuntimeError(f"Batch job failed with state: {state}") | |
| if state in ['JOB_STATE_CANCELLED', 'CANCELLED']: | |
| raise RuntimeError(f"Batch job was cancelled: {state}") | |
| # Still processing | |
| poll_count += 1 | |
| if poll_count % 5 == 0: # Log every 5 polls | |
| logger.info(f" ... still processing ({elapsed:.0f}s elapsed, state: {state})") | |
| time.sleep(self.polling_interval) | |
| except (TimeoutError, RuntimeError): | |
| raise | |
| except Exception as e: | |
| logger.warning(f" ⚠️ Error checking job status: {e}, retrying...") | |
| time.sleep(5) | |
| def _retrieve_batch_results(self, job_name: str) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve and parse batch results. | |
| Args: | |
| job_name: Batch job name | |
| Returns: | |
| List of result dicts | |
| """ | |
| batch_job = self.client.batches.get(name=job_name) | |
| # Check for inline responses (preferred) | |
| if hasattr(batch_job.dest, 'inlined_responses') and batch_job.dest.inlined_responses: | |
| logger.info(f" 📥 Processing inline responses...") | |
| return self._parse_inline_results(batch_job.dest.inlined_responses) | |
| # Download results file (fallback) | |
| if hasattr(batch_job.dest, 'file_name') and batch_job.dest.file_name: | |
| logger.info(f" 📥 Downloading results file: {batch_job.dest.file_name}") | |
| file_data = self.client.files.download(file=batch_job.dest.file_name) | |
| return self._parse_file_results(file_data) | |
| raise RuntimeError("No results available from batch job") | |
| def _parse_inline_results(self, inline_responses) -> List[Dict[str, Any]]: | |
| """Parse inline batch results.""" | |
| results = [] | |
| for response_obj in inline_responses: | |
| if hasattr(response_obj, 'response') and response_obj.response: | |
| text = self._extract_text_from_response(response_obj.response) | |
| results.append({ | |
| "content": text, | |
| "role": "assistant", | |
| "model": self.model_name, | |
| "provider": "google" | |
| }) | |
| else: | |
| error_msg = str(getattr(response_obj, 'error', 'Unknown error')) | |
| logger.warning(f" ⚠️ Response error: {error_msg}") | |
| results.append({ | |
| "content": "", | |
| "error": error_msg | |
| }) | |
| return results | |
| def _parse_file_results(self, file_data) -> List[Dict[str, Any]]: | |
| """Parse JSONL results file.""" | |
| if isinstance(file_data, bytes): | |
| jsonl_content = file_data.decode('utf-8') | |
| else: | |
| jsonl_content = file_data | |
| results = [] | |
| for line_num, line in enumerate(jsonl_content.strip().split('\n'), 1): | |
| if not line.strip(): | |
| continue | |
| try: | |
| result = json.loads(line) | |
| if 'response' in result: | |
| text = self._extract_text_from_dict(result['response']) | |
| results.append({ | |
| "content": text, | |
| "role": "assistant", | |
| "model": self.model_name, | |
| "provider": "google" | |
| }) | |
| else: | |
| error_msg = result.get('error', 'Unknown error') | |
| logger.warning(f" ⚠️ Line {line_num} error: {error_msg}") | |
| results.append({ | |
| "content": "", | |
| "error": error_msg | |
| }) | |
| except json.JSONDecodeError as e: | |
| logger.error(f" ✗ Line {line_num}: JSON decode error: {e}") | |
| results.append({"content": "", "error": f"JSON decode error: {e}"}) | |
| return results | |
| def _extract_text_from_response(self, response_obj) -> str: | |
| """Extract text from response object.""" | |
| try: | |
| # Direct text attribute | |
| if hasattr(response_obj, 'text'): | |
| return response_obj.text | |
| # Navigate through candidates | |
| if hasattr(response_obj, 'candidates') and response_obj.candidates: | |
| candidate = response_obj.candidates[0] | |
| if hasattr(candidate, 'content'): | |
| content = candidate.content | |
| if hasattr(content, 'parts') and content.parts: | |
| part = content.parts[0] | |
| if hasattr(part, 'text'): | |
| return part.text | |
| # Fallback to string representation | |
| return str(response_obj) | |
| except Exception as e: | |
| logger.error(f"Error extracting text from response: {e}") | |
| return "" | |
| def _extract_text_from_dict(self, response_dict: Dict) -> str: | |
| """Extract text from response dictionary.""" | |
| try: | |
| # Direct text key | |
| if 'text' in response_dict: | |
| return response_dict['text'] | |
| # Navigate through candidates | |
| if 'candidates' in response_dict and response_dict['candidates']: | |
| candidate = response_dict['candidates'][0] | |
| if 'content' in candidate and 'parts' in candidate['content']: | |
| parts = candidate['content']['parts'] | |
| if parts and 'text' in parts[0]: | |
| return parts[0]['text'] | |
| # Fallback to JSON string | |
| return json.dumps(response_dict) | |
| except Exception as e: | |
| logger.error(f"Error extracting text from dict: {e}") | |
| return "" | |
| def _wait_for_file_active(self, uploaded_file, timeout: int = 60): | |
| """ | |
| Wait for uploaded file to become active. | |
| Args: | |
| uploaded_file: Uploaded file object | |
| timeout: Maximum seconds to wait | |
| Raises: | |
| TimeoutError: If file processing exceeds timeout | |
| RuntimeError: If file processing fails | |
| """ | |
| start_time = time.time() | |
| while uploaded_file.state.name == "PROCESSING": | |
| if time.time() - start_time > timeout: | |
| raise TimeoutError(f"File processing timeout: {uploaded_file.name}") | |
| time.sleep(1) | |
| uploaded_file = self.client.files.get(name=uploaded_file.name) | |
| if uploaded_file.state.name != "ACTIVE": | |
| raise RuntimeError( | |
| f"File processing failed: {uploaded_file.name} " | |
| f"(state: {uploaded_file.state.name})" | |
| ) | |
| def get_model_info(self) -> Dict[str, str]: | |
| """Get model information for logging and debugging.""" | |
| return { | |
| 'provider': self.provider, | |
| 'model_name': self.model_name, | |
| 'class': self.__class__.__name__, | |
| 'mode': 'batch', | |
| 'batch_size': str(self.batch_size), | |
| 'polling_interval': f'{self.polling_interval}s' | |
| } | |