Spaces:
Sleeping
Sleeping
File size: 17,135 Bytes
cacd4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
"""
Custom GEPA Adapter for the GEPA Universal Prompt Optimizer
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional
# Import ModelConfig
from ..models import ModelConfig
from gepa.core.adapter import GEPAAdapter, EvaluationBatch
from ..llms.vision_llm import VisionLLMClient
from ..evaluation.ui_evaluator import UITreeEvaluator
from .base_adapter import BaseGepaAdapter
logger = logging.getLogger(__name__)
class CustomGepaAdapter(BaseGepaAdapter):
"""
Custom adapter for the GEPA Universal Prompt Optimizer.
"""
def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None):
"""Initialize the custom GEPA adapter with model configuration."""
# Convert string model to ModelConfig if needed
if not isinstance(model_config, ModelConfig):
model_config = ModelConfig(
provider='openai',
model_name=str(model_config),
api_key=None
)
# Initialize components
llm_client = VisionLLMClient(
provider=model_config.provider,
model_name=model_config.model_name,
api_key=model_config.api_key,
base_url=model_config.base_url,
temperature=model_config.temperature,
max_tokens=model_config.max_tokens,
top_p=model_config.top_p,
frequency_penalty=model_config.frequency_penalty,
presence_penalty=model_config.presence_penalty
)
evaluator = UITreeEvaluator(metric_weights=metric_weights)
# Initialize parent class
super().__init__(llm_client, evaluator)
# Track candidates for logging
self._last_candidate = None
self._evaluation_count = 0
self.logger.info(f"π Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}")
def _parse_json_safely(self, json_str: str) -> Dict[str, Any]:
"""Safely parse JSON string to dictionary with enhanced parsing and repair."""
if not json_str or not isinstance(json_str, str):
return {}
# Try direct parsing first
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
# Try to extract JSON from markdown code blocks
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# Try to find JSON object in the string
json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(0))
except json.JSONDecodeError:
pass
# Try repair and parse
repaired_json = self._repair_json(json_str)
if repaired_json:
try:
return json.loads(repaired_json)
except json.JSONDecodeError:
pass
self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...")
return {}
def _repair_json(self, json_str: str) -> str:
"""Attempt to repair common JSON issues."""
try:
# Remove markdown formatting
json_str = re.sub(r'```(?:json)?\s*', '', json_str)
json_str = re.sub(r'```\s*$', '', json_str)
# Remove extra text before/after JSON
json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
if json_match:
json_str = json_match.group(0)
# Fix common issues
json_str = re.sub(r',\s*}', '}', json_str) # Remove trailing commas
json_str = re.sub(r',\s*]', ']', json_str) # Remove trailing commas in arrays
json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str) # Quote unquoted keys
return json_str
except Exception as e:
self.logger.warning(f"π§ JSON repair failed: {e}")
return ""
def evaluate(
self,
batch: List[Dict[str, Any]],
candidate: Dict[str, str],
capture_traces: bool = False,
) -> EvaluationBatch:
"""Evaluate the candidate on a batch of data."""
outputs = []
scores = []
trajectories = [] if capture_traces else None
system_prompt = candidate.get('system_prompt', '')
# Check if this is a new candidate (different from last one)
if self._last_candidate != system_prompt:
self._evaluation_count += 1
self.log_proposed_candidate(candidate, self._evaluation_count)
self._last_candidate = system_prompt
self.logger.info(f"π Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'")
for i, item in enumerate(batch):
input_text = item.get('input', '')
image_base64 = item.get('image', '')
ground_truth_json = item.get('output', '')
# Call the LLM client
llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64)
# Extract content from the response dictionary
if isinstance(llm_response, dict):
llm_output_json_str = llm_response.get("content", "")
if not llm_output_json_str:
llm_output_json_str = str(llm_response)
else:
llm_output_json_str = str(llm_response) if llm_response else ""
# π DEBUG: Log essential info only (removed verbose JSON content)
self.logger.debug(f"π Sample {i+1} - LLM Response Type: {type(llm_response)}")
self.logger.debug(f"π Sample {i+1} - Response Length: {len(llm_output_json_str)} chars")
outputs.append(llm_output_json_str)
# Parse JSON strings to dictionaries for evaluation
llm_output_dict = self._parse_json_safely(llm_output_json_str)
ground_truth_dict = self._parse_json_safely(ground_truth_json)
# Initialize evaluation_results with default values
evaluation_results = {
"composite_score": 0.0,
"element_completeness": 0.0,
"element_type_accuracy": 0.0,
"text_content_accuracy": 0.0,
"hierarchy_accuracy": 0.0,
"style_accuracy": 0.0
}
# Calculate composite score and evaluation results
if not llm_output_dict and not ground_truth_dict:
composite_score = 0.1
evaluation_results = {k: 0.1 for k in evaluation_results.keys()}
self.logger.warning(f"β οΈ Sample {i+1}: Empty results - using default score: {composite_score}")
elif not llm_output_dict or not ground_truth_dict:
composite_score = 0.05
evaluation_results = {k: 0.05 for k in evaluation_results.keys()}
self.logger.warning(f"β οΈ Sample {i+1}: Incomplete results - using low score: {composite_score}")
else:
# Calculate score using evaluator with parsed dictionaries
evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict)
composite_score = evaluation_results["composite_score"]
# Clean, readable logging (removed verbose JSON dumps)
llm_children = len(llm_output_dict.get('children', []))
gt_children = len(ground_truth_dict.get('children', []))
if composite_score < 0.1:
self.logger.warning(f"β οΈ Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
self.logger.debug(f" Score breakdown: {evaluation_results}")
else:
self.logger.info(f"β
Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
scores.append(composite_score)
if capture_traces:
trajectories.append({
'input_text': input_text,
'image_base64': image_base64,
'ground_truth_json': ground_truth_json,
'llm_output_json': llm_output_json_str,
'evaluation_results': evaluation_results
})
avg_score = sum(scores) / len(scores) if scores else 0.0
# Update performance tracking (handled by parent class)
if avg_score > self._best_score:
self._best_score = avg_score
self._best_candidate = candidate.copy()
self.logger.info(f"π― New best candidate found with score: {avg_score:.4f}")
self.logger.info(f"π Batch evaluation complete - Average score: {avg_score:.4f}")
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
def make_reflective_dataset(
self,
candidate: Dict[str, str],
eval_batch: EvaluationBatch,
components_to_update: List[str],
) -> Dict[str, List[Dict[str, Any]]]:
"""Create a reflective dataset from the evaluation results."""
reflective_dataset = {}
system_prompt = candidate.get('system_prompt', '')
# π― NEW: Log the proposed new prompt being evaluated
self.logger.info(f"π Creating reflection dataset for prompt: '{system_prompt[:100]}...'")
# Pretty print reflection dataset creation
self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update)
for component in components_to_update:
reflective_dataset[component] = []
for i, trace in enumerate(eval_batch.trajectories):
feedback = self._generate_feedback(trace['evaluation_results'])
reflective_dataset[component].append({
"current_prompt": system_prompt,
"input_text": trace['input_text'],
"image_base64": trace['image_base64'],
"generated_json": trace['llm_output_json'],
"ground_truth_json": trace['ground_truth_json'],
"score": trace['evaluation_results']["composite_score"],
"feedback": feedback,
"detailed_scores": trace['evaluation_results']
})
# π― NEW: Log reflection dataset summary
total_samples = sum(len(data) for data in reflective_dataset.values())
avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0
self.logger.info(f"π Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}")
return reflective_dataset
def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str:
"""Generate textual feedback based on evaluation results."""
composite_score = evaluation_results.get("composite_score", 0.0)
feedback_parts = []
# Overall quality assessment
if composite_score >= 0.8:
feedback_parts.append("The overall quality is good.")
elif composite_score >= 0.5:
feedback_parts.append("The overall quality is moderate.")
else:
feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.")
# Specific metric feedback
if evaluation_results.get("element_completeness", 0.0) < 0.7:
feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.")
if evaluation_results.get("element_type_accuracy", 0.0) < 0.7:
feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).")
if evaluation_results.get("text_content_accuracy", 0.0) < 0.7:
feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.")
if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7:
feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.")
if evaluation_results.get("style_accuracy", 0.0) < 0.7:
feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).")
return " ".join(feedback_parts)
def get_best_candidate(self) -> Optional[Dict[str, str]]:
"""Get the best candidate found so far."""
return self._best_candidate
def get_best_score(self) -> float:
"""Get the best score found so far."""
return self._best_score
def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0):
"""
Log the new proposed candidate prompt.
Args:
candidate: The new candidate prompt from GEPA
iteration: Current optimization iteration
"""
system_prompt = candidate.get('system_prompt', '')
logger.info("="*80)
logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})")
logger.info("="*80)
logger.info(f"PROPOSED PROMPT:")
logger.info("-" * 40)
logger.debug(f'"{system_prompt}"')
logger.info("-" * 40)
logger.info(f"Prompt Length: {len(system_prompt)} characters")
logger.info(f"Word Count: {len(system_prompt.split())} words")
logger.info("="*80)
def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch,
components_to_update: List[str]):
"""
Log the reflection dataset creation process.
Args:
candidate: Current candidate being evaluated
eval_batch: Evaluation results
components_to_update: Components being updated
"""
system_prompt = candidate.get('system_prompt', '')
logger.info("="*80)
logger.info("REFLECTION DATASET CREATION")
logger.info("="*80)
logger.info(f"CURRENT PROMPT BEING ANALYZED:")
logger.info("-" * 40)
logger.debug(f'"{system_prompt}"')
logger.info("-" * 40)
logger.info(f"EVALUATION SUMMARY:")
logger.info("-" * 40)
if eval_batch.scores:
avg_score = sum(eval_batch.scores) / len(eval_batch.scores)
min_score = min(eval_batch.scores)
max_score = max(eval_batch.scores)
logger.info(f" Average Score: {avg_score:.4f}")
logger.info(f" Min Score: {min_score:.4f}")
logger.info(f" Max Score: {max_score:.4f}")
logger.info(f" Total Samples: {len(eval_batch.scores)}")
logger.info(f"COMPONENTS TO UPDATE:")
logger.info("-" * 40)
for i, component in enumerate(components_to_update, 1):
logger.info(f" {i}. {component}")
if eval_batch.trajectories:
logger.debug(f"DETAILED ANALYSIS:")
logger.debug("-" * 40)
for i, trace in enumerate(eval_batch.trajectories[:3], 1): # Show first 3 samples
evaluation_results = trace['evaluation_results']
composite_score = evaluation_results.get("composite_score", 0.0)
logger.debug(f" Sample {i} (Score: {composite_score:.4f}):")
# Show input data (truncated)
input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text']
logger.debug(f" Input: \"{input_text}\"")
# Show predicted output (truncated)
predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json']
logger.debug(f" Output: \"{predicted_output}\"")
# Show detailed scores
logger.debug(f" Detailed Scores:")
for metric, score in evaluation_results.items():
if metric != "composite_score":
logger.debug(f" {metric.replace('_', ' ').title()}: {score:.4f}")
# Show generated feedback
feedback = self._generate_feedback(evaluation_results)
logger.debug(f" Feedback: \"{feedback}\"")
if len(eval_batch.trajectories) > 3:
logger.debug(f" ... and {len(eval_batch.trajectories) - 3} more samples")
logger.info("="*80)
|