MetiMiester commited on
Commit
fb1e49b
Β·
verified Β·
1 Parent(s): ba00459

Upload 4 files

Browse files
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # ──────────────────────────────────────────────────────────────
3
+ # BubbleAI Image-Safety Detector – Hugging Face Space
4
+ # -------------------------------------------------------------
5
+ # Gradio app that classifies uploaded images as β€œSafe” or
6
+ # β€œUnsafe” using a fine-tuned ResNet-50.
7
+ #
8
+ # Coder: Amir Mehdi Memari (2025-08-06)
9
+ # Description:
10
+ # β€’ Loads the checkpoint `resnet_safety_classifier.pth`
11
+ # (must be in the same repo directory; tracked with Git-LFS).
12
+ # β€’ Applies standard ImageNet preprocessing.
13
+ # β€’ Returns class probabilities via a simple Gradio UI.
14
+ # Usage:
15
+ # The HF Space builder executes `python app.py` automatically.
16
+ # ──────────────────────────────────────────────────────────────
17
+
18
+ from __future__ import annotations
19
+ import pathlib, typing as t
20
+
21
+ import torch
22
+ import torchvision
23
+ from torchvision import transforms
24
+ from PIL import Image
25
+ import gradio as gr
26
+
27
+
28
+ # ── 1. Locate checkpoint ──────────────────────────────────────
29
+ REPO_DIR = pathlib.Path(__file__).parent
30
+ CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+
34
+ # ── 2. Define the model architecture exactly as trained ───────
35
+ class SafetyResNet(torch.nn.Module):
36
+ """
37
+ ResNet-50 backbone with a 2-unit classifier head.
38
+ """
39
+ def __init__(self) -> None:
40
+ super().__init__()
41
+ base = torchvision.models.resnet50(weights=None)
42
+ base.fc = torch.nn.Linear(base.fc.in_features, 2) # Safe / Unsafe
43
+ self.model = base
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ return self.model(x)
47
+
48
+
49
+ # ── 3. Instantiate & load weights ─────────────────────────────
50
+ model = SafetyResNet().to(DEVICE)
51
+ state = torch.load(CKPT_PATH, map_location=DEVICE)
52
+ model.load_state_dict(state, strict=True)
53
+ model.eval() # inference mode
54
+ CLASSES = ["Safe", "Unsafe"]
55
+
56
+
57
+ # ── 4. Pre-processing pipeline (ImageNet stats) ───────────────
58
+ preprocess = transforms.Compose([
59
+ transforms.Resize(256),
60
+ transforms.CenterCrop(224),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
63
+ std =[0.229, 0.224, 0.225]),
64
+ ])
65
+
66
+
67
+ # ── 5. Inference helper ───────────────────────────────────────
68
+ @torch.inference_mode()
69
+ def predict(img: Image.Image) -> t.Dict[str, float]:
70
+ """
71
+ Returns {class_name: probability} for a single PIL image.
72
+ """
73
+ tensor = preprocess(img).unsqueeze(0).to(DEVICE)
74
+ probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
75
+ return {CLASSES[i]: float(probs[i]) for i in range(2)}
76
+
77
+
78
+ # ── 6. Build Gradio Interface ─────────────────────────────────
79
+ demo = gr.Interface(
80
+ fn=predict,
81
+ inputs=gr.Image(type="pil", label="Upload an image"),
82
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
83
+ title="BubbleAI Image-Safety Detector",
84
+ description=(
85
+ "This demo classifies images as **Safe** or **Unsafe** (NSFW) "
86
+ "using a fine-tuned ResNet-50. Probabilities are shown for both "
87
+ "classes. Model weights Β© 2025 Amir Mehdi Memari."
88
+ ),
89
+ cache_examples=False, # disable weighty example cache rebuilds
90
+ )
91
+
92
+ # ── 7. Launch (HF Spaces auto-calls this in production) ───────
93
+ if __name__ == "__main__":
94
+ demo.launch()
clip_safety_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd1c0a11caa8a6ebeda4bc1258f1d09be2cc8a15f19fab8b27eda0ffb1da4183
3
+ size 605757795
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # β€”β€”β€” Core DL stack β€”β€”β€”
2
+ torch>=2.2 # runtime will auto-select the CPU build
3
+ torchvision>=0.17 # provides ResNet + transforms
4
+
5
+ # β€”β€”β€” App / UI β€”β€”β€”
6
+ gradio>=4.26 # Space UI
7
+ pillow # PIL image handling (pulled by torchvision, listed for clarity)
resnet_safety_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce42fb248e341f8cbf133d1164b9453e88c088ba2fcc291f3dd8e65746598b18
3
+ size 98558659