Spaces:
Running
Running
File size: 11,367 Bytes
d4336c3 eef2847 81c8695 eef2847 81c8695 eef2847 81c8695 eef2847 f1f4eb8 eef2847 81c8695 eef2847 81c8695 eef2847 d4336c3 81c8695 d4336c3 eef2847 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 eef2847 81c8695 d4336c3 eef2847 81c8695 d4336c3 eef2847 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 eef2847 d4336c3 81c8695 d4336c3 eef2847 d4336c3 81c8695 eef2847 d4336c3 f1f4eb8 d4336c3 81c8695 eef2847 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 eef2847 d4336c3 81c8695 d4336c3 eef2847 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 d4336c3 81c8695 eef2847 f1f4eb8 81c8695 eef2847 d4336c3 81c8695 d4336c3 eef2847 81c8695 d4336c3 81c8695 eef2847 81c8695 d4336c3 81c8695 eef2847 81c8695 d4336c3 81c8695 eef2847 81c8695 d4336c3 81c8695 d4336c3 f1f4eb8 d4336c3 f1f4eb8 81c8695 d4336c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# app_server.py β BubbleGuard API + Web UI
# Version: 1.7.1 (/api/* routes + repo-root UI support)
import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
from typing import Dict, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import PlainTextResponse
import torch, joblib, torchvision
from torchvision import transforms
from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
from PIL import Image
from faster_whisper import WhisperModel
BASE = pathlib.Path(__file__).resolve().parent
TEXT_DIR = BASE / "Text"
IMG_DIR = BASE / "Image"
AUD_DIR = BASE / "Audio"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_UNSAFE_THR = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
TEXT_UNSAFE_THR = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
SHORT_MSG_MAX_TOKENS = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
SHORT_MSG_UNSAFE_THR = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
AUDIO_UNSAFE_THR = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))
app = FastAPI(title="BubbleGuard API", version="1.7.1")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# ---------- Text model ----------
if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}")
tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()
SAFE_LABEL_HINTS = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"}
UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"}
def _infer_ids_by_name(model):
try:
id2label = getattr(model.config, "id2label", {})
norm = {}
for k, v in id2label.items():
try: ki = int(k)
except Exception:
try: ki = int(str(k).strip())
except Exception: continue
norm[ki] = str(v).lower()
s = u = None
for i, name in norm.items():
if any(h in name for h in SAFE_LABEL_HINTS): s = i
if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
if s is not None and u is None: u = 1 - s
if u is not None and s is None: s = 1 - u
return s, u
except Exception:
return None, None
@torch.no_grad()
def _infer_ids_by_probe(model, tok, device):
enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64)
enc = {k:v.to(device) for k,v in enc.items()}
probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
s = int(torch.argmax(probs)); return s, 1 - s
def _resolve_ids(model, tok, device):
s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID")
if s_env is not None and u_env is not None: return int(s_env), int(u_env)
s, u = _infer_ids_by_name(model)
return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device)
SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE)
print(f"[BubbleGuard] SAFE_ID={SAFE_ID} UNSAFE_ID={UNSAFE_ID} id2label={getattr(txtM.config,'id2label',None)}")
def normalize(t: str) -> str:
if not isinstance(t, str): return ""
t = unicodedata.normalize("NFKC", t).replace("β","'").replace("β","'").replace("β",'"').replace("β",'"')
t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip()
SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]))
NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"}
PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]))
@torch.no_grad()
def text_safe_payload(text: str) -> Dict:
clean = normalize(text); toks = clean.split()
if len(toks)==1 and toks[0] in PROFANITY_TERMS:
p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
if NEUTRAL_DISLIKE.match(clean):
if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID])
return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"}
enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1])
thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR
return {"safe": up<thr, "unsafe_prob": up, "label": str(int(torch.argmax(logits))), "probs": probs, "tokens": n, "reason": "short_msg_threshold" if n<=SHORT_MSG_MAX_TOKENS else "global_threshold"}
# ---------- Image ----------
class SafetyResNet(torch.nn.Module):
def __init__(self):
super().__init__()
base = torchvision.models.resnet50(weights=None)
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
self.pool = torch.nn.AdaptiveAvgPool2d(1)
self.cls = torch.nn.Sequential(torch.nn.Linear(2048,512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512,2))
def forward(self,x): return self.cls(torch.flatten(self.pool(self.feature_extractor(x)),1))
if not IMG_DIR.exists(): raise RuntimeError(f"Missing Image dir: {IMG_DIR}")
imgM = SafetyResNet().to(DEVICE); imgM.load_state_dict(torch.load(IMG_DIR/"resnet_safety_classifier.pth", map_location=DEVICE), strict=True); imgM.eval()
img_tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ])
@torch.no_grad()
def image_safe_payload(pil: Image.Image) -> Dict:
x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up<IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}
# ---------- Audio ----------
compute_type = "float16" if DEVICE=="cuda" else "int8"
asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
if not AUD_DIR.exists(): raise RuntimeError(f"Missing Audio dir: {AUD_DIR}")
text_clf = joblib.load(AUD_DIR/"text_pipeline_balanced.joblib")
def _ffmpeg_to_wav(src: bytes) -> bytes:
with tempfile.TemporaryDirectory() as td:
ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src)
try:
subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return op.read_bytes()
except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e
except subprocess.CalledProcessError: return src
def _transcribe_wav_bytes(w: bytes) -> str:
td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav"
try:
p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en")
return " ".join(s.text for s in segs).strip()
finally:
try: p.unlink(missing_ok=True)
except Exception: pass
try: pathlib.Path(td).rmdir()
except Exception: pass
def audio_safe_from_bytes(raw: bytes) -> Dict:
wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav)
proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX])
return {"safe": up<AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}
# ---------- Routes (/api/*) ----------
@app.get("/api/health")
def health():
return {"ok":True,"device":DEVICE,"whisper":WHISPER_MODEL_NAME,
"img":{"unsafe_threshold":IMG_UNSAFE_THR,"unsafe_index":IMG_UNSAFE_INDEX},
"text_thresholds":{"TEXT_UNSAFE_THR":TEXT_UNSAFE_THR,"SHORT_MSG_MAX_TOKENS":SHORT_MSG_MAX_TOKENS,"SHORT_MSG_UNSAFE_THR":SHORT_MSG_UNSAFE_THR},
"audio":{"unsafe_index":AUDIO_UNSAFE_INDEX,"unsafe_threshold":AUDIO_UNSAFE_THR},
"safe_unsafe_indices(text_model)":{"SAFE_ID":SAFE_ID,"UNSAFE_ID":UNSAFE_ID}}
@app.post("/api/check_text")
def check_text(text: str = Form(...)):
if not text.strip(): raise HTTPException(400, "Empty text")
try: return text_safe_payload(text)
except Exception as e: raise HTTPException(500, f"Text screening error: {e}")
@app.post("/api/check_image")
async def check_image(file: UploadFile = File(...)):
data = await file.read()
if not data: raise HTTPException(400, "Empty image")
try: pil = Image.open(io.BytesIO(data))
except Exception: raise HTTPException(400, "Invalid image")
try: return image_safe_payload(pil)
except Exception as e: raise HTTPException(500, f"Image screening error: {e}")
@app.post("/api/check_audio")
async def check_audio(file: UploadFile = File(...)):
raw = await file.read()
if not raw: raise HTTPException(400, "Empty audio")
try: return audio_safe_from_bytes(raw)
except RuntimeError as e: raise HTTPException(500, f"{e}")
except Exception as e: raise HTTPException(500, f"Audio processing error: {e}")
# ---------- Static ----------
static_dir = BASE / "static"
if static_dir.exists():
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
elif (BASE/"index.html").exists():
app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root")
else:
@app.get("/", response_class=PlainTextResponse)
def _root_fallback(): return "BubbleGuard API is running. Add index.html to repo root."
|