""" Batch LLM Client for cost-effective processing using Gemini Batch API. This client provides 50% cost savings by using Google's Gemini Batch API instead of real-time API calls. Ideal for large-scale prompt optimization where latency is acceptable. Features: - 50% cost reduction compared to standard API - Automatic batching and job management - Built-in retry and polling logic - Thread-safe operation - Comprehensive error handling Author: GEPA Optimizer Team """ import os import json import time import logging import tempfile import io from pathlib import Path from typing import Dict, List, Any, Optional, Tuple from .base_llm import BaseLLMClient try: from PIL import Image PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False Image = None try: from google import genai from google.genai import types GENAI_AVAILABLE = True except ImportError: GENAI_AVAILABLE = False genai = None types = None logger = logging.getLogger(__name__) class BatchLLMClient(BaseLLMClient): """ Batch LLM client that uses Gemini Batch API for cost-effective processing. This client processes multiple requests together in batch jobs, providing: - 50% cost savings vs standard API - No rate limit impact - Automatic job management and polling Usage: >>> from gepa_optimizer.llms import BatchLLMClient >>> >>> client = BatchLLMClient( ... provider="google", ... model_name="gemini-2.5-flash", ... api_key="your-key", ... batch_size=20, ... polling_interval=30 ... ) >>> >>> # Use just like VisionLLMClient - adapter handles the rest! >>> result = client.generate( ... system_prompt="You are a helpful assistant", ... user_prompt="Analyze this image", ... image_base64="..." ... ) Performance Note: Batch processing adds latency (30s+ polling time) but reduces costs by 50%. Choose this mode for large-scale optimization where cost > speed. """ def __init__( self, provider: str, model_name: str, api_key: Optional[str] = None, batch_size: int = 20, polling_interval: int = 30, max_polling_time: int = 3600, temp_dir: str = ".gepa_batch_temp", **kwargs ): """ Initialize Batch LLM Client. Args: provider: Must be "google" or "gemini" model_name: Gemini model (e.g., "gemini-2.5-flash", "gemini-1.5-flash") api_key: Google API key (defaults to GEMINI_API_KEY env var) batch_size: Number of samples to process per batch job (1-100) polling_interval: Seconds between job status checks (default: 30) max_polling_time: Maximum seconds to wait for job completion (default: 3600) temp_dir: Directory for temporary files (default: ".gepa_batch_temp") **kwargs: Additional parameters Raises: ValueError: If provider is not Google/Gemini ImportError: If google-genai is not installed """ super().__init__(provider=provider, model_name=model_name, **kwargs) # Validate provider if provider.lower() not in ["google", "gemini"]: raise ValueError( f"BatchLLMClient only supports Google/Gemini provider. Got: {provider}" ) # Check dependencies if not GENAI_AVAILABLE: raise ImportError( "google-genai not installed. Install with: pip install google-genai" ) # Configuration self.batch_size = batch_size self.polling_interval = polling_interval self.max_polling_time = max_polling_time self.temp_dir = Path(temp_dir) self.temp_dir.mkdir(exist_ok=True) # Initialize Gemini client from ..utils.api_keys import APIKeyManager self.api_key = api_key or APIKeyManager().get_api_key("google") if not self.api_key: raise ValueError( "Google API key required. Provide via api_key parameter or " "set GEMINI_API_KEY environment variable." ) self.client = genai.Client(api_key=self.api_key) logger.info( f"✓ BatchLLMClient initialized: {model_name} " f"(batch_size={batch_size}, polling={polling_interval}s)" ) def generate( self, system_prompt: str, user_prompt: str, image_base64: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ Generate response using batch API. Note: This method is primarily for compatibility. For batch optimization, the adapter will call generate_batch() directly with multiple requests. Args: system_prompt: System-level instructions user_prompt: User's input prompt image_base64: Optional base64 encoded image **kwargs: Additional generation parameters Returns: Dict with 'content' key containing generated text """ # Single request - process as a batch of 1 requests = [{ 'system_prompt': system_prompt, 'user_prompt': user_prompt, 'image_base64': image_base64 }] results = self.generate_batch(requests) return results[0] if results else {"content": "", "error": "No results"} def generate_batch( self, requests: List[Dict[str, Any]], timeout_override: Optional[int] = None ) -> List[Dict[str, Any]]: """ Process multiple requests in a single batch job. This is the main method called by UniversalGepaAdapter during GEPA optimization. Args: requests: List of request dicts with keys: - system_prompt: System instructions - user_prompt: User input - image_base64: Optional base64 image timeout_override: Override max_polling_time for this batch Returns: List of response dicts with 'content' key Raises: RuntimeError: If batch job fails TimeoutError: If polling exceeds timeout """ logger.info(f"đŸ“Ļ Processing batch of {len(requests)} requests via Gemini Batch API...") start_time = time.time() try: # Step 1: Upload images if needed file_uris, mime_types = self._upload_images_for_batch(requests) # Step 2: Create JSONL file jsonl_path = self._create_batch_jsonl(requests, file_uris, mime_types) # Step 3: Submit batch job batch_job_name = self._submit_batch_job(jsonl_path) # Step 4: Wait for completion timeout = timeout_override or self.max_polling_time self._wait_for_batch_completion(batch_job_name, timeout) # Step 5: Retrieve results results = self._retrieve_batch_results(batch_job_name) # Cleanup jsonl_path.unlink(missing_ok=True) elapsed_time = time.time() - start_time logger.info( f"✓ Batch processing complete: {len(results)} results in {elapsed_time:.1f}s " f"(~{elapsed_time/len(results):.1f}s per request)" ) return results except Exception as e: elapsed_time = time.time() - start_time logger.error(f"❌ Batch processing failed after {elapsed_time:.1f}s: {e}") raise def _upload_images_for_batch(self, requests: List[Dict]) -> Tuple[List[Optional[str]], List[Optional[str]]]: """ Upload images to Gemini and return file URIs and MIME types. Args: requests: List of request dicts Returns: Tuple of (file_uris, mime_types) - both are lists with None for requests without images """ file_uris = [] mime_types = [] images_to_upload = sum(1 for r in requests if r.get('image_base64')) if images_to_upload > 0: logger.info(f" âŦ†ī¸ Uploading {images_to_upload} images to Gemini...") for i, request in enumerate(requests): image_base64 = request.get('image_base64') if not image_base64: file_uris.append(None) mime_types.append(None) continue try: # Decode image data import base64 image_data = base64.b64decode(image_base64) # Detect image format using Pillow image_format = None if PIL_AVAILABLE: try: img = Image.open(io.BytesIO(image_data)) image_format = img.format.lower() if img.format else None except Exception as e: logger.warning(f" âš ī¸ Could not detect image format: {e}") # Map format to extension and MIME type format_map = { 'jpeg': ('.jpg', 'image/jpeg'), 'jpg': ('.jpg', 'image/jpeg'), 'png': ('.png', 'image/png'), 'gif': ('.gif', 'image/gif'), 'webp': ('.webp', 'image/webp'), 'bmp': ('.bmp', 'image/bmp'), 'tiff': ('.tiff', 'image/tiff'), 'tif': ('.tiff', 'image/tiff'), } # Get extension and MIME type (default to PNG if unknown) ext, mime_type = format_map.get(image_format, ('.png', 'image/png')) if image_format and image_format not in format_map: logger.warning(f" âš ī¸ Unknown image format '{image_format}' for image {i}, defaulting to PNG") elif not image_format: logger.debug(f" â„šī¸ Could not detect format for image {i}, using PNG") # Save to temp file with correct extension temp_file = tempfile.NamedTemporaryFile( delete=False, suffix=ext, dir=self.temp_dir ) temp_file.write(image_data) temp_file.close() # Upload to Gemini with correct MIME type uploaded_file = self.client.files.upload( file=temp_file.name, config=types.UploadFileConfig( display_name=f"batch_image_{i}_{int(time.time())}{ext}", mime_type=mime_type ) ) logger.debug(f" ✓ Uploaded image {i} as {mime_type}") # Wait for file to be active self._wait_for_file_active(uploaded_file) file_uris.append(uploaded_file.uri) mime_types.append(mime_type) # Cleanup temp file Path(temp_file.name).unlink() except Exception as e: logger.error(f" ✗ Failed to upload image {i}: {e}") file_uris.append(None) mime_types.append(None) if images_to_upload > 0: successful = sum(1 for uri in file_uris if uri is not None) logger.info(f" ✓ Uploaded {successful}/{images_to_upload} images successfully") return file_uris, mime_types def _create_batch_jsonl( self, requests: List[Dict], file_uris: List[Optional[str]], mime_types: List[Optional[str]] ) -> Path: """ Create JSONL file for batch job. Args: requests: List of request dicts file_uris: List of uploaded file URIs mime_types: List of MIME types for uploaded files Returns: Path to created JSONL file """ timestamp = int(time.time()) jsonl_path = self.temp_dir / f"batch_{timestamp}.jsonl" with open(jsonl_path, 'w', encoding='utf-8') as f: for i, (request, file_uri, mime_type) in enumerate(zip(requests, file_uris, mime_types)): # Combine system and user prompts system_prompt = request.get('system_prompt', '') user_prompt = request.get('user_prompt', '') full_prompt = f"{system_prompt}\n\n{user_prompt}".strip() # Build request parts parts = [{"text": full_prompt}] if file_uri: parts.append({ "file_data": { "file_uri": file_uri, "mime_type": mime_type or "image/png" # Use actual MIME type } }) # Gemini Batch API format according to official docs # Reference: https://ai.google.dev/gemini-api/docs/batch-inference # NOTE: The "request" wrapper is REQUIRED for Gemini 2.5 batch API batch_request = { "custom_id": f"request-{i}", "request": { "contents": [{ "role": "user", "parts": parts }] } } f.write(json.dumps(batch_request, ensure_ascii=False) + '\n') logger.info(f" 📝 Created JSONL file: {jsonl_path.name} ({len(requests)} requests)") return jsonl_path def _submit_batch_job(self, jsonl_path: Path) -> str: """ Submit batch job to Gemini. Args: jsonl_path: Path to JSONL file Returns: Batch job name """ # Upload JSONL file # Try multiple methods as the google-genai SDK can be finicky try: logger.info(f" 📤 Uploading JSONL file: {jsonl_path.name}") # Read and validate file content with open(jsonl_path, 'r', encoding='utf-8') as f: content = f.read() line_count = len(content.strip().split('\n')) logger.debug(f" 📄 JSONL: {len(content)} bytes, {line_count} lines") # Validate JSONL format for line_num, line in enumerate(content.strip().split('\n'), 1): try: json.loads(line) except json.JSONDecodeError as e: logger.error(f" ❌ Invalid JSON at line {line_num}: {e}") logger.error(f" Content: {line[:100]}...") raise ValueError(f"Invalid JSONL format at line {line_num}") from e # Method 1: Try uploading with Path object logger.info(f" 🔄 Upload method 1: Using Path object...") try: jsonl_file = self.client.files.upload( file=jsonl_path, config=types.UploadFileConfig( display_name=f'gepa-batch-{int(time.time())}', mime_type='application/json' # Try application/json instead of application/jsonl ) ) logger.info(f" ✓ JSONL file uploaded: {jsonl_file.name}") except Exception as e1: logger.warning(f" âš ī¸ Method 1 failed: {e1}") logger.info(f" 🔄 Upload method 2: Using string path...") # Method 2: Fallback to string path try: jsonl_file = self.client.files.upload( file=str(jsonl_path.absolute()), config=types.UploadFileConfig( display_name=f'gepa-batch-{int(time.time())}', mime_type='application/json' ) ) logger.info(f" ✓ JSONL file uploaded (method 2): {jsonl_file.name}") except Exception as e2: logger.error(f" ❌ Method 2 also failed: {e2}") raise e2 except KeyError as e: logger.error(f"❌ KeyError during JSONL upload: {e}") logger.error(f" This suggests the Gemini API response format changed") logger.error(f" Try updating google-genai: pip install --upgrade google-genai") raise RuntimeError(f"Gemini Batch API response format error: {e}") from e except Exception as e: logger.error(f"❌ Failed to upload JSONL file: {e}") logger.error(f" File path: {jsonl_path}") logger.error(f" File exists: {jsonl_path.exists()}") logger.error(f" File size: {jsonl_path.stat().st_size if jsonl_path.exists() else 'N/A'} bytes") raise RuntimeError(f"Gemini Batch API file upload failed: {e}") from e # Wait for JSONL to be active try: logger.info(f" âŗ Waiting for JSONL file to be processed...") self._wait_for_file_active(jsonl_file) except Exception as e: logger.error(f"❌ JSONL file processing failed: {e}") raise # Create batch job try: logger.info(f" 🚀 Creating batch job...") batch_job = self.client.batches.create( model=self.model_name, src=jsonl_file.name, config={'display_name': f'gepa-opt-{int(time.time())}'} ) logger.info(f" ✓ Batch job submitted: {batch_job.name}") return batch_job.name except Exception as e: logger.error(f"❌ Failed to create batch job: {e}") raise RuntimeError(f"Batch job creation failed: {e}") from e def _wait_for_batch_completion(self, job_name: str, timeout: int): """ Poll batch job until completion. Args: job_name: Batch job name timeout: Maximum seconds to wait Raises: TimeoutError: If polling exceeds timeout RuntimeError: If batch job fails """ logger.info(f" âŗ Polling for completion (checking every {self.polling_interval}s)...") start_time = time.time() poll_count = 0 while True: elapsed = time.time() - start_time if elapsed > timeout: raise TimeoutError( f"Batch job timeout after {elapsed:.0f}s " f"(max: {timeout}s)" ) try: batch_job = self.client.batches.get(name=job_name) state = batch_job.state.name # Success states if state in ['JOB_STATE_SUCCEEDED', 'SUCCEEDED']: logger.info(f" ✓ Batch job completed in {elapsed:.0f}s") return # Failure states if state in ['JOB_STATE_FAILED', 'FAILED']: raise RuntimeError(f"Batch job failed with state: {state}") if state in ['JOB_STATE_CANCELLED', 'CANCELLED']: raise RuntimeError(f"Batch job was cancelled: {state}") # Still processing poll_count += 1 if poll_count % 5 == 0: # Log every 5 polls logger.info(f" ... still processing ({elapsed:.0f}s elapsed, state: {state})") time.sleep(self.polling_interval) except (TimeoutError, RuntimeError): raise except Exception as e: logger.warning(f" âš ī¸ Error checking job status: {e}, retrying...") time.sleep(5) def _retrieve_batch_results(self, job_name: str) -> List[Dict[str, Any]]: """ Retrieve and parse batch results. Args: job_name: Batch job name Returns: List of result dicts """ batch_job = self.client.batches.get(name=job_name) # Check for inline responses (preferred) if hasattr(batch_job.dest, 'inlined_responses') and batch_job.dest.inlined_responses: logger.info(f" đŸ“Ĩ Processing inline responses...") return self._parse_inline_results(batch_job.dest.inlined_responses) # Download results file (fallback) if hasattr(batch_job.dest, 'file_name') and batch_job.dest.file_name: logger.info(f" đŸ“Ĩ Downloading results file: {batch_job.dest.file_name}") file_data = self.client.files.download(file=batch_job.dest.file_name) return self._parse_file_results(file_data) raise RuntimeError("No results available from batch job") def _parse_inline_results(self, inline_responses) -> List[Dict[str, Any]]: """Parse inline batch results.""" results = [] for response_obj in inline_responses: if hasattr(response_obj, 'response') and response_obj.response: text = self._extract_text_from_response(response_obj.response) results.append({ "content": text, "role": "assistant", "model": self.model_name, "provider": "google" }) else: error_msg = str(getattr(response_obj, 'error', 'Unknown error')) logger.warning(f" âš ī¸ Response error: {error_msg}") results.append({ "content": "", "error": error_msg }) return results def _parse_file_results(self, file_data) -> List[Dict[str, Any]]: """Parse JSONL results file.""" if isinstance(file_data, bytes): jsonl_content = file_data.decode('utf-8') else: jsonl_content = file_data results = [] for line_num, line in enumerate(jsonl_content.strip().split('\n'), 1): if not line.strip(): continue try: result = json.loads(line) if 'response' in result: text = self._extract_text_from_dict(result['response']) results.append({ "content": text, "role": "assistant", "model": self.model_name, "provider": "google" }) else: error_msg = result.get('error', 'Unknown error') logger.warning(f" âš ī¸ Line {line_num} error: {error_msg}") results.append({ "content": "", "error": error_msg }) except json.JSONDecodeError as e: logger.error(f" ✗ Line {line_num}: JSON decode error: {e}") results.append({"content": "", "error": f"JSON decode error: {e}"}) return results def _extract_text_from_response(self, response_obj) -> str: """Extract text from response object.""" try: # Direct text attribute if hasattr(response_obj, 'text'): return response_obj.text # Navigate through candidates if hasattr(response_obj, 'candidates') and response_obj.candidates: candidate = response_obj.candidates[0] if hasattr(candidate, 'content'): content = candidate.content if hasattr(content, 'parts') and content.parts: part = content.parts[0] if hasattr(part, 'text'): return part.text # Fallback to string representation return str(response_obj) except Exception as e: logger.error(f"Error extracting text from response: {e}") return "" def _extract_text_from_dict(self, response_dict: Dict) -> str: """Extract text from response dictionary.""" try: # Direct text key if 'text' in response_dict: return response_dict['text'] # Navigate through candidates if 'candidates' in response_dict and response_dict['candidates']: candidate = response_dict['candidates'][0] if 'content' in candidate and 'parts' in candidate['content']: parts = candidate['content']['parts'] if parts and 'text' in parts[0]: return parts[0]['text'] # Fallback to JSON string return json.dumps(response_dict) except Exception as e: logger.error(f"Error extracting text from dict: {e}") return "" def _wait_for_file_active(self, uploaded_file, timeout: int = 60): """ Wait for uploaded file to become active. Args: uploaded_file: Uploaded file object timeout: Maximum seconds to wait Raises: TimeoutError: If file processing exceeds timeout RuntimeError: If file processing fails """ start_time = time.time() while uploaded_file.state.name == "PROCESSING": if time.time() - start_time > timeout: raise TimeoutError(f"File processing timeout: {uploaded_file.name}") time.sleep(1) uploaded_file = self.client.files.get(name=uploaded_file.name) if uploaded_file.state.name != "ACTIVE": raise RuntimeError( f"File processing failed: {uploaded_file.name} " f"(state: {uploaded_file.state.name})" ) def get_model_info(self) -> Dict[str, str]: """Get model information for logging and debugging.""" return { 'provider': self.provider, 'model_name': self.model_name, 'class': self.__class__.__name__, 'mode': 'batch', 'batch_size': str(self.batch_size), 'polling_interval': f'{self.polling_interval}s' }