sdxl-detector-vit / model.py
ash12321's picture
Create model.py
13ee6ba verified
"""
SDXL Detector Model
===================
Vision Transformer-based model for detecting SDXL-generated images.
This model is a binary classifier that detects whether an image
was generated by Stable Diffusion XL (SDXL).
⚠️ IMPORTANT: This model ONLY detects SDXL images!
- SDXL images → Classified as "Fake"
- Real images → Classified as "Real"
- FLUX/Midjourney/other AI → Classified as "Real" (not trained on these!)
For comprehensive AI detection, use this as part of an ensemble with
other specialized detectors.
Architecture:
- Base: Vision Transformer (ViT-base-patch16-224)
- Classifier: Dropout + Linear (768 → 2)
- Output: Binary (0=Real, 1=SDXL-Fake)
Quick Start:
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
# Load model
model = ViTForImageClassification.from_pretrained(
"ash12321/sdxl-detector-vit"
)
processor = ViTImageProcessor.from_pretrained(
"google/vit-base-patch16-224"
)
# Process image
image = Image.open("test.jpg")
inputs = processor(images=image, return_tensors="pt")
# Get prediction
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
if probs[0][1] > 0.5:
print(f"SDXL-Generated: {probs[0][1]:.2%}")
else:
print(f"Not SDXL: {probs[0][0]:.2%}")
Performance:
Test Accuracy: 99.60%
Precision: 99.30%
Recall: 99.90%
False Positive Rate: 0.70%
False Negative Rate: 0.10%
"""
import torch
import torch.nn as nn
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
from typing import Dict, Union, Optional
from pathlib import Path
class SDXLDetector:
"""
SDXL Image Detector
Easy-to-use wrapper for detecting SDXL-generated images.
"""
def __init__(
self,
model_path: str = "ash12321/sdxl-detector-vit",
device: str = None
):
"""
Initialize SDXL detector
Args:
model_path: HuggingFace model repo or local path
device: Device to use ('cuda', 'cpu', or None for auto)
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.model_path = model_path
# Load model and processor
self.model = ViTForImageClassification.from_pretrained(model_path)
self.model.to(device)
self.model.eval()
self.processor = ViTImageProcessor.from_pretrained(
"google/vit-base-patch16-224"
)
print(f"✅ SDXL Detector loaded on {device}")
def detect(
self,
image: Union[str, Path, Image.Image],
threshold: float = 0.5
) -> Dict[str, Union[bool, float]]:
"""
Detect if image is SDXL-generated
Args:
image: Image path or PIL Image
threshold: Classification threshold (default 0.5)
Returns:
dict with keys:
- is_sdxl: bool - True if SDXL-generated
- confidence: float - Confidence in prediction
- sdxl_probability: float - Probability of being SDXL
- real_probability: float - Probability of being real
- label: str - Human-readable label
"""
# Load image if path
if isinstance(image, (str, Path)):
image = Image.open(image).convert('RGB')
# Process image
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
sdxl_prob = probs[0][1].item()
real_prob = probs[0][0].item()
is_sdxl = sdxl_prob > threshold
return {
'is_sdxl': is_sdxl,
'confidence': sdxl_prob if is_sdxl else real_prob,
'sdxl_probability': sdxl_prob,
'real_probability': real_prob,
'label': 'SDXL-Generated' if is_sdxl else 'Not SDXL'
}
def batch_detect(
self,
images: list,
threshold: float = 0.5
) -> list:
"""
Detect SDXL on multiple images
Args:
images: List of image paths or PIL Images
threshold: Classification threshold
Returns:
List of detection results
"""
return [self.detect(img, threshold) for img in images]
def detect_sdxl(
image_path: str,
threshold: float = 0.5,
device: str = None
) -> Dict[str, Union[bool, float]]:
"""
Quick function to detect SDXL image
Args:
image_path: Path to image
threshold: Classification threshold
device: Device to use
Returns:
Detection results dictionary
Example:
>>> result = detect_sdxl("image.jpg")
>>> print(f"Is SDXL: {result['is_sdxl']}")
>>> print(f"Confidence: {result['confidence']:.2%}")
"""
detector = SDXLDetector(device=device)
return detector.detect(image_path, threshold)
# Model specifications
MODEL_INFO = {
'name': 'SDXL Detector',
'version': '1.0',
'type': 'Binary Classifier',
'detects': 'Stable Diffusion XL (SDXL) images',
'does_not_detect': [
'FLUX images',
'Midjourney images',
'DALL-E images',
'Other AI generators'
],
'architecture': 'Vision Transformer (ViT-base-patch16-224)',
'input_size': (224, 224),
'classes': {
0: 'Real / Not SDXL',
1: 'SDXL-Generated'
},
'performance': {
'test_accuracy': 0.9960,
'precision': 0.9930,
'recall': 0.9990,
'f1_score': 0.9960,
'false_positive_rate': 0.0070,
'false_negative_rate': 0.0010
},
'training': {
'real_images': 8000,
'sdxl_images': 8000,
'epochs': 12,
'best_epoch': 3
}
}
if __name__ == "__main__":
print("="*60)
print("SDXL Detector - Model Information")
print("="*60)
print(f"\nModel: {MODEL_INFO['name']}")
print(f"Detects: {MODEL_INFO['detects']}")
print(f"\n⚠️ Does NOT detect:")
for item in MODEL_INFO['does_not_detect']:
print(f" - {item}")
print(f"\n📊 Performance:")
print(f" Accuracy: {MODEL_INFO['performance']['test_accuracy']:.2%}")
print(f" Precision: {MODEL_INFO['performance']['precision']:.2%}")
print(f" Recall: {MODEL_INFO['performance']['recall']:.2%}")
print(f" FPR: {MODEL_INFO['performance']['false_positive_rate']:.2%}")
print(f" FNR: {MODEL_INFO['performance']['false_negative_rate']:.2%}")
print("\n" + "="*60)
print("Example Usage:")
print("="*60)
print("""
from model import SDXLDetector
# Initialize detector
detector = SDXLDetector()
# Detect single image
result = detector.detect("image.jpg")
print(f"Is SDXL: {result['is_sdxl']}")
print(f"Confidence: {result['confidence']:.2%}")
# Or use quick function
from model import detect_sdxl
result = detect_sdxl("image.jpg")
""")