Spaces:
Sleeping
Sleeping
| """ | |
| Custom GEPA Adapter for the GEPA Universal Prompt Optimizer | |
| """ | |
| import json | |
| import logging | |
| import re | |
| from typing import Any, Dict, List, Optional | |
| # Import ModelConfig | |
| from ..models import ModelConfig | |
| from gepa.core.adapter import GEPAAdapter, EvaluationBatch | |
| from ..llms.vision_llm import VisionLLMClient | |
| from ..evaluation.ui_evaluator import UITreeEvaluator | |
| from .base_adapter import BaseGepaAdapter | |
| logger = logging.getLogger(__name__) | |
| class CustomGepaAdapter(BaseGepaAdapter): | |
| """ | |
| Custom adapter for the GEPA Universal Prompt Optimizer. | |
| """ | |
| def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None): | |
| """Initialize the custom GEPA adapter with model configuration.""" | |
| # Convert string model to ModelConfig if needed | |
| if not isinstance(model_config, ModelConfig): | |
| model_config = ModelConfig( | |
| provider='openai', | |
| model_name=str(model_config), | |
| api_key=None | |
| ) | |
| # Initialize components | |
| llm_client = VisionLLMClient( | |
| provider=model_config.provider, | |
| model_name=model_config.model_name, | |
| api_key=model_config.api_key, | |
| base_url=model_config.base_url, | |
| temperature=model_config.temperature, | |
| max_tokens=model_config.max_tokens, | |
| top_p=model_config.top_p, | |
| frequency_penalty=model_config.frequency_penalty, | |
| presence_penalty=model_config.presence_penalty | |
| ) | |
| evaluator = UITreeEvaluator(metric_weights=metric_weights) | |
| # Initialize parent class | |
| super().__init__(llm_client, evaluator) | |
| # Track candidates for logging | |
| self._last_candidate = None | |
| self._evaluation_count = 0 | |
| self.logger.info(f"🚀 Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}") | |
| def _parse_json_safely(self, json_str: str) -> Dict[str, Any]: | |
| """Safely parse JSON string to dictionary with enhanced parsing and repair.""" | |
| if not json_str or not isinstance(json_str, str): | |
| return {} | |
| # Try direct parsing first | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to extract JSON from markdown code blocks | |
| json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL) | |
| if json_match: | |
| try: | |
| return json.loads(json_match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find JSON object in the string | |
| json_match = re.search(r'\{.*\}', json_str, re.DOTALL) | |
| if json_match: | |
| try: | |
| return json.loads(json_match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try repair and parse | |
| repaired_json = self._repair_json(json_str) | |
| if repaired_json: | |
| try: | |
| return json.loads(repaired_json) | |
| except json.JSONDecodeError: | |
| pass | |
| self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...") | |
| return {} | |
| def _repair_json(self, json_str: str) -> str: | |
| """Attempt to repair common JSON issues.""" | |
| try: | |
| # Remove markdown formatting | |
| json_str = re.sub(r'```(?:json)?\s*', '', json_str) | |
| json_str = re.sub(r'```\s*$', '', json_str) | |
| # Remove extra text before/after JSON | |
| json_match = re.search(r'\{.*\}', json_str, re.DOTALL) | |
| if json_match: | |
| json_str = json_match.group(0) | |
| # Fix common issues | |
| json_str = re.sub(r',\s*}', '}', json_str) # Remove trailing commas | |
| json_str = re.sub(r',\s*]', ']', json_str) # Remove trailing commas in arrays | |
| json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str) # Quote unquoted keys | |
| return json_str | |
| except Exception as e: | |
| self.logger.warning(f"🔧 JSON repair failed: {e}") | |
| return "" | |
| def evaluate( | |
| self, | |
| batch: List[Dict[str, Any]], | |
| candidate: Dict[str, str], | |
| capture_traces: bool = False, | |
| ) -> EvaluationBatch: | |
| """Evaluate the candidate on a batch of data.""" | |
| outputs = [] | |
| scores = [] | |
| trajectories = [] if capture_traces else None | |
| system_prompt = candidate.get('system_prompt', '') | |
| # Check if this is a new candidate (different from last one) | |
| if self._last_candidate != system_prompt: | |
| self._evaluation_count += 1 | |
| self.log_proposed_candidate(candidate, self._evaluation_count) | |
| self._last_candidate = system_prompt | |
| self.logger.info(f"📊 Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'") | |
| for i, item in enumerate(batch): | |
| input_text = item.get('input', '') | |
| image_base64 = item.get('image', '') | |
| ground_truth_json = item.get('output', '') | |
| # Call the LLM client | |
| llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64) | |
| # Extract content from the response dictionary | |
| if isinstance(llm_response, dict): | |
| llm_output_json_str = llm_response.get("content", "") | |
| if not llm_output_json_str: | |
| llm_output_json_str = str(llm_response) | |
| else: | |
| llm_output_json_str = str(llm_response) if llm_response else "" | |
| # 🔍 DEBUG: Log essential info only (removed verbose JSON content) | |
| self.logger.debug(f"🔍 Sample {i+1} - LLM Response Type: {type(llm_response)}") | |
| self.logger.debug(f"🔍 Sample {i+1} - Response Length: {len(llm_output_json_str)} chars") | |
| outputs.append(llm_output_json_str) | |
| # Parse JSON strings to dictionaries for evaluation | |
| llm_output_dict = self._parse_json_safely(llm_output_json_str) | |
| ground_truth_dict = self._parse_json_safely(ground_truth_json) | |
| # Initialize evaluation_results with default values | |
| evaluation_results = { | |
| "composite_score": 0.0, | |
| "element_completeness": 0.0, | |
| "element_type_accuracy": 0.0, | |
| "text_content_accuracy": 0.0, | |
| "hierarchy_accuracy": 0.0, | |
| "style_accuracy": 0.0 | |
| } | |
| # Calculate composite score and evaluation results | |
| if not llm_output_dict and not ground_truth_dict: | |
| composite_score = 0.1 | |
| evaluation_results = {k: 0.1 for k in evaluation_results.keys()} | |
| self.logger.warning(f"⚠️ Sample {i+1}: Empty results - using default score: {composite_score}") | |
| elif not llm_output_dict or not ground_truth_dict: | |
| composite_score = 0.05 | |
| evaluation_results = {k: 0.05 for k in evaluation_results.keys()} | |
| self.logger.warning(f"⚠️ Sample {i+1}: Incomplete results - using low score: {composite_score}") | |
| else: | |
| # Calculate score using evaluator with parsed dictionaries | |
| evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict) | |
| composite_score = evaluation_results["composite_score"] | |
| # Clean, readable logging (removed verbose JSON dumps) | |
| llm_children = len(llm_output_dict.get('children', [])) | |
| gt_children = len(ground_truth_dict.get('children', [])) | |
| if composite_score < 0.1: | |
| self.logger.warning(f"⚠️ Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements") | |
| self.logger.debug(f" Score breakdown: {evaluation_results}") | |
| else: | |
| self.logger.info(f"✅ Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements") | |
| scores.append(composite_score) | |
| if capture_traces: | |
| trajectories.append({ | |
| 'input_text': input_text, | |
| 'image_base64': image_base64, | |
| 'ground_truth_json': ground_truth_json, | |
| 'llm_output_json': llm_output_json_str, | |
| 'evaluation_results': evaluation_results | |
| }) | |
| avg_score = sum(scores) / len(scores) if scores else 0.0 | |
| # Update performance tracking (handled by parent class) | |
| if avg_score > self._best_score: | |
| self._best_score = avg_score | |
| self._best_candidate = candidate.copy() | |
| self.logger.info(f"🎯 New best candidate found with score: {avg_score:.4f}") | |
| self.logger.info(f"📈 Batch evaluation complete - Average score: {avg_score:.4f}") | |
| return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories) | |
| def make_reflective_dataset( | |
| self, | |
| candidate: Dict[str, str], | |
| eval_batch: EvaluationBatch, | |
| components_to_update: List[str], | |
| ) -> Dict[str, List[Dict[str, Any]]]: | |
| """Create a reflective dataset from the evaluation results.""" | |
| reflective_dataset = {} | |
| system_prompt = candidate.get('system_prompt', '') | |
| # 🎯 NEW: Log the proposed new prompt being evaluated | |
| self.logger.info(f"📝 Creating reflection dataset for prompt: '{system_prompt[:100]}...'") | |
| # Pretty print reflection dataset creation | |
| self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update) | |
| for component in components_to_update: | |
| reflective_dataset[component] = [] | |
| for i, trace in enumerate(eval_batch.trajectories): | |
| feedback = self._generate_feedback(trace['evaluation_results']) | |
| reflective_dataset[component].append({ | |
| "current_prompt": system_prompt, | |
| "input_text": trace['input_text'], | |
| "image_base64": trace['image_base64'], | |
| "generated_json": trace['llm_output_json'], | |
| "ground_truth_json": trace['ground_truth_json'], | |
| "score": trace['evaluation_results']["composite_score"], | |
| "feedback": feedback, | |
| "detailed_scores": trace['evaluation_results'] | |
| }) | |
| # 🎯 NEW: Log reflection dataset summary | |
| total_samples = sum(len(data) for data in reflective_dataset.values()) | |
| avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0 | |
| self.logger.info(f"📝 Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}") | |
| return reflective_dataset | |
| def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str: | |
| """Generate textual feedback based on evaluation results.""" | |
| composite_score = evaluation_results.get("composite_score", 0.0) | |
| feedback_parts = [] | |
| # Overall quality assessment | |
| if composite_score >= 0.8: | |
| feedback_parts.append("The overall quality is good.") | |
| elif composite_score >= 0.5: | |
| feedback_parts.append("The overall quality is moderate.") | |
| else: | |
| feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.") | |
| # Specific metric feedback | |
| if evaluation_results.get("element_completeness", 0.0) < 0.7: | |
| feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.") | |
| if evaluation_results.get("element_type_accuracy", 0.0) < 0.7: | |
| feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).") | |
| if evaluation_results.get("text_content_accuracy", 0.0) < 0.7: | |
| feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.") | |
| if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7: | |
| feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.") | |
| if evaluation_results.get("style_accuracy", 0.0) < 0.7: | |
| feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).") | |
| return " ".join(feedback_parts) | |
| def get_best_candidate(self) -> Optional[Dict[str, str]]: | |
| """Get the best candidate found so far.""" | |
| return self._best_candidate | |
| def get_best_score(self) -> float: | |
| """Get the best score found so far.""" | |
| return self._best_score | |
| def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0): | |
| """ | |
| Log the new proposed candidate prompt. | |
| Args: | |
| candidate: The new candidate prompt from GEPA | |
| iteration: Current optimization iteration | |
| """ | |
| system_prompt = candidate.get('system_prompt', '') | |
| logger.info("="*80) | |
| logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})") | |
| logger.info("="*80) | |
| logger.info(f"PROPOSED PROMPT:") | |
| logger.info("-" * 40) | |
| logger.debug(f'"{system_prompt}"') | |
| logger.info("-" * 40) | |
| logger.info(f"Prompt Length: {len(system_prompt)} characters") | |
| logger.info(f"Word Count: {len(system_prompt.split())} words") | |
| logger.info("="*80) | |
| def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch, | |
| components_to_update: List[str]): | |
| """ | |
| Log the reflection dataset creation process. | |
| Args: | |
| candidate: Current candidate being evaluated | |
| eval_batch: Evaluation results | |
| components_to_update: Components being updated | |
| """ | |
| system_prompt = candidate.get('system_prompt', '') | |
| logger.info("="*80) | |
| logger.info("REFLECTION DATASET CREATION") | |
| logger.info("="*80) | |
| logger.info(f"CURRENT PROMPT BEING ANALYZED:") | |
| logger.info("-" * 40) | |
| logger.debug(f'"{system_prompt}"') | |
| logger.info("-" * 40) | |
| logger.info(f"EVALUATION SUMMARY:") | |
| logger.info("-" * 40) | |
| if eval_batch.scores: | |
| avg_score = sum(eval_batch.scores) / len(eval_batch.scores) | |
| min_score = min(eval_batch.scores) | |
| max_score = max(eval_batch.scores) | |
| logger.info(f" Average Score: {avg_score:.4f}") | |
| logger.info(f" Min Score: {min_score:.4f}") | |
| logger.info(f" Max Score: {max_score:.4f}") | |
| logger.info(f" Total Samples: {len(eval_batch.scores)}") | |
| logger.info(f"COMPONENTS TO UPDATE:") | |
| logger.info("-" * 40) | |
| for i, component in enumerate(components_to_update, 1): | |
| logger.info(f" {i}. {component}") | |
| if eval_batch.trajectories: | |
| logger.debug(f"DETAILED ANALYSIS:") | |
| logger.debug("-" * 40) | |
| for i, trace in enumerate(eval_batch.trajectories[:3], 1): # Show first 3 samples | |
| evaluation_results = trace['evaluation_results'] | |
| composite_score = evaluation_results.get("composite_score", 0.0) | |
| logger.debug(f" Sample {i} (Score: {composite_score:.4f}):") | |
| # Show input data (truncated) | |
| input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text'] | |
| logger.debug(f" Input: \"{input_text}\"") | |
| # Show predicted output (truncated) | |
| predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json'] | |
| logger.debug(f" Output: \"{predicted_output}\"") | |
| # Show detailed scores | |
| logger.debug(f" Detailed Scores:") | |
| for metric, score in evaluation_results.items(): | |
| if metric != "composite_score": | |
| logger.debug(f" {metric.replace('_', ' ').title()}: {score:.4f}") | |
| # Show generated feedback | |
| feedback = self._generate_feedback(evaluation_results) | |
| logger.debug(f" Feedback: \"{feedback}\"") | |
| if len(eval_batch.trajectories) > 3: | |
| logger.debug(f" ... and {len(eval_batch.trajectories) - 3} more samples") | |
| logger.info("="*80) | |