32x_Quantum_NLP / src /cst /classical /evaluation_framework.py
melhelbawi's picture
feat: establish Quantum-Enhanced CST project with core components, training pipelines, and evaluation utilities, and update README.md.
94c2e42
# CST / QCST Dual License
# Non-commercial research use only.
# Commercial use requires explicit permission.
# Copyright (c) 2025 Mohamed Mohamed Elhelbawi
# All rights reserved.
# See LICENSE file in the project root for full license information.
"""
Evaluation Framework for CST Models
Comprehensive benchmarking suite for semantic disambiguation, efficiency, and multimodal tasks
"""
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import time
import json
import logging
from typing import Dict, List, Optional, Any, Tuple
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from scipy.stats import pearsonr, spearmanr
from cst_transformer import CSTransformer
from config import CSTConfig
logger = logging.getLogger(__name__)
class PerformanceProfiler:
"""Performance profiling utilities"""
def __init__(self):
self.reset()
def reset(self):
self.timings = defaultdict(list)
self.memory_usage = defaultdict(list)
self.counters = defaultdict(int)
def time_operation(self, operation_name: str):
"""Context manager for timing operations"""
return self.TimingContext(self, operation_name)
class TimingContext:
def __init__(self, profiler, operation_name):
self.profiler = profiler
self.operation_name = operation_name
self.start_time = None
self.start_memory = None
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.synchronize()
self.start_memory = torch.cuda.memory_allocated()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if torch.cuda.is_available():
torch.cuda.synchronize()
end_memory = torch.cuda.memory_allocated()
memory_diff = end_memory - (self.start_memory or 0)
self.profiler.memory_usage[self.operation_name].append(memory_diff)
end_time = time.time()
elapsed = end_time - self.start_time
self.profiler.timings[self.operation_name].append(elapsed)
self.profiler.counters[self.operation_name] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get performance statistics"""
stats = {}
for op_name, times in self.timings.items():
stats[f"{op_name}_timing"] = {
'mean': np.mean(times),
'std': np.std(times),
'min': np.min(times),
'max': np.max(times),
'p50': np.percentile(times, 50),
'p95': np.percentile(times, 95),
'p99': np.percentile(times, 99),
'total': np.sum(times),
'count': len(times)
}
for op_name, memory in self.memory_usage.items():
if memory:
stats[f"{op_name}_memory"] = {
'mean_mb': np.mean(memory) / 1024 / 1024,
'max_mb': np.max(memory) / 1024 / 1024,
'total_mb': np.sum(memory) / 1024 / 1024
}
return stats
class WordSenseDisambiguationEvaluator:
"""Evaluator for Word Sense Disambiguation tasks"""
def __init__(self, model: CSTransformer, config: CSTConfig):
self.model = model
self.config = config
self.profiler = PerformanceProfiler()
def evaluate_on_semeval(self, dataset_path: str) -> Dict[str, Any]:
"""Evaluate on SemEval WSD datasets"""
# Load SemEval data (simplified - adapt to actual format)
test_data = self._load_semeval_data(dataset_path)
results = {
'predictions': [],
'ground_truth': [],
'ambiguous_words': [],
'context_lengths': [],
'prediction_confidences': []
}
self.model.eval()
with torch.no_grad():
for item in test_data:
with self.profiler.time_operation('wsd_inference'):
prediction, confidence = self._predict_word_sense(
item['sentence'],
item['target_word'],
item['target_position'],
item['sense_candidates']
)
results['predictions'].append(prediction)
results['ground_truth'].append(item['correct_sense'])
results['ambiguous_words'].append(item['target_word'])
results['context_lengths'].append(len(item['sentence'].split()))
results['prediction_confidences'].append(confidence)
# Compute metrics
accuracy = accuracy_score(results['ground_truth'], results['predictions'])
f1 = f1_score(results['ground_truth'], results['predictions'], average='weighted')
# Per-word analysis
word_accuracies = self._compute_per_word_accuracy(
results['ambiguous_words'],
results['ground_truth'],
results['predictions']
)
# Context length analysis
context_analysis = self._analyze_by_context_length(
results['context_lengths'],
results['ground_truth'],
results['predictions']
)
return {
'overall_accuracy': accuracy,
'weighted_f1': f1,
'per_word_accuracy': word_accuracies,
'context_length_analysis': context_analysis,
'performance_stats': self.profiler.get_stats(),
'num_samples': len(test_data)
}
def _load_semeval_data(self, dataset_path: str) -> List[Dict[str, Any]]:
"""Load SemEval WSD data - simplified implementation"""
# This is a placeholder - implement based on actual SemEval format
synthetic_data = []
ambiguous_words = ['bank', 'plant', 'scale', 'rock', 'bark', 'crown', 'mouse', 'bat']
senses = {
'bank': ['financial_institution', 'river_side'],
'plant': ['factory', 'vegetation'],
'scale': ['measurement', 'fish_covering'],
'rock': ['stone', 'music_genre'],
'bark': ['dog_sound', 'tree_covering'],
'crown': ['royal_headwear', 'tooth_covering'],
'mouse': ['computer_device', 'animal'],
'bat': ['sports_equipment', 'flying_mammal']
}
for i in range(200): # Generate 200 test cases
word = np.random.choice(ambiguous_words)
sense_candidates = senses[word]
correct_sense = np.random.choice(sense_candidates)
# Generate context sentence
if word == 'bank':
if correct_sense == 'financial_institution':
sentence = f"I went to the {word} to deposit money and check my account balance."
else:
sentence = f"We sat by the {word} of the river watching the sunset."
elif word == 'plant':
if correct_sense == 'factory':
sentence = f"The manufacturing {word} operates 24 hours a day."
else:
sentence = f"This {word} needs water and sunlight to grow properly."
else:
sentence = f"The {word} is important in this context for disambiguation."
target_pos = sentence.split().index(word)
synthetic_data.append({
'sentence': sentence,
'target_word': word,
'target_position': target_pos,
'sense_candidates': sense_candidates,
'correct_sense': correct_sense
})
return synthetic_data
def _predict_word_sense(self, sentence: str, target_word: str,
target_position: int, sense_candidates: List[str]) -> Tuple[str, float]:
"""Predict word sense using CST model"""
# Tokenize sentence (simplified)
words = sentence.split()
# Create input for model
input_ids = torch.tensor([[hash(w) % self.config.vocab_size for w in words]], dtype=torch.long)
# Create context data emphasizing the target word
context_data = {
'document_embedding': torch.randn(1, self.config.raw_doc_dim),
'metadata': {
'author': torch.tensor([0]),
'domain': torch.tensor([0]),
'timestamp': torch.tensor([0.0])
}
}
# Get model outputs
outputs = self.model(input_ids, context_data)
# Extract representation for target word
target_repr = outputs['hidden_states'][0, target_position] # [d_model]
# Compute similarity with sense embeddings (simplified)
sense_scores = []
for sense in sense_candidates:
# Generate sense embedding (in practice, use pre-trained sense embeddings)
sense_embedding = torch.randn(self.config.d_model)
similarity = F.cosine_similarity(target_repr, sense_embedding, dim=0)
sense_scores.append(similarity.item())
# Predict sense with highest similarity
best_sense_idx = np.argmax(sense_scores)
confidence = torch.softmax(torch.tensor(sense_scores), dim=0)[best_sense_idx].item()
return sense_candidates[best_sense_idx], confidence
def _compute_per_word_accuracy(self, words: List[str],
ground_truth: List[str],
predictions: List[str]) -> Dict[str, float]:
"""Compute accuracy for each ambiguous word"""
word_results = defaultdict(lambda: {'correct': 0, 'total': 0})
for word, gt, pred in zip(words, ground_truth, predictions):
word_results[word]['total'] += 1
if gt == pred:
word_results[word]['correct'] += 1
return {word: stats['correct'] / stats['total']
for word, stats in word_results.items()}
def _analyze_by_context_length(self, context_lengths: List[int],
ground_truth: List[str],
predictions: List[str]) -> Dict[str, float]:
"""Analyze performance by context length"""
length_buckets = [(0, 10), (10, 20), (20, 30), (30, float('inf'))]
bucket_results = {}
for min_len, max_len in length_buckets:
mask = [(min_len <= length < max_len) for length in context_lengths]
if not any(mask):
continue
bucket_gt = [gt for gt, m in zip(ground_truth, mask) if m]
bucket_pred = [pred for pred, m in zip(predictions, mask) if m]
if bucket_gt:
accuracy = accuracy_score(bucket_gt, bucket_pred)
bucket_name = f"{min_len}-{max_len if max_len != float('inf') else '∞'}"
bucket_results[bucket_name] = accuracy
return bucket_results
class EfficiencyEvaluator:
"""Evaluator for computational efficiency"""
def __init__(self, cst_model: CSTransformer, baseline_models: Dict[str, Any]):
self.cst_model = cst_model
self.baseline_models = baseline_models
self.profiler = PerformanceProfiler()
def benchmark_inference_speed(self, test_sequences: List[torch.Tensor],
context_data_list: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Benchmark inference speed comparison"""
results = {
'cst_model': {'times': [], 'memory': []},
'baselines': {name: {'times': [], 'memory': []} for name in self.baseline_models}
}
# Benchmark CST model
self.cst_model.eval()
with torch.no_grad():
for seq, context_data in zip(test_sequences, context_data_list):
with self.profiler.time_operation('cst_inference'):
_ = self.cst_model(seq.unsqueeze(0), context_data)
cst_stats = self.profiler.get_stats()
results['cst_model'] = cst_stats.get('cst_inference_timing', {})
# Benchmark baseline models
for name, baseline_model in self.baseline_models.items():
baseline_model.eval()
profiler = PerformanceProfiler()
with torch.no_grad():
for seq in test_sequences:
with profiler.time_operation('baseline_inference'):
if hasattr(baseline_model, 'forward'):
_ = baseline_model(seq.unsqueeze(0))
else:
# Handle different baseline interfaces
_ = baseline_model(seq.unsqueeze(0))
baseline_stats = profiler.get_stats()
results['baselines'][name] = baseline_stats.get('baseline_inference_timing', {})
# Compute relative performance
cst_mean_time = results['cst_model'].get('mean', 0)
relative_performance = {}
for name, stats in results['baselines'].items():
baseline_mean_time = stats.get('mean', 0)
if baseline_mean_time > 0:
relative_performance[name] = cst_mean_time / baseline_mean_time
results['relative_performance'] = relative_performance
return results
def analyze_cache_performance(self, test_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Analyze CST cache performance"""
# Clear cache and enable profiling
self.cst_model.cst_module.clear_cache()
self.cst_model.enable_cst_profiling(True)
cache_stats_over_time = []
self.cst_model.eval()
with torch.no_grad():
for i, item in enumerate(test_data):
_ = self.cst_model(
item['input_ids'].unsqueeze(0),
item['context_data']
)
if i % 100 == 0: # Sample every 100 steps
stats = self.cst_model.get_cst_stats()
cache_stats_over_time.append({
'step': i,
'hit_rate': stats.get('hit_rate', 0),
'cache_size': stats.get('cache_size', 0),
'ambiguous_ratio': stats.get('ambiguous_ratio', 0)
})
final_stats = self.cst_model.get_cst_stats()
return {
'final_cache_stats': final_stats,
'cache_evolution': cache_stats_over_time,
'cache_efficiency_analysis': self._analyze_cache_efficiency(cache_stats_over_time)
}
def _analyze_cache_efficiency(self, cache_evolution: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Analyze cache efficiency over time"""
if not cache_evolution:
return {}
hit_rates = [stats['hit_rate'] for stats in cache_evolution]
cache_sizes = [stats['cache_size'] for stats in cache_evolution]
return {
'hit_rate_trend': {
'initial': hit_rates[0] if hit_rates else 0,
'final': hit_rates[-1] if hit_rates else 0,
'peak': max(hit_rates) if hit_rates else 0,
'mean': np.mean(hit_rates) if hit_rates else 0
},
'cache_utilization': {
'mean_size': np.mean(cache_sizes) if cache_sizes else 0,
'max_size': max(cache_sizes) if cache_sizes else 0,
'final_size': cache_sizes[-1] if cache_sizes else 0
}
}
class MultimodalEvaluator:
"""Evaluator for multimodal understanding tasks"""
def __init__(self, model: CSTransformer, config: CSTConfig):
self.model = model
self.config = config
self.profiler = PerformanceProfiler()
def evaluate_visual_question_answering(self, vqa_dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Evaluate on Visual Question Answering tasks"""
results = {
'predictions': [],
'ground_truth': [],
'question_types': [],
'confidence_scores': []
}
self.model.eval()
with torch.no_grad():
for item in vqa_dataset:
with self.profiler.time_operation('vqa_inference'):
prediction, confidence = self._answer_visual_question(
item['question'],
item['image_features'],
item['answer_candidates']
)
results['predictions'].append(prediction)
results['ground_truth'].append(item['correct_answer'])
results['question_types'].append(item.get('question_type', 'unknown'))
results['confidence_scores'].append(confidence)
# Compute metrics
accuracy = accuracy_score(results['ground_truth'], results['predictions'])
# Per question type analysis
type_analysis = self._analyze_by_question_type(
results['question_types'],
results['ground_truth'],
results['predictions']
)
return {
'overall_accuracy': accuracy,
'per_question_type': type_analysis,
'mean_confidence': np.mean(results['confidence_scores']),
'performance_stats': self.profiler.get_stats()
}
def _answer_visual_question(self, question: str, image_features: torch.Tensor,
answer_candidates: List[str]) -> Tuple[str, float]:
"""Answer a visual question using CST model"""
# Tokenize question
question_words = question.split()
input_ids = torch.tensor([[hash(w) % self.config.vocab_size for w in question_words]], dtype=torch.long)
# Create multimodal context
context_data = {
'document_embedding': torch.randn(1, self.config.raw_doc_dim),
'metadata': {
'author': torch.tensor([0]),
'domain': torch.tensor([0]), # VQA domain
'timestamp': torch.tensor([0.0])
},
'multimodal': {
'image_clip': image_features.unsqueeze(0) # [1, clip_dim]
}
}
# Get model representation
outputs = self.model(input_ids, context_data)
question_repr = outputs['hidden_states'].mean(dim=1) # Pool over sequence
# Compare with answer embeddings (simplified)
answer_scores = []
for answer in answer_candidates:
answer_embedding = torch.randn(self.config.d_model) # Placeholder
similarity = F.cosine_similarity(question_repr.squeeze(), answer_embedding, dim=0)
answer_scores.append(similarity.item())
best_answer_idx = np.argmax(answer_scores)
confidence = torch.softmax(torch.tensor(answer_scores), dim=0)[best_answer_idx].item()
return answer_candidates[best_answer_idx], confidence
def _analyze_by_question_type(self, question_types: List[str],
ground_truth: List[str],
predictions: List[str]) -> Dict[str, float]:
"""Analyze VQA performance by question type"""
type_results = defaultdict(lambda: {'correct': 0, 'total': 0})
for qtype, gt, pred in zip(question_types, ground_truth, predictions):
type_results[qtype]['total'] += 1
if gt == pred:
type_results[qtype]['correct'] += 1
return {qtype: stats['correct'] / stats['total']
for qtype, stats in type_results.items()}
class ComprehensiveEvaluator:
"""Main evaluator that orchestrates all evaluation tasks"""
def __init__(self, cst_model: CSTransformer, baseline_models: Dict[str, Any], config: CSTConfig):
self.cst_model = cst_model
self.baseline_models = baseline_models
self.config = config
# Initialize sub-evaluators
self.wsd_evaluator = WordSenseDisambiguationEvaluator(cst_model, config)
self.efficiency_evaluator = EfficiencyEvaluator(cst_model, baseline_models)
self.multimodal_evaluator = MultimodalEvaluator(cst_model, config)
def run_full_evaluation(self, test_datasets: Dict[str, Any]) -> Dict[str, Any]:
"""Run comprehensive evaluation across all tasks"""
results = {}
# 1. Word Sense Disambiguation
if 'wsd' in test_datasets:
logger.info("Running Word Sense Disambiguation evaluation...")
wsd_results = self.wsd_evaluator.evaluate_on_semeval(test_datasets['wsd'])
results['word_sense_disambiguation'] = wsd_results
# 2. Efficiency Benchmarking
if 'efficiency' in test_datasets:
logger.info("Running efficiency benchmarks...")
efficiency_results = self.efficiency_evaluator.benchmark_inference_speed(
test_datasets['efficiency']['sequences'],
test_datasets['efficiency']['context_data']
)
results['efficiency'] = efficiency_results
# 3. Cache Performance Analysis
if 'cache_test' in test_datasets:
logger.info("Analyzing cache performance...")
cache_results = self.efficiency_evaluator.analyze_cache_performance(
test_datasets['cache_test']
)
results['cache_performance'] = cache_results
# 4. Multimodal Tasks
if 'vqa' in test_datasets:
logger.info("Running Visual Question Answering evaluation...")
vqa_results = self.multimodal_evaluator.evaluate_visual_question_answering(
test_datasets['vqa']
)
results['visual_question_answering'] = vqa_results
# 5. GLUE-style benchmarks
if 'glue' in test_datasets:
logger.info("Running GLUE benchmark tasks...")
glue_results = self.evaluate_glue_tasks(test_datasets['glue'])
results['glue_benchmark'] = glue_results
# 6. Generate comprehensive report
report = self.generate_evaluation_report(results)
results['comprehensive_report'] = report
return results
def evaluate_glue_tasks(self, glue_data: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluate on GLUE-style tasks"""
glue_results = {}
for task_name, task_data in glue_data.items():
logger.info(f"Evaluating on {task_name}...")
# Create task-specific model if needed
task_model = CSTransformer(self.config, task_type='classification')
task_model.load_state_dict(self.cst_model.state_dict(), strict=False)
predictions = []
ground_truth = []
task_model.eval()
with torch.no_grad():
for item in task_data:
# Prepare input
input_ids = torch.tensor([item['input_ids']], dtype=torch.long)
context_data = item['context_data']
# Forward pass
outputs = task_model(input_ids, context_data)
# Get prediction
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
predictions.append(prediction)
ground_truth.append(item['label'])
# Compute metrics
accuracy = accuracy_score(ground_truth, predictions)
f1 = f1_score(ground_truth, predictions, average='weighted')
glue_results[task_name] = {
'accuracy': accuracy,
'f1_score': f1,
'num_samples': len(task_data)
}
return glue_results
def generate_evaluation_report(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Generate comprehensive evaluation report"""
report = {
'summary': {},
'detailed_analysis': {},
'comparisons': {},
'recommendations': []
}
# Summary metrics
if 'word_sense_disambiguation' in results:
wsd_acc = results['word_sense_disambiguation']['overall_accuracy']
report['summary']['wsd_accuracy'] = wsd_acc
if wsd_acc > 0.8:
report['recommendations'].append("Excellent WSD performance - suitable for disambiguation tasks")
elif wsd_acc > 0.6:
report['recommendations'].append("Good WSD performance - consider fine-tuning for critical applications")
else:
report['recommendations'].append("WSD performance needs improvement - review training data and context features")
# Efficiency analysis
if 'efficiency' in results:
rel_perf = results['efficiency'].get('relative_performance', {})
report['summary']['efficiency_vs_baselines'] = rel_perf
avg_slowdown = np.mean(list(rel_perf.values())) if rel_perf else 1.0
if avg_slowdown < 1.5:
report['recommendations'].append("Efficient inference - suitable for production deployment")
elif avg_slowdown < 3.0:
report['recommendations'].append("Moderate overhead - optimize caching for better performance")
else:
report['recommendations'].append("High computational overhead - consider model compression")
# Cache effectiveness
if 'cache_performance' in results:
final_hit_rate = results['cache_performance']['final_cache_stats'].get('hit_rate', 0)
report['summary']['cache_hit_rate'] = final_hit_rate
if final_hit_rate > 0.7:
report['recommendations'].append("Excellent cache performance - dynamic processing is well-optimized")
elif final_hit_rate > 0.4:
report['recommendations'].append("Good cache performance - consider increasing cache size")
else:
report['recommendations'].append("Low cache hit rate - review ambiguity detection or increase cache capacity")
# Multimodal performance
if 'visual_question_answering' in results:
vqa_acc = results['visual_question_answering']['overall_accuracy']
report['summary']['vqa_accuracy'] = vqa_acc
if vqa_acc > 0.6:
report['recommendations'].append("Strong multimodal understanding - CST effectively integrates visual information")
else:
report['recommendations'].append("Multimodal performance needs improvement - enhance image processing pipeline")
# Overall assessment
accuracies = []
if 'wsd_accuracy' in report['summary']:
accuracies.append(report['summary']['wsd_accuracy'])
if 'vqa_accuracy' in report['summary']:
accuracies.append(report['summary']['vqa_accuracy'])
if accuracies:
avg_accuracy = np.mean(accuracies)
report['summary']['overall_performance'] = avg_accuracy
if avg_accuracy > 0.75:
report['summary']['assessment'] = "Excellent overall performance"
elif avg_accuracy > 0.6:
report['summary']['assessment'] = "Good overall performance"
else:
report['summary']['assessment'] = "Performance needs improvement"
return report
def save_results(self, results: Dict[str, Any], output_path: str):
"""Save evaluation results to file"""
# Convert tensors to lists for JSON serialization
def convert_tensors(obj):
if isinstance(obj, torch.Tensor):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_tensors(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_tensors(item) for item in obj]
else:
return obj
serializable_results = convert_tensors(results)
with open(output_path, 'w') as f:
json.dump(serializable_results, f, indent=2)
logger.info(f"Evaluation results saved to {output_path}")
def plot_performance_comparison(self, results: Dict[str, Any], save_path: Optional[str] = None):
"""Generate performance comparison plots"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 1. WSD accuracy by word
if 'word_sense_disambiguation' in results:
wsd_data = results['word_sense_disambiguation']['per_word_accuracy']
words = list(wsd_data.keys())
accuracies = list(wsd_data.values())
axes[0, 0].bar(words, accuracies)
axes[0, 0].set_title('WSD Accuracy by Word')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].tick_params(axis='x', rotation=45)
# 2. Efficiency comparison
if 'efficiency' in results:
rel_perf = results['efficiency']['relative_performance']
models = list(rel_perf.keys())
speedups = [1/perf for perf in rel_perf.values()] # Convert to speedup
axes[0, 1].bar(models, speedups)
axes[0, 1].axhline(y=1, color='r', linestyle='--', alpha=0.5)
axes[0, 1].set_title('Inference Speed Comparison (Speedup vs CST)')
axes[0, 1].set_ylabel('Speedup Factor')
axes[0, 1].tick_params(axis='x', rotation=45)
# 3. Cache performance over time
if 'cache_performance' in results:
cache_evolution = results['cache_performance']['cache_evolution']
steps = [item['step'] for item in cache_evolution]
hit_rates = [item['hit_rate'] for item in cache_evolution]
axes[1, 0].plot(steps, hit_rates, 'b-', linewidth=2)
axes[1, 0].set_title('Cache Hit Rate Over Time')
axes[1, 0].set_xlabel('Steps')
axes[1, 0].set_ylabel('Hit Rate')
axes[1, 0].grid(True, alpha=0.3)
# 4. Overall performance summary
summary_metrics = []
summary_values = []
if 'word_sense_disambiguation' in results:
summary_metrics.append('WSD')
summary_values.append(results['word_sense_disambiguation']['overall_accuracy'])
if 'visual_question_answering' in results:
summary_metrics.append('VQA')
summary_values.append(results['visual_question_answering']['overall_accuracy'])
if 'glue_benchmark' in results:
glue_scores = [task['accuracy'] for task in results['glue_benchmark'].values()]
summary_metrics.append('GLUE Avg')
summary_values.append(np.mean(glue_scores))
if summary_metrics:
axes[1, 1].bar(summary_metrics, summary_values)
axes[1, 1].set_title('Overall Performance Summary')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_ylim(0, 1)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Performance plots saved to {save_path}")
plt.show()
def create_test_datasets(config: CSTConfig) -> Dict[str, Any]:
"""Create synthetic test datasets for evaluation"""
datasets = {}
# WSD dataset
datasets['wsd'] = "synthetic_wsd_data" # Path placeholder
# Efficiency test data
test_sequences = []
context_data_list = []
for i in range(100):
seq_len = np.random.randint(10, 100)
seq = torch.randint(1, config.vocab_size, (seq_len,))
context_data = {
'document_embedding': torch.randn(config.raw_doc_dim),
'metadata': {
'author': torch.randint(0, config.num_authors, (1,)).item(),
'domain': torch.randint(0, config.num_domains, (1,)).item(),
'timestamp': torch.randn(1).item(),
}
}
test_sequences.append(seq)
context_data_list.append(context_data)
datasets['efficiency'] = {
'sequences': test_sequences,
'context_data': context_data_list
}
# Cache test data
cache_test_data = []
for i in range(1000):
seq_len = np.random.randint(5, 50)
input_ids = torch.randint(1, config.vocab_size, (seq_len,))
context_data = {
'document_embedding': torch.randn(config.raw_doc_dim),
'metadata': {
'author': torch.randint(0, 10, (1,)).item(), # Limited authors for cache hits
'domain': torch.randint(0, 5, (1,)).item(), # Limited domains for cache hits
'timestamp': torch.randn(1).item(),
}
}
cache_test_data.append({
'input_ids': input_ids,
'context_data': context_data
})
datasets['cache_test'] = cache_test_data
# VQA dataset
vqa_data = []
for i in range(200):
questions = [
"What color is the object?",
"How many items are visible?",
"What is the person doing?",
"Where is this photo taken?",
"What type of animal is shown?"
]
question = np.random.choice(questions)
image_features = torch.randn(config.clip_dim)
if "color" in question:
answer_candidates = ["red", "blue", "green", "yellow", "black"]
elif "many" in question:
answer_candidates = ["one", "two", "three", "four", "many"]
elif "doing" in question:
answer_candidates = ["walking", "running", "sitting", "standing", "eating"]
elif "where" in question:
answer_candidates = ["park", "home", "office", "street", "beach"]
else:
answer_candidates = ["dog", "cat", "bird", "horse", "elephant"]
vqa_data.append({
'question': question,
'image_features': image_features,
'answer_candidates': answer_candidates,
'correct_answer': np.random.choice(answer_candidates),
'question_type': question.split()[0].lower()
})
datasets['vqa'] = vqa_data
return datasets
def main():
"""Main evaluation script"""
# Setup
config = CSTConfig()
config.ambiguous_word_ids = [1, 5, 10, 15, 20, 25, 30]
# Load models
cst_model = CSTransformer(config, task_type='mlm')
# Create dummy baseline models for comparison
baseline_models = {
'standard_bert': torch.nn.Sequential(
torch.nn.Embedding(config.vocab_size, config.d_model),
torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(config.d_model, 8),
num_layers=6
)
)
}
# Create test datasets
test_datasets = create_test_datasets(config)
# Initialize evaluator
evaluator = ComprehensiveEvaluator(cst_model, baseline_models, config)
# Run evaluation
logger.info("Starting comprehensive evaluation...")
results = evaluator.run_full_evaluation(test_datasets)
# Save results
evaluator.save_results(results, 'cst_evaluation_results.json')
# Generate plots
evaluator.plot_performance_comparison(results, 'cst_performance_plots.png')
# Print summary
if 'comprehensive_report' in results:
report = results['comprehensive_report']
print("\n" + "="*50)
print("CST EVALUATION SUMMARY")
print("="*50)
print(f"Overall Assessment: {report['summary'].get('assessment', 'N/A')}")
print(f"Overall Performance: {report['summary'].get('overall_performance', 0):.3f}")
print(f"WSD Accuracy: {report['summary'].get('wsd_accuracy', 0):.3f}")
print(f"Cache Hit Rate: {report['summary'].get('cache_hit_rate', 0):.3f}")
print(f"VQA Accuracy: {report['summary'].get('vqa_accuracy', 0):.3f}")
print("\nRecommendations:")
for i, rec in enumerate(report['recommendations'], 1):
print(f"{i}. {rec}")
print("="*50)
logger.info("Evaluation completed!")
if __name__ == "__main__":
main()