BubbleGuard / app_server.py
MetiMiester's picture
Update app_server.py
d4336c3 verified
# app_server.py β€” BubbleGuard API + Web UI
# Version: 1.7.1 (/api/* routes + repo-root UI support)
import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
from typing import Dict, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import PlainTextResponse
import torch, joblib, torchvision
from torchvision import transforms
from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
from PIL import Image
from faster_whisper import WhisperModel
BASE = pathlib.Path(__file__).resolve().parent
TEXT_DIR = BASE / "Text"
IMG_DIR = BASE / "Image"
AUD_DIR = BASE / "Audio"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
app = FastAPI(title="BubbleGuard API", version="1.7.1")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# ---------- Text model ----------
if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}")
tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()
SAFE_LABEL_HINTS = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"}
UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"}
def _infer_ids_by_name(model):
try:
id2label = getattr(model.config, "id2label", {})
norm = {}
for k, v in id2label.items():
try: ki = int(k)
except Exception:
try: ki = int(str(k).strip())
except Exception: continue
norm[ki] = str(v).lower()
s = u = None
for i, name in norm.items():
if any(h in name for h in SAFE_LABEL_HINTS): s = i
if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
if s is not None and u is None: u = 1 - s
if u is not None and s is None: s = 1 - u
return s, u
except Exception:
return None, None
@torch.no_grad()
def _infer_ids_by_probe(model, tok, device):
enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64)
enc = {k:v.to(device) for k,v in enc.items()}
probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
s = int(torch.argmax(probs)); return s, 1 - s
def _resolve_ids(model, tok, device):
s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID")
if s_env is not None and u_env is not None: return int(s_env), int(u_env)
s, u = _infer_ids_by_name(model)
return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device)
SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE)
print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config,'id2label',None)}")
def normalize(t: str) -> str:
if not isinstance(t, str): return ""
t = unicodedata.normalize("NFKC", t).replace("’","'").replace("β€˜","'").replace("β€œ",'"').replace("”",'"')
t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip()
SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]))
NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"}
PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]))
@torch.no_grad()
def text_safe_payload(text: str) -> Dict:
clean = normalize(text); toks = clean.split()
if len(toks)==1 and toks[0] in PROFANITY_TERMS:
p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
if NEUTRAL_DISLIKE.match(clean):
if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID])
return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"}
enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1])
thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR
return {"safe": up<thr, "unsafe_prob": up, "label": str(int(torch.argmax(logits))), "probs": probs, "tokens": n, "reason": "short_msg_threshold" if n<=SHORT_MSG_MAX_TOKENS else "global_threshold"}
# ---------- Image ----------
class SafetyResNet(torch.nn.Module):
def __init__(self):
super().__init__()
base = torchvision.models.resnet50(weights=None)
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
self.pool = torch.nn.AdaptiveAvgPool2d(1)
self.cls = torch.nn.Sequential(torch.nn.Linear(2048,512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512,2))
def forward(self,x): return self.cls(torch.flatten(self.pool(self.feature_extractor(x)),1))
if not IMG_DIR.exists(): raise RuntimeError(f"Missing Image dir: {IMG_DIR}")
imgM = SafetyResNet().to(DEVICE); imgM.load_state_dict(torch.load(IMG_DIR/"resnet_safety_classifier.pth", map_location=DEVICE), strict=True); imgM.eval()
img_tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ])
@torch.no_grad()
def image_safe_payload(pil: Image.Image) -> Dict:
x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up<IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}
# ---------- Audio ----------
compute_type = "float16" if DEVICE=="cuda" else "int8"
asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
if not AUD_DIR.exists(): raise RuntimeError(f"Missing Audio dir: {AUD_DIR}")
text_clf = joblib.load(AUD_DIR/"text_pipeline_balanced.joblib")
def _ffmpeg_to_wav(src: bytes) -> bytes:
with tempfile.TemporaryDirectory() as td:
ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src)
try:
subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return op.read_bytes()
except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e
except subprocess.CalledProcessError: return src
def _transcribe_wav_bytes(w: bytes) -> str:
td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav"
try:
p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en")
return " ".join(s.text for s in segs).strip()
finally:
try: p.unlink(missing_ok=True)
except Exception: pass
try: pathlib.Path(td).rmdir()
except Exception: pass
def audio_safe_from_bytes(raw: bytes) -> Dict:
wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav)
proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX])
return {"safe": up<AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}
# ---------- Routes (/api/*) ----------
@app.get("/api/health")
def health():
return {"ok":True,"device":DEVICE,"whisper":WHISPER_MODEL_NAME,
"img":{"unsafe_threshold":IMG_UNSAFE_THR,"unsafe_index":IMG_UNSAFE_INDEX},
"text_thresholds":{"TEXT_UNSAFE_THR":TEXT_UNSAFE_THR,"SHORT_MSG_MAX_TOKENS":SHORT_MSG_MAX_TOKENS,"SHORT_MSG_UNSAFE_THR":SHORT_MSG_UNSAFE_THR},
"audio":{"unsafe_index":AUDIO_UNSAFE_INDEX,"unsafe_threshold":AUDIO_UNSAFE_THR},
"safe_unsafe_indices(text_model)":{"SAFE_ID":SAFE_ID,"UNSAFE_ID":UNSAFE_ID}}
@app.post("/api/check_text")
def check_text(text: str = Form(...)):
if not text.strip(): raise HTTPException(400, "Empty text")
try: return text_safe_payload(text)
except Exception as e: raise HTTPException(500, f"Text screening error: {e}")
@app.post("/api/check_image")
async def check_image(file: UploadFile = File(...)):
data = await file.read()
if not data: raise HTTPException(400, "Empty image")
try: pil = Image.open(io.BytesIO(data))
except Exception: raise HTTPException(400, "Invalid image")
try: return image_safe_payload(pil)
except Exception as e: raise HTTPException(500, f"Image screening error: {e}")
@app.post("/api/check_audio")
async def check_audio(file: UploadFile = File(...)):
raw = await file.read()
if not raw: raise HTTPException(400, "Empty audio")
try: return audio_safe_from_bytes(raw)
except RuntimeError as e: raise HTTPException(500, f"{e}")
except Exception as e: raise HTTPException(500, f"Audio processing error: {e}")
# ---------- Static ----------
static_dir = BASE / "static"
if static_dir.exists():
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
elif (BASE/"index.html").exists():
app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root")
else:
@app.get("/", response_class=PlainTextResponse)
def _root_fallback(): return "BubbleGuard API is running. Add index.html to repo root."