|
|
import os |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "2") |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "2") |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
DebertaTokenizer, |
|
|
DebertaForSequenceClassification, |
|
|
T5Tokenizer, |
|
|
T5ForConditionalGeneration, |
|
|
) |
|
|
|
|
|
|
|
|
torch.set_num_threads(2) |
|
|
torch.set_num_interop_threads(1) |
|
|
|
|
|
DETECT_REPO = "jokugeorgin/CI_MA_Detect" |
|
|
REFRAME_REPO = "jokugeorgin/CI_MA_Reframe" |
|
|
|
|
|
|
|
|
class MicroaggressionPipeline: |
|
|
def __init__(self): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
print("Loading detection model...") |
|
|
self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO) |
|
|
self.det_mod = DebertaForSequenceClassification.from_pretrained( |
|
|
DETECT_REPO, num_labels=2 |
|
|
).to(self.device) |
|
|
self.det_mod.eval() |
|
|
|
|
|
|
|
|
print("Loading reframing model...") |
|
|
self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO) |
|
|
self.ref_mod = T5ForConditionalGeneration.from_pretrained( |
|
|
REFRAME_REPO |
|
|
).to(self.device) |
|
|
self.ref_mod.eval() |
|
|
|
|
|
|
|
|
print("Warming up...") |
|
|
_ = self.analyze("hello", threshold=0.5, k=1) |
|
|
print("Ready!") |
|
|
|
|
|
@torch.no_grad() |
|
|
def detect(self, text: str, threshold: float = 0.5): |
|
|
enc = self.det_tok( |
|
|
text, |
|
|
max_length=128, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
logits = self.det_mod(**enc).logits |
|
|
probs = F.softmax(logits, dim=1)[0] |
|
|
pred_idx = int(torch.argmax(logits, dim=1)) |
|
|
conf = float(probs[pred_idx]) |
|
|
|
|
|
is_micro = bool(pred_idx) and (conf >= threshold) |
|
|
return is_micro, conf, f"LABEL_{pred_idx}" |
|
|
|
|
|
@torch.no_grad() |
|
|
def reframe(self, text: str, k: int = 3): |
|
|
|
|
|
pref = f"rephrase: {text}" |
|
|
enc = self.ref_tok( |
|
|
pref, return_tensors="pt", max_length=192, truncation=True |
|
|
) |
|
|
enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
out = self.ref_mod.generate( |
|
|
**enc, |
|
|
max_length=192, |
|
|
num_beams=4, |
|
|
num_return_sequences=max(1, min(k, 5)), |
|
|
no_repeat_ngram_size=2, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
early_stopping=True, |
|
|
) |
|
|
seen = set() |
|
|
options = [] |
|
|
for seq in out: |
|
|
s = self.ref_tok.decode(seq, skip_special_tokens=True).strip() |
|
|
if s and s not in seen: |
|
|
seen.add(s) |
|
|
options.append(s) |
|
|
if len(options) >= k: |
|
|
break |
|
|
while len(options) < k and options: |
|
|
options.append(options[-1]) |
|
|
return options[:k] |
|
|
|
|
|
def analyze(self, text: str, threshold: float = 0.5, k: int = 3): |
|
|
is_micro, conf, raw_label = self.detect(text, threshold=threshold) |
|
|
options = self.reframe(text, k=k) if is_micro else [] |
|
|
return is_micro, conf, raw_label, options |
|
|
|
|
|
|
|
|
PIPELINE = MicroaggressionPipeline() |
|
|
|
|
|
|
|
|
def gradio_interface(text: str, threshold: float): |
|
|
text = (text or "").strip() |
|
|
if not text: |
|
|
return "❌ Please enter some text", "", "", "" |
|
|
|
|
|
is_micro, conf, raw_label, options = PIPELINE.analyze( |
|
|
text, threshold=float(threshold), k=3 |
|
|
) |
|
|
|
|
|
if is_micro: |
|
|
header = f"⚠️ **Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}" |
|
|
else: |
|
|
header = f"✅ **No Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}" |
|
|
|
|
|
|
|
|
opts = (options + ["", "", ""])[:3] |
|
|
return header, opts[0], opts[1], opts[2] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Microaggression Analyzer") as demo: |
|
|
gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text_in = gr.Textbox( |
|
|
label="Enter text to analyze", |
|
|
placeholder="Type or paste text...", |
|
|
lines=3, |
|
|
) |
|
|
thr = gr.Slider( |
|
|
minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold" |
|
|
) |
|
|
analyze_btn = gr.Button("Analyze", variant="primary") |
|
|
with gr.Column(): |
|
|
result_md = gr.Markdown(label="Result") |
|
|
|
|
|
gr.Markdown("### Suggested Reframings") |
|
|
with gr.Row(): |
|
|
opt1 = gr.Textbox(label="Option 1", lines=2) |
|
|
opt2 = gr.Textbox(label="Option 2", lines=2) |
|
|
opt3 = gr.Textbox(label="Option 3", lines=2) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["You speak good English for someone from there.", 0.5], |
|
|
["Where are you really from?", 0.5], |
|
|
["You're so articulate.", 0.5], |
|
|
], |
|
|
inputs=[text_in, thr], |
|
|
) |
|
|
|
|
|
analyze_btn.click( |
|
|
fn=gradio_interface, |
|
|
inputs=[text_in, thr], |
|
|
outputs=[result_md, opt1, opt2, opt3], |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(default_concurrency_limit=2, max_size=16) |
|
|
demo.launch(show_api=True) |
|
|
|