File size: 3,957 Bytes
0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 113fb27 0b31297 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#!/usr/bin/env python
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# BubbleAI Image-Safety Detector β final checkpoint-aligned build
# (matches classifier.0 / classifier.3 keys)
# Coder: Amir Mehdi Memari β 2025-08-06
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from __future__ import annotations
import pathlib, typing as t
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import gradio as gr
# ββ 1. Paths & device βββββββββββββββββββββββββββββββββββββββββ
REPO_DIR = pathlib.Path(__file__).parent
CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ββ 2. Architecture that matches checkpoint keys ββββββββββββββ
class SafetyResNet(torch.nn.Module):
"""
ResNet-50 backbone (conv1 βΈ layer4) + global-avg-pool
β Linear(2048β512) β ReLU β Dropout β Linear(512β2)
Stored in checkpoint as:
feature_extractor.* , classifier.0.* , classifier.3.*
"""
def __init__(self) -> None:
super().__init__()
base = torchvision.models.resnet50(weights=None)
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.classifier = torch.nn.Sequential(
torch.nn.Linear(2048, 512, bias=True), # index 0 β matches checkpoint
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(p=0.30),
torch.nn.Linear(512, 2, bias=True) # index 3 β matches checkpoint
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feature_extractor(x) # (B, 2048, H/32, W/32)
x = self.pool(x) # (B, 2048, 1, 1)
x = torch.flatten(x, 1) # (B, 2048)
x = self.classifier(x) # (B, 2)
return x
# ββ 3. Instantiate & load weights βββββββββββββββββββββββββββββ
model = SafetyResNet().to(DEVICE)
state = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(state, strict=True) # β should succeed now
model.eval()
CLASSES = ["Safe", "Unsafe"]
# ββ 4. Preprocessing (ImageNet stats) βββββββββββββββββββββββββ
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225]),
])
# ββ 5. Inference helper βββββββββββββββββββββββββββββββββββββββ
@torch.inference_mode()
def predict(img: Image.Image) -> t.Dict[str, float]:
tensor = preprocess(img).unsqueeze(0).to(DEVICE)
probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
return {CLASSES[i]: float(probs[i]) for i in range(2)}
# ββ 6. Gradio interface βββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="BubbleAI Image-Safety Detector",
description="Drag an image here to see Safe vs Unsafe probabilities.",
cache_examples=False,
)
# ββ 7. Launch (HF Space calls this automatically) βββββββββββββ
if __name__ == "__main__":
demo.launch()
|