File size: 4,595 Bytes
84ab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""
Example: Using SDXL Detector from HuggingFace
==============================================

Simple example showing how to use the SDXL detector
to classify images as real or SDXL-generated.
"""

import torch
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torchvision.models as models

# ============================================================================
# MODEL DEFINITION
# ============================================================================

class SDXLDetector(nn.Module):
    """ResNet-50 based SDXL detector"""
    
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(pretrained=False)
        num_features = self.backbone.fc.in_features
        
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.15),
            nn.Linear(512, 2)
        )
    
    def forward(self, x):
        return self.backbone(x)

# ============================================================================
# LOAD MODEL
# ============================================================================

def load_model(device='cpu'):
    """Load model from HuggingFace Hub"""
    
    # Download checkpoint
    model_path = hf_hub_download(
        repo_id="ash12321/sdxl-detector-resnet50",
        filename="best.pth"
    )
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Create model and load weights
    model = SDXLDetector()
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"✅ Model loaded from {model_path}")
    print(f"   Trained for {checkpoint['epoch'] + 1} epochs")
    print(f"   Best validation accuracy: {checkpoint['best_val_acc']:.2f}%")
    
    return model

# ============================================================================
# PREPROCESSING
# ============================================================================

def get_transform():
    """Get image preprocessing transform"""
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

# ============================================================================
# PREDICTION
# ============================================================================

def predict_image(model, image_path, device='cpu'):
    """
    Predict if an image is real or SDXL-generated
    
    Args:
        model: Loaded SDXLDetector model
        image_path: Path to image file
        device: Device to run inference on
    
    Returns:
        dict with prediction, confidence, and probabilities
    """
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    transform = get_transform()
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs, dim=1)
        prediction = torch.argmax(probs, dim=1).item()
        confidence = probs[0][prediction].item()
    
    # Format results
    labels = ['Real', 'SDXL-generated']
    
    return {
        'prediction': labels[prediction],
        'confidence': confidence,
        'probabilities': {
            'real': probs[0][0].item(),
            'sdxl': probs[0][1].item()
        }
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    """Example usage"""
    
    # Setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Load model
    model = load_model(device)
    
    # Example prediction
    image_path = "test_image.jpg"  # Replace with your image
    
    result = predict_image(model, image_path, device)
    
    print(f"\n📊 Results for {image_path}:")
    print(f"   Prediction: {result['prediction']}")
    print(f"   Confidence: {result['confidence']*100:.2f}%")
    print(f"   \nProbabilities:")
    print(f"      Real: {result['probabilities']['real']*100:.2f}%")
    print(f"      SDXL: {result['probabilities']['sdxl']*100:.2f}%")

if __name__ == "__main__":
    main()