Spaces:
Sleeping
Sleeping
| """ | |
| UI Tree Evaluator for GEPA Optimizer | |
| """ | |
| import json | |
| import logging | |
| import difflib | |
| from typing import Any, Dict, List, Optional | |
| from .base_evaluator import BaseEvaluator | |
| logger = logging.getLogger(__name__) | |
| class UITreeEvaluator(BaseEvaluator): | |
| """ | |
| Comprehensive evaluator for UI tree extraction quality. | |
| """ | |
| def __init__(self, metric_weights: Optional[Dict[str, float]] = None): | |
| """ | |
| Initializes the UITreeEvaluator with configurable metric weights. | |
| Args: | |
| metric_weights: A dictionary of weights for different metrics. | |
| If None, default weights will be used. | |
| """ | |
| # Set default weights for UI tree evaluation | |
| default_weights = { | |
| "element_completeness": 0.3, # How many elements are captured | |
| "element_type_accuracy": 0.25, # Correct element types (Button, Text, etc.) | |
| "text_content_accuracy": 0.2, # Text content matches | |
| "hierarchy_accuracy": 0.15, # Parent-child relationships | |
| "style_accuracy": 0.1, # Style properties captured | |
| } | |
| # Use provided weights or defaults | |
| weights = metric_weights or default_weights | |
| # Initialize parent class | |
| super().__init__(metric_weights=weights) | |
| # Normalize weights | |
| self._normalize_weights() | |
| def _normalize_weights(self): | |
| """Normalize weights to sum to 1.0""" | |
| total_weight = sum(self.metric_weights.values()) | |
| if total_weight > 0: | |
| self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()} | |
| else: | |
| self.logger.warning("Total metric weight is zero. Scores will be zero.") | |
| def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]: | |
| """ | |
| Generates a weighted composite score from individual metrics. | |
| Args: | |
| predicted_json: The JSON generated by the LLM. | |
| expected_json: The ground truth JSON. | |
| Returns: | |
| A dictionary of individual metric scores and the composite score. | |
| """ | |
| scores = { | |
| "element_completeness": self.calculate_element_completeness(predicted_json, expected_json), | |
| "element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json), | |
| "text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json), | |
| "hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json), | |
| "style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json), | |
| } | |
| composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores) | |
| scores["composite_score"] = composite_score | |
| # Add detailed logging for debugging | |
| logger.debug(f"Evaluation scores: {scores}") | |
| logger.debug(f"Composite score: {composite_score:.4f}") | |
| # Add small improvement bonus for better prompts (encourage GEPA to accept improvements) | |
| # This helps GEPA recognize even tiny improvements | |
| if composite_score > 0.05: # If we have any meaningful content | |
| composite_score = min(composite_score + 0.001, 1.0) # Small bonus to encourage acceptance | |
| return scores | |
| def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float: | |
| """ | |
| Calculates how many UI elements are captured in the predicted JSON. | |
| This is the most important metric for UI tree extraction. | |
| """ | |
| def _count_elements(node): | |
| """Count total elements in the tree""" | |
| if not isinstance(node, dict): | |
| return 0 | |
| count = 1 # Count current node | |
| for child in node.get("children", []): | |
| count += _count_elements(child) | |
| return count | |
| try: | |
| predicted_count = _count_elements(predicted) | |
| expected_count = _count_elements(expected) | |
| if expected_count == 0: | |
| return 1.0 if predicted_count == 0 else 0.0 | |
| # Score based on how many elements are captured | |
| completeness_ratio = predicted_count / expected_count | |
| # Give bonus for capturing more elements (up to 1.0) | |
| # Penalize heavily for missing elements | |
| if completeness_ratio >= 1.0: | |
| return 1.0 # Perfect or better | |
| elif completeness_ratio >= 0.8: | |
| return completeness_ratio # Good coverage | |
| elif completeness_ratio >= 0.5: | |
| return completeness_ratio * 0.8 # Moderate coverage with penalty | |
| else: | |
| return completeness_ratio * 0.5 # Poor coverage with heavy penalty | |
| except Exception as e: | |
| logger.warning(f"Error calculating element completeness: {e}") | |
| return 0.0 | |
| def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float: | |
| """ | |
| Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes. | |
| Focuses on common UI element types like Button, Text, Image, etc. | |
| """ | |
| def _get_all_types(node): | |
| if not isinstance(node, dict): | |
| return [] | |
| types = [node.get("type")] | |
| for child in node.get("children", []): | |
| types.extend(_get_all_types(child)) | |
| return [t for t in types if t is not None] | |
| try: | |
| predicted_types = _get_all_types(predicted) | |
| expected_types = _get_all_types(expected) | |
| if not expected_types: | |
| return 1.0 if not predicted_types else 0.5 | |
| if not predicted_types: | |
| return 0.0 | |
| # Count matching types with frequency consideration | |
| expected_type_counts = {} | |
| for t in expected_types: | |
| expected_type_counts[t] = expected_type_counts.get(t, 0) + 1 | |
| predicted_type_counts = {} | |
| for t in predicted_types: | |
| predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1 | |
| # Calculate accuracy based on type matches | |
| total_matches = 0 | |
| for type_name, expected_count in expected_type_counts.items(): | |
| predicted_count = predicted_type_counts.get(type_name, 0) | |
| # Count matches up to the expected count | |
| total_matches += min(predicted_count, expected_count) | |
| return total_matches / len(expected_types) if expected_types else 0.0 | |
| except Exception as e: | |
| logger.warning(f"Error calculating element type accuracy: {e}") | |
| return 0.0 | |
| def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float: | |
| """ | |
| Calculates hierarchy accuracy by comparing parent-child relationships. | |
| """ | |
| def _get_hierarchy_structure(node, parent_type="ROOT"): | |
| """Extract hierarchy structure as (parent_type, child_type) pairs""" | |
| if not isinstance(node, dict): | |
| return [] | |
| current_type = node.get("type", "unknown") | |
| hierarchy = [(parent_type, current_type)] | |
| for child in node.get("children", []): | |
| hierarchy.extend(_get_hierarchy_structure(child, current_type)) | |
| return hierarchy | |
| try: | |
| predicted_hierarchy = _get_hierarchy_structure(predicted) | |
| expected_hierarchy = _get_hierarchy_structure(expected) | |
| if not expected_hierarchy: | |
| return 1.0 if not predicted_hierarchy else 0.5 | |
| if not predicted_hierarchy: | |
| return 0.0 | |
| # Count matching hierarchy relationships | |
| expected_hierarchy_set = set(expected_hierarchy) | |
| predicted_hierarchy_set = set(predicted_hierarchy) | |
| matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set)) | |
| total_expected = len(expected_hierarchy_set) | |
| return matches / total_expected if total_expected > 0 else 0.0 | |
| except Exception as e: | |
| logger.warning(f"Error calculating hierarchy accuracy: {e}") | |
| return 0.0 | |
| def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float: | |
| """ | |
| Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes. | |
| """ | |
| def _get_all_texts(node): | |
| if not isinstance(node, dict): | |
| return [] | |
| texts = [node.get("text")] | |
| for child in node.get("children", []): | |
| texts.extend(_get_all_texts(child)) | |
| return [t for t in texts if t is not None and str(t).strip()] | |
| try: | |
| predicted_texts = _get_all_texts(predicted) | |
| expected_texts = _get_all_texts(expected) | |
| if not expected_texts: | |
| return 1.0 if not predicted_texts else 0.5 # Partial credit if predicted has texts but expected doesn't | |
| if not predicted_texts: | |
| return 0.0 # No predicted texts, so no match | |
| total_similarity = 0.0 | |
| for p_text in predicted_texts: | |
| best_similarity = 0.0 | |
| for e_text in expected_texts: | |
| similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio() | |
| best_similarity = max(best_similarity, similarity) | |
| total_similarity += best_similarity | |
| # Average similarity over all predicted texts | |
| if not predicted_texts and not expected_texts: | |
| return 1.0 | |
| elif not predicted_texts: | |
| return 0.0 | |
| else: | |
| return total_similarity / len(predicted_texts) | |
| except Exception as e: | |
| logger.warning(f"Error calculating text content accuracy: {e}") | |
| return 0.0 | |
| def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float: | |
| """ | |
| Calculates style accuracy by comparing style properties. | |
| """ | |
| def _get_all_styles(node): | |
| """Extract all style properties from the tree""" | |
| if not isinstance(node, dict): | |
| return [] | |
| styles = [] | |
| if "style" in node and isinstance(node["style"], dict): | |
| styles.append(node["style"]) | |
| for child in node.get("children", []): | |
| styles.extend(_get_all_styles(child)) | |
| return styles | |
| try: | |
| predicted_styles = _get_all_styles(predicted) | |
| expected_styles = _get_all_styles(expected) | |
| if not expected_styles: | |
| return 1.0 if not predicted_styles else 0.5 | |
| if not predicted_styles: | |
| return 0.0 | |
| # Calculate style property overlap | |
| total_style_properties = 0 | |
| matching_properties = 0 | |
| for exp_style in expected_styles: | |
| for prop_name, prop_value in exp_style.items(): | |
| total_style_properties += 1 | |
| # Find matching property in predicted styles | |
| for pred_style in predicted_styles: | |
| if prop_name in pred_style and pred_style[prop_name] == prop_value: | |
| matching_properties += 1 | |
| break | |
| return matching_properties / total_style_properties if total_style_properties > 0 else 0.0 | |
| except Exception as e: | |
| logger.warning(f"Error calculating style accuracy: {e}") | |
| return 0.0 | |