File size: 7,807 Bytes
87224ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
Evaluation on simulated test set with 30 random samplings.
Implements evaluation protocol from Section 4.4.
"""
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np
from config import Config
from models.mmrm import MMRM
from data_utils.dataset import MMRMDataset
from evaluation.metrics import RestorationMetrics
from utils.font_utils import FontManager
from utils.tensorboard_tracker import TensorBoardTracker
def evaluate_on_test_set(
config: Config,
checkpoint_path: str,
num_samples: int = 30,
num_masks: int = 1
) -> dict:
"""
Evaluate model on test set with multiple random samplings.
As per paper Section 4.4: "all simulation results are the averages obtained
after randomly sampling the damaged characters on the test set 30 times"
Args:
config: Configuration object
checkpoint_path: Path to model checkpoint
num_samples: Number of random samplings (default: 30)
num_masks: Number of masks per sample (1 for single, or use random 1-5)
Returns:
Dictionary of averaged metrics
"""
device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu")
# Initialize TensorBoard tracker
tb_tracker = TensorBoardTracker(config)
# Load model
model = MMRM(config).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False)
# Check if this is a Phase 1 checkpoint (separate state dicts) or Phase 2 (full model)
if 'decoder_state_dict' in checkpoint:
print("Detected Phase 1 checkpoint (separate encoder/decoder state dicts).")
# Phase 1 saves context_encoder as 'model_state_dict' and text_decoder as 'decoder_state_dict'
try:
model.context_encoder.load_state_dict(checkpoint['model_state_dict'])
model.text_decoder.load_state_dict(checkpoint['decoder_state_dict'])
print("Successfully loaded Phase 1 weights into ContextEncoder and TextDecoder.")
print("Warning: ImageEncoder and ImageDecoder are using random initialization (expected for Phase 1).")
except RuntimeError as e:
# Fallback or detail error reporting
print(f"Error loading Phase 1 weights: {e}")
raise e
else:
# Phase 2 saves the full MMRM model in 'model_state_dict'
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded model from {checkpoint_path}")
print(f"Evaluating with {num_samples} random samplings...")
# Initialize tokenizer and font manager
tokenizer = AutoTokenizer.from_pretrained(config.roberta_model)
font_manager = FontManager(config.font_dir, config.image_size, config.min_black_pixels)
# Start TensorBoard run
with tb_tracker.start_run(run_name="Evaluation_Simulation"):
# Log evaluation parameters
tb_tracker.log_params({
"checkpoint_path": checkpoint_path,
"num_samples": num_samples,
"num_masks": num_masks,
"evaluation_type": "simulation"
})
tb_tracker.set_tags({
"evaluation": "simulation",
"num_samplings": str(num_samples)
})
# Run multiple samplings
all_metrics = []
for sample_idx in range(num_samples):
print(f"\nSampling {sample_idx + 1}/{num_samples}")
# Create test dataset (randomness comes from mask selection and augmentation)
test_dataset = MMRMDataset(
config,
'test',
tokenizer,
font_manager,
num_masks=num_masks,
curriculum_epoch=None
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=config.pin_memory
)
# Evaluate
metrics = RestorationMetrics(config.top_k_values)
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
mask_positions = batch['mask_positions'].to(device)
damaged_images = batch['damaged_images'].to(device)
labels = batch['labels'].to(device)
# Forward pass
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images)
# Update metrics
metrics.update(text_logits, labels)
sample_metrics = metrics.compute()
all_metrics.append(sample_metrics)
# Log individual sampling metrics to TensorBoard
sample_metrics_prefixed = {f"eval/sampling_{sample_idx}_{k}": v for k, v in sample_metrics.items()}
tb_tracker.log_metrics(sample_metrics_prefixed)
print(f" Acc={sample_metrics['accuracy']:.2f}%, "
f"Hit@5={sample_metrics['hit_5']:.2f}%, "
f"Hit@10={sample_metrics['hit_10']:.2f}%, "
f"Hit@20={sample_metrics['hit_20']:.2f}%, "
f"MRR={sample_metrics['mrr']:.2f}")
# Compute average and std
averaged_metrics = {}
for key in all_metrics[0].keys():
values = [m[key] for m in all_metrics]
averaged_metrics[key] = np.mean(values)
averaged_metrics[f'{key}_std'] = np.std(values)
# Log averaged metrics to TensorBoard
tb_tracker.log_metrics({
"eval/avg_accuracy": averaged_metrics['accuracy'],
"eval/avg_hit_5": averaged_metrics['hit_5'],
"eval/avg_hit_10": averaged_metrics['hit_10'],
"eval/avg_hit_20": averaged_metrics['hit_20'],
"eval/avg_mrr": averaged_metrics['mrr'],
"eval/std_accuracy": averaged_metrics['accuracy_std'],
"eval/std_hit_5": averaged_metrics['hit_5_std'],
"eval/std_hit_10": averaged_metrics['hit_10_std'],
"eval/std_hit_20": averaged_metrics['hit_20_std'],
"eval/std_mrr": averaged_metrics['mrr_std']
})
# Log all metrics as a JSON artifact
tb_tracker.log_dict(averaged_metrics, "evaluation_results.json", artifact_path="metrics")
print(f"\n{'='*70}")
print(f"Final Results (averaged over {num_samples} samplings):")
print(f"{'='*70}")
print(f"Accuracy: {averaged_metrics['accuracy']:.2f} ± {averaged_metrics['accuracy_std']:.2f}%")
print(f"Hit@5: {averaged_metrics['hit_5']:.2f} ± {averaged_metrics['hit_5_std']:.2f}%")
print(f"Hit@10: {averaged_metrics['hit_10']:.2f} ± {averaged_metrics['hit_10_std']:.2f}%")
print(f"Hit@20: {averaged_metrics['hit_20']:.2f} ± {averaged_metrics['hit_20_std']:.2f}%")
print(f"MRR: {averaged_metrics['mrr']:.2f} ± {averaged_metrics['mrr_std']:.2f}")
print(f"{'='*70}")
return averaged_metrics
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python evaluate_simulation.py <checkpoint_path> [num_samples]")
sys.exit(1)
checkpoint_path = sys.argv[1]
num_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 30
config = Config()
results = evaluate_on_test_set(config, checkpoint_path, num_samples)
|