|
|
""" |
|
|
SDXL Detector Model |
|
|
=================== |
|
|
|
|
|
Vision Transformer-based model for detecting SDXL-generated images. |
|
|
|
|
|
This model is a binary classifier that detects whether an image |
|
|
was generated by Stable Diffusion XL (SDXL). |
|
|
|
|
|
⚠️ IMPORTANT: This model ONLY detects SDXL images! |
|
|
- SDXL images → Classified as "Fake" |
|
|
- Real images → Classified as "Real" |
|
|
- FLUX/Midjourney/other AI → Classified as "Real" (not trained on these!) |
|
|
|
|
|
For comprehensive AI detection, use this as part of an ensemble with |
|
|
other specialized detectors. |
|
|
|
|
|
Architecture: |
|
|
- Base: Vision Transformer (ViT-base-patch16-224) |
|
|
- Classifier: Dropout + Linear (768 → 2) |
|
|
- Output: Binary (0=Real, 1=SDXL-Fake) |
|
|
|
|
|
Quick Start: |
|
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
|
from PIL import Image |
|
|
|
|
|
# Load model |
|
|
model = ViTForImageClassification.from_pretrained( |
|
|
"ash12321/sdxl-detector-vit" |
|
|
) |
|
|
processor = ViTImageProcessor.from_pretrained( |
|
|
"google/vit-base-patch16-224" |
|
|
) |
|
|
|
|
|
# Process image |
|
|
image = Image.open("test.jpg") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
# Get prediction |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=1) |
|
|
|
|
|
if probs[0][1] > 0.5: |
|
|
print(f"SDXL-Generated: {probs[0][1]:.2%}") |
|
|
else: |
|
|
print(f"Not SDXL: {probs[0][0]:.2%}") |
|
|
|
|
|
Performance: |
|
|
Test Accuracy: 99.60% |
|
|
Precision: 99.30% |
|
|
Recall: 99.90% |
|
|
False Positive Rate: 0.70% |
|
|
False Negative Rate: 0.10% |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
|
from PIL import Image |
|
|
from typing import Dict, Union, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class SDXLDetector: |
|
|
""" |
|
|
SDXL Image Detector |
|
|
|
|
|
Easy-to-use wrapper for detecting SDXL-generated images. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str = "ash12321/sdxl-detector-vit", |
|
|
device: str = None |
|
|
): |
|
|
""" |
|
|
Initialize SDXL detector |
|
|
|
|
|
Args: |
|
|
model_path: HuggingFace model repo or local path |
|
|
device: Device to use ('cuda', 'cpu', or None for auto) |
|
|
""" |
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.device = device |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
self.model = ViTForImageClassification.from_pretrained(model_path) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
self.processor = ViTImageProcessor.from_pretrained( |
|
|
"google/vit-base-patch16-224" |
|
|
) |
|
|
|
|
|
print(f"✅ SDXL Detector loaded on {device}") |
|
|
|
|
|
def detect( |
|
|
self, |
|
|
image: Union[str, Path, Image.Image], |
|
|
threshold: float = 0.5 |
|
|
) -> Dict[str, Union[bool, float]]: |
|
|
""" |
|
|
Detect if image is SDXL-generated |
|
|
|
|
|
Args: |
|
|
image: Image path or PIL Image |
|
|
threshold: Classification threshold (default 0.5) |
|
|
|
|
|
Returns: |
|
|
dict with keys: |
|
|
- is_sdxl: bool - True if SDXL-generated |
|
|
- confidence: float - Confidence in prediction |
|
|
- sdxl_probability: float - Probability of being SDXL |
|
|
- real_probability: float - Probability of being real |
|
|
- label: str - Human-readable label |
|
|
""" |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
|
image = Image.open(image).convert('RGB') |
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=1) |
|
|
sdxl_prob = probs[0][1].item() |
|
|
real_prob = probs[0][0].item() |
|
|
|
|
|
is_sdxl = sdxl_prob > threshold |
|
|
|
|
|
return { |
|
|
'is_sdxl': is_sdxl, |
|
|
'confidence': sdxl_prob if is_sdxl else real_prob, |
|
|
'sdxl_probability': sdxl_prob, |
|
|
'real_probability': real_prob, |
|
|
'label': 'SDXL-Generated' if is_sdxl else 'Not SDXL' |
|
|
} |
|
|
|
|
|
def batch_detect( |
|
|
self, |
|
|
images: list, |
|
|
threshold: float = 0.5 |
|
|
) -> list: |
|
|
""" |
|
|
Detect SDXL on multiple images |
|
|
|
|
|
Args: |
|
|
images: List of image paths or PIL Images |
|
|
threshold: Classification threshold |
|
|
|
|
|
Returns: |
|
|
List of detection results |
|
|
""" |
|
|
return [self.detect(img, threshold) for img in images] |
|
|
|
|
|
|
|
|
def detect_sdxl( |
|
|
image_path: str, |
|
|
threshold: float = 0.5, |
|
|
device: str = None |
|
|
) -> Dict[str, Union[bool, float]]: |
|
|
""" |
|
|
Quick function to detect SDXL image |
|
|
|
|
|
Args: |
|
|
image_path: Path to image |
|
|
threshold: Classification threshold |
|
|
device: Device to use |
|
|
|
|
|
Returns: |
|
|
Detection results dictionary |
|
|
|
|
|
Example: |
|
|
>>> result = detect_sdxl("image.jpg") |
|
|
>>> print(f"Is SDXL: {result['is_sdxl']}") |
|
|
>>> print(f"Confidence: {result['confidence']:.2%}") |
|
|
""" |
|
|
detector = SDXLDetector(device=device) |
|
|
return detector.detect(image_path, threshold) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_INFO = { |
|
|
'name': 'SDXL Detector', |
|
|
'version': '1.0', |
|
|
'type': 'Binary Classifier', |
|
|
'detects': 'Stable Diffusion XL (SDXL) images', |
|
|
'does_not_detect': [ |
|
|
'FLUX images', |
|
|
'Midjourney images', |
|
|
'DALL-E images', |
|
|
'Other AI generators' |
|
|
], |
|
|
'architecture': 'Vision Transformer (ViT-base-patch16-224)', |
|
|
'input_size': (224, 224), |
|
|
'classes': { |
|
|
0: 'Real / Not SDXL', |
|
|
1: 'SDXL-Generated' |
|
|
}, |
|
|
'performance': { |
|
|
'test_accuracy': 0.9960, |
|
|
'precision': 0.9930, |
|
|
'recall': 0.9990, |
|
|
'f1_score': 0.9960, |
|
|
'false_positive_rate': 0.0070, |
|
|
'false_negative_rate': 0.0010 |
|
|
}, |
|
|
'training': { |
|
|
'real_images': 8000, |
|
|
'sdxl_images': 8000, |
|
|
'epochs': 12, |
|
|
'best_epoch': 3 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("="*60) |
|
|
print("SDXL Detector - Model Information") |
|
|
print("="*60) |
|
|
print(f"\nModel: {MODEL_INFO['name']}") |
|
|
print(f"Detects: {MODEL_INFO['detects']}") |
|
|
print(f"\n⚠️ Does NOT detect:") |
|
|
for item in MODEL_INFO['does_not_detect']: |
|
|
print(f" - {item}") |
|
|
print(f"\n📊 Performance:") |
|
|
print(f" Accuracy: {MODEL_INFO['performance']['test_accuracy']:.2%}") |
|
|
print(f" Precision: {MODEL_INFO['performance']['precision']:.2%}") |
|
|
print(f" Recall: {MODEL_INFO['performance']['recall']:.2%}") |
|
|
print(f" FPR: {MODEL_INFO['performance']['false_positive_rate']:.2%}") |
|
|
print(f" FNR: {MODEL_INFO['performance']['false_negative_rate']:.2%}") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Example Usage:") |
|
|
print("="*60) |
|
|
print(""" |
|
|
from model import SDXLDetector |
|
|
|
|
|
# Initialize detector |
|
|
detector = SDXLDetector() |
|
|
|
|
|
# Detect single image |
|
|
result = detector.detect("image.jpg") |
|
|
print(f"Is SDXL: {result['is_sdxl']}") |
|
|
print(f"Confidence: {result['confidence']:.2%}") |
|
|
|
|
|
# Or use quick function |
|
|
from model import detect_sdxl |
|
|
result = detect_sdxl("image.jpg") |
|
|
""") |