# app_server.py — BubbleGuard API + Web UI # Version: 1.7.1 (/api/* routes + repo-root UI support) import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata from typing import Dict, Optional from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import PlainTextResponse import torch, joblib, torchvision from torchvision import transforms from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification from PIL import Image from faster_whisper import WhisperModel BASE = pathlib.Path(__file__).resolve().parent TEXT_DIR = BASE / "Text" IMG_DIR = BASE / "Image" AUD_DIR = BASE / "Audio" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5")) IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1")) WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base") TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60")) SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6")) SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90")) AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1")) AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50")) app = FastAPI(title="BubbleGuard API", version="1.7.1") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) # ---------- Text model ---------- if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}") tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True) txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval() SAFE_LABEL_HINTS = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"} UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"} def _infer_ids_by_name(model): try: id2label = getattr(model.config, "id2label", {}) norm = {} for k, v in id2label.items(): try: ki = int(k) except Exception: try: ki = int(str(k).strip()) except Exception: continue norm[ki] = str(v).lower() s = u = None for i, name in norm.items(): if any(h in name for h in SAFE_LABEL_HINTS): s = i if any(h in name for h in UNSAFE_LABEL_HINTS): u = i if s is not None and u is None: u = 1 - s if u is not None and s is None: s = 1 - u return s, u except Exception: return None, None @torch.no_grad() def _infer_ids_by_probe(model, tok, device): enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64) enc = {k:v.to(device) for k,v in enc.items()} probs = torch.softmax(model(**enc).logits, dim=-1).mean(0) s = int(torch.argmax(probs)); return s, 1 - s def _resolve_ids(model, tok, device): s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID") if s_env is not None and u_env is not None: return int(s_env), int(u_env) s, u = _infer_ids_by_name(model) return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device) SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE) print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config,'id2label',None)}") def normalize(t: str) -> str: if not isinstance(t, str): return "" t = unicodedata.normalize("NFKC", t).replace("’","'").replace("‘","'").replace("“",'"').replace("”",'"') t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip() SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"])) NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$") NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b") SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"} PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"} GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"])) @torch.no_grad() def text_safe_payload(text: str) -> Dict: clean = normalize(text); toks = clean.split() if len(toks)==1 and toks[0] in PROFANITY_TERMS: p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"} if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks): p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"} if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean): p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"} if NEUTRAL_DISLIKE.match(clean): if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS): enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()} probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]) return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"} enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()} logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1]) thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR return {"safe": up Dict: x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE) probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist() up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up bytes: with tempfile.TemporaryDirectory() as td: ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src) try: subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return op.read_bytes() except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e except subprocess.CalledProcessError: return src def _transcribe_wav_bytes(w: bytes) -> str: td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav" try: p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en") return " ".join(s.text for s in segs).strip() finally: try: p.unlink(missing_ok=True) except Exception: pass try: pathlib.Path(td).rmdir() except Exception: pass def audio_safe_from_bytes(raw: bytes) -> Dict: wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav) proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX]) return {"safe": up