Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| from typing import Dict, List | |
| import logging | |
| from rouge_score import rouge_scorer | |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
| from sentence_transformers import SentenceTransformer | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import re | |
| logger = logging.getLogger(__name__) | |
| class VietMEAgentEvaluator: | |
| """Comprehensive evaluation for VietMEAgent - FIXED VERSION""" | |
| def __init__(self, cultural_kb_path: str): | |
| # Load cultural knowledge for evaluation | |
| with open(cultural_kb_path, 'r', encoding='utf-8') as f: | |
| self.cultural_kb = json.load(f) | |
| # Initialize evaluation tools | |
| self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False) | |
| self.sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') | |
| self.smoothing = SmoothingFunction().method1 | |
| # Cultural object vocabulary - EXPANDED | |
| self.cultural_vocabulary = set() | |
| for obj_name, obj_data in self.cultural_kb['objects'].items(): | |
| self.cultural_vocabulary.add(obj_name.lower()) | |
| # Add variations | |
| if 'name' in obj_data: | |
| self.cultural_vocabulary.add(obj_data['name'].lower()) | |
| # Additional common Vietnamese cultural terms | |
| additional_terms = [ | |
| 'phở', 'bánh mì', 'áo dài', 'nón lá', 'chùa', 'đình', 'làng', 'thờ', | |
| 'tết', 'trung thu', 'gỏi cuốn', 'bánh xèo', 'cà phê', 'trúc', 'tre', | |
| 'đàn bầu', 'trống', 'sáo', 'múa lân', 'rối nước', 'việt nam' | |
| ] | |
| self.cultural_vocabulary.update(additional_terms) | |
| logger.info(f"Initialized evaluator with {len(self.cultural_vocabulary)} cultural terms") | |
| def evaluate_batch(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: | |
| """Evaluate a batch of predictions""" | |
| logger.info(f"Evaluating {len(predictions)} predictions against {len(ground_truth)} ground truth") | |
| results = { | |
| 'language_quality': {}, | |
| 'cultural_relevance': {}, | |
| 'visual_grounding': {}, | |
| 'overall_performance': {} | |
| } | |
| # Language quality metrics | |
| results['language_quality'] = self.evaluate_language_quality(predictions, ground_truth) | |
| # Cultural relevance metrics | |
| results['cultural_relevance'] = self.evaluate_cultural_relevance(predictions, ground_truth) | |
| # Visual grounding metrics | |
| results['visual_grounding'] = self.evaluate_visual_grounding(predictions, ground_truth) | |
| # Overall performance | |
| results['overall_performance'] = self.calculate_overall_performance(results) | |
| # Debug metrics | |
| self.debug_evaluation_results(results, predictions, ground_truth) | |
| return results | |
| def debug_evaluation_results(self, results: Dict, predictions: List[Dict], ground_truth: List[Dict]): | |
| """Debug evaluation results""" | |
| logger.info("=== EVALUATION DEBUG ===") | |
| # Sample text comparison | |
| if predictions and ground_truth: | |
| pred_text = self.extract_text_from_prediction(predictions[0]) | |
| gt_text = self.extract_text_from_ground_truth(ground_truth[0]) | |
| logger.info(f"Sample prediction text: {pred_text[:100]}...") | |
| logger.info(f"Sample ground truth text: {gt_text[:100]}...") | |
| # Cultural objects | |
| pred_cultural = self.extract_cultural_objects(predictions[0]) | |
| gt_cultural = self.extract_cultural_objects(ground_truth[0]) | |
| logger.info(f"Pred cultural objects: {pred_cultural}") | |
| logger.info(f"GT cultural objects: {gt_cultural}") | |
| logger.info("=== END DEBUG ===") | |
| def evaluate_language_quality(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: | |
| """Evaluate language quality using BLEU and ROUGE - IMPROVED""" | |
| bleu_scores = [] | |
| rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []} | |
| valid_comparisons = 0 | |
| for pred, gt in zip(predictions, ground_truth): | |
| # Extract text for comparison - IMPROVED | |
| pred_text = self.extract_text_from_prediction(pred) | |
| gt_text = self.extract_text_from_ground_truth(gt) | |
| if pred_text and gt_text: | |
| # Clean and normalize text | |
| pred_clean = self.clean_vietnamese_text(pred_text) | |
| gt_clean = self.clean_vietnamese_text(gt_text) | |
| if pred_clean and gt_clean: | |
| valid_comparisons += 1 | |
| # BLEU score - IMPROVED tokenization | |
| pred_tokens = self.tokenize_vietnamese(pred_clean) | |
| gt_tokens = self.tokenize_vietnamese(gt_clean) | |
| if pred_tokens and gt_tokens: | |
| # Use multiple reference for better BLEU | |
| references = [gt_tokens] | |
| # Add variations | |
| if len(gt_tokens) > 3: | |
| references.append(gt_tokens[:-1]) # Remove last word | |
| references.append(gt_tokens[1:]) # Remove first word | |
| bleu = sentence_bleu( | |
| references, | |
| pred_tokens, | |
| smoothing_function=self.smoothing, | |
| weights=(0.5, 0.3, 0.2) # Give more weight to unigrams and bigrams | |
| ) | |
| bleu_scores.append(bleu) | |
| # ROUGE scores | |
| try: | |
| rouge_result = self.rouge_scorer.score(pred_clean, gt_clean) | |
| for metric in rouge_scores: | |
| rouge_scores[metric].append(rouge_result[metric].fmeasure) | |
| except Exception as e: | |
| logger.warning(f"ROUGE calculation failed: {e}") | |
| logger.info(f"Language quality: {valid_comparisons} valid comparisons out of {len(predictions)}") | |
| return { | |
| 'bleu': np.mean(bleu_scores) if bleu_scores else 0.0, | |
| 'rouge1': np.mean(rouge_scores['rouge1']) if rouge_scores['rouge1'] else 0.0, | |
| 'rouge2': np.mean(rouge_scores['rouge2']) if rouge_scores['rouge2'] else 0.0, | |
| 'rougeL': np.mean(rouge_scores['rougeL']) if rouge_scores['rougeL'] else 0.0, | |
| 'num_evaluated': valid_comparisons | |
| } | |
| def clean_vietnamese_text(self, text: str) -> str: | |
| """Clean and normalize Vietnamese text""" | |
| if not text: | |
| return "" | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # Remove special characters but keep Vietnamese diacritics | |
| text = re.sub(r'[^\w\sàáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', text) | |
| return text | |
| def tokenize_vietnamese(self, text: str) -> List[str]: | |
| """Tokenize Vietnamese text""" | |
| if not text: | |
| return [] | |
| # Simple word-based tokenization | |
| tokens = text.split() | |
| # Filter out very short tokens | |
| tokens = [t for t in tokens if len(t) > 1] | |
| return tokens | |
| def evaluate_cultural_relevance(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: | |
| """Evaluate cultural relevance of predictions - IMPROVED""" | |
| cultural_precision = [] | |
| cultural_recall = [] | |
| cultural_accuracy = [] | |
| cultural_mentions = [] | |
| for pred, gt in zip(predictions, ground_truth): | |
| # Extract cultural objects - IMPROVED | |
| pred_cultural = self.extract_cultural_objects(pred) | |
| gt_cultural = self.extract_cultural_objects(gt) | |
| # Count cultural mentions in text | |
| pred_text = self.extract_text_from_prediction(pred) | |
| gt_text = self.extract_text_from_ground_truth(gt) | |
| pred_mentions = self.count_cultural_mentions(pred_text) | |
| gt_mentions = self.count_cultural_mentions(gt_text) | |
| cultural_mentions.append({ | |
| 'pred_mentions': pred_mentions, | |
| 'gt_mentions': gt_mentions, | |
| 'mention_overlap': len(set(pred_mentions).intersection(set(gt_mentions))) | |
| }) | |
| # If we have ground truth cultural objects | |
| if gt_cultural or gt_mentions: | |
| all_gt_cultural = gt_cultural.union(set(gt_mentions)) | |
| all_pred_cultural = pred_cultural.union(set(pred_mentions)) | |
| if all_pred_cultural: | |
| precision = len(all_pred_cultural.intersection(all_gt_cultural)) / len(all_pred_cultural) | |
| cultural_precision.append(precision) | |
| if all_gt_cultural: | |
| recall = len(all_pred_cultural.intersection(all_gt_cultural)) / len(all_gt_cultural) | |
| cultural_recall.append(recall) | |
| # Cultural context accuracy using semantic similarity | |
| if pred_text and gt_text: | |
| cultural_acc = self.evaluate_cultural_context_accuracy(pred, gt) | |
| cultural_accuracy.append(cultural_acc) | |
| # Calculate cultural mention accuracy | |
| mention_accuracy = 0.0 | |
| if cultural_mentions: | |
| total_overlap = sum(m['mention_overlap'] for m in cultural_mentions) | |
| total_gt_mentions = sum(len(m['gt_mentions']) for m in cultural_mentions) | |
| mention_accuracy = total_overlap / total_gt_mentions if total_gt_mentions > 0 else 0.0 | |
| return { | |
| 'cultural_precision': np.mean(cultural_precision) if cultural_precision else 0.0, | |
| 'cultural_recall': np.mean(cultural_recall) if cultural_recall else 0.0, | |
| 'cultural_accuracy': np.mean(cultural_accuracy) if cultural_accuracy else 0.0, | |
| 'cultural_mention_accuracy': mention_accuracy, | |
| 'cultural_f1': self.calculate_f1( | |
| np.mean(cultural_precision) if cultural_precision else 0.0, | |
| np.mean(cultural_recall) if cultural_recall else 0.0 | |
| ), | |
| 'num_cultural_samples': len(cultural_mentions) | |
| } | |
| def count_cultural_mentions(self, text: str) -> List[str]: | |
| """Count mentions of cultural terms in text""" | |
| if not text: | |
| return [] | |
| text_lower = text.lower() | |
| mentions = [] | |
| for cultural_term in self.cultural_vocabulary: | |
| if cultural_term in text_lower: | |
| mentions.append(cultural_term) | |
| return mentions | |
| def evaluate_visual_grounding(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: | |
| """Evaluate visual grounding accuracy - IMPROVED""" | |
| grounding_scores = [] | |
| detection_accuracy = [] | |
| heatmap_quality = [] | |
| for pred, gt in zip(predictions, ground_truth): | |
| # Heatmap-based grounding evaluation | |
| if 'heatmap' in pred: | |
| heatmap = np.array(pred['heatmap']) if isinstance(pred['heatmap'], list) else pred['heatmap'] | |
| # Basic heatmap quality metrics | |
| if heatmap.size > 0: | |
| concentration = np.std(heatmap) | |
| coverage = np.mean(heatmap > 0.3) | |
| max_attention = np.max(heatmap) | |
| # Simple quality score | |
| quality_score = min(1.0, (concentration * 2 + coverage + max_attention) / 3) | |
| heatmap_quality.append(quality_score) | |
| # If we have ground truth regions, calculate IoU | |
| if 'attention_regions' in gt: | |
| iou = self.calculate_grounding_accuracy(heatmap, gt['attention_regions']) | |
| grounding_scores.append(iou) | |
| else: | |
| # Use heatmap quality as proxy for grounding | |
| grounding_scores.append(quality_score * 0.5) # Lower weight without GT | |
| # Object detection accuracy | |
| pred_objects = [] | |
| if 'image_analysis' in pred and 'cultural_objects' in pred['image_analysis']: | |
| pred_objects = pred['image_analysis']['cultural_objects'] | |
| elif 'cultural_objects' in pred: | |
| pred_objects = pred['cultural_objects'] | |
| gt_objects = [] | |
| if 'image_analysis' in gt and 'cultural_objects' in gt['image_analysis']: | |
| gt_objects = gt['image_analysis']['cultural_objects'] | |
| elif 'cultural_objects' in gt: | |
| gt_objects = gt['cultural_objects'] | |
| if gt_objects or pred_objects: | |
| detection_acc = self.calculate_detection_accuracy(pred_objects, gt_objects) | |
| detection_accuracy.append(detection_acc) | |
| return { | |
| 'visual_grounding': np.mean(grounding_scores) if grounding_scores else 0.0, | |
| 'detection_accuracy': np.mean(detection_accuracy) if detection_accuracy else 0.0, | |
| 'heatmap_quality': np.mean(heatmap_quality) if heatmap_quality else 0.0, | |
| 'num_grounding_samples': len(grounding_scores), | |
| 'num_detection_samples': len(detection_accuracy) | |
| } | |
| def extract_text_from_prediction(self, prediction: Dict) -> str: | |
| """Extract text from prediction for evaluation - IMPROVED""" | |
| texts = [] | |
| # Extract from questions | |
| if 'questions' in prediction: | |
| for q in prediction['questions']: | |
| if 'explanation' in q and q['explanation']: | |
| texts.append(str(q['explanation'])) | |
| if 'answer' in q and q['answer']: | |
| texts.append(str(q['answer'])) | |
| if 'question' in q and q['question']: | |
| texts.append(str(q['question'])) | |
| # Extract from vietnamese_explanation | |
| if 'vietnamese_explanation' in prediction and prediction['vietnamese_explanation']: | |
| texts.append(str(prediction['vietnamese_explanation'])) | |
| # Extract from image analysis | |
| if 'image_analysis' in prediction: | |
| analysis = prediction['image_analysis'] | |
| if 'vietnamese_text' in analysis: | |
| texts.extend([str(t) for t in analysis['vietnamese_text'] if t]) | |
| return ' '.join(texts) | |
| def extract_text_from_ground_truth(self, ground_truth: Dict) -> str: | |
| """Extract text from ground truth for evaluation - IMPROVED""" | |
| texts = [] | |
| # Extract from questions | |
| if 'questions' in ground_truth: | |
| for q in ground_truth['questions']: | |
| if 'explanation' in q and q['explanation']: | |
| texts.append(str(q['explanation'])) | |
| if 'answer' in q and q['answer']: | |
| texts.append(str(q['answer'])) | |
| if 'question' in q and q['question']: | |
| texts.append(str(q['question'])) | |
| # Extract from image analysis | |
| if 'image_analysis' in ground_truth: | |
| analysis = ground_truth['image_analysis'] | |
| if 'vietnamese_text' in analysis: | |
| texts.extend([str(t) for t in analysis['vietnamese_text'] if t]) | |
| return ' '.join(texts) | |
| def extract_cultural_objects(self, data: Dict) -> set: | |
| """Extract cultural objects mentioned in data - IMPROVED""" | |
| cultural_objects = set() | |
| # Get all text from the data | |
| text = "" | |
| if 'questions' in data: | |
| text = self.extract_text_from_prediction(data) | |
| else: | |
| text = self.extract_text_from_ground_truth(data) | |
| text_lower = text.lower() | |
| # Find cultural terms in text | |
| for cultural_term in self.cultural_vocabulary: | |
| if cultural_term in text_lower: | |
| cultural_objects.add(cultural_term) | |
| # Also check explicit cultural_objects fields | |
| if 'cultural_objects' in data: | |
| for obj in data['cultural_objects']: | |
| cultural_objects.add(str(obj).lower()) | |
| if 'image_analysis' in data and 'cultural_objects' in data['image_analysis']: | |
| for obj in data['image_analysis']['cultural_objects']: | |
| cultural_objects.add(str(obj).lower()) | |
| return cultural_objects | |
| def evaluate_cultural_context_accuracy(self, prediction: Dict, ground_truth: Dict) -> float: | |
| """Evaluate accuracy of cultural context understanding - IMPROVED""" | |
| # Extract cultural explanations | |
| pred_text = self.extract_text_from_prediction(prediction) | |
| gt_text = self.extract_text_from_ground_truth(ground_truth) | |
| if not pred_text or not gt_text: | |
| return 0.0 | |
| # Clean texts | |
| pred_clean = self.clean_vietnamese_text(pred_text) | |
| gt_clean = self.clean_vietnamese_text(gt_text) | |
| if not pred_clean or not gt_clean: | |
| return 0.0 | |
| try: | |
| # Use semantic similarity for cultural context evaluation | |
| pred_embedding = self.sentence_model.encode([pred_clean]) | |
| gt_embedding = self.sentence_model.encode([gt_clean]) | |
| # Calculate cosine similarity | |
| similarity = np.dot(pred_embedding[0], gt_embedding[0]) / ( | |
| np.linalg.norm(pred_embedding[0]) * np.linalg.norm(gt_embedding[0]) | |
| ) | |
| return max(0.0, float(similarity)) # Ensure non-negative | |
| except Exception as e: | |
| logger.warning(f"Cultural context accuracy calculation failed: {e}") | |
| return 0.0 | |
| def calculate_grounding_accuracy(self, pred_heatmap: np.ndarray, gt_regions: List) -> float: | |
| """Calculate visual grounding accuracy""" | |
| if len(gt_regions) == 0 or pred_heatmap.size == 0: | |
| return 0.0 | |
| try: | |
| # Ensure heatmap is 2D | |
| if pred_heatmap.ndim > 2: | |
| pred_heatmap = pred_heatmap.reshape(-1, pred_heatmap.shape[-1]) | |
| # Create ground truth mask | |
| gt_mask = np.zeros_like(pred_heatmap) | |
| for region in gt_regions: | |
| if isinstance(region, (list, tuple)) and len(region) >= 4: | |
| x, y, w, h = region[:4] | |
| x, y, w, h = int(x), int(y), int(w), int(h) | |
| # Ensure bounds | |
| x = max(0, min(x, gt_mask.shape[1] - 1)) | |
| y = max(0, min(y, gt_mask.shape[0] - 1)) | |
| w = max(1, min(w, gt_mask.shape[1] - x)) | |
| h = max(1, min(h, gt_mask.shape[0] - y)) | |
| gt_mask[y:y+h, x:x+w] = 1 | |
| # Threshold prediction heatmap | |
| pred_mask = (pred_heatmap > 0.5).astype(np.float32) | |
| # Calculate IoU | |
| intersection = np.logical_and(pred_mask, gt_mask).sum() | |
| union = np.logical_or(pred_mask, gt_mask).sum() | |
| return float(intersection / union) if union > 0 else 0.0 | |
| except Exception as e: | |
| logger.warning(f"Grounding accuracy calculation failed: {e}") | |
| return 0.0 | |
| def calculate_detection_accuracy(self, pred_objects: List, gt_objects: List) -> float: | |
| """Calculate object detection accuracy - IMPROVED""" | |
| if not gt_objects and not pred_objects: | |
| return 1.0 | |
| if not gt_objects: | |
| return 0.0 if pred_objects else 1.0 | |
| # Convert to lowercase and clean | |
| pred_set = set(str(obj).lower().strip() for obj in pred_objects if obj) | |
| gt_set = set(str(obj).lower().strip() for obj in gt_objects if obj) | |
| if not gt_set: | |
| return 1.0 if not pred_set else 0.0 | |
| # Calculate Jaccard similarity (IoU for sets) | |
| intersection = len(pred_set.intersection(gt_set)) | |
| union = len(pred_set.union(gt_set)) | |
| return intersection / union if union > 0 else 0.0 | |
| def calculate_f1(self, precision: float, recall: float) -> float: | |
| """Calculate F1 score""" | |
| if precision + recall == 0: | |
| return 0.0 | |
| return 2 * (precision * recall) / (precision + recall) | |
| def calculate_overall_performance(self, results: Dict) -> Dict: | |
| """Calculate overall performance metrics - IMPROVED""" | |
| # Weight different aspects | |
| weights = { | |
| 'language_quality': 0.4, # Increased weight | |
| 'cultural_relevance': 0.4, # Increased weight | |
| 'visual_grounding': 0.2 # Decreased weight (often no GT data) | |
| } | |
| # Calculate weighted average using multiple metrics | |
| overall_score = 0.0 | |
| component_scores = {} | |
| for aspect, weight in weights.items(): | |
| if aspect in results: | |
| if aspect == 'language_quality': | |
| # Average of ROUGE-L and BLEU (ROUGE usually more reliable for Vietnamese) | |
| rouge_l = results[aspect].get('rougeL', 0.0) | |
| bleu = results[aspect].get('bleu', 0.0) | |
| score = (rouge_l * 0.7 + bleu * 0.3) # Weight ROUGE-L higher | |
| elif aspect == 'cultural_relevance': | |
| # Average of multiple cultural metrics | |
| cult_acc = results[aspect].get('cultural_accuracy', 0.0) | |
| cult_f1 = results[aspect].get('cultural_f1', 0.0) | |
| mention_acc = results[aspect].get('cultural_mention_accuracy', 0.0) | |
| score = (cult_acc * 0.4 + cult_f1 * 0.3 + mention_acc * 0.3) | |
| elif aspect == 'visual_grounding': | |
| # Average of grounding metrics | |
| grounding = results[aspect].get('visual_grounding', 0.0) | |
| detection = results[aspect].get('detection_accuracy', 0.0) | |
| heatmap_q = results[aspect].get('heatmap_quality', 0.0) | |
| score = (grounding * 0.4 + detection * 0.4 + heatmap_q * 0.2) | |
| component_scores[aspect] = score | |
| overall_score += weight * score | |
| return { | |
| 'overall_score': overall_score, | |
| 'component_scores': component_scores, | |
| 'weights': weights | |
| } | |
| def generate_evaluation_report(self, results: Dict, save_path: str = None) -> str: | |
| """Generate comprehensive evaluation report - IMPROVED""" | |
| report = f""" | |
| VietMEAgent Evaluation Report | |
| {'='*50} | |
| Language Quality: | |
| BLEU Score: {results['language_quality']['bleu']:.4f} | |
| ROUGE-1: {results['language_quality']['rouge1']:.4f} | |
| ROUGE-2: {results['language_quality']['rouge2']:.4f} | |
| ROUGE-L: {results['language_quality']['rougeL']:.4f} | |
| Samples Evaluated: {results['language_quality']['num_evaluated']} | |
| Cultural Relevance: | |
| Cultural Precision: {results['cultural_relevance']['cultural_precision']:.4f} | |
| Cultural Recall: {results['cultural_relevance']['cultural_recall']:.4f} | |
| Cultural F1: {results['cultural_relevance']['cultural_f1']:.4f} | |
| Cultural Accuracy: {results['cultural_relevance']['cultural_accuracy']:.4f} | |
| Cultural Mention Accuracy: {results['cultural_relevance']['cultural_mention_accuracy']:.4f} | |
| Cultural Samples: {results['cultural_relevance']['num_cultural_samples']} | |
| Visual Grounding: | |
| Grounding Accuracy: {results['visual_grounding']['visual_grounding']:.4f} | |
| Detection Accuracy: {results['visual_grounding']['detection_accuracy']:.4f} | |
| Heatmap Quality: {results['visual_grounding']['heatmap_quality']:.4f} | |
| Grounding Samples: {results['visual_grounding']['num_grounding_samples']} | |
| Detection Samples: {results['visual_grounding']['num_detection_samples']} | |
| Overall Performance: | |
| Overall Score: {results['overall_performance']['overall_score']:.4f} | |
| Component Scores: {results['overall_performance']['component_scores']} | |
| {'='*50} | |
| """ | |
| if save_path: | |
| with open(save_path, 'w', encoding='utf-8') as f: | |
| f.write(report) | |
| logger.info(f"Evaluation report saved to {save_path}") | |
| return report | |
| def plot_evaluation_results(self, results: Dict, save_path: str = None): | |
| """Plot evaluation results - IMPROVED""" | |
| # Create subplots | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 10)) | |
| # Language Quality | |
| lang_metrics = ['bleu', 'rouge1', 'rouge2', 'rougeL'] | |
| lang_scores = [results['language_quality'][m] for m in lang_metrics] | |
| axes[0, 0].bar(lang_metrics, lang_scores, color='skyblue') | |
| axes[0, 0].set_title('Language Quality Metrics') | |
| axes[0, 0].set_ylim(0, 1) | |
| axes[0, 0].tick_params(axis='x', rotation=45) | |
| # Cultural Relevance | |
| cult_metrics = ['cultural_precision', 'cultural_recall', 'cultural_f1', 'cultural_accuracy'] | |
| cult_scores = [results['cultural_relevance'][m] for m in cult_metrics] | |
| axes[0, 1].bar(cult_metrics, cult_scores, color='lightcoral') | |
| axes[0, 1].set_title('Cultural Relevance Metrics') | |
| axes[0, 1].set_ylim(0, 1) | |
| axes[0, 1].tick_params(axis='x', rotation=45) | |
| # Visual Grounding | |
| visual_metrics = ['visual_grounding', 'detection_accuracy', 'heatmap_quality'] | |
| visual_scores = [results['visual_grounding'][m] for m in visual_metrics] | |
| axes[1, 0].bar(visual_metrics, visual_scores, color='lightgreen') | |
| axes[1, 0].set_title('Visual Grounding Metrics') | |
| axes[1, 0].set_ylim(0, 1) | |
| axes[1, 0].tick_params(axis='x', rotation=45) | |
| # Overall comparison | |
| overall_metrics = ['Language Quality', 'Cultural Relevance', 'Visual Grounding'] | |
| component_scores = results['overall_performance']['component_scores'] | |
| overall_scores = [ | |
| component_scores.get('language_quality', 0), | |
| component_scores.get('cultural_relevance', 0), | |
| component_scores.get('visual_grounding', 0) | |
| ] | |
| axes[1, 1].bar(overall_metrics, overall_scores, color='gold') | |
| axes[1, 1].set_title('Overall Performance Comparison') | |
| axes[1, 1].set_ylim(0, 1) | |
| axes[1, 1].tick_params(axis='x', rotation=45) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| logger.info(f"Evaluation plots saved to {save_path}") | |
| plt.show() | |
| return fig | |