MetiMiester commited on
Commit
eef2847
·
verified ·
1 Parent(s): b3f739e

Update app_server.py

Browse files
Files changed (1) hide show
  1. app_server.py +92 -197
app_server.py CHANGED
@@ -1,15 +1,7 @@
1
  # app_server.py — BubbleGuard API + Dating-style Web Chat (Static UI)
2
- # Version: 1.7.0 (HF Drive-ready + repo-root UI support)
3
- # Author: Amir
4
-
5
- import io
6
- import os
7
- import re
8
- import uuid
9
- import pathlib
10
- import tempfile
11
- import subprocess
12
- import unicodedata
13
  from typing import Dict, Optional
14
 
15
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
@@ -17,9 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
17
  from fastapi.staticfiles import StaticFiles
18
  from fastapi.responses import PlainTextResponse
19
 
20
- import torch
21
- import joblib
22
- import torchvision
23
  from torchvision import transforms
24
  from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
25
  from PIL import Image
@@ -28,52 +18,39 @@ from faster_whisper import WhisperModel
28
  # -------------------------- Paths & Config --------------------------
29
  BASE = pathlib.Path(__file__).resolve().parent
30
  TEXT_DIR = BASE / "Text"
31
- IMG_DIR = BASE / "Image"
32
- AUD_DIR = BASE / "Audio"
33
  STATIC_DIR = BASE / "static"
34
 
35
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
- # Image thresholds / mapping
38
- IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
39
  IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
40
 
41
- # Whisper model selection
42
- WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base") # large-v2 | medium | small | base | tiny
43
 
44
- # Text thresholds and heuristics
45
- TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
46
- SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
47
- SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
48
 
49
- # Audio mapping/threshold (can differ from text)
50
  AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
51
- AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
52
 
53
- app = FastAPI(title="BubbleGuard API", version="1.7.0")
54
 
55
- # CORS open for demo; restrict in production
56
  app.add_middleware(
57
- CORSMiddleware,
58
- allow_origins=["*"],
59
- allow_methods=["*"],
60
- allow_headers=["*"],
61
  )
62
 
63
  # -------------------------- Text Classifier -------------------------
64
  if not TEXT_DIR.exists():
65
- raise RuntimeError(f"Text model dir not found: {TEXT_DIR}. Make sure download_assets ran.")
66
-
67
- try:
68
- tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
69
- txtM = AutoModelForSequenceClassification.from_pretrained(
70
- TEXT_DIR, local_files_only=True
71
- ).to(DEVICE).eval()
72
- except Exception as e:
73
- raise RuntimeError(f"Failed to load text model from {TEXT_DIR}: {e}")
74
-
75
- # ------------------------ Label mapping (robust) --------------------
76
- SAFE_LABEL_HINTS = {"safe", "ok", "clean", "benign", "non-toxic", "non_toxic", "non toxic"}
77
  UNSAFE_LABEL_HINTS = {"unsafe", "toxic", "abuse", "harm", "offense", "nsfw", "not_safe", "not safe"}
78
 
79
  def _infer_ids_by_name(model) -> (Optional[int], Optional[int]):
@@ -91,18 +68,13 @@ def _infer_ids_by_name(model) -> (Optional[int], Optional[int]):
91
  except Exception:
92
  continue
93
  norm[ki] = str(v).lower()
94
- safe_idx = None
95
- unsafe_idx = None
96
  for i, name in norm.items():
97
- if any(h in name for h in SAFE_LABEL_HINTS):
98
- safe_idx = i
99
- if any(h in name for h in UNSAFE_LABEL_HINTS):
100
- unsafe_idx = i
101
- if safe_idx is not None and unsafe_idx is None:
102
- unsafe_idx = 1 - safe_idx
103
- if unsafe_idx is not None and safe_idx is None:
104
- safe_idx = 1 - unsafe_idx
105
- return safe_idx, unsafe_idx
106
  except Exception:
107
  return None, None
108
 
@@ -111,10 +83,8 @@ def _infer_ids_by_probe(model, tok, device) -> (int, int):
111
  samples = ["hi", "hello", "how are you", "nice to meet you", "thanks"]
112
  enc = tok(samples, return_tensors="pt", truncation=True, padding=True, max_length=64)
113
  enc = {k: v.to(device) for k, v in enc.items()}
114
- logits = model(**enc).logits # [B, 2]
115
- probs = torch.softmax(logits, dim=-1).mean(0) # [2]
116
- safe_idx = int(torch.argmax(probs).item())
117
- unsafe_idx = 1 - safe_idx
118
  return safe_idx, unsafe_idx
119
 
120
  def _resolve_safe_unsafe_ids(model, tok, device) -> (int, int):
@@ -129,116 +99,63 @@ def _resolve_safe_unsafe_ids(model, tok, device) -> (int, int):
129
  SAFE_ID, UNSAFE_ID = _resolve_safe_unsafe_ids(txtM, tok, DEVICE)
130
  print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config, 'id2label', None)}")
131
 
132
- # ------------------------ Normalization utils -----------------------
133
  def normalize(text: str) -> str:
134
- if not isinstance(text, str):
135
- return ""
136
  t = unicodedata.normalize("NFKC", text)
137
- t = t.replace("’", "'").replace("‘", "'").replace("“", '"').replace("”", '"')
138
- t = t.lower()
139
- t = re.sub(r"[^a-z0-9\s']", " ", t)
140
- t = re.sub(r"\s+", " ", t).strip()
141
- return t
142
-
143
- SAFE_PHRASES = [
144
- r"^i don'?t$",
145
- r"^i do not$",
146
- r"^don'?t$",
147
- r"^no$",
148
- r"^not really$",
149
- r"^i wouldn'?t$",
150
- r"^i woulde?n'?t$",
151
- r"^i don'?t like$",
152
- ]
153
  SAFE_RE = re.compile("|".join(SAFE_PHRASES))
154
  NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
155
  NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
156
 
157
- SENSITIVE_TERMS = {
158
- "people", "you", "him", "her", "them", "men", "women", "girls", "boys",
159
- "muslim", "christian", "jew", "jews", "black", "white", "asian",
160
- "gay", "lesbian", "trans", "transgender", "disabled",
161
- "immigrants", "refugees", "poor", "old", "elderly", "fat", "skinny"
162
- }
163
- PROFANITY_TERMS = {"fuck", "shit", "bitch", "pussy", "dick", "cunt", "slut", "whore"}
164
-
165
- GREETINGS = [
166
- r"^hi$",
167
- r"^hello$",
168
- r"^hey(?: there)?$",
169
- r"^how are (?:you|u)\b.*$",
170
- r"^good (?:morning|afternoon|evening)\b.*$",
171
- r"^what'?s up\b.*$",
172
- r"^how'?s it going\b.*$",
173
- ]
174
  GREETING_RE = re.compile("|".join(GREETINGS))
175
 
176
  @torch.no_grad()
177
  def text_safe_payload(text: str) -> Dict:
178
- clean = normalize(text)
179
- toks = clean.split()
180
 
181
- # A) single-word profanity → UNSAFE
182
  if len(toks) == 1 and toks[0] in PROFANITY_TERMS:
183
- probs = [0.0, 0.0]; probs[UNSAFE_ID] = 1.0
184
- return {"safe": False, "unsafe_prob": 1.0, "label": "UNSAFE", "probs": probs,
185
- "tokens": 1, "reason": "profanity_single_word",
186
- "params": {"TEXT_UNSAFE_THR": TEXT_UNSAFE_THR, "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS, "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR}}
187
 
188
- # B) short text with any profanity → UNSAFE
189
  if len(toks) <= SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
190
- probs = [0.0, 0.0]; probs[UNSAFE_ID] = 1.0
191
- return {"safe": False, "unsafe_prob": 1.0, "label": "UNSAFE", "probs": probs,
192
- "tokens": len(toks), "reason": "profanity_short_text",
193
- "params": {"TEXT_UNSAFE_THR": TEXT_UNSAFE_THR, "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS, "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR}}
194
 
195
- # Allow-lists
196
  if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
197
- probs = [0.0, 0.0]; probs[SAFE_ID] = 1.0
198
- reason = "greeting_allow" if GREETING_RE.match(clean) else "allow_list"
199
- return {"safe": True, "unsafe_prob": 0.0, "label": "SAFE", "probs": probs,
200
- "tokens": len(toks), "reason": reason,
201
- "params": {"TEXT_UNSAFE_THR": TEXT_UNSAFE_THR, "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS, "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR}}
202
 
203
- # Neutral dislike relax
204
  if NEUTRAL_DISLIKE.match(clean):
205
- has_sensitive = any(term in clean for term in SENSITIVE_TERMS)
206
- has_profanity = any(term in clean for term in PROFANITY_TERMS)
207
- if not has_sensitive and not has_profanity:
208
  enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
209
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
210
- logits = txtM(**enc).logits[0]
211
- probs = torch.softmax(logits, dim=-1).detach().cpu().tolist()
212
- unsafe_prob = float(probs[UNSAFE_ID])
213
- is_safe = unsafe_prob < 0.98
214
- return {"safe": bool(is_safe), "unsafe_prob": unsafe_prob,
215
- "label": "SAFE" if is_safe else "UNSAFE",
216
- "probs": probs, "tokens": int(enc["input_ids"].shape[1]),
217
- "reason": "neutral_dislike_relaxed",
218
- "params": {"TEXT_UNSAFE_THR": TEXT_UNSAFE_THR, "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS, "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR}}
219
-
220
- # Normal model path
221
  enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
222
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
223
  logits = txtM(**enc).logits[0]
224
- probs = torch.softmax(logits, dim=-1).detach().cpu().tolist()
225
- unsafe_prob = float(probs[UNSAFE_ID])
226
- pred_idx = int(torch.argmax(logits))
227
- num_tokens = int(enc["input_ids"].shape[1])
228
 
229
- if num_tokens <= SHORT_MSG_MAX_TOKENS:
230
- is_safe = unsafe_prob < SHORT_MSG_UNSAFE_THR
231
- reason = "short_msg_threshold"
232
- else:
233
- is_safe = unsafe_prob < TEXT_UNSAFE_THR
234
- reason = "global_threshold"
235
-
236
- label = (txtM.config.id2label.get(pred_idx)
237
- if isinstance(getattr(txtM.config, "id2label", None), dict) else None) or str(pred_idx)
238
-
239
- return {"safe": bool(is_safe), "unsafe_prob": unsafe_prob, "label": label,
240
- "probs": probs, "tokens": num_tokens, "reason": reason,
241
- "params": {"TEXT_UNSAFE_THR": TEXT_UNSAFE_THR, "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS, "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR}}
242
 
243
  # -------------------------- Image Classifier ------------------------
244
  class SafetyResNet(torch.nn.Module):
@@ -248,28 +165,18 @@ class SafetyResNet(torch.nn.Module):
248
  self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
249
  self.pool = torch.nn.AdaptiveAvgPool2d(1)
250
  self.classifier = torch.nn.Sequential(
251
- torch.nn.Linear(2048, 512),
252
- torch.nn.ReLU(True),
253
- torch.nn.Dropout(0.30),
254
- torch.nn.Linear(512, 2),
255
  )
256
-
257
  def forward(self, x):
258
  x = self.pool(self.feature_extractor(x))
259
  return self.classifier(torch.flatten(x, 1))
260
 
261
  if not IMG_DIR.exists():
262
- raise RuntimeError(f"Image model dir not found: {IMG_DIR}. Make sure download_assets ran.")
263
-
264
- try:
265
- imgM = SafetyResNet().to(DEVICE)
266
- imgM.load_state_dict(
267
- torch.load(IMG_DIR / "resnet_safety_classifier.pth", map_location=DEVICE),
268
- strict=True
269
- )
270
- imgM.eval()
271
- except Exception as e:
272
- raise RuntimeError(f"Failed to load image model weights from {IMG_DIR}: {e}")
273
 
274
  img_tf = transforms.Compose([
275
  transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
@@ -279,68 +186,56 @@ img_tf = transforms.Compose([
279
  ])
280
 
281
  @torch.no_grad()
282
- def image_safe_payload(pil_img: Image.Image) -> Dict:
283
- x = img_tf(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
284
- logits = imgM(x)[0]
285
- probs = torch.softmax(logits, dim=0).detach().cpu().tolist() # [2]
286
- unsafe_p = float(probs[IMG_UNSAFE_INDEX])
287
- return {"safe": unsafe_p < IMG_UNSAFE_THR, "unsafe_prob": unsafe_p, "probs": probs}
288
 
289
  # -------------------------- Audio (ASR -> NLP) ----------------------
290
  compute_type = "float16" if DEVICE == "cuda" else "int8"
291
-
292
- try:
293
- asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
294
- except Exception as e:
295
- raise RuntimeError(
296
- f"Failed to load Whisper model '{WHISPER_MODEL_NAME}': {e}. "
297
- f"Tip: ensure ffmpeg is installed (Dockerfile/apt.txt)."
298
- )
299
 
300
  if not AUD_DIR.exists():
301
- raise RuntimeError(f"Audio pipeline dir not found: {AUD_DIR}. Make sure download_assets ran.")
302
-
303
- try:
304
- text_clf = joblib.load(AUD_DIR / "text_pipeline_balanced.joblib")
305
- except Exception as e:
306
- raise RuntimeError(f"Failed to load audio text pipeline from {AUD_DIR}: {e}")
307
 
308
  def _ffmpeg_to_wav(src_bytes: bytes) -> bytes:
309
  with tempfile.TemporaryDirectory() as td:
310
  in_path = pathlib.Path(td) / f"in-{uuid.uuid4().hex}.bin"
311
  out_path = pathlib.Path(td) / "out.wav"
312
  in_path.write_bytes(src_bytes)
313
- cmd = ["ffmpeg", "-y", "-i", str(in_path), "-ac", "1", "-ar", "16000", str(out_path)]
314
  try:
315
  subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
316
  return out_path.read_bytes()
317
  except FileNotFoundError as e:
318
- raise RuntimeError("FFmpeg not found. Install it and ensure 'ffmpeg' is on PATH.") from e
319
  except subprocess.CalledProcessError:
320
  return src_bytes
321
 
322
  def _transcribe_wav_bytes(wav_bytes: bytes) -> str:
323
  td = tempfile.mkdtemp()
324
- path = pathlib.Path(td) / "in.wav"
325
  try:
326
- path.write_bytes(wav_bytes)
327
- segments, _ = asr.transcribe(str(path), beam_size=5, language="en")
328
  return " ".join(s.text for s in segments).strip()
329
  finally:
330
- try: path.unlink(missing_ok=True)
331
  except Exception: pass
332
  try: pathlib.Path(td).rmdir()
333
  except Exception: pass
334
 
335
- def audio_safe_from_bytes(raw_bytes: bytes) -> Dict:
336
- wav = _ffmpeg_to_wav(raw_bytes)
337
  text = _transcribe_wav_bytes(wav)
338
  proba = text_clf.predict_proba([text])[0].tolist()
339
- unsafe_p = float(proba[AUDIO_UNSAFE_INDEX])
340
- return {"safe": unsafe_p < AUDIO_UNSAFE_THR, "unsafe_prob": unsafe_p, "text": text, "probs": proba}
341
 
342
- # ------------------------------ Routes ------------------------------
343
- @app.get("/health")
344
  def health():
345
  return {
346
  "ok": True,
@@ -356,7 +251,7 @@ def health():
356
  "safe_unsafe_indices(text_model)": {"SAFE_ID": SAFE_ID, "UNSAFE_ID": UNSAFE_ID},
357
  }
358
 
359
- @app.post("/check_text")
360
  def check_text(text: str = Form(...)):
361
  if not text or not text.strip():
362
  raise HTTPException(400, "Empty text")
@@ -365,7 +260,7 @@ def check_text(text: str = Form(...)):
365
  except Exception as e:
366
  raise HTTPException(500, f"Text screening error: {e}")
367
 
368
- @app.post("/check_image")
369
  async def check_image(file: UploadFile = File(...)):
370
  data = await file.read()
371
  if not data:
@@ -379,7 +274,7 @@ async def check_image(file: UploadFile = File(...)):
379
  except Exception as e:
380
  raise HTTPException(500, f"Image screening error: {e}")
381
 
382
- @app.post("/check_audio")
383
  async def check_audio(file: UploadFile = File(...)):
384
  raw = await file.read()
385
  if not raw:
@@ -392,7 +287,7 @@ async def check_audio(file: UploadFile = File(...)):
392
  raise HTTPException(500, f"Audio processing error: {e}")
393
 
394
  # --------------------------- Static Mount ---------------------------
395
- # Serve web UI from /static if it exists; otherwise serve from repo root.
396
  static_dir = BASE / "static"
397
  root_index = BASE / "index.html"
398
 
 
1
  # app_server.py — BubbleGuard API + Dating-style Web Chat (Static UI)
2
+ # Version: 1.7.1 (/api/* routes + repo-root UI support)
3
+
4
+ import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
 
 
 
 
 
 
 
 
5
  from typing import Dict, Optional
6
 
7
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
 
9
  from fastapi.staticfiles import StaticFiles
10
  from fastapi.responses import PlainTextResponse
11
 
12
+ import torch, joblib, torchvision
 
 
13
  from torchvision import transforms
14
  from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
15
  from PIL import Image
 
18
  # -------------------------- Paths & Config --------------------------
19
  BASE = pathlib.Path(__file__).resolve().parent
20
  TEXT_DIR = BASE / "Text"
21
+ IMG_DIR = BASE / "Image"
22
+ AUD_DIR = BASE / "Audio"
23
  STATIC_DIR = BASE / "static"
24
 
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
+ IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
 
28
  IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
29
 
30
+ WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
 
31
 
32
+ TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
33
+ SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
34
+ SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
 
35
 
 
36
  AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
37
+ AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
38
 
39
+ app = FastAPI(title="BubbleGuard API", version="1.7.1")
40
 
 
41
  app.add_middleware(
42
+ CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
 
 
 
43
  )
44
 
45
  # -------------------------- Text Classifier -------------------------
46
  if not TEXT_DIR.exists():
47
+ raise RuntimeError(f"Text model dir not found: {TEXT_DIR}. Run download_assets first.")
48
+
49
+ tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
50
+ txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()
51
+
52
+ # -------- Label mapping (robust) --------
53
+ SAFE_LABEL_HINTS = {"safe", "ok", "clean", "benign", "non-toxic", "non_toxic", "non toxic"}
 
 
 
 
 
54
  UNSAFE_LABEL_HINTS = {"unsafe", "toxic", "abuse", "harm", "offense", "nsfw", "not_safe", "not safe"}
55
 
56
  def _infer_ids_by_name(model) -> (Optional[int], Optional[int]):
 
68
  except Exception:
69
  continue
70
  norm[ki] = str(v).lower()
71
+ s = u = None
 
72
  for i, name in norm.items():
73
+ if any(h in name for h in SAFE_LABEL_HINTS): s = i
74
+ if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
75
+ if s is not None and u is None: u = 1 - s
76
+ if u is not None and s is None: s = 1 - u
77
+ return s, u
 
 
 
 
78
  except Exception:
79
  return None, None
80
 
 
83
  samples = ["hi", "hello", "how are you", "nice to meet you", "thanks"]
84
  enc = tok(samples, return_tensors="pt", truncation=True, padding=True, max_length=64)
85
  enc = {k: v.to(device) for k, v in enc.items()}
86
+ probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
87
+ safe_idx = int(torch.argmax(probs).item()); unsafe_idx = 1 - safe_idx
 
 
88
  return safe_idx, unsafe_idx
89
 
90
  def _resolve_safe_unsafe_ids(model, tok, device) -> (int, int):
 
99
  SAFE_ID, UNSAFE_ID = _resolve_safe_unsafe_ids(txtM, tok, DEVICE)
100
  print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config, 'id2label', None)}")
101
 
102
+ # ------------------------ Text utils ------------------------
103
  def normalize(text: str) -> str:
104
+ if not isinstance(text, str): return ""
 
105
  t = unicodedata.normalize("NFKC", text)
106
+ t = t.replace("’","'").replace("‘","'").replace("“",'"').replace("”",'"')
107
+ t = re.sub(r"[^a-z0-9\s']", " ", t.lower())
108
+ return re.sub(r"\s+", " ", t).strip()
109
+
110
+ SAFE_PHRASES = [r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]
 
 
 
 
 
 
 
 
 
 
 
111
  SAFE_RE = re.compile("|".join(SAFE_PHRASES))
112
  NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
113
  NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
114
 
115
+ SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys",
116
+ "muslim","christian","jew","jews","black","white","asian",
117
+ "gay","lesbian","trans","transgender","disabled",
118
+ "immigrants","refugees","poor","old","elderly","fat","skinny"}
119
+ PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
120
+
121
+ GREETINGS = [r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$",
122
+ r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]
 
 
 
 
 
 
 
 
 
123
  GREETING_RE = re.compile("|".join(GREETINGS))
124
 
125
  @torch.no_grad()
126
  def text_safe_payload(text: str) -> Dict:
127
+ clean = normalize(text); toks = clean.split()
 
128
 
 
129
  if len(toks) == 1 and toks[0] in PROFANITY_TERMS:
130
+ p = [0.0,0.0]; p[UNSAFE_ID]=1.0
131
+ return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
 
 
132
 
 
133
  if len(toks) <= SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
134
+ p = [0.0,0.0]; p[UNSAFE_ID]=1.0
135
+ return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
 
 
136
 
 
137
  if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
138
+ p=[0.0,0.0]; p[SAFE_ID]=1.0
139
+ return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
 
 
 
140
 
 
141
  if NEUTRAL_DISLIKE.match(clean):
142
+ if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
 
 
143
  enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
144
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
145
+ probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist()
146
+ up = float(probs[UNSAFE_ID]); safe = up < 0.98
147
+ return {"safe":bool(safe),"unsafe_prob":up,"label":"SAFE" if safe else "UNSAFE",
148
+ "probs":probs,"tokens":int(enc["input_ids"].shape[1]),"reason":"neutral_dislike_relaxed"}
149
+
 
 
 
 
 
 
150
  enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
151
  enc = {k: v.to(DEVICE) for k, v in enc.items()}
152
  logits = txtM(**enc).logits[0]
153
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()
154
+ up = float(probs[UNSAFE_ID]); toks = int(enc["input_ids"].shape[1])
 
 
155
 
156
+ safe = up < (SHORT_MSG_UNSAFE_THR if toks <= SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR)
157
+ return {"safe":bool(safe),"unsafe_prob":up,"label":str(int(torch.argmax(logits))),
158
+ "probs":probs,"tokens":toks,"reason":"short_msg_threshold" if toks<=SHORT_MSG_MAX_TOKENS else "global_threshold"}
 
 
 
 
 
 
 
 
 
 
159
 
160
  # -------------------------- Image Classifier ------------------------
161
  class SafetyResNet(torch.nn.Module):
 
165
  self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
166
  self.pool = torch.nn.AdaptiveAvgPool2d(1)
167
  self.classifier = torch.nn.Sequential(
168
+ torch.nn.Linear(2048, 512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512, 2)
 
 
 
169
  )
 
170
  def forward(self, x):
171
  x = self.pool(self.feature_extractor(x))
172
  return self.classifier(torch.flatten(x, 1))
173
 
174
  if not IMG_DIR.exists():
175
+ raise RuntimeError(f"Image model dir not found: {IMG_DIR}. Run download_assets first.")
176
+
177
+ imgM = SafetyResNet().to(DEVICE)
178
+ imgM.load_state_dict(torch.load(IMG_DIR / "resnet_safety_classifier.pth", map_location=DEVICE), strict=True)
179
+ imgM.eval()
 
 
 
 
 
 
180
 
181
  img_tf = transforms.Compose([
182
  transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
 
186
  ])
187
 
188
  @torch.no_grad()
189
+ def image_safe_payload(pil: Image.Image) -> Dict:
190
+ x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
191
+ probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
192
+ up = float(probs[IMG_UNSAFE_INDEX])
193
+ return {"safe": up < IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}
 
194
 
195
  # -------------------------- Audio (ASR -> NLP) ----------------------
196
  compute_type = "float16" if DEVICE == "cuda" else "int8"
197
+ asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
 
 
 
 
 
 
 
198
 
199
  if not AUD_DIR.exists():
200
+ raise RuntimeError(f"Audio pipeline dir not found: {AUD_DIR}. Run download_assets first.")
201
+ text_clf = joblib.load(AUD_DIR / "text_pipeline_balanced.joblib")
 
 
 
 
202
 
203
  def _ffmpeg_to_wav(src_bytes: bytes) -> bytes:
204
  with tempfile.TemporaryDirectory() as td:
205
  in_path = pathlib.Path(td) / f"in-{uuid.uuid4().hex}.bin"
206
  out_path = pathlib.Path(td) / "out.wav"
207
  in_path.write_bytes(src_bytes)
208
+ cmd = ["ffmpeg","-y","-i",str(in_path),"-ac","1","-ar","16000",str(out_path)]
209
  try:
210
  subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
211
  return out_path.read_bytes()
212
  except FileNotFoundError as e:
213
+ raise RuntimeError("FFmpeg not found on PATH.") from e
214
  except subprocess.CalledProcessError:
215
  return src_bytes
216
 
217
  def _transcribe_wav_bytes(wav_bytes: bytes) -> str:
218
  td = tempfile.mkdtemp()
219
+ p = pathlib.Path(td) / "in.wav"
220
  try:
221
+ p.write_bytes(wav_bytes)
222
+ segments, _ = asr.transcribe(str(p), beam_size=5, language="en")
223
  return " ".join(s.text for s in segments).strip()
224
  finally:
225
+ try: p.unlink(missing_ok=True)
226
  except Exception: pass
227
  try: pathlib.Path(td).rmdir()
228
  except Exception: pass
229
 
230
+ def audio_safe_from_bytes(raw: bytes) -> Dict:
231
+ wav = _ffmpeg_to_wav(raw)
232
  text = _transcribe_wav_bytes(wav)
233
  proba = text_clf.predict_proba([text])[0].tolist()
234
+ up = float(proba[AUDIO_UNSAFE_INDEX])
235
+ return {"safe": up < AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}
236
 
237
+ # ------------------------------ Routes (under /api) ------------------------------
238
+ @app.get("/api/health")
239
  def health():
240
  return {
241
  "ok": True,
 
251
  "safe_unsafe_indices(text_model)": {"SAFE_ID": SAFE_ID, "UNSAFE_ID": UNSAFE_ID},
252
  }
253
 
254
+ @app.post("/api/check_text")
255
  def check_text(text: str = Form(...)):
256
  if not text or not text.strip():
257
  raise HTTPException(400, "Empty text")
 
260
  except Exception as e:
261
  raise HTTPException(500, f"Text screening error: {e}")
262
 
263
+ @app.post("/api/check_image")
264
  async def check_image(file: UploadFile = File(...)):
265
  data = await file.read()
266
  if not data:
 
274
  except Exception as e:
275
  raise HTTPException(500, f"Image screening error: {e}")
276
 
277
+ @app.post("/api/check_audio")
278
  async def check_audio(file: UploadFile = File(...)):
279
  raw = await file.read()
280
  if not raw:
 
287
  raise HTTPException(500, f"Audio processing error: {e}")
288
 
289
  # --------------------------- Static Mount ---------------------------
290
+ # Serve UI from /static if present; otherwise from repo root (index.html at root).
291
  static_dir = BASE / "static"
292
  root_index = BASE / "index.html"
293