MetiMiester's picture
Update app.py
113fb27 verified
#!/usr/bin/env python
# ──────────────────────────────────────────────────────────────
# BubbleAI Image-Safety Detector – final checkpoint-aligned build
# (matches classifier.0 / classifier.3 keys)
# Coder: Amir Mehdi Memari – 2025-08-06
# ──────────────────────────────────────────────────────────────
from __future__ import annotations
import pathlib, typing as t
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import gradio as gr
# ── 1. Paths & device ─────────────────────────────────────────
REPO_DIR = pathlib.Path(__file__).parent
CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── 2. Architecture that matches checkpoint keys ──────────────
class SafetyResNet(torch.nn.Module):
"""
ResNet-50 backbone (conv1 β–Έ layer4) + global-avg-pool
➜ Linear(2048β†’512) ➜ ReLU ➜ Dropout ➜ Linear(512β†’2)
Stored in checkpoint as:
feature_extractor.* , classifier.0.* , classifier.3.*
"""
def __init__(self) -> None:
super().__init__()
base = torchvision.models.resnet50(weights=None)
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.classifier = torch.nn.Sequential(
torch.nn.Linear(2048, 512, bias=True), # index 0 ← matches checkpoint
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(p=0.30),
torch.nn.Linear(512, 2, bias=True) # index 3 ← matches checkpoint
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feature_extractor(x) # (B, 2048, H/32, W/32)
x = self.pool(x) # (B, 2048, 1, 1)
x = torch.flatten(x, 1) # (B, 2048)
x = self.classifier(x) # (B, 2)
return x
# ── 3. Instantiate & load weights ─────────────────────────────
model = SafetyResNet().to(DEVICE)
state = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(state, strict=True) # ← should succeed now
model.eval()
CLASSES = ["Safe", "Unsafe"]
# ── 4. Preprocessing (ImageNet stats) ─────────────────────────
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225]),
])
# ── 5. Inference helper ───────────────────────────────────────
@torch.inference_mode()
def predict(img: Image.Image) -> t.Dict[str, float]:
tensor = preprocess(img).unsqueeze(0).to(DEVICE)
probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
return {CLASSES[i]: float(probs[i]) for i in range(2)}
# ── 6. Gradio interface ───────────────────────────────────────
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="BubbleAI Image-Safety Detector",
description="Drag an image here to see Safe vs Unsafe probabilities.",
cache_examples=False,
)
# ── 7. Launch (HF Space calls this automatically) ─────────────
if __name__ == "__main__":
demo.launch()