File size: 12,146 Bytes
cacd4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
UI Tree Evaluator for GEPA Optimizer
"""

import json
import logging
import difflib
from typing import Any, Dict, List, Optional

from .base_evaluator import BaseEvaluator

logger = logging.getLogger(__name__)

class UITreeEvaluator(BaseEvaluator):
    """
    Comprehensive evaluator for UI tree extraction quality.
    """

    def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
        """
        Initializes the UITreeEvaluator with configurable metric weights.

        Args:
            metric_weights: A dictionary of weights for different metrics.
                            If None, default weights will be used.
        """
        # Set default weights for UI tree evaluation
        default_weights = {
            "element_completeness": 0.3,      # How many elements are captured
            "element_type_accuracy": 0.25,    # Correct element types (Button, Text, etc.)
            "text_content_accuracy": 0.2,     # Text content matches
            "hierarchy_accuracy": 0.15,       # Parent-child relationships
            "style_accuracy": 0.1,            # Style properties captured
        }
        
        # Use provided weights or defaults
        weights = metric_weights or default_weights
        
        # Initialize parent class
        super().__init__(metric_weights=weights)
        
        # Normalize weights
        self._normalize_weights()
    
    def _normalize_weights(self):
        """Normalize weights to sum to 1.0"""
        total_weight = sum(self.metric_weights.values())
        if total_weight > 0:
            self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()}
        else:
            self.logger.warning("Total metric weight is zero. Scores will be zero.")

    def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]:
        """
        Generates a weighted composite score from individual metrics.

        Args:
            predicted_json: The JSON generated by the LLM.
            expected_json: The ground truth JSON.

        Returns:
            A dictionary of individual metric scores and the composite score.
        """
        scores = {
            "element_completeness": self.calculate_element_completeness(predicted_json, expected_json),
            "element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json),
            "text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json),
            "hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json),
            "style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json),
        }

        composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores)
        scores["composite_score"] = composite_score

        # Add detailed logging for debugging
        logger.debug(f"Evaluation scores: {scores}")
        logger.debug(f"Composite score: {composite_score:.4f}")
        
        # Add small improvement bonus for better prompts (encourage GEPA to accept improvements)
        # This helps GEPA recognize even tiny improvements
        if composite_score > 0.05:  # If we have any meaningful content
            composite_score = min(composite_score + 0.001, 1.0)  # Small bonus to encourage acceptance

        return scores

    def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float:
        """
        Calculates how many UI elements are captured in the predicted JSON.
        This is the most important metric for UI tree extraction.
        """
        def _count_elements(node):
            """Count total elements in the tree"""
            if not isinstance(node, dict):
                return 0
            count = 1  # Count current node
            for child in node.get("children", []):
                count += _count_elements(child)
            return count

        try:
            predicted_count = _count_elements(predicted)
            expected_count = _count_elements(expected)
            
            if expected_count == 0:
                return 1.0 if predicted_count == 0 else 0.0
            
            # Score based on how many elements are captured
            completeness_ratio = predicted_count / expected_count
            
            # Give bonus for capturing more elements (up to 1.0)
            # Penalize heavily for missing elements
            if completeness_ratio >= 1.0:
                return 1.0  # Perfect or better
            elif completeness_ratio >= 0.8:
                return completeness_ratio  # Good coverage
            elif completeness_ratio >= 0.5:
                return completeness_ratio * 0.8  # Moderate coverage with penalty
            else:
                return completeness_ratio * 0.5  # Poor coverage with heavy penalty
                
        except Exception as e:
            logger.warning(f"Error calculating element completeness: {e}")
            return 0.0

    def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float:
        """
        Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes.
        Focuses on common UI element types like Button, Text, Image, etc.
        """
        def _get_all_types(node):
            if not isinstance(node, dict):
                return []
            types = [node.get("type")]
            for child in node.get("children", []):
                types.extend(_get_all_types(child))
            return [t for t in types if t is not None]

        try:
            predicted_types = _get_all_types(predicted)
            expected_types = _get_all_types(expected)

            if not expected_types:
                return 1.0 if not predicted_types else 0.5

            if not predicted_types:
                return 0.0

            # Count matching types with frequency consideration
            expected_type_counts = {}
            for t in expected_types:
                expected_type_counts[t] = expected_type_counts.get(t, 0) + 1
            
            predicted_type_counts = {}
            for t in predicted_types:
                predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1
            
            # Calculate accuracy based on type matches
            total_matches = 0
            for type_name, expected_count in expected_type_counts.items():
                predicted_count = predicted_type_counts.get(type_name, 0)
                # Count matches up to the expected count
                total_matches += min(predicted_count, expected_count)
            
            return total_matches / len(expected_types) if expected_types else 0.0
            
        except Exception as e:
            logger.warning(f"Error calculating element type accuracy: {e}")
            return 0.0

    def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float:
        """
        Calculates hierarchy accuracy by comparing parent-child relationships.
        """
        def _get_hierarchy_structure(node, parent_type="ROOT"):
            """Extract hierarchy structure as (parent_type, child_type) pairs"""
            if not isinstance(node, dict):
                return []
            
            current_type = node.get("type", "unknown")
            hierarchy = [(parent_type, current_type)]
            
            for child in node.get("children", []):
                hierarchy.extend(_get_hierarchy_structure(child, current_type))
            
            return hierarchy

        try:
            predicted_hierarchy = _get_hierarchy_structure(predicted)
            expected_hierarchy = _get_hierarchy_structure(expected)
            
            if not expected_hierarchy:
                return 1.0 if not predicted_hierarchy else 0.5
            
            if not predicted_hierarchy:
                return 0.0
            
            # Count matching hierarchy relationships
            expected_hierarchy_set = set(expected_hierarchy)
            predicted_hierarchy_set = set(predicted_hierarchy)
            
            matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set))
            total_expected = len(expected_hierarchy_set)
            
            return matches / total_expected if total_expected > 0 else 0.0
            
        except Exception as e:
            logger.warning(f"Error calculating hierarchy accuracy: {e}")
            return 0.0

    def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float:
        """
        Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes.
        """
        def _get_all_texts(node):
            if not isinstance(node, dict):
                return []
            texts = [node.get("text")]
            for child in node.get("children", []):
                texts.extend(_get_all_texts(child))
            return [t for t in texts if t is not None and str(t).strip()]

        try:
            predicted_texts = _get_all_texts(predicted)
            expected_texts = _get_all_texts(expected)

            if not expected_texts:
                return 1.0 if not predicted_texts else 0.5  # Partial credit if predicted has texts but expected doesn't

            if not predicted_texts:
                return 0.0  # No predicted texts, so no match

            total_similarity = 0.0
            for p_text in predicted_texts:
                best_similarity = 0.0
                for e_text in expected_texts:
                    similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio()
                    best_similarity = max(best_similarity, similarity)
                total_similarity += best_similarity
            
            # Average similarity over all predicted texts
            if not predicted_texts and not expected_texts:
                return 1.0
            elif not predicted_texts:
                return 0.0
            else:
                return total_similarity / len(predicted_texts)
        except Exception as e:
            logger.warning(f"Error calculating text content accuracy: {e}")
            return 0.0

    def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float:
        """
        Calculates style accuracy by comparing style properties.
        """
        def _get_all_styles(node):
            """Extract all style properties from the tree"""
            if not isinstance(node, dict):
                return []
            
            styles = []
            if "style" in node and isinstance(node["style"], dict):
                styles.append(node["style"])
            
            for child in node.get("children", []):
                styles.extend(_get_all_styles(child))
            
            return styles

        try:
            predicted_styles = _get_all_styles(predicted)
            expected_styles = _get_all_styles(expected)
            
            if not expected_styles:
                return 1.0 if not predicted_styles else 0.5
            
            if not predicted_styles:
                return 0.0
            
            # Calculate style property overlap
            total_style_properties = 0
            matching_properties = 0
            
            for exp_style in expected_styles:
                for prop_name, prop_value in exp_style.items():
                    total_style_properties += 1
                    
                    # Find matching property in predicted styles
                    for pred_style in predicted_styles:
                        if prop_name in pred_style and pred_style[prop_name] == prop_value:
                            matching_properties += 1
                            break
            
            return matching_properties / total_style_properties if total_style_properties > 0 else 0.0
            
        except Exception as e:
            logger.warning(f"Error calculating style accuracy: {e}")
            return 0.0