Spaces:
Sleeping
Sleeping
File size: 4,933 Bytes
51fdac5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
Model Evaluation Script
Evaluate trained model on validation/test sets with comprehensive metrics.
Usage:
python scripts/evaluate.py --model outputs/checkpoints/best_doctamper.pth --dataset doctamper
"""
import argparse
import sys
from pathlib import Path
import json
import numpy as np
from tqdm import tqdm
import torch
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.config import get_config
from src.models import get_model
from src.data import get_dataset
from src.training.metrics import SegmentationMetrics
from src.utils import plot_training_curves
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate forgery detection model")
parser.add_argument('--model', type=str, required=True,
help='Path to model checkpoint')
parser.add_argument('--dataset', type=str, required=True,
choices=['doctamper', 'rtm', 'casia', 'receipts'],
help='Dataset to evaluate on')
parser.add_argument('--split', type=str, default='val',
help='Data split (val/test)')
parser.add_argument('--output', type=str, default='outputs/evaluation',
help='Output directory')
parser.add_argument('--config', type=str, default='config.yaml',
help='Path to config file')
return parser.parse_args()
def main():
args = parse_args()
# Load config
config = get_config(args.config)
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("\n" + "="*60)
print("Model Evaluation")
print("="*60)
print(f"Model: {args.model}")
print(f"Dataset: {args.dataset}")
print(f"Split: {args.split}")
print(f"Device: {device}")
print("="*60)
# Load model
model = get_model(config).to(device)
checkpoint = torch.load(args.model, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("Model loaded")
# Load dataset
dataset = get_dataset(config, args.dataset, split=args.split)
print(f"Dataset loaded: {len(dataset)} samples")
# Create output directory
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Evaluate
metrics = SegmentationMetrics()
has_pixel_mask = config.has_pixel_mask(args.dataset)
print(f"\nEvaluating...")
all_ious = []
all_dices = []
with torch.no_grad():
for i in tqdm(range(len(dataset)), desc="Evaluating"):
try:
image, mask, metadata = dataset[i]
# Move to device
image = image.unsqueeze(0).to(device)
mask = mask.unsqueeze(0).to(device)
# Forward pass
logits, _ = model(image)
probs = torch.sigmoid(logits)
# Update metrics
if has_pixel_mask:
metrics.update(probs, mask, has_pixel_mask=True)
# Per-sample metrics
pred_binary = (probs > 0.5).float()
intersection = (pred_binary * mask).sum().item()
union = pred_binary.sum().item() + mask.sum().item() - intersection
iou = intersection / (union + 1e-8)
dice = (2 * intersection) / (pred_binary.sum().item() + mask.sum().item() + 1e-8)
all_ious.append(iou)
all_dices.append(dice)
except Exception as e:
print(f"Error on sample {i}: {e}")
continue
# Compute final metrics
results = metrics.compute()
# Add per-sample statistics
if has_pixel_mask and all_ious:
results['iou_mean'] = np.mean(all_ious)
results['iou_std'] = np.std(all_ious)
results['dice_mean'] = np.mean(all_dices)
results['dice_std'] = np.std(all_dices)
# Print results
print("\n" + "="*60)
print("Evaluation Results")
print("="*60)
for key, value in results.items():
if isinstance(value, float):
print(f" {key}: {value:.4f}")
# Save results
results_path = output_dir / f'{args.dataset}_{args.split}_results.json'
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_path}")
print("="*60)
if __name__ == '__main__':
main()
|