MetiMiester commited on
Commit
d4336c3
·
verified ·
1 Parent(s): f8924f9

Update app_server.py

Browse files
Files changed (1) hide show
  1. app_server.py +88 -189
app_server.py CHANGED
@@ -1,76 +1,57 @@
1
- # app_server.py — BubbleGuard API + Dating-style Web Chat (Static UI)
2
  # Version: 1.7.1 (/api/* routes + repo-root UI support)
3
 
4
  import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
5
  from typing import Dict, Optional
6
-
7
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.staticfiles import StaticFiles
10
  from fastapi.responses import PlainTextResponse
11
-
12
  import torch, joblib, torchvision
13
  from torchvision import transforms
14
  from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
15
  from PIL import Image
16
  from faster_whisper import WhisperModel
17
 
18
- # -------------------------- Paths & Config --------------------------
19
  BASE = pathlib.Path(__file__).resolve().parent
20
  TEXT_DIR = BASE / "Text"
21
  IMG_DIR = BASE / "Image"
22
  AUD_DIR = BASE / "Audio"
23
- STATIC_DIR = BASE / "static"
24
 
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
-
27
  IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
28
  IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
29
-
30
  WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
31
-
32
  TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
33
  SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
34
  SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
35
-
36
  AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
37
  AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
38
 
39
  app = FastAPI(title="BubbleGuard API", version="1.7.1")
 
40
 
41
- app.add_middleware(
42
- CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
43
- )
44
-
45
- # -------------------------- Text Classifier -------------------------
46
- if not TEXT_DIR.exists():
47
- raise RuntimeError(f"Text model dir not found: {TEXT_DIR}. Run download_assets first.")
48
-
49
  tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
50
  txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()
51
 
52
- # -------- Label mapping (robust) --------
53
- SAFE_LABEL_HINTS = {"safe", "ok", "clean", "benign", "non-toxic", "non_toxic", "non toxic"}
54
- UNSAFE_LABEL_HINTS = {"unsafe", "toxic", "abuse", "harm", "offense", "nsfw", "not_safe", "not safe"}
55
 
56
- def _infer_ids_by_name(model) -> (Optional[int], Optional[int]):
57
  try:
58
- id2label = getattr(model.config, "id2label", None)
59
- if not isinstance(id2label, dict):
60
- return None, None
61
  norm = {}
62
  for k, v in id2label.items():
63
- try:
64
- ki = int(k)
65
- except Exception:
66
- try:
67
- ki = int(str(k).strip())
68
- except Exception:
69
- continue
70
  norm[ki] = str(v).lower()
71
  s = u = None
72
  for i, name in norm.items():
73
- if any(h in name for h in SAFE_LABEL_HINTS): s = i
74
  if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
75
  if s is not None and u is None: u = 1 - s
76
  if u is not None and s is None: s = 1 - u
@@ -79,148 +60,92 @@ def _infer_ids_by_name(model) -> (Optional[int], Optional[int]):
79
  return None, None
80
 
81
  @torch.no_grad()
82
- def _infer_ids_by_probe(model, tok, device) -> (int, int):
83
- samples = ["hi", "hello", "how are you", "nice to meet you", "thanks"]
84
- enc = tok(samples, return_tensors="pt", truncation=True, padding=True, max_length=64)
85
- enc = {k: v.to(device) for k, v in enc.items()}
86
  probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
87
- safe_idx = int(torch.argmax(probs).item()); unsafe_idx = 1 - safe_idx
88
- return safe_idx, unsafe_idx
89
 
90
- def _resolve_safe_unsafe_ids(model, tok, device) -> (int, int):
91
- s_env = os.getenv("SAFE_ID"); u_env = os.getenv("UNSAFE_ID")
92
- if s_env is not None and u_env is not None:
93
- return int(s_env), int(u_env)
94
  s, u = _infer_ids_by_name(model)
95
- if s is not None and u is not None:
96
- return s, u
97
- return _infer_ids_by_probe(model, tok, device)
98
 
99
- SAFE_ID, UNSAFE_ID = _resolve_safe_unsafe_ids(txtM, tok, DEVICE)
100
- print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config, 'id2label', None)}")
101
 
102
- # ------------------------ Text utils ------------------------
103
- def normalize(text: str) -> str:
104
- if not isinstance(text, str): return ""
105
- t = unicodedata.normalize("NFKC", text)
106
- t = t.replace("’","'").replace("‘","'").replace("“",'"').replace("”",'"')
107
- t = re.sub(r"[^a-z0-9\s']", " ", t.lower())
108
- return re.sub(r"\s+", " ", t).strip()
109
 
110
- SAFE_PHRASES = [r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]
111
- SAFE_RE = re.compile("|".join(SAFE_PHRASES))
112
  NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
113
  NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
114
-
115
- SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys",
116
- "muslim","christian","jew","jews","black","white","asian",
117
- "gay","lesbian","trans","transgender","disabled",
118
- "immigrants","refugees","poor","old","elderly","fat","skinny"}
119
  PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
120
-
121
- GREETINGS = [r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$",
122
- r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]
123
- GREETING_RE = re.compile("|".join(GREETINGS))
124
 
125
  @torch.no_grad()
126
  def text_safe_payload(text: str) -> Dict:
127
  clean = normalize(text); toks = clean.split()
128
-
129
- if len(toks) == 1 and toks[0] in PROFANITY_TERMS:
130
- p = [0.0,0.0]; p[UNSAFE_ID]=1.0
131
- return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
132
-
133
- if len(toks) <= SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
134
- p = [0.0,0.0]; p[UNSAFE_ID]=1.0
135
- return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
136
-
137
  if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
138
- p=[0.0,0.0]; p[SAFE_ID]=1.0
139
- return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
140
-
141
  if NEUTRAL_DISLIKE.match(clean):
142
  if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
143
- enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
144
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
145
- probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist()
146
- up = float(probs[UNSAFE_ID]); safe = up < 0.98
147
- return {"safe":bool(safe),"unsafe_prob":up,"label":"SAFE" if safe else "UNSAFE",
148
- "probs":probs,"tokens":int(enc["input_ids"].shape[1]),"reason":"neutral_dislike_relaxed"}
149
-
150
- enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
151
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
152
- logits = txtM(**enc).logits[0]
153
- probs = torch.softmax(logits, dim=-1).cpu().tolist()
154
- up = float(probs[UNSAFE_ID]); toks = int(enc["input_ids"].shape[1])
155
-
156
- safe = up < (SHORT_MSG_UNSAFE_THR if toks <= SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR)
157
- return {"safe":bool(safe),"unsafe_prob":up,"label":str(int(torch.argmax(logits))),
158
- "probs":probs,"tokens":toks,"reason":"short_msg_threshold" if toks<=SHORT_MSG_MAX_TOKENS else "global_threshold"}
159
-
160
- # -------------------------- Image Classifier ------------------------
161
  class SafetyResNet(torch.nn.Module):
162
  def __init__(self):
163
  super().__init__()
164
  base = torchvision.models.resnet50(weights=None)
165
  self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
166
  self.pool = torch.nn.AdaptiveAvgPool2d(1)
167
- self.classifier = torch.nn.Sequential(
168
- torch.nn.Linear(2048, 512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512, 2)
169
- )
170
- def forward(self, x):
171
- x = self.pool(self.feature_extractor(x))
172
- return self.classifier(torch.flatten(x, 1))
173
 
174
- if not IMG_DIR.exists():
175
- raise RuntimeError(f"Image model dir not found: {IMG_DIR}. Run download_assets first.")
176
-
177
- imgM = SafetyResNet().to(DEVICE)
178
- imgM.load_state_dict(torch.load(IMG_DIR / "resnet_safety_classifier.pth", map_location=DEVICE), strict=True)
179
- imgM.eval()
180
-
181
- img_tf = transforms.Compose([
182
- transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
183
- transforms.CenterCrop(224),
184
- transforms.ToTensor(),
185
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
186
- ])
187
 
188
  @torch.no_grad()
189
  def image_safe_payload(pil: Image.Image) -> Dict:
190
  x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
191
  probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
192
- up = float(probs[IMG_UNSAFE_INDEX])
193
- return {"safe": up < IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}
194
 
195
- # -------------------------- Audio (ASR -> NLP) ----------------------
196
- compute_type = "float16" if DEVICE == "cuda" else "int8"
197
  asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
 
 
198
 
199
- if not AUD_DIR.exists():
200
- raise RuntimeError(f"Audio pipeline dir not found: {AUD_DIR}. Run download_assets first.")
201
- text_clf = joblib.load(AUD_DIR / "text_pipeline_balanced.joblib")
202
-
203
- def _ffmpeg_to_wav(src_bytes: bytes) -> bytes:
204
  with tempfile.TemporaryDirectory() as td:
205
- in_path = pathlib.Path(td) / f"in-{uuid.uuid4().hex}.bin"
206
- out_path = pathlib.Path(td) / "out.wav"
207
- in_path.write_bytes(src_bytes)
208
- cmd = ["ffmpeg","-y","-i",str(in_path),"-ac","1","-ar","16000",str(out_path)]
209
  try:
210
- subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
211
- return out_path.read_bytes()
212
- except FileNotFoundError as e:
213
- raise RuntimeError("FFmpeg not found on PATH.") from e
214
- except subprocess.CalledProcessError:
215
- return src_bytes
216
 
217
- def _transcribe_wav_bytes(wav_bytes: bytes) -> str:
218
- td = tempfile.mkdtemp()
219
- p = pathlib.Path(td) / "in.wav"
220
  try:
221
- p.write_bytes(wav_bytes)
222
- segments, _ = asr.transcribe(str(p), beam_size=5, language="en")
223
- return " ".join(s.text for s in segments).strip()
224
  finally:
225
  try: p.unlink(missing_ok=True)
226
  except Exception: pass
@@ -228,74 +153,48 @@ def _transcribe_wav_bytes(wav_bytes: bytes) -> str:
228
  except Exception: pass
229
 
230
  def audio_safe_from_bytes(raw: bytes) -> Dict:
231
- wav = _ffmpeg_to_wav(raw)
232
- text = _transcribe_wav_bytes(wav)
233
- proba = text_clf.predict_proba([text])[0].tolist()
234
- up = float(proba[AUDIO_UNSAFE_INDEX])
235
- return {"safe": up < AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}
236
 
237
- # ------------------------------ Routes (under /api) ------------------------------
238
  @app.get("/api/health")
239
  def health():
240
- return {
241
- "ok": True,
242
- "device": DEVICE,
243
- "whisper": WHISPER_MODEL_NAME,
244
- "img": {"unsafe_threshold": IMG_UNSAFE_THR, "unsafe_index": IMG_UNSAFE_INDEX},
245
- "text_thresholds": {
246
- "TEXT_UNSAFE_THR": TEXT_UNSAFE_THR,
247
- "SHORT_MSG_MAX_TOKENS": SHORT_MSG_MAX_TOKENS,
248
- "SHORT_MSG_UNSAFE_THR": SHORT_MSG_UNSAFE_THR,
249
- },
250
- "audio": {"unsafe_index": AUDIO_UNSAFE_INDEX, "unsafe_threshold": AUDIO_UNSAFE_THR},
251
- "safe_unsafe_indices(text_model)": {"SAFE_ID": SAFE_ID, "UNSAFE_ID": UNSAFE_ID},
252
- }
253
 
254
  @app.post("/api/check_text")
255
  def check_text(text: str = Form(...)):
256
- if not text or not text.strip():
257
- raise HTTPException(400, "Empty text")
258
- try:
259
- return text_safe_payload(text)
260
- except Exception as e:
261
- raise HTTPException(500, f"Text screening error: {e}")
262
 
263
  @app.post("/api/check_image")
264
  async def check_image(file: UploadFile = File(...)):
265
  data = await file.read()
266
- if not data:
267
- raise HTTPException(400, "Empty image")
268
- try:
269
- pil = Image.open(io.BytesIO(data))
270
- except Exception:
271
- raise HTTPException(400, "Invalid image")
272
- try:
273
- return image_safe_payload(pil)
274
- except Exception as e:
275
- raise HTTPException(500, f"Image screening error: {e}")
276
 
277
  @app.post("/api/check_audio")
278
  async def check_audio(file: UploadFile = File(...)):
279
  raw = await file.read()
280
- if not raw:
281
- raise HTTPException(400, "Empty audio")
282
- try:
283
- return audio_safe_from_bytes(raw)
284
- except RuntimeError as e:
285
- raise HTTPException(500, f"{e}")
286
- except Exception as e:
287
- raise HTTPException(500, f"Audio processing error: {e}")
288
 
289
- # --------------------------- Static Mount ---------------------------
290
- # Serve UI from /static if present; otherwise from repo root (index.html at root).
291
  static_dir = BASE / "static"
292
- root_index = BASE / "index.html"
293
-
294
  if static_dir.exists():
295
  app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
296
- elif root_index.exists():
297
  app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root")
298
  else:
299
  @app.get("/", response_class=PlainTextResponse)
300
- def _root_fallback():
301
- return "BubbleGuard API is running. Put index.html at repo root or add a 'static/' folder."
 
1
+ # app_server.py — BubbleGuard API + Web UI
2
  # Version: 1.7.1 (/api/* routes + repo-root UI support)
3
 
4
  import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
5
  from typing import Dict, Optional
 
6
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
9
  from fastapi.responses import PlainTextResponse
 
10
  import torch, joblib, torchvision
11
  from torchvision import transforms
12
  from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
13
  from PIL import Image
14
  from faster_whisper import WhisperModel
15
 
 
16
  BASE = pathlib.Path(__file__).resolve().parent
17
  TEXT_DIR = BASE / "Text"
18
  IMG_DIR = BASE / "Image"
19
  AUD_DIR = BASE / "Audio"
 
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
23
  IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
 
24
  WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
 
25
  TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
26
  SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
27
  SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
 
28
  AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
29
  AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
30
 
31
  app = FastAPI(title="BubbleGuard API", version="1.7.1")
32
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
33
 
34
+ # ---------- Text model ----------
35
+ if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}")
 
 
 
 
 
 
36
  tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
37
  txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()
38
 
39
+ SAFE_LABEL_HINTS = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"}
40
+ UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"}
 
41
 
42
+ def _infer_ids_by_name(model):
43
  try:
44
+ id2label = getattr(model.config, "id2label", {})
 
 
45
  norm = {}
46
  for k, v in id2label.items():
47
+ try: ki = int(k)
48
+ except Exception:
49
+ try: ki = int(str(k).strip())
50
+ except Exception: continue
 
 
 
51
  norm[ki] = str(v).lower()
52
  s = u = None
53
  for i, name in norm.items():
54
+ if any(h in name for h in SAFE_LABEL_HINTS): s = i
55
  if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
56
  if s is not None and u is None: u = 1 - s
57
  if u is not None and s is None: s = 1 - u
 
60
  return None, None
61
 
62
  @torch.no_grad()
63
+ def _infer_ids_by_probe(model, tok, device):
64
+ enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64)
65
+ enc = {k:v.to(device) for k,v in enc.items()}
 
66
  probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
67
+ s = int(torch.argmax(probs)); return s, 1 - s
 
68
 
69
+ def _resolve_ids(model, tok, device):
70
+ s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID")
71
+ if s_env is not None and u_env is not None: return int(s_env), int(u_env)
 
72
  s, u = _infer_ids_by_name(model)
73
+ return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device)
 
 
74
 
75
+ SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE)
76
+ print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config,'id2label',None)}")
77
 
78
+ def normalize(t: str) -> str:
79
+ if not isinstance(t, str): return ""
80
+ t = unicodedata.normalize("NFKC", t).replace("","'").replace("‘","'").replace("“",'"').replace("”",'"')
81
+ t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip()
 
 
 
82
 
83
+ SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]))
 
84
  NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
85
  NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
86
+ SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"}
 
 
 
 
87
  PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
88
+ GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]))
 
 
 
89
 
90
  @torch.no_grad()
91
  def text_safe_payload(text: str) -> Dict:
92
  clean = normalize(text); toks = clean.split()
93
+ if len(toks)==1 and toks[0] in PROFANITY_TERMS:
94
+ p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
95
+ if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
96
+ p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
 
 
 
 
 
97
  if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
98
+ p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
 
 
99
  if NEUTRAL_DISLIKE.match(clean):
100
  if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
101
+ enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
102
+ probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID])
103
+ return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"}
104
+ enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
105
+ logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1])
106
+ thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR
107
+ return {"safe": up<thr, "unsafe_prob": up, "label": str(int(torch.argmax(logits))), "probs": probs, "tokens": n, "reason": "short_msg_threshold" if n<=SHORT_MSG_MAX_TOKENS else "global_threshold"}
108
+
109
+ # ---------- Image ----------
 
 
 
 
 
 
 
 
 
110
  class SafetyResNet(torch.nn.Module):
111
  def __init__(self):
112
  super().__init__()
113
  base = torchvision.models.resnet50(weights=None)
114
  self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
115
  self.pool = torch.nn.AdaptiveAvgPool2d(1)
116
+ self.cls = torch.nn.Sequential(torch.nn.Linear(2048,512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512,2))
117
+ def forward(self,x): return self.cls(torch.flatten(self.pool(self.feature_extractor(x)),1))
 
 
 
 
118
 
119
+ if not IMG_DIR.exists(): raise RuntimeError(f"Missing Image dir: {IMG_DIR}")
120
+ imgM = SafetyResNet().to(DEVICE); imgM.load_state_dict(torch.load(IMG_DIR/"resnet_safety_classifier.pth", map_location=DEVICE), strict=True); imgM.eval()
121
+ img_tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ])
 
 
 
 
 
 
 
 
 
 
122
 
123
  @torch.no_grad()
124
  def image_safe_payload(pil: Image.Image) -> Dict:
125
  x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
126
  probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
127
+ up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up<IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}
 
128
 
129
+ # ---------- Audio ----------
130
+ compute_type = "float16" if DEVICE=="cuda" else "int8"
131
  asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
132
+ if not AUD_DIR.exists(): raise RuntimeError(f"Missing Audio dir: {AUD_DIR}")
133
+ text_clf = joblib.load(AUD_DIR/"text_pipeline_balanced.joblib")
134
 
135
+ def _ffmpeg_to_wav(src: bytes) -> bytes:
 
 
 
 
136
  with tempfile.TemporaryDirectory() as td:
137
+ ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src)
 
 
 
138
  try:
139
+ subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
140
+ return op.read_bytes()
141
+ except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e
142
+ except subprocess.CalledProcessError: return src
 
 
143
 
144
+ def _transcribe_wav_bytes(w: bytes) -> str:
145
+ td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav"
 
146
  try:
147
+ p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en")
148
+ return " ".join(s.text for s in segs).strip()
 
149
  finally:
150
  try: p.unlink(missing_ok=True)
151
  except Exception: pass
 
153
  except Exception: pass
154
 
155
  def audio_safe_from_bytes(raw: bytes) -> Dict:
156
+ wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav)
157
+ proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX])
158
+ return {"safe": up<AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}
 
 
159
 
160
+ # ---------- Routes (/api/*) ----------
161
  @app.get("/api/health")
162
  def health():
163
+ return {"ok":True,"device":DEVICE,"whisper":WHISPER_MODEL_NAME,
164
+ "img":{"unsafe_threshold":IMG_UNSAFE_THR,"unsafe_index":IMG_UNSAFE_INDEX},
165
+ "text_thresholds":{"TEXT_UNSAFE_THR":TEXT_UNSAFE_THR,"SHORT_MSG_MAX_TOKENS":SHORT_MSG_MAX_TOKENS,"SHORT_MSG_UNSAFE_THR":SHORT_MSG_UNSAFE_THR},
166
+ "audio":{"unsafe_index":AUDIO_UNSAFE_INDEX,"unsafe_threshold":AUDIO_UNSAFE_THR},
167
+ "safe_unsafe_indices(text_model)":{"SAFE_ID":SAFE_ID,"UNSAFE_ID":UNSAFE_ID}}
 
 
 
 
 
 
 
 
168
 
169
  @app.post("/api/check_text")
170
  def check_text(text: str = Form(...)):
171
+ if not text.strip(): raise HTTPException(400, "Empty text")
172
+ try: return text_safe_payload(text)
173
+ except Exception as e: raise HTTPException(500, f"Text screening error: {e}")
 
 
 
174
 
175
  @app.post("/api/check_image")
176
  async def check_image(file: UploadFile = File(...)):
177
  data = await file.read()
178
+ if not data: raise HTTPException(400, "Empty image")
179
+ try: pil = Image.open(io.BytesIO(data))
180
+ except Exception: raise HTTPException(400, "Invalid image")
181
+ try: return image_safe_payload(pil)
182
+ except Exception as e: raise HTTPException(500, f"Image screening error: {e}")
 
 
 
 
 
183
 
184
  @app.post("/api/check_audio")
185
  async def check_audio(file: UploadFile = File(...)):
186
  raw = await file.read()
187
+ if not raw: raise HTTPException(400, "Empty audio")
188
+ try: return audio_safe_from_bytes(raw)
189
+ except RuntimeError as e: raise HTTPException(500, f"{e}")
190
+ except Exception as e: raise HTTPException(500, f"Audio processing error: {e}")
 
 
 
 
191
 
192
+ # ---------- Static ----------
 
193
  static_dir = BASE / "static"
 
 
194
  if static_dir.exists():
195
  app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
196
+ elif (BASE/"index.html").exists():
197
  app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root")
198
  else:
199
  @app.get("/", response_class=PlainTextResponse)
200
+ def _root_fallback(): return "BubbleGuard API is running. Add index.html to repo root."