jokugeorgin's picture
Update app.py
66fb10b verified
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import (
DebertaTokenizer,
DebertaForSequenceClassification,
T5Tokenizer,
T5ForConditionalGeneration,
)
# keep CPU predictable
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
DETECT_REPO = "jokugeorgin/CI_MA_Detect"
REFRAME_REPO = "jokugeorgin/CI_MA_Reframe"
class MicroaggressionPipeline:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# ---- Load detection (DeBERTa) ----
print("Loading detection model...")
self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO)
self.det_mod = DebertaForSequenceClassification.from_pretrained(
DETECT_REPO, num_labels=2
).to(self.device)
self.det_mod.eval()
# ---- Load reframing (T5) ----
print("Loading reframing model...")
self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO)
self.ref_mod = T5ForConditionalGeneration.from_pretrained(
REFRAME_REPO
).to(self.device)
self.ref_mod.eval()
# warm-up (tiny forward pass so first request is snappy)
print("Warming up...")
_ = self.analyze("hello", threshold=0.5, k=1)
print("Ready!")
@torch.no_grad()
def detect(self, text: str, threshold: float = 0.5):
enc = self.det_tok(
text,
max_length=128,
truncation=True,
padding="max_length",
return_tensors="pt",
)
enc = {k: v.to(self.device) for k, v in enc.items()}
logits = self.det_mod(**enc).logits
probs = F.softmax(logits, dim=1)[0]
pred_idx = int(torch.argmax(logits, dim=1))
conf = float(probs[pred_idx])
is_micro = bool(pred_idx) and (conf >= threshold)
return is_micro, conf, f"LABEL_{pred_idx}"
@torch.no_grad()
def reframe(self, text: str, k: int = 3):
# capped for latency on CPU
pref = f"rephrase: {text}"
enc = self.ref_tok(
pref, return_tensors="pt", max_length=192, truncation=True
)
enc = {k: v.to(self.device) for k, v in enc.items()}
out = self.ref_mod.generate(
**enc,
max_length=192,
num_beams=4,
num_return_sequences=max(1, min(k, 5)),
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7,
early_stopping=True,
)
seen = set()
options = []
for seq in out:
s = self.ref_tok.decode(seq, skip_special_tokens=True).strip()
if s and s not in seen:
seen.add(s)
options.append(s)
if len(options) >= k:
break
while len(options) < k and options:
options.append(options[-1])
return options[:k]
def analyze(self, text: str, threshold: float = 0.5, k: int = 3):
is_micro, conf, raw_label = self.detect(text, threshold=threshold)
options = self.reframe(text, k=k) if is_micro else []
return is_micro, conf, raw_label, options
PIPELINE = MicroaggressionPipeline()
def gradio_interface(text: str, threshold: float):
text = (text or "").strip()
if not text:
return "❌ Please enter some text", "", "", ""
is_micro, conf, raw_label, options = PIPELINE.analyze(
text, threshold=float(threshold), k=3
)
if is_micro:
header = f"⚠️ **Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
else:
header = f"✅ **No Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
# pad to 3 fields for the UI
opts = (options + ["", "", ""])[:3]
return header, opts[0], opts[1], opts[2]
with gr.Blocks(title="Microaggression Analyzer") as demo:
gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text")
with gr.Row():
with gr.Column():
text_in = gr.Textbox(
label="Enter text to analyze",
placeholder="Type or paste text...",
lines=3,
)
thr = gr.Slider(
minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold"
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
result_md = gr.Markdown(label="Result")
gr.Markdown("### Suggested Reframings")
with gr.Row():
opt1 = gr.Textbox(label="Option 1", lines=2)
opt2 = gr.Textbox(label="Option 2", lines=2)
opt3 = gr.Textbox(label="Option 3", lines=2)
gr.Examples(
examples=[
["You speak good English for someone from there.", 0.5],
["Where are you really from?", 0.5],
["You're so articulate.", 0.5],
],
inputs=[text_in, thr],
)
analyze_btn.click(
fn=gradio_interface,
inputs=[text_in, thr],
outputs=[result_md, opt1, opt2, opt3],
# (gradio v5) optional per-event limit:
# concurrency_limit="default"
)
# (gradio v5) no concurrency_count; use default_concurrency_limit if you want
demo.queue(default_concurrency_limit=2, max_size=16)
demo.launch(show_api=True)