#!/usr/bin/env python3 """ Example: Using SDXL Detector from HuggingFace ============================================== Simple example showing how to use the SDXL detector to classify images as real or SDXL-generated. """ import torch from torchvision import transforms from PIL import Image from huggingface_hub import hf_hub_download import torch.nn as nn import torchvision.models as models # ============================================================================ # MODEL DEFINITION # ============================================================================ class SDXLDetector(nn.Module): """ResNet-50 based SDXL detector""" def __init__(self): super().__init__() self.backbone = models.resnet50(pretrained=False) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(p=0.15), nn.Linear(512, 2) ) def forward(self, x): return self.backbone(x) # ============================================================================ # LOAD MODEL # ============================================================================ def load_model(device='cpu'): """Load model from HuggingFace Hub""" # Download checkpoint model_path = hf_hub_download( repo_id="ash12321/sdxl-detector-resnet50", filename="best.pth" ) # Load checkpoint checkpoint = torch.load(model_path, map_location=device) # Create model and load weights model = SDXLDetector() model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() print(f"āœ… Model loaded from {model_path}") print(f" Trained for {checkpoint['epoch'] + 1} epochs") print(f" Best validation accuracy: {checkpoint['best_val_acc']:.2f}%") return model # ============================================================================ # PREPROCESSING # ============================================================================ def get_transform(): """Get image preprocessing transform""" return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ============================================================================ # PREDICTION # ============================================================================ def predict_image(model, image_path, device='cpu'): """ Predict if an image is real or SDXL-generated Args: model: Loaded SDXLDetector model image_path: Path to image file device: Device to run inference on Returns: dict with prediction, confidence, and probabilities """ # Load and preprocess image image = Image.open(image_path).convert('RGB') transform = get_transform() input_tensor = transform(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = model(input_tensor) probs = torch.softmax(outputs, dim=1) prediction = torch.argmax(probs, dim=1).item() confidence = probs[0][prediction].item() # Format results labels = ['Real', 'SDXL-generated'] return { 'prediction': labels[prediction], 'confidence': confidence, 'probabilities': { 'real': probs[0][0].item(), 'sdxl': probs[0][1].item() } } # ============================================================================ # MAIN # ============================================================================ def main(): """Example usage""" # Setup device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # Load model model = load_model(device) # Example prediction image_path = "test_image.jpg" # Replace with your image result = predict_image(model, image_path, device) print(f"\nšŸ“Š Results for {image_path}:") print(f" Prediction: {result['prediction']}") print(f" Confidence: {result['confidence']*100:.2f}%") print(f" \nProbabilities:") print(f" Real: {result['probabilities']['real']*100:.2f}%") print(f" SDXL: {result['probabilities']['sdxl']*100:.2f}%") if __name__ == "__main__": main()