ash12321's picture
Update app.py
b408637 verified
"""
Hugging Face Space App for AI Image Detector
User: ash12321
Repository: ash12321/ai-image-detector-deepsvdd
Save this as: app.py in your Hugging Face Space
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import io
import numpy as np
# ======================================================================
# MODEL ARCHITECTURE (Copy from your training script)
# ======================================================================
class EfficientChannelAttention(nn.Module):
def __init__(self, channels, reduction=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
avg_out = self.fc(self.avg_pool(x).view(b, c))
max_out = self.fc(self.max_pool(x).view(b, c))
attention = (avg_out + max_out).view(b, c, 1, 1)
return x * attention
class EnhancedDeepSVDDEncoder(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer1 = self._make_layer(64, 128, stride=2, use_attention=True)
self.layer2 = self._make_layer(128, 256, stride=2, use_attention=True)
self.layer3 = self._make_layer(256, 512, stride=2, use_attention=True)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.maxpool = nn.AdaptiveMaxPool2d((1, 1))
self.projection = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, latent_dim),
nn.BatchNorm1d(latent_dim)
)
self._initialize_weights()
def _make_layer(self, in_channels, out_channels, stride, use_attention=True):
layers = []
layers.extend([
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels)
])
if use_attention:
layers.append(EfficientChannelAttention(out_channels))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
avg_feat = self.avgpool(x)
max_feat = self.maxpool(x)
x = torch.cat([avg_feat, max_feat], dim=1)
x = torch.flatten(x, 1)
x = self.projection(x)
return x
class AdvancedDeepSVDD(nn.Module):
def __init__(self, latent_dim=128, nu=0.1, temperature=0.5):
super().__init__()
self.encoder = EnhancedDeepSVDDEncoder(latent_dim=latent_dim)
self.register_buffer('center', torch.zeros(latent_dim))
self.register_buffer('radius', torch.tensor(1.0))
self.nu = nu
self.temperature = temperature
def forward(self, x):
return self.encoder(x)
def predict_anomaly(self, images, threshold_multiplier=1.0):
self.eval()
with torch.no_grad():
embeddings = self(images)
embeddings = F.normalize(embeddings, p=2, dim=1)
distances = torch.sum((embeddings - self.center) ** 2, dim=1)
anomaly_scores = torch.sigmoid((distances - self.radius) / self.temperature)
threshold = self.radius * threshold_multiplier
is_anomaly = distances > threshold
return is_anomaly, anomaly_scores, distances
# ======================================================================
# LOAD MODEL
# ======================================================================
print("πŸ” AI Image Detector - Loading...")
REPO_ID = "ash12321/ai-image-detector-deepsvdd"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"πŸ“₯ Downloading model from: {REPO_ID}")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename="model.ckpt"
)
print(f"πŸ“‚ Loading model checkpoint...")
checkpoint = torch.load(model_path, map_location=device)
# Load model state
model = AdvancedDeepSVDD(latent_dim=128)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.to(device)
model.eval()
print(f"βœ… Model loaded successfully on {device}!")
# ======================================================================
# IMAGE PREPROCESSING
# ======================================================================
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616]
)
])
# ======================================================================
# PREDICTION FUNCTION
# ======================================================================
def create_visualization(image, is_ai, score, distance, threshold):
"""Create result visualization"""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Original image
axes[0].imshow(image)
axes[0].axis('off')
axes[0].set_title('Input Image', fontsize=14, fontweight='bold')
# Results panel
axes[1].axis('off')
if is_ai:
color = '#ff4444'
bg_color = '#ffcccc'
label = '🚨 AI-GENERATED'
else:
color = '#44ff44'
bg_color = '#ccffcc'
label = 'βœ… REAL IMAGE'
result_text = f"{label}\n\n"
result_text += f"Confidence: {score*100:.1f}%\n\n"
result_text += f"━━━━━━━━━━━━━━\n\n"
result_text += f"Anomaly Score: {score:.4f}\n"
result_text += f"Distance: {distance:.4f}\n"
result_text += f"Threshold: {threshold:.4f}\n\n"
result_text += f"Distance {'>' if distance > threshold else '≀'} Threshold"
axes[1].text(0.5, 0.5, result_text,
ha='center', va='center',
fontsize=13,
fontfamily='monospace',
bbox=dict(boxstyle='round,pad=1.2',
facecolor=bg_color,
edgecolor=color,
linewidth=3),
transform=axes[1].transAxes)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white')
buf.seek(0)
result_img = Image.open(buf)
plt.close()
return result_img
def predict_image(image, sensitivity):
"""Main prediction function"""
if image is None:
return None, "⚠️ Please upload an image first!"
try:
# Preprocess
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
is_fake, scores, distances = model.predict_anomaly(
img_tensor,
threshold_multiplier=sensitivity
)
# Extract values
is_ai = bool(is_fake[0].item())
score = float(scores[0].item())
distance = float(distances[0].item())
threshold = float(model.radius.item() * sensitivity)
# Create visualization
viz_img = create_visualization(image, is_ai, score, distance, threshold)
# Format output
if is_ai:
verdict = "# 🚨 AI-GENERATED IMAGE DETECTED"
status = "πŸ”΄"
interpretation = "This image shows characteristics typical of AI-generated content."
else:
verdict = "# βœ… REAL IMAGE"
status = "🟒"
interpretation = "This image appears to be a real/natural photograph."
output_text = f"""{verdict}
## {status} Analysis Results
| Metric | Value |
|--------|-------|
| **Status** | {'AI-Generated' if is_ai else 'Real/Natural'} {status} |
| **Confidence** | {score*100:.1f}% |
| **Anomaly Score** | {score:.4f} |
| **Distance** | {distance:.4f} |
| **Threshold** | {threshold:.4f} |
---
### 🎯 Decision
Distance ({distance:.4f}) **{'>' if distance > threshold else '≀'}** Threshold ({threshold:.4f})
β†’ **{'AI-Generated' if is_ai else 'Real'}**
{interpretation}
---
### πŸ“Š Interpretation
**Anomaly Score:** Higher = More unusual compared to real images
**Distance:** How far from typical real images
**Threshold:** Decision boundary (distance > threshold = AI)
**Sensitivity:** {sensitivity}x (Lower = more sensitive, Higher = more conservative)
---
### ⚠️ Note
Results are probabilistic. Best accuracy on natural photos similar to training data.
"""
return viz_img, output_text
except Exception as e:
return None, f"❌ **Error:** {str(e)}\n\nPlease try a different image."
# ======================================================================
# GRADIO INTERFACE
# ======================================================================
with gr.Blocks(title="AI Image Detector") as demo:
gr.Markdown("""
# πŸ” AI Image Detector
## Deep SVDD One-Class Learning
**Created by:** [ash12321](https://huggingface.co/ash12321)
**Model:** [ai-image-detector-deepsvdd](https://huggingface.co/ash12321/ai-image-detector-deepsvdd)
Detect AI-generated images using one-class learning. Trained on 35,000 real images from CIFAR-10.
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Input")
input_image = gr.Image(
type="pil",
label="Upload Image to Analyze",
height=350
)
sensitivity_slider = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="🎚️ Detection Sensitivity",
info="Lower = More sensitive | Higher = More conservative"
)
analyze_btn = gr.Button(
"πŸ” Analyze Image",
variant="primary",
size="lg"
)
gr.Markdown("""
### πŸ’‘ Tips
- Works best with natural photos
- Try AI images from DALL-E, Midjourney, Stable Diffusion
- Adjust sensitivity if needed
""")
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Results")
output_viz = gr.Image(
label="Visual Analysis",
height=350
)
output_text = gr.Markdown(
value="Upload an image and click **Analyze** to see results."
)
# Connect interactions
analyze_btn.click(
fn=predict_image,
inputs=[input_image, sensitivity_slider],
outputs=[output_viz, output_text]
)
input_image.change(
fn=predict_image,
inputs=[input_image, sensitivity_slider],
outputs=[output_viz, output_text]
)
# Footer
gr.Markdown(f"""
---
## πŸ“‹ Model Information
| Specification | Value |
|--------------|-------|
| Architecture | Enhanced Deep SVDD |
| Parameters | 5.3M |
| Training Data | CIFAR-10 (35,000 images) |
| Test Loss | 0.7637 |
| Latent Dim | 128 |
| Device | {device.type.upper()} |
### ⚠️ Limitations
- Best for natural images similar to CIFAR-10
- Research model - validate before critical use
- May flag unusual real images as AI
- Trained on 32Γ—32 images
**Built with PyTorch Lightning & Gradio** | [Model Card](https://huggingface.co/ash12321/ai-image-detector-deepsvdd)
""")
if __name__ == "__main__":
demo.launch()