|
|
|
|
|
""" |
|
|
Example: Using SDXL Detector from HuggingFace |
|
|
============================================== |
|
|
|
|
|
Simple example showing how to use the SDXL detector |
|
|
to classify images as real or SDXL-generated. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SDXLDetector(nn.Module): |
|
|
"""ResNet-50 based SDXL detector""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.backbone = models.resnet50(pretrained=False) |
|
|
num_features = self.backbone.fc.in_features |
|
|
|
|
|
self.backbone.fc = nn.Sequential( |
|
|
nn.Dropout(p=0.3), |
|
|
nn.Linear(num_features, 512), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=0.15), |
|
|
nn.Linear(512, 2) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(device='cpu'): |
|
|
"""Load model from HuggingFace Hub""" |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="ash12321/sdxl-detector-resnet50", |
|
|
filename="best.pth" |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
model = SDXLDetector() |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"✅ Model loaded from {model_path}") |
|
|
print(f" Trained for {checkpoint['epoch'] + 1} epochs") |
|
|
print(f" Best validation accuracy: {checkpoint['best_val_acc']:.2f}%") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transform(): |
|
|
"""Get image preprocessing transform""" |
|
|
return transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_image(model, image_path, device='cpu'): |
|
|
""" |
|
|
Predict if an image is real or SDXL-generated |
|
|
|
|
|
Args: |
|
|
model: Loaded SDXLDetector model |
|
|
image_path: Path to image file |
|
|
device: Device to run inference on |
|
|
|
|
|
Returns: |
|
|
dict with prediction, confidence, and probabilities |
|
|
""" |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
transform = get_transform() |
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
prediction = torch.argmax(probs, dim=1).item() |
|
|
confidence = probs[0][prediction].item() |
|
|
|
|
|
|
|
|
labels = ['Real', 'SDXL-generated'] |
|
|
|
|
|
return { |
|
|
'prediction': labels[prediction], |
|
|
'confidence': confidence, |
|
|
'probabilities': { |
|
|
'real': probs[0][0].item(), |
|
|
'sdxl': probs[0][1].item() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = load_model(device) |
|
|
|
|
|
|
|
|
image_path = "test_image.jpg" |
|
|
|
|
|
result = predict_image(model, image_path, device) |
|
|
|
|
|
print(f"\n📊 Results for {image_path}:") |
|
|
print(f" Prediction: {result['prediction']}") |
|
|
print(f" Confidence: {result['confidence']*100:.2f}%") |
|
|
print(f" \nProbabilities:") |
|
|
print(f" Real: {result['probabilities']['real']*100:.2f}%") |
|
|
print(f" SDXL: {result['probabilities']['sdxl']*100:.2f}%") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|