Spaces:
Sleeping
Sleeping
| """ | |
| Model Evaluation Script | |
| Evaluate trained model on validation/test sets with comprehensive metrics. | |
| Usage: | |
| python scripts/evaluate.py --model outputs/checkpoints/best_doctamper.pth --dataset doctamper | |
| """ | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| # Add src to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import get_config | |
| from src.models import get_model | |
| from src.data import get_dataset | |
| from src.training.metrics import SegmentationMetrics | |
| from src.utils import plot_training_curves | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Evaluate forgery detection model") | |
| parser.add_argument('--model', type=str, required=True, | |
| help='Path to model checkpoint') | |
| parser.add_argument('--dataset', type=str, required=True, | |
| choices=['doctamper', 'rtm', 'casia', 'receipts'], | |
| help='Dataset to evaluate on') | |
| parser.add_argument('--split', type=str, default='val', | |
| help='Data split (val/test)') | |
| parser.add_argument('--output', type=str, default='outputs/evaluation', | |
| help='Output directory') | |
| parser.add_argument('--config', type=str, default='config.yaml', | |
| help='Path to config file') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Load config | |
| config = get_config(args.config) | |
| # Device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print("\n" + "="*60) | |
| print("Model Evaluation") | |
| print("="*60) | |
| print(f"Model: {args.model}") | |
| print(f"Dataset: {args.dataset}") | |
| print(f"Split: {args.split}") | |
| print(f"Device: {device}") | |
| print("="*60) | |
| # Load model | |
| model = get_model(config).to(device) | |
| checkpoint = torch.load(args.model, map_location=device) | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print("Model loaded") | |
| # Load dataset | |
| dataset = get_dataset(config, args.dataset, split=args.split) | |
| print(f"Dataset loaded: {len(dataset)} samples") | |
| # Create output directory | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Evaluate | |
| metrics = SegmentationMetrics() | |
| has_pixel_mask = config.has_pixel_mask(args.dataset) | |
| print(f"\nEvaluating...") | |
| all_ious = [] | |
| all_dices = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(len(dataset)), desc="Evaluating"): | |
| try: | |
| image, mask, metadata = dataset[i] | |
| # Move to device | |
| image = image.unsqueeze(0).to(device) | |
| mask = mask.unsqueeze(0).to(device) | |
| # Forward pass | |
| logits, _ = model(image) | |
| probs = torch.sigmoid(logits) | |
| # Update metrics | |
| if has_pixel_mask: | |
| metrics.update(probs, mask, has_pixel_mask=True) | |
| # Per-sample metrics | |
| pred_binary = (probs > 0.5).float() | |
| intersection = (pred_binary * mask).sum().item() | |
| union = pred_binary.sum().item() + mask.sum().item() - intersection | |
| iou = intersection / (union + 1e-8) | |
| dice = (2 * intersection) / (pred_binary.sum().item() + mask.sum().item() + 1e-8) | |
| all_ious.append(iou) | |
| all_dices.append(dice) | |
| except Exception as e: | |
| print(f"Error on sample {i}: {e}") | |
| continue | |
| # Compute final metrics | |
| results = metrics.compute() | |
| # Add per-sample statistics | |
| if has_pixel_mask and all_ious: | |
| results['iou_mean'] = np.mean(all_ious) | |
| results['iou_std'] = np.std(all_ious) | |
| results['dice_mean'] = np.mean(all_dices) | |
| results['dice_std'] = np.std(all_dices) | |
| # Print results | |
| print("\n" + "="*60) | |
| print("Evaluation Results") | |
| print("="*60) | |
| for key, value in results.items(): | |
| if isinstance(value, float): | |
| print(f" {key}: {value:.4f}") | |
| # Save results | |
| results_path = output_dir / f'{args.dataset}_{args.split}_results.json' | |
| with open(results_path, 'w') as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nResults saved to: {results_path}") | |
| print("="*60) | |
| if __name__ == '__main__': | |
| main() | |