File size: 7,807 Bytes
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
Evaluation on simulated test set with 30 random samplings.
Implements evaluation protocol from Section 4.4.
"""

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np

from config import Config
from models.mmrm import MMRM
from data_utils.dataset import MMRMDataset
from evaluation.metrics import RestorationMetrics
from utils.font_utils import FontManager
from utils.tensorboard_tracker import TensorBoardTracker


def evaluate_on_test_set(
    config: Config,
    checkpoint_path: str,
    num_samples: int = 30,
    num_masks: int = 1
) -> dict:
    """
    Evaluate model on test set with multiple random samplings.
    
    As per paper Section 4.4: "all simulation results are the averages obtained
    after randomly sampling the damaged characters on the test set 30 times"
    
    Args:
        config: Configuration object
        checkpoint_path: Path to model checkpoint
        num_samples: Number of random samplings (default: 30)
        num_masks: Number of masks per sample (1 for single, or use random 1-5)
        
    Returns:
        Dictionary of averaged metrics
    """
    device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu")
    
    # Initialize TensorBoard tracker
    tb_tracker = TensorBoardTracker(config)
    
    # Load model
    model = MMRM(config).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False)
    
    # Check if this is a Phase 1 checkpoint (separate state dicts) or Phase 2 (full model)
    if 'decoder_state_dict' in checkpoint:
        print("Detected Phase 1 checkpoint (separate encoder/decoder state dicts).")
        # Phase 1 saves context_encoder as 'model_state_dict' and text_decoder as 'decoder_state_dict'
        try:
            model.context_encoder.load_state_dict(checkpoint['model_state_dict'])
            model.text_decoder.load_state_dict(checkpoint['decoder_state_dict'])
            print("Successfully loaded Phase 1 weights into ContextEncoder and TextDecoder.")
            print("Warning: ImageEncoder and ImageDecoder are using random initialization (expected for Phase 1).")
        except RuntimeError as e:
            # Fallback or detail error reporting
            print(f"Error loading Phase 1 weights: {e}")
            raise e
    else:
        # Phase 2 saves the full MMRM model in 'model_state_dict'
        model.load_state_dict(checkpoint['model_state_dict'])
        
    model.eval()
    
    print(f"Loaded model from {checkpoint_path}")
    print(f"Evaluating with {num_samples} random samplings...")
    
    # Initialize tokenizer and font manager
    tokenizer = AutoTokenizer.from_pretrained(config.roberta_model)
    font_manager = FontManager(config.font_dir, config.image_size, config.min_black_pixels)
    
    # Start TensorBoard run
    with tb_tracker.start_run(run_name="Evaluation_Simulation"):
        # Log evaluation parameters
        tb_tracker.log_params({
            "checkpoint_path": checkpoint_path,
            "num_samples": num_samples,
            "num_masks": num_masks,
            "evaluation_type": "simulation"
        })
        tb_tracker.set_tags({
            "evaluation": "simulation",
            "num_samplings": str(num_samples)
        })
        
        # Run multiple samplings
        all_metrics = []
        
        for sample_idx in range(num_samples):
            print(f"\nSampling {sample_idx + 1}/{num_samples}")
            
            # Create test dataset (randomness comes from mask selection and augmentation)
            test_dataset = MMRMDataset(
                config,
                'test',
                tokenizer,
                font_manager,
                num_masks=num_masks,
                curriculum_epoch=None
            )
            
            test_loader = DataLoader(
                test_dataset,
                batch_size=config.batch_size,
                shuffle=False,
                num_workers=config.num_workers,
                pin_memory=config.pin_memory
            )
            
            # Evaluate
            metrics = RestorationMetrics(config.top_k_values)
            
            with torch.no_grad():
                for batch in test_loader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    mask_positions = batch['mask_positions'].to(device)
                    damaged_images = batch['damaged_images'].to(device)
                    labels = batch['labels'].to(device)
                    
                    # Forward pass
                    text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images)
                    
                    # Update metrics
                    metrics.update(text_logits, labels)
            
            sample_metrics = metrics.compute()
            all_metrics.append(sample_metrics)
            
            # Log individual sampling metrics to TensorBoard
            sample_metrics_prefixed = {f"eval/sampling_{sample_idx}_{k}": v for k, v in sample_metrics.items()}
            tb_tracker.log_metrics(sample_metrics_prefixed)
            
            print(f"  Acc={sample_metrics['accuracy']:.2f}%, "
                  f"Hit@5={sample_metrics['hit_5']:.2f}%, "
                  f"Hit@10={sample_metrics['hit_10']:.2f}%, "
                  f"Hit@20={sample_metrics['hit_20']:.2f}%, "
                  f"MRR={sample_metrics['mrr']:.2f}")
        
        # Compute average and std
        averaged_metrics = {}
        for key in all_metrics[0].keys():
            values = [m[key] for m in all_metrics]
            averaged_metrics[key] = np.mean(values)
            averaged_metrics[f'{key}_std'] = np.std(values)
        
        # Log averaged metrics to TensorBoard
        tb_tracker.log_metrics({
            "eval/avg_accuracy": averaged_metrics['accuracy'],
            "eval/avg_hit_5": averaged_metrics['hit_5'],
            "eval/avg_hit_10": averaged_metrics['hit_10'],
            "eval/avg_hit_20": averaged_metrics['hit_20'],
            "eval/avg_mrr": averaged_metrics['mrr'],
            "eval/std_accuracy": averaged_metrics['accuracy_std'],
            "eval/std_hit_5": averaged_metrics['hit_5_std'],
            "eval/std_hit_10": averaged_metrics['hit_10_std'],
            "eval/std_hit_20": averaged_metrics['hit_20_std'],
            "eval/std_mrr": averaged_metrics['mrr_std']
        })
        
        # Log all metrics as a JSON artifact
        tb_tracker.log_dict(averaged_metrics, "evaluation_results.json", artifact_path="metrics")
        
        print(f"\n{'='*70}")
        print(f"Final Results (averaged over {num_samples} samplings):")
        print(f"{'='*70}")
        print(f"Accuracy: {averaged_metrics['accuracy']:.2f} ± {averaged_metrics['accuracy_std']:.2f}%")
        print(f"Hit@5:    {averaged_metrics['hit_5']:.2f} ± {averaged_metrics['hit_5_std']:.2f}%")
        print(f"Hit@10:   {averaged_metrics['hit_10']:.2f} ± {averaged_metrics['hit_10_std']:.2f}%")
        print(f"Hit@20:   {averaged_metrics['hit_20']:.2f} ± {averaged_metrics['hit_20_std']:.2f}%")
        print(f"MRR:      {averaged_metrics['mrr']:.2f} ± {averaged_metrics['mrr_std']:.2f}")
        print(f"{'='*70}")
        
        return averaged_metrics



if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 2:
        print("Usage: python evaluate_simulation.py <checkpoint_path> [num_samples]")
        sys.exit(1)
    
    checkpoint_path = sys.argv[1]
    num_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 30
    
    config = Config()
    results = evaluate_on_test_set(config, checkpoint_path, num_samples)