""" Custom GEPA Adapter for the GEPA Universal Prompt Optimizer """ import json import logging import re from typing import Any, Dict, List, Optional # Import ModelConfig from ..models import ModelConfig from gepa.core.adapter import GEPAAdapter, EvaluationBatch from ..llms.vision_llm import VisionLLMClient from ..evaluation.ui_evaluator import UITreeEvaluator from .base_adapter import BaseGepaAdapter logger = logging.getLogger(__name__) class CustomGepaAdapter(BaseGepaAdapter): """ Custom adapter for the GEPA Universal Prompt Optimizer. """ def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None): """Initialize the custom GEPA adapter with model configuration.""" # Convert string model to ModelConfig if needed if not isinstance(model_config, ModelConfig): model_config = ModelConfig( provider='openai', model_name=str(model_config), api_key=None ) # Initialize components llm_client = VisionLLMClient( provider=model_config.provider, model_name=model_config.model_name, api_key=model_config.api_key, base_url=model_config.base_url, temperature=model_config.temperature, max_tokens=model_config.max_tokens, top_p=model_config.top_p, frequency_penalty=model_config.frequency_penalty, presence_penalty=model_config.presence_penalty ) evaluator = UITreeEvaluator(metric_weights=metric_weights) # Initialize parent class super().__init__(llm_client, evaluator) # Track candidates for logging self._last_candidate = None self._evaluation_count = 0 self.logger.info(f"🚀 Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}") def _parse_json_safely(self, json_str: str) -> Dict[str, Any]: """Safely parse JSON string to dictionary with enhanced parsing and repair.""" if not json_str or not isinstance(json_str, str): return {} # Try direct parsing first try: return json.loads(json_str) except json.JSONDecodeError: pass # Try to extract JSON from markdown code blocks json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL) if json_match: try: return json.loads(json_match.group(1)) except json.JSONDecodeError: pass # Try to find JSON object in the string json_match = re.search(r'\{.*\}', json_str, re.DOTALL) if json_match: try: return json.loads(json_match.group(0)) except json.JSONDecodeError: pass # Try repair and parse repaired_json = self._repair_json(json_str) if repaired_json: try: return json.loads(repaired_json) except json.JSONDecodeError: pass self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...") return {} def _repair_json(self, json_str: str) -> str: """Attempt to repair common JSON issues.""" try: # Remove markdown formatting json_str = re.sub(r'```(?:json)?\s*', '', json_str) json_str = re.sub(r'```\s*$', '', json_str) # Remove extra text before/after JSON json_match = re.search(r'\{.*\}', json_str, re.DOTALL) if json_match: json_str = json_match.group(0) # Fix common issues json_str = re.sub(r',\s*}', '}', json_str) # Remove trailing commas json_str = re.sub(r',\s*]', ']', json_str) # Remove trailing commas in arrays json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str) # Quote unquoted keys return json_str except Exception as e: self.logger.warning(f"🔧 JSON repair failed: {e}") return "" def evaluate( self, batch: List[Dict[str, Any]], candidate: Dict[str, str], capture_traces: bool = False, ) -> EvaluationBatch: """Evaluate the candidate on a batch of data.""" outputs = [] scores = [] trajectories = [] if capture_traces else None system_prompt = candidate.get('system_prompt', '') # Check if this is a new candidate (different from last one) if self._last_candidate != system_prompt: self._evaluation_count += 1 self.log_proposed_candidate(candidate, self._evaluation_count) self._last_candidate = system_prompt self.logger.info(f"📊 Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'") for i, item in enumerate(batch): input_text = item.get('input', '') image_base64 = item.get('image', '') ground_truth_json = item.get('output', '') # Call the LLM client llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64) # Extract content from the response dictionary if isinstance(llm_response, dict): llm_output_json_str = llm_response.get("content", "") if not llm_output_json_str: llm_output_json_str = str(llm_response) else: llm_output_json_str = str(llm_response) if llm_response else "" # 🔍 DEBUG: Log essential info only (removed verbose JSON content) self.logger.debug(f"🔍 Sample {i+1} - LLM Response Type: {type(llm_response)}") self.logger.debug(f"🔍 Sample {i+1} - Response Length: {len(llm_output_json_str)} chars") outputs.append(llm_output_json_str) # Parse JSON strings to dictionaries for evaluation llm_output_dict = self._parse_json_safely(llm_output_json_str) ground_truth_dict = self._parse_json_safely(ground_truth_json) # Initialize evaluation_results with default values evaluation_results = { "composite_score": 0.0, "element_completeness": 0.0, "element_type_accuracy": 0.0, "text_content_accuracy": 0.0, "hierarchy_accuracy": 0.0, "style_accuracy": 0.0 } # Calculate composite score and evaluation results if not llm_output_dict and not ground_truth_dict: composite_score = 0.1 evaluation_results = {k: 0.1 for k in evaluation_results.keys()} self.logger.warning(f"⚠️ Sample {i+1}: Empty results - using default score: {composite_score}") elif not llm_output_dict or not ground_truth_dict: composite_score = 0.05 evaluation_results = {k: 0.05 for k in evaluation_results.keys()} self.logger.warning(f"⚠️ Sample {i+1}: Incomplete results - using low score: {composite_score}") else: # Calculate score using evaluator with parsed dictionaries evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict) composite_score = evaluation_results["composite_score"] # Clean, readable logging (removed verbose JSON dumps) llm_children = len(llm_output_dict.get('children', [])) gt_children = len(ground_truth_dict.get('children', [])) if composite_score < 0.1: self.logger.warning(f"⚠️ Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements") self.logger.debug(f" Score breakdown: {evaluation_results}") else: self.logger.info(f"✅ Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements") scores.append(composite_score) if capture_traces: trajectories.append({ 'input_text': input_text, 'image_base64': image_base64, 'ground_truth_json': ground_truth_json, 'llm_output_json': llm_output_json_str, 'evaluation_results': evaluation_results }) avg_score = sum(scores) / len(scores) if scores else 0.0 # Update performance tracking (handled by parent class) if avg_score > self._best_score: self._best_score = avg_score self._best_candidate = candidate.copy() self.logger.info(f"🎯 New best candidate found with score: {avg_score:.4f}") self.logger.info(f"📈 Batch evaluation complete - Average score: {avg_score:.4f}") return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories) def make_reflective_dataset( self, candidate: Dict[str, str], eval_batch: EvaluationBatch, components_to_update: List[str], ) -> Dict[str, List[Dict[str, Any]]]: """Create a reflective dataset from the evaluation results.""" reflective_dataset = {} system_prompt = candidate.get('system_prompt', '') # 🎯 NEW: Log the proposed new prompt being evaluated self.logger.info(f"📝 Creating reflection dataset for prompt: '{system_prompt[:100]}...'") # Pretty print reflection dataset creation self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update) for component in components_to_update: reflective_dataset[component] = [] for i, trace in enumerate(eval_batch.trajectories): feedback = self._generate_feedback(trace['evaluation_results']) reflective_dataset[component].append({ "current_prompt": system_prompt, "input_text": trace['input_text'], "image_base64": trace['image_base64'], "generated_json": trace['llm_output_json'], "ground_truth_json": trace['ground_truth_json'], "score": trace['evaluation_results']["composite_score"], "feedback": feedback, "detailed_scores": trace['evaluation_results'] }) # 🎯 NEW: Log reflection dataset summary total_samples = sum(len(data) for data in reflective_dataset.values()) avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0 self.logger.info(f"📝 Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}") return reflective_dataset def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str: """Generate textual feedback based on evaluation results.""" composite_score = evaluation_results.get("composite_score", 0.0) feedback_parts = [] # Overall quality assessment if composite_score >= 0.8: feedback_parts.append("The overall quality is good.") elif composite_score >= 0.5: feedback_parts.append("The overall quality is moderate.") else: feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.") # Specific metric feedback if evaluation_results.get("element_completeness", 0.0) < 0.7: feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.") if evaluation_results.get("element_type_accuracy", 0.0) < 0.7: feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).") if evaluation_results.get("text_content_accuracy", 0.0) < 0.7: feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.") if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7: feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.") if evaluation_results.get("style_accuracy", 0.0) < 0.7: feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).") return " ".join(feedback_parts) def get_best_candidate(self) -> Optional[Dict[str, str]]: """Get the best candidate found so far.""" return self._best_candidate def get_best_score(self) -> float: """Get the best score found so far.""" return self._best_score def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0): """ Log the new proposed candidate prompt. Args: candidate: The new candidate prompt from GEPA iteration: Current optimization iteration """ system_prompt = candidate.get('system_prompt', '') logger.info("="*80) logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})") logger.info("="*80) logger.info(f"PROPOSED PROMPT:") logger.info("-" * 40) logger.debug(f'"{system_prompt}"') logger.info("-" * 40) logger.info(f"Prompt Length: {len(system_prompt)} characters") logger.info(f"Word Count: {len(system_prompt.split())} words") logger.info("="*80) def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch, components_to_update: List[str]): """ Log the reflection dataset creation process. Args: candidate: Current candidate being evaluated eval_batch: Evaluation results components_to_update: Components being updated """ system_prompt = candidate.get('system_prompt', '') logger.info("="*80) logger.info("REFLECTION DATASET CREATION") logger.info("="*80) logger.info(f"CURRENT PROMPT BEING ANALYZED:") logger.info("-" * 40) logger.debug(f'"{system_prompt}"') logger.info("-" * 40) logger.info(f"EVALUATION SUMMARY:") logger.info("-" * 40) if eval_batch.scores: avg_score = sum(eval_batch.scores) / len(eval_batch.scores) min_score = min(eval_batch.scores) max_score = max(eval_batch.scores) logger.info(f" Average Score: {avg_score:.4f}") logger.info(f" Min Score: {min_score:.4f}") logger.info(f" Max Score: {max_score:.4f}") logger.info(f" Total Samples: {len(eval_batch.scores)}") logger.info(f"COMPONENTS TO UPDATE:") logger.info("-" * 40) for i, component in enumerate(components_to_update, 1): logger.info(f" {i}. {component}") if eval_batch.trajectories: logger.debug(f"DETAILED ANALYSIS:") logger.debug("-" * 40) for i, trace in enumerate(eval_batch.trajectories[:3], 1): # Show first 3 samples evaluation_results = trace['evaluation_results'] composite_score = evaluation_results.get("composite_score", 0.0) logger.debug(f" Sample {i} (Score: {composite_score:.4f}):") # Show input data (truncated) input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text'] logger.debug(f" Input: \"{input_text}\"") # Show predicted output (truncated) predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json'] logger.debug(f" Output: \"{predicted_output}\"") # Show detailed scores logger.debug(f" Detailed Scores:") for metric, score in evaluation_results.items(): if metric != "composite_score": logger.debug(f" {metric.replace('_', ' ').title()}: {score:.4f}") # Show generated feedback feedback = self._generate_feedback(evaluation_results) logger.debug(f" Feedback: \"{feedback}\"") if len(eval_batch.trajectories) > 3: logger.debug(f" ... and {len(eval_batch.trajectories) - 3} more samples") logger.info("="*80)