#!/usr/bin/env python # ────────────────────────────────────────────────────────────── # BubbleAI Image-Safety Detector – final checkpoint-aligned build # (matches classifier.0 / classifier.3 keys) # Coder: Amir Mehdi Memari – 2025-08-06 # ────────────────────────────────────────────────────────────── from __future__ import annotations import pathlib, typing as t import torch import torchvision from torchvision import transforms from PIL import Image import gradio as gr # ── 1. Paths & device ───────────────────────────────────────── REPO_DIR = pathlib.Path(__file__).parent CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── 2. Architecture that matches checkpoint keys ────────────── class SafetyResNet(torch.nn.Module): """ ResNet-50 backbone (conv1 ▸ layer4) + global-avg-pool ➜ Linear(2048→512) ➜ ReLU ➜ Dropout ➜ Linear(512→2) Stored in checkpoint as: feature_extractor.* , classifier.0.* , classifier.3.* """ def __init__(self) -> None: super().__init__() base = torchvision.models.resnet50(weights=None) self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8]) self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.classifier = torch.nn.Sequential( torch.nn.Linear(2048, 512, bias=True), # index 0 ← matches checkpoint torch.nn.ReLU(inplace=True), torch.nn.Dropout(p=0.30), torch.nn.Linear(512, 2, bias=True) # index 3 ← matches checkpoint ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.feature_extractor(x) # (B, 2048, H/32, W/32) x = self.pool(x) # (B, 2048, 1, 1) x = torch.flatten(x, 1) # (B, 2048) x = self.classifier(x) # (B, 2) return x # ── 3. Instantiate & load weights ───────────────────────────── model = SafetyResNet().to(DEVICE) state = torch.load(CKPT_PATH, map_location=DEVICE) model.load_state_dict(state, strict=True) # ← should succeed now model.eval() CLASSES = ["Safe", "Unsafe"] # ── 4. Preprocessing (ImageNet stats) ───────────────────────── preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225]), ]) # ── 5. Inference helper ─────────────────────────────────────── @torch.inference_mode() def predict(img: Image.Image) -> t.Dict[str, float]: tensor = preprocess(img).unsqueeze(0).to(DEVICE) probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist() return {CLASSES[i]: float(probs[i]) for i in range(2)} # ── 6. Gradio interface ─────────────────────────────────────── demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an image"), outputs=gr.Label(num_top_classes=2, label="Prediction"), title="BubbleAI Image-Safety Detector", description="Drag an image here to see Safe vs Unsafe probabilities.", cache_examples=False, ) # ── 7. Launch (HF Space calls this automatically) ───────────── if __name__ == "__main__": demo.launch()