MetiMiester commited on
Commit
0b31297
Β·
verified Β·
1 Parent(s): f326800

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -94
app.py CHANGED
@@ -1,94 +1,106 @@
1
- #!/usr/bin/env python
2
- # ──────────────────────────────────────────────────────────────
3
- # BubbleAI Image-Safety Detector – Hugging Face Space
4
- # -------------------------------------------------------------
5
- # Gradio app that classifies uploaded images as β€œSafe” or
6
- # β€œUnsafe” using a fine-tuned ResNet-50.
7
- #
8
- # Coder: Amir Mehdi Memari (2025-08-06)
9
- # Description:
10
- # β€’ Loads the checkpoint `resnet_safety_classifier.pth`
11
- # (must be in the same repo directory; tracked with Git-LFS).
12
- # β€’ Applies standard ImageNet preprocessing.
13
- # β€’ Returns class probabilities via a simple Gradio UI.
14
- # Usage:
15
- # The HF Space builder executes `python app.py` automatically.
16
- # ──────────────────────────────────────────────────────────────
17
-
18
- from __future__ import annotations
19
- import pathlib, typing as t
20
-
21
- import torch
22
- import torchvision
23
- from torchvision import transforms
24
- from PIL import Image
25
- import gradio as gr
26
-
27
-
28
- # ── 1. Locate checkpoint ──────────────────────────────────────
29
- REPO_DIR = pathlib.Path(__file__).parent
30
- CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
31
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
-
34
- # ── 2. Define the model architecture exactly as trained ───────
35
- class SafetyResNet(torch.nn.Module):
36
- """
37
- ResNet-50 backbone with a 2-unit classifier head.
38
- """
39
- def __init__(self) -> None:
40
- super().__init__()
41
- base = torchvision.models.resnet50(weights=None)
42
- base.fc = torch.nn.Linear(base.fc.in_features, 2) # Safe / Unsafe
43
- self.model = base
44
-
45
- def forward(self, x: torch.Tensor) -> torch.Tensor:
46
- return self.model(x)
47
-
48
-
49
- # ── 3. Instantiate & load weights ─────────────────────────────
50
- model = SafetyResNet().to(DEVICE)
51
- state = torch.load(CKPT_PATH, map_location=DEVICE)
52
- model.load_state_dict(state, strict=True)
53
- model.eval() # inference mode
54
- CLASSES = ["Safe", "Unsafe"]
55
-
56
-
57
- # ── 4. Pre-processing pipeline (ImageNet stats) ───────────────
58
- preprocess = transforms.Compose([
59
- transforms.Resize(256),
60
- transforms.CenterCrop(224),
61
- transforms.ToTensor(),
62
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
63
- std =[0.229, 0.224, 0.225]),
64
- ])
65
-
66
-
67
- # ── 5. Inference helper ───────────────────────────────────────
68
- @torch.inference_mode()
69
- def predict(img: Image.Image) -> t.Dict[str, float]:
70
- """
71
- Returns {class_name: probability} for a single PIL image.
72
- """
73
- tensor = preprocess(img).unsqueeze(0).to(DEVICE)
74
- probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
75
- return {CLASSES[i]: float(probs[i]) for i in range(2)}
76
-
77
-
78
- # ── 6. Build Gradio Interface ─────────────────────────────────
79
- demo = gr.Interface(
80
- fn=predict,
81
- inputs=gr.Image(type="pil", label="Upload an image"),
82
- outputs=gr.Label(num_top_classes=2, label="Prediction"),
83
- title="BubbleAI Image-Safety Detector",
84
- description=(
85
- "This demo classifies images as **Safe** or **Unsafe** (NSFW) "
86
- "using a fine-tuned ResNet-50. Probabilities are shown for both "
87
- "classes. Model weights Β© 2025 Amir Mehdi Memari."
88
- ),
89
- cache_examples=False, # disable weighty example cache rebuilds
90
- )
91
-
92
- # ── 7. Launch (HF Spaces auto-calls this in production) ───────
93
- if __name__ == "__main__":
94
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # ──────────────────────────────────────────────────────────────
3
+ # BubbleAI Image-Safety Detector – fixed version
4
+ # -------------------------------------------------------------
5
+ # Loads checkpoint `resnet_safety_classifier.pth` whose keys are
6
+ # saved under feature_extractor.* and classifier.*; serves a
7
+ # Gradio UI that predicts β€œSafe” vs β€œUnsafe”.
8
+ #
9
+ # Coder: Amir Mehdi Memari (2025-08-06)
10
+ # ──────────────────────────────────────────────────────────────
11
+
12
+ from __future__ import annotations
13
+ import pathlib
14
+ import typing as t
15
+
16
+ import torch
17
+ import torchvision
18
+ from torchvision import transforms
19
+ from PIL import Image
20
+ import gradio as gr
21
+
22
+
23
+ # ── 1. Paths & device ─────────────────────────────────────────
24
+ REPO_DIR = pathlib.Path(__file__).parent
25
+ CKPT_PATH = REPO_DIR / "resnet_safety_classifier.pth"
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+
29
+ # ── 2. Architecture that matches checkpoint keys ──────────────
30
+ class SafetyResNet(torch.nn.Module):
31
+ """
32
+ ResNet-50 backbone (conv1 β–Έ layer4) ➜ global avg-pool
33
+ ➜ MLP (2048β†’512β†’2). Keys align with:
34
+ β€’ feature_extractor.*
35
+ β€’ classifier.*
36
+ """
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+
40
+ base = torchvision.models.resnet50(weights=None)
41
+
42
+ # keep stem + 4 stages (0-7) [conv1, bn1, relu, maxpool, layer1-4]
43
+ self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
44
+
45
+ self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
46
+ self.classifier = torch.nn.Sequential(
47
+ torch.nn.Flatten(), # (B, 2048, 1, 1) β†’ (B, 2048)
48
+ torch.nn.Linear(2048, 512, bias=True),
49
+ torch.nn.ReLU(inplace=True),
50
+ torch.nn.Dropout(p=0.30),
51
+ torch.nn.Linear(512, 2, bias=True)
52
+ )
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ x = self.feature_extractor(x) # (B, 2048, H/32, W/32)
56
+ x = self.pool(x) # (B, 2048, 1, 1)
57
+ x = self.classifier(x) # (B, 2)
58
+ return x
59
+
60
+
61
+ # ── 3. Instantiate & load weights ─────────────────────────────
62
+ model = SafetyResNet().to(DEVICE)
63
+ state = torch.load(CKPT_PATH, map_location=DEVICE)
64
+ model.load_state_dict(state, strict=True)
65
+ model.eval()
66
+
67
+ CLASSES = ["Safe", "Unsafe"]
68
+
69
+
70
+ # ── 4. Pre-processing pipeline (ImageNet stats) ───────────────
71
+ preprocess = transforms.Compose([
72
+ transforms.Resize(256),
73
+ transforms.CenterCrop(224),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
76
+ std =[0.229, 0.224, 0.225]),
77
+ ])
78
+
79
+
80
+ # ── 5. Inference helper ───────────────────────────────────────
81
+ @torch.inference_mode()
82
+ def predict(img: Image.Image) -> t.Dict[str, float]:
83
+ """
84
+ Returns {class_name: probability} for a single PIL image.
85
+ """
86
+ tensor = preprocess(img).unsqueeze(0).to(DEVICE)
87
+ probs = torch.softmax(model(tensor)[0], dim=0).cpu().tolist()
88
+ return {CLASSES[i]: float(probs[i]) for i in range(2)}
89
+
90
+
91
+ # ── 6. Gradio interface ───────────────────────────────────────
92
+ demo = gr.Interface(
93
+ fn=predict,
94
+ inputs=gr.Image(type="pil", label="Upload an image"),
95
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
96
+ title="BubbleAI Image-Safety Detector",
97
+ description=(
98
+ "Drop an image to check whether it's **Safe** or **Unsafe** "
99
+ "using BubbleAI’s ResNet-50 classifier."
100
+ ),
101
+ cache_examples=False,
102
+ )
103
+
104
+ # ── 7. Launch ────────────────────────────────────────────────
105
+ if __name__ == "__main__":
106
+ demo.launch()