""" SDXL Detector Model =================== Vision Transformer-based model for detecting SDXL-generated images. This model is a binary classifier that detects whether an image was generated by Stable Diffusion XL (SDXL). ⚠️ IMPORTANT: This model ONLY detects SDXL images! - SDXL images → Classified as "Fake" - Real images → Classified as "Real" - FLUX/Midjourney/other AI → Classified as "Real" (not trained on these!) For comprehensive AI detection, use this as part of an ensemble with other specialized detectors. Architecture: - Base: Vision Transformer (ViT-base-patch16-224) - Classifier: Dropout + Linear (768 → 2) - Output: Binary (0=Real, 1=SDXL-Fake) Quick Start: from transformers import ViTForImageClassification, ViTImageProcessor from PIL import Image # Load model model = ViTForImageClassification.from_pretrained( "ash12321/sdxl-detector-vit" ) processor = ViTImageProcessor.from_pretrained( "google/vit-base-patch16-224" ) # Process image image = Image.open("test.jpg") inputs = processor(images=image, return_tensors="pt") # Get prediction outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1) if probs[0][1] > 0.5: print(f"SDXL-Generated: {probs[0][1]:.2%}") else: print(f"Not SDXL: {probs[0][0]:.2%}") Performance: Test Accuracy: 99.60% Precision: 99.30% Recall: 99.90% False Positive Rate: 0.70% False Negative Rate: 0.10% """ import torch import torch.nn as nn from transformers import ViTForImageClassification, ViTImageProcessor from PIL import Image from typing import Dict, Union, Optional from pathlib import Path class SDXLDetector: """ SDXL Image Detector Easy-to-use wrapper for detecting SDXL-generated images. """ def __init__( self, model_path: str = "ash12321/sdxl-detector-vit", device: str = None ): """ Initialize SDXL detector Args: model_path: HuggingFace model repo or local path device: Device to use ('cuda', 'cpu', or None for auto) """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.model_path = model_path # Load model and processor self.model = ViTForImageClassification.from_pretrained(model_path) self.model.to(device) self.model.eval() self.processor = ViTImageProcessor.from_pretrained( "google/vit-base-patch16-224" ) print(f"✅ SDXL Detector loaded on {device}") def detect( self, image: Union[str, Path, Image.Image], threshold: float = 0.5 ) -> Dict[str, Union[bool, float]]: """ Detect if image is SDXL-generated Args: image: Image path or PIL Image threshold: Classification threshold (default 0.5) Returns: dict with keys: - is_sdxl: bool - True if SDXL-generated - confidence: float - Confidence in prediction - sdxl_probability: float - Probability of being SDXL - real_probability: float - Probability of being real - label: str - Human-readable label """ # Load image if path if isinstance(image, (str, Path)): image = Image.open(image).convert('RGB') # Process image inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get prediction with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=1) sdxl_prob = probs[0][1].item() real_prob = probs[0][0].item() is_sdxl = sdxl_prob > threshold return { 'is_sdxl': is_sdxl, 'confidence': sdxl_prob if is_sdxl else real_prob, 'sdxl_probability': sdxl_prob, 'real_probability': real_prob, 'label': 'SDXL-Generated' if is_sdxl else 'Not SDXL' } def batch_detect( self, images: list, threshold: float = 0.5 ) -> list: """ Detect SDXL on multiple images Args: images: List of image paths or PIL Images threshold: Classification threshold Returns: List of detection results """ return [self.detect(img, threshold) for img in images] def detect_sdxl( image_path: str, threshold: float = 0.5, device: str = None ) -> Dict[str, Union[bool, float]]: """ Quick function to detect SDXL image Args: image_path: Path to image threshold: Classification threshold device: Device to use Returns: Detection results dictionary Example: >>> result = detect_sdxl("image.jpg") >>> print(f"Is SDXL: {result['is_sdxl']}") >>> print(f"Confidence: {result['confidence']:.2%}") """ detector = SDXLDetector(device=device) return detector.detect(image_path, threshold) # Model specifications MODEL_INFO = { 'name': 'SDXL Detector', 'version': '1.0', 'type': 'Binary Classifier', 'detects': 'Stable Diffusion XL (SDXL) images', 'does_not_detect': [ 'FLUX images', 'Midjourney images', 'DALL-E images', 'Other AI generators' ], 'architecture': 'Vision Transformer (ViT-base-patch16-224)', 'input_size': (224, 224), 'classes': { 0: 'Real / Not SDXL', 1: 'SDXL-Generated' }, 'performance': { 'test_accuracy': 0.9960, 'precision': 0.9930, 'recall': 0.9990, 'f1_score': 0.9960, 'false_positive_rate': 0.0070, 'false_negative_rate': 0.0010 }, 'training': { 'real_images': 8000, 'sdxl_images': 8000, 'epochs': 12, 'best_epoch': 3 } } if __name__ == "__main__": print("="*60) print("SDXL Detector - Model Information") print("="*60) print(f"\nModel: {MODEL_INFO['name']}") print(f"Detects: {MODEL_INFO['detects']}") print(f"\n⚠️ Does NOT detect:") for item in MODEL_INFO['does_not_detect']: print(f" - {item}") print(f"\n📊 Performance:") print(f" Accuracy: {MODEL_INFO['performance']['test_accuracy']:.2%}") print(f" Precision: {MODEL_INFO['performance']['precision']:.2%}") print(f" Recall: {MODEL_INFO['performance']['recall']:.2%}") print(f" FPR: {MODEL_INFO['performance']['false_positive_rate']:.2%}") print(f" FNR: {MODEL_INFO['performance']['false_negative_rate']:.2%}") print("\n" + "="*60) print("Example Usage:") print("="*60) print(""" from model import SDXLDetector # Initialize detector detector = SDXLDetector() # Detect single image result = detector.detect("image.jpg") print(f"Is SDXL: {result['is_sdxl']}") print(f"Confidence: {result['confidence']:.2%}") # Or use quick function from model import detect_sdxl result = detect_sdxl("image.jpg") """)