MMRM / evaluation /evaluate_simulation.py
rexera's picture
0-shot pipeline test
87224ba
"""
Evaluation on simulated test set with 30 random samplings.
Implements evaluation protocol from Section 4.4.
"""
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np
from config import Config
from models.mmrm import MMRM
from data_utils.dataset import MMRMDataset
from evaluation.metrics import RestorationMetrics
from utils.font_utils import FontManager
from utils.tensorboard_tracker import TensorBoardTracker
def evaluate_on_test_set(
config: Config,
checkpoint_path: str,
num_samples: int = 30,
num_masks: int = 1
) -> dict:
"""
Evaluate model on test set with multiple random samplings.
As per paper Section 4.4: "all simulation results are the averages obtained
after randomly sampling the damaged characters on the test set 30 times"
Args:
config: Configuration object
checkpoint_path: Path to model checkpoint
num_samples: Number of random samplings (default: 30)
num_masks: Number of masks per sample (1 for single, or use random 1-5)
Returns:
Dictionary of averaged metrics
"""
device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu")
# Initialize TensorBoard tracker
tb_tracker = TensorBoardTracker(config)
# Load model
model = MMRM(config).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False)
# Check if this is a Phase 1 checkpoint (separate state dicts) or Phase 2 (full model)
if 'decoder_state_dict' in checkpoint:
print("Detected Phase 1 checkpoint (separate encoder/decoder state dicts).")
# Phase 1 saves context_encoder as 'model_state_dict' and text_decoder as 'decoder_state_dict'
try:
model.context_encoder.load_state_dict(checkpoint['model_state_dict'])
model.text_decoder.load_state_dict(checkpoint['decoder_state_dict'])
print("Successfully loaded Phase 1 weights into ContextEncoder and TextDecoder.")
print("Warning: ImageEncoder and ImageDecoder are using random initialization (expected for Phase 1).")
except RuntimeError as e:
# Fallback or detail error reporting
print(f"Error loading Phase 1 weights: {e}")
raise e
else:
# Phase 2 saves the full MMRM model in 'model_state_dict'
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded model from {checkpoint_path}")
print(f"Evaluating with {num_samples} random samplings...")
# Initialize tokenizer and font manager
tokenizer = AutoTokenizer.from_pretrained(config.roberta_model)
font_manager = FontManager(config.font_dir, config.image_size, config.min_black_pixels)
# Start TensorBoard run
with tb_tracker.start_run(run_name="Evaluation_Simulation"):
# Log evaluation parameters
tb_tracker.log_params({
"checkpoint_path": checkpoint_path,
"num_samples": num_samples,
"num_masks": num_masks,
"evaluation_type": "simulation"
})
tb_tracker.set_tags({
"evaluation": "simulation",
"num_samplings": str(num_samples)
})
# Run multiple samplings
all_metrics = []
for sample_idx in range(num_samples):
print(f"\nSampling {sample_idx + 1}/{num_samples}")
# Create test dataset (randomness comes from mask selection and augmentation)
test_dataset = MMRMDataset(
config,
'test',
tokenizer,
font_manager,
num_masks=num_masks,
curriculum_epoch=None
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=config.pin_memory
)
# Evaluate
metrics = RestorationMetrics(config.top_k_values)
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
mask_positions = batch['mask_positions'].to(device)
damaged_images = batch['damaged_images'].to(device)
labels = batch['labels'].to(device)
# Forward pass
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images)
# Update metrics
metrics.update(text_logits, labels)
sample_metrics = metrics.compute()
all_metrics.append(sample_metrics)
# Log individual sampling metrics to TensorBoard
sample_metrics_prefixed = {f"eval/sampling_{sample_idx}_{k}": v for k, v in sample_metrics.items()}
tb_tracker.log_metrics(sample_metrics_prefixed)
print(f" Acc={sample_metrics['accuracy']:.2f}%, "
f"Hit@5={sample_metrics['hit_5']:.2f}%, "
f"Hit@10={sample_metrics['hit_10']:.2f}%, "
f"Hit@20={sample_metrics['hit_20']:.2f}%, "
f"MRR={sample_metrics['mrr']:.2f}")
# Compute average and std
averaged_metrics = {}
for key in all_metrics[0].keys():
values = [m[key] for m in all_metrics]
averaged_metrics[key] = np.mean(values)
averaged_metrics[f'{key}_std'] = np.std(values)
# Log averaged metrics to TensorBoard
tb_tracker.log_metrics({
"eval/avg_accuracy": averaged_metrics['accuracy'],
"eval/avg_hit_5": averaged_metrics['hit_5'],
"eval/avg_hit_10": averaged_metrics['hit_10'],
"eval/avg_hit_20": averaged_metrics['hit_20'],
"eval/avg_mrr": averaged_metrics['mrr'],
"eval/std_accuracy": averaged_metrics['accuracy_std'],
"eval/std_hit_5": averaged_metrics['hit_5_std'],
"eval/std_hit_10": averaged_metrics['hit_10_std'],
"eval/std_hit_20": averaged_metrics['hit_20_std'],
"eval/std_mrr": averaged_metrics['mrr_std']
})
# Log all metrics as a JSON artifact
tb_tracker.log_dict(averaged_metrics, "evaluation_results.json", artifact_path="metrics")
print(f"\n{'='*70}")
print(f"Final Results (averaged over {num_samples} samplings):")
print(f"{'='*70}")
print(f"Accuracy: {averaged_metrics['accuracy']:.2f} ± {averaged_metrics['accuracy_std']:.2f}%")
print(f"Hit@5: {averaged_metrics['hit_5']:.2f} ± {averaged_metrics['hit_5_std']:.2f}%")
print(f"Hit@10: {averaged_metrics['hit_10']:.2f} ± {averaged_metrics['hit_10_std']:.2f}%")
print(f"Hit@20: {averaged_metrics['hit_20']:.2f} ± {averaged_metrics['hit_20_std']:.2f}%")
print(f"MRR: {averaged_metrics['mrr']:.2f} ± {averaged_metrics['mrr_std']:.2f}")
print(f"{'='*70}")
return averaged_metrics
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python evaluate_simulation.py <checkpoint_path> [num_samples]")
sys.exit(1)
checkpoint_path = sys.argv[1]
num_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 30
config = Config()
results = evaluate_on_test_set(config, checkpoint_path, num_samples)