import json import numpy as np from typing import Dict, List import logging from rouge_score import rouge_scorer from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from sentence_transformers import SentenceTransformer import matplotlib.pyplot as plt import seaborn as sns import re logger = logging.getLogger(__name__) class VietMEAgentEvaluator: """Comprehensive evaluation for VietMEAgent - FIXED VERSION""" def __init__(self, cultural_kb_path: str): # Load cultural knowledge for evaluation with open(cultural_kb_path, 'r', encoding='utf-8') as f: self.cultural_kb = json.load(f) # Initialize evaluation tools self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False) self.sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') self.smoothing = SmoothingFunction().method1 # Cultural object vocabulary - EXPANDED self.cultural_vocabulary = set() for obj_name, obj_data in self.cultural_kb['objects'].items(): self.cultural_vocabulary.add(obj_name.lower()) # Add variations if 'name' in obj_data: self.cultural_vocabulary.add(obj_data['name'].lower()) # Additional common Vietnamese cultural terms additional_terms = [ 'phở', 'bánh mì', 'áo dài', 'nón lá', 'chùa', 'đình', 'làng', 'thờ', 'tết', 'trung thu', 'gỏi cuốn', 'bánh xèo', 'cà phê', 'trúc', 'tre', 'đàn bầu', 'trống', 'sáo', 'múa lân', 'rối nước', 'việt nam' ] self.cultural_vocabulary.update(additional_terms) logger.info(f"Initialized evaluator with {len(self.cultural_vocabulary)} cultural terms") def evaluate_batch(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: """Evaluate a batch of predictions""" logger.info(f"Evaluating {len(predictions)} predictions against {len(ground_truth)} ground truth") results = { 'language_quality': {}, 'cultural_relevance': {}, 'visual_grounding': {}, 'overall_performance': {} } # Language quality metrics results['language_quality'] = self.evaluate_language_quality(predictions, ground_truth) # Cultural relevance metrics results['cultural_relevance'] = self.evaluate_cultural_relevance(predictions, ground_truth) # Visual grounding metrics results['visual_grounding'] = self.evaluate_visual_grounding(predictions, ground_truth) # Overall performance results['overall_performance'] = self.calculate_overall_performance(results) # Debug metrics self.debug_evaluation_results(results, predictions, ground_truth) return results def debug_evaluation_results(self, results: Dict, predictions: List[Dict], ground_truth: List[Dict]): """Debug evaluation results""" logger.info("=== EVALUATION DEBUG ===") # Sample text comparison if predictions and ground_truth: pred_text = self.extract_text_from_prediction(predictions[0]) gt_text = self.extract_text_from_ground_truth(ground_truth[0]) logger.info(f"Sample prediction text: {pred_text[:100]}...") logger.info(f"Sample ground truth text: {gt_text[:100]}...") # Cultural objects pred_cultural = self.extract_cultural_objects(predictions[0]) gt_cultural = self.extract_cultural_objects(ground_truth[0]) logger.info(f"Pred cultural objects: {pred_cultural}") logger.info(f"GT cultural objects: {gt_cultural}") logger.info("=== END DEBUG ===") def evaluate_language_quality(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: """Evaluate language quality using BLEU and ROUGE - IMPROVED""" bleu_scores = [] rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []} valid_comparisons = 0 for pred, gt in zip(predictions, ground_truth): # Extract text for comparison - IMPROVED pred_text = self.extract_text_from_prediction(pred) gt_text = self.extract_text_from_ground_truth(gt) if pred_text and gt_text: # Clean and normalize text pred_clean = self.clean_vietnamese_text(pred_text) gt_clean = self.clean_vietnamese_text(gt_text) if pred_clean and gt_clean: valid_comparisons += 1 # BLEU score - IMPROVED tokenization pred_tokens = self.tokenize_vietnamese(pred_clean) gt_tokens = self.tokenize_vietnamese(gt_clean) if pred_tokens and gt_tokens: # Use multiple reference for better BLEU references = [gt_tokens] # Add variations if len(gt_tokens) > 3: references.append(gt_tokens[:-1]) # Remove last word references.append(gt_tokens[1:]) # Remove first word bleu = sentence_bleu( references, pred_tokens, smoothing_function=self.smoothing, weights=(0.5, 0.3, 0.2) # Give more weight to unigrams and bigrams ) bleu_scores.append(bleu) # ROUGE scores try: rouge_result = self.rouge_scorer.score(pred_clean, gt_clean) for metric in rouge_scores: rouge_scores[metric].append(rouge_result[metric].fmeasure) except Exception as e: logger.warning(f"ROUGE calculation failed: {e}") logger.info(f"Language quality: {valid_comparisons} valid comparisons out of {len(predictions)}") return { 'bleu': np.mean(bleu_scores) if bleu_scores else 0.0, 'rouge1': np.mean(rouge_scores['rouge1']) if rouge_scores['rouge1'] else 0.0, 'rouge2': np.mean(rouge_scores['rouge2']) if rouge_scores['rouge2'] else 0.0, 'rougeL': np.mean(rouge_scores['rougeL']) if rouge_scores['rougeL'] else 0.0, 'num_evaluated': valid_comparisons } def clean_vietnamese_text(self, text: str) -> str: """Clean and normalize Vietnamese text""" if not text: return "" # Convert to lowercase text = text.lower() # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() # Remove special characters but keep Vietnamese diacritics text = re.sub(r'[^\w\sàáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', text) return text def tokenize_vietnamese(self, text: str) -> List[str]: """Tokenize Vietnamese text""" if not text: return [] # Simple word-based tokenization tokens = text.split() # Filter out very short tokens tokens = [t for t in tokens if len(t) > 1] return tokens def evaluate_cultural_relevance(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: """Evaluate cultural relevance of predictions - IMPROVED""" cultural_precision = [] cultural_recall = [] cultural_accuracy = [] cultural_mentions = [] for pred, gt in zip(predictions, ground_truth): # Extract cultural objects - IMPROVED pred_cultural = self.extract_cultural_objects(pred) gt_cultural = self.extract_cultural_objects(gt) # Count cultural mentions in text pred_text = self.extract_text_from_prediction(pred) gt_text = self.extract_text_from_ground_truth(gt) pred_mentions = self.count_cultural_mentions(pred_text) gt_mentions = self.count_cultural_mentions(gt_text) cultural_mentions.append({ 'pred_mentions': pred_mentions, 'gt_mentions': gt_mentions, 'mention_overlap': len(set(pred_mentions).intersection(set(gt_mentions))) }) # If we have ground truth cultural objects if gt_cultural or gt_mentions: all_gt_cultural = gt_cultural.union(set(gt_mentions)) all_pred_cultural = pred_cultural.union(set(pred_mentions)) if all_pred_cultural: precision = len(all_pred_cultural.intersection(all_gt_cultural)) / len(all_pred_cultural) cultural_precision.append(precision) if all_gt_cultural: recall = len(all_pred_cultural.intersection(all_gt_cultural)) / len(all_gt_cultural) cultural_recall.append(recall) # Cultural context accuracy using semantic similarity if pred_text and gt_text: cultural_acc = self.evaluate_cultural_context_accuracy(pred, gt) cultural_accuracy.append(cultural_acc) # Calculate cultural mention accuracy mention_accuracy = 0.0 if cultural_mentions: total_overlap = sum(m['mention_overlap'] for m in cultural_mentions) total_gt_mentions = sum(len(m['gt_mentions']) for m in cultural_mentions) mention_accuracy = total_overlap / total_gt_mentions if total_gt_mentions > 0 else 0.0 return { 'cultural_precision': np.mean(cultural_precision) if cultural_precision else 0.0, 'cultural_recall': np.mean(cultural_recall) if cultural_recall else 0.0, 'cultural_accuracy': np.mean(cultural_accuracy) if cultural_accuracy else 0.0, 'cultural_mention_accuracy': mention_accuracy, 'cultural_f1': self.calculate_f1( np.mean(cultural_precision) if cultural_precision else 0.0, np.mean(cultural_recall) if cultural_recall else 0.0 ), 'num_cultural_samples': len(cultural_mentions) } def count_cultural_mentions(self, text: str) -> List[str]: """Count mentions of cultural terms in text""" if not text: return [] text_lower = text.lower() mentions = [] for cultural_term in self.cultural_vocabulary: if cultural_term in text_lower: mentions.append(cultural_term) return mentions def evaluate_visual_grounding(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict: """Evaluate visual grounding accuracy - IMPROVED""" grounding_scores = [] detection_accuracy = [] heatmap_quality = [] for pred, gt in zip(predictions, ground_truth): # Heatmap-based grounding evaluation if 'heatmap' in pred: heatmap = np.array(pred['heatmap']) if isinstance(pred['heatmap'], list) else pred['heatmap'] # Basic heatmap quality metrics if heatmap.size > 0: concentration = np.std(heatmap) coverage = np.mean(heatmap > 0.3) max_attention = np.max(heatmap) # Simple quality score quality_score = min(1.0, (concentration * 2 + coverage + max_attention) / 3) heatmap_quality.append(quality_score) # If we have ground truth regions, calculate IoU if 'attention_regions' in gt: iou = self.calculate_grounding_accuracy(heatmap, gt['attention_regions']) grounding_scores.append(iou) else: # Use heatmap quality as proxy for grounding grounding_scores.append(quality_score * 0.5) # Lower weight without GT # Object detection accuracy pred_objects = [] if 'image_analysis' in pred and 'cultural_objects' in pred['image_analysis']: pred_objects = pred['image_analysis']['cultural_objects'] elif 'cultural_objects' in pred: pred_objects = pred['cultural_objects'] gt_objects = [] if 'image_analysis' in gt and 'cultural_objects' in gt['image_analysis']: gt_objects = gt['image_analysis']['cultural_objects'] elif 'cultural_objects' in gt: gt_objects = gt['cultural_objects'] if gt_objects or pred_objects: detection_acc = self.calculate_detection_accuracy(pred_objects, gt_objects) detection_accuracy.append(detection_acc) return { 'visual_grounding': np.mean(grounding_scores) if grounding_scores else 0.0, 'detection_accuracy': np.mean(detection_accuracy) if detection_accuracy else 0.0, 'heatmap_quality': np.mean(heatmap_quality) if heatmap_quality else 0.0, 'num_grounding_samples': len(grounding_scores), 'num_detection_samples': len(detection_accuracy) } def extract_text_from_prediction(self, prediction: Dict) -> str: """Extract text from prediction for evaluation - IMPROVED""" texts = [] # Extract from questions if 'questions' in prediction: for q in prediction['questions']: if 'explanation' in q and q['explanation']: texts.append(str(q['explanation'])) if 'answer' in q and q['answer']: texts.append(str(q['answer'])) if 'question' in q and q['question']: texts.append(str(q['question'])) # Extract from vietnamese_explanation if 'vietnamese_explanation' in prediction and prediction['vietnamese_explanation']: texts.append(str(prediction['vietnamese_explanation'])) # Extract from image analysis if 'image_analysis' in prediction: analysis = prediction['image_analysis'] if 'vietnamese_text' in analysis: texts.extend([str(t) for t in analysis['vietnamese_text'] if t]) return ' '.join(texts) def extract_text_from_ground_truth(self, ground_truth: Dict) -> str: """Extract text from ground truth for evaluation - IMPROVED""" texts = [] # Extract from questions if 'questions' in ground_truth: for q in ground_truth['questions']: if 'explanation' in q and q['explanation']: texts.append(str(q['explanation'])) if 'answer' in q and q['answer']: texts.append(str(q['answer'])) if 'question' in q and q['question']: texts.append(str(q['question'])) # Extract from image analysis if 'image_analysis' in ground_truth: analysis = ground_truth['image_analysis'] if 'vietnamese_text' in analysis: texts.extend([str(t) for t in analysis['vietnamese_text'] if t]) return ' '.join(texts) def extract_cultural_objects(self, data: Dict) -> set: """Extract cultural objects mentioned in data - IMPROVED""" cultural_objects = set() # Get all text from the data text = "" if 'questions' in data: text = self.extract_text_from_prediction(data) else: text = self.extract_text_from_ground_truth(data) text_lower = text.lower() # Find cultural terms in text for cultural_term in self.cultural_vocabulary: if cultural_term in text_lower: cultural_objects.add(cultural_term) # Also check explicit cultural_objects fields if 'cultural_objects' in data: for obj in data['cultural_objects']: cultural_objects.add(str(obj).lower()) if 'image_analysis' in data and 'cultural_objects' in data['image_analysis']: for obj in data['image_analysis']['cultural_objects']: cultural_objects.add(str(obj).lower()) return cultural_objects def evaluate_cultural_context_accuracy(self, prediction: Dict, ground_truth: Dict) -> float: """Evaluate accuracy of cultural context understanding - IMPROVED""" # Extract cultural explanations pred_text = self.extract_text_from_prediction(prediction) gt_text = self.extract_text_from_ground_truth(ground_truth) if not pred_text or not gt_text: return 0.0 # Clean texts pred_clean = self.clean_vietnamese_text(pred_text) gt_clean = self.clean_vietnamese_text(gt_text) if not pred_clean or not gt_clean: return 0.0 try: # Use semantic similarity for cultural context evaluation pred_embedding = self.sentence_model.encode([pred_clean]) gt_embedding = self.sentence_model.encode([gt_clean]) # Calculate cosine similarity similarity = np.dot(pred_embedding[0], gt_embedding[0]) / ( np.linalg.norm(pred_embedding[0]) * np.linalg.norm(gt_embedding[0]) ) return max(0.0, float(similarity)) # Ensure non-negative except Exception as e: logger.warning(f"Cultural context accuracy calculation failed: {e}") return 0.0 def calculate_grounding_accuracy(self, pred_heatmap: np.ndarray, gt_regions: List) -> float: """Calculate visual grounding accuracy""" if len(gt_regions) == 0 or pred_heatmap.size == 0: return 0.0 try: # Ensure heatmap is 2D if pred_heatmap.ndim > 2: pred_heatmap = pred_heatmap.reshape(-1, pred_heatmap.shape[-1]) # Create ground truth mask gt_mask = np.zeros_like(pred_heatmap) for region in gt_regions: if isinstance(region, (list, tuple)) and len(region) >= 4: x, y, w, h = region[:4] x, y, w, h = int(x), int(y), int(w), int(h) # Ensure bounds x = max(0, min(x, gt_mask.shape[1] - 1)) y = max(0, min(y, gt_mask.shape[0] - 1)) w = max(1, min(w, gt_mask.shape[1] - x)) h = max(1, min(h, gt_mask.shape[0] - y)) gt_mask[y:y+h, x:x+w] = 1 # Threshold prediction heatmap pred_mask = (pred_heatmap > 0.5).astype(np.float32) # Calculate IoU intersection = np.logical_and(pred_mask, gt_mask).sum() union = np.logical_or(pred_mask, gt_mask).sum() return float(intersection / union) if union > 0 else 0.0 except Exception as e: logger.warning(f"Grounding accuracy calculation failed: {e}") return 0.0 def calculate_detection_accuracy(self, pred_objects: List, gt_objects: List) -> float: """Calculate object detection accuracy - IMPROVED""" if not gt_objects and not pred_objects: return 1.0 if not gt_objects: return 0.0 if pred_objects else 1.0 # Convert to lowercase and clean pred_set = set(str(obj).lower().strip() for obj in pred_objects if obj) gt_set = set(str(obj).lower().strip() for obj in gt_objects if obj) if not gt_set: return 1.0 if not pred_set else 0.0 # Calculate Jaccard similarity (IoU for sets) intersection = len(pred_set.intersection(gt_set)) union = len(pred_set.union(gt_set)) return intersection / union if union > 0 else 0.0 def calculate_f1(self, precision: float, recall: float) -> float: """Calculate F1 score""" if precision + recall == 0: return 0.0 return 2 * (precision * recall) / (precision + recall) def calculate_overall_performance(self, results: Dict) -> Dict: """Calculate overall performance metrics - IMPROVED""" # Weight different aspects weights = { 'language_quality': 0.4, # Increased weight 'cultural_relevance': 0.4, # Increased weight 'visual_grounding': 0.2 # Decreased weight (often no GT data) } # Calculate weighted average using multiple metrics overall_score = 0.0 component_scores = {} for aspect, weight in weights.items(): if aspect in results: if aspect == 'language_quality': # Average of ROUGE-L and BLEU (ROUGE usually more reliable for Vietnamese) rouge_l = results[aspect].get('rougeL', 0.0) bleu = results[aspect].get('bleu', 0.0) score = (rouge_l * 0.7 + bleu * 0.3) # Weight ROUGE-L higher elif aspect == 'cultural_relevance': # Average of multiple cultural metrics cult_acc = results[aspect].get('cultural_accuracy', 0.0) cult_f1 = results[aspect].get('cultural_f1', 0.0) mention_acc = results[aspect].get('cultural_mention_accuracy', 0.0) score = (cult_acc * 0.4 + cult_f1 * 0.3 + mention_acc * 0.3) elif aspect == 'visual_grounding': # Average of grounding metrics grounding = results[aspect].get('visual_grounding', 0.0) detection = results[aspect].get('detection_accuracy', 0.0) heatmap_q = results[aspect].get('heatmap_quality', 0.0) score = (grounding * 0.4 + detection * 0.4 + heatmap_q * 0.2) component_scores[aspect] = score overall_score += weight * score return { 'overall_score': overall_score, 'component_scores': component_scores, 'weights': weights } def generate_evaluation_report(self, results: Dict, save_path: str = None) -> str: """Generate comprehensive evaluation report - IMPROVED""" report = f""" VietMEAgent Evaluation Report {'='*50} Language Quality: BLEU Score: {results['language_quality']['bleu']:.4f} ROUGE-1: {results['language_quality']['rouge1']:.4f} ROUGE-2: {results['language_quality']['rouge2']:.4f} ROUGE-L: {results['language_quality']['rougeL']:.4f} Samples Evaluated: {results['language_quality']['num_evaluated']} Cultural Relevance: Cultural Precision: {results['cultural_relevance']['cultural_precision']:.4f} Cultural Recall: {results['cultural_relevance']['cultural_recall']:.4f} Cultural F1: {results['cultural_relevance']['cultural_f1']:.4f} Cultural Accuracy: {results['cultural_relevance']['cultural_accuracy']:.4f} Cultural Mention Accuracy: {results['cultural_relevance']['cultural_mention_accuracy']:.4f} Cultural Samples: {results['cultural_relevance']['num_cultural_samples']} Visual Grounding: Grounding Accuracy: {results['visual_grounding']['visual_grounding']:.4f} Detection Accuracy: {results['visual_grounding']['detection_accuracy']:.4f} Heatmap Quality: {results['visual_grounding']['heatmap_quality']:.4f} Grounding Samples: {results['visual_grounding']['num_grounding_samples']} Detection Samples: {results['visual_grounding']['num_detection_samples']} Overall Performance: Overall Score: {results['overall_performance']['overall_score']:.4f} Component Scores: {results['overall_performance']['component_scores']} {'='*50} """ if save_path: with open(save_path, 'w', encoding='utf-8') as f: f.write(report) logger.info(f"Evaluation report saved to {save_path}") return report def plot_evaluation_results(self, results: Dict, save_path: str = None): """Plot evaluation results - IMPROVED""" # Create subplots fig, axes = plt.subplots(2, 2, figsize=(15, 10)) # Language Quality lang_metrics = ['bleu', 'rouge1', 'rouge2', 'rougeL'] lang_scores = [results['language_quality'][m] for m in lang_metrics] axes[0, 0].bar(lang_metrics, lang_scores, color='skyblue') axes[0, 0].set_title('Language Quality Metrics') axes[0, 0].set_ylim(0, 1) axes[0, 0].tick_params(axis='x', rotation=45) # Cultural Relevance cult_metrics = ['cultural_precision', 'cultural_recall', 'cultural_f1', 'cultural_accuracy'] cult_scores = [results['cultural_relevance'][m] for m in cult_metrics] axes[0, 1].bar(cult_metrics, cult_scores, color='lightcoral') axes[0, 1].set_title('Cultural Relevance Metrics') axes[0, 1].set_ylim(0, 1) axes[0, 1].tick_params(axis='x', rotation=45) # Visual Grounding visual_metrics = ['visual_grounding', 'detection_accuracy', 'heatmap_quality'] visual_scores = [results['visual_grounding'][m] for m in visual_metrics] axes[1, 0].bar(visual_metrics, visual_scores, color='lightgreen') axes[1, 0].set_title('Visual Grounding Metrics') axes[1, 0].set_ylim(0, 1) axes[1, 0].tick_params(axis='x', rotation=45) # Overall comparison overall_metrics = ['Language Quality', 'Cultural Relevance', 'Visual Grounding'] component_scores = results['overall_performance']['component_scores'] overall_scores = [ component_scores.get('language_quality', 0), component_scores.get('cultural_relevance', 0), component_scores.get('visual_grounding', 0) ] axes[1, 1].bar(overall_metrics, overall_scores, color='gold') axes[1, 1].set_title('Overall Performance Comparison') axes[1, 1].set_ylim(0, 1) axes[1, 1].tick_params(axis='x', rotation=45) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') logger.info(f"Evaluation plots saved to {save_path}") plt.show() return fig