|
|
""" |
|
|
Main GepaOptimizer class - the heart of the optimization system |
|
|
""" |
|
|
|
|
|
import time |
|
|
import logging |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
import asyncio |
|
|
import io |
|
|
import sys |
|
|
from contextlib import redirect_stdout, redirect_stderr |
|
|
|
|
|
import gepa |
|
|
from ..utils.api_keys import APIKeyManager |
|
|
from .result import ResultProcessor |
|
|
from ..data.converters import UniversalConverter |
|
|
from ..models.result import OptimizationResult, OptimizedResult |
|
|
from ..models.config import OptimizationConfig, ModelConfig |
|
|
from ..utils.helpers import sanitize_prompt |
|
|
from ..utils.exceptions import GepaDependencyError, InvalidInputError, DatasetError, GepaOptimizerError |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GepaOptimizer: |
|
|
""" |
|
|
Main class for prompt optimization using GEPA |
|
|
|
|
|
This is the primary interface that users interact with. |
|
|
Provides both simple and advanced optimization capabilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[OptimizationConfig] = None, |
|
|
adapter_type: str = "universal", |
|
|
custom_adapter: Optional[Any] = None, |
|
|
llm_model_name: Optional[str] = None, |
|
|
metric_weights: Optional[Dict[str, float]] = None, |
|
|
**kwargs): |
|
|
""" |
|
|
Initialize the optimizer |
|
|
|
|
|
Args: |
|
|
config: Optimization configuration (required) |
|
|
adapter_type: Type of adapter to use ("universal" only - fully configurable) |
|
|
custom_adapter: Custom adapter instance (overrides adapter_type) |
|
|
llm_model_name: [Deprecated] Use config.model instead. Will be removed in future versions. |
|
|
metric_weights: [Deprecated] Not used - evaluator handles metrics. Will be removed in future versions. |
|
|
**kwargs: Additional parameters for universal adapter (llm_client, evaluator, etc.) |
|
|
|
|
|
Raises: |
|
|
ValueError: If required configuration is missing |
|
|
GepaDependencyError: If GEPA library is not available |
|
|
""" |
|
|
if config is None: |
|
|
raise ValueError("config parameter is required. Use OptimizationConfig to configure the optimizer.") |
|
|
|
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self.config = config |
|
|
self.converter = UniversalConverter(data_split_config=config.data_split) |
|
|
self.api_manager = APIKeyManager() |
|
|
self.result_processor = ResultProcessor() |
|
|
|
|
|
|
|
|
if custom_adapter: |
|
|
|
|
|
from .base_adapter import BaseGepaAdapter |
|
|
if not isinstance(custom_adapter, BaseGepaAdapter): |
|
|
raise TypeError("custom_adapter must be an instance of BaseGepaAdapter") |
|
|
self.adapter = custom_adapter |
|
|
self.logger.info("Using user-provided custom adapter") |
|
|
elif adapter_type == "universal": |
|
|
|
|
|
llm_client = kwargs.get('llm_client') |
|
|
evaluator = kwargs.get('evaluator') |
|
|
|
|
|
if not llm_client or not evaluator: |
|
|
raise ValueError( |
|
|
"llm_client and evaluator are required for universal adapter. " |
|
|
"Example: GepaOptimizer(config=config, adapter_type='universal', " |
|
|
"llm_client=llm_client, evaluator=evaluator)" |
|
|
) |
|
|
|
|
|
from .universal_adapter import UniversalGepaAdapter |
|
|
self.adapter = UniversalGepaAdapter( |
|
|
llm_client=llm_client, |
|
|
evaluator=evaluator, |
|
|
data_converter=kwargs.get('data_converter') |
|
|
) |
|
|
self.logger.info("Using universal adapter") |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown adapter_type: {adapter_type}. " |
|
|
f"Only 'universal' is supported. " |
|
|
f"Provide llm_client and evaluator when using universal adapter." |
|
|
) |
|
|
|
|
|
|
|
|
self.custom_adapter = self.adapter |
|
|
|
|
|
|
|
|
model_info = self.adapter.get_performance_stats() |
|
|
self.logger.info(f"Initialized adapter: {model_info}") |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
|
|
|
if gepa is None: |
|
|
raise GepaDependencyError("GEPA library is not available. Please install it with: pip install gepa") |
|
|
|
|
|
async def train(self, |
|
|
seed_prompt: str, |
|
|
dataset: Union[List[Any], str, Dict, Any], |
|
|
**kwargs) -> OptimizedResult: |
|
|
""" |
|
|
Main training method for prompt optimization |
|
|
|
|
|
Args: |
|
|
seed_prompt: Initial prompt to optimize |
|
|
dataset: Training data in any format |
|
|
**kwargs: Additional parameters that can override config |
|
|
|
|
|
Returns: |
|
|
OptimizedResult: Optimization result with improved prompt |
|
|
|
|
|
Raises: |
|
|
InvalidInputError: For invalid input parameters |
|
|
DatasetError: For issues with dataset processing |
|
|
GepaOptimizerError: For optimization failures |
|
|
""" |
|
|
start_time = time.time() |
|
|
session_id = f"opt_{int(start_time)}_{id(self)}" |
|
|
|
|
|
try: |
|
|
self.logger.info(f"Starting optimization session: {session_id}") |
|
|
self.logger.info(f"Using model: {self.config.model.model_name} (provider: {self.config.model.provider})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..utils.pareto_logger import reset_pareto_logger |
|
|
reset_pareto_logger() |
|
|
self.logger.info("✅ Reset Pareto logger for new optimization run") |
|
|
|
|
|
|
|
|
self._update_config_from_kwargs(kwargs) |
|
|
|
|
|
|
|
|
self._validate_inputs(seed_prompt) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(dataset, dict) and all(k in dataset for k in ['train', 'val', 'test']): |
|
|
|
|
|
self.logger.info("✅ Detected pre-split dataset - using user's split (no re-splitting)") |
|
|
trainset_raw = dataset.get('train', []) |
|
|
valset_raw = dataset.get('val', []) |
|
|
testset_raw = dataset.get('test', []) |
|
|
|
|
|
|
|
|
trainset = self.converter._standardize(trainset_raw) |
|
|
valset = self.converter._standardize(valset_raw) |
|
|
testset = self.converter._standardize(testset_raw) if testset_raw else [] |
|
|
|
|
|
self.logger.info( |
|
|
f"Using pre-split dataset: {len(trainset)} train (Dfeedback), " |
|
|
f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)" |
|
|
) |
|
|
else: |
|
|
|
|
|
self.logger.info("Converting dataset to GEPA format with 3-way split...") |
|
|
trainset, valset, testset = self.converter.convert( |
|
|
dataset, |
|
|
split_config=self.config.data_split |
|
|
) |
|
|
|
|
|
|
|
|
split_strategy = self.config.data_split.small_dataset_strategy |
|
|
strategy_note = "" |
|
|
if split_strategy == 'adaptive': |
|
|
total_size = len(trainset) + len(valset) + len(testset) |
|
|
train_ratio, val_ratio, test_ratio = self.config.data_split.get_adaptive_ratios(total_size) |
|
|
strategy_note = f" (adaptive: {train_ratio*100:.0f}%/{val_ratio*100:.0f}%/{test_ratio*100:.0f}% ratios)" |
|
|
self.logger.info( |
|
|
f"Dataset split{strategy_note}: {len(trainset)} train (Dfeedback), " |
|
|
f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)" |
|
|
) |
|
|
|
|
|
if not trainset: |
|
|
raise DatasetError("Dataset appears to be empty after conversion") |
|
|
|
|
|
|
|
|
seed_candidate = self._create_seed_candidate(seed_prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
self.adapter._valset_size = len(valset) if valset else 0 |
|
|
self.logger.info(f"✅ Set valset_size in adapter: {len(valset) if valset else 0} for Dpareto detection") |
|
|
except AttributeError: |
|
|
self.logger.warning("⚠️ Could not set _valset_size in adapter - attribute not supported") |
|
|
|
|
|
try: |
|
|
self.adapter._valset = valset |
|
|
self.logger.info(f"✅ Stored valset in adapter ({len(valset) if valset else 0} samples)") |
|
|
except AttributeError: |
|
|
self.logger.warning("⚠️ Could not set _valset in adapter - attribute not supported") |
|
|
|
|
|
|
|
|
|
|
|
baseline_val_score = None |
|
|
if valset: |
|
|
self.logger.info("📊 Evaluating seed prompt on validation set for baseline...") |
|
|
|
|
|
|
|
|
try: |
|
|
self.adapter._is_baseline_evaluation = True |
|
|
self.logger.info("✅ Set baseline evaluation flag in adapter") |
|
|
except AttributeError: |
|
|
self.logger.warning("⚠️ Could not set _is_baseline_evaluation in adapter") |
|
|
|
|
|
try: |
|
|
|
|
|
eval_result = self.adapter.evaluate( |
|
|
batch=valset, |
|
|
candidate=seed_candidate, |
|
|
capture_traces=False |
|
|
) |
|
|
baseline_val_score = sum(eval_result.scores) / len(eval_result.scores) if eval_result.scores else 0.0 |
|
|
self.logger.info(f"📊 Baseline validation score: {baseline_val_score:.4f} (on {len(valset)} samples)") |
|
|
|
|
|
|
|
|
if hasattr(self.adapter, '_baseline_score'): |
|
|
self.adapter._baseline_score = baseline_val_score |
|
|
|
|
|
|
|
|
|
|
|
from ..utils.pareto_logger import get_pareto_logger |
|
|
pareto_log = get_pareto_logger() |
|
|
pareto_log.set_baseline(baseline_val_score) |
|
|
self.logger.info(f"✅ Baseline set in Pareto logger: {baseline_val_score:.4f}") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.warning(f"Baseline evaluation failed: {e}") |
|
|
import traceback |
|
|
self.logger.debug(f"Baseline evaluation error: {traceback.format_exc()}") |
|
|
finally: |
|
|
try: |
|
|
self.adapter._is_baseline_evaluation = False |
|
|
self.logger.debug("✅ Reset baseline evaluation flag - optimization can begin") |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
|
|
|
self.logger.info("Starting GEPA optimization...") |
|
|
gepa_result, actual_iterations = await self._run_gepa_optimization( |
|
|
adapter=self.adapter, |
|
|
seed_candidate=seed_candidate, |
|
|
trainset=trainset, |
|
|
valset=valset, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
best_candidate = self._extract_best_candidate(gepa_result) |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"\n{'═'*80}") |
|
|
self.logger.info(f"📝 EXTRACTING OPTIMIZED PROMPT FROM GEPA RESULT") |
|
|
self.logger.info(f"{'═'*80}") |
|
|
self.logger.info(f"best_candidate keys: {list(best_candidate.keys()) if isinstance(best_candidate, dict) else 'N/A'}") |
|
|
|
|
|
optimized_prompt = best_candidate.get('system_prompt', seed_prompt) |
|
|
if not optimized_prompt or optimized_prompt.strip() == '': |
|
|
|
|
|
optimized_prompt = best_candidate.get('prompt', best_candidate.get('text', seed_prompt)) |
|
|
|
|
|
|
|
|
best_fitness = best_candidate.get('fitness') or self.adapter.get_best_score() if hasattr(self.adapter, 'get_best_score') else None |
|
|
candidate_source = best_candidate.get('source', 'unknown') |
|
|
|
|
|
self.logger.info(f"\n✅ EXTRACTED OPTIMIZED PROMPT:") |
|
|
self.logger.info(f" Source: {candidate_source}") |
|
|
if best_fitness is not None: |
|
|
self.logger.info(f" Fitness: f={best_fitness:.4f}") |
|
|
self.logger.info(f" Length: {len(optimized_prompt)} characters") |
|
|
self.logger.info(f" Words: {len(optimized_prompt.split())} words") |
|
|
self.logger.info(f"\n📝 FULL OPTIMIZED PROMPT TEXT:") |
|
|
self.logger.info(f"{'─'*80}") |
|
|
self.logger.info(optimized_prompt) |
|
|
self.logger.info(f"{'─'*80}") |
|
|
|
|
|
if optimized_prompt != seed_prompt: |
|
|
self.logger.info(f"\n✅ SUCCESS: Prompt WAS OPTIMIZED!") |
|
|
self.logger.info(f" Seed length: {len(seed_prompt)} chars") |
|
|
self.logger.info(f" Optimized length: {len(optimized_prompt)} chars") |
|
|
self.logger.info(f" Difference: {len(optimized_prompt) - len(seed_prompt):+d} chars") |
|
|
if best_fitness is not None: |
|
|
baseline_fitness = 0.5 |
|
|
improvement = best_fitness - baseline_fitness |
|
|
improvement_pct = (improvement / baseline_fitness * 100) if baseline_fitness > 0 else 0 |
|
|
self.logger.info(f" Fitness: f={best_fitness:.4f} (improvement: {improvement:+.4f} ({improvement_pct:+.1f}%))") |
|
|
else: |
|
|
self.logger.warning(f"\n⚠️ WARNING: Optimized prompt is IDENTICAL to seed prompt") |
|
|
self.logger.warning(f" This means GEPA didn't modify the prompt during optimization") |
|
|
if best_fitness is not None: |
|
|
self.logger.warning(f" Best fitness found: f={best_fitness:.4f}") |
|
|
self.logger.warning(f" 💡 Check if LLEGO best candidate is being properly extracted") |
|
|
|
|
|
self.logger.info(f"{'═'*80}\n") |
|
|
|
|
|
|
|
|
optimized_test_score = None |
|
|
improvement_data = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimized_val_score = best_fitness |
|
|
|
|
|
if baseline_val_score is not None and optimized_val_score is not None: |
|
|
absolute_improvement = optimized_val_score - baseline_val_score |
|
|
relative_improvement = ( |
|
|
(absolute_improvement / baseline_val_score * 100) |
|
|
if baseline_val_score > 0 else 0 |
|
|
) |
|
|
|
|
|
improvement_data = { |
|
|
'baseline_val_score': baseline_val_score, |
|
|
'optimized_val_score': optimized_val_score, |
|
|
'absolute_improvement': absolute_improvement, |
|
|
'relative_improvement_percent': relative_improvement |
|
|
} |
|
|
|
|
|
self.logger.info( |
|
|
f"📈 Validation improvement: {relative_improvement:+.2f}% " |
|
|
f"(baseline val: {baseline_val_score:.4f} → optimized val: {optimized_val_score:.4f})" |
|
|
) |
|
|
|
|
|
|
|
|
if testset and self.config.evaluate_on_test: |
|
|
self.logger.info("📊 Evaluating optimized prompt on test set...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient |
|
|
if hasattr(self.adapter, 'llm_client') and isinstance(self.adapter.llm_client, LLEGOEnhancedLLMClient): |
|
|
if hasattr(self.adapter.llm_client, '_adapter_generated_candidates'): |
|
|
self.adapter.llm_client._adapter_generated_candidates = [] |
|
|
self.logger.info("✅ Cleared LLEGO candidate queue for clean test evaluation") |
|
|
if hasattr(self.adapter.llm_client, '_candidate_queue'): |
|
|
self.adapter.llm_client._candidate_queue = [] |
|
|
self.logger.info("✅ Cleared LLEGO hybrid candidate queue for clean test evaluation") |
|
|
|
|
|
|
|
|
try: |
|
|
optimized_test_score = self._evaluate_candidate_on_testset( |
|
|
best_candidate, |
|
|
testset |
|
|
) |
|
|
self.logger.info(f"📊 Optimized test score: {optimized_test_score:.4f}") |
|
|
|
|
|
|
|
|
improvement_data['optimized_test_score'] = optimized_test_score |
|
|
|
|
|
if baseline_val_score is not None: |
|
|
test_vs_baseline = ( |
|
|
((optimized_test_score - baseline_val_score) / baseline_val_score * 100) |
|
|
if baseline_val_score > 0 else 0 |
|
|
) |
|
|
self.logger.info( |
|
|
f"📊 Test set vs validation baseline: {test_vs_baseline:+.2f}% " |
|
|
f"(baseline val: {baseline_val_score:.4f} → optimized test: {optimized_test_score:.4f})" |
|
|
) |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Test evaluation failed: {e}") |
|
|
|
|
|
|
|
|
optimization_time = time.time() - start_time |
|
|
|
|
|
processed_result = self.result_processor.process_full_result( |
|
|
result=gepa_result, |
|
|
original_prompt=seed_prompt, |
|
|
optimization_time=optimization_time, |
|
|
actual_iterations=actual_iterations, |
|
|
test_metrics=improvement_data |
|
|
) |
|
|
|
|
|
|
|
|
final_improvement_data = {**processed_result.get('improvement_data', {}), **improvement_data} |
|
|
|
|
|
|
|
|
|
|
|
result = OptimizedResult( |
|
|
original_prompt=seed_prompt, |
|
|
optimized_prompt=optimized_prompt, |
|
|
improvement_data=final_improvement_data, |
|
|
optimization_time=optimization_time, |
|
|
dataset_size=len(trainset) + len(valset) + len(testset), |
|
|
total_iterations=processed_result.get('total_iterations', 0), |
|
|
status=processed_result.get('status', 'completed'), |
|
|
error_message=processed_result.get('error_message'), |
|
|
detailed_result=OptimizationResult( |
|
|
session_id=session_id, |
|
|
original_prompt=seed_prompt, |
|
|
optimized_prompt=optimized_prompt, |
|
|
improvement_data=final_improvement_data, |
|
|
optimization_time=optimization_time, |
|
|
dataset_size=len(trainset) + len(valset) + len(testset), |
|
|
total_iterations=processed_result.get('total_iterations', 0), |
|
|
status=processed_result.get('status', 'completed'), |
|
|
error_message=processed_result.get('error_message') |
|
|
) |
|
|
) |
|
|
|
|
|
self.logger.info(f"✅ Optimization completed in {optimization_time:.2f}s") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
optimization_time = time.time() - start_time |
|
|
error_msg = f"Optimization failed: {str(e)}" |
|
|
self.logger.error(error_msg) |
|
|
|
|
|
|
|
|
return OptimizedResult( |
|
|
original_prompt=seed_prompt, |
|
|
optimized_prompt=seed_prompt, |
|
|
improvement_data={'error': error_msg}, |
|
|
optimization_time=optimization_time, |
|
|
dataset_size=0, |
|
|
total_iterations=0, |
|
|
status='failed', |
|
|
error_message=error_msg |
|
|
) |
|
|
|
|
|
def _update_config_from_kwargs(self, kwargs: Dict[str, Any]) -> None: |
|
|
"""Update configuration with runtime overrides from kwargs.""" |
|
|
updated_params = [] |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if hasattr(self.config, key): |
|
|
setattr(self.config, key, value) |
|
|
updated_params.append(f"{key}={value}") |
|
|
else: |
|
|
self.logger.warning(f"Unknown parameter '{key}' ignored") |
|
|
|
|
|
if updated_params: |
|
|
self.logger.info(f"Updated config parameters: {', '.join(updated_params)}") |
|
|
|
|
|
def _validate_inputs(self, seed_prompt: str) -> None: |
|
|
""" |
|
|
Validate input parameters for optimization |
|
|
|
|
|
Args: |
|
|
seed_prompt: The seed prompt to validate |
|
|
|
|
|
Raises: |
|
|
InvalidInputError: If validation fails |
|
|
""" |
|
|
if not seed_prompt or not isinstance(seed_prompt, str): |
|
|
raise InvalidInputError("Seed prompt must be a non-empty string") |
|
|
|
|
|
if len(seed_prompt.strip()) < 10: |
|
|
raise InvalidInputError("Seed prompt is too short (minimum 10 characters)") |
|
|
|
|
|
|
|
|
model_config = self.config.model |
|
|
if not hasattr(model_config, 'model_name') or not model_config.model_name: |
|
|
raise InvalidInputError("Model name is required") |
|
|
|
|
|
reflection_config = self.config.reflection_model |
|
|
if not hasattr(reflection_config, 'model_name') or not reflection_config.model_name: |
|
|
raise InvalidInputError("Reflection model name is required") |
|
|
|
|
|
def _clean_reflection_prompt(self, prompt: str, max_length: int = 50000) -> str: |
|
|
""" |
|
|
Clean reflection prompt by removing base64 images and truncating if too long. |
|
|
|
|
|
🔥 CRITICAL: GEPA's reflective dataset includes base64 images which create |
|
|
massive prompts (7MB+) that exceed token limits. This function: |
|
|
1. Strips all base64 image data |
|
|
2. Removes excessive detailed_scores entries |
|
|
3. Truncates to reasonable size |
|
|
4. Preserves essential feedback information |
|
|
|
|
|
Args: |
|
|
prompt: Original prompt from GEPA (may contain base64) |
|
|
max_length: Maximum length after cleaning (default: 50K chars) |
|
|
|
|
|
Returns: |
|
|
Cleaned prompt without base64, within size limits |
|
|
""" |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base64_pattern = r'[A-Za-z0-9+/=]{5000,}' |
|
|
cleaned = re.sub(base64_pattern, '[IMAGE_DATA_REMOVED]', prompt) |
|
|
|
|
|
|
|
|
|
|
|
detailed_scores_pattern = r'### detailed_scores[^\n]*\n[^#]*(?:image_base64|base64)[^\n]*(?:\n[^#]*)*' |
|
|
cleaned = re.sub(detailed_scores_pattern, '### detailed_scores: [REMOVED_FOR_BREVITY]', cleaned, flags=re.IGNORECASE | re.MULTILINE) |
|
|
|
|
|
|
|
|
cleaned = re.sub(r'image_base64[^\n]*', 'image_base64: [REMOVED]', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'"[A-Za-z0-9+/=]{10000,}"', '[LARGE_DATA_STRING_REMOVED]', cleaned) |
|
|
|
|
|
|
|
|
if len(cleaned) > max_length: |
|
|
|
|
|
|
|
|
truncated_size = len(cleaned) - max_length |
|
|
cleaned = cleaned[:max_length] + f"\n\n[TRUNCATED {truncated_size} characters of detailed evaluation data]" |
|
|
self.logger.warning(f"⚠️ Prompt truncated: {len(prompt)} → {len(cleaned)} chars") |
|
|
|
|
|
return cleaned |
|
|
|
|
|
def _validate_models(self, task_lm, reflection_lm): |
|
|
""" |
|
|
Validate if specified models are supported. |
|
|
|
|
|
Note: No hardcoded restrictions - the API provider will validate model existence. |
|
|
This method is kept for potential future validation logic but doesn't restrict users. |
|
|
""" |
|
|
|
|
|
|
|
|
self.logger.debug(f"Using task model: {task_lm}, reflection model: {reflection_lm}") |
|
|
|
|
|
def _create_seed_candidate(self, seed_prompt: str) -> Dict[str, str]: |
|
|
"""Create a seed candidate from the input prompt.""" |
|
|
sanitized_prompt = sanitize_prompt(seed_prompt) |
|
|
return {'system_prompt': sanitized_prompt} |
|
|
|
|
|
async def _run_gepa_optimization(self, adapter, seed_candidate: Any, trainset: List[Any], valset: List[Any], **kwargs) -> tuple: |
|
|
""" |
|
|
Run GEPA optimization with the given adapter and data |
|
|
|
|
|
Args: |
|
|
adapter: Custom adapter for GEPA |
|
|
seed_candidate: Initial prompt candidate |
|
|
trainset: Training dataset |
|
|
valset: Validation dataset |
|
|
**kwargs: Additional optimization parameters that can override config |
|
|
|
|
|
Returns: |
|
|
Dict with optimization results |
|
|
|
|
|
Raises: |
|
|
GepaOptimizerError: If optimization fails |
|
|
|
|
|
Note: |
|
|
The following parameters are required in the config: |
|
|
- max_metric_calls: Maximum number of metric evaluations |
|
|
- batch_size: Batch size for evaluation |
|
|
- max_iterations: Maximum number of optimization iterations |
|
|
""" |
|
|
try: |
|
|
|
|
|
max_metric_calls = self.config.max_metric_calls |
|
|
batch_size = self.config.batch_size |
|
|
max_iterations = self.config.max_iterations |
|
|
|
|
|
|
|
|
from ..llms.vision_llm import VisionLLMClient |
|
|
base_reflection_lm_client = VisionLLMClient( |
|
|
provider=self.config.reflection_model.provider, |
|
|
model_name=self.config.reflection_model.model_name, |
|
|
api_key=self.config.reflection_model.api_key, |
|
|
base_url=self.config.reflection_model.base_url, |
|
|
temperature=self.config.reflection_model.temperature, |
|
|
max_tokens=self.config.reflection_model.max_tokens, |
|
|
top_p=self.config.reflection_model.top_p, |
|
|
frequency_penalty=self.config.reflection_model.frequency_penalty, |
|
|
presence_penalty=self.config.reflection_model.presence_penalty |
|
|
) |
|
|
|
|
|
reflection_lm_client = base_reflection_lm_client |
|
|
|
|
|
|
|
|
if self.config.use_llego_operators: |
|
|
self.logger.info("🧬 LLEGO genetic operators ENABLED") |
|
|
self.logger.info(f" α={self.config.alpha}, τ={self.config.tau}, ν={self.config.nu}") |
|
|
self.logger.info(f" Crossover offspring: {self.config.n_crossover}, Mutation offspring: {self.config.n_mutation}") |
|
|
|
|
|
|
|
|
from ..operators.llego_operators import LLEGOIntegrationLayer, PromptCandidate |
|
|
|
|
|
|
|
|
llego = LLEGOIntegrationLayer( |
|
|
alpha=self.config.alpha, |
|
|
tau=self.config.tau, |
|
|
nu=self.config.nu, |
|
|
population_size=self.config.population_size, |
|
|
n_crossover=self.config.n_crossover, |
|
|
n_mutation=self.config.n_mutation |
|
|
) |
|
|
|
|
|
|
|
|
llego.initialize_population( |
|
|
seed_prompt=seed_candidate.get('system_prompt', ''), |
|
|
initial_fitness=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.config.enable_gepa_reflection_with_llego: |
|
|
self.logger.info("🔥 HYBRID MODE: Wrapping reflection_lm_client with LLEGO") |
|
|
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient |
|
|
|
|
|
|
|
|
reflection_lm_client = LLEGOEnhancedLLMClient( |
|
|
base_llm=base_reflection_lm_client, |
|
|
llego_layer=llego, |
|
|
config=self.config, |
|
|
verbose=True |
|
|
) |
|
|
self.logger.info("✅ reflection_lm_client wrapped with LLEGO (hybrid mode enabled)") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(adapter, 'reflection_lm_client'): |
|
|
adapter.reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Stored reflection_lm_client reference in adapter") |
|
|
else: |
|
|
|
|
|
adapter.reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Added reflection_lm_client attribute to adapter") |
|
|
|
|
|
|
|
|
if hasattr(adapter, '_config'): |
|
|
adapter._config = self.config |
|
|
self.logger.info("✅ Stored config in adapter for hybrid mode") |
|
|
else: |
|
|
adapter._config = self.config |
|
|
self.logger.info("✅ Added _config attribute to adapter") |
|
|
|
|
|
if hasattr(adapter, '_reflection_lm_client'): |
|
|
adapter._reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Stored _reflection_lm_client in adapter for hybrid mode") |
|
|
else: |
|
|
adapter._reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Added _reflection_lm_client attribute to adapter") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(adapter, 'llego'): |
|
|
if adapter.llego is None: |
|
|
adapter.llego = llego |
|
|
self.logger.info("✅ CRITICAL: Set LLEGO layer in adapter (was None)") |
|
|
else: |
|
|
self.logger.debug("✅ LLEGO layer already set in adapter") |
|
|
else: |
|
|
|
|
|
adapter.llego = llego |
|
|
self.logger.info("✅ CRITICAL: Added LLEGO layer to adapter") |
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None: |
|
|
adapter._reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)") |
|
|
|
|
|
|
|
|
|
|
|
if self.config.enable_gepa_reflection_with_llego: |
|
|
|
|
|
self.logger.info("🔥 HYBRID MODE: Enabling hybrid candidate generation in LLEGO wrapper") |
|
|
|
|
|
|
|
|
llm_client = self.adapter.llm_client |
|
|
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient |
|
|
|
|
|
if isinstance(llm_client, LLEGOEnhancedLLMClient): |
|
|
|
|
|
llm_client.config = self.config |
|
|
self.logger.info("✅ Updated LLEGO wrapper with hybrid mode config") |
|
|
else: |
|
|
|
|
|
llego_wrapped_llm = LLEGOEnhancedLLMClient( |
|
|
base_llm=llm_client, |
|
|
llego_layer=llego, |
|
|
config=self.config, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
self.adapter.llm_client = llego_wrapped_llm |
|
|
self.logger.info("✅ Wrapped LLM client with LLEGO (hybrid mode enabled)") |
|
|
|
|
|
adapter = self.adapter |
|
|
else: |
|
|
|
|
|
self.logger.info("🧬 LLEGO-ONLY MODE: Recreating adapter with LLEGO integration...") |
|
|
if hasattr(self, 'adapter') and self.adapter: |
|
|
from .universal_adapter import UniversalGepaAdapter |
|
|
|
|
|
|
|
|
original_llm = self.adapter.llm_client |
|
|
|
|
|
if hasattr(original_llm, 'base_llm'): |
|
|
original_llm = original_llm.base_llm |
|
|
|
|
|
evaluator = self.adapter.evaluator |
|
|
data_converter = self.adapter.data_converter |
|
|
|
|
|
|
|
|
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient |
|
|
llego_wrapped_llm = LLEGOEnhancedLLMClient( |
|
|
base_llm=original_llm, |
|
|
llego_layer=llego, |
|
|
config=None, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
adapter = UniversalGepaAdapter( |
|
|
llm_client=llego_wrapped_llm, |
|
|
evaluator=evaluator, |
|
|
data_converter=data_converter, |
|
|
llego_layer=llego |
|
|
) |
|
|
self.logger.info("✅ Adapter recreated with LLEGO-enhanced LLM client") |
|
|
else: |
|
|
adapter = self.adapter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reflection_lm_callable(prompt: str) -> str: |
|
|
""" |
|
|
Reflection callable that delegates to LLEGO-wrapped client. |
|
|
In hybrid mode, the wrapper generates candidates from both GEPA and LLEGO. |
|
|
|
|
|
🔥 CRITICAL: Clean the prompt to remove base64 images and truncate if too long. |
|
|
""" |
|
|
|
|
|
cleaned_prompt = self._clean_reflection_prompt(prompt) |
|
|
|
|
|
self.logger.info(f"\n{'🔥'*40}") |
|
|
self.logger.info(f"🔥 reflection_lm_callable CALLED (delegating to LLEGO wrapper)") |
|
|
self.logger.info(f"🔥 Original prompt length: {len(prompt)} chars") |
|
|
self.logger.info(f"🔥 Cleaned prompt length: {len(cleaned_prompt)} chars") |
|
|
self.logger.info(f"🔥 Truncation: {len(prompt) - len(cleaned_prompt)} chars removed") |
|
|
self.logger.info(f"🔥 First 200 chars (cleaned): {cleaned_prompt[:200]}...") |
|
|
self.logger.info(f"{'🔥'*40}\n") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
if isinstance(reflection_lm_client, LLEGOEnhancedLLMClient): |
|
|
reflection_lm_client.set_reflection_context( |
|
|
current_prompt=cleaned_prompt, |
|
|
feedback=None, |
|
|
in_reflection=True |
|
|
) |
|
|
self.logger.info("✅ Reflection context set on reflection_lm_client") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization. |
|
|
|
|
|
Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues. |
|
|
|
|
|
Core Requirements: |
|
|
1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary) |
|
|
2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening) |
|
|
3. PRESERVE the core task domain and output format requirements |
|
|
4. INTEGRATE improvements from feedback naturally into the prompt structure |
|
|
5. MAINTAIN clarity, specificity, and actionability |
|
|
|
|
|
Quality Standards: |
|
|
- Be specific and concrete (avoid vague instructions) |
|
|
- Use clear, imperative language for task instructions |
|
|
- Include edge case handling if feedback identifies confusion |
|
|
- Ensure the prompt is self-contained and unambiguous |
|
|
|
|
|
DO NOT include: |
|
|
- Analysis of what went wrong |
|
|
- Explanations of your changes |
|
|
- Meta-text like "Here's an improved version..." or "Based on feedback..." |
|
|
- Recommendations or suggestions (those are already in the feedback) |
|
|
|
|
|
Output the improved prompt directly and only the prompt.""" |
|
|
|
|
|
result = reflection_lm_client.generate( |
|
|
system_prompt=optimization_system_prompt, |
|
|
user_prompt=cleaned_prompt, |
|
|
image_base64="" |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
candidate = result.get("content", str(result)) |
|
|
source = result.get("source", "unknown") |
|
|
self.logger.info(f"✅ Candidate from {source} (FULL TEXT):") |
|
|
self.logger.info(f" '{candidate}'") |
|
|
return candidate |
|
|
else: |
|
|
candidate = str(result) |
|
|
self.logger.info(f"✅ Candidate generated (FULL TEXT):") |
|
|
self.logger.info(f" '{candidate}'") |
|
|
return candidate |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Error in reflection_lm_callable: {e}") |
|
|
import traceback |
|
|
self.logger.error(traceback.format_exc()) |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
if self.config.enable_gepa_reflection_with_llego and isinstance(reflection_lm_client, LLEGOEnhancedLLMClient): |
|
|
|
|
|
reflection_lm_client.set_reflection_context( |
|
|
current_prompt=seed_candidate.get('system_prompt', ''), |
|
|
feedback=None, |
|
|
in_reflection=True |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
adapter = self.adapter |
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None: |
|
|
adapter._reflection_lm_client = reflection_lm_client |
|
|
self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)") |
|
|
|
|
|
|
|
|
def reflection_lm_callable(prompt: str) -> str: |
|
|
"""Standard callable wrapper for reflection model that GEPA expects""" |
|
|
try: |
|
|
|
|
|
optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization. |
|
|
|
|
|
Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues. |
|
|
|
|
|
Core Requirements: |
|
|
1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary) |
|
|
2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening) |
|
|
3. PRESERVE the core task domain and output format requirements |
|
|
4. INTEGRATE improvements from feedback naturally into the prompt structure |
|
|
5. MAINTAIN clarity, specificity, and actionability |
|
|
|
|
|
Quality Standards: |
|
|
- Be specific and concrete (avoid vague instructions) |
|
|
- Use clear, imperative language for task instructions |
|
|
- Include edge case handling if feedback identifies confusion |
|
|
- Ensure the prompt is self-contained and unambiguous |
|
|
|
|
|
DO NOT include: |
|
|
- Analysis of what went wrong |
|
|
- Explanations of your changes |
|
|
- Meta-text like "Here's an improved version..." or "Based on feedback..." |
|
|
- Recommendations or suggestions (those are already in the feedback) |
|
|
|
|
|
Output the improved prompt directly and only the prompt.""" |
|
|
|
|
|
|
|
|
result = reflection_lm_client.generate( |
|
|
system_prompt=optimization_system_prompt, |
|
|
user_prompt=prompt, |
|
|
image_base64="" |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
return result.get("content", str(result)) |
|
|
else: |
|
|
return str(result) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Reflection model error: {e}") |
|
|
return prompt |
|
|
self.logger.info( |
|
|
f"Starting GEPA optimization with {max_iterations} iterations, " |
|
|
f"batch size {batch_size}, max metric calls: {max_metric_calls}" |
|
|
) |
|
|
self.logger.info( |
|
|
f"GEPA parameters: candidate_selection_strategy=pareto, " |
|
|
f"reflection_minibatch_size={batch_size}, " |
|
|
f"skip_perfect_score=False, " |
|
|
f"module_selector=round_robin" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reflection_lm_passed = reflection_lm_callable |
|
|
self.logger.info(f"✅ reflection_lm_callable passed to GEPA (LLEGO={self.config.use_llego_operators})") |
|
|
|
|
|
|
|
|
|
|
|
gepa_params = { |
|
|
'adapter': adapter, |
|
|
'seed_candidate': seed_candidate, |
|
|
'trainset': trainset, |
|
|
'valset': valset, |
|
|
'max_metric_calls': max_metric_calls, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'task_lm': None, |
|
|
'reflection_lm': reflection_lm_passed, |
|
|
|
|
|
|
|
|
'candidate_selection_strategy': 'pareto', |
|
|
'skip_perfect_score': False, |
|
|
'reflection_minibatch_size': batch_size, |
|
|
'perfect_score': 1.0, |
|
|
'module_selector': 'round_robin', |
|
|
'display_progress_bar': self.config.verbose, |
|
|
'raise_on_exception': True, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
VALID_GEPA_PARAMS = { |
|
|
'seed_candidate', 'trainset', 'valset', 'adapter', 'task_lm', 'reflection_lm', |
|
|
'candidate_selection_strategy', 'skip_perfect_score', 'batch_sampler', |
|
|
'reflection_minibatch_size', 'perfect_score', 'reflection_prompt_template', |
|
|
'module_selector', 'use_merge', 'max_merge_invocations', 'merge_val_overlap_floor', |
|
|
'max_metric_calls', 'stop_callbacks', 'logger', 'run_dir', 'use_wandb', |
|
|
'wandb_api_key', 'wandb_init_kwargs', 'use_mlflow', 'mlflow_tracking_uri', |
|
|
'mlflow_experiment_name', 'track_best_outputs', 'display_progress_bar', |
|
|
'use_cloudpickle', 'seed', 'raise_on_exception', 'val_evaluation_policy' |
|
|
} |
|
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if key in VALID_GEPA_PARAMS and key not in gepa_params: |
|
|
gepa_params[key] = value |
|
|
elif key not in VALID_GEPA_PARAMS: |
|
|
self.logger.debug(f"⚠️ Filtering out invalid GEPA parameter: {key}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gepa_output = io.StringIO() |
|
|
|
|
|
|
|
|
from ..utils.clean_logger import get_clean_logger |
|
|
clean_log = get_clean_logger() |
|
|
clean_log.log_iteration_start(1, seed_prompt=seed_candidate.get('system_prompt', '')) |
|
|
|
|
|
|
|
|
if hasattr(adapter, '_valset_size'): |
|
|
adapter._valset_size = len(valset) |
|
|
self.logger.debug(f"✅ Set valset_size in adapter: {len(valset)} for Dpareto detection") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(adapter, '_valset'): |
|
|
adapter._valset = valset |
|
|
self.logger.debug(f"✅ Stored valset in adapter ({len(valset)} samples) for Dpareto evaluation of generated candidates") |
|
|
else: |
|
|
|
|
|
adapter._valset = valset |
|
|
self.logger.debug(f"✅ Added _valset attribute to adapter ({len(valset)} samples)") |
|
|
|
|
|
|
|
|
result = await asyncio.get_event_loop().run_in_executor( |
|
|
None, |
|
|
lambda: self._run_gepa_with_logging(gepa_params, gepa_output) |
|
|
) |
|
|
|
|
|
|
|
|
gepa_logs = gepa_output.getvalue() |
|
|
actual_iterations = self._log_pareto_front_info(gepa_logs) |
|
|
|
|
|
return result, actual_iterations |
|
|
except Exception as e: |
|
|
|
|
|
self.logger.warning(f"GEPA optimization failed: {e}") |
|
|
|
|
|
|
|
|
best_candidate = adapter.get_best_candidate() |
|
|
best_score = adapter.get_best_score() |
|
|
|
|
|
if best_candidate and best_score > 0: |
|
|
self.logger.info(f"🎯 Using cached best result with score: {best_score:.4f}") |
|
|
|
|
|
|
|
|
return { |
|
|
'best_candidate': best_candidate, |
|
|
'best_score': best_score, |
|
|
'partial_result': True, |
|
|
'error': f'GEPA failed but returning best result found: {str(e)}' |
|
|
} |
|
|
else: |
|
|
|
|
|
raise GepaOptimizerError(f"GEPA optimization failed: {str(e)}") |
|
|
|
|
|
def _run_gepa_with_logging(self, gepa_params: Dict[str, Any], output_buffer: io.StringIO) -> Any: |
|
|
"""Run GEPA optimization while capturing its output.""" |
|
|
self.logger.info("🔄 Calling gepa.optimize() - GEPA should now:") |
|
|
self.logger.info(" 1. Evaluate seed on validation set") |
|
|
self.logger.info(" 2. For each iteration: evaluate on training minibatch (capture_traces=True)") |
|
|
self.logger.info(" 3. Call make_reflective_dataset() with trajectories") |
|
|
self.logger.info(" 4. Call propose_new_texts() or reflection_lm to generate new candidates") |
|
|
self.logger.info(" 5. Evaluate new candidates and update Pareto front") |
|
|
|
|
|
|
|
|
with redirect_stdout(output_buffer), redirect_stderr(output_buffer): |
|
|
result = gepa.optimize(**gepa_params) |
|
|
|
|
|
|
|
|
gepa_output = output_buffer.getvalue() |
|
|
if gepa_output: |
|
|
self.logger.info("📋 GEPA Output (captured):") |
|
|
for line in gepa_output.split('\n')[:50]: |
|
|
if line.strip(): |
|
|
self.logger.info(f" GEPA: {line}") |
|
|
|
|
|
return result |
|
|
|
|
|
def _log_pareto_front_info(self, gepa_logs: str) -> int: |
|
|
"""Extract and log pareto front information from GEPA logs. Returns max iteration count.""" |
|
|
lines = gepa_logs.split('\n') |
|
|
current_iteration = 0 |
|
|
max_iteration = 0 |
|
|
|
|
|
for line in lines: |
|
|
|
|
|
if 'iteration' in line.lower(): |
|
|
|
|
|
import re |
|
|
iteration_match = re.search(r'iteration\s+(\d+)', line.lower()) |
|
|
if iteration_match: |
|
|
current_iteration = int(iteration_match.group(1)) |
|
|
max_iteration = max(max_iteration, current_iteration) |
|
|
|
|
|
from ..utils.clean_logger import get_clean_logger |
|
|
clean_log = get_clean_logger() |
|
|
if current_iteration > clean_log.current_iteration: |
|
|
clean_log.current_iteration = current_iteration |
|
|
|
|
|
|
|
|
if 'pareto front' in line.lower() or 'new program' in line.lower(): |
|
|
self.logger.info(f"GEPA Pareto Update: {line.strip()}") |
|
|
elif 'iteration' in line.lower() and ('score' in line.lower() or 'program' in line.lower()): |
|
|
self.logger.debug(f"{line.strip()}") |
|
|
elif 'best' in line.lower() and 'score' in line.lower(): |
|
|
self.logger.info(f"{line.strip()}") |
|
|
|
|
|
|
|
|
if 'evaluating' in line.lower() and 'candidate' in line.lower(): |
|
|
self.logger.debug(f"{line.strip()}") |
|
|
|
|
|
self.logger.info(f"GEPA Optimization Complete: {max_iteration} iterations") |
|
|
|
|
|
|
|
|
|
|
|
return max_iteration |
|
|
|
|
|
def _extract_best_candidate(self, gepa_result: Any) -> Dict[str, str]: |
|
|
""" |
|
|
Extract the best candidate from GEPA Pareto front (single source of truth). |
|
|
|
|
|
GEPA Pareto front is the single source of truth because: |
|
|
- All candidates (GEPA reflection, LLEGO crossover, LLEGO mutation) are evaluated on Dpareto |
|
|
- All non-dominated candidates are added to GEPA Pareto front |
|
|
- Therefore, the best candidate MUST be in GEPA Pareto front |
|
|
|
|
|
Args: |
|
|
gepa_result: Raw result from gepa.optimize() (used only as fallback edge case) |
|
|
|
|
|
Returns: |
|
|
Best candidate dictionary with prompt components from GEPA Pareto front |
|
|
""" |
|
|
try: |
|
|
self.logger.info(f"\n{'═'*80}") |
|
|
self.logger.info(f"🔍 EXTRACTING BEST CANDIDATE FROM GEPA PARETO FRONT") |
|
|
self.logger.info(f"{'═'*80}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..utils.pareto_logger import get_pareto_logger |
|
|
pareto_log = get_pareto_logger() |
|
|
|
|
|
if pareto_log.pareto_front: |
|
|
try: |
|
|
|
|
|
gepa_pareto_best = max(pareto_log.pareto_front, key=lambda x: x['score']) |
|
|
gepa_pareto_fitness = gepa_pareto_best['score'] |
|
|
gepa_pareto_prompt = gepa_pareto_best['prompt'] |
|
|
gepa_pareto_type = gepa_pareto_best.get('type', 'unknown') |
|
|
gepa_pareto_notation = gepa_pareto_best.get('notation', 'S') |
|
|
|
|
|
best_candidate = { |
|
|
'system_prompt': gepa_pareto_prompt, |
|
|
'fitness': gepa_pareto_fitness, |
|
|
'source': 'gepa_pareto_front', |
|
|
'candidate_type': gepa_pareto_type, |
|
|
'notation': gepa_pareto_notation |
|
|
} |
|
|
|
|
|
self.logger.info(f"✅ SELECTED: Best candidate from GEPA Pareto front") |
|
|
self.logger.info(f" Notation: {gepa_pareto_notation}") |
|
|
self.logger.info(f" Fitness: f({gepa_pareto_notation})={gepa_pareto_fitness:.4f}") |
|
|
self.logger.info(f" Type: {gepa_pareto_type}") |
|
|
self.logger.info(f" Prompt length: {len(gepa_pareto_prompt)} chars") |
|
|
self.logger.info(f" 💡 GEPA Pareto front is single source of truth (all candidates evaluated on Dpareto)") |
|
|
|
|
|
return best_candidate |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Failed to extract from GEPA Pareto front: {e}") |
|
|
import traceback |
|
|
self.logger.error(traceback.format_exc()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.warning(f"⚠️ GEPA Pareto front is empty - using gepa_result as fallback") |
|
|
self.logger.warning(f" This should not happen if all candidates are evaluated on Dpareto") |
|
|
|
|
|
|
|
|
if hasattr(gepa_result, 'best_candidate'): |
|
|
gepa_candidate = gepa_result.best_candidate |
|
|
gepa_prompt = gepa_candidate.get('system_prompt') if isinstance(gepa_candidate, dict) else str(gepa_candidate) |
|
|
gepa_fitness = getattr(gepa_result, 'best_score', None) |
|
|
|
|
|
if gepa_prompt: |
|
|
self.logger.info(f"✅ Using gepa_result.best_candidate as fallback") |
|
|
return { |
|
|
'system_prompt': gepa_prompt, |
|
|
'fitness': float(gepa_fitness) if gepa_fitness is not None else None, |
|
|
'source': 'gepa_result_fallback', |
|
|
'candidate_type': 'unknown', |
|
|
'notation': 'S' |
|
|
} |
|
|
|
|
|
|
|
|
self.logger.error(f"❌ No candidates found anywhere - returning empty prompt") |
|
|
return {'system_prompt': ''} |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"❌ Error extracting best candidate: {e}") |
|
|
import traceback |
|
|
self.logger.error(traceback.format_exc()) |
|
|
return {'system_prompt': ''} |
|
|
|
|
|
def _evaluate_candidate_on_testset( |
|
|
self, |
|
|
candidate: Dict[str, str], |
|
|
testset: List[Dict] |
|
|
) -> float: |
|
|
""" |
|
|
Evaluate a candidate prompt on the held-out test set. |
|
|
|
|
|
Args: |
|
|
candidate: Prompt candidate to evaluate |
|
|
testset: Test dataset (not used during optimization) |
|
|
|
|
|
Returns: |
|
|
Average composite score on test set |
|
|
|
|
|
Raises: |
|
|
TestSetEvaluationError: If evaluation fails |
|
|
""" |
|
|
from ..utils.exceptions import TestSetEvaluationError |
|
|
|
|
|
try: |
|
|
|
|
|
eval_result = self.adapter.evaluate( |
|
|
batch=testset, |
|
|
candidate=candidate, |
|
|
capture_traces=False |
|
|
) |
|
|
|
|
|
if not eval_result.scores: |
|
|
raise TestSetEvaluationError("No scores returned from test evaluation") |
|
|
|
|
|
|
|
|
avg_score = sum(eval_result.scores) / len(eval_result.scores) |
|
|
|
|
|
self.logger.debug( |
|
|
f"Test set evaluation: {len(eval_result.scores)} samples, " |
|
|
f"scores: {eval_result.scores}, avg: {avg_score:.4f}" |
|
|
) |
|
|
|
|
|
return avg_score |
|
|
|
|
|
except Exception as e: |
|
|
raise TestSetEvaluationError(f"Failed to evaluate on test set: {str(e)}") |
|
|
|
|
|
def optimize_sync(self, |
|
|
model: str, |
|
|
seed_prompt: str, |
|
|
dataset: Any, |
|
|
reflection_lm: str, |
|
|
max_metric_calls: int = 150, |
|
|
**kwargs) -> OptimizedResult: |
|
|
""" |
|
|
Synchronous version of the optimization method |
|
|
|
|
|
Args: |
|
|
model: Target model to optimize for |
|
|
seed_prompt: Initial prompt to optimize |
|
|
dataset: Training data in any format |
|
|
reflection_lm: Model for reflection |
|
|
max_metric_calls: Budget for optimization attempts |
|
|
**kwargs: Additional optimization parameters |
|
|
|
|
|
Returns: |
|
|
OptimizedResult: Optimization result |
|
|
""" |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
try: |
|
|
result = loop.run_until_complete( |
|
|
self.train(model, seed_prompt, dataset, reflection_lm, max_metric_calls, **kwargs) |
|
|
) |
|
|
return result |
|
|
finally: |
|
|
loop.close() |
|
|
|
|
|
|
|
|
|
|
|
def optimize_prompt( |
|
|
model: Union[str, ModelConfig], |
|
|
seed_prompt: str, |
|
|
dataset: Any, |
|
|
reflection_model: Optional[Union[str, ModelConfig]] = None, |
|
|
**kwargs |
|
|
) -> OptimizedResult: |
|
|
""" |
|
|
Convenience function for quick prompt optimization without creating optimizer instance |
|
|
|
|
|
Args: |
|
|
model: Target model configuration |
|
|
seed_prompt: Initial prompt to optimize |
|
|
dataset: Training data |
|
|
reflection_model: Model for reflection (optional) |
|
|
**kwargs: Additional optimization parameters |
|
|
|
|
|
Returns: |
|
|
OptimizedResult: Optimization result |
|
|
""" |
|
|
|
|
|
if reflection_model is None: |
|
|
reflection_model = model |
|
|
|
|
|
config = OptimizationConfig( |
|
|
model=model, |
|
|
reflection_model=reflection_model, |
|
|
max_iterations=kwargs.get('max_iterations', 10), |
|
|
max_metric_calls=kwargs.get('max_metric_calls', 50), |
|
|
batch_size=kwargs.get('batch_size', 4) |
|
|
) |
|
|
|
|
|
optimizer = GepaOptimizer(config=config) |
|
|
return asyncio.run(optimizer.train(seed_prompt, dataset, **kwargs)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|