jokugeorgin commited on
Commit
66fb10b
·
verified ·
1 Parent(s): beb2921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -65
app.py CHANGED
@@ -10,111 +10,142 @@ from transformers import (
10
  DebertaTokenizer,
11
  DebertaForSequenceClassification,
12
  T5Tokenizer,
13
- T5ForConditionalGeneration
14
  )
15
 
 
16
  torch.set_num_threads(2)
17
  torch.set_num_interop_threads(1)
18
 
 
 
 
 
19
  class MicroaggressionPipeline:
20
  def __init__(self):
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  print(f"Using device: {self.device}")
23
 
 
24
  print("Loading detection model...")
25
- self.detection_tokenizer = DebertaTokenizer.from_pretrained("jokugeorgin/CI_MA_Detect")
26
- self.detection_model = DebertaForSequenceClassification.from_pretrained(
27
- "jokugeorgin/CI_MA_Detect", num_labels=2
28
  ).to(self.device)
29
- self.detection_model.eval()
30
 
 
31
  print("Loading reframing model...")
32
- self.reframing_tokenizer = T5Tokenizer.from_pretrained("jokugeorgin/CI_MA_Reframe")
33
- self.reframing_model = T5ForConditionalGeneration.from_pretrained(
34
- "jokugeorgin/CI_MA_Reframe"
35
  ).to(self.device)
36
- self.reframing_model.eval()
37
 
 
38
  print("Warming up...")
39
- _ = self.analyze("hello", threshold=0.5)
40
  print("Ready!")
41
 
42
  @torch.no_grad()
43
- def analyze(self, text, threshold=0.5, k=3):
44
- enc = self.detection_tokenizer(
45
- text, max_length=128, truncation=True, padding="max_length", return_tensors="pt"
 
 
 
 
46
  )
47
  enc = {k: v.to(self.device) for k, v in enc.items()}
48
- logits = self.detection_model(**enc).logits
49
  probs = F.softmax(logits, dim=1)[0]
50
  pred_idx = int(torch.argmax(logits, dim=1))
51
- confidence = float(probs[pred_idx])
52
- is_micro = bool(pred_idx) and (confidence >= threshold)
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  options = []
55
- if is_micro:
56
- prefixed = f"rephrase: {text}"
57
- genc = self.reframing_tokenizer(
58
- prefixed, return_tensors="pt", max_length=192, truncation=True
59
- )
60
- genc = {k: v.to(self.device) for k, v in genc.items()}
61
- out = self.reframing_model.generate(
62
- **genc,
63
- max_length=192,
64
- num_beams=4,
65
- num_return_sequences=max(1, min(k, 5)),
66
- no_repeat_ngram_size=2,
67
- do_sample=True,
68
- temperature=0.7,
69
- early_stopping=True,
70
- )
71
- seen = set()
72
- for seq in out:
73
- s = self.reframing_tokenizer.decode(seq, skip_special_tokens=True).strip()
74
- if s and s not in seen:
75
- seen.add(s)
76
- options.append(s)
77
- if len(options) >= k:
78
- break
79
- while len(options) < k and options:
80
- options.append(options[-1])
81
-
82
- return is_micro, confidence, options[:k]
83
-
84
- pipeline = MicroaggressionPipeline()
85
-
86
- def gradio_interface(text, threshold):
87
  text = (text or "").strip()
88
  if not text:
89
  return "❌ Please enter some text", "", "", ""
90
 
91
- is_micro, confidence, options = pipeline.analyze(text, threshold=threshold, k=3)
92
-
93
- result = (
94
- f"⚠️ **Microaggression Detected**\n\nConfidence: {confidence:.1%}"
95
- if is_micro else
96
- f"✅ **No Microaggression Detected**\n\nConfidence: {confidence:.1%}"
97
  )
98
 
 
 
 
 
 
 
99
  opts = (options + ["", "", ""])[:3]
100
- return result, opts[0], opts[1], opts[2]
 
101
 
102
  with gr.Blocks(title="Microaggression Analyzer") as demo:
103
  gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text")
104
 
105
  with gr.Row():
106
  with gr.Column():
107
- text_input = gr.Textbox(label="Enter text to analyze", placeholder="Type or paste text...", lines=3)
108
- threshold = gr.Slider(minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold")
 
 
 
 
 
 
109
  analyze_btn = gr.Button("Analyze", variant="primary")
110
  with gr.Column():
111
- result_output = gr.Markdown(label="Result")
112
 
113
  gr.Markdown("### Suggested Reframings")
114
  with gr.Row():
115
- option1 = gr.Textbox(label="Option 1", lines=2)
116
- option2 = gr.Textbox(label="Option 2", lines=2)
117
- option3 = gr.Textbox(label="Option 3", lines=2)
118
 
119
  gr.Examples(
120
  examples=[
@@ -122,14 +153,17 @@ with gr.Blocks(title="Microaggression Analyzer") as demo:
122
  ["Where are you really from?", 0.5],
123
  ["You're so articulate.", 0.5],
124
  ],
125
- inputs=[text_input, threshold],
126
  )
127
 
128
  analyze_btn.click(
129
  fn=gradio_interface,
130
- inputs=[text_input, threshold],
131
- outputs=[result_output, option1, option2, option3],
 
 
132
  )
133
 
134
- demo.queue(concurrency_count=2, max_size=16)
 
135
  demo.launch(show_api=True)
 
10
  DebertaTokenizer,
11
  DebertaForSequenceClassification,
12
  T5Tokenizer,
13
+ T5ForConditionalGeneration,
14
  )
15
 
16
+ # keep CPU predictable
17
  torch.set_num_threads(2)
18
  torch.set_num_interop_threads(1)
19
 
20
+ DETECT_REPO = "jokugeorgin/CI_MA_Detect"
21
+ REFRAME_REPO = "jokugeorgin/CI_MA_Reframe"
22
+
23
+
24
  class MicroaggressionPipeline:
25
  def __init__(self):
26
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  print(f"Using device: {self.device}")
28
 
29
+ # ---- Load detection (DeBERTa) ----
30
  print("Loading detection model...")
31
+ self.det_tok = DebertaTokenizer.from_pretrained(DETECT_REPO)
32
+ self.det_mod = DebertaForSequenceClassification.from_pretrained(
33
+ DETECT_REPO, num_labels=2
34
  ).to(self.device)
35
+ self.det_mod.eval()
36
 
37
+ # ---- Load reframing (T5) ----
38
  print("Loading reframing model...")
39
+ self.ref_tok = T5Tokenizer.from_pretrained(REFRAME_REPO)
40
+ self.ref_mod = T5ForConditionalGeneration.from_pretrained(
41
+ REFRAME_REPO
42
  ).to(self.device)
43
+ self.ref_mod.eval()
44
 
45
+ # warm-up (tiny forward pass so first request is snappy)
46
  print("Warming up...")
47
+ _ = self.analyze("hello", threshold=0.5, k=1)
48
  print("Ready!")
49
 
50
  @torch.no_grad()
51
+ def detect(self, text: str, threshold: float = 0.5):
52
+ enc = self.det_tok(
53
+ text,
54
+ max_length=128,
55
+ truncation=True,
56
+ padding="max_length",
57
+ return_tensors="pt",
58
  )
59
  enc = {k: v.to(self.device) for k, v in enc.items()}
60
+ logits = self.det_mod(**enc).logits
61
  probs = F.softmax(logits, dim=1)[0]
62
  pred_idx = int(torch.argmax(logits, dim=1))
63
+ conf = float(probs[pred_idx])
64
+
65
+ is_micro = bool(pred_idx) and (conf >= threshold)
66
+ return is_micro, conf, f"LABEL_{pred_idx}"
67
 
68
+ @torch.no_grad()
69
+ def reframe(self, text: str, k: int = 3):
70
+ # capped for latency on CPU
71
+ pref = f"rephrase: {text}"
72
+ enc = self.ref_tok(
73
+ pref, return_tensors="pt", max_length=192, truncation=True
74
+ )
75
+ enc = {k: v.to(self.device) for k, v in enc.items()}
76
+ out = self.ref_mod.generate(
77
+ **enc,
78
+ max_length=192,
79
+ num_beams=4,
80
+ num_return_sequences=max(1, min(k, 5)),
81
+ no_repeat_ngram_size=2,
82
+ do_sample=True,
83
+ temperature=0.7,
84
+ early_stopping=True,
85
+ )
86
+ seen = set()
87
  options = []
88
+ for seq in out:
89
+ s = self.ref_tok.decode(seq, skip_special_tokens=True).strip()
90
+ if s and s not in seen:
91
+ seen.add(s)
92
+ options.append(s)
93
+ if len(options) >= k:
94
+ break
95
+ while len(options) < k and options:
96
+ options.append(options[-1])
97
+ return options[:k]
98
+
99
+ def analyze(self, text: str, threshold: float = 0.5, k: int = 3):
100
+ is_micro, conf, raw_label = self.detect(text, threshold=threshold)
101
+ options = self.reframe(text, k=k) if is_micro else []
102
+ return is_micro, conf, raw_label, options
103
+
104
+
105
+ PIPELINE = MicroaggressionPipeline()
106
+
107
+
108
+ def gradio_interface(text: str, threshold: float):
 
 
 
 
 
 
 
 
 
 
 
109
  text = (text or "").strip()
110
  if not text:
111
  return "❌ Please enter some text", "", "", ""
112
 
113
+ is_micro, conf, raw_label, options = PIPELINE.analyze(
114
+ text, threshold=float(threshold), k=3
 
 
 
 
115
  )
116
 
117
+ if is_micro:
118
+ header = f"⚠️ **Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
119
+ else:
120
+ header = f"✅ **No Microaggression Detected** \nConfidence: {conf:.1%} \nRaw label: {raw_label}"
121
+
122
+ # pad to 3 fields for the UI
123
  opts = (options + ["", "", ""])[:3]
124
+ return header, opts[0], opts[1], opts[2]
125
+
126
 
127
  with gr.Blocks(title="Microaggression Analyzer") as demo:
128
  gr.Markdown("# 🔍 Microaggression Analyzer\nDetect and reframe microaggressions in text")
129
 
130
  with gr.Row():
131
  with gr.Column():
132
+ text_in = gr.Textbox(
133
+ label="Enter text to analyze",
134
+ placeholder="Type or paste text...",
135
+ lines=3,
136
+ )
137
+ thr = gr.Slider(
138
+ minimum=0.3, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold"
139
+ )
140
  analyze_btn = gr.Button("Analyze", variant="primary")
141
  with gr.Column():
142
+ result_md = gr.Markdown(label="Result")
143
 
144
  gr.Markdown("### Suggested Reframings")
145
  with gr.Row():
146
+ opt1 = gr.Textbox(label="Option 1", lines=2)
147
+ opt2 = gr.Textbox(label="Option 2", lines=2)
148
+ opt3 = gr.Textbox(label="Option 3", lines=2)
149
 
150
  gr.Examples(
151
  examples=[
 
153
  ["Where are you really from?", 0.5],
154
  ["You're so articulate.", 0.5],
155
  ],
156
+ inputs=[text_in, thr],
157
  )
158
 
159
  analyze_btn.click(
160
  fn=gradio_interface,
161
+ inputs=[text_in, thr],
162
+ outputs=[result_md, opt1, opt2, opt3],
163
+ # (gradio v5) optional per-event limit:
164
+ # concurrency_limit="default"
165
  )
166
 
167
+ # (gradio v5) no concurrency_count; use default_concurrency_limit if you want
168
+ demo.queue(default_concurrency_limit=2, max_size=16)
169
  demo.launch(show_api=True)