Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space App for AI Image Detector | |
| User: ash12321 | |
| Repository: ash12321/ai-image-detector-deepsvdd | |
| Save this as: app.py in your Hugging Face Space | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib.pyplot as plt | |
| import io | |
| import numpy as np | |
| # ====================================================================== | |
| # MODEL ARCHITECTURE (Copy from your training script) | |
| # ====================================================================== | |
| class EfficientChannelAttention(nn.Module): | |
| def __init__(self, channels, reduction=8): | |
| super().__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.fc = nn.Sequential( | |
| nn.Linear(channels, channels // reduction, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channels // reduction, channels, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| b, c, _, _ = x.size() | |
| avg_out = self.fc(self.avg_pool(x).view(b, c)) | |
| max_out = self.fc(self.max_pool(x).view(b, c)) | |
| attention = (avg_out + max_out).view(b, c, 1, 1) | |
| return x * attention | |
| class EnhancedDeepSVDDEncoder(nn.Module): | |
| def __init__(self, latent_dim=128): | |
| super().__init__() | |
| self.stem = nn.Sequential( | |
| nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.layer1 = self._make_layer(64, 128, stride=2, use_attention=True) | |
| self.layer2 = self._make_layer(128, 256, stride=2, use_attention=True) | |
| self.layer3 = self._make_layer(256, 512, stride=2, use_attention=True) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) | |
| self.projection = nn.Sequential( | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.4), | |
| nn.Linear(512, latent_dim), | |
| nn.BatchNorm1d(latent_dim) | |
| ) | |
| self._initialize_weights() | |
| def _make_layer(self, in_channels, out_channels, stride, use_attention=True): | |
| layers = [] | |
| layers.extend([ | |
| nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels) | |
| ]) | |
| if use_attention: | |
| layers.append(EfficientChannelAttention(out_channels)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| x = self.stem(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| avg_feat = self.avgpool(x) | |
| max_feat = self.maxpool(x) | |
| x = torch.cat([avg_feat, max_feat], dim=1) | |
| x = torch.flatten(x, 1) | |
| x = self.projection(x) | |
| return x | |
| class AdvancedDeepSVDD(nn.Module): | |
| def __init__(self, latent_dim=128, nu=0.1, temperature=0.5): | |
| super().__init__() | |
| self.encoder = EnhancedDeepSVDDEncoder(latent_dim=latent_dim) | |
| self.register_buffer('center', torch.zeros(latent_dim)) | |
| self.register_buffer('radius', torch.tensor(1.0)) | |
| self.nu = nu | |
| self.temperature = temperature | |
| def forward(self, x): | |
| return self.encoder(x) | |
| def predict_anomaly(self, images, threshold_multiplier=1.0): | |
| self.eval() | |
| with torch.no_grad(): | |
| embeddings = self(images) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| distances = torch.sum((embeddings - self.center) ** 2, dim=1) | |
| anomaly_scores = torch.sigmoid((distances - self.radius) / self.temperature) | |
| threshold = self.radius * threshold_multiplier | |
| is_anomaly = distances > threshold | |
| return is_anomaly, anomaly_scores, distances | |
| # ====================================================================== | |
| # LOAD MODEL | |
| # ====================================================================== | |
| print("π AI Image Detector - Loading...") | |
| REPO_ID = "ash12321/ai-image-detector-deepsvdd" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"π₯ Downloading model from: {REPO_ID}") | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="model.ckpt" | |
| ) | |
| print(f"π Loading model checkpoint...") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Load model state | |
| model = AdvancedDeepSVDD(latent_dim=128) | |
| model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| model.to(device) | |
| model.eval() | |
| print(f"β Model loaded successfully on {device}!") | |
| # ====================================================================== | |
| # IMAGE PREPROCESSING | |
| # ====================================================================== | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.4914, 0.4822, 0.4465], | |
| std=[0.2470, 0.2435, 0.2616] | |
| ) | |
| ]) | |
| # ====================================================================== | |
| # PREDICTION FUNCTION | |
| # ====================================================================== | |
| def create_visualization(image, is_ai, score, distance, threshold): | |
| """Create result visualization""" | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| # Original image | |
| axes[0].imshow(image) | |
| axes[0].axis('off') | |
| axes[0].set_title('Input Image', fontsize=14, fontweight='bold') | |
| # Results panel | |
| axes[1].axis('off') | |
| if is_ai: | |
| color = '#ff4444' | |
| bg_color = '#ffcccc' | |
| label = 'π¨ AI-GENERATED' | |
| else: | |
| color = '#44ff44' | |
| bg_color = '#ccffcc' | |
| label = 'β REAL IMAGE' | |
| result_text = f"{label}\n\n" | |
| result_text += f"Confidence: {score*100:.1f}%\n\n" | |
| result_text += f"ββββββββββββββ\n\n" | |
| result_text += f"Anomaly Score: {score:.4f}\n" | |
| result_text += f"Distance: {distance:.4f}\n" | |
| result_text += f"Threshold: {threshold:.4f}\n\n" | |
| result_text += f"Distance {'>' if distance > threshold else 'β€'} Threshold" | |
| axes[1].text(0.5, 0.5, result_text, | |
| ha='center', va='center', | |
| fontsize=13, | |
| fontfamily='monospace', | |
| bbox=dict(boxstyle='round,pad=1.2', | |
| facecolor=bg_color, | |
| edgecolor=color, | |
| linewidth=3), | |
| transform=axes[1].transAxes) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| result_img = Image.open(buf) | |
| plt.close() | |
| return result_img | |
| def predict_image(image, sensitivity): | |
| """Main prediction function""" | |
| if image is None: | |
| return None, "β οΈ Please upload an image first!" | |
| try: | |
| # Preprocess | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| is_fake, scores, distances = model.predict_anomaly( | |
| img_tensor, | |
| threshold_multiplier=sensitivity | |
| ) | |
| # Extract values | |
| is_ai = bool(is_fake[0].item()) | |
| score = float(scores[0].item()) | |
| distance = float(distances[0].item()) | |
| threshold = float(model.radius.item() * sensitivity) | |
| # Create visualization | |
| viz_img = create_visualization(image, is_ai, score, distance, threshold) | |
| # Format output | |
| if is_ai: | |
| verdict = "# π¨ AI-GENERATED IMAGE DETECTED" | |
| status = "π΄" | |
| interpretation = "This image shows characteristics typical of AI-generated content." | |
| else: | |
| verdict = "# β REAL IMAGE" | |
| status = "π’" | |
| interpretation = "This image appears to be a real/natural photograph." | |
| output_text = f"""{verdict} | |
| ## {status} Analysis Results | |
| | Metric | Value | | |
| |--------|-------| | |
| | **Status** | {'AI-Generated' if is_ai else 'Real/Natural'} {status} | | |
| | **Confidence** | {score*100:.1f}% | | |
| | **Anomaly Score** | {score:.4f} | | |
| | **Distance** | {distance:.4f} | | |
| | **Threshold** | {threshold:.4f} | | |
| --- | |
| ### π― Decision | |
| Distance ({distance:.4f}) **{'>' if distance > threshold else 'β€'}** Threshold ({threshold:.4f}) | |
| β **{'AI-Generated' if is_ai else 'Real'}** | |
| {interpretation} | |
| --- | |
| ### π Interpretation | |
| **Anomaly Score:** Higher = More unusual compared to real images | |
| **Distance:** How far from typical real images | |
| **Threshold:** Decision boundary (distance > threshold = AI) | |
| **Sensitivity:** {sensitivity}x (Lower = more sensitive, Higher = more conservative) | |
| --- | |
| ### β οΈ Note | |
| Results are probabilistic. Best accuracy on natural photos similar to training data. | |
| """ | |
| return viz_img, output_text | |
| except Exception as e: | |
| return None, f"β **Error:** {str(e)}\n\nPlease try a different image." | |
| # ====================================================================== | |
| # GRADIO INTERFACE | |
| # ====================================================================== | |
| with gr.Blocks(title="AI Image Detector") as demo: | |
| gr.Markdown(""" | |
| # π AI Image Detector | |
| ## Deep SVDD One-Class Learning | |
| **Created by:** [ash12321](https://huggingface.co/ash12321) | |
| **Model:** [ai-image-detector-deepsvdd](https://huggingface.co/ash12321/ai-image-detector-deepsvdd) | |
| Detect AI-generated images using one-class learning. Trained on 35,000 real images from CIFAR-10. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input") | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Image to Analyze", | |
| height=350 | |
| ) | |
| sensitivity_slider = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="ποΈ Detection Sensitivity", | |
| info="Lower = More sensitive | Higher = More conservative" | |
| ) | |
| analyze_btn = gr.Button( | |
| "π Analyze Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### π‘ Tips | |
| - Works best with natural photos | |
| - Try AI images from DALL-E, Midjourney, Stable Diffusion | |
| - Adjust sensitivity if needed | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Results") | |
| output_viz = gr.Image( | |
| label="Visual Analysis", | |
| height=350 | |
| ) | |
| output_text = gr.Markdown( | |
| value="Upload an image and click **Analyze** to see results." | |
| ) | |
| # Connect interactions | |
| analyze_btn.click( | |
| fn=predict_image, | |
| inputs=[input_image, sensitivity_slider], | |
| outputs=[output_viz, output_text] | |
| ) | |
| input_image.change( | |
| fn=predict_image, | |
| inputs=[input_image, sensitivity_slider], | |
| outputs=[output_viz, output_text] | |
| ) | |
| # Footer | |
| gr.Markdown(f""" | |
| --- | |
| ## π Model Information | |
| | Specification | Value | | |
| |--------------|-------| | |
| | Architecture | Enhanced Deep SVDD | | |
| | Parameters | 5.3M | | |
| | Training Data | CIFAR-10 (35,000 images) | | |
| | Test Loss | 0.7637 | | |
| | Latent Dim | 128 | | |
| | Device | {device.type.upper()} | | |
| ### β οΈ Limitations | |
| - Best for natural images similar to CIFAR-10 | |
| - Research model - validate before critical use | |
| - May flag unusual real images as AI | |
| - Trained on 32Γ32 images | |
| **Built with PyTorch Lightning & Gradio** | [Model Card](https://huggingface.co/ash12321/ai-image-detector-deepsvdd) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |