File size: 5,566 Bytes
bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 66fb10b bacbc00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import (
DebertaTokenizer,
DebertaForSequenceClassification,
T5Tokenizer,
T5ForConditionalGeneration,
)
# keep CPU predictable
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
DETECT_REPO = "jokugeorgin/CI_MA_Detect"
REFRAME_REPO = "jokugeorgin/CI_MA_Reframe"
class MicroaggressionPipeline:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# ---- Load detection (DeBERTa) ----
print("Loading detection model...")
self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO)
self.det_mod = DebertaForSequenceClassification.from_pretrained(
DETECT_REPO, num_labels=2
).to(self.device)
self.det_mod.eval()
# ---- Load reframing (T5) ----
print("Loading reframing model...")
self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO)
self.ref_mod = T5ForConditionalGeneration.from_pretrained(
REFRAME_REPO
).to(self.device)
self.ref_mod.eval()
# warm-up (tiny forward pass so first request is snappy)
print("Warming up...")
_ = self.analyze("hello", threshold=0.5, k=1)
print("Ready!")
@torch.no_grad()
def detect(self, text: str, threshold: float = 0.5):
enc = self.det_tok(
text,
max_length=128,
truncation=True,
padding="max_length",
return_tensors="pt",
)
enc = {k: v.to(self.device) for k, v in enc.items()}
logits = self.det_mod(**enc).logits
probs = F.softmax(logits, dim=1)[0]
pred_idx = int(torch.argmax(logits, dim=1))
conf = float(probs[pred_idx])
is_micro = bool(pred_idx) and (conf >= threshold)
return is_micro, conf, f"LABEL_{pred_idx}"
@torch.no_grad()
def reframe(self, text: str, k: int = 3):
# capped for latency on CPU
pref = f"rephrase: {text}"
enc = self.ref_tok(
pref, return_tensors="pt", max_length=192, truncation=True
)
enc = {k: v.to(self.device) for k, v in enc.items()}
out = self.ref_mod.generate(
**enc,
max_length=192,
num_beams=4,
num_return_sequences=max(1, min(k, 5)),
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7,
early_stopping=True,
)
seen = set()
options = []
for seq in out:
s = self.ref_tok.decode(seq, skip_special_tokens=True).strip()
if s and s not in seen:
seen.add(s)
options.append(s)
if len(options) >= k:
break
while len(options) < k and options:
options.append(options[-1])
return options[:k]
def analyze(self, text: str, threshold: float = 0.5, k: int = 3):
is_micro, conf, raw_label = self.detect(text, threshold=threshold)
options = self.reframe(text, k=k) if is_micro else []
return is_micro, conf, raw_label, options
PIPELINE = MicroaggressionPipeline()
def gradio_interface(text: str, threshold: float):
text = (text or "").strip()
if not text:
return "❌ Please enter some text", "", "", ""
is_micro, conf, raw_label, options = PIPELINE.analyze(
text, threshold=float(threshold), k=3
)
if is_micro:
header = f"⚠️ **Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
else:
header = f"✅ **No Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
# pad to 3 fields for the UI
opts = (options + ["", "", ""])[:3]
return header, opts[0], opts[1], opts[2]
with gr.Blocks(title="Microaggression Analyzer") as demo:
gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text")
with gr.Row():
with gr.Column():
text_in = gr.Textbox(
label="Enter text to analyze",
placeholder="Type or paste text...",
lines=3,
)
thr = gr.Slider(
minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold"
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
result_md = gr.Markdown(label="Result")
gr.Markdown("### Suggested Reframings")
with gr.Row():
opt1 = gr.Textbox(label="Option 1", lines=2)
opt2 = gr.Textbox(label="Option 2", lines=2)
opt3 = gr.Textbox(label="Option 3", lines=2)
gr.Examples(
examples=[
["You speak good English for someone from there.", 0.5],
["Where are you really from?", 0.5],
["You're so articulate.", 0.5],
],
inputs=[text_in, thr],
)
analyze_btn.click(
fn=gradio_interface,
inputs=[text_in, thr],
outputs=[result_md, opt1, opt2, opt3],
# (gradio v5) optional per-event limit:
# concurrency_limit="default"
)
# (gradio v5) no concurrency_count; use default_concurrency_limit if you want
demo.queue(default_concurrency_limit=2, max_size=16)
demo.launch(show_api=True)
|