AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
raw
history blame
4.04 kB
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
import sys
import time
import torch
import argparse
from PIL import Image
import torchvision.transforms as transforms
# Add project root to path for imports
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from support.detect_utils import format_result, save_result, get_device
from networks import create_architecture
# ----------------------------------------------------------------------------
# IMAGE PREPROCESSING
# ----------------------------------------------------------------------------
def preprocess_image(image_path, size=224):
"""Load and preprocess a single image for model input."""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0) # Add batch dimension
# ----------------------------------------------------------------------------
# ARGUMENT PARSING
# ----------------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(description='CLIP-D single image detector')
parser.add_argument('--image', type=str, required=True, help='Path to input image')
parser.add_argument('--model', type=str, default='pretrained', help='Name of the model checkpoint directory')
parser.add_argument('--output', type=str, help='Path to save detection result JSON')
parser.add_argument('--device', type=str, help='Device to run on (e.g., cuda:0, cuda:1, cpu)')
return parser.parse_args()
def main():
args = parse_args()
# Setup device
device = torch.device(args.device) if args.device else get_device()
# Load model
try:
load_path = f'./detectors/CLIP-D/checkpoint/{args.model}/weights/best.pt'
if not os.path.exists(load_path):
raise FileNotFoundError(f"Model weights not found at: {load_path}")
checkpoint = torch.load(load_path, map_location=device)
# Initialize model and load state
model = create_architecture("opencliplinearnext_clipL14commonpool", pretrained=False, num_classes=1).to(device)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
else:
model.load_state_dict(checkpoint)
model.eval()
except Exception as e:
print(f"Error loading model: {e}")
return
# Load and preprocess image
try:
image_tensor = preprocess_image(args.image)
image_tensor = image_tensor.to(device)
except Exception as e:
print(f"Error loading image: {e}")
return
# Run detection
start_time = time.time()
with torch.no_grad():
try:
score = model(image_tensor)
prediction = torch.sigmoid(score)
confidence = prediction.item()
result = format_result(
'fake' if confidence>0.5 else 'real',
confidence,
time.time() - start_time
)
# Print result
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Time: {result['elapsed_time']:.3f}s")
# Save result if output path provided
if args.output:
save_result(result, args.output)
except Exception as e:
print(f"Error during detection: {e}")
return
if __name__ == '__main__':
main()