import gradio as gr import torch, torch.nn as nn, warnings from torchvision import transforms from transformers import EfficientNetModel from PIL import Image import numpy as np warnings.filterwarnings("ignore") # ── Model ───────────────────────────────────────────────────────── class FFTBranch(nn.Module): def __init__(self, out_dim=512): super().__init__() self.cnn = nn.Sequential( nn.Conv2d(1,32,3,padding=1),nn.BatchNorm2d(32),nn.GELU(),nn.MaxPool2d(2), nn.Conv2d(32,64,3,padding=1),nn.BatchNorm2d(64),nn.GELU(),nn.MaxPool2d(2), nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.GELU(), nn.AdaptiveAvgPool2d((4,4)), ) self.proj = nn.Sequential(nn.Linear(128*4*4,out_dim),nn.GELU(),nn.Dropout(0.3)) def forward(self, x): g = x.mean(dim=1,keepdim=True) fft = torch.fft.fftshift(torch.fft.fft2(g)) mag = torch.log(torch.abs(fft)+1e-8) mn = mag.flatten(2).min(2)[0].unsqueeze(-1).unsqueeze(-1) mx = mag.flatten(2).max(2)[0].unsqueeze(-1).unsqueeze(-1) mag = (mag-mn)/(mx-mn+1e-8) return self.proj(self.cnn(mag).flatten(1)) class CNNFFTDetector(nn.Module): def __init__(self): super().__init__() self.cnn = EfficientNetModel.from_pretrained("google/efficientnet-b0") params = list(self.cnn.parameters()) for i,p in enumerate(params): p.requires_grad = (i>=int(len(params)*0.6)) self.cnn_proj = nn.Sequential(nn.Linear(1280,512),nn.GELU(),nn.Dropout(0.3)) self.fft = FFTBranch(out_dim=512) self.classifier = nn.Sequential( nn.Linear(1024,256),nn.GELU(),nn.Dropout(0.4), nn.Linear(256,64),nn.GELU(),nn.Linear(64,1)) def forward(self, x): c = self.cnn_proj(self.cnn(x).pooler_output) f = self.fft(x) return self.classifier(torch.cat([c,f],dim=1)) print("Loading model...") device = torch.device("cpu") model = CNNFFTDetector().to(device) ckpt = torch.load("best.pth", map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model_state"]) model.eval() print(f"Model ready — {ckpt['best_val_acc']*100:.2f}%") tf = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize([0.5]*3,[0.5]*3), ]) def predict(image): if image is None: return {"AI Generated": 0.0, "Real": 1.0}, "Please upload an image" if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert("RGB") tensor = tf(image).unsqueeze(0).to(device) with torch.no_grad(): score = torch.sigmoid(model(tensor)).item() fake_pct = round(score*100, 1) real_pct = round((1-score)*100, 1) label = "AI Generated / Deepfake" if score >= 0.5 else "Real Image" verdict = f"## {'🔴' if score>=0.5 else '🟢'} {label}\n\n**AI/Fake:** {fake_pct}% \n**Real:** {real_pct}% \n**Confidence:** {round(max(score,1-score)*100,1)}%" return {"AI Generated": float(score), "Real": float(1-score)}, verdict # ── UI ──────────────────────────────────────────────────────────── with gr.Blocks(theme=gr.themes.Soft(), title="LunaNet") as demo: gr.Markdown("# 🌙 LunaNet — AI Image & Deepfake Detector\n**Revealing the Unseen** · CNN (EfficientNetB0) + FFT · 91.47% accuracy") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Upload Image") btn = gr.Button("✦ Analyse", variant="primary", size="lg") with gr.Column(): label_out = gr.Label(num_top_classes=2, label="Detection Result") md_out = gr.Markdown(label="Verdict") # api_name makes it callable as /predict from external frontends btn.click(fn=predict, inputs=img_input, outputs=[label_out, md_out], api_name="predict") img_input.upload(fn=predict, inputs=img_input, outputs=[label_out, md_out]) gr.Markdown("---\n**Training data:** CIFAKE · 140k Faces · OpenForensics · Celeb-DF v2") demo.launch(ssr_mode=False)