msmaje commited on
Commit
0c0c1de
·
verified ·
1 Parent(s): 0225ee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -72
app.py CHANGED
@@ -44,40 +44,60 @@ except ImportError:
44
  # Configuration
45
  # -----------------------------------------------------------------------------
46
  MODEL_NAME = "msmaje/phdhatamodel"
47
- SUPPORTED_LANGUAGES = ["Hausa", "Yoruba", "Igbo", "Swahili", "Amharic", "Nigerian Pidgin"]
48
  LANGUAGE_CODES = {
49
  "Hausa": "ha",
50
  "Yoruba": "yo",
51
  "Igbo": "ig",
52
- "Swahili": "sw",
53
- "Amharic": "am",
54
  "Nigerian Pidgin": "pcm"
55
  }
56
 
57
  # -----------------------------------------------------------------------------
58
  # Model Loading
59
  # -----------------------------------------------------------------------------
60
- print("Loading model and tokenizer...")
61
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
62
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
63
- model.eval()
64
- print("✅ Model loaded successfully!")
 
 
 
 
 
 
 
 
 
65
 
66
  # Initialize explainability tools
67
  if LIME_AVAILABLE:
68
- lime_explainer = LimeTextExplainer(class_names=["Human", "AI"])
 
 
 
 
 
69
 
70
  if SHAP_AVAILABLE:
71
- # Create a wrapper for SHAP
72
- def model_predict_proba(texts):
73
- inputs = tokenizer(texts, return_tensors="pt", truncation=True,
74
- max_length=128, padding=True)
75
- with torch.no_grad():
76
- outputs = model(**inputs)
77
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
78
- return probs.numpy()
79
-
80
- shap_explainer = shap.Explainer(model_predict_proba, tokenizer)
 
 
 
 
 
 
 
 
81
 
82
  # -----------------------------------------------------------------------------
83
  # Bias and Fairness Metrics
@@ -153,39 +173,50 @@ def get_shap_explanation(text, language="English"):
153
  return "⚠️ SHAP is not installed. Install with: pip install shap", None
154
 
155
  try:
156
- # Get SHAP values
157
- shap_values = shap_explainer([text])
158
 
159
- # Create visualization
160
- fig, ax = plt.subplots(figsize=(12, 6))
161
- shap.plots.text(shap_values[0], display=False)
162
- plt.tight_layout()
 
 
163
 
164
- # Extract token attributions
165
- tokens = tokenizer.tokenize(text)[:20] # Limit to first 20 tokens
166
- values = shap_values.values[0][:len(tokens), 1] # AI class
 
167
 
168
- attribution_data = {
169
- "Token": tokens,
170
- "Attribution": values.tolist()
171
- }
172
 
173
- explanation = f"## SHAP Explanation for {language}\n\n"
174
- explanation += "Tokens with **positive values** push toward AI-generated classification.\n"
175
- explanation += "Tokens with **negative values** push toward Human-written classification.\n\n"
176
- explanation += f"Top 5 most influential tokens:\n"
 
 
 
 
 
 
 
 
 
 
177
 
178
- top_indices = np.argsort(np.abs(values))[-5:][::-1]
179
  for idx in top_indices:
180
- token = tokens[idx]
181
- value = values[idx]
182
- direction = "→ AI" if value > 0 else "→ Human"
183
- explanation += f"- **{token}**: {value:.4f} {direction}\n"
184
 
185
- return explanation, (fig, attribution_data)
186
 
187
  except Exception as e:
188
- return f"❌ SHAP explanation failed: {str(e)}", None
189
 
190
  def get_lime_explanation(text, language="English"):
191
  """Generate LIME-based explanation"""
@@ -194,19 +225,27 @@ def get_lime_explanation(text, language="English"):
194
 
195
  try:
196
  def predict_fn(texts):
197
- inputs = tokenizer(texts, return_tensors="pt", truncation=True,
198
- max_length=128, padding=True)
199
- with torch.no_grad():
200
- outputs = model(**inputs)
201
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
202
- return probs.numpy()
 
 
 
 
 
 
 
 
203
 
204
  # Generate explanation
205
  exp = lime_explainer.explain_instance(
206
  text,
207
  predict_fn,
208
  num_features=10,
209
- num_samples=100
210
  )
211
 
212
  # Create visualization
@@ -219,7 +258,7 @@ def get_lime_explanation(text, language="English"):
219
  explanation = f"## LIME Explanation for {language}\n\n"
220
  explanation += "Features with **positive weights** indicate AI-generated characteristics.\n"
221
  explanation += "Features with **negative weights** indicate Human-written characteristics.\n\n"
222
- explanation += "Top contributing features:\n"
223
 
224
  for feature, weight in weights[:5]:
225
  direction = "→ AI" if weight > 0 else "→ Human"
@@ -228,7 +267,7 @@ def get_lime_explanation(text, language="English"):
228
  return explanation, fig
229
 
230
  except Exception as e:
231
- return f"❌ LIME explanation failed: {str(e)}", None
232
 
233
  # -----------------------------------------------------------------------------
234
  # Main Classification Function
@@ -263,29 +302,37 @@ def classify_with_explanation(text, language, explainer_type="SHAP"):
263
  else:
264
  result += "❓ **Low confidence** - Uncertain, mixed characteristics detected\n"
265
 
266
- # Probability breakdown
267
- prob_chart = {
268
  "Class": ["Human-written", "AI-generated"],
269
  "Probability": [float(probabilities[0][0]), float(probabilities[0][1])]
270
- }
271
 
272
  # Generate explanation
273
- explanation_text = None
274
  explanation_viz = None
275
 
276
  if explainer_type == "SHAP" and SHAP_AVAILABLE:
277
  explanation_text, explanation_viz = get_shap_explanation(text, language)
 
 
278
  elif explainer_type == "LIME" and LIME_AVAILABLE:
279
  explanation_text, explanation_viz = get_lime_explanation(text, language)
280
  elif explainer_type == "Both":
281
  shap_text, shap_viz = get_shap_explanation(text, language)
282
  lime_text, lime_viz = get_lime_explanation(text, language)
283
  explanation_text = shap_text + "\n\n---\n\n" + lime_text
284
- explanation_viz = (shap_viz, lime_viz) if shap_viz and lime_viz else shap_viz or lime_viz
 
 
 
 
 
 
285
  else:
286
- explanation_text = "⚠️ Selected explainer not available"
287
 
288
- return result, prob_chart, explanation_text, explanation_viz
289
 
290
  # -----------------------------------------------------------------------------
291
  # Bias Auditing Function
@@ -431,12 +478,28 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
431
  x="Class",
432
  y="Probability",
433
  title="Prediction Probabilities",
434
- y_lim=[0, 1]
 
 
435
  )
436
 
437
  with gr.Row():
438
- explanation_output = gr.Markdown(label="Explanation")
439
- explanation_viz = gr.Plot(label="Visual Explanation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
  classify_btn.click(
442
  fn=classify_with_explanation,
@@ -491,7 +554,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
491
  - Per-language performance metrics
492
 
493
  ## 🌍 Supported Languages
494
- Hausa, Yoruba, Igbo, Swahili, Amharic, Nigerian Pidgin
495
 
496
  ## 📊 Model Performance
497
  - Accuracy: 100%
@@ -500,9 +563,11 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
500
  - AAOD: 0.0 (No bias)
501
 
502
  ## 🔬 Technical Details
503
- - Base Model: AfroXLMR-base
504
  - Parameters: ~270M
505
  - Max Sequence Length: 128 tokens
 
 
506
 
507
  ## 📚 Citation
508
  ```bibtex
@@ -524,11 +589,4 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
524
  """)
525
 
526
  if __name__ == "__main__":
527
- demo.queue(api_open=False)
528
- demo.launch(
529
- server_name="0.0.0.0",
530
- server_port=7860,
531
- show_error=True,
532
- share=True # <-- important for Spaces
533
- )
534
-
 
44
  # Configuration
45
  # -----------------------------------------------------------------------------
46
  MODEL_NAME = "msmaje/phdhatamodel"
47
+ SUPPORTED_LANGUAGES = ["Hausa", "Yoruba", "Igbo", "Nigerian Pidgin"]
48
  LANGUAGE_CODES = {
49
  "Hausa": "ha",
50
  "Yoruba": "yo",
51
  "Igbo": "ig",
 
 
52
  "Nigerian Pidgin": "pcm"
53
  }
54
 
55
  # -----------------------------------------------------------------------------
56
  # Model Loading
57
  # -----------------------------------------------------------------------------
58
+ print("📥 Loading model and tokenizer...")
59
+ try:
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
61
+ model = AutoModelForSequenceClassification.from_pretrained(
62
+ MODEL_NAME,
63
+ output_attentions=True # Enable attention outputs for explainability
64
+ )
65
+ model.eval()
66
+ print("✅ Model loaded successfully!")
67
+ print(f" Model: {MODEL_NAME}")
68
+ print(f" Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
69
+ except Exception as e:
70
+ print(f"❌ Error loading model: {e}")
71
+ raise
72
 
73
  # Initialize explainability tools
74
  if LIME_AVAILABLE:
75
+ try:
76
+ lime_explainer = LimeTextExplainer(class_names=["Human", "AI"])
77
+ print("✅ LIME explainer initialized")
78
+ except Exception as e:
79
+ print(f"⚠️ LIME initialization failed: {e}")
80
+ LIME_AVAILABLE = False
81
 
82
  if SHAP_AVAILABLE:
83
+ try:
84
+ # Create a wrapper for SHAP
85
+ def model_predict_proba(texts):
86
+ if isinstance(texts, str):
87
+ texts = [texts]
88
+ inputs = tokenizer(texts, return_tensors="pt", truncation=True,
89
+ max_length=128, padding=True)
90
+ with torch.no_grad():
91
+ outputs = model(**inputs)
92
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
93
+ return probs.numpy()
94
+
95
+ shap_explainer = shap.Explainer(model_predict_proba, tokenizer)
96
+ print("✅ SHAP explainer initialized")
97
+ except Exception as e:
98
+ print(f"⚠️ SHAP initialization failed: {e}")
99
+ print(" Will use attention-based explanations as fallback")
100
+ SHAP_AVAILABLE = False
101
 
102
  # -----------------------------------------------------------------------------
103
  # Bias and Fairness Metrics
 
173
  return "⚠️ SHAP is not installed. Install with: pip install shap", None
174
 
175
  try:
176
+ # Simpler approach - use attention weights as proxy for SHAP
177
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
178
 
179
+ with torch.no_grad():
180
+ outputs = model(**inputs, output_attentions=True)
181
+ # Get mean attention across all layers and heads
182
+ attentions = outputs.attentions
183
+ mean_attention = torch.mean(torch.stack([att.mean(dim=1) for att in attentions]), dim=0)
184
+ token_importance = mean_attention[0].sum(dim=0).numpy()
185
 
186
+ # Get tokens
187
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
188
+ tokens = tokens[1:-1] # Remove [CLS] and [SEP]
189
+ token_importance = token_importance[1:-1] # Match tokens
190
 
191
+ # Normalize
192
+ token_importance = token_importance / (token_importance.max() + 1e-8)
 
 
193
 
194
+ # Create simple bar plot
195
+ fig, ax = plt.subplots(figsize=(12, 6))
196
+ colors = ['red' if x < 0 else 'green' for x in token_importance]
197
+ ax.barh(range(min(20, len(tokens))), token_importance[:20], color=colors[:20])
198
+ ax.set_yticks(range(min(20, len(tokens))))
199
+ ax.set_yticklabels(tokens[:20])
200
+ ax.set_xlabel('Importance (Attention Weight)')
201
+ ax.set_title(f'Token Importance - {language}')
202
+ ax.invert_yaxis()
203
+ plt.tight_layout()
204
+
205
+ explanation = f"## Attention-Based Explanation for {language}\n\n"
206
+ explanation += "Tokens with **higher values** are more important for classification.\n\n"
207
+ explanation += f"Top 5 most important tokens:\n"
208
 
209
+ top_indices = np.argsort(token_importance)[-5:][::-1]
210
  for idx in top_indices:
211
+ if idx < len(tokens):
212
+ token = tokens[idx]
213
+ value = token_importance[idx]
214
+ explanation += f"- **{token}**: {value:.4f}\n"
215
 
216
+ return explanation, fig
217
 
218
  except Exception as e:
219
+ return f"❌ Explanation failed: {str(e)}", None
220
 
221
  def get_lime_explanation(text, language="English"):
222
  """Generate LIME-based explanation"""
 
225
 
226
  try:
227
  def predict_fn(texts):
228
+ """Prediction function for LIME"""
229
+ if isinstance(texts, str):
230
+ texts = [texts]
231
+
232
+ results = []
233
+ for txt in texts:
234
+ inputs = tokenizer(txt, return_tensors="pt", truncation=True,
235
+ max_length=128, padding=True)
236
+ with torch.no_grad():
237
+ outputs = model(**inputs)
238
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
239
+ results.append(probs[0].numpy())
240
+
241
+ return np.array(results)
242
 
243
  # Generate explanation
244
  exp = lime_explainer.explain_instance(
245
  text,
246
  predict_fn,
247
  num_features=10,
248
+ num_samples=50 # Reduced for speed
249
  )
250
 
251
  # Create visualization
 
258
  explanation = f"## LIME Explanation for {language}\n\n"
259
  explanation += "Features with **positive weights** indicate AI-generated characteristics.\n"
260
  explanation += "Features with **negative weights** indicate Human-written characteristics.\n\n"
261
+ explanation += "Top contributing features:\n\n"
262
 
263
  for feature, weight in weights[:5]:
264
  direction = "→ AI" if weight > 0 else "→ Human"
 
267
  return explanation, fig
268
 
269
  except Exception as e:
270
+ return f"❌ LIME explanation failed: {str(e)}\n\nTry using SHAP instead.", None
271
 
272
  # -----------------------------------------------------------------------------
273
  # Main Classification Function
 
302
  else:
303
  result += "❓ **Low confidence** - Uncertain, mixed characteristics detected\n"
304
 
305
+ # Probability breakdown - Create DataFrame for BarPlot
306
+ prob_data = pd.DataFrame({
307
  "Class": ["Human-written", "AI-generated"],
308
  "Probability": [float(probabilities[0][0]), float(probabilities[0][1])]
309
+ })
310
 
311
  # Generate explanation
312
+ explanation_text = ""
313
  explanation_viz = None
314
 
315
  if explainer_type == "SHAP" and SHAP_AVAILABLE:
316
  explanation_text, explanation_viz = get_shap_explanation(text, language)
317
+ if explanation_viz and isinstance(explanation_viz, tuple):
318
+ explanation_viz = explanation_viz[0] # Extract just the figure
319
  elif explainer_type == "LIME" and LIME_AVAILABLE:
320
  explanation_text, explanation_viz = get_lime_explanation(text, language)
321
  elif explainer_type == "Both":
322
  shap_text, shap_viz = get_shap_explanation(text, language)
323
  lime_text, lime_viz = get_lime_explanation(text, language)
324
  explanation_text = shap_text + "\n\n---\n\n" + lime_text
325
+ # Use SHAP visualization by default for "Both"
326
+ if shap_viz and isinstance(shap_viz, tuple):
327
+ explanation_viz = shap_viz[0]
328
+ elif isinstance(shap_viz, plt.Figure):
329
+ explanation_viz = shap_viz
330
+ else:
331
+ explanation_viz = lime_viz
332
  else:
333
+ explanation_text = "⚠️ Selected explainer not available. Please install SHAP and/or LIME."
334
 
335
+ return result, prob_data, explanation_text, explanation_viz
336
 
337
  # -----------------------------------------------------------------------------
338
  # Bias Auditing Function
 
478
  x="Class",
479
  y="Probability",
480
  title="Prediction Probabilities",
481
+ y_lim=[0, 1],
482
+ height=300,
483
+ width=400
484
  )
485
 
486
  with gr.Row():
487
+ with gr.Column():
488
+ explanation_output = gr.Markdown(label="Explanation")
489
+ with gr.Column():
490
+ explanation_viz = gr.Plot(label="Visual Explanation")
491
+
492
+ # Examples to help users
493
+ gr.Examples(
494
+ examples=[
495
+ ["Ka rubuta labari game da kasuwa a Kano", "Hausa", "SHAP"],
496
+ ["Ìwé yìí jẹ́ ìwé tó dára púpọ̀ fún àwọn akẹ́kọ̀ọ́", "Yoruba", "LIME"],
497
+ ["Akwụkwọ a dị mma maka ụmụ akwụkwọ", "Igbo", "SHAP"],
498
+ ["Dis book dey very good for students wey wan learn", "Nigerian Pidgin", "Both"]
499
+ ],
500
+ inputs=[text_input, language_select, explainer_select],
501
+ label="Try these examples in different languages"
502
+ )
503
 
504
  classify_btn.click(
505
  fn=classify_with_explanation,
 
554
  - Per-language performance metrics
555
 
556
  ## 🌍 Supported Languages
557
+ Hausa, Yoruba, Igbo, Nigerian Pidgin
558
 
559
  ## 📊 Model Performance
560
  - Accuracy: 100%
 
563
  - AAOD: 0.0 (No bias)
564
 
565
  ## 🔬 Technical Details
566
+ - Base Model: AfroXLMR-base (davlan/afro-xlmr-base)
567
  - Parameters: ~270M
568
  - Max Sequence Length: 128 tokens
569
+ - Training Dataset: PhD HATA African Dataset
570
+ - Languages: 4 West African languages
571
 
572
  ## 📚 Citation
573
  ```bibtex
 
589
  """)
590
 
591
  if __name__ == "__main__":
592
+ demo.launch()