File size: 4,323 Bytes
21554ac
 
 
 
 
28513b8
21554ac
361ae5f
aedd425
21554ac
4da7b85
 
21554ac
 
 
 
 
 
1968d99
361ae5f
21554ac
361ae5f
 
 
 
 
 
21554ac
361ae5f
21554ac
4da7b85
 
361ae5f
 
1968d99
361ae5f
 
 
 
1968d99
 
361ae5f
21554ac
361ae5f
 
1968d99
361ae5f
21554ac
 
 
361ae5f
21554ac
 
aedd425
361ae5f
21554ac
1968d99
21554ac
1968d99
21554ac
361ae5f
1968d99
f5420d3
28513b8
 
aedd425
361ae5f
1968d99
 
 
361ae5f
 
28513b8
361ae5f
28513b8
361ae5f
 
 
 
aedd425
 
 
28513b8
aedd425
28513b8
 
361ae5f
 
d5f04bf
28513b8
361ae5f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch, torch.nn as nn, warnings
from torchvision import transforms
from transformers import EfficientNetModel
from PIL import Image
import numpy as np
warnings.filterwarnings("ignore")
 
# ── Model ─────────────────────────────────────────────────────────
class FFTBranch(nn.Module):
    def __init__(self, out_dim=512):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1,32,3,padding=1),nn.BatchNorm2d(32),nn.GELU(),nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1),nn.BatchNorm2d(64),nn.GELU(),nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1),nn.BatchNorm2d(128),nn.GELU(),
            nn.AdaptiveAvgPool2d((4,4)),
        )
        self.proj = nn.Sequential(nn.Linear(128*4*4,out_dim),nn.GELU(),nn.Dropout(0.3))
 
    def forward(self, x):
        g = x.mean(dim=1,keepdim=True)
        fft = torch.fft.fftshift(torch.fft.fft2(g))
        mag = torch.log(torch.abs(fft)+1e-8)
        mn = mag.flatten(2).min(2)[0].unsqueeze(-1).unsqueeze(-1)
        mx = mag.flatten(2).max(2)[0].unsqueeze(-1).unsqueeze(-1)
        mag = (mag-mn)/(mx-mn+1e-8)
        return self.proj(self.cnn(mag).flatten(1))
 
class CNNFFTDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = EfficientNetModel.from_pretrained("google/efficientnet-b0")
        params = list(self.cnn.parameters())
        for i,p in enumerate(params):
            p.requires_grad = (i>=int(len(params)*0.6))
        self.cnn_proj = nn.Sequential(nn.Linear(1280,512),nn.GELU(),nn.Dropout(0.3))
        self.fft = FFTBranch(out_dim=512)
        self.classifier = nn.Sequential(
            nn.Linear(1024,256),nn.GELU(),nn.Dropout(0.4),
            nn.Linear(256,64),nn.GELU(),nn.Linear(64,1))
 
    def forward(self, x):
        c = self.cnn_proj(self.cnn(x).pooler_output)
        f = self.fft(x)
        return self.classifier(torch.cat([c,f],dim=1))
 
print("Loading model...")
device = torch.device("cpu")
model  = CNNFFTDetector().to(device)
ckpt   = torch.load("best.pth", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state"])
model.eval()
print(f"Model ready β€” {ckpt['best_val_acc']*100:.2f}%")
 
tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3),
])
 
def predict(image):
    if image is None:
        return {"AI Generated": 0.0, "Real": 1.0}, "Please upload an image"
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    tensor = tf(image).unsqueeze(0).to(device)
    with torch.no_grad():
        score = torch.sigmoid(model(tensor)).item()
    fake_pct = round(score*100, 1)
    real_pct = round((1-score)*100, 1)
    label    = "AI Generated / Deepfake" if score >= 0.5 else "Real Image"
    verdict  = f"## {'πŸ”΄' if score>=0.5 else '🟒'} {label}\n\n**AI/Fake:** {fake_pct}%  \n**Real:** {real_pct}%  \n**Confidence:** {round(max(score,1-score)*100,1)}%"
    return {"AI Generated": float(score), "Real": float(1-score)}, verdict
 
# ── UI ────────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(), title="LunaNet") as demo:
    gr.Markdown("# πŸŒ™ LunaNet β€” AI Image & Deepfake Detector\n**Revealing the Unseen** Β· CNN (EfficientNetB0) + FFT Β· 91.47% accuracy")
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Upload Image")
            btn = gr.Button("✦ Analyse", variant="primary", size="lg")
        with gr.Column():
            label_out = gr.Label(num_top_classes=2, label="Detection Result")
            md_out    = gr.Markdown(label="Verdict")
 
    # api_name makes it callable as /predict from external frontends
    btn.click(fn=predict, inputs=img_input, outputs=[label_out, md_out], api_name="predict")
    img_input.upload(fn=predict, inputs=img_input, outputs=[label_out, md_out])
 
    gr.Markdown("---\n**Training data:** CIFAKE Β· 140k Faces Β· OpenForensics Β· Celeb-DF v2")
 
demo.launch(ssr_mode=False)