yuu1234 commited on
Commit
bd48ebf
·
1 Parent(s): 0943e27
Files changed (5) hide show
  1. Dockerfile +9 -0
  2. README.md +3 -3
  3. app.py +197 -0
  4. requirements.txt +7 -0
  5. resnet50_best_2.pth +3 -0
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+ COPY . .
5
+
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ EXPOSE 7860
9
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: SFW Check
3
  emoji: 🦀
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: SFW
3
  emoji: 🦀
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  ---
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ from typing import List
10
+ import requests
11
+ from io import BytesIO
12
+ import gradio as gr
13
+
14
+ # =====================
15
+ # CONFIG
16
+ # =====================
17
+ DEVICE = torch.device("cpu") # HF free → CPU
18
+ NUM_CLASSES = 3
19
+ IMG_SIZE = 224
20
+ MODEL_PATH = "resnet50_best_2.pth"
21
+
22
+ class_mapping = {
23
+ 0: "safe",
24
+ 1: "sexy",
25
+ 2: "violence"
26
+ }
27
+
28
+ # =====================
29
+ # LOAD MODEL (ONCE)
30
+ # =====================
31
+ model = models.resnet50(weights=None)
32
+
33
+ in_features = model.fc.in_features
34
+ model.fc = nn.Sequential(
35
+ nn.Dropout(p=0.5),
36
+ nn.Linear(in_features, NUM_CLASSES)
37
+ )
38
+
39
+ # Load weights
40
+ state_dict = torch.load(
41
+ MODEL_PATH,
42
+ map_location="cpu"
43
+ )
44
+ model.load_state_dict(state_dict)
45
+
46
+ model.to(DEVICE)
47
+ model.eval()
48
+
49
+ # =====================
50
+ # TRANSFORM
51
+ # =====================
52
+ transform = transforms.Compose([
53
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(
56
+ mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225]
58
+ )
59
+ ])
60
+
61
+ # =====================
62
+ # CORE PREDICT
63
+ # =====================
64
+ def predict_single(img: Image.Image):
65
+ x = transform(img).unsqueeze(0)
66
+ with torch.no_grad():
67
+ probs = torch.softmax(model(x), dim=1)[0]
68
+ conf, pred = torch.max(probs, dim=0)
69
+ return pred.item(), conf.item(), probs.tolist()
70
+
71
+
72
+ def predict_batch(images: List[Image.Image]):
73
+ xs = torch.stack([transform(img) for img in images])
74
+ with torch.no_grad():
75
+ probs = torch.softmax(model(xs), dim=1)
76
+ confs, preds = torch.max(probs, dim=1)
77
+ return preds.tolist(), confs.tolist(), probs.tolist()
78
+
79
+ # =====================
80
+ # FASTAPI APP
81
+ # =====================
82
+ app = FastAPI(title="SFW Image Moderation API")
83
+
84
+ class ImageRequest(BaseModel):
85
+ image_url: str
86
+
87
+ class BatchImageRequest(BaseModel):
88
+ image_urls: List[str]
89
+
90
+ # ---------- Single image ----------
91
+ @app.post("/moderate")
92
+ def moderate(req: ImageRequest):
93
+ try:
94
+ r = requests.get(req.image_url, timeout=5)
95
+ r.raise_for_status()
96
+ img = Image.open(BytesIO(r.content)).convert("RGB")
97
+ except Exception:
98
+ raise HTTPException(status_code=400, detail="Invalid image URL")
99
+
100
+ pred, conf, probs = predict_single(img)
101
+ label = class_mapping[pred]
102
+
103
+ return {
104
+ "allowed": label == "safe" and conf >= 0.6,
105
+ "label": label,
106
+ "confidence": round(conf, 4),
107
+ "probabilities": {
108
+ class_mapping[i]: round(probs[i], 4)
109
+ for i in range(NUM_CLASSES)
110
+ }
111
+ }
112
+
113
+ # ---------- Batch images ----------
114
+ @app.post("/moderate_batch")
115
+ def moderate_batch(req: BatchImageRequest):
116
+ if len(req.image_urls) > 16:
117
+ raise HTTPException(400, "Max 16 images per request")
118
+
119
+ images = []
120
+ urls = []
121
+
122
+ for url in req.image_urls:
123
+ try:
124
+ r = requests.get(url, timeout=5)
125
+ r.raise_for_status()
126
+ img = Image.open(BytesIO(r.content)).convert("RGB")
127
+ images.append(img)
128
+ urls.append(url)
129
+ except Exception:
130
+ continue
131
+
132
+ if not images:
133
+ raise HTTPException(400, "No valid images")
134
+
135
+ preds, confs, probs_list = predict_batch(images)
136
+
137
+ results = []
138
+ for url, pred, conf, probs in zip(urls, preds, confs, probs_list):
139
+ label = class_mapping[pred]
140
+ results.append({
141
+ "image_url": url,
142
+ "allowed": label == "safe" and conf >= 0.6,
143
+ "label": label,
144
+ "confidence": round(conf, 4),
145
+ "probabilities": {
146
+ class_mapping[i]: round(probs[i], 4)
147
+ for i in range(NUM_CLASSES)
148
+ }
149
+ })
150
+
151
+ return {
152
+ "count": len(results),
153
+ "results": results
154
+ }
155
+
156
+ # =====================
157
+ # GRADIO UI (DEMO ONLY)
158
+ # =====================
159
+ def ui_predict(image_url):
160
+ try:
161
+ r = requests.get(image_url, timeout=5)
162
+ r.raise_for_status()
163
+ img = Image.open(BytesIO(r.content)).convert("RGB")
164
+
165
+ pred, conf, probs = predict_single(img)
166
+
167
+ result = {
168
+ "label": class_mapping[pred],
169
+ "confidence": round(conf, 4),
170
+ "safe": round(probs[0], 4),
171
+ "sexy": round(probs[1], 4),
172
+ "violence": round(probs[2], 4),
173
+ }
174
+
175
+ return img, result
176
+
177
+ except Exception:
178
+ return None, {"error": "Invalid image URL"}
179
+
180
+ ui = gr.Interface(
181
+ fn=ui_predict,
182
+ inputs=gr.Textbox(
183
+ label="Image URL",
184
+ placeholder="Paste image URL here..."
185
+ ),
186
+ outputs=[
187
+ gr.Image(label="Preview Image"),
188
+ gr.JSON(label="Prediction")
189
+ ],
190
+ title="SFW Image Moderation Demo",
191
+ description="Demo UI. Backend should call API endpoints directly."
192
+ )
193
+
194
+ # =====================
195
+ # MOUNT UI TO FASTAPI
196
+ # =====================
197
+ app = gr.mount_gradio_app(app, ui, path="/")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ fastapi
4
+ uvicorn
5
+ pillow
6
+ requests
7
+ gradio
resnet50_best_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd9b80b9df73c3f0c46dbe7c7cc2953779fd6c2c4fc91961cf1057ab617fce2b
3
+ size 94376330