MMRM / evaluation /metrics.py
rexera's picture
0-shot pipeline test
87224ba
"""
Evaluation metrics for text restoration.
Implements Accuracy, Hits@K, and MRR as described in the paper.
"""
import torch
import numpy as np
from typing import List, Dict
class RestorationMetrics:
"""
Computes evaluation metrics for character restoration.
Metrics from Section 4.3:
- Accuracy: Top-1 exact match
- Hits@K: Correct character in top K predictions
- MRR: Mean Reciprocal Rank
"""
def __init__(self, top_k_values: List[int] = [5, 10, 20]):
"""
Initialize metrics calculator.
Args:
top_k_values: K values for Hits@K metric
"""
self.top_k_values = top_k_values
self.reset()
def reset(self):
"""Reset all accumulated metrics."""
self.total_predictions = 0
self.correct_top1 = 0
self.hits_at_k = {k: 0 for k in self.top_k_values}
self.reciprocal_ranks = []
def update(self, predictions: torch.Tensor, labels: torch.Tensor):
"""
Update metrics with a batch of predictions.
Args:
predictions: Model logits [batch_size, num_masks, vocab_size]
labels: True character IDs [batch_size, num_masks]
"""
# Flatten predictions and labels
# predictions: [B, N, V] -> [B*N, V]
# labels: [B, N] -> [B*N]
predictions_flat = predictions.view(-1, predictions.size(-1))
labels_flat = labels.view(-1)
batch_total = labels_flat.size(0)
self.total_predictions += batch_total
# Get top-K predictions (max k in top_k_values)
max_k = max(self.top_k_values)
_, top_k_indices = torch.topk(predictions_flat, k=max_k, dim=1) # [B*N, max_k]
# Accuracy: top-1 is equal to label
top1_correct = (top_k_indices[:, 0] == labels_flat).sum().item()
self.correct_top1 += top1_correct
# Hits@K: label present in top-k
# Expand labels for comparison: [B*N] -> [B*N, 1] -> [B*N, max_k]
labels_expanded = labels_flat.unsqueeze(1).expand_as(top_k_indices)
matches = (top_k_indices == labels_expanded) # [B*N, max_k]
for k in self.top_k_values:
# Check if match in first k columns
hit_k = matches[:, :k].any(dim=1).sum().item()
self.hits_at_k[k] += hit_k
# MRR: 1.0 / (rank of true label)
# matches has True where top_k_indices == labels_expanded
# We find the index of the True value in each row
match_found, match_indices = matches.max(dim=1) # indices are 0-based
# reciprocal_rank = 1 / (index + 1) if found, else 0
ranks = match_indices.float() + 1.0
reciprocal_ranks = torch.where(match_found, 1.0 / ranks, torch.zeros_like(ranks))
# Accumulate sums (we store individual RR in original code, but sum + count is more memory efficient)
# Note: original code used self.reciprocal_ranks.append(), let's keep it consistent or optimize
# Since the goal is performance, I'll store the sum and count instead of a huge list
if not hasattr(self, 'rr_sum'):
self.rr_sum = 0.0
self.rr_sum += reciprocal_ranks.sum().item()
# Keep reciprocal_ranks list for legacy compute() if it uses np.mean
# but actually np.mean is just sum/len.
# For compatibility with existing compute():
self.reciprocal_ranks.extend(reciprocal_ranks.tolist())
def compute(self) -> Dict[str, float]:
"""
Compute final metrics.
Returns:
Dictionary with keys: 'accuracy', 'hit_5', 'hit_10', 'hit_20', 'mrr'
"""
if self.total_predictions == 0:
return {
'accuracy': 0.0,
**{f'hit_{k}': 0.0 for k in self.top_k_values},
'mrr': 0.0
}
metrics = {
'accuracy': 100.0 * self.correct_top1 / self.total_predictions,
'mrr': 100.0 * np.mean(self.reciprocal_ranks) if self.reciprocal_ranks else 0.0
}
for k in self.top_k_values:
metrics[f'hit_{k}'] = 100.0 * self.hits_at_k[k] / self.total_predictions
return metrics
def get_summary(self) -> str:
"""Get formatted summary string."""
metrics = self.compute()
summary = f"Acc: {metrics['accuracy']:.2f}%"
for k in self.top_k_values:
summary += f" | Hit@{k}: {metrics[f'hit_{k}']:.2f}%"
summary += f" | MRR: {metrics['mrr']:.2f}"
return summary
def evaluate_model(model, dataloader, device, num_samples: int = 1):
"""
Evaluate model on a dataset.
Args:
model: Model to evaluate
dataloader: DataLoader for evaluation data
device: Device to run evaluation on
num_samples: Number of random samplings (30 for final evaluation per paper)
Returns:
Dictionary of averaged metrics over all samplings
"""
model.eval()
all_metrics = []
for sample_idx in range(num_samples):
metrics = RestorationMetrics()
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
mask_positions = batch['mask_positions'].to(device)
damaged_images = batch['damaged_images'].to(device)
labels = batch['labels'].to(device)
# Forward pass
if hasattr(model, 'forward') and 'damaged_images' in str(model.forward.__code__.co_varnames):
# MMRM model
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images)
elif 'mask_positions' in str(model.forward.__code__.co_varnames):
# Language model baseline
text_logits = model(input_ids, attention_mask, mask_positions)
else:
# Image-only baseline
text_logits = model(damaged_images)
# Update metrics
metrics.update(text_logits, labels)
all_metrics.append(metrics.compute())
# Average over all samplings
if num_samples == 1:
return all_metrics[0]
averaged_metrics = {}
for key in all_metrics[0].keys():
averaged_metrics[key] = np.mean([m[key] for m in all_metrics])
averaged_metrics[f'{key}_std'] = np.std([m[key] for m in all_metrics])
return averaged_metrics