File size: 5,566 Bytes
bacbc00
 
 
 
 
 
 
 
 
 
 
 
66fb10b
bacbc00
 
66fb10b
bacbc00
 
 
66fb10b
 
 
 
bacbc00
 
 
 
 
66fb10b
bacbc00
66fb10b
 
 
bacbc00
66fb10b
bacbc00
66fb10b
bacbc00
66fb10b
 
 
bacbc00
66fb10b
bacbc00
66fb10b
bacbc00
66fb10b
bacbc00
 
 
66fb10b
 
 
 
 
 
 
bacbc00
 
66fb10b
bacbc00
 
66fb10b
 
 
 
bacbc00
66fb10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bacbc00
66fb10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bacbc00
 
 
 
66fb10b
 
bacbc00
 
66fb10b
 
 
 
 
 
bacbc00
66fb10b
 
bacbc00
 
 
 
 
 
66fb10b
 
 
 
 
 
 
 
bacbc00
 
66fb10b
bacbc00
 
 
66fb10b
 
 
bacbc00
 
 
 
 
 
 
66fb10b
bacbc00
 
 
 
66fb10b
 
 
 
bacbc00
 
66fb10b
 
bacbc00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")

import gradio as gr
import torch
import torch.nn.functional as F
from transformers import (
    DebertaTokenizer,
    DebertaForSequenceClassification,
    T5Tokenizer,
    T5ForConditionalGeneration,
)

# keep CPU predictable
torch.set_num_threads(2)
torch.set_num_interop_threads(1)

DETECT_REPO = "jokugeorgin/CI_MA_Detect"
REFRAME_REPO = "jokugeorgin/CI_MA_Reframe"


class MicroaggressionPipeline:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # ---- Load detection (DeBERTa) ----
        print("Loading detection model...")
        self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO)
        self.det_mod = DebertaForSequenceClassification.from_pretrained(
            DETECT_REPO, num_labels=2
        ).to(self.device)
        self.det_mod.eval()

        # ---- Load reframing (T5) ----
        print("Loading reframing model...")
        self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO)
        self.ref_mod = T5ForConditionalGeneration.from_pretrained(
            REFRAME_REPO
        ).to(self.device)
        self.ref_mod.eval()

        # warm-up (tiny forward pass so first request is snappy)
        print("Warming up...")
        _ = self.analyze("hello", threshold=0.5, k=1)
        print("Ready!")

    @torch.no_grad()
    def detect(self, text: str, threshold: float = 0.5):
        enc = self.det_tok(
            text,
            max_length=128,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        enc = {k: v.to(self.device) for k, v in enc.items()}
        logits = self.det_mod(**enc).logits
        probs = F.softmax(logits, dim=1)[0]
        pred_idx = int(torch.argmax(logits, dim=1))
        conf = float(probs[pred_idx])

        is_micro = bool(pred_idx) and (conf >= threshold)
        return is_micro, conf, f"LABEL_{pred_idx}"

    @torch.no_grad()
    def reframe(self, text: str, k: int = 3):
        # capped for latency on CPU
        pref = f"rephrase: {text}"
        enc = self.ref_tok(
            pref, return_tensors="pt", max_length=192, truncation=True
        )
        enc = {k: v.to(self.device) for k, v in enc.items()}
        out = self.ref_mod.generate(
            **enc,
            max_length=192,
            num_beams=4,
            num_return_sequences=max(1, min(k, 5)),
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7,
            early_stopping=True,
        )
        seen = set()
        options = []
        for seq in out:
            s = self.ref_tok.decode(seq, skip_special_tokens=True).strip()
            if s and s not in seen:
                seen.add(s)
                options.append(s)
            if len(options) >= k:
                break
        while len(options) < k and options:
            options.append(options[-1])
        return options[:k]

    def analyze(self, text: str, threshold: float = 0.5, k: int = 3):
        is_micro, conf, raw_label = self.detect(text, threshold=threshold)
        options = self.reframe(text, k=k) if is_micro else []
        return is_micro, conf, raw_label, options


PIPELINE = MicroaggressionPipeline()


def gradio_interface(text: str, threshold: float):
    text = (text or "").strip()
    if not text:
        return "❌ Please enter some text", "", "", ""

    is_micro, conf, raw_label, options = PIPELINE.analyze(
        text, threshold=float(threshold), k=3
    )

    if is_micro:
        header = f"⚠️ **Microaggression Detected**  \nConfidence: {conf:.1%}  \nRaw label: {raw_label}"
    else:
        header = f"✅ **No Microaggression Detected**  \nConfidence: {conf:.1%}  \nRaw label: {raw_label}"

    # pad to 3 fields for the UI
    opts = (options + ["", "", ""])[:3]
    return header, opts[0], opts[1], opts[2]


with gr.Blocks(title="Microaggression Analyzer") as demo:
    gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text")

    with gr.Row():
        with gr.Column():
            text_in = gr.Textbox(
                label="Enter text to analyze",
                placeholder="Type or paste text...",
                lines=3,
            )
            thr = gr.Slider(
                minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold"
            )
            analyze_btn = gr.Button("Analyze", variant="primary")
        with gr.Column():
            result_md = gr.Markdown(label="Result")

    gr.Markdown("### Suggested Reframings")
    with gr.Row():
        opt1 = gr.Textbox(label="Option 1", lines=2)
        opt2 = gr.Textbox(label="Option 2", lines=2)
        opt3 = gr.Textbox(label="Option 3", lines=2)

    gr.Examples(
        examples=[
            ["You speak good English for someone from there.", 0.5],
            ["Where are you really from?", 0.5],
            ["You're so articulate.", 0.5],
        ],
        inputs=[text_in, thr],
    )

    analyze_btn.click(
        fn=gradio_interface,
        inputs=[text_in, thr],
        outputs=[result_md, opt1, opt2, opt3],
        # (gradio v5) optional per-event limit:
        # concurrency_limit="default"
    )

# (gradio v5) no concurrency_count; use default_concurrency_limit if you want
demo.queue(default_concurrency_limit=2, max_size=16)
demo.launch(show_api=True)