File size: 3,864 Bytes
9c4b1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
import time
import yaml
import torch
from PIL import Image
import torchvision.transforms as transforms
import argparse

# Add project root to path for imports
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
from support.detect_utils import format_result, save_result, get_device
from networks import create_architecture

def parse_args():
    parser = argparse.ArgumentParser(description='R50_nodown single image detector')
    parser.add_argument('--image', type=str, required=True, help='Path to input image')
    parser.add_argument('--model', type=str, default='checkpoint/best.pt', help='Path to model checkpoint')
    parser.add_argument('--output', type=str, help='Path to save detection result JSON')
    parser.add_argument('--device', type=str, help='Device to run on (e.g., cuda:0, cuda:1, cpu)')
    parser.add_argument('--config', type=str, default='configs/r50_nodown.yaml', help='Path to config file')
    return parser.parse_args()

def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def load_image(image_path, size=224):
    """Load and preprocess image."""
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
        
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def main():
    args = parse_args()
    
    # Setup device
    if args.device:
        if args.device.startswith('cuda') and not torch.cuda.is_available():
            print("CUDA is not available, falling back to CPU")
            device = torch.device('cpu')
        else:
            device = torch.device(args.device)
    else:
        device = get_device()
    
    # Load model
    try:
        # Initialize model
        model = create_architecture("res50nodown", pretrained=True, num_classes=1).to(device)
        load_path = f'./detectors/R50_nodown/checkpoint/{args.model}/weights/best.pt'
        
        if not os.path.exists(load_path):
            raise FileNotFoundError(f"Model weights not found at: {load_path}")
        checkpoint = torch.load(load_path, map_location=device)
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint)
        model.eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # Load and preprocess image
    try:
        image_tensor = load_image(args.image).to(device)
    except Exception as e:
        print(f"Error loading image: {e}")
        return

    # Run detection
    start_time = time.time()
    with torch.no_grad():
        try:
            score = model(image_tensor)
            prediction = torch.sigmoid(score)
            
            confidence = prediction.item()
            
            result = format_result(
                'fake' if confidence>0.5 else 'real',
                confidence,
                time.time() - start_time
            )
            
            # Print result
            print(f"Prediction: {result['prediction']}")
            print(f"Confidence: {result['confidence']:.4f}")
            print(f"Time: {result['elapsed_time']:.3f}s")
            
            # Save result if output path provided
            if args.output:
                save_result(result, args.output)
                
        except Exception as e:
            print(f"Error during detection: {e}")
            return

if __name__ == '__main__':
    main()