Spaces:
Running
Running
| # app_server.py β BubbleGuard API + Web UI | |
| # Version: 1.7.1 (/api/* routes + repo-root UI support) | |
| import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata | |
| from typing import Dict, Optional | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import PlainTextResponse | |
| import torch, joblib, torchvision | |
| from torchvision import transforms | |
| from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification | |
| from PIL import Image | |
| from faster_whisper import WhisperModel | |
| BASE = pathlib.Path(__file__).resolve().parent | |
| TEXT_DIR = BASE / "Text" | |
| IMG_DIR = BASE / "Image" | |
| AUD_DIR = BASE / "Audio" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5")) | |
| IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1")) | |
| WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base") | |
| TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60")) | |
| SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6")) | |
| SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90")) | |
| AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1")) | |
| AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50")) | |
| app = FastAPI(title="BubbleGuard API", version="1.7.1") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # ---------- Text model ---------- | |
| if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}") | |
| tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True) | |
| txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval() | |
| SAFE_LABEL_HINTS = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"} | |
| UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"} | |
| def _infer_ids_by_name(model): | |
| try: | |
| id2label = getattr(model.config, "id2label", {}) | |
| norm = {} | |
| for k, v in id2label.items(): | |
| try: ki = int(k) | |
| except Exception: | |
| try: ki = int(str(k).strip()) | |
| except Exception: continue | |
| norm[ki] = str(v).lower() | |
| s = u = None | |
| for i, name in norm.items(): | |
| if any(h in name for h in SAFE_LABEL_HINTS): s = i | |
| if any(h in name for h in UNSAFE_LABEL_HINTS): u = i | |
| if s is not None and u is None: u = 1 - s | |
| if u is not None and s is None: s = 1 - u | |
| return s, u | |
| except Exception: | |
| return None, None | |
| def _infer_ids_by_probe(model, tok, device): | |
| enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64) | |
| enc = {k:v.to(device) for k,v in enc.items()} | |
| probs = torch.softmax(model(**enc).logits, dim=-1).mean(0) | |
| s = int(torch.argmax(probs)); return s, 1 - s | |
| def _resolve_ids(model, tok, device): | |
| s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID") | |
| if s_env is not None and u_env is not None: return int(s_env), int(u_env) | |
| s, u = _infer_ids_by_name(model) | |
| return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device) | |
| SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE) | |
| print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config,'id2label',None)}") | |
| def normalize(t: str) -> str: | |
| if not isinstance(t, str): return "" | |
| t = unicodedata.normalize("NFKC", t).replace("β","'").replace("β","'").replace("β",'"').replace("β",'"') | |
| t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip() | |
| SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"])) | |
| NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$") | |
| NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b") | |
| SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"} | |
| PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"} | |
| GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"])) | |
| def text_safe_payload(text: str) -> Dict: | |
| clean = normalize(text); toks = clean.split() | |
| if len(toks)==1 and toks[0] in PROFANITY_TERMS: | |
| p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"} | |
| if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks): | |
| p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"} | |
| if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean): | |
| p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"} | |
| if NEUTRAL_DISLIKE.match(clean): | |
| if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS): | |
| enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()} | |
| probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]) | |
| return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"} | |
| enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()} | |
| logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1]) | |
| thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR | |
| return {"safe": up<thr, "unsafe_prob": up, "label": str(int(torch.argmax(logits))), "probs": probs, "tokens": n, "reason": "short_msg_threshold" if n<=SHORT_MSG_MAX_TOKENS else "global_threshold"} | |
| # ---------- Image ---------- | |
| class SafetyResNet(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| base = torchvision.models.resnet50(weights=None) | |
| self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8]) | |
| self.pool = torch.nn.AdaptiveAvgPool2d(1) | |
| self.cls = torch.nn.Sequential(torch.nn.Linear(2048,512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512,2)) | |
| def forward(self,x): return self.cls(torch.flatten(self.pool(self.feature_extractor(x)),1)) | |
| if not IMG_DIR.exists(): raise RuntimeError(f"Missing Image dir: {IMG_DIR}") | |
| imgM = SafetyResNet().to(DEVICE); imgM.load_state_dict(torch.load(IMG_DIR/"resnet_safety_classifier.pth", map_location=DEVICE), strict=True); imgM.eval() | |
| img_tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]) | |
| def image_safe_payload(pil: Image.Image) -> Dict: | |
| x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE) | |
| probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist() | |
| up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up<IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs} | |
| # ---------- Audio ---------- | |
| compute_type = "float16" if DEVICE=="cuda" else "int8" | |
| asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type) | |
| if not AUD_DIR.exists(): raise RuntimeError(f"Missing Audio dir: {AUD_DIR}") | |
| text_clf = joblib.load(AUD_DIR/"text_pipeline_balanced.joblib") | |
| def _ffmpeg_to_wav(src: bytes) -> bytes: | |
| with tempfile.TemporaryDirectory() as td: | |
| ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src) | |
| try: | |
| subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| return op.read_bytes() | |
| except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e | |
| except subprocess.CalledProcessError: return src | |
| def _transcribe_wav_bytes(w: bytes) -> str: | |
| td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav" | |
| try: | |
| p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en") | |
| return " ".join(s.text for s in segs).strip() | |
| finally: | |
| try: p.unlink(missing_ok=True) | |
| except Exception: pass | |
| try: pathlib.Path(td).rmdir() | |
| except Exception: pass | |
| def audio_safe_from_bytes(raw: bytes) -> Dict: | |
| wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav) | |
| proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX]) | |
| return {"safe": up<AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba} | |
| # ---------- Routes (/api/*) ---------- | |
| def health(): | |
| return {"ok":True,"device":DEVICE,"whisper":WHISPER_MODEL_NAME, | |
| "img":{"unsafe_threshold":IMG_UNSAFE_THR,"unsafe_index":IMG_UNSAFE_INDEX}, | |
| "text_thresholds":{"TEXT_UNSAFE_THR":TEXT_UNSAFE_THR,"SHORT_MSG_MAX_TOKENS":SHORT_MSG_MAX_TOKENS,"SHORT_MSG_UNSAFE_THR":SHORT_MSG_UNSAFE_THR}, | |
| "audio":{"unsafe_index":AUDIO_UNSAFE_INDEX,"unsafe_threshold":AUDIO_UNSAFE_THR}, | |
| "safe_unsafe_indices(text_model)":{"SAFE_ID":SAFE_ID,"UNSAFE_ID":UNSAFE_ID}} | |
| def check_text(text: str = Form(...)): | |
| if not text.strip(): raise HTTPException(400, "Empty text") | |
| try: return text_safe_payload(text) | |
| except Exception as e: raise HTTPException(500, f"Text screening error: {e}") | |
| async def check_image(file: UploadFile = File(...)): | |
| data = await file.read() | |
| if not data: raise HTTPException(400, "Empty image") | |
| try: pil = Image.open(io.BytesIO(data)) | |
| except Exception: raise HTTPException(400, "Invalid image") | |
| try: return image_safe_payload(pil) | |
| except Exception as e: raise HTTPException(500, f"Image screening error: {e}") | |
| async def check_audio(file: UploadFile = File(...)): | |
| raw = await file.read() | |
| if not raw: raise HTTPException(400, "Empty audio") | |
| try: return audio_safe_from_bytes(raw) | |
| except RuntimeError as e: raise HTTPException(500, f"{e}") | |
| except Exception as e: raise HTTPException(500, f"Audio processing error: {e}") | |
| # ---------- Static ---------- | |
| static_dir = BASE / "static" | |
| if static_dir.exists(): | |
| app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") | |
| elif (BASE/"index.html").exists(): | |
| app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root") | |
| else: | |
| def _root_fallback(): return "BubbleGuard API is running. Add index.html to repo root." | |