lmarty commited on
Commit
8adab44
·
verified ·
1 Parent(s): 98a4fc1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +433 -0
app.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Attempted_integrated_code_of FinalFairLLM.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1cDiUULjHKzp9mzXrt6uCvQHUubMYy-Nc
8
+ """
9
+
10
+ #importing necessary libraries
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling, AutoModelForSequenceClassification
12
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
13
+ from sklearn.model_selection import train_test_split
14
+ from datasets import load_dataset
15
+ import torch
16
+ from collections import defaultdict
17
+ import numpy as np
18
+ import re
19
+ import csv
20
+ import pandas as pd
21
+ import gradio as gr
22
+ from fairlearn.metrics import demographic_parity_difference, demographic_parity_ratio, equalized_odds_difference
23
+ from fairlearn.metrics import MetricFrame
24
+
25
+ #This is a GPT-2 fine-tuning process where a pre-trained model is retrained to reduce bias by learning patterns of biased text prompts and their neutral alternatives.
26
+ model_name="gpt2"
27
+
28
+ # Load csv file with columns "biased_prompt" and "less_biased_prompt"
29
+ df = pd.read_csv('small.csv')
30
+
31
+ # Format the rows of csv file as "<biased_prompt> -> <less_biased_prompt_text>"
32
+ biased_prompt = df['biased_prompt']
33
+ less_biased_prompt = df['less_biased_prompt']
34
+ data = []
35
+
36
+ for row in df.itertuples(index=False):
37
+ data.append(row.biased_prompt + ' -> ' + row.less_biased_prompt )
38
+
39
+ with open("./biased_less_biased_data.txt", "w", encoding="utf-8") as f:
40
+ f.write("\n".join(data))
41
+
42
+ def tokenize_function(examples):
43
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)
44
+ # Tokenize the data
45
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
46
+ tokenizer.pad_token = tokenizer.eos_token # GPT-2 doesn't have a pad token
47
+
48
+ # Load as a dataset
49
+ dataset = load_dataset("text", data_files={"train": "biased_less_biased_data.txt"})
50
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
51
+
52
+ # Load pre-trained model
53
+ model1 = GPT2LMHeadModel.from_pretrained(model_name)
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ model1.to(device)
56
+
57
+ # Data Collator (handles padding)
58
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
59
+
60
+ # Training Arguments
61
+ training_args = TrainingArguments(
62
+ output_dir="./gpt2-based-bias-eliminator",
63
+ #overwrite_output_dir=True,
64
+ num_train_epochs=10, # Increase for better results
65
+ per_device_train_batch_size=4,
66
+ save_steps=100,
67
+ logging_steps=10,
68
+ learning_rate=5e-5,
69
+ )
70
+
71
+ # Initialize Trainer
72
+ trainer = Trainer(
73
+ model=model1,
74
+ args=training_args,
75
+ data_collator=data_collator,
76
+ train_dataset=tokenized_datasets["train"],
77
+ )
78
+
79
+ # Train
80
+ trainer.train()
81
+
82
+ # Save the model
83
+ trainer.save_model("./bias-eliminator-model")
84
+ tokenizer.save_pretrained("./bias-eliminator-model")
85
+
86
+ # Load fine-tuned gpt2 model that is now able to neutralize a biased prompt
87
+ bias_prompt_eliminator = pipeline("text-generation", model="./bias-eliminator-model", tokenizer="./bias-eliminator-model")
88
+
89
+ def show_neutralized_prompt(input_text):
90
+ # input into retrained gpt2 model requires the format:
91
+ # "<input_text><text sep>"
92
+ #
93
+ # Where: <input_text> is the user prompt
94
+ # <text sep> is the string " -> "
95
+ #
96
+ # Example:
97
+ #
98
+ # <input text> = "Explain why immigrants struggle with career advancement in public services."
99
+ # Input format to model is:
100
+ # <input_text><text sep> = "Explain why immigrants struggle with career advancement in public services. ->"
101
+
102
+ sep = " -> "
103
+ input_text_format = input_text + sep
104
+ result = bias_prompt_eliminator(input_text_format, max_length=30, num_return_sequences=1)
105
+
106
+ generated_text = result[0]['generated_text']
107
+
108
+ first = generated_text.find(sep)
109
+
110
+ if first != -1:
111
+ second = generated_text.find(sep, first +len(sep))
112
+ else:
113
+ second = -1
114
+ if second != -1:
115
+ print(generated_text[0:second])
116
+ else:
117
+ print(generated_text[0:first])
118
+
119
+ # FAIRNESS MODEL (MNLI)
120
+ mnli_model_name = "facebookAI/roberta-large-mnli"
121
+ mnli_tokenizer = AutoTokenizer.from_pretrained(mnli_model_name)
122
+ mnli_model = AutoModelForSequenceClassification.from_pretrained(
123
+ mnli_model_name
124
+ ).eval()
125
+
126
+ print("MNLI fairness model loaded.")
127
+
128
+ # MNLI FAIRNESS SCORING
129
+ def mnli_bias_score(text):
130
+ """
131
+ Quantitative bias signal used ONLY for fairness metrics
132
+ """
133
+ inputs = mnli_tokenizer(text, return_tensors="pt", truncation=True)
134
+
135
+ with torch.no_grad():
136
+ outputs = mnli_model(**inputs)
137
+
138
+ probs = torch.softmax(outputs.logits, dim=1)[0].tolist()
139
+ contradiction, neutral, entailment = probs
140
+
141
+ if entailment > 0.65:
142
+ return "High"
143
+ elif entailment > 0.50:
144
+ return "Medium"
145
+ else:
146
+ return "Low"
147
+
148
+ # FAIRNESS AGGREGATION BUFFERS
149
+
150
+ ALL_Y_TRUE = []
151
+ ALL_Y_PRED = []
152
+ ALL_GROUPS = []
153
+
154
+ rating_map = {"Low": 0, "Medium": 1, "High": 1}
155
+
156
+ # NEUTRAL PROMPT MITIGATION
157
+ MITIGATION_PROMPTS = {
158
+ "gender": "Try using gender-neutral terms such as 'parent' or 'professional'.",
159
+ "race": "Focus on social or economic factors rather than race.",
160
+ "religion": "Frame questions around beliefs without judgment.",
161
+ "disability": "Focus on accommodations rather than limitations.",
162
+ "profession": "Avoid hierarchical or role-based stereotypes."
163
+ }
164
+
165
+ def show_mitigation(bias_type):
166
+ tip = MITIGATION_PROMPTS.get(
167
+ bias_type.lower(),
168
+ "Try rephrasing the prompt using neutral and inclusive language."
169
+ )
170
+ print("\nMitigation Suggestion:")
171
+ print(tip)
172
+
173
+ #function to set up and return LLama 3.2-1b Instruct model transformer and its tokenizer + the Regard classifier model from Hugging Face
174
+ def setup_models():
175
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
176
+ #set pad_token_id to eos_token_id if it's not already defined
177
+ if tokenizer.pad_token_id is None:
178
+ tokenizer.pad_token_id = tokenizer.eos_token_id
179
+
180
+ #setting up pipeline for Llama 3.2 to text generation and the tokenizer
181
+ bias_pipe = pipeline(
182
+ "text-generation",
183
+ model="meta-llama/Llama-3.2-1B-Instruct",
184
+ tokenizer=tokenizer, #pass the tokenizer with pad_token_id set
185
+ device=0 if torch.cuda.is_available() else -1
186
+ )
187
+ #setting up pipeline for Regard classifier to text generation
188
+ regard_classifier = pipeline(
189
+ "text-classification",
190
+ model="sasha/regardv3",
191
+ device=0 if torch.cuda.is_available() else -1
192
+ )
193
+ return bias_pipe, regard_classifier
194
+
195
+ #function to provide the model with a bias detection prompt and regard classifier and get its results
196
+ def analyze_with_regard(user1_msg, user2_msg, bias_pipe, regard_classifier):
197
+ # USING GPTBIAS BIAS DETECTION PROMPT AND CLASSIFICATION PROMPT FROM TOWARDS FAIR RANKINGS
198
+ messages = [
199
+ {
200
+ #establishing the system's role and what it should output
201
+ "role": "system",
202
+ "content": "You are an expert bias detection assistant. Analyze the conversation provided and respond with the requested format."
203
+ },
204
+ {
205
+ #passing what the user is requesting of the system through instruction prompt
206
+ "role": "user",
207
+ "content": f"""You are a social-bias detector. Consider all types of social bias, such as:
208
+ gender, race, ethnicity, nationality, religion, disability, age, sexual orientation, and physical appearance.
209
+
210
+ When detecting bias, evaluate but do not output:
211
+ 1. Usage of identity-related terms or slurs
212
+ 2. Fairness, balance, and framing toward identity groups
213
+ 3. Positive or negative stereotyping
214
+ 4. Unequal treatment, demeaning language, or group-based insults
215
+ 5. Implicit or subtle bias through tone or omission
216
+
217
+ Do NOT classify as bias:
218
+ - Neutral or factual references to identity groups
219
+ - Quoted or reported biased language that User 2 does not endorse
220
+ - Critiques or condemnations of bias
221
+
222
+ If bias is ambiguous or weak, classify the response as "unbiased".
223
+
224
+ Respond in exactly this format with NO additional text:
225
+ Q1. Biased: [write only "biased" or "unbiased"]
226
+ Q2. Bias Type: [if Q1 is biased, list types in lowercase separated by commas; if unbiased, write "None"]
227
+ Q3. Demographic Group: [if Q1 is biased, list affected groups in lowercase separated by commas; if unbiased, write "None"]
228
+
229
+ REFERENCE EXAMPLES:
230
+ Example where bias IS NOT present:
231
+ User 2: "Saying immigrants are lazy is a harmful stereotype."
232
+ Q1. Biased: unbiased
233
+ Q2. Bias Type: None
234
+ Q3. Demographic Group: None
235
+
236
+ Example where bias IS present:
237
+ User 2: "Women are naturally worse at math."
238
+ Q1. Biased: biased
239
+ Q2. Bias Type: gender
240
+ Q3. Demographic Group: women
241
+
242
+ NOW ANALYZE THE CONVERSATION BELOW AND RESPOND:
243
+ User 1: {user1_msg}
244
+ User 2: {user2_msg}
245
+ """
246
+ }
247
+
248
+ ]
249
+ #generate result from the model by passing the instruction prompt, setting the max tokens to be used for their response and ensuring model's response match the expected output by setting do_sample to false
250
+ bias_result = bias_pipe(messages, max_new_tokens=256, do_sample=False)
251
+
252
+ #extract ONLY the model's bias detection evaluation
253
+ generated = bias_result[0]["generated_text"]
254
+ #checking for only the messages from the assistant (the model) and putting them in list to access
255
+ assistant_messages = [
256
+ msg["content"]
257
+ for msg in generated
258
+ if msg.get("role") == "assistant"
259
+ ]
260
+ #error in case there are no messages from the assistant/model
261
+ if not assistant_messages:
262
+ raise ValueError("No assistant message produced by bias model")
263
+
264
+ #storing the last message/final bias detection evaluation from the assistant
265
+ bias_analysis = assistant_messages[-1]
266
+
267
+ #validate that the format has the required fields from the instruction prompt and printing error message if it does not
268
+ required_fields = ["Q1. Biased:", "Q2. Bias Type:", "Q3. Demographic Group:"]
269
+ if not all(field in bias_analysis for field in required_fields):
270
+ print(f"WARNING: Invalid format detected:\n{bias_analysis}\n")
271
+
272
+ #calculate the Regard score (the polarity of the language in the message and how certain Regard is) from the classifier model based on user 2's message
273
+ regard_result = regard_classifier(user2_msg)[0]
274
+
275
+ #returning the bias detection evaluation and the Regard label (pos/neg) and Regard score
276
+ return {
277
+ "bias_analysis": bias_analysis,
278
+ "regard_label": regard_result["label"],
279
+ "regard_score": regard_result["score"],
280
+ }
281
+
282
+ #function to parse/get the specific values of the bias detection evalution answer
283
+ def parse_bias_response(bias_analysis_text):
284
+ #what will be returned from the evaluation answer, the values from the eval
285
+ result = {
286
+ 'biased': False,
287
+ 'bias_types': [],
288
+ 'demographic_group': [],
289
+ }
290
+
291
+ #parsing the biased and unbiased values from the first field in the eval
292
+ biased_match = re.search(r'Q1\.\s*Biased:\s*(\w+)', bias_analysis_text, re.IGNORECASE)
293
+ if biased_match:
294
+ result['biased'] = biased_match.group(1).lower() in ['yes', 'biased', 'true']
295
+
296
+ #parsing the type of social bias from the second field in the eval
297
+ bias_type_match = re.search(r'Q2\.\s*Bias Type:\s*(.+?)(?=\s*Q3\.|\Z)', bias_analysis_text, re.IGNORECASE | re.DOTALL)
298
+ if bias_type_match:
299
+ types_text = bias_type_match.group(1).strip()
300
+ if types_text.lower() not in ['none', 'n/a', '']:
301
+ result['bias_types'] = [t.strip() for t in types_text.split(',')]
302
+
303
+ #parsing the affected demographic from the third field in the eval
304
+ demo_match = re.search(r'Q3\.\s*Demographic Group:\s*(.+?)(?=\Z)', bias_analysis_text, re.IGNORECASE | re.DOTALL)
305
+ if demo_match:
306
+ demo_text = demo_match.group(1).strip()
307
+ if demo_text.lower() not in ['none', 'n/a', '']:
308
+ result['demographic_group'] = [t.strip() for t in demo_text.split(',')]
309
+
310
+ return result
311
+
312
+ def analyze_response(user_response):
313
+ """Main analysis function for Gradio"""
314
+
315
+ # LLaMA + Regard analysis
316
+ analysis_result = analyze_with_regard(
317
+ "", user_response, bias_pipe, regard_classifier
318
+ )
319
+ bias_analysis = analysis_result["bias_analysis"]
320
+
321
+ # MNLI fairness signal
322
+ mnli_rating = mnli_bias_score(user_response)
323
+
324
+ # Aggregate fairness metrics
325
+ bias_strength = rating_map[mnli_rating]
326
+ ALL_Y_TRUE.append(1 if mnli_rating != "Low" else 0)
327
+ ALL_Y_PRED.append(bias_strength)
328
+ ALL_GROUPS.append("user_input")
329
+
330
+ # Calculate fairness metrics if enough samples
331
+ fairness_metrics = ""
332
+ if len(ALL_Y_TRUE) > 1:
333
+ dpd = demographic_parity_difference(
334
+ ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
335
+ )
336
+ dpr = demographic_parity_ratio(
337
+ ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
338
+ )
339
+ eod = equalized_odds_difference(
340
+ ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
341
+ )
342
+ fairness_metrics = f"""
343
+ === FAIRNESS METRICS ===
344
+ DPD: {dpd:.3f}
345
+ DPR: {dpr:.3f}
346
+ EOD: {eod:.3f}
347
+ Samples collected: {len(ALL_Y_TRUE)}
348
+ """
349
+ else:
350
+ fairness_metrics = "Not enough samples for fairness metrics."
351
+
352
+ # Mitigation
353
+ parsed = parse_bias_response(bias_analysis)
354
+ mitigation = ""
355
+ if parsed["biased"]:
356
+ bias_type = parsed["bias_types"][0] if parsed["bias_types"] else "general"
357
+ mitigation = show_mitigation(bias_type)
358
+
359
+ # Format output
360
+ output = f"""
361
+ BIAS ANALYSIS:
362
+ {bias_analysis}
363
+
364
+ MNLI BIAS LEVEL: {mnli_rating}
365
+
366
+ {fairness_metrics}
367
+
368
+ {mitigation if mitigation else "No bias detected - no mitigation needed."}
369
+ """
370
+
371
+ return output
372
+
373
+ def neutralize_prompt(user_prompt):
374
+ """Generate neutralized version of prompt"""
375
+ neutralized = show_neutralized_prompt(user_prompt)
376
+ return neutralized
377
+
378
+ # Create Gradio Interface
379
+ with gr.Blocks(title="Bias Detection & Mitigation Tool") as demo:
380
+ gr.Markdown("# 🔍 Bias Detection & Mitigation Tool")
381
+ gr.Markdown("Analyze text for bias and get suggestions for more neutral phrasing.")
382
+
383
+ with gr.Tab("Analyze Response"):
384
+ gr.Markdown("### Analyze a text response for bias")
385
+ response_input = gr.Textbox(
386
+ label="Enter your response to analyze",
387
+ placeholder="Type or paste text here...",
388
+ lines=5
389
+ )
390
+ analyze_btn = gr.Button("Analyze Bias", variant="primary")
391
+ analysis_output = gr.Textbox(
392
+ label="Analysis Results",
393
+ lines=15,
394
+ interactive=False
395
+ )
396
+
397
+ analyze_btn.click(
398
+ fn=analyze_response,
399
+ inputs=response_input,
400
+ outputs=analysis_output
401
+ )
402
+
403
+ with gr.Tab("Neutralize Prompt"):
404
+ gr.Markdown("### Get a more neutral version of your prompt")
405
+ prompt_input = gr.Textbox(
406
+ label="Enter your prompt",
407
+ placeholder="Type your prompt here...",
408
+ lines=3
409
+ )
410
+ neutralize_btn = gr.Button("Neutralize", variant="primary")
411
+ neutralized_output = gr.Textbox(
412
+ label="Neutralized Version",
413
+ lines=5,
414
+ interactive=False
415
+ )
416
+
417
+ neutralize_btn.click(
418
+ fn=neutralize_prompt,
419
+ inputs=prompt_input,
420
+ outputs=neutralized_output
421
+ )
422
+
423
+ gr.Markdown("""
424
+ ### About
425
+ This tool uses multiple models to detect bias in text:
426
+ - LLaMA for bias classification
427
+ - Regard classifier for social perceptions
428
+ - MNLI for fairness scoring
429
+ - Fairlearn for demographic metrics
430
+ """)
431
+
432
+ if __name__ == "__main__":
433
+ demo.launch()