File size: 9,253 Bytes
9c4b1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
import argparse
import json
import time
import yaml
import torch
from PIL import Image
import torchvision.transforms.v2 as Tv2

from networks import ImageClassifier
import sys
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from support.detect_utils import format_result, save_result, get_device


# ----------------------------------------------------------------------------
# IMAGE PREPROCESSING
# ----------------------------------------------------------------------------
def preprocess_image(image_path):
    """
    Load and preprocess a single image for model input.
    Uses the same normalization as test.py (ImageNet stats).
    """
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Apply transforms (same as test split without augmentation)
    transform = Tv2.Compose([
        Tv2.ToImage(),
        Tv2.ToDtype(torch.float32, scale=True),
        Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Apply transforms and add batch dimension
    tensor = transform(image)
    tensor = tensor.unsqueeze(0)  # Add batch dimension
    
    return tensor


# ----------------------------------------------------------------------------
# CONFIG LOADING AND PARSING
# ----------------------------------------------------------------------------
def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def parse_detector_args(detector_args, default_num_centers=1):
    """
    Parse detector_args list (e.g., ["--arch", "nodown", "--prototype", "--freeze"])
    into a settings object.
    """
    class Settings:
        def __init__(self):
            self.arch = "nodown"
            self.freeze = False
            self.prototype = False
            self.num_centers = default_num_centers
    
    settings = Settings()
    
    i = 0
    while i < len(detector_args):
        arg = detector_args[i]
        
        if arg == "--arch":
            if i + 1 < len(detector_args):
                settings.arch = detector_args[i + 1]
                i += 2
            else:
                i += 1
        elif arg == "--freeze":
            settings.freeze = True
            i += 1
        elif arg == "--prototype":
            settings.prototype = True
            i += 1
        elif arg == "--num_centers":
            if i + 1 < len(detector_args):
                settings.num_centers = int(detector_args[i + 1])
                i += 2
            else:
                i += 1
        else:
            i += 1
    
    return settings


def resolve_config_path(config_path):
    """
    Resolve config path. If relative, resolve relative to project root
    (two levels up from detect.py location).
    """
    if os.path.isabs(config_path):
        return config_path
    
    # Get directory of detect.py (detectors/R50_TF/)
    detect_dir = os.path.dirname(os.path.abspath(__file__))
    # Go two levels up to project root
    project_root = os.path.dirname(os.path.dirname(detect_dir))
    # Join with config path
    return os.path.join(project_root, config_path)


# ----------------------------------------------------------------------------
# INFERENCE
# ----------------------------------------------------------------------------
def run_inference(model, image_path, device):
    """
    Run inference on a single image.
    Returns: (probability, label, runtime_ms)
    """
    start_time = time.time()
    
    # Preprocess image
    image_tensor = preprocess_image(image_path)
    image_tensor = image_tensor.to(device)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        raw_score_tensor = model(image_tensor).squeeze(1)  # shape [1]
    
    # Convert to probability using sigmoid
    probability = torch.sigmoid(raw_score_tensor).item()
    
    # Determine label (fake if probability > 0.5, else real)
    label = "fake" if probability > 0.5 else "real"
    
    # Calculate runtime in milliseconds
    runtime_ms = int((time.time() - start_time) * 1000)
    
    return probability, label, runtime_ms


# ----------------------------------------------------------------------------
# MAIN
# ----------------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(description='Single image inference for R50_TF detector')
    parser.add_argument('--input', type=str, required=False, help='Path to input image (alias: --image)')
    parser.add_argument('--image', type=str, required=False, help='Path to input image (alias for --input)')
    parser.add_argument('--output', type=str, default='/tmp/result.json', help='Path to output JSON file')
    parser.add_argument('--checkpoint', type=str, required=False, help='Path to model checkpoint file')
    parser.add_argument('--model', type=str, required=False, help='Model name or checkpoint directory (alias for --checkpoint)')
    parser.add_argument('--config', type=str, default='configs/R50_TF.yaml', help='Path to YAML config file')
    parser.add_argument('--device', type=str, default=None, help='Device to use (cuda:0, cpu, etc.)')
    
    args = parser.parse_args()

    # Normalize image argument: prefer --image over --input if provided
    if args.image:
        args.input = args.image

    checkpoint_path = None
    if args.checkpoint:
        checkpoint_path = args.checkpoint
    elif getattr(args, 'model', None):
        detect_dir = os.path.dirname(os.path.abspath(__file__))
        candidate = os.path.join(detect_dir, 'checkpoint', args.model, 'weights', 'best.pt')
        if os.path.exists(candidate):
            checkpoint_path = candidate
        else:
            # If model refers directly to a file path, accept it
            if os.path.isabs(args.model) and os.path.exists(args.model):
                checkpoint_path = args.model
            else:
                # Try resolving relative to project root
                project_root = os.path.dirname(os.path.dirname(detect_dir))
                candidate2 = os.path.join(project_root, args.model)
                if os.path.exists(candidate2):
                    checkpoint_path = candidate2

    # If still not found, keep existing behavior (will raise later)
    if checkpoint_path:
        args.checkpoint = checkpoint_path
    
    # Resolve config path
    config_path = resolve_config_path(args.config)
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    # Load config
    config = load_config(config_path)
    
    # Parse detector_args from config
    detector_args = config.get('detector_args', [])
    settings = parse_detector_args(detector_args)
    
    # Get device from config if available, else use argument
    # Prioritize argument if explicitly provided (we assume if it's not default, or if we trust the caller)
    # Since we want to support --device cpu override, we should prioritize args.device
    device_str = args.device
    
    # Only check config if args.device wasn't explicitly passed (but here it has a default)
    # Let's assume if the user passed --device, they want that.
    # But args.device has a default 'cuda:0'.
    # We should change the default to None to distinguish.
    
    if args.device is None:
        if config.get('global', {}).get('device_override'):
            device_override = config['global']['device_override']
            if device_override and device_override != "null" and device_override != "":
                device_str = device_override
        else:
             device_str = 'cuda:0'
    else:
        device_str = args.device
    
    # Determine device
    if device_str.startswith('cuda') and not torch.cuda.is_available():
        print(f"Warning: CUDA requested but not available. Using CPU.")
        device = torch.device('cpu')
    else:
        device = torch.device(device_str if torch.cuda.is_available() else 'cpu')
    
    # Load model
    print(f"Loading model from {args.checkpoint}")
    model = ImageClassifier(settings)
    model.load_state_dict(torch.load(args.checkpoint, map_location=device))
    model.to(device)
    model.eval()
    
    # Run inference
    print(f"Running inference on {args.input}")
    probability, label, runtime_ms = run_inference(model, args.input, device)
    
    # Format result to match other detectors (prediction/confidence/elapsed_time)
    elapsed_time = runtime_ms / 1000.0
    formatted = format_result(label, float(round(probability, 4)), elapsed_time)

    # Save using shared utility (if output path is provided)
    if args.output:
        save_result(formatted, args.output)
        print(f"Results saved to {args.output}")

    # Print concise output for user
    print(f"Prediction: {formatted['prediction']}")
    print(f"Confidence: {formatted['confidence']:.4f}")
    print(f"Time: {formatted['elapsed_time']:.3f}s")


if __name__ == '__main__':
    main()