File size: 11,367 Bytes
d4336c3
eef2847
 
 
81c8695
 
 
 
 
eef2847
81c8695
 
 
 
 
 
 
eef2847
 
81c8695
 
eef2847
f1f4eb8
eef2847
 
 
 
81c8695
eef2847
81c8695
eef2847
d4336c3
81c8695
d4336c3
 
eef2847
 
 
d4336c3
 
81c8695
d4336c3
81c8695
d4336c3
81c8695
 
d4336c3
 
 
 
81c8695
eef2847
81c8695
d4336c3
eef2847
 
 
 
81c8695
 
 
 
d4336c3
 
 
eef2847
d4336c3
81c8695
d4336c3
 
 
81c8695
d4336c3
81c8695
d4336c3
 
81c8695
d4336c3
 
 
 
eef2847
d4336c3
81c8695
 
d4336c3
eef2847
d4336c3
81c8695
 
 
eef2847
d4336c3
 
 
 
f1f4eb8
d4336c3
81c8695
eef2847
d4336c3
 
 
 
 
 
 
 
 
81c8695
 
 
 
 
 
d4336c3
 
81c8695
d4336c3
 
 
81c8695
 
eef2847
 
 
d4336c3
81c8695
d4336c3
 
eef2847
d4336c3
 
81c8695
d4336c3
81c8695
d4336c3
81c8695
d4336c3
 
 
 
81c8695
d4336c3
 
81c8695
d4336c3
 
81c8695
eef2847
f1f4eb8
 
 
81c8695
eef2847
d4336c3
 
 
81c8695
d4336c3
eef2847
81c8695
d4336c3
 
 
 
 
81c8695
eef2847
81c8695
d4336c3
 
 
81c8695
eef2847
81c8695
 
d4336c3
 
 
 
 
81c8695
eef2847
81c8695
 
d4336c3
 
 
 
81c8695
d4336c3
f1f4eb8
 
 
d4336c3
f1f4eb8
81c8695
 
d4336c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# app_server.py β€” BubbleGuard API + Web UI
# Version: 1.7.1 (/api/* routes + repo-root UI support)

import io, os, re, uuid, pathlib, tempfile, subprocess, unicodedata
from typing import Dict, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import PlainTextResponse
import torch, joblib, torchvision
from torchvision import transforms
from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification
from PIL import Image
from faster_whisper import WhisperModel

BASE = pathlib.Path(__file__).resolve().parent
TEXT_DIR = BASE / "Text"
IMG_DIR  = BASE / "Image"
AUD_DIR  = BASE / "Audio"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_UNSAFE_THR   = float(os.getenv("IMG_UNSAFE_THR", "0.5"))
IMG_UNSAFE_INDEX = int(os.getenv("IMG_UNSAFE_INDEX", "1"))
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
TEXT_UNSAFE_THR        = float(os.getenv("TEXT_UNSAFE_THR", "0.60"))
SHORT_MSG_MAX_TOKENS   = int(os.getenv("SHORT_MSG_MAX_TOKENS", "6"))
SHORT_MSG_UNSAFE_THR   = float(os.getenv("SHORT_MSG_UNSAFE_THR", "0.90"))
AUDIO_UNSAFE_INDEX = int(os.getenv("AUDIO_UNSAFE_INDEX", "1"))
AUDIO_UNSAFE_THR   = float(os.getenv("AUDIO_UNSAFE_THR", "0.50"))

app = FastAPI(title="BubbleGuard API", version="1.7.1")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

# ---------- Text model ----------
if not TEXT_DIR.exists(): raise RuntimeError(f"Missing Text dir: {TEXT_DIR}")
tok = RobertaTokenizerFast.from_pretrained(TEXT_DIR, local_files_only=True)
txtM = AutoModelForSequenceClassification.from_pretrained(TEXT_DIR, local_files_only=True).to(DEVICE).eval()

SAFE_LABEL_HINTS   = {"safe","ok","clean","benign","non-toxic","non_toxic","non toxic"}
UNSAFE_LABEL_HINTS = {"unsafe","toxic","abuse","harm","offense","nsfw","not_safe","not safe"}

def _infer_ids_by_name(model):
    try:
        id2label = getattr(model.config, "id2label", {})
        norm = {}
        for k, v in id2label.items():
            try: ki = int(k)
            except Exception: 
                try: ki = int(str(k).strip())
                except Exception: continue
            norm[ki] = str(v).lower()
        s = u = None
        for i, name in norm.items():
            if any(h in name for h in SAFE_LABEL_HINTS): s = i
            if any(h in name for h in UNSAFE_LABEL_HINTS): u = i
        if s is not None and u is None: u = 1 - s
        if u is not None and s is None: s = 1 - u
        return s, u
    except Exception:
        return None, None

@torch.no_grad()
def _infer_ids_by_probe(model, tok, device):
    enc = tok(["hi","hello","how are you","nice to meet you","thanks"], return_tensors="pt", truncation=True, padding=True, max_length=64)
    enc = {k:v.to(device) for k,v in enc.items()}
    probs = torch.softmax(model(**enc).logits, dim=-1).mean(0)
    s = int(torch.argmax(probs)); return s, 1 - s

def _resolve_ids(model, tok, device):
    s_env, u_env = os.getenv("SAFE_ID"), os.getenv("UNSAFE_ID")
    if s_env is not None and u_env is not None: return int(s_env), int(u_env)
    s, u = _infer_ids_by_name(model)
    return (s, u) if (s is not None and u is not None) else _infer_ids_by_probe(model, tok, device)

SAFE_ID, UNSAFE_ID = _resolve_ids(txtM, tok, DEVICE)
print(f"[BubbleGuard] SAFE_ID={SAFE_ID}  UNSAFE_ID={UNSAFE_ID}  id2label={getattr(txtM.config,'id2label',None)}")

def normalize(t: str) -> str:
    if not isinstance(t, str): return ""
    t = unicodedata.normalize("NFKC", t).replace("’","'").replace("β€˜","'").replace("β€œ",'"').replace("”",'"')
    t = re.sub(r"[^a-z0-9\s']", " ", t.lower()); return re.sub(r"\s+", " ", t).strip()

SAFE_RE = re.compile("|".join([r"^i don'?t$", r"^i do not$", r"^don'?t$", r"^no$", r"^not really$", r"^i woulde?n'?t$", r"^i don'?t like$"]))
NEGATION_ONLY = re.compile(r"^(?:i\s+)?(?:do\s+not|don'?t|no|not)$")
NEUTRAL_DISLIKE = re.compile(r"^i don'?t like(?:\s+to)?\b")
SENSITIVE_TERMS = {"people","you","him","her","them","men","women","girls","boys","muslim","christian","jew","jews","black","white","asian","gay","lesbian","trans","transgender","disabled","immigrants","refugees","poor","old","elderly","fat","skinny"}
PROFANITY_TERMS = {"fuck","shit","bitch","pussy","dick","cunt","slut","whore"}
GREETING_RE = re.compile("|".join([r"^hi$", r"^hello$", r"^hey(?: there)?$", r"^how are (?:you|u)\b.*$", r"^good (?:morning|afternoon|evening)\b.*$", r"^what'?s up\b.*$", r"^how'?s it going\b.*$"]))

@torch.no_grad()
def text_safe_payload(text: str) -> Dict:
    clean = normalize(text); toks = clean.split()
    if len(toks)==1 and toks[0] in PROFANITY_TERMS:
        p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":1,"reason":"profanity_single_word"}
    if len(toks)<=SHORT_MSG_MAX_TOKENS and any(t in PROFANITY_TERMS for t in toks):
        p=[0,0]; p[UNSAFE_ID]=1.; return {"safe":False,"unsafe_prob":1.0,"label":"UNSAFE","probs":p,"tokens":len(toks),"reason":"profanity_short_text"}
    if SAFE_RE.match(clean) or NEGATION_ONLY.match(clean) or GREETING_RE.match(clean):
        p=[0,0]; p[SAFE_ID]=1.; return {"safe":True,"unsafe_prob":0.0,"label":"SAFE","probs":p,"tokens":len(toks),"reason":"allow_or_greeting"}
    if NEUTRAL_DISLIKE.match(clean):
        if not any(t in clean for t in SENSITIVE_TERMS) and not any(t in clean for t in PROFANITY_TERMS):
            enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
            probs = torch.softmax(txtM(**enc).logits[0], dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID])
            return {"safe": up<0.98, "unsafe_prob": up, "label":"SAFE" if up<0.98 else "UNSAFE", "probs": probs, "tokens": int(enc["input_ids"].shape[1]), "reason":"neutral_dislike_relaxed"}
    enc = tok(text, return_tensors="pt", truncation=True, padding=True, max_length=512); enc = {k:v.to(DEVICE) for k,v in enc.items()}
    logits = txtM(**enc).logits[0]; probs = torch.softmax(logits, dim=-1).cpu().tolist(); up=float(probs[UNSAFE_ID]); n=int(enc["input_ids"].shape[1])
    thr = SHORT_MSG_UNSAFE_THR if n<=SHORT_MSG_MAX_TOKENS else TEXT_UNSAFE_THR
    return {"safe": up<thr, "unsafe_prob": up, "label": str(int(torch.argmax(logits))), "probs": probs, "tokens": n, "reason": "short_msg_threshold" if n<=SHORT_MSG_MAX_TOKENS else "global_threshold"}

# ---------- Image ----------
class SafetyResNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = torchvision.models.resnet50(weights=None)
        self.feature_extractor = torch.nn.Sequential(*list(base.children())[:8])
        self.pool = torch.nn.AdaptiveAvgPool2d(1)
        self.cls  = torch.nn.Sequential(torch.nn.Linear(2048,512), torch.nn.ReLU(True), torch.nn.Dropout(0.30), torch.nn.Linear(512,2))
    def forward(self,x): return self.cls(torch.flatten(self.pool(self.feature_extractor(x)),1))

if not IMG_DIR.exists(): raise RuntimeError(f"Missing Image dir: {IMG_DIR}")
imgM = SafetyResNet().to(DEVICE); imgM.load_state_dict(torch.load(IMG_DIR/"resnet_safety_classifier.pth", map_location=DEVICE), strict=True); imgM.eval()
img_tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ])

@torch.no_grad()
def image_safe_payload(pil: Image.Image) -> Dict:
    x = img_tf(pil.convert("RGB")).unsqueeze(0).to(DEVICE)
    probs = torch.softmax(imgM(x)[0], dim=0).cpu().tolist()
    up = float(probs[IMG_UNSAFE_INDEX]); return {"safe": up<IMG_UNSAFE_THR, "unsafe_prob": up, "probs": probs}

# ---------- Audio ----------
compute_type = "float16" if DEVICE=="cuda" else "int8"
asr = WhisperModel(WHISPER_MODEL_NAME, device=DEVICE, compute_type=compute_type)
if not AUD_DIR.exists(): raise RuntimeError(f"Missing Audio dir: {AUD_DIR}")
text_clf = joblib.load(AUD_DIR/"text_pipeline_balanced.joblib")

def _ffmpeg_to_wav(src: bytes) -> bytes:
    with tempfile.TemporaryDirectory() as td:
        ip = pathlib.Path(td)/"in"; op = pathlib.Path(td)/"out.wav"; ip.write_bytes(src)
        try:
            subprocess.run(["ffmpeg","-y","-i",str(ip),"-ac","1","-ar","16000",str(op)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            return op.read_bytes()
        except FileNotFoundError as e: raise RuntimeError("FFmpeg not found on PATH.") from e
        except subprocess.CalledProcessError: return src

def _transcribe_wav_bytes(w: bytes) -> str:
    td = tempfile.mkdtemp(); p = pathlib.Path(td)/"in.wav"
    try:
        p.write_bytes(w); segs,_ = asr.transcribe(str(p), beam_size=5, language="en")
        return " ".join(s.text for s in segs).strip()
    finally:
        try: p.unlink(missing_ok=True)
        except Exception: pass
        try: pathlib.Path(td).rmdir()
        except Exception: pass

def audio_safe_from_bytes(raw: bytes) -> Dict:
    wav = _ffmpeg_to_wav(raw); text = _transcribe_wav_bytes(wav)
    proba = text_clf.predict_proba([text])[0].tolist(); up=float(proba[AUDIO_UNSAFE_INDEX])
    return {"safe": up<AUDIO_UNSAFE_THR, "unsafe_prob": up, "text": text, "probs": proba}

# ---------- Routes (/api/*) ----------
@app.get("/api/health")
def health():
    return {"ok":True,"device":DEVICE,"whisper":WHISPER_MODEL_NAME,
            "img":{"unsafe_threshold":IMG_UNSAFE_THR,"unsafe_index":IMG_UNSAFE_INDEX},
            "text_thresholds":{"TEXT_UNSAFE_THR":TEXT_UNSAFE_THR,"SHORT_MSG_MAX_TOKENS":SHORT_MSG_MAX_TOKENS,"SHORT_MSG_UNSAFE_THR":SHORT_MSG_UNSAFE_THR},
            "audio":{"unsafe_index":AUDIO_UNSAFE_INDEX,"unsafe_threshold":AUDIO_UNSAFE_THR},
            "safe_unsafe_indices(text_model)":{"SAFE_ID":SAFE_ID,"UNSAFE_ID":UNSAFE_ID}}

@app.post("/api/check_text")
def check_text(text: str = Form(...)):
    if not text.strip(): raise HTTPException(400, "Empty text")
    try: return text_safe_payload(text)
    except Exception as e: raise HTTPException(500, f"Text screening error: {e}")

@app.post("/api/check_image")
async def check_image(file: UploadFile = File(...)):
    data = await file.read()
    if not data: raise HTTPException(400, "Empty image")
    try: pil = Image.open(io.BytesIO(data))
    except Exception: raise HTTPException(400, "Invalid image")
    try: return image_safe_payload(pil)
    except Exception as e: raise HTTPException(500, f"Image screening error: {e}")

@app.post("/api/check_audio")
async def check_audio(file: UploadFile = File(...)):
    raw = await file.read()
    if not raw: raise HTTPException(400, "Empty audio")
    try: return audio_safe_from_bytes(raw)
    except RuntimeError as e: raise HTTPException(500, f"{e}")
    except Exception as e: raise HTTPException(500, f"Audio processing error: {e}")

# ---------- Static ----------
static_dir = BASE / "static"
if static_dir.exists():
    app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
elif (BASE/"index.html").exists():
    app.mount("/", StaticFiles(directory=str(BASE), html=True), name="static-root")
else:
    @app.get("/", response_class=PlainTextResponse)
    def _root_fallback(): return "BubbleGuard API is running. Add index.html to repo root."