Suhasdev's picture
Deploy Universal Prompt Optimizer to HF Spaces (clean)
cacd4d0
"""
UI Tree Evaluator for GEPA Optimizer
"""
import json
import logging
import difflib
from typing import Any, Dict, List, Optional
from .base_evaluator import BaseEvaluator
logger = logging.getLogger(__name__)
class UITreeEvaluator(BaseEvaluator):
"""
Comprehensive evaluator for UI tree extraction quality.
"""
def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
"""
Initializes the UITreeEvaluator with configurable metric weights.
Args:
metric_weights: A dictionary of weights for different metrics.
If None, default weights will be used.
"""
# Set default weights for UI tree evaluation
default_weights = {
"element_completeness": 0.3, # How many elements are captured
"element_type_accuracy": 0.25, # Correct element types (Button, Text, etc.)
"text_content_accuracy": 0.2, # Text content matches
"hierarchy_accuracy": 0.15, # Parent-child relationships
"style_accuracy": 0.1, # Style properties captured
}
# Use provided weights or defaults
weights = metric_weights or default_weights
# Initialize parent class
super().__init__(metric_weights=weights)
# Normalize weights
self._normalize_weights()
def _normalize_weights(self):
"""Normalize weights to sum to 1.0"""
total_weight = sum(self.metric_weights.values())
if total_weight > 0:
self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()}
else:
self.logger.warning("Total metric weight is zero. Scores will be zero.")
def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]:
"""
Generates a weighted composite score from individual metrics.
Args:
predicted_json: The JSON generated by the LLM.
expected_json: The ground truth JSON.
Returns:
A dictionary of individual metric scores and the composite score.
"""
scores = {
"element_completeness": self.calculate_element_completeness(predicted_json, expected_json),
"element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json),
"text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json),
"hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json),
"style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json),
}
composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores)
scores["composite_score"] = composite_score
# Add detailed logging for debugging
logger.debug(f"Evaluation scores: {scores}")
logger.debug(f"Composite score: {composite_score:.4f}")
# Add small improvement bonus for better prompts (encourage GEPA to accept improvements)
# This helps GEPA recognize even tiny improvements
if composite_score > 0.05: # If we have any meaningful content
composite_score = min(composite_score + 0.001, 1.0) # Small bonus to encourage acceptance
return scores
def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float:
"""
Calculates how many UI elements are captured in the predicted JSON.
This is the most important metric for UI tree extraction.
"""
def _count_elements(node):
"""Count total elements in the tree"""
if not isinstance(node, dict):
return 0
count = 1 # Count current node
for child in node.get("children", []):
count += _count_elements(child)
return count
try:
predicted_count = _count_elements(predicted)
expected_count = _count_elements(expected)
if expected_count == 0:
return 1.0 if predicted_count == 0 else 0.0
# Score based on how many elements are captured
completeness_ratio = predicted_count / expected_count
# Give bonus for capturing more elements (up to 1.0)
# Penalize heavily for missing elements
if completeness_ratio >= 1.0:
return 1.0 # Perfect or better
elif completeness_ratio >= 0.8:
return completeness_ratio # Good coverage
elif completeness_ratio >= 0.5:
return completeness_ratio * 0.8 # Moderate coverage with penalty
else:
return completeness_ratio * 0.5 # Poor coverage with heavy penalty
except Exception as e:
logger.warning(f"Error calculating element completeness: {e}")
return 0.0
def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float:
"""
Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes.
Focuses on common UI element types like Button, Text, Image, etc.
"""
def _get_all_types(node):
if not isinstance(node, dict):
return []
types = [node.get("type")]
for child in node.get("children", []):
types.extend(_get_all_types(child))
return [t for t in types if t is not None]
try:
predicted_types = _get_all_types(predicted)
expected_types = _get_all_types(expected)
if not expected_types:
return 1.0 if not predicted_types else 0.5
if not predicted_types:
return 0.0
# Count matching types with frequency consideration
expected_type_counts = {}
for t in expected_types:
expected_type_counts[t] = expected_type_counts.get(t, 0) + 1
predicted_type_counts = {}
for t in predicted_types:
predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1
# Calculate accuracy based on type matches
total_matches = 0
for type_name, expected_count in expected_type_counts.items():
predicted_count = predicted_type_counts.get(type_name, 0)
# Count matches up to the expected count
total_matches += min(predicted_count, expected_count)
return total_matches / len(expected_types) if expected_types else 0.0
except Exception as e:
logger.warning(f"Error calculating element type accuracy: {e}")
return 0.0
def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float:
"""
Calculates hierarchy accuracy by comparing parent-child relationships.
"""
def _get_hierarchy_structure(node, parent_type="ROOT"):
"""Extract hierarchy structure as (parent_type, child_type) pairs"""
if not isinstance(node, dict):
return []
current_type = node.get("type", "unknown")
hierarchy = [(parent_type, current_type)]
for child in node.get("children", []):
hierarchy.extend(_get_hierarchy_structure(child, current_type))
return hierarchy
try:
predicted_hierarchy = _get_hierarchy_structure(predicted)
expected_hierarchy = _get_hierarchy_structure(expected)
if not expected_hierarchy:
return 1.0 if not predicted_hierarchy else 0.5
if not predicted_hierarchy:
return 0.0
# Count matching hierarchy relationships
expected_hierarchy_set = set(expected_hierarchy)
predicted_hierarchy_set = set(predicted_hierarchy)
matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set))
total_expected = len(expected_hierarchy_set)
return matches / total_expected if total_expected > 0 else 0.0
except Exception as e:
logger.warning(f"Error calculating hierarchy accuracy: {e}")
return 0.0
def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float:
"""
Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes.
"""
def _get_all_texts(node):
if not isinstance(node, dict):
return []
texts = [node.get("text")]
for child in node.get("children", []):
texts.extend(_get_all_texts(child))
return [t for t in texts if t is not None and str(t).strip()]
try:
predicted_texts = _get_all_texts(predicted)
expected_texts = _get_all_texts(expected)
if not expected_texts:
return 1.0 if not predicted_texts else 0.5 # Partial credit if predicted has texts but expected doesn't
if not predicted_texts:
return 0.0 # No predicted texts, so no match
total_similarity = 0.0
for p_text in predicted_texts:
best_similarity = 0.0
for e_text in expected_texts:
similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio()
best_similarity = max(best_similarity, similarity)
total_similarity += best_similarity
# Average similarity over all predicted texts
if not predicted_texts and not expected_texts:
return 1.0
elif not predicted_texts:
return 0.0
else:
return total_similarity / len(predicted_texts)
except Exception as e:
logger.warning(f"Error calculating text content accuracy: {e}")
return 0.0
def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float:
"""
Calculates style accuracy by comparing style properties.
"""
def _get_all_styles(node):
"""Extract all style properties from the tree"""
if not isinstance(node, dict):
return []
styles = []
if "style" in node and isinstance(node["style"], dict):
styles.append(node["style"])
for child in node.get("children", []):
styles.extend(_get_all_styles(child))
return styles
try:
predicted_styles = _get_all_styles(predicted)
expected_styles = _get_all_styles(expected)
if not expected_styles:
return 1.0 if not predicted_styles else 0.5
if not predicted_styles:
return 0.0
# Calculate style property overlap
total_style_properties = 0
matching_properties = 0
for exp_style in expected_styles:
for prop_name, prop_value in exp_style.items():
total_style_properties += 1
# Find matching property in predicted styles
for pred_style in predicted_styles:
if prop_name in pred_style and pred_style[prop_name] == prop_value:
matching_properties += 1
break
return matching_properties / total_style_properties if total_style_properties > 0 else 0.0
except Exception as e:
logger.warning(f"Error calculating style accuracy: {e}")
return 0.0