File size: 3,957 Bytes
0b31297
 
113fb27
 
 
0b31297
 
 
113fb27
0b31297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113fb27
 
 
 
0b31297
 
 
 
 
 
 
113fb27
0b31297
113fb27
0b31297
 
113fb27
0b31297
 
 
113fb27
 
 
 
0b31297
 
 
 
 
113fb27
0b31297
 
 
 
113fb27
0b31297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113fb27
0b31297
 
 
113fb27
0b31297
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
# ──────────────────────────────────────────────────────────────
#  BubbleAI Image-Safety Detector – final checkpoint-aligned build
#  (matches classifier.0 / classifier.3 keys)
#  Coder: Amir Mehdi Memari – 2025-08-06
# ──────────────────────────────────────────────────────────────

from __future__ import annotations
import pathlib, typing as t

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import gradio as gr

# ── 1. Paths & device ─────────────────────────────────────────
REPO_DIR  = pathlib.Path(__file__).parent
CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"

# ── 2. Architecture that matches checkpoint keys ──────────────
class SafetyResNet(torch.nn.Module):
    """
    ResNet-50 backbone (conv1 β–Έ layer4) + global-avg-pool
    ➜ Linear(2048β†’512) ➜ ReLU ➜ Dropout ➜ Linear(512β†’2)
    Stored in checkpoint as:
        feature_extractor.* , classifier.0.* , classifier.3.*
    """
    def __init__(self) -> None:
        super().__init__()

        base = torchvision.models.resnet50(weights=None)
        self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(2048, 512, bias=True),   # index 0  ← matches checkpoint
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.30),
            torch.nn.Linear(512, 2, bias=True)       # index 3  ← matches checkpoint
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.feature_extractor(x)      # (B, 2048, H/32, W/32)
        x = self.pool(x)                   # (B, 2048, 1, 1)
        x = torch.flatten(x, 1)            # (B, 2048)
        x = self.classifier(x)             # (B, 2)
        return x

# ── 3. Instantiate & load weights ─────────────────────────────
model = SafetyResNet().to(DEVICE)
state = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(state, strict=True)      # ← should succeed now
model.eval()

CLASSES = ["Safe", "Unsafe"]

# ── 4. Preprocessing (ImageNet stats) ─────────────────────────
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

# ── 5. Inference helper ───────────────────────────────────────
@torch.inference_mode()
def predict(img: Image.Image) -> t.Dict[str, float]:
    tensor = preprocess(img).unsqueeze(0).to(DEVICE)
    probs  = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
    return {CLASSES[i]: float(probs[i]) for i in range(2)}

# ── 6. Gradio interface ───────────────────────────────────────
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload an image"),
    outputs=gr.Label(num_top_classes=2, label="Prediction"),
    title="BubbleAI Image-Safety Detector",
    description="Drag an image here to see Safe vs Unsafe probabilities.",
    cache_examples=False,
)

# ── 7. Launch (HF Space calls this automatically) ─────────────
if __name__ == "__main__":
    demo.launch()