demectai / app.py
nahid112376's picture
Simplify: Use ResNet-1D only (93.4% accuracy) - remove TCN and stacking
b60a0a4
#!/usr/bin/env python3
"""
AI Image Detector - API Endpoint for Hugging Face Spaces
Uses ResNet-1D only (93.37% accuracy) - no TCN or stacking
"""
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
# ==================== ResNet-1D MODEL ====================
class ResidualBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock1D, self).__init__()
mid = out_channels // 4
self.conv1 = nn.Conv1d(in_channels, mid, 1, bias=False)
self.bn1 = nn.BatchNorm1d(mid)
self.conv2 = nn.Conv1d(mid, mid, 3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(mid)
self.conv3 = nn.Conv1d(mid, out_channels, 1, bias=False)
self.bn3 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(0.3)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv1d(in_channels, out_channels, 1, stride=stride, bias=False),
nn.BatchNorm1d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
return self.dropout(self.relu(out + self.shortcut(x)))
class ResNet1D(nn.Module):
def __init__(self, input_dim, num_classes=2):
super(ResNet1D, self).__init__()
self.conv1 = nn.Conv1d(input_dim, 64, 7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool1d(3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 256, 2, 1)
self.layer2 = self._make_layer(256, 512, 2, 2)
self.layer3 = self._make_layer(512, 1024, 2, 2)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def _make_layer(self, in_ch, out_ch, blocks, stride):
layers = [ResidualBlock1D(in_ch, out_ch, stride)]
for _ in range(1, blocks):
layers.append(ResidualBlock1D(out_ch, out_ch, 1))
return nn.Sequential(*layers)
def forward(self, x):
x = x.transpose(1, 2)
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
x = self.layer3(self.layer2(self.layer1(x)))
return self.fc(self.avgpool(x).view(x.size(0), -1))
# ==================== DETECTOR CLASS ====================
class AIDetector:
def __init__(self):
self.device = "cpu"
self.hidden_dim = 2048
self.max_patches = 103
# Models (lazy loaded)
self.qwen_model = None
self.qwen_processor = None
self.resnet = None
self.loaded = False
def load_qwen(self):
"""Load Qwen2.5-VL for feature extraction"""
if self.qwen_model is None:
print("Loading Qwen2.5-VL-3B-Instruct...")
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
self.qwen_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
self.qwen_model.eval()
print("Qwen loaded!")
def load_classifier(self):
"""Load ResNet-1D classifier"""
if not self.loaded:
print("Loading ResNet-1D classifier...")
self.resnet = ResNet1D(self.hidden_dim).to(self.device)
self.resnet.load_state_dict(torch.load("resnet1d_best.pth", map_location=self.device))
self.resnet.eval()
self.loaded = True
print("ResNet-1D loaded!")
def extract_features(self, image: Image.Image) -> np.ndarray:
"""Extract features from image using Qwen2.5-VL"""
self.load_qwen()
from qwen_vl_utils import process_vision_info
# Resize with aspect ratio preservation (same as training)
image = image.convert("RGB")
target_size = (256, 256)
width, height = image.size
scale = min(target_size[0] / width, target_size[1] / height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.LANCZOS)
# Center on black canvas (same as training)
canvas = Image.new('RGB', target_size, (0, 0, 0))
paste_x = (target_size[0] - new_width) // 2
paste_y = (target_size[1] - new_height) // 2
canvas.paste(image, (paste_x, paste_y))
image = canvas
# Same prompt as training
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "Image"}]}]
text = self.qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = self.qwen_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt")
with torch.no_grad():
outputs = self.qwen_model.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs.get("pixel_values"),
image_grid_thw=inputs.get("image_grid_thw"),
output_hidden_states=True
)
features = outputs.last_hidden_state[0].cpu().numpy()
return features
def pad_features(self, features: np.ndarray) -> np.ndarray:
"""Pad/truncate features to expected size"""
n = features.shape[0]
if n < self.max_patches:
padded = np.zeros((self.max_patches, self.hidden_dim), dtype=np.float32)
padded[:n] = features
return padded
return features[:self.max_patches].astype(np.float32)
def predict(self, image: Image.Image) -> dict:
"""Prediction using ResNet-1D only"""
self.load_classifier()
# Extract features
features = self.extract_features(image)
features = self.pad_features(features)
x = torch.FloatTensor(features).unsqueeze(0).to(self.device)
# Get ResNet-1D prediction
# Class 0 = Fake/AI, Class 1 = Real
with torch.no_grad():
output = torch.softmax(self.resnet(x), dim=1)
real_prob = output[0, 1].item()
ai_percentage = (1 - real_prob) * 100
return {
"ai_percentage": round(ai_percentage, 2),
"real_percentage": round(real_prob * 100, 2),
"verdict": "AI Generated" if ai_percentage > 50 else "Real",
"confidence": round(max(ai_percentage, 100 - ai_percentage), 2)
}
# ==================== GRADIO API ====================
detector = AIDetector()
def detect(image):
"""API endpoint function"""
if image is None:
return {"error": "No image provided"}
try:
result = detector.predict(image)
return result
except Exception as e:
return {"error": str(e)}
demo = gr.Interface(
fn=detect,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.JSON(label="Result"),
title="AI Image Detector API",
description="Upload an image to detect if it's AI-generated. Uses ResNet-1D (93.4% accuracy).",
allow_flagging="never",
api_name="predict"
)
if __name__ == "__main__":
demo.queue(api_open=True).launch()