""" UI Tree Evaluator for GEPA Optimizer """ import json import logging import difflib from typing import Any, Dict, List, Optional from .base_evaluator import BaseEvaluator logger = logging.getLogger(__name__) class UITreeEvaluator(BaseEvaluator): """ Comprehensive evaluator for UI tree extraction quality. """ def __init__(self, metric_weights: Optional[Dict[str, float]] = None): """ Initializes the UITreeEvaluator with configurable metric weights. Args: metric_weights: A dictionary of weights for different metrics. If None, default weights will be used. """ # Set default weights for UI tree evaluation default_weights = { "element_completeness": 0.3, # How many elements are captured "element_type_accuracy": 0.25, # Correct element types (Button, Text, etc.) "text_content_accuracy": 0.2, # Text content matches "hierarchy_accuracy": 0.15, # Parent-child relationships "style_accuracy": 0.1, # Style properties captured } # Use provided weights or defaults weights = metric_weights or default_weights # Initialize parent class super().__init__(metric_weights=weights) # Normalize weights self._normalize_weights() def _normalize_weights(self): """Normalize weights to sum to 1.0""" total_weight = sum(self.metric_weights.values()) if total_weight > 0: self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()} else: self.logger.warning("Total metric weight is zero. Scores will be zero.") def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]: """ Generates a weighted composite score from individual metrics. Args: predicted_json: The JSON generated by the LLM. expected_json: The ground truth JSON. Returns: A dictionary of individual metric scores and the composite score. """ scores = { "element_completeness": self.calculate_element_completeness(predicted_json, expected_json), "element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json), "text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json), "hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json), "style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json), } composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores) scores["composite_score"] = composite_score # Add detailed logging for debugging logger.debug(f"Evaluation scores: {scores}") logger.debug(f"Composite score: {composite_score:.4f}") # Add small improvement bonus for better prompts (encourage GEPA to accept improvements) # This helps GEPA recognize even tiny improvements if composite_score > 0.05: # If we have any meaningful content composite_score = min(composite_score + 0.001, 1.0) # Small bonus to encourage acceptance return scores def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float: """ Calculates how many UI elements are captured in the predicted JSON. This is the most important metric for UI tree extraction. """ def _count_elements(node): """Count total elements in the tree""" if not isinstance(node, dict): return 0 count = 1 # Count current node for child in node.get("children", []): count += _count_elements(child) return count try: predicted_count = _count_elements(predicted) expected_count = _count_elements(expected) if expected_count == 0: return 1.0 if predicted_count == 0 else 0.0 # Score based on how many elements are captured completeness_ratio = predicted_count / expected_count # Give bonus for capturing more elements (up to 1.0) # Penalize heavily for missing elements if completeness_ratio >= 1.0: return 1.0 # Perfect or better elif completeness_ratio >= 0.8: return completeness_ratio # Good coverage elif completeness_ratio >= 0.5: return completeness_ratio * 0.8 # Moderate coverage with penalty else: return completeness_ratio * 0.5 # Poor coverage with heavy penalty except Exception as e: logger.warning(f"Error calculating element completeness: {e}") return 0.0 def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float: """ Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes. Focuses on common UI element types like Button, Text, Image, etc. """ def _get_all_types(node): if not isinstance(node, dict): return [] types = [node.get("type")] for child in node.get("children", []): types.extend(_get_all_types(child)) return [t for t in types if t is not None] try: predicted_types = _get_all_types(predicted) expected_types = _get_all_types(expected) if not expected_types: return 1.0 if not predicted_types else 0.5 if not predicted_types: return 0.0 # Count matching types with frequency consideration expected_type_counts = {} for t in expected_types: expected_type_counts[t] = expected_type_counts.get(t, 0) + 1 predicted_type_counts = {} for t in predicted_types: predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1 # Calculate accuracy based on type matches total_matches = 0 for type_name, expected_count in expected_type_counts.items(): predicted_count = predicted_type_counts.get(type_name, 0) # Count matches up to the expected count total_matches += min(predicted_count, expected_count) return total_matches / len(expected_types) if expected_types else 0.0 except Exception as e: logger.warning(f"Error calculating element type accuracy: {e}") return 0.0 def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float: """ Calculates hierarchy accuracy by comparing parent-child relationships. """ def _get_hierarchy_structure(node, parent_type="ROOT"): """Extract hierarchy structure as (parent_type, child_type) pairs""" if not isinstance(node, dict): return [] current_type = node.get("type", "unknown") hierarchy = [(parent_type, current_type)] for child in node.get("children", []): hierarchy.extend(_get_hierarchy_structure(child, current_type)) return hierarchy try: predicted_hierarchy = _get_hierarchy_structure(predicted) expected_hierarchy = _get_hierarchy_structure(expected) if not expected_hierarchy: return 1.0 if not predicted_hierarchy else 0.5 if not predicted_hierarchy: return 0.0 # Count matching hierarchy relationships expected_hierarchy_set = set(expected_hierarchy) predicted_hierarchy_set = set(predicted_hierarchy) matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set)) total_expected = len(expected_hierarchy_set) return matches / total_expected if total_expected > 0 else 0.0 except Exception as e: logger.warning(f"Error calculating hierarchy accuracy: {e}") return 0.0 def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float: """ Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes. """ def _get_all_texts(node): if not isinstance(node, dict): return [] texts = [node.get("text")] for child in node.get("children", []): texts.extend(_get_all_texts(child)) return [t for t in texts if t is not None and str(t).strip()] try: predicted_texts = _get_all_texts(predicted) expected_texts = _get_all_texts(expected) if not expected_texts: return 1.0 if not predicted_texts else 0.5 # Partial credit if predicted has texts but expected doesn't if not predicted_texts: return 0.0 # No predicted texts, so no match total_similarity = 0.0 for p_text in predicted_texts: best_similarity = 0.0 for e_text in expected_texts: similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio() best_similarity = max(best_similarity, similarity) total_similarity += best_similarity # Average similarity over all predicted texts if not predicted_texts and not expected_texts: return 1.0 elif not predicted_texts: return 0.0 else: return total_similarity / len(predicted_texts) except Exception as e: logger.warning(f"Error calculating text content accuracy: {e}") return 0.0 def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float: """ Calculates style accuracy by comparing style properties. """ def _get_all_styles(node): """Extract all style properties from the tree""" if not isinstance(node, dict): return [] styles = [] if "style" in node and isinstance(node["style"], dict): styles.append(node["style"]) for child in node.get("children", []): styles.extend(_get_all_styles(child)) return styles try: predicted_styles = _get_all_styles(predicted) expected_styles = _get_all_styles(expected) if not expected_styles: return 1.0 if not predicted_styles else 0.5 if not predicted_styles: return 0.0 # Calculate style property overlap total_style_properties = 0 matching_properties = 0 for exp_style in expected_styles: for prop_name, prop_value in exp_style.items(): total_style_properties += 1 # Find matching property in predicted styles for pred_style in predicted_styles: if prop_name in pred_style and pred_style[prop_name] == prop_value: matching_properties += 1 break return matching_properties / total_style_properties if total_style_properties > 0 else 0.0 except Exception as e: logger.warning(f"Error calculating style accuracy: {e}") return 0.0