File size: 4,595 Bytes
84ab660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python3
"""
Example: Using SDXL Detector from HuggingFace
==============================================
Simple example showing how to use the SDXL detector
to classify images as real or SDXL-generated.
"""
import torch
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torchvision.models as models
# ============================================================================
# MODEL DEFINITION
# ============================================================================
class SDXLDetector(nn.Module):
"""ResNet-50 based SDXL detector"""
def __init__(self):
super().__init__()
self.backbone = models.resnet50(pretrained=False)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.15),
nn.Linear(512, 2)
)
def forward(self, x):
return self.backbone(x)
# ============================================================================
# LOAD MODEL
# ============================================================================
def load_model(device='cpu'):
"""Load model from HuggingFace Hub"""
# Download checkpoint
model_path = hf_hub_download(
repo_id="ash12321/sdxl-detector-resnet50",
filename="best.pth"
)
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Create model and load weights
model = SDXLDetector()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print(f"✅ Model loaded from {model_path}")
print(f" Trained for {checkpoint['epoch'] + 1} epochs")
print(f" Best validation accuracy: {checkpoint['best_val_acc']:.2f}%")
return model
# ============================================================================
# PREPROCESSING
# ============================================================================
def get_transform():
"""Get image preprocessing transform"""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================================
# PREDICTION
# ============================================================================
def predict_image(model, image_path, device='cpu'):
"""
Predict if an image is real or SDXL-generated
Args:
model: Loaded SDXLDetector model
image_path: Path to image file
device: Device to run inference on
Returns:
dict with prediction, confidence, and probabilities
"""
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
transform = get_transform()
input_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.softmax(outputs, dim=1)
prediction = torch.argmax(probs, dim=1).item()
confidence = probs[0][prediction].item()
# Format results
labels = ['Real', 'SDXL-generated']
return {
'prediction': labels[prediction],
'confidence': confidence,
'probabilities': {
'real': probs[0][0].item(),
'sdxl': probs[0][1].item()
}
}
# ============================================================================
# MAIN
# ============================================================================
def main():
"""Example usage"""
# Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load model
model = load_model(device)
# Example prediction
image_path = "test_image.jpg" # Replace with your image
result = predict_image(model, image_path, device)
print(f"\n📊 Results for {image_path}:")
print(f" Prediction: {result['prediction']}")
print(f" Confidence: {result['confidence']*100:.2f}%")
print(f" \nProbabilities:")
print(f" Real: {result['probabilities']['real']*100:.2f}%")
print(f" SDXL: {result['probabilities']['sdxl']*100:.2f}%")
if __name__ == "__main__":
main()
|