Spaces:
Sleeping
Sleeping
File size: 4,323 Bytes
21554ac 28513b8 21554ac 361ae5f aedd425 21554ac 4da7b85 21554ac 1968d99 361ae5f 21554ac 361ae5f 21554ac 361ae5f 21554ac 4da7b85 361ae5f 1968d99 361ae5f 1968d99 361ae5f 21554ac 361ae5f 1968d99 361ae5f 21554ac 361ae5f 21554ac aedd425 361ae5f 21554ac 1968d99 21554ac 1968d99 21554ac 361ae5f 1968d99 f5420d3 28513b8 aedd425 361ae5f 1968d99 361ae5f 28513b8 361ae5f 28513b8 361ae5f aedd425 28513b8 aedd425 28513b8 361ae5f d5f04bf 28513b8 361ae5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import gradio as gr
import torch, torch.nn as nn, warnings
from torchvision import transforms
from transformers import EfficientNetModel
from PIL import Image
import numpy as np
warnings.filterwarnings("ignore")
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class FFTBranch(nn.Module):
def __init__(self, out_dim=512):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1,32,3,padding=1),nn.BatchNorm2d(32),nn.GELU(),nn.MaxPool2d(2),
nn.Conv2d(32,64,3,padding=1),nn.BatchNorm2d(64),nn.GELU(),nn.MaxPool2d(2),
nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.GELU(),
nn.AdaptiveAvgPool2d((4,4)),
)
self.proj = nn.Sequential(nn.Linear(128*4*4,out_dim),nn.GELU(),nn.Dropout(0.3))
def forward(self, x):
g = x.mean(dim=1,keepdim=True)
fft = torch.fft.fftshift(torch.fft.fft2(g))
mag = torch.log(torch.abs(fft)+1e-8)
mn = mag.flatten(2).min(2)[0].unsqueeze(-1).unsqueeze(-1)
mx = mag.flatten(2).max(2)[0].unsqueeze(-1).unsqueeze(-1)
mag = (mag-mn)/(mx-mn+1e-8)
return self.proj(self.cnn(mag).flatten(1))
class CNNFFTDetector(nn.Module):
def __init__(self):
super().__init__()
self.cnn = EfficientNetModel.from_pretrained("google/efficientnet-b0")
params = list(self.cnn.parameters())
for i,p in enumerate(params):
p.requires_grad = (i>=int(len(params)*0.6))
self.cnn_proj = nn.Sequential(nn.Linear(1280,512),nn.GELU(),nn.Dropout(0.3))
self.fft = FFTBranch(out_dim=512)
self.classifier = nn.Sequential(
nn.Linear(1024,256),nn.GELU(),nn.Dropout(0.4),
nn.Linear(256,64),nn.GELU(),nn.Linear(64,1))
def forward(self, x):
c = self.cnn_proj(self.cnn(x).pooler_output)
f = self.fft(x)
return self.classifier(torch.cat([c,f],dim=1))
print("Loading model...")
device = torch.device("cpu")
model = CNNFFTDetector().to(device)
ckpt = torch.load("best.pth", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state"])
model.eval()
print(f"Model ready β {ckpt['best_val_acc']*100:.2f}%")
tf = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3),
])
def predict(image):
if image is None:
return {"AI Generated": 0.0, "Real": 1.0}, "Please upload an image"
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
tensor = tf(image).unsqueeze(0).to(device)
with torch.no_grad():
score = torch.sigmoid(model(tensor)).item()
fake_pct = round(score*100, 1)
real_pct = round((1-score)*100, 1)
label = "AI Generated / Deepfake" if score >= 0.5 else "Real Image"
verdict = f"## {'π΄' if score>=0.5 else 'π’'} {label}\n\n**AI/Fake:** {fake_pct}% \n**Real:** {real_pct}% \n**Confidence:** {round(max(score,1-score)*100,1)}%"
return {"AI Generated": float(score), "Real": float(1-score)}, verdict
# ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(theme=gr.themes.Soft(), title="LunaNet") as demo:
gr.Markdown("# π LunaNet β AI Image & Deepfake Detector\n**Revealing the Unseen** Β· CNN (EfficientNetB0) + FFT Β· 91.47% accuracy")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload Image")
btn = gr.Button("β¦ Analyse", variant="primary", size="lg")
with gr.Column():
label_out = gr.Label(num_top_classes=2, label="Detection Result")
md_out = gr.Markdown(label="Verdict")
# api_name makes it callable as /predict from external frontends
btn.click(fn=predict, inputs=img_input, outputs=[label_out, md_out], api_name="predict")
img_input.upload(fn=predict, inputs=img_input, outputs=[label_out, md_out])
gr.Markdown("---\n**Training data:** CIFAKE Β· 140k Faces Β· OpenForensics Β· Celeb-DF v2")
demo.launch(ssr_mode=False) |