File size: 6,740 Bytes
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Evaluation metrics for text restoration.
Implements Accuracy, Hits@K, and MRR as described in the paper.
"""

import torch
import numpy as np
from typing import List, Dict


class RestorationMetrics:
    """
    Computes evaluation metrics for character restoration.
    
    Metrics from Section 4.3:
    - Accuracy: Top-1 exact match
    - Hits@K: Correct character in top K predictions
    - MRR: Mean Reciprocal Rank
    """
    
    def __init__(self, top_k_values: List[int] = [5, 10, 20]):
        """
        Initialize metrics calculator.
        
        Args:
            top_k_values: K values for Hits@K metric
        """
        self.top_k_values = top_k_values
        self.reset()
    
    def reset(self):
        """Reset all accumulated metrics."""
        self.total_predictions = 0
        self.correct_top1 = 0
        self.hits_at_k = {k: 0 for k in self.top_k_values}
        self.reciprocal_ranks = []
    
    def update(self, predictions: torch.Tensor, labels: torch.Tensor):
        """
        Update metrics with a batch of predictions.
        
        Args:
            predictions: Model logits [batch_size, num_masks, vocab_size]
            labels: True character IDs [batch_size, num_masks]
        """
        # Flatten predictions and labels
        # predictions: [B, N, V] -> [B*N, V]
        # labels: [B, N] -> [B*N]
        predictions_flat = predictions.view(-1, predictions.size(-1))
        labels_flat = labels.view(-1)
        
        batch_total = labels_flat.size(0)
        self.total_predictions += batch_total
        
        # Get top-K predictions (max k in top_k_values)
        max_k = max(self.top_k_values)
        _, top_k_indices = torch.topk(predictions_flat, k=max_k, dim=1)  # [B*N, max_k]
        
        # Accuracy: top-1 is equal to label
        top1_correct = (top_k_indices[:, 0] == labels_flat).sum().item()
        self.correct_top1 += top1_correct
        
        # Hits@K: label present in top-k
        # Expand labels for comparison: [B*N] -> [B*N, 1] -> [B*N, max_k]
        labels_expanded = labels_flat.unsqueeze(1).expand_as(top_k_indices)
        matches = (top_k_indices == labels_expanded)  # [B*N, max_k]
        
        for k in self.top_k_values:
            # Check if match in first k columns
            hit_k = matches[:, :k].any(dim=1).sum().item()
            self.hits_at_k[k] += hit_k
            
        # MRR: 1.0 / (rank of true label)
        # matches has True where top_k_indices == labels_expanded
        # We find the index of the True value in each row
        match_found, match_indices = matches.max(dim=1)  # indices are 0-based
        
        # reciprocal_rank = 1 / (index + 1) if found, else 0
        ranks = match_indices.float() + 1.0
        reciprocal_ranks = torch.where(match_found, 1.0 / ranks, torch.zeros_like(ranks))
        
        # Accumulate sums (we store individual RR in original code, but sum + count is more memory efficient)
        # Note: original code used self.reciprocal_ranks.append(), let's keep it consistent or optimize
        # Since the goal is performance, I'll store the sum and count instead of a huge list
        if not hasattr(self, 'rr_sum'):
            self.rr_sum = 0.0
            
        self.rr_sum += reciprocal_ranks.sum().item()
        # Keep reciprocal_ranks list for legacy compute() if it uses np.mean
        # but actually np.mean is just sum/len. 
        # For compatibility with existing compute():
        self.reciprocal_ranks.extend(reciprocal_ranks.tolist())
    
    def compute(self) -> Dict[str, float]:
        """
        Compute final metrics.
        
        Returns:
            Dictionary with keys: 'accuracy', 'hit_5', 'hit_10', 'hit_20', 'mrr'
        """
        if self.total_predictions == 0:
            return {
                'accuracy': 0.0,
                **{f'hit_{k}': 0.0 for k in self.top_k_values},
                'mrr': 0.0
            }
        
        metrics = {
            'accuracy': 100.0 * self.correct_top1 / self.total_predictions,
            'mrr': 100.0 * np.mean(self.reciprocal_ranks) if self.reciprocal_ranks else 0.0
        }
        
        for k in self.top_k_values:
            metrics[f'hit_{k}'] = 100.0 * self.hits_at_k[k] / self.total_predictions
        
        return metrics
    
    def get_summary(self) -> str:
        """Get formatted summary string."""
        metrics = self.compute()
        summary = f"Acc: {metrics['accuracy']:.2f}%"
        for k in self.top_k_values:
            summary += f" | Hit@{k}: {metrics[f'hit_{k}']:.2f}%"
        summary += f" | MRR: {metrics['mrr']:.2f}"
        return summary


def evaluate_model(model, dataloader, device, num_samples: int = 1):
    """
    Evaluate model on a dataset.
    
    Args:
        model: Model to evaluate
        dataloader: DataLoader for evaluation data
        device: Device to run evaluation on
        num_samples: Number of random samplings (30 for final evaluation per paper)
        
    Returns:
        Dictionary of averaged metrics over all samplings
    """
    model.eval()
    
    all_metrics = []
    
    for sample_idx in range(num_samples):
        metrics = RestorationMetrics()
        
        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                mask_positions = batch['mask_positions'].to(device)
                damaged_images = batch['damaged_images'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                if hasattr(model, 'forward') and 'damaged_images' in str(model.forward.__code__.co_varnames):
                    # MMRM model
                    text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images)
                elif 'mask_positions' in str(model.forward.__code__.co_varnames):
                    # Language model baseline
                    text_logits = model(input_ids, attention_mask, mask_positions)
                else:
                    # Image-only baseline
                    text_logits = model(damaged_images)
                
                # Update metrics
                metrics.update(text_logits, labels)
        
        all_metrics.append(metrics.compute())
    
    # Average over all samplings
    if num_samples == 1:
        return all_metrics[0]
    
    averaged_metrics = {}
    for key in all_metrics[0].keys():
        averaged_metrics[key] = np.mean([m[key] for m in all_metrics])
        averaged_metrics[f'{key}_std'] = np.std([m[key] for m in all_metrics])
    
    return averaged_metrics