Suhasdev's picture
Deploy Universal Prompt Optimizer to HF Spaces (clean)
cacd4d0
"""
Batch LLM Client for cost-effective processing using Gemini Batch API.
This client provides 50% cost savings by using Google's Gemini Batch API
instead of real-time API calls. Ideal for large-scale prompt optimization
where latency is acceptable.
Features:
- 50% cost reduction compared to standard API
- Automatic batching and job management
- Built-in retry and polling logic
- Thread-safe operation
- Comprehensive error handling
Author: GEPA Optimizer Team
"""
import os
import json
import time
import logging
import tempfile
import io
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from .base_llm import BaseLLMClient
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
Image = None
try:
from google import genai
from google.genai import types
GENAI_AVAILABLE = True
except ImportError:
GENAI_AVAILABLE = False
genai = None
types = None
logger = logging.getLogger(__name__)
class BatchLLMClient(BaseLLMClient):
"""
Batch LLM client that uses Gemini Batch API for cost-effective processing.
This client processes multiple requests together in batch jobs, providing:
- 50% cost savings vs standard API
- No rate limit impact
- Automatic job management and polling
Usage:
>>> from gepa_optimizer.llms import BatchLLMClient
>>>
>>> client = BatchLLMClient(
... provider="google",
... model_name="gemini-2.5-flash",
... api_key="your-key",
... batch_size=20,
... polling_interval=30
... )
>>>
>>> # Use just like VisionLLMClient - adapter handles the rest!
>>> result = client.generate(
... system_prompt="You are a helpful assistant",
... user_prompt="Analyze this image",
... image_base64="..."
... )
Performance Note:
Batch processing adds latency (30s+ polling time) but reduces costs by 50%.
Choose this mode for large-scale optimization where cost > speed.
"""
def __init__(
self,
provider: str,
model_name: str,
api_key: Optional[str] = None,
batch_size: int = 20,
polling_interval: int = 30,
max_polling_time: int = 3600,
temp_dir: str = ".gepa_batch_temp",
**kwargs
):
"""
Initialize Batch LLM Client.
Args:
provider: Must be "google" or "gemini"
model_name: Gemini model (e.g., "gemini-2.5-flash", "gemini-1.5-flash")
api_key: Google API key (defaults to GEMINI_API_KEY env var)
batch_size: Number of samples to process per batch job (1-100)
polling_interval: Seconds between job status checks (default: 30)
max_polling_time: Maximum seconds to wait for job completion (default: 3600)
temp_dir: Directory for temporary files (default: ".gepa_batch_temp")
**kwargs: Additional parameters
Raises:
ValueError: If provider is not Google/Gemini
ImportError: If google-genai is not installed
"""
super().__init__(provider=provider, model_name=model_name, **kwargs)
# Validate provider
if provider.lower() not in ["google", "gemini"]:
raise ValueError(
f"BatchLLMClient only supports Google/Gemini provider. Got: {provider}"
)
# Check dependencies
if not GENAI_AVAILABLE:
raise ImportError(
"google-genai not installed. Install with: pip install google-genai"
)
# Configuration
self.batch_size = batch_size
self.polling_interval = polling_interval
self.max_polling_time = max_polling_time
self.temp_dir = Path(temp_dir)
self.temp_dir.mkdir(exist_ok=True)
# Initialize Gemini client
from ..utils.api_keys import APIKeyManager
self.api_key = api_key or APIKeyManager().get_api_key("google")
if not self.api_key:
raise ValueError(
"Google API key required. Provide via api_key parameter or "
"set GEMINI_API_KEY environment variable."
)
self.client = genai.Client(api_key=self.api_key)
logger.info(
f"✓ BatchLLMClient initialized: {model_name} "
f"(batch_size={batch_size}, polling={polling_interval}s)"
)
def generate(
self,
system_prompt: str,
user_prompt: str,
image_base64: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Generate response using batch API.
Note: This method is primarily for compatibility. For batch optimization,
the adapter will call generate_batch() directly with multiple requests.
Args:
system_prompt: System-level instructions
user_prompt: User's input prompt
image_base64: Optional base64 encoded image
**kwargs: Additional generation parameters
Returns:
Dict with 'content' key containing generated text
"""
# Single request - process as a batch of 1
requests = [{
'system_prompt': system_prompt,
'user_prompt': user_prompt,
'image_base64': image_base64
}]
results = self.generate_batch(requests)
return results[0] if results else {"content": "", "error": "No results"}
def generate_batch(
self,
requests: List[Dict[str, Any]],
timeout_override: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Process multiple requests in a single batch job.
This is the main method called by UniversalGepaAdapter during GEPA optimization.
Args:
requests: List of request dicts with keys:
- system_prompt: System instructions
- user_prompt: User input
- image_base64: Optional base64 image
timeout_override: Override max_polling_time for this batch
Returns:
List of response dicts with 'content' key
Raises:
RuntimeError: If batch job fails
TimeoutError: If polling exceeds timeout
"""
logger.info(f"📦 Processing batch of {len(requests)} requests via Gemini Batch API...")
start_time = time.time()
try:
# Step 1: Upload images if needed
file_uris, mime_types = self._upload_images_for_batch(requests)
# Step 2: Create JSONL file
jsonl_path = self._create_batch_jsonl(requests, file_uris, mime_types)
# Step 3: Submit batch job
batch_job_name = self._submit_batch_job(jsonl_path)
# Step 4: Wait for completion
timeout = timeout_override or self.max_polling_time
self._wait_for_batch_completion(batch_job_name, timeout)
# Step 5: Retrieve results
results = self._retrieve_batch_results(batch_job_name)
# Cleanup
jsonl_path.unlink(missing_ok=True)
elapsed_time = time.time() - start_time
logger.info(
f"✓ Batch processing complete: {len(results)} results in {elapsed_time:.1f}s "
f"(~{elapsed_time/len(results):.1f}s per request)"
)
return results
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"❌ Batch processing failed after {elapsed_time:.1f}s: {e}")
raise
def _upload_images_for_batch(self, requests: List[Dict]) -> Tuple[List[Optional[str]], List[Optional[str]]]:
"""
Upload images to Gemini and return file URIs and MIME types.
Args:
requests: List of request dicts
Returns:
Tuple of (file_uris, mime_types) - both are lists with None for requests without images
"""
file_uris = []
mime_types = []
images_to_upload = sum(1 for r in requests if r.get('image_base64'))
if images_to_upload > 0:
logger.info(f" ⬆️ Uploading {images_to_upload} images to Gemini...")
for i, request in enumerate(requests):
image_base64 = request.get('image_base64')
if not image_base64:
file_uris.append(None)
mime_types.append(None)
continue
try:
# Decode image data
import base64
image_data = base64.b64decode(image_base64)
# Detect image format using Pillow
image_format = None
if PIL_AVAILABLE:
try:
img = Image.open(io.BytesIO(image_data))
image_format = img.format.lower() if img.format else None
except Exception as e:
logger.warning(f" ⚠️ Could not detect image format: {e}")
# Map format to extension and MIME type
format_map = {
'jpeg': ('.jpg', 'image/jpeg'),
'jpg': ('.jpg', 'image/jpeg'),
'png': ('.png', 'image/png'),
'gif': ('.gif', 'image/gif'),
'webp': ('.webp', 'image/webp'),
'bmp': ('.bmp', 'image/bmp'),
'tiff': ('.tiff', 'image/tiff'),
'tif': ('.tiff', 'image/tiff'),
}
# Get extension and MIME type (default to PNG if unknown)
ext, mime_type = format_map.get(image_format, ('.png', 'image/png'))
if image_format and image_format not in format_map:
logger.warning(f" ⚠️ Unknown image format '{image_format}' for image {i}, defaulting to PNG")
elif not image_format:
logger.debug(f" ℹ️ Could not detect format for image {i}, using PNG")
# Save to temp file with correct extension
temp_file = tempfile.NamedTemporaryFile(
delete=False,
suffix=ext,
dir=self.temp_dir
)
temp_file.write(image_data)
temp_file.close()
# Upload to Gemini with correct MIME type
uploaded_file = self.client.files.upload(
file=temp_file.name,
config=types.UploadFileConfig(
display_name=f"batch_image_{i}_{int(time.time())}{ext}",
mime_type=mime_type
)
)
logger.debug(f" ✓ Uploaded image {i} as {mime_type}")
# Wait for file to be active
self._wait_for_file_active(uploaded_file)
file_uris.append(uploaded_file.uri)
mime_types.append(mime_type)
# Cleanup temp file
Path(temp_file.name).unlink()
except Exception as e:
logger.error(f" ✗ Failed to upload image {i}: {e}")
file_uris.append(None)
mime_types.append(None)
if images_to_upload > 0:
successful = sum(1 for uri in file_uris if uri is not None)
logger.info(f" ✓ Uploaded {successful}/{images_to_upload} images successfully")
return file_uris, mime_types
def _create_batch_jsonl(
self,
requests: List[Dict],
file_uris: List[Optional[str]],
mime_types: List[Optional[str]]
) -> Path:
"""
Create JSONL file for batch job.
Args:
requests: List of request dicts
file_uris: List of uploaded file URIs
mime_types: List of MIME types for uploaded files
Returns:
Path to created JSONL file
"""
timestamp = int(time.time())
jsonl_path = self.temp_dir / f"batch_{timestamp}.jsonl"
with open(jsonl_path, 'w', encoding='utf-8') as f:
for i, (request, file_uri, mime_type) in enumerate(zip(requests, file_uris, mime_types)):
# Combine system and user prompts
system_prompt = request.get('system_prompt', '')
user_prompt = request.get('user_prompt', '')
full_prompt = f"{system_prompt}\n\n{user_prompt}".strip()
# Build request parts
parts = [{"text": full_prompt}]
if file_uri:
parts.append({
"file_data": {
"file_uri": file_uri,
"mime_type": mime_type or "image/png" # Use actual MIME type
}
})
# Gemini Batch API format according to official docs
# Reference: https://ai.google.dev/gemini-api/docs/batch-inference
# NOTE: The "request" wrapper is REQUIRED for Gemini 2.5 batch API
batch_request = {
"custom_id": f"request-{i}",
"request": {
"contents": [{
"role": "user",
"parts": parts
}]
}
}
f.write(json.dumps(batch_request, ensure_ascii=False) + '\n')
logger.info(f" 📝 Created JSONL file: {jsonl_path.name} ({len(requests)} requests)")
return jsonl_path
def _submit_batch_job(self, jsonl_path: Path) -> str:
"""
Submit batch job to Gemini.
Args:
jsonl_path: Path to JSONL file
Returns:
Batch job name
"""
# Upload JSONL file
# Try multiple methods as the google-genai SDK can be finicky
try:
logger.info(f" 📤 Uploading JSONL file: {jsonl_path.name}")
# Read and validate file content
with open(jsonl_path, 'r', encoding='utf-8') as f:
content = f.read()
line_count = len(content.strip().split('\n'))
logger.debug(f" 📄 JSONL: {len(content)} bytes, {line_count} lines")
# Validate JSONL format
for line_num, line in enumerate(content.strip().split('\n'), 1):
try:
json.loads(line)
except json.JSONDecodeError as e:
logger.error(f" ❌ Invalid JSON at line {line_num}: {e}")
logger.error(f" Content: {line[:100]}...")
raise ValueError(f"Invalid JSONL format at line {line_num}") from e
# Method 1: Try uploading with Path object
logger.info(f" 🔄 Upload method 1: Using Path object...")
try:
jsonl_file = self.client.files.upload(
file=jsonl_path,
config=types.UploadFileConfig(
display_name=f'gepa-batch-{int(time.time())}',
mime_type='application/json' # Try application/json instead of application/jsonl
)
)
logger.info(f" ✓ JSONL file uploaded: {jsonl_file.name}")
except Exception as e1:
logger.warning(f" ⚠️ Method 1 failed: {e1}")
logger.info(f" 🔄 Upload method 2: Using string path...")
# Method 2: Fallback to string path
try:
jsonl_file = self.client.files.upload(
file=str(jsonl_path.absolute()),
config=types.UploadFileConfig(
display_name=f'gepa-batch-{int(time.time())}',
mime_type='application/json'
)
)
logger.info(f" ✓ JSONL file uploaded (method 2): {jsonl_file.name}")
except Exception as e2:
logger.error(f" ❌ Method 2 also failed: {e2}")
raise e2
except KeyError as e:
logger.error(f"❌ KeyError during JSONL upload: {e}")
logger.error(f" This suggests the Gemini API response format changed")
logger.error(f" Try updating google-genai: pip install --upgrade google-genai")
raise RuntimeError(f"Gemini Batch API response format error: {e}") from e
except Exception as e:
logger.error(f"❌ Failed to upload JSONL file: {e}")
logger.error(f" File path: {jsonl_path}")
logger.error(f" File exists: {jsonl_path.exists()}")
logger.error(f" File size: {jsonl_path.stat().st_size if jsonl_path.exists() else 'N/A'} bytes")
raise RuntimeError(f"Gemini Batch API file upload failed: {e}") from e
# Wait for JSONL to be active
try:
logger.info(f" ⏳ Waiting for JSONL file to be processed...")
self._wait_for_file_active(jsonl_file)
except Exception as e:
logger.error(f"❌ JSONL file processing failed: {e}")
raise
# Create batch job
try:
logger.info(f" 🚀 Creating batch job...")
batch_job = self.client.batches.create(
model=self.model_name,
src=jsonl_file.name,
config={'display_name': f'gepa-opt-{int(time.time())}'}
)
logger.info(f" ✓ Batch job submitted: {batch_job.name}")
return batch_job.name
except Exception as e:
logger.error(f"❌ Failed to create batch job: {e}")
raise RuntimeError(f"Batch job creation failed: {e}") from e
def _wait_for_batch_completion(self, job_name: str, timeout: int):
"""
Poll batch job until completion.
Args:
job_name: Batch job name
timeout: Maximum seconds to wait
Raises:
TimeoutError: If polling exceeds timeout
RuntimeError: If batch job fails
"""
logger.info(f" ⏳ Polling for completion (checking every {self.polling_interval}s)...")
start_time = time.time()
poll_count = 0
while True:
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Batch job timeout after {elapsed:.0f}s "
f"(max: {timeout}s)"
)
try:
batch_job = self.client.batches.get(name=job_name)
state = batch_job.state.name
# Success states
if state in ['JOB_STATE_SUCCEEDED', 'SUCCEEDED']:
logger.info(f" ✓ Batch job completed in {elapsed:.0f}s")
return
# Failure states
if state in ['JOB_STATE_FAILED', 'FAILED']:
raise RuntimeError(f"Batch job failed with state: {state}")
if state in ['JOB_STATE_CANCELLED', 'CANCELLED']:
raise RuntimeError(f"Batch job was cancelled: {state}")
# Still processing
poll_count += 1
if poll_count % 5 == 0: # Log every 5 polls
logger.info(f" ... still processing ({elapsed:.0f}s elapsed, state: {state})")
time.sleep(self.polling_interval)
except (TimeoutError, RuntimeError):
raise
except Exception as e:
logger.warning(f" ⚠️ Error checking job status: {e}, retrying...")
time.sleep(5)
def _retrieve_batch_results(self, job_name: str) -> List[Dict[str, Any]]:
"""
Retrieve and parse batch results.
Args:
job_name: Batch job name
Returns:
List of result dicts
"""
batch_job = self.client.batches.get(name=job_name)
# Check for inline responses (preferred)
if hasattr(batch_job.dest, 'inlined_responses') and batch_job.dest.inlined_responses:
logger.info(f" 📥 Processing inline responses...")
return self._parse_inline_results(batch_job.dest.inlined_responses)
# Download results file (fallback)
if hasattr(batch_job.dest, 'file_name') and batch_job.dest.file_name:
logger.info(f" 📥 Downloading results file: {batch_job.dest.file_name}")
file_data = self.client.files.download(file=batch_job.dest.file_name)
return self._parse_file_results(file_data)
raise RuntimeError("No results available from batch job")
def _parse_inline_results(self, inline_responses) -> List[Dict[str, Any]]:
"""Parse inline batch results."""
results = []
for response_obj in inline_responses:
if hasattr(response_obj, 'response') and response_obj.response:
text = self._extract_text_from_response(response_obj.response)
results.append({
"content": text,
"role": "assistant",
"model": self.model_name,
"provider": "google"
})
else:
error_msg = str(getattr(response_obj, 'error', 'Unknown error'))
logger.warning(f" ⚠️ Response error: {error_msg}")
results.append({
"content": "",
"error": error_msg
})
return results
def _parse_file_results(self, file_data) -> List[Dict[str, Any]]:
"""Parse JSONL results file."""
if isinstance(file_data, bytes):
jsonl_content = file_data.decode('utf-8')
else:
jsonl_content = file_data
results = []
for line_num, line in enumerate(jsonl_content.strip().split('\n'), 1):
if not line.strip():
continue
try:
result = json.loads(line)
if 'response' in result:
text = self._extract_text_from_dict(result['response'])
results.append({
"content": text,
"role": "assistant",
"model": self.model_name,
"provider": "google"
})
else:
error_msg = result.get('error', 'Unknown error')
logger.warning(f" ⚠️ Line {line_num} error: {error_msg}")
results.append({
"content": "",
"error": error_msg
})
except json.JSONDecodeError as e:
logger.error(f" ✗ Line {line_num}: JSON decode error: {e}")
results.append({"content": "", "error": f"JSON decode error: {e}"})
return results
def _extract_text_from_response(self, response_obj) -> str:
"""Extract text from response object."""
try:
# Direct text attribute
if hasattr(response_obj, 'text'):
return response_obj.text
# Navigate through candidates
if hasattr(response_obj, 'candidates') and response_obj.candidates:
candidate = response_obj.candidates[0]
if hasattr(candidate, 'content'):
content = candidate.content
if hasattr(content, 'parts') and content.parts:
part = content.parts[0]
if hasattr(part, 'text'):
return part.text
# Fallback to string representation
return str(response_obj)
except Exception as e:
logger.error(f"Error extracting text from response: {e}")
return ""
def _extract_text_from_dict(self, response_dict: Dict) -> str:
"""Extract text from response dictionary."""
try:
# Direct text key
if 'text' in response_dict:
return response_dict['text']
# Navigate through candidates
if 'candidates' in response_dict and response_dict['candidates']:
candidate = response_dict['candidates'][0]
if 'content' in candidate and 'parts' in candidate['content']:
parts = candidate['content']['parts']
if parts and 'text' in parts[0]:
return parts[0]['text']
# Fallback to JSON string
return json.dumps(response_dict)
except Exception as e:
logger.error(f"Error extracting text from dict: {e}")
return ""
def _wait_for_file_active(self, uploaded_file, timeout: int = 60):
"""
Wait for uploaded file to become active.
Args:
uploaded_file: Uploaded file object
timeout: Maximum seconds to wait
Raises:
TimeoutError: If file processing exceeds timeout
RuntimeError: If file processing fails
"""
start_time = time.time()
while uploaded_file.state.name == "PROCESSING":
if time.time() - start_time > timeout:
raise TimeoutError(f"File processing timeout: {uploaded_file.name}")
time.sleep(1)
uploaded_file = self.client.files.get(name=uploaded_file.name)
if uploaded_file.state.name != "ACTIVE":
raise RuntimeError(
f"File processing failed: {uploaded_file.name} "
f"(state: {uploaded_file.state.name})"
)
def get_model_info(self) -> Dict[str, str]:
"""Get model information for logging and debugging."""
return {
'provider': self.provider,
'model_name': self.model_name,
'class': self.__class__.__name__,
'mode': 'batch',
'batch_size': str(self.batch_size),
'polling_interval': f'{self.polling_interval}s'
}