import os os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("OMP_NUM_THREADS", "2") os.environ.setdefault("MKL_NUM_THREADS", "2") import gradio as gr import torch import torch.nn.functional as F from transformers import ( DebertaTokenizer, DebertaForSequenceClassification, T5Tokenizer, T5ForConditionalGeneration, ) # keep CPU predictable torch.set_num_threads(2) torch.set_num_interop_threads(1) DETECT_REPO = "jokugeorgin/CI_MA_Detect" REFRAME_REPO = "jokugeorgin/CI_MA_Reframe" class MicroaggressionPipeline: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # ---- Load detection (DeBERTa) ---- print("Loading detection model...") self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO) self.det_mod = DebertaForSequenceClassification.from_pretrained( DETECT_REPO, num_labels=2 ).to(self.device) self.det_mod.eval() # ---- Load reframing (T5) ---- print("Loading reframing model...") self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO) self.ref_mod = T5ForConditionalGeneration.from_pretrained( REFRAME_REPO ).to(self.device) self.ref_mod.eval() # warm-up (tiny forward pass so first request is snappy) print("Warming up...") _ = self.analyze("hello", threshold=0.5, k=1) print("Ready!") @torch.no_grad() def detect(self, text: str, threshold: float = 0.5): enc = self.det_tok( text, max_length=128, truncation=True, padding="max_length", return_tensors="pt", ) enc = {k: v.to(self.device) for k, v in enc.items()} logits = self.det_mod(**enc).logits probs = F.softmax(logits, dim=1)[0] pred_idx = int(torch.argmax(logits, dim=1)) conf = float(probs[pred_idx]) is_micro = bool(pred_idx) and (conf >= threshold) return is_micro, conf, f"LABEL_{pred_idx}" @torch.no_grad() def reframe(self, text: str, k: int = 3): # capped for latency on CPU pref = f"rephrase: {text}" enc = self.ref_tok( pref, return_tensors="pt", max_length=192, truncation=True ) enc = {k: v.to(self.device) for k, v in enc.items()} out = self.ref_mod.generate( **enc, max_length=192, num_beams=4, num_return_sequences=max(1, min(k, 5)), no_repeat_ngram_size=2, do_sample=True, temperature=0.7, early_stopping=True, ) seen = set() options = [] for seq in out: s = self.ref_tok.decode(seq, skip_special_tokens=True).strip() if s and s not in seen: seen.add(s) options.append(s) if len(options) >= k: break while len(options) < k and options: options.append(options[-1]) return options[:k] def analyze(self, text: str, threshold: float = 0.5, k: int = 3): is_micro, conf, raw_label = self.detect(text, threshold=threshold) options = self.reframe(text, k=k) if is_micro else [] return is_micro, conf, raw_label, options PIPELINE = MicroaggressionPipeline() def gradio_interface(text: str, threshold: float): text = (text or "").strip() if not text: return "❌ Please enter some text", "", "", "" is_micro, conf, raw_label, options = PIPELINE.analyze( text, threshold=float(threshold), k=3 ) if is_micro: header = f"⚠️ **Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}" else: header = f"✅ **No Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}" # pad to 3 fields for the UI opts = (options + ["", "", ""])[:3] return header, opts[0], opts[1], opts[2] with gr.Blocks(title="Microaggression Analyzer") as demo: gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text") with gr.Row(): with gr.Column(): text_in = gr.Textbox( label="Enter text to analyze", placeholder="Type or paste text...", lines=3, ) thr = gr.Slider( minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold" ) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Column(): result_md = gr.Markdown(label="Result") gr.Markdown("### Suggested Reframings") with gr.Row(): opt1 = gr.Textbox(label="Option 1", lines=2) opt2 = gr.Textbox(label="Option 2", lines=2) opt3 = gr.Textbox(label="Option 3", lines=2) gr.Examples( examples=[ ["You speak good English for someone from there.", 0.5], ["Where are you really from?", 0.5], ["You're so articulate.", 0.5], ], inputs=[text_in, thr], ) analyze_btn.click( fn=gradio_interface, inputs=[text_in, thr], outputs=[result_md, opt1, opt2, opt3], # (gradio v5) optional per-event limit: # concurrency_limit="default" ) # (gradio v5) no concurrency_count; use default_concurrency_limit if you want demo.queue(default_concurrency_limit=2, max_size=16) demo.launch(show_api=True)