File size: 17,135 Bytes
cacd4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""
Custom GEPA Adapter for the GEPA Universal Prompt Optimizer
"""

import json
import logging
import re
from typing import Any, Dict, List, Optional

# Import ModelConfig
from ..models import ModelConfig

from gepa.core.adapter import GEPAAdapter, EvaluationBatch
from ..llms.vision_llm import VisionLLMClient
from ..evaluation.ui_evaluator import UITreeEvaluator
from .base_adapter import BaseGepaAdapter

logger = logging.getLogger(__name__)

class CustomGepaAdapter(BaseGepaAdapter):
    """
    Custom adapter for the GEPA Universal Prompt Optimizer.
    """

    def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None):
        """Initialize the custom GEPA adapter with model configuration."""
        # Convert string model to ModelConfig if needed
        if not isinstance(model_config, ModelConfig):
            model_config = ModelConfig(
                provider='openai',
                model_name=str(model_config),
                api_key=None
            )
        
        # Initialize components
        llm_client = VisionLLMClient(
            provider=model_config.provider,
            model_name=model_config.model_name,
            api_key=model_config.api_key,
            base_url=model_config.base_url,
            temperature=model_config.temperature,
            max_tokens=model_config.max_tokens,
            top_p=model_config.top_p,
            frequency_penalty=model_config.frequency_penalty,
            presence_penalty=model_config.presence_penalty
        )
        
        evaluator = UITreeEvaluator(metric_weights=metric_weights)
        
        # Initialize parent class
        super().__init__(llm_client, evaluator)
        
        # Track candidates for logging
        self._last_candidate = None
        self._evaluation_count = 0
        
        self.logger.info(f"πŸš€ Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}")

    def _parse_json_safely(self, json_str: str) -> Dict[str, Any]:
        """Safely parse JSON string to dictionary with enhanced parsing and repair."""
        if not json_str or not isinstance(json_str, str):
            return {}
        
        # Try direct parsing first
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            pass
        
        # Try to extract JSON from markdown code blocks
        json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group(1))
            except json.JSONDecodeError:
                pass
        
        # Try to find JSON object in the string
        json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group(0))
            except json.JSONDecodeError:
                pass
        
        # Try repair and parse
        repaired_json = self._repair_json(json_str)
        if repaired_json:
            try:
                return json.loads(repaired_json)
            except json.JSONDecodeError:
                pass
        
        self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...")
        return {}

    def _repair_json(self, json_str: str) -> str:
        """Attempt to repair common JSON issues."""
        try:
            # Remove markdown formatting
            json_str = re.sub(r'```(?:json)?\s*', '', json_str)
            json_str = re.sub(r'```\s*$', '', json_str)
            
            # Remove extra text before/after JSON
            json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
            if json_match:
                json_str = json_match.group(0)
            
            # Fix common issues
            json_str = re.sub(r',\s*}', '}', json_str)  # Remove trailing commas
            json_str = re.sub(r',\s*]', ']', json_str)  # Remove trailing commas in arrays
            json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str)  # Quote unquoted keys
            
            return json_str
        except Exception as e:
            self.logger.warning(f"πŸ”§ JSON repair failed: {e}")
            return ""

    def evaluate(
        self,
        batch: List[Dict[str, Any]],
        candidate: Dict[str, str],
        capture_traces: bool = False,
    ) -> EvaluationBatch:
        """Evaluate the candidate on a batch of data."""
        outputs = []
        scores = []
        trajectories = [] if capture_traces else None

        system_prompt = candidate.get('system_prompt', '')

        # Check if this is a new candidate (different from last one)
        if self._last_candidate != system_prompt:
            self._evaluation_count += 1
            self.log_proposed_candidate(candidate, self._evaluation_count)
            self._last_candidate = system_prompt

        self.logger.info(f"πŸ“Š Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'")

        for i, item in enumerate(batch):
            input_text = item.get('input', '')
            image_base64 = item.get('image', '')
            ground_truth_json = item.get('output', '')

            # Call the LLM client
            llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64)
            
            # Extract content from the response dictionary
            if isinstance(llm_response, dict):
                llm_output_json_str = llm_response.get("content", "")
                if not llm_output_json_str:
                    llm_output_json_str = str(llm_response)
            else:
                llm_output_json_str = str(llm_response) if llm_response else ""
            
            # πŸ” DEBUG: Log essential info only (removed verbose JSON content)
            self.logger.debug(f"πŸ” Sample {i+1} - LLM Response Type: {type(llm_response)}")
            self.logger.debug(f"πŸ” Sample {i+1} - Response Length: {len(llm_output_json_str)} chars")
            
            outputs.append(llm_output_json_str)

            # Parse JSON strings to dictionaries for evaluation
            llm_output_dict = self._parse_json_safely(llm_output_json_str)
            ground_truth_dict = self._parse_json_safely(ground_truth_json)

            # Initialize evaluation_results with default values
            evaluation_results = {
                "composite_score": 0.0,
                "element_completeness": 0.0,
                "element_type_accuracy": 0.0,
                "text_content_accuracy": 0.0,
                "hierarchy_accuracy": 0.0,
                "style_accuracy": 0.0
            }

            # Calculate composite score and evaluation results
            if not llm_output_dict and not ground_truth_dict:
                composite_score = 0.1
                evaluation_results = {k: 0.1 for k in evaluation_results.keys()}
                self.logger.warning(f"⚠️  Sample {i+1}: Empty results - using default score: {composite_score}")
            elif not llm_output_dict or not ground_truth_dict:
                composite_score = 0.05
                evaluation_results = {k: 0.05 for k in evaluation_results.keys()}
                self.logger.warning(f"⚠️  Sample {i+1}: Incomplete results - using low score: {composite_score}")
            else:
                # Calculate score using evaluator with parsed dictionaries
                evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict)
                composite_score = evaluation_results["composite_score"]
                
                # Clean, readable logging (removed verbose JSON dumps)
                llm_children = len(llm_output_dict.get('children', []))
                gt_children = len(ground_truth_dict.get('children', []))
                
                if composite_score < 0.1:
                    self.logger.warning(f"⚠️  Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
                    self.logger.debug(f"   Score breakdown: {evaluation_results}")
                else:
                    self.logger.info(f"βœ… Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
            
            scores.append(composite_score)

            if capture_traces:
                trajectories.append({
                    'input_text': input_text,
                    'image_base64': image_base64,
                    'ground_truth_json': ground_truth_json,
                    'llm_output_json': llm_output_json_str,
                    'evaluation_results': evaluation_results
                })

        avg_score = sum(scores) / len(scores) if scores else 0.0
        
        # Update performance tracking (handled by parent class)
        if avg_score > self._best_score:
            self._best_score = avg_score
            self._best_candidate = candidate.copy()
            self.logger.info(f"🎯 New best candidate found with score: {avg_score:.4f}")
        
        self.logger.info(f"πŸ“ˆ Batch evaluation complete - Average score: {avg_score:.4f}")

        return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)

    def make_reflective_dataset(
        self,
        candidate: Dict[str, str],
        eval_batch: EvaluationBatch,
        components_to_update: List[str],
    ) -> Dict[str, List[Dict[str, Any]]]:
        """Create a reflective dataset from the evaluation results."""
        reflective_dataset = {}
        system_prompt = candidate.get('system_prompt', '')

        # 🎯 NEW: Log the proposed new prompt being evaluated
        self.logger.info(f"πŸ“ Creating reflection dataset for prompt: '{system_prompt[:100]}...'")
        
        # Pretty print reflection dataset creation
        self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update)

        for component in components_to_update:
            reflective_dataset[component] = []
            for i, trace in enumerate(eval_batch.trajectories):
                feedback = self._generate_feedback(trace['evaluation_results'])
                reflective_dataset[component].append({
                    "current_prompt": system_prompt,
                    "input_text": trace['input_text'],
                    "image_base64": trace['image_base64'],
                    "generated_json": trace['llm_output_json'],
                    "ground_truth_json": trace['ground_truth_json'],
                    "score": trace['evaluation_results']["composite_score"],
                    "feedback": feedback,
                    "detailed_scores": trace['evaluation_results']
                })

        # 🎯 NEW: Log reflection dataset summary
        total_samples = sum(len(data) for data in reflective_dataset.values())
        avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0
        self.logger.info(f"πŸ“ Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}")

        return reflective_dataset

    def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str:
        """Generate textual feedback based on evaluation results."""
        composite_score = evaluation_results.get("composite_score", 0.0)
        
        feedback_parts = []
        
        # Overall quality assessment
        if composite_score >= 0.8:
            feedback_parts.append("The overall quality is good.")
        elif composite_score >= 0.5:
            feedback_parts.append("The overall quality is moderate.")
        else:
            feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.")
        
        # Specific metric feedback
        if evaluation_results.get("element_completeness", 0.0) < 0.7:
            feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.")
        
        if evaluation_results.get("element_type_accuracy", 0.0) < 0.7:
            feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).")
        
        if evaluation_results.get("text_content_accuracy", 0.0) < 0.7:
            feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.")
        
        if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7:
            feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.")
        
        if evaluation_results.get("style_accuracy", 0.0) < 0.7:
            feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).")
        
        return " ".join(feedback_parts)
    
    def get_best_candidate(self) -> Optional[Dict[str, str]]:
        """Get the best candidate found so far."""
        return self._best_candidate
    
    def get_best_score(self) -> float:
        """Get the best score found so far."""
        return self._best_score
    
    def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0):
        """
        Log the new proposed candidate prompt.
        
        Args:
            candidate: The new candidate prompt from GEPA
            iteration: Current optimization iteration
        """
        system_prompt = candidate.get('system_prompt', '')
        
        logger.info("="*80)
        logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})")
        logger.info("="*80)
        logger.info(f"PROPOSED PROMPT:")
        logger.info("-" * 40)
        logger.debug(f'"{system_prompt}"')
        logger.info("-" * 40)
        logger.info(f"Prompt Length: {len(system_prompt)} characters")
        logger.info(f"Word Count: {len(system_prompt.split())} words")
        logger.info("="*80)
    
    def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch, 
                                       components_to_update: List[str]):
        """
        Log the reflection dataset creation process.
        
        Args:
            candidate: Current candidate being evaluated
            eval_batch: Evaluation results
            components_to_update: Components being updated
        """
        system_prompt = candidate.get('system_prompt', '')
        
        logger.info("="*80)
        logger.info("REFLECTION DATASET CREATION")
        logger.info("="*80)
        
        logger.info(f"CURRENT PROMPT BEING ANALYZED:")
        logger.info("-" * 40)
        logger.debug(f'"{system_prompt}"')
        logger.info("-" * 40)
        
        logger.info(f"EVALUATION SUMMARY:")
        logger.info("-" * 40)
        if eval_batch.scores:
            avg_score = sum(eval_batch.scores) / len(eval_batch.scores)
            min_score = min(eval_batch.scores)
            max_score = max(eval_batch.scores)
            logger.info(f"   Average Score: {avg_score:.4f}")
            logger.info(f"   Min Score: {min_score:.4f}")
            logger.info(f"   Max Score: {max_score:.4f}")
            logger.info(f"   Total Samples: {len(eval_batch.scores)}")
        
        logger.info(f"COMPONENTS TO UPDATE:")
        logger.info("-" * 40)
        for i, component in enumerate(components_to_update, 1):
            logger.info(f"   {i}. {component}")
        
        if eval_batch.trajectories:
            logger.debug(f"DETAILED ANALYSIS:")
            logger.debug("-" * 40)
            for i, trace in enumerate(eval_batch.trajectories[:3], 1):  # Show first 3 samples
                evaluation_results = trace['evaluation_results']
                composite_score = evaluation_results.get("composite_score", 0.0)
                
                logger.debug(f"   Sample {i} (Score: {composite_score:.4f}):")
                
                # Show input data (truncated)
                input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text']
                logger.debug(f"      Input: \"{input_text}\"")
                
                # Show predicted output (truncated)
                predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json']
                logger.debug(f"      Output: \"{predicted_output}\"")
                
                # Show detailed scores
                logger.debug(f"      Detailed Scores:")
                for metric, score in evaluation_results.items():
                    if metric != "composite_score":
                        logger.debug(f"        {metric.replace('_', ' ').title()}: {score:.4f}")
                
                # Show generated feedback
                feedback = self._generate_feedback(evaluation_results)
                logger.debug(f"      Feedback: \"{feedback}\"")
            
            if len(eval_batch.trajectories) > 3:
                logger.debug(f"   ... and {len(eval_batch.trajectories) - 3} more samples")
        
        logger.info("="*80)