Spaces:
Sleeping
Sleeping
feat: establish Quantum-Enhanced CST project with core components, training pipelines, and evaluation utilities, and update README.md.
94c2e42
| # CST / QCST Dual License | |
| # Non-commercial research use only. | |
| # Commercial use requires explicit permission. | |
| # Copyright (c) 2025 Mohamed Mohamed Elhelbawi | |
| # All rights reserved. | |
| # See LICENSE file in the project root for full license information. | |
| """ | |
| Evaluation Framework for CST Models | |
| Comprehensive benchmarking suite for semantic disambiguation, efficiency, and multimodal tasks | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import time | |
| import json | |
| import logging | |
| from typing import Dict, List, Optional, Any, Tuple | |
| from collections import defaultdict | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix | |
| from scipy.stats import pearsonr, spearmanr | |
| from cst_transformer import CSTransformer | |
| from config import CSTConfig | |
| logger = logging.getLogger(__name__) | |
| class PerformanceProfiler: | |
| """Performance profiling utilities""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.timings = defaultdict(list) | |
| self.memory_usage = defaultdict(list) | |
| self.counters = defaultdict(int) | |
| def time_operation(self, operation_name: str): | |
| """Context manager for timing operations""" | |
| return self.TimingContext(self, operation_name) | |
| class TimingContext: | |
| def __init__(self, profiler, operation_name): | |
| self.profiler = profiler | |
| self.operation_name = operation_name | |
| self.start_time = None | |
| self.start_memory = None | |
| def __enter__(self): | |
| self.start_time = time.time() | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.start_memory = torch.cuda.memory_allocated() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| end_memory = torch.cuda.memory_allocated() | |
| memory_diff = end_memory - (self.start_memory or 0) | |
| self.profiler.memory_usage[self.operation_name].append(memory_diff) | |
| end_time = time.time() | |
| elapsed = end_time - self.start_time | |
| self.profiler.timings[self.operation_name].append(elapsed) | |
| self.profiler.counters[self.operation_name] += 1 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics""" | |
| stats = {} | |
| for op_name, times in self.timings.items(): | |
| stats[f"{op_name}_timing"] = { | |
| 'mean': np.mean(times), | |
| 'std': np.std(times), | |
| 'min': np.min(times), | |
| 'max': np.max(times), | |
| 'p50': np.percentile(times, 50), | |
| 'p95': np.percentile(times, 95), | |
| 'p99': np.percentile(times, 99), | |
| 'total': np.sum(times), | |
| 'count': len(times) | |
| } | |
| for op_name, memory in self.memory_usage.items(): | |
| if memory: | |
| stats[f"{op_name}_memory"] = { | |
| 'mean_mb': np.mean(memory) / 1024 / 1024, | |
| 'max_mb': np.max(memory) / 1024 / 1024, | |
| 'total_mb': np.sum(memory) / 1024 / 1024 | |
| } | |
| return stats | |
| class WordSenseDisambiguationEvaluator: | |
| """Evaluator for Word Sense Disambiguation tasks""" | |
| def __init__(self, model: CSTransformer, config: CSTConfig): | |
| self.model = model | |
| self.config = config | |
| self.profiler = PerformanceProfiler() | |
| def evaluate_on_semeval(self, dataset_path: str) -> Dict[str, Any]: | |
| """Evaluate on SemEval WSD datasets""" | |
| # Load SemEval data (simplified - adapt to actual format) | |
| test_data = self._load_semeval_data(dataset_path) | |
| results = { | |
| 'predictions': [], | |
| 'ground_truth': [], | |
| 'ambiguous_words': [], | |
| 'context_lengths': [], | |
| 'prediction_confidences': [] | |
| } | |
| self.model.eval() | |
| with torch.no_grad(): | |
| for item in test_data: | |
| with self.profiler.time_operation('wsd_inference'): | |
| prediction, confidence = self._predict_word_sense( | |
| item['sentence'], | |
| item['target_word'], | |
| item['target_position'], | |
| item['sense_candidates'] | |
| ) | |
| results['predictions'].append(prediction) | |
| results['ground_truth'].append(item['correct_sense']) | |
| results['ambiguous_words'].append(item['target_word']) | |
| results['context_lengths'].append(len(item['sentence'].split())) | |
| results['prediction_confidences'].append(confidence) | |
| # Compute metrics | |
| accuracy = accuracy_score(results['ground_truth'], results['predictions']) | |
| f1 = f1_score(results['ground_truth'], results['predictions'], average='weighted') | |
| # Per-word analysis | |
| word_accuracies = self._compute_per_word_accuracy( | |
| results['ambiguous_words'], | |
| results['ground_truth'], | |
| results['predictions'] | |
| ) | |
| # Context length analysis | |
| context_analysis = self._analyze_by_context_length( | |
| results['context_lengths'], | |
| results['ground_truth'], | |
| results['predictions'] | |
| ) | |
| return { | |
| 'overall_accuracy': accuracy, | |
| 'weighted_f1': f1, | |
| 'per_word_accuracy': word_accuracies, | |
| 'context_length_analysis': context_analysis, | |
| 'performance_stats': self.profiler.get_stats(), | |
| 'num_samples': len(test_data) | |
| } | |
| def _load_semeval_data(self, dataset_path: str) -> List[Dict[str, Any]]: | |
| """Load SemEval WSD data - simplified implementation""" | |
| # This is a placeholder - implement based on actual SemEval format | |
| synthetic_data = [] | |
| ambiguous_words = ['bank', 'plant', 'scale', 'rock', 'bark', 'crown', 'mouse', 'bat'] | |
| senses = { | |
| 'bank': ['financial_institution', 'river_side'], | |
| 'plant': ['factory', 'vegetation'], | |
| 'scale': ['measurement', 'fish_covering'], | |
| 'rock': ['stone', 'music_genre'], | |
| 'bark': ['dog_sound', 'tree_covering'], | |
| 'crown': ['royal_headwear', 'tooth_covering'], | |
| 'mouse': ['computer_device', 'animal'], | |
| 'bat': ['sports_equipment', 'flying_mammal'] | |
| } | |
| for i in range(200): # Generate 200 test cases | |
| word = np.random.choice(ambiguous_words) | |
| sense_candidates = senses[word] | |
| correct_sense = np.random.choice(sense_candidates) | |
| # Generate context sentence | |
| if word == 'bank': | |
| if correct_sense == 'financial_institution': | |
| sentence = f"I went to the {word} to deposit money and check my account balance." | |
| else: | |
| sentence = f"We sat by the {word} of the river watching the sunset." | |
| elif word == 'plant': | |
| if correct_sense == 'factory': | |
| sentence = f"The manufacturing {word} operates 24 hours a day." | |
| else: | |
| sentence = f"This {word} needs water and sunlight to grow properly." | |
| else: | |
| sentence = f"The {word} is important in this context for disambiguation." | |
| target_pos = sentence.split().index(word) | |
| synthetic_data.append({ | |
| 'sentence': sentence, | |
| 'target_word': word, | |
| 'target_position': target_pos, | |
| 'sense_candidates': sense_candidates, | |
| 'correct_sense': correct_sense | |
| }) | |
| return synthetic_data | |
| def _predict_word_sense(self, sentence: str, target_word: str, | |
| target_position: int, sense_candidates: List[str]) -> Tuple[str, float]: | |
| """Predict word sense using CST model""" | |
| # Tokenize sentence (simplified) | |
| words = sentence.split() | |
| # Create input for model | |
| input_ids = torch.tensor([[hash(w) % self.config.vocab_size for w in words]], dtype=torch.long) | |
| # Create context data emphasizing the target word | |
| context_data = { | |
| 'document_embedding': torch.randn(1, self.config.raw_doc_dim), | |
| 'metadata': { | |
| 'author': torch.tensor([0]), | |
| 'domain': torch.tensor([0]), | |
| 'timestamp': torch.tensor([0.0]) | |
| } | |
| } | |
| # Get model outputs | |
| outputs = self.model(input_ids, context_data) | |
| # Extract representation for target word | |
| target_repr = outputs['hidden_states'][0, target_position] # [d_model] | |
| # Compute similarity with sense embeddings (simplified) | |
| sense_scores = [] | |
| for sense in sense_candidates: | |
| # Generate sense embedding (in practice, use pre-trained sense embeddings) | |
| sense_embedding = torch.randn(self.config.d_model) | |
| similarity = F.cosine_similarity(target_repr, sense_embedding, dim=0) | |
| sense_scores.append(similarity.item()) | |
| # Predict sense with highest similarity | |
| best_sense_idx = np.argmax(sense_scores) | |
| confidence = torch.softmax(torch.tensor(sense_scores), dim=0)[best_sense_idx].item() | |
| return sense_candidates[best_sense_idx], confidence | |
| def _compute_per_word_accuracy(self, words: List[str], | |
| ground_truth: List[str], | |
| predictions: List[str]) -> Dict[str, float]: | |
| """Compute accuracy for each ambiguous word""" | |
| word_results = defaultdict(lambda: {'correct': 0, 'total': 0}) | |
| for word, gt, pred in zip(words, ground_truth, predictions): | |
| word_results[word]['total'] += 1 | |
| if gt == pred: | |
| word_results[word]['correct'] += 1 | |
| return {word: stats['correct'] / stats['total'] | |
| for word, stats in word_results.items()} | |
| def _analyze_by_context_length(self, context_lengths: List[int], | |
| ground_truth: List[str], | |
| predictions: List[str]) -> Dict[str, float]: | |
| """Analyze performance by context length""" | |
| length_buckets = [(0, 10), (10, 20), (20, 30), (30, float('inf'))] | |
| bucket_results = {} | |
| for min_len, max_len in length_buckets: | |
| mask = [(min_len <= length < max_len) for length in context_lengths] | |
| if not any(mask): | |
| continue | |
| bucket_gt = [gt for gt, m in zip(ground_truth, mask) if m] | |
| bucket_pred = [pred for pred, m in zip(predictions, mask) if m] | |
| if bucket_gt: | |
| accuracy = accuracy_score(bucket_gt, bucket_pred) | |
| bucket_name = f"{min_len}-{max_len if max_len != float('inf') else '∞'}" | |
| bucket_results[bucket_name] = accuracy | |
| return bucket_results | |
| class EfficiencyEvaluator: | |
| """Evaluator for computational efficiency""" | |
| def __init__(self, cst_model: CSTransformer, baseline_models: Dict[str, Any]): | |
| self.cst_model = cst_model | |
| self.baseline_models = baseline_models | |
| self.profiler = PerformanceProfiler() | |
| def benchmark_inference_speed(self, test_sequences: List[torch.Tensor], | |
| context_data_list: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Benchmark inference speed comparison""" | |
| results = { | |
| 'cst_model': {'times': [], 'memory': []}, | |
| 'baselines': {name: {'times': [], 'memory': []} for name in self.baseline_models} | |
| } | |
| # Benchmark CST model | |
| self.cst_model.eval() | |
| with torch.no_grad(): | |
| for seq, context_data in zip(test_sequences, context_data_list): | |
| with self.profiler.time_operation('cst_inference'): | |
| _ = self.cst_model(seq.unsqueeze(0), context_data) | |
| cst_stats = self.profiler.get_stats() | |
| results['cst_model'] = cst_stats.get('cst_inference_timing', {}) | |
| # Benchmark baseline models | |
| for name, baseline_model in self.baseline_models.items(): | |
| baseline_model.eval() | |
| profiler = PerformanceProfiler() | |
| with torch.no_grad(): | |
| for seq in test_sequences: | |
| with profiler.time_operation('baseline_inference'): | |
| if hasattr(baseline_model, 'forward'): | |
| _ = baseline_model(seq.unsqueeze(0)) | |
| else: | |
| # Handle different baseline interfaces | |
| _ = baseline_model(seq.unsqueeze(0)) | |
| baseline_stats = profiler.get_stats() | |
| results['baselines'][name] = baseline_stats.get('baseline_inference_timing', {}) | |
| # Compute relative performance | |
| cst_mean_time = results['cst_model'].get('mean', 0) | |
| relative_performance = {} | |
| for name, stats in results['baselines'].items(): | |
| baseline_mean_time = stats.get('mean', 0) | |
| if baseline_mean_time > 0: | |
| relative_performance[name] = cst_mean_time / baseline_mean_time | |
| results['relative_performance'] = relative_performance | |
| return results | |
| def analyze_cache_performance(self, test_data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Analyze CST cache performance""" | |
| # Clear cache and enable profiling | |
| self.cst_model.cst_module.clear_cache() | |
| self.cst_model.enable_cst_profiling(True) | |
| cache_stats_over_time = [] | |
| self.cst_model.eval() | |
| with torch.no_grad(): | |
| for i, item in enumerate(test_data): | |
| _ = self.cst_model( | |
| item['input_ids'].unsqueeze(0), | |
| item['context_data'] | |
| ) | |
| if i % 100 == 0: # Sample every 100 steps | |
| stats = self.cst_model.get_cst_stats() | |
| cache_stats_over_time.append({ | |
| 'step': i, | |
| 'hit_rate': stats.get('hit_rate', 0), | |
| 'cache_size': stats.get('cache_size', 0), | |
| 'ambiguous_ratio': stats.get('ambiguous_ratio', 0) | |
| }) | |
| final_stats = self.cst_model.get_cst_stats() | |
| return { | |
| 'final_cache_stats': final_stats, | |
| 'cache_evolution': cache_stats_over_time, | |
| 'cache_efficiency_analysis': self._analyze_cache_efficiency(cache_stats_over_time) | |
| } | |
| def _analyze_cache_efficiency(self, cache_evolution: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Analyze cache efficiency over time""" | |
| if not cache_evolution: | |
| return {} | |
| hit_rates = [stats['hit_rate'] for stats in cache_evolution] | |
| cache_sizes = [stats['cache_size'] for stats in cache_evolution] | |
| return { | |
| 'hit_rate_trend': { | |
| 'initial': hit_rates[0] if hit_rates else 0, | |
| 'final': hit_rates[-1] if hit_rates else 0, | |
| 'peak': max(hit_rates) if hit_rates else 0, | |
| 'mean': np.mean(hit_rates) if hit_rates else 0 | |
| }, | |
| 'cache_utilization': { | |
| 'mean_size': np.mean(cache_sizes) if cache_sizes else 0, | |
| 'max_size': max(cache_sizes) if cache_sizes else 0, | |
| 'final_size': cache_sizes[-1] if cache_sizes else 0 | |
| } | |
| } | |
| class MultimodalEvaluator: | |
| """Evaluator for multimodal understanding tasks""" | |
| def __init__(self, model: CSTransformer, config: CSTConfig): | |
| self.model = model | |
| self.config = config | |
| self.profiler = PerformanceProfiler() | |
| def evaluate_visual_question_answering(self, vqa_dataset: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Evaluate on Visual Question Answering tasks""" | |
| results = { | |
| 'predictions': [], | |
| 'ground_truth': [], | |
| 'question_types': [], | |
| 'confidence_scores': [] | |
| } | |
| self.model.eval() | |
| with torch.no_grad(): | |
| for item in vqa_dataset: | |
| with self.profiler.time_operation('vqa_inference'): | |
| prediction, confidence = self._answer_visual_question( | |
| item['question'], | |
| item['image_features'], | |
| item['answer_candidates'] | |
| ) | |
| results['predictions'].append(prediction) | |
| results['ground_truth'].append(item['correct_answer']) | |
| results['question_types'].append(item.get('question_type', 'unknown')) | |
| results['confidence_scores'].append(confidence) | |
| # Compute metrics | |
| accuracy = accuracy_score(results['ground_truth'], results['predictions']) | |
| # Per question type analysis | |
| type_analysis = self._analyze_by_question_type( | |
| results['question_types'], | |
| results['ground_truth'], | |
| results['predictions'] | |
| ) | |
| return { | |
| 'overall_accuracy': accuracy, | |
| 'per_question_type': type_analysis, | |
| 'mean_confidence': np.mean(results['confidence_scores']), | |
| 'performance_stats': self.profiler.get_stats() | |
| } | |
| def _answer_visual_question(self, question: str, image_features: torch.Tensor, | |
| answer_candidates: List[str]) -> Tuple[str, float]: | |
| """Answer a visual question using CST model""" | |
| # Tokenize question | |
| question_words = question.split() | |
| input_ids = torch.tensor([[hash(w) % self.config.vocab_size for w in question_words]], dtype=torch.long) | |
| # Create multimodal context | |
| context_data = { | |
| 'document_embedding': torch.randn(1, self.config.raw_doc_dim), | |
| 'metadata': { | |
| 'author': torch.tensor([0]), | |
| 'domain': torch.tensor([0]), # VQA domain | |
| 'timestamp': torch.tensor([0.0]) | |
| }, | |
| 'multimodal': { | |
| 'image_clip': image_features.unsqueeze(0) # [1, clip_dim] | |
| } | |
| } | |
| # Get model representation | |
| outputs = self.model(input_ids, context_data) | |
| question_repr = outputs['hidden_states'].mean(dim=1) # Pool over sequence | |
| # Compare with answer embeddings (simplified) | |
| answer_scores = [] | |
| for answer in answer_candidates: | |
| answer_embedding = torch.randn(self.config.d_model) # Placeholder | |
| similarity = F.cosine_similarity(question_repr.squeeze(), answer_embedding, dim=0) | |
| answer_scores.append(similarity.item()) | |
| best_answer_idx = np.argmax(answer_scores) | |
| confidence = torch.softmax(torch.tensor(answer_scores), dim=0)[best_answer_idx].item() | |
| return answer_candidates[best_answer_idx], confidence | |
| def _analyze_by_question_type(self, question_types: List[str], | |
| ground_truth: List[str], | |
| predictions: List[str]) -> Dict[str, float]: | |
| """Analyze VQA performance by question type""" | |
| type_results = defaultdict(lambda: {'correct': 0, 'total': 0}) | |
| for qtype, gt, pred in zip(question_types, ground_truth, predictions): | |
| type_results[qtype]['total'] += 1 | |
| if gt == pred: | |
| type_results[qtype]['correct'] += 1 | |
| return {qtype: stats['correct'] / stats['total'] | |
| for qtype, stats in type_results.items()} | |
| class ComprehensiveEvaluator: | |
| """Main evaluator that orchestrates all evaluation tasks""" | |
| def __init__(self, cst_model: CSTransformer, baseline_models: Dict[str, Any], config: CSTConfig): | |
| self.cst_model = cst_model | |
| self.baseline_models = baseline_models | |
| self.config = config | |
| # Initialize sub-evaluators | |
| self.wsd_evaluator = WordSenseDisambiguationEvaluator(cst_model, config) | |
| self.efficiency_evaluator = EfficiencyEvaluator(cst_model, baseline_models) | |
| self.multimodal_evaluator = MultimodalEvaluator(cst_model, config) | |
| def run_full_evaluation(self, test_datasets: Dict[str, Any]) -> Dict[str, Any]: | |
| """Run comprehensive evaluation across all tasks""" | |
| results = {} | |
| # 1. Word Sense Disambiguation | |
| if 'wsd' in test_datasets: | |
| logger.info("Running Word Sense Disambiguation evaluation...") | |
| wsd_results = self.wsd_evaluator.evaluate_on_semeval(test_datasets['wsd']) | |
| results['word_sense_disambiguation'] = wsd_results | |
| # 2. Efficiency Benchmarking | |
| if 'efficiency' in test_datasets: | |
| logger.info("Running efficiency benchmarks...") | |
| efficiency_results = self.efficiency_evaluator.benchmark_inference_speed( | |
| test_datasets['efficiency']['sequences'], | |
| test_datasets['efficiency']['context_data'] | |
| ) | |
| results['efficiency'] = efficiency_results | |
| # 3. Cache Performance Analysis | |
| if 'cache_test' in test_datasets: | |
| logger.info("Analyzing cache performance...") | |
| cache_results = self.efficiency_evaluator.analyze_cache_performance( | |
| test_datasets['cache_test'] | |
| ) | |
| results['cache_performance'] = cache_results | |
| # 4. Multimodal Tasks | |
| if 'vqa' in test_datasets: | |
| logger.info("Running Visual Question Answering evaluation...") | |
| vqa_results = self.multimodal_evaluator.evaluate_visual_question_answering( | |
| test_datasets['vqa'] | |
| ) | |
| results['visual_question_answering'] = vqa_results | |
| # 5. GLUE-style benchmarks | |
| if 'glue' in test_datasets: | |
| logger.info("Running GLUE benchmark tasks...") | |
| glue_results = self.evaluate_glue_tasks(test_datasets['glue']) | |
| results['glue_benchmark'] = glue_results | |
| # 6. Generate comprehensive report | |
| report = self.generate_evaluation_report(results) | |
| results['comprehensive_report'] = report | |
| return results | |
| def evaluate_glue_tasks(self, glue_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Evaluate on GLUE-style tasks""" | |
| glue_results = {} | |
| for task_name, task_data in glue_data.items(): | |
| logger.info(f"Evaluating on {task_name}...") | |
| # Create task-specific model if needed | |
| task_model = CSTransformer(self.config, task_type='classification') | |
| task_model.load_state_dict(self.cst_model.state_dict(), strict=False) | |
| predictions = [] | |
| ground_truth = [] | |
| task_model.eval() | |
| with torch.no_grad(): | |
| for item in task_data: | |
| # Prepare input | |
| input_ids = torch.tensor([item['input_ids']], dtype=torch.long) | |
| context_data = item['context_data'] | |
| # Forward pass | |
| outputs = task_model(input_ids, context_data) | |
| # Get prediction | |
| logits = outputs['logits'] | |
| prediction = torch.argmax(logits, dim=-1).item() | |
| predictions.append(prediction) | |
| ground_truth.append(item['label']) | |
| # Compute metrics | |
| accuracy = accuracy_score(ground_truth, predictions) | |
| f1 = f1_score(ground_truth, predictions, average='weighted') | |
| glue_results[task_name] = { | |
| 'accuracy': accuracy, | |
| 'f1_score': f1, | |
| 'num_samples': len(task_data) | |
| } | |
| return glue_results | |
| def generate_evaluation_report(self, results: Dict[str, Any]) -> Dict[str, Any]: | |
| """Generate comprehensive evaluation report""" | |
| report = { | |
| 'summary': {}, | |
| 'detailed_analysis': {}, | |
| 'comparisons': {}, | |
| 'recommendations': [] | |
| } | |
| # Summary metrics | |
| if 'word_sense_disambiguation' in results: | |
| wsd_acc = results['word_sense_disambiguation']['overall_accuracy'] | |
| report['summary']['wsd_accuracy'] = wsd_acc | |
| if wsd_acc > 0.8: | |
| report['recommendations'].append("Excellent WSD performance - suitable for disambiguation tasks") | |
| elif wsd_acc > 0.6: | |
| report['recommendations'].append("Good WSD performance - consider fine-tuning for critical applications") | |
| else: | |
| report['recommendations'].append("WSD performance needs improvement - review training data and context features") | |
| # Efficiency analysis | |
| if 'efficiency' in results: | |
| rel_perf = results['efficiency'].get('relative_performance', {}) | |
| report['summary']['efficiency_vs_baselines'] = rel_perf | |
| avg_slowdown = np.mean(list(rel_perf.values())) if rel_perf else 1.0 | |
| if avg_slowdown < 1.5: | |
| report['recommendations'].append("Efficient inference - suitable for production deployment") | |
| elif avg_slowdown < 3.0: | |
| report['recommendations'].append("Moderate overhead - optimize caching for better performance") | |
| else: | |
| report['recommendations'].append("High computational overhead - consider model compression") | |
| # Cache effectiveness | |
| if 'cache_performance' in results: | |
| final_hit_rate = results['cache_performance']['final_cache_stats'].get('hit_rate', 0) | |
| report['summary']['cache_hit_rate'] = final_hit_rate | |
| if final_hit_rate > 0.7: | |
| report['recommendations'].append("Excellent cache performance - dynamic processing is well-optimized") | |
| elif final_hit_rate > 0.4: | |
| report['recommendations'].append("Good cache performance - consider increasing cache size") | |
| else: | |
| report['recommendations'].append("Low cache hit rate - review ambiguity detection or increase cache capacity") | |
| # Multimodal performance | |
| if 'visual_question_answering' in results: | |
| vqa_acc = results['visual_question_answering']['overall_accuracy'] | |
| report['summary']['vqa_accuracy'] = vqa_acc | |
| if vqa_acc > 0.6: | |
| report['recommendations'].append("Strong multimodal understanding - CST effectively integrates visual information") | |
| else: | |
| report['recommendations'].append("Multimodal performance needs improvement - enhance image processing pipeline") | |
| # Overall assessment | |
| accuracies = [] | |
| if 'wsd_accuracy' in report['summary']: | |
| accuracies.append(report['summary']['wsd_accuracy']) | |
| if 'vqa_accuracy' in report['summary']: | |
| accuracies.append(report['summary']['vqa_accuracy']) | |
| if accuracies: | |
| avg_accuracy = np.mean(accuracies) | |
| report['summary']['overall_performance'] = avg_accuracy | |
| if avg_accuracy > 0.75: | |
| report['summary']['assessment'] = "Excellent overall performance" | |
| elif avg_accuracy > 0.6: | |
| report['summary']['assessment'] = "Good overall performance" | |
| else: | |
| report['summary']['assessment'] = "Performance needs improvement" | |
| return report | |
| def save_results(self, results: Dict[str, Any], output_path: str): | |
| """Save evaluation results to file""" | |
| # Convert tensors to lists for JSON serialization | |
| def convert_tensors(obj): | |
| if isinstance(obj, torch.Tensor): | |
| return obj.tolist() | |
| elif isinstance(obj, dict): | |
| return {k: convert_tensors(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_tensors(item) for item in obj] | |
| else: | |
| return obj | |
| serializable_results = convert_tensors(results) | |
| with open(output_path, 'w') as f: | |
| json.dump(serializable_results, f, indent=2) | |
| logger.info(f"Evaluation results saved to {output_path}") | |
| def plot_performance_comparison(self, results: Dict[str, Any], save_path: Optional[str] = None): | |
| """Generate performance comparison plots""" | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 12)) | |
| # 1. WSD accuracy by word | |
| if 'word_sense_disambiguation' in results: | |
| wsd_data = results['word_sense_disambiguation']['per_word_accuracy'] | |
| words = list(wsd_data.keys()) | |
| accuracies = list(wsd_data.values()) | |
| axes[0, 0].bar(words, accuracies) | |
| axes[0, 0].set_title('WSD Accuracy by Word') | |
| axes[0, 0].set_ylabel('Accuracy') | |
| axes[0, 0].tick_params(axis='x', rotation=45) | |
| # 2. Efficiency comparison | |
| if 'efficiency' in results: | |
| rel_perf = results['efficiency']['relative_performance'] | |
| models = list(rel_perf.keys()) | |
| speedups = [1/perf for perf in rel_perf.values()] # Convert to speedup | |
| axes[0, 1].bar(models, speedups) | |
| axes[0, 1].axhline(y=1, color='r', linestyle='--', alpha=0.5) | |
| axes[0, 1].set_title('Inference Speed Comparison (Speedup vs CST)') | |
| axes[0, 1].set_ylabel('Speedup Factor') | |
| axes[0, 1].tick_params(axis='x', rotation=45) | |
| # 3. Cache performance over time | |
| if 'cache_performance' in results: | |
| cache_evolution = results['cache_performance']['cache_evolution'] | |
| steps = [item['step'] for item in cache_evolution] | |
| hit_rates = [item['hit_rate'] for item in cache_evolution] | |
| axes[1, 0].plot(steps, hit_rates, 'b-', linewidth=2) | |
| axes[1, 0].set_title('Cache Hit Rate Over Time') | |
| axes[1, 0].set_xlabel('Steps') | |
| axes[1, 0].set_ylabel('Hit Rate') | |
| axes[1, 0].grid(True, alpha=0.3) | |
| # 4. Overall performance summary | |
| summary_metrics = [] | |
| summary_values = [] | |
| if 'word_sense_disambiguation' in results: | |
| summary_metrics.append('WSD') | |
| summary_values.append(results['word_sense_disambiguation']['overall_accuracy']) | |
| if 'visual_question_answering' in results: | |
| summary_metrics.append('VQA') | |
| summary_values.append(results['visual_question_answering']['overall_accuracy']) | |
| if 'glue_benchmark' in results: | |
| glue_scores = [task['accuracy'] for task in results['glue_benchmark'].values()] | |
| summary_metrics.append('GLUE Avg') | |
| summary_values.append(np.mean(glue_scores)) | |
| if summary_metrics: | |
| axes[1, 1].bar(summary_metrics, summary_values) | |
| axes[1, 1].set_title('Overall Performance Summary') | |
| axes[1, 1].set_ylabel('Accuracy') | |
| axes[1, 1].set_ylim(0, 1) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| logger.info(f"Performance plots saved to {save_path}") | |
| plt.show() | |
| def create_test_datasets(config: CSTConfig) -> Dict[str, Any]: | |
| """Create synthetic test datasets for evaluation""" | |
| datasets = {} | |
| # WSD dataset | |
| datasets['wsd'] = "synthetic_wsd_data" # Path placeholder | |
| # Efficiency test data | |
| test_sequences = [] | |
| context_data_list = [] | |
| for i in range(100): | |
| seq_len = np.random.randint(10, 100) | |
| seq = torch.randint(1, config.vocab_size, (seq_len,)) | |
| context_data = { | |
| 'document_embedding': torch.randn(config.raw_doc_dim), | |
| 'metadata': { | |
| 'author': torch.randint(0, config.num_authors, (1,)).item(), | |
| 'domain': torch.randint(0, config.num_domains, (1,)).item(), | |
| 'timestamp': torch.randn(1).item(), | |
| } | |
| } | |
| test_sequences.append(seq) | |
| context_data_list.append(context_data) | |
| datasets['efficiency'] = { | |
| 'sequences': test_sequences, | |
| 'context_data': context_data_list | |
| } | |
| # Cache test data | |
| cache_test_data = [] | |
| for i in range(1000): | |
| seq_len = np.random.randint(5, 50) | |
| input_ids = torch.randint(1, config.vocab_size, (seq_len,)) | |
| context_data = { | |
| 'document_embedding': torch.randn(config.raw_doc_dim), | |
| 'metadata': { | |
| 'author': torch.randint(0, 10, (1,)).item(), # Limited authors for cache hits | |
| 'domain': torch.randint(0, 5, (1,)).item(), # Limited domains for cache hits | |
| 'timestamp': torch.randn(1).item(), | |
| } | |
| } | |
| cache_test_data.append({ | |
| 'input_ids': input_ids, | |
| 'context_data': context_data | |
| }) | |
| datasets['cache_test'] = cache_test_data | |
| # VQA dataset | |
| vqa_data = [] | |
| for i in range(200): | |
| questions = [ | |
| "What color is the object?", | |
| "How many items are visible?", | |
| "What is the person doing?", | |
| "Where is this photo taken?", | |
| "What type of animal is shown?" | |
| ] | |
| question = np.random.choice(questions) | |
| image_features = torch.randn(config.clip_dim) | |
| if "color" in question: | |
| answer_candidates = ["red", "blue", "green", "yellow", "black"] | |
| elif "many" in question: | |
| answer_candidates = ["one", "two", "three", "four", "many"] | |
| elif "doing" in question: | |
| answer_candidates = ["walking", "running", "sitting", "standing", "eating"] | |
| elif "where" in question: | |
| answer_candidates = ["park", "home", "office", "street", "beach"] | |
| else: | |
| answer_candidates = ["dog", "cat", "bird", "horse", "elephant"] | |
| vqa_data.append({ | |
| 'question': question, | |
| 'image_features': image_features, | |
| 'answer_candidates': answer_candidates, | |
| 'correct_answer': np.random.choice(answer_candidates), | |
| 'question_type': question.split()[0].lower() | |
| }) | |
| datasets['vqa'] = vqa_data | |
| return datasets | |
| def main(): | |
| """Main evaluation script""" | |
| # Setup | |
| config = CSTConfig() | |
| config.ambiguous_word_ids = [1, 5, 10, 15, 20, 25, 30] | |
| # Load models | |
| cst_model = CSTransformer(config, task_type='mlm') | |
| # Create dummy baseline models for comparison | |
| baseline_models = { | |
| 'standard_bert': torch.nn.Sequential( | |
| torch.nn.Embedding(config.vocab_size, config.d_model), | |
| torch.nn.TransformerEncoder( | |
| torch.nn.TransformerEncoderLayer(config.d_model, 8), | |
| num_layers=6 | |
| ) | |
| ) | |
| } | |
| # Create test datasets | |
| test_datasets = create_test_datasets(config) | |
| # Initialize evaluator | |
| evaluator = ComprehensiveEvaluator(cst_model, baseline_models, config) | |
| # Run evaluation | |
| logger.info("Starting comprehensive evaluation...") | |
| results = evaluator.run_full_evaluation(test_datasets) | |
| # Save results | |
| evaluator.save_results(results, 'cst_evaluation_results.json') | |
| # Generate plots | |
| evaluator.plot_performance_comparison(results, 'cst_performance_plots.png') | |
| # Print summary | |
| if 'comprehensive_report' in results: | |
| report = results['comprehensive_report'] | |
| print("\n" + "="*50) | |
| print("CST EVALUATION SUMMARY") | |
| print("="*50) | |
| print(f"Overall Assessment: {report['summary'].get('assessment', 'N/A')}") | |
| print(f"Overall Performance: {report['summary'].get('overall_performance', 0):.3f}") | |
| print(f"WSD Accuracy: {report['summary'].get('wsd_accuracy', 0):.3f}") | |
| print(f"Cache Hit Rate: {report['summary'].get('cache_hit_rate', 0):.3f}") | |
| print(f"VQA Accuracy: {report['summary'].get('vqa_accuracy', 0):.3f}") | |
| print("\nRecommendations:") | |
| for i, rec in enumerate(report['recommendations'], 1): | |
| print(f"{i}. {rec}") | |
| print("="*50) | |
| logger.info("Evaluation completed!") | |
| if __name__ == "__main__": | |
| main() |