|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import pathlib, typing as t |
|
|
|
|
|
import torch |
|
|
import torchvision |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
REPO_DIR = pathlib.Path(__file__).parent |
|
|
CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
class SafetyResNet(torch.nn.Module): |
|
|
""" |
|
|
ResNet-50 backbone (conv1 βΈ layer4) + global-avg-pool |
|
|
β Linear(2048β512) β ReLU β Dropout β Linear(512β2) |
|
|
Stored in checkpoint as: |
|
|
feature_extractor.* , classifier.0.* , classifier.3.* |
|
|
""" |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
base = torchvision.models.resnet50(weights=None) |
|
|
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8]) |
|
|
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
self.classifier = torch.nn.Sequential( |
|
|
torch.nn.Linear(2048, 512, bias=True), |
|
|
torch.nn.ReLU(inplace=True), |
|
|
torch.nn.Dropout(p=0.30), |
|
|
torch.nn.Linear(512, 2, bias=True) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.feature_extractor(x) |
|
|
x = self.pool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
|
|
|
model = SafetyResNet().to(DEVICE) |
|
|
state = torch.load(CKPT_PATH, map_location=DEVICE) |
|
|
model.load_state_dict(state, strict=True) |
|
|
model.eval() |
|
|
|
|
|
CLASSES = ["Safe", "Unsafe"] |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std =[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def predict(img: Image.Image) -> t.Dict[str, float]: |
|
|
tensor = preprocess(img).unsqueeze(0).to(DEVICE) |
|
|
probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist() |
|
|
return {CLASSES[i]: float(probs[i]) for i in range(2)} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Upload an image"), |
|
|
outputs=gr.Label(num_top_classes=2, label="Prediction"), |
|
|
title="BubbleAI Image-Safety Detector", |
|
|
description="Drag an image here to see Safe vs Unsafe probabilities.", |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|