Spaces:
Running
Running
File size: 9,253 Bytes
9c4b1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
import argparse
import json
import time
import yaml
import torch
from PIL import Image
import torchvision.transforms.v2 as Tv2
from networks import ImageClassifier
import sys
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from support.detect_utils import format_result, save_result, get_device
# ----------------------------------------------------------------------------
# IMAGE PREPROCESSING
# ----------------------------------------------------------------------------
def preprocess_image(image_path):
"""
Load and preprocess a single image for model input.
Uses the same normalization as test.py (ImageNet stats).
"""
# Load image
image = Image.open(image_path).convert('RGB')
# Apply transforms (same as test split without augmentation)
transform = Tv2.Compose([
Tv2.ToImage(),
Tv2.ToDtype(torch.float32, scale=True),
Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply transforms and add batch dimension
tensor = transform(image)
tensor = tensor.unsqueeze(0) # Add batch dimension
return tensor
# ----------------------------------------------------------------------------
# CONFIG LOADING AND PARSING
# ----------------------------------------------------------------------------
def load_config(config_path):
"""Load configuration from YAML file."""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def parse_detector_args(detector_args, default_num_centers=1):
"""
Parse detector_args list (e.g., ["--arch", "nodown", "--prototype", "--freeze"])
into a settings object.
"""
class Settings:
def __init__(self):
self.arch = "nodown"
self.freeze = False
self.prototype = False
self.num_centers = default_num_centers
settings = Settings()
i = 0
while i < len(detector_args):
arg = detector_args[i]
if arg == "--arch":
if i + 1 < len(detector_args):
settings.arch = detector_args[i + 1]
i += 2
else:
i += 1
elif arg == "--freeze":
settings.freeze = True
i += 1
elif arg == "--prototype":
settings.prototype = True
i += 1
elif arg == "--num_centers":
if i + 1 < len(detector_args):
settings.num_centers = int(detector_args[i + 1])
i += 2
else:
i += 1
else:
i += 1
return settings
def resolve_config_path(config_path):
"""
Resolve config path. If relative, resolve relative to project root
(two levels up from detect.py location).
"""
if os.path.isabs(config_path):
return config_path
# Get directory of detect.py (detectors/R50_TF/)
detect_dir = os.path.dirname(os.path.abspath(__file__))
# Go two levels up to project root
project_root = os.path.dirname(os.path.dirname(detect_dir))
# Join with config path
return os.path.join(project_root, config_path)
# ----------------------------------------------------------------------------
# INFERENCE
# ----------------------------------------------------------------------------
def run_inference(model, image_path, device):
"""
Run inference on a single image.
Returns: (probability, label, runtime_ms)
"""
start_time = time.time()
# Preprocess image
image_tensor = preprocess_image(image_path)
image_tensor = image_tensor.to(device)
# Run inference
model.eval()
with torch.no_grad():
raw_score_tensor = model(image_tensor).squeeze(1) # shape [1]
# Convert to probability using sigmoid
probability = torch.sigmoid(raw_score_tensor).item()
# Determine label (fake if probability > 0.5, else real)
label = "fake" if probability > 0.5 else "real"
# Calculate runtime in milliseconds
runtime_ms = int((time.time() - start_time) * 1000)
return probability, label, runtime_ms
# ----------------------------------------------------------------------------
# MAIN
# ----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description='Single image inference for R50_TF detector')
parser.add_argument('--input', type=str, required=False, help='Path to input image (alias: --image)')
parser.add_argument('--image', type=str, required=False, help='Path to input image (alias for --input)')
parser.add_argument('--output', type=str, default='/tmp/result.json', help='Path to output JSON file')
parser.add_argument('--checkpoint', type=str, required=False, help='Path to model checkpoint file')
parser.add_argument('--model', type=str, required=False, help='Model name or checkpoint directory (alias for --checkpoint)')
parser.add_argument('--config', type=str, default='configs/R50_TF.yaml', help='Path to YAML config file')
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda:0, cpu, etc.)')
args = parser.parse_args()
# Normalize image argument: prefer --image over --input if provided
if args.image:
args.input = args.image
checkpoint_path = None
if args.checkpoint:
checkpoint_path = args.checkpoint
elif getattr(args, 'model', None):
detect_dir = os.path.dirname(os.path.abspath(__file__))
candidate = os.path.join(detect_dir, 'checkpoint', args.model, 'weights', 'best.pt')
if os.path.exists(candidate):
checkpoint_path = candidate
else:
# If model refers directly to a file path, accept it
if os.path.isabs(args.model) and os.path.exists(args.model):
checkpoint_path = args.model
else:
# Try resolving relative to project root
project_root = os.path.dirname(os.path.dirname(detect_dir))
candidate2 = os.path.join(project_root, args.model)
if os.path.exists(candidate2):
checkpoint_path = candidate2
# If still not found, keep existing behavior (will raise later)
if checkpoint_path:
args.checkpoint = checkpoint_path
# Resolve config path
config_path = resolve_config_path(args.config)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Load config
config = load_config(config_path)
# Parse detector_args from config
detector_args = config.get('detector_args', [])
settings = parse_detector_args(detector_args)
# Get device from config if available, else use argument
# Prioritize argument if explicitly provided (we assume if it's not default, or if we trust the caller)
# Since we want to support --device cpu override, we should prioritize args.device
device_str = args.device
# Only check config if args.device wasn't explicitly passed (but here it has a default)
# Let's assume if the user passed --device, they want that.
# But args.device has a default 'cuda:0'.
# We should change the default to None to distinguish.
if args.device is None:
if config.get('global', {}).get('device_override'):
device_override = config['global']['device_override']
if device_override and device_override != "null" and device_override != "":
device_str = device_override
else:
device_str = 'cuda:0'
else:
device_str = args.device
# Determine device
if device_str.startswith('cuda') and not torch.cuda.is_available():
print(f"Warning: CUDA requested but not available. Using CPU.")
device = torch.device('cpu')
else:
device = torch.device(device_str if torch.cuda.is_available() else 'cpu')
# Load model
print(f"Loading model from {args.checkpoint}")
model = ImageClassifier(settings)
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
model.to(device)
model.eval()
# Run inference
print(f"Running inference on {args.input}")
probability, label, runtime_ms = run_inference(model, args.input, device)
# Format result to match other detectors (prediction/confidence/elapsed_time)
elapsed_time = runtime_ms / 1000.0
formatted = format_result(label, float(round(probability, 4)), elapsed_time)
# Save using shared utility (if output path is provided)
if args.output:
save_result(formatted, args.output)
print(f"Results saved to {args.output}")
# Print concise output for user
print(f"Prediction: {formatted['prediction']}")
print(f"Confidence: {formatted['confidence']:.4f}")
print(f"Time: {formatted['elapsed_time']:.3f}s")
if __name__ == '__main__':
main()
|