""" Evaluation metrics for text restoration. Implements Accuracy, Hits@K, and MRR as described in the paper. """ import torch import numpy as np from typing import List, Dict class RestorationMetrics: """ Computes evaluation metrics for character restoration. Metrics from Section 4.3: - Accuracy: Top-1 exact match - Hits@K: Correct character in top K predictions - MRR: Mean Reciprocal Rank """ def __init__(self, top_k_values: List[int] = [5, 10, 20]): """ Initialize metrics calculator. Args: top_k_values: K values for Hits@K metric """ self.top_k_values = top_k_values self.reset() def reset(self): """Reset all accumulated metrics.""" self.total_predictions = 0 self.correct_top1 = 0 self.hits_at_k = {k: 0 for k in self.top_k_values} self.reciprocal_ranks = [] def update(self, predictions: torch.Tensor, labels: torch.Tensor): """ Update metrics with a batch of predictions. Args: predictions: Model logits [batch_size, num_masks, vocab_size] labels: True character IDs [batch_size, num_masks] """ # Flatten predictions and labels # predictions: [B, N, V] -> [B*N, V] # labels: [B, N] -> [B*N] predictions_flat = predictions.view(-1, predictions.size(-1)) labels_flat = labels.view(-1) batch_total = labels_flat.size(0) self.total_predictions += batch_total # Get top-K predictions (max k in top_k_values) max_k = max(self.top_k_values) _, top_k_indices = torch.topk(predictions_flat, k=max_k, dim=1) # [B*N, max_k] # Accuracy: top-1 is equal to label top1_correct = (top_k_indices[:, 0] == labels_flat).sum().item() self.correct_top1 += top1_correct # Hits@K: label present in top-k # Expand labels for comparison: [B*N] -> [B*N, 1] -> [B*N, max_k] labels_expanded = labels_flat.unsqueeze(1).expand_as(top_k_indices) matches = (top_k_indices == labels_expanded) # [B*N, max_k] for k in self.top_k_values: # Check if match in first k columns hit_k = matches[:, :k].any(dim=1).sum().item() self.hits_at_k[k] += hit_k # MRR: 1.0 / (rank of true label) # matches has True where top_k_indices == labels_expanded # We find the index of the True value in each row match_found, match_indices = matches.max(dim=1) # indices are 0-based # reciprocal_rank = 1 / (index + 1) if found, else 0 ranks = match_indices.float() + 1.0 reciprocal_ranks = torch.where(match_found, 1.0 / ranks, torch.zeros_like(ranks)) # Accumulate sums (we store individual RR in original code, but sum + count is more memory efficient) # Note: original code used self.reciprocal_ranks.append(), let's keep it consistent or optimize # Since the goal is performance, I'll store the sum and count instead of a huge list if not hasattr(self, 'rr_sum'): self.rr_sum = 0.0 self.rr_sum += reciprocal_ranks.sum().item() # Keep reciprocal_ranks list for legacy compute() if it uses np.mean # but actually np.mean is just sum/len. # For compatibility with existing compute(): self.reciprocal_ranks.extend(reciprocal_ranks.tolist()) def compute(self) -> Dict[str, float]: """ Compute final metrics. Returns: Dictionary with keys: 'accuracy', 'hit_5', 'hit_10', 'hit_20', 'mrr' """ if self.total_predictions == 0: return { 'accuracy': 0.0, **{f'hit_{k}': 0.0 for k in self.top_k_values}, 'mrr': 0.0 } metrics = { 'accuracy': 100.0 * self.correct_top1 / self.total_predictions, 'mrr': 100.0 * np.mean(self.reciprocal_ranks) if self.reciprocal_ranks else 0.0 } for k in self.top_k_values: metrics[f'hit_{k}'] = 100.0 * self.hits_at_k[k] / self.total_predictions return metrics def get_summary(self) -> str: """Get formatted summary string.""" metrics = self.compute() summary = f"Acc: {metrics['accuracy']:.2f}%" for k in self.top_k_values: summary += f" | Hit@{k}: {metrics[f'hit_{k}']:.2f}%" summary += f" | MRR: {metrics['mrr']:.2f}" return summary def evaluate_model(model, dataloader, device, num_samples: int = 1): """ Evaluate model on a dataset. Args: model: Model to evaluate dataloader: DataLoader for evaluation data device: Device to run evaluation on num_samples: Number of random samplings (30 for final evaluation per paper) Returns: Dictionary of averaged metrics over all samplings """ model.eval() all_metrics = [] for sample_idx in range(num_samples): metrics = RestorationMetrics() with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) mask_positions = batch['mask_positions'].to(device) damaged_images = batch['damaged_images'].to(device) labels = batch['labels'].to(device) # Forward pass if hasattr(model, 'forward') and 'damaged_images' in str(model.forward.__code__.co_varnames): # MMRM model text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) elif 'mask_positions' in str(model.forward.__code__.co_varnames): # Language model baseline text_logits = model(input_ids, attention_mask, mask_positions) else: # Image-only baseline text_logits = model(damaged_images) # Update metrics metrics.update(text_logits, labels) all_metrics.append(metrics.compute()) # Average over all samplings if num_samples == 1: return all_metrics[0] averaged_metrics = {} for key in all_metrics[0].keys(): averaged_metrics[key] = np.mean([m[key] for m in all_metrics]) averaged_metrics[f'{key}_std'] = np.std([m[key] for m in all_metrics]) return averaged_metrics