File size: 4,933 Bytes
51fdac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""

Model Evaluation Script



Evaluate trained model on validation/test sets with comprehensive metrics.



Usage:

    python scripts/evaluate.py --model outputs/checkpoints/best_doctamper.pth --dataset doctamper

"""

import argparse
import sys
from pathlib import Path
import json
import numpy as np
from tqdm import tqdm
import torch

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.config import get_config
from src.models import get_model
from src.data import get_dataset
from src.training.metrics import SegmentationMetrics
from src.utils import plot_training_curves


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate forgery detection model")
    
    parser.add_argument('--model', type=str, required=True,
                       help='Path to model checkpoint')
    
    parser.add_argument('--dataset', type=str, required=True,
                       choices=['doctamper', 'rtm', 'casia', 'receipts'],
                       help='Dataset to evaluate on')
    
    parser.add_argument('--split', type=str, default='val',
                       help='Data split (val/test)')
    
    parser.add_argument('--output', type=str, default='outputs/evaluation',
                       help='Output directory')
    
    parser.add_argument('--config', type=str, default='config.yaml',
                       help='Path to config file')
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Load config
    config = get_config(args.config)
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("\n" + "="*60)
    print("Model Evaluation")
    print("="*60)
    print(f"Model: {args.model}")
    print(f"Dataset: {args.dataset}")
    print(f"Split: {args.split}")
    print(f"Device: {device}")
    print("="*60)
    
    # Load model
    model = get_model(config).to(device)
    checkpoint = torch.load(args.model, map_location=device)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    print("Model loaded")
    
    # Load dataset
    dataset = get_dataset(config, args.dataset, split=args.split)
    print(f"Dataset loaded: {len(dataset)} samples")
    
    # Create output directory
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Evaluate
    metrics = SegmentationMetrics()
    has_pixel_mask = config.has_pixel_mask(args.dataset)
    
    print(f"\nEvaluating...")
    
    all_ious = []
    all_dices = []
    
    with torch.no_grad():
        for i in tqdm(range(len(dataset)), desc="Evaluating"):
            try:
                image, mask, metadata = dataset[i]
                
                # Move to device
                image = image.unsqueeze(0).to(device)
                mask = mask.unsqueeze(0).to(device)
                
                # Forward pass
                logits, _ = model(image)
                probs = torch.sigmoid(logits)
                
                # Update metrics
                if has_pixel_mask:
                    metrics.update(probs, mask, has_pixel_mask=True)
                    
                    # Per-sample metrics
                    pred_binary = (probs > 0.5).float()
                    intersection = (pred_binary * mask).sum().item()
                    union = pred_binary.sum().item() + mask.sum().item() - intersection
                    
                    iou = intersection / (union + 1e-8)
                    dice = (2 * intersection) / (pred_binary.sum().item() + mask.sum().item() + 1e-8)
                    
                    all_ious.append(iou)
                    all_dices.append(dice)
            
            except Exception as e:
                print(f"Error on sample {i}: {e}")
                continue
    
    # Compute final metrics
    results = metrics.compute()
    
    # Add per-sample statistics
    if has_pixel_mask and all_ious:
        results['iou_mean'] = np.mean(all_ious)
        results['iou_std'] = np.std(all_ious)
        results['dice_mean'] = np.mean(all_dices)
        results['dice_std'] = np.std(all_dices)
    
    # Print results
    print("\n" + "="*60)
    print("Evaluation Results")
    print("="*60)
    
    for key, value in results.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
    
    # Save results
    results_path = output_dir / f'{args.dataset}_{args.split}_results.json'
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to: {results_path}")
    print("="*60)


if __name__ == '__main__':
    main()