sdxl-detector-resnet50 / example_usage.py
ash12321's picture
Add Example code
84ab660 verified
#!/usr/bin/env python3
"""
Example: Using SDXL Detector from HuggingFace
==============================================
Simple example showing how to use the SDXL detector
to classify images as real or SDXL-generated.
"""
import torch
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torchvision.models as models
# ============================================================================
# MODEL DEFINITION
# ============================================================================
class SDXLDetector(nn.Module):
"""ResNet-50 based SDXL detector"""
def __init__(self):
super().__init__()
self.backbone = models.resnet50(pretrained=False)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.15),
nn.Linear(512, 2)
)
def forward(self, x):
return self.backbone(x)
# ============================================================================
# LOAD MODEL
# ============================================================================
def load_model(device='cpu'):
"""Load model from HuggingFace Hub"""
# Download checkpoint
model_path = hf_hub_download(
repo_id="ash12321/sdxl-detector-resnet50",
filename="best.pth"
)
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Create model and load weights
model = SDXLDetector()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print(f"✅ Model loaded from {model_path}")
print(f" Trained for {checkpoint['epoch'] + 1} epochs")
print(f" Best validation accuracy: {checkpoint['best_val_acc']:.2f}%")
return model
# ============================================================================
# PREPROCESSING
# ============================================================================
def get_transform():
"""Get image preprocessing transform"""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================================
# PREDICTION
# ============================================================================
def predict_image(model, image_path, device='cpu'):
"""
Predict if an image is real or SDXL-generated
Args:
model: Loaded SDXLDetector model
image_path: Path to image file
device: Device to run inference on
Returns:
dict with prediction, confidence, and probabilities
"""
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
transform = get_transform()
input_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.softmax(outputs, dim=1)
prediction = torch.argmax(probs, dim=1).item()
confidence = probs[0][prediction].item()
# Format results
labels = ['Real', 'SDXL-generated']
return {
'prediction': labels[prediction],
'confidence': confidence,
'probabilities': {
'real': probs[0][0].item(),
'sdxl': probs[0][1].item()
}
}
# ============================================================================
# MAIN
# ============================================================================
def main():
"""Example usage"""
# Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load model
model = load_model(device)
# Example prediction
image_path = "test_image.jpg" # Replace with your image
result = predict_image(model, image_path, device)
print(f"\n📊 Results for {image_path}:")
print(f" Prediction: {result['prediction']}")
print(f" Confidence: {result['confidence']*100:.2f}%")
print(f" \nProbabilities:")
print(f" Real: {result['probabilities']['real']*100:.2f}%")
print(f" SDXL: {result['probabilities']['sdxl']*100:.2f}%")
if __name__ == "__main__":
main()