Ryanfafa's picture
Upload 7 files
19ea5c5 verified
import argparse
import json
import os
from typing import Dict, List, Tuple
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
from nltk.translate.meteor_score import single_meteor_score
from rouge_score import rouge_scorer
from .config import PathsConfig, TrainingConfig, get_device, set_seed
from .dataset import create_dataloader, create_tokenizer
from .model import ImageCaptioningModel
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for evaluation.
"""
parser = argparse.ArgumentParser(description="Evaluate image captioning model on test set.")
parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt).")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation.")
parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length during generation.")
parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--output_samples", type=str, default="evaluation_samples.jsonl", help="File to save sample predictions.")
return parser.parse_args()
def compute_metrics(
references: List[List[str]],
hypotheses: List[str],
) -> Dict[str, float]:
"""
Compute BLEU (1-4), METEOR, and ROUGE-L metrics.
"""
if not references or not hypotheses:
raise ValueError("References and hypotheses must be non-empty.")
if len(references) != len(hypotheses):
raise ValueError("Number of references and hypotheses must match.")
smoothie = SmoothingFunction().method4
# BLEU scores
bleu1 = corpus_bleu(
references,
hypotheses,
weights=(1.0, 0.0, 0.0, 0.0),
smoothing_function=smoothie,
)
bleu2 = corpus_bleu(
references,
hypotheses,
weights=(0.5, 0.5, 0.0, 0.0),
smoothing_function=smoothie,
)
bleu3 = corpus_bleu(
references,
hypotheses,
weights=(1.0 / 3, 1.0 / 3, 1.0 / 3, 0.0),
smoothing_function=smoothie,
)
bleu4 = corpus_bleu(
references,
hypotheses,
weights=(0.25, 0.25, 0.25, 0.25),
smoothing_function=smoothie,
)
# METEOR
meteor_scores: List[float] = []
for ref_list, hyp in zip(references, hypotheses):
# Use the first reference for METEOR; tokenize by simple whitespace.
# If NLTK's WordNet data is missing, fall back to a simple unigram F1.
ref_tokens = ref_list[0].split()
hyp_tokens = hyp.split()
try:
meteor_scores.append(single_meteor_score(ref_tokens, hyp_tokens))
except LookupError:
ref_set = set(ref_tokens)
hyp_set = set(hyp_tokens)
if not ref_set or not hyp_set:
meteor_scores.append(0.0)
else:
overlap = len(ref_set & hyp_set)
precision = overlap / len(hyp_set)
recall = overlap / len(ref_set)
if precision + recall == 0:
meteor_scores.append(0.0)
else:
meteor_scores.append(2 * precision * recall / (precision + recall))
meteor = sum(meteor_scores) / max(1, len(meteor_scores))
# ROUGE-L
rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
rouge_l_scores: List[float] = []
for ref_list, hyp in zip(references, hypotheses):
scores = rouge.score(ref_list[0], hyp)
rouge_l_scores.append(scores["rougeL"].fmeasure)
rouge_l = sum(rouge_l_scores) / max(1, len(rouge_l_scores))
return {
"BLEU-1": bleu1,
"BLEU-2": bleu2,
"BLEU-3": bleu3,
"BLEU-4": bleu4,
"METEOR": meteor,
"ROUGE-L": rouge_l,
}
def run_evaluation(args: argparse.Namespace) -> None:
"""
Run evaluation on the test set, compute metrics, and save sample predictions.
"""
paths_cfg = PathsConfig(data_root=args.data_root)
training_cfg = TrainingConfig(
batch_size=args.batch_size,
max_caption_length=args.max_length,
num_epochs=1,
)
set_seed(args.seed)
device = get_device()
tokenizer = create_tokenizer()
test_loader, tokenizer = create_dataloader(
paths_cfg=paths_cfg,
training_cfg=training_cfg,
split="test",
tokenizer=tokenizer,
shuffle=False,
)
model = ImageCaptioningModel(training_cfg=training_cfg)
state_dict = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
references: List[List[str]] = []
hypotheses: List[str] = []
num_samples_to_save = 50
saved_samples: List[Dict[str, str]] = []
with torch.no_grad():
for batch in test_loader:
images = batch["image"].to(device)
# Use the raw caption string from the dataset as reference
captions = batch["caption"]
# Generate predictions one image at a time to respect generate() constraints
for idx in range(images.size(0)):
single_image = images[idx : idx + 1]
ref_caption = captions[idx]
pred_text_list = model.generate(
images=single_image,
max_length=args.max_length,
num_beams=args.num_beams,
)
pred_text = pred_text_list[0]
references.append([ref_caption])
hypotheses.append(pred_text)
if len(saved_samples) < num_samples_to_save:
saved_samples.append(
{
"image_id": batch["image_id"][idx],
"reference": ref_caption,
"prediction": pred_text,
}
)
metrics = compute_metrics(references, hypotheses)
print("Evaluation metrics:")
for name, value in metrics.items():
print(f" {name}: {value:.4f}")
# Save sample predictions
output_path = args.output_samples
with open(output_path, "w", encoding="utf-8") as f:
for sample in saved_samples:
f.write(json.dumps(sample) + "\n")
print(f"Saved {len(saved_samples)} sample predictions to {output_path}")
def main() -> None:
args = parse_args()
if not os.path.exists(args.checkpoint):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
run_evaluation(args)
if __name__ == "__main__":
main()