|
|
""" |
|
|
Evaluation metrics for text restoration. |
|
|
Implements Accuracy, Hits@K, and MRR as described in the paper. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
class RestorationMetrics: |
|
|
""" |
|
|
Computes evaluation metrics for character restoration. |
|
|
|
|
|
Metrics from Section 4.3: |
|
|
- Accuracy: Top-1 exact match |
|
|
- Hits@K: Correct character in top K predictions |
|
|
- MRR: Mean Reciprocal Rank |
|
|
""" |
|
|
|
|
|
def __init__(self, top_k_values: List[int] = [5, 10, 20]): |
|
|
""" |
|
|
Initialize metrics calculator. |
|
|
|
|
|
Args: |
|
|
top_k_values: K values for Hits@K metric |
|
|
""" |
|
|
self.top_k_values = top_k_values |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
"""Reset all accumulated metrics.""" |
|
|
self.total_predictions = 0 |
|
|
self.correct_top1 = 0 |
|
|
self.hits_at_k = {k: 0 for k in self.top_k_values} |
|
|
self.reciprocal_ranks = [] |
|
|
|
|
|
def update(self, predictions: torch.Tensor, labels: torch.Tensor): |
|
|
""" |
|
|
Update metrics with a batch of predictions. |
|
|
|
|
|
Args: |
|
|
predictions: Model logits [batch_size, num_masks, vocab_size] |
|
|
labels: True character IDs [batch_size, num_masks] |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
predictions_flat = predictions.view(-1, predictions.size(-1)) |
|
|
labels_flat = labels.view(-1) |
|
|
|
|
|
batch_total = labels_flat.size(0) |
|
|
self.total_predictions += batch_total |
|
|
|
|
|
|
|
|
max_k = max(self.top_k_values) |
|
|
_, top_k_indices = torch.topk(predictions_flat, k=max_k, dim=1) |
|
|
|
|
|
|
|
|
top1_correct = (top_k_indices[:, 0] == labels_flat).sum().item() |
|
|
self.correct_top1 += top1_correct |
|
|
|
|
|
|
|
|
|
|
|
labels_expanded = labels_flat.unsqueeze(1).expand_as(top_k_indices) |
|
|
matches = (top_k_indices == labels_expanded) |
|
|
|
|
|
for k in self.top_k_values: |
|
|
|
|
|
hit_k = matches[:, :k].any(dim=1).sum().item() |
|
|
self.hits_at_k[k] += hit_k |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
match_found, match_indices = matches.max(dim=1) |
|
|
|
|
|
|
|
|
ranks = match_indices.float() + 1.0 |
|
|
reciprocal_ranks = torch.where(match_found, 1.0 / ranks, torch.zeros_like(ranks)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self, 'rr_sum'): |
|
|
self.rr_sum = 0.0 |
|
|
|
|
|
self.rr_sum += reciprocal_ranks.sum().item() |
|
|
|
|
|
|
|
|
|
|
|
self.reciprocal_ranks.extend(reciprocal_ranks.tolist()) |
|
|
|
|
|
def compute(self) -> Dict[str, float]: |
|
|
""" |
|
|
Compute final metrics. |
|
|
|
|
|
Returns: |
|
|
Dictionary with keys: 'accuracy', 'hit_5', 'hit_10', 'hit_20', 'mrr' |
|
|
""" |
|
|
if self.total_predictions == 0: |
|
|
return { |
|
|
'accuracy': 0.0, |
|
|
**{f'hit_{k}': 0.0 for k in self.top_k_values}, |
|
|
'mrr': 0.0 |
|
|
} |
|
|
|
|
|
metrics = { |
|
|
'accuracy': 100.0 * self.correct_top1 / self.total_predictions, |
|
|
'mrr': 100.0 * np.mean(self.reciprocal_ranks) if self.reciprocal_ranks else 0.0 |
|
|
} |
|
|
|
|
|
for k in self.top_k_values: |
|
|
metrics[f'hit_{k}'] = 100.0 * self.hits_at_k[k] / self.total_predictions |
|
|
|
|
|
return metrics |
|
|
|
|
|
def get_summary(self) -> str: |
|
|
"""Get formatted summary string.""" |
|
|
metrics = self.compute() |
|
|
summary = f"Acc: {metrics['accuracy']:.2f}%" |
|
|
for k in self.top_k_values: |
|
|
summary += f" | Hit@{k}: {metrics[f'hit_{k}']:.2f}%" |
|
|
summary += f" | MRR: {metrics['mrr']:.2f}" |
|
|
return summary |
|
|
|
|
|
|
|
|
def evaluate_model(model, dataloader, device, num_samples: int = 1): |
|
|
""" |
|
|
Evaluate model on a dataset. |
|
|
|
|
|
Args: |
|
|
model: Model to evaluate |
|
|
dataloader: DataLoader for evaluation data |
|
|
device: Device to run evaluation on |
|
|
num_samples: Number of random samplings (30 for final evaluation per paper) |
|
|
|
|
|
Returns: |
|
|
Dictionary of averaged metrics over all samplings |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
all_metrics = [] |
|
|
|
|
|
for sample_idx in range(num_samples): |
|
|
metrics = RestorationMetrics() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
mask_positions = batch['mask_positions'].to(device) |
|
|
damaged_images = batch['damaged_images'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
|
|
|
if hasattr(model, 'forward') and 'damaged_images' in str(model.forward.__code__.co_varnames): |
|
|
|
|
|
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) |
|
|
elif 'mask_positions' in str(model.forward.__code__.co_varnames): |
|
|
|
|
|
text_logits = model(input_ids, attention_mask, mask_positions) |
|
|
else: |
|
|
|
|
|
text_logits = model(damaged_images) |
|
|
|
|
|
|
|
|
metrics.update(text_logits, labels) |
|
|
|
|
|
all_metrics.append(metrics.compute()) |
|
|
|
|
|
|
|
|
if num_samples == 1: |
|
|
return all_metrics[0] |
|
|
|
|
|
averaged_metrics = {} |
|
|
for key in all_metrics[0].keys(): |
|
|
averaged_metrics[key] = np.mean([m[key] for m in all_metrics]) |
|
|
averaged_metrics[f'{key}_std'] = np.std([m[key] for m in all_metrics]) |
|
|
|
|
|
return averaged_metrics |
|
|
|