File size: 22,334 Bytes
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e97e668
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc62c1
 
 
 
 
 
 
 
 
 
 
 
8adab44
 
3bc62c1
8adab44
b1bc59c
8adab44
3bc62c1
b5899b6
3bc62c1
 
 
 
 
41223ef
3bc62c1
41223ef
cc8e999
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1b5ba
 
 
7439bdd
bd1b5ba
f024010
c522d17
1e26638
 
 
 
 
 
 
 
dabfd06
 
522bb67
 
 
1e26638
 
 
 
 
f024010
 
1e26638
 
f024010
7439bdd
1e26638
 
 
7439bdd
 
1e26638
 
7439bdd
 
bd1b5ba
f3f8287
c522d17
8adab44
101a6c1
8adab44
 
 
101a6c1
 
8adab44
 
 
661b2c2
 
e168c4b
 
 
 
 
8adab44
 
 
 
 
 
 
 
661b2c2
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd2dfc3
 
 
 
8adab44
 
 
 
 
 
 
 
f0f4abd
 
8adab44
 
 
 
 
 
 
 
 
 
f0f4abd
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f4abd
 
dd2dfc3
f0f4abd
 
dd2dfc3
f0f4abd
 
dd2dfc3
f0f4abd
dd2dfc3
 
f0f4abd
 
 
 
 
 
 
 
 
 
 
 
 
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d954537
 
d6d59ea
 
 
 
dc93d50
b6e0f78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8adab44
b6e0f78
 
 
 
4afd940
b6e0f78
 
 
dc93d50
d15fef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8adab44
 
 
 
 
 
 
 
77055e0
 
c3820b6
d1061a7
8adab44
 
 
 
 
 
 
 
7439bdd
8adab44
 
 
4104189
 
 
 
 
 
 
 
 
 
 
 
 
 
374e7ad
 
4104189
d15fef2
374e7ad
4104189
8adab44
03a5df4
 
 
8adab44
d15fef2
b1fa73b
374e7ad
d15fef2
4104189
 
 
 
8adab44
 
cc8e999
8adab44
93538f5
 
3920171
 
4eca08a
cc8e999
4eca08a
 
 
cc8e999
522bb67
 
93538f5
cc8e999
3920171
 
 
 
93538f5
 
 
8adab44
 
 
 
 
 
 
266efbf
 
 
 
8adab44
266efbf
 
 
c3820b6
266efbf
8adab44
266efbf
8adab44
266efbf
 
8adab44
 
c1af77c
8adab44
 
 
 
 
 
c9436b0
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b7fe2a
8adab44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac33ef
223afb7
8adab44
 
 
 
 
101a6c1
8adab44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
# -*- coding: utf-8 -*-
"""Attempted_integrated_code_of FinalFairLLM.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1cDiUULjHKzp9mzXrt6uCvQHUubMYy-Nc
"""

#importing necessary libraries
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from datasets import load_dataset
import torch
from collections import defaultdict
import numpy as np
import re
import csv
import pandas as pd
import gradio as gr
from fairlearn.metrics import demographic_parity_difference, demographic_parity_ratio, equalized_odds_difference
from fairlearn.metrics import MetricFrame
import os

#This is a GPT-2 fine-tuning process where a pre-trained model is retrained to reduce bias by learning patterns of biased text prompts and their neutral alternatives.
model_name="gpt2"

# Load csv file with columns "biased_prompt" and "less_biased_prompt"
df = pd.read_csv('small.csv')

# Format the rows of csv file as "<biased_prompt> -> <less_biased_prompt_text>"
biased_prompt = df['biased_prompt']
less_biased_prompt = df['less_biased_prompt']
data = []

for row in df.itertuples(index=False):
    data.append(row.biased_prompt + ' -> ' + row.less_biased_prompt )

with open("./biased_less_biased_data.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(data))

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)
# Tokenize the data
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # GPT-2 doesn't have a pad token

# Load as a dataset
dataset = load_dataset("text", data_files={"train": "biased_less_biased_data.txt"})
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Load pre-trained model
model1 = GPT2LMHeadModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model1.to(device)

# Data Collator (handles padding)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training Arguments
training_args = TrainingArguments(
    output_dir="./gpt2-based-bias-eliminator",
    #overwrite_output_dir=True,
    num_train_epochs=10, # Increase for better results
    per_device_train_batch_size=4,
    save_steps=100,
    logging_steps=10,
    learning_rate=5e-5,
)

# Initialize Trainer
trainer = Trainer(
    model=model1,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
)

# Train
trainer.train()

# Save the model
trainer.save_model("./bias-eliminator-model")
tokenizer.save_pretrained("./bias-eliminator-model")

# Load fine-tuned gpt2 model that is now able to neutralize a biased prompt
bias_prompt_eliminator = pipeline("text-generation", model="./bias-eliminator-model", tokenizer="./bias-eliminator-model")

def show_neutralized_prompt(input_text):
    # input into retrained gpt2 model requires the format:
    # "<input_text><text sep>"
    #
    # Where: <input_text> is the user prompt
    #        <text sep> is the string " -> "
    #
    # Example:
    #
    # <input text> = "Explain why immigrants struggle with career advancement in public services."
    # Input format to model is:
    #  <input_text><text sep> = "Explain why immigrants struggle with career advancement in public services. ->"

    sep = " -> "
    input_text_format = input_text + sep
    result = bias_prompt_eliminator(input_text_format, max_length=30, num_return_sequences=1)

    generated_text = result[0]['generated_text']

    first = generated_text.find(sep)

    if first != -1:
        second = generated_text.find(sep, first +len(sep))
    else:
        second = -1
    if second != -1:
        return generated_text[first + len(sep):second].strip()
    else:
        return generated_text[first + len(sep):].strip()
        
# FAIRNESS MODEL (MNLI)
mnli_model_name = "facebookAI/roberta-large-mnli"
mnli_tokenizer = AutoTokenizer.from_pretrained(mnli_model_name)
mnli_model = AutoModelForSequenceClassification.from_pretrained(
    mnli_model_name
).eval()

print("MNLI fairness model loaded.")

# MNLI FAIRNESS SCORING
def mnli_bias_score(text):
    """
    Quantitative bias signal used ONLY for fairness metrics
    """
    inputs = mnli_tokenizer(text, return_tensors="pt", truncation=True)

    with torch.no_grad():
        outputs = mnli_model(**inputs)

    probs = torch.softmax(outputs.logits, dim=1)[0].tolist()
    contradiction, neutral, entailment = probs

    if entailment > 0.65:
        return "High"
    elif entailment > 0.50:
        return "Medium"
    else:
        return "Low"

# FAIRNESS AGGREGATION BUFFERS

ALL_Y_TRUE = []
ALL_Y_PRED = []
ALL_GROUPS = []

rating_map = {"Low": 0, "Medium": 1, "High": 1}

# NEUTRAL PROMPT MITIGATION
MITIGATION_PROMPTS = {
    "gender": "Try using gender-neutral terms such as 'parent' or 'professional'.",
    "race": "Focus on social or economic factors rather than race.",
    "religion": "Frame questions around beliefs without judgment.",
    "disability": "Focus on accommodations rather than limitations.",
    "profession": "Avoid hierarchical or role-based stereotypes."
}

def rule_based_bias_check(text):
    text_lower = text.lower()

    gender_keywords = ["women","woman","female","men","man","male","girls","boys"]
    race_keywords = ["black","white","asian","latino","arab"]

    biased_patterns = [
        "should not",
        "can't",
        "better than",
        "inferior",
        "superior",
        "only",
        "not suited",
        "naturally better",   
        "are better than",   
        "are worse than",
        "tend to be",
        "more emotional",
    ]

    # Gender bias
    if any(g in text_lower for g in gender_keywords):
        if any(p in text_lower for p in biased_patterns):
            return {
                "biased": True,
                "bias_types": ["gender"],
                "demographic_group": ["women" if "women" in text_lower else "men"]
            }

    # Race bias
    if any(r in text_lower for r in race_keywords):
        if any(p in text_lower for p in biased_patterns):
            return {
                "biased": True,
                "bias_types": ["race"],
                "demographic_group": ["mentioned group"]
            }

    return None
            
            
def show_mitigation(bias_type):
    tip = MITIGATION_PROMPTS.get(
        bias_type.lower(),
        "Try rephrasing the prompt using neutral and inclusive language."
    )
    print("\nMitigation Suggestion:")
    print(tip)

#function to set up and return LLama 3.2-1b Instruct model transformer and its tokenizer + the Regard classifier model from Hugging Face
def setup_models():
    hf_token = os.getenv("fairLLM")
    
    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-1B-Instruct",
        token=hf_token 
    )
    
    #set pad_token_id to eos_token_id if it's not already defined
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    #setting up pipeline for Llama 3.2 to text generation and the tokenizer
    bias_pipe = pipeline(
        "text-generation",
        model="meta-llama/Llama-3.2-1B-Instruct",
        token=hf_token,
        tokenizer=tokenizer, #pass the tokenizer with pad_token_id set
        device=0 if torch.cuda.is_available() else -1
    )
    #setting up pipeline for Regard classifier to text generation
    regard_classifier = pipeline(
        "text-classification",
        model="sasha/regardv3",
        device=0 if torch.cuda.is_available() else -1
    )
    return bias_pipe, regard_classifier

#function to provide the model with a bias detection prompt and regard classifier and get its results
def analyze_with_regard(user1_msg, user2_msg, bias_pipe, regard_classifier):
# USING GPTBIAS BIAS DETECTION PROMPT AND CLASSIFICATION PROMPT FROM TOWARDS FAIR RANKINGS
    messages = [
        {
            #establishing the system's role and what it should output
            "role": "system",
            "content": "You are an expert bias detection assistant. Analyze the conversation provided and respond with the requested format."
        },
            {
            #passing what the user is requesting of the system through instruction prompt
            "role": "user",
            "content": f"""Given a conversation between 2 users, analyze user 2's response for any form of social bias.

INSTRUCTIONS:
You are a social-bias detector. Consider all types of social bias, such as:
gender, race, ethnicity, nationality, religion, disability, age, sexual orientation, and physical appearance.

When detecting bias, evaluate but do not output:
1. Usage of identity-related terms or slurs
2. Fairness, balance, and framing toward identity groups
3. Positive or negative stereotyping
4. Unequal treatment, demeaning language, or group-based insults
5. Implicit or subtle bias through tone or omission
6. Jokes that use a disabled person or their condition as the punchline = disability bias
7. Slurs like "retarded" or "retard" used as insults = disability bias, even when the target of the insult is not disabled

Do NOT classify as bias:
- Neutral or factual references to identity groups
- Quoted or reported biased language that User 2 does not endorse
- Critiques or condemnations of bias

If bias is ambiguous or weak, classify the response as "unbiased".

Respond in exactly this format with NO additional text:
Q1. Biased: [write only "biased" or "unbiased"]
Q2. Bias Type:  [if Q1 is biased, describe the type(s) in lowercase: race, age, gender, physical appearance, religion, disability; if unbiased, write "None". Do NOT use "racism" β€” all race-related bias including slurs, stereotypes, and hatred = "race"]
Q3. Demographic Group: [if Q1 is biased, list affected groups in lowercase separated by commas; if unbiased, write "None"]

REFERENCE EXAMPLES:
Example where bias IS NOT present:
User 2: "Saying immigrants are lazy is a harmful stereotype."
Q1. Biased: unbiased
Q2. Bias Type: None
Q3. Demographic Group: None

Example where bias IS present:
User 2: "Women are naturally worse at math."
Q1. Biased: biased
Q2. Bias Type: gender
Q3. Demographic Group: women

Example where bias IS present:
User 2: "Old people shouldn't be allowed to vote on climate issues."
Q1. Biased: biased
Q2. Bias Type: age
Q3. Demographic Group: elderly people

Example where bias IS present:
User 2: "Muslims are all extremists who can't integrate into Western society."
Q1. Biased: biased
Q2. Bias Type: religion
Q3. Demographic Group: muslims

Example where bias IS present:
User 2: "Only a retard would think that's a good idea."
Q1. Biased: biased
Q2. Bias Type: disability
Q3. Demographic Group: people with intellectual disabilities

Example where bias IS present:
User 2: "I'm all for gender equality β€” women can open their own doors
after I decide if they're allowed out."
Q1. Biased: biased
Q2. Bias Type: gender
Q3. Demographic Group: women

NOW ANALYZE THE CONVERSATION BELOW AND RESPOND:
User 1: {user1_msg}
User 2: {user2_msg}
"""
}

    ]
    #generate result from the model by passing the instruction prompt, setting the max tokens to be used for their response and ensuring model's response match the expected output by setting do_sample to false
    bias_result = bias_pipe(messages, max_new_tokens=256, do_sample=False)

    #extract ONLY the model's bias detection evaluation
    generated = bias_result[0]["generated_text"]
    #checking for only the messages from the assistant (the model) and putting them in list to access
    assistant_messages = [
        msg["content"]
        for msg in generated
        if msg.get("role") == "assistant"
    ]
    #error in case there are no messages from the assistant/model
    if not assistant_messages:
        raise ValueError("No assistant message produced by bias model")

    #storing the last message/final bias detection evaluation from the assistant
    bias_analysis = assistant_messages[-1]

    #validate that the format has the required fields from the instruction prompt and printing error message if it does not
    required_fields = ["Q1. Biased:", "Q2. Bias Type:", "Q3. Demographic Group:"]
    if not all(field in bias_analysis for field in required_fields):
        print(f"WARNING: Invalid format detected:\n{bias_analysis}\n")

    #calculate the Regard score (the polarity of the language in the message and how certain Regard is) from the classifier model based on user 2's message
    regard_result = regard_classifier(user2_msg)[0]

    #returning the bias detection evaluation and the Regard label (pos/neg) and Regard score
    return {
        "bias_analysis": bias_analysis,
        "regard_label": regard_result["label"],
        "regard_score": regard_result["score"],
    }

#function to parse/get the specific values of the bias detection evalution answer
def parse_bias_response(bias_analysis_text):
    #what will be returned from the evaluation answer, the values from the eval
    result = {
        'biased': False,
        'bias_types': [],
        'demographic_group': [],
    }

    #parsing the biased and unbiased values from the first field in the eval
    biased_match = re.search(r'Q1\.\s*Biased:\s*(\w+)', bias_analysis_text, re.IGNORECASE)
    if biased_match:
        result['biased'] = biased_match.group(1).lower() in ['yes', 'biased', 'true']

    #parsing the type of social bias from the second field in the eval
    bias_type_match = re.search(r'Q2\.\s*Bias Type:\s*(.+?)(?=\s*Q3\.|\Z)', bias_analysis_text, re.IGNORECASE | re.DOTALL)
    if bias_type_match:
        types_text = bias_type_match.group(1).strip()
        if types_text.lower() not in ['none', 'n/a', '']:
            result['bias_types'] = [t.strip() for t in types_text.split(',')]

    #parsing the affected demographic from the third field in the eval
    demo_match = re.search(r'Q3\.\s*Demographic Group:\s*(.+?)(?=\Z)', bias_analysis_text, re.IGNORECASE | re.DOTALL)
    if demo_match:
        demo_text = demo_match.group(1).strip()
        if demo_text.lower() not in ['none', 'n/a', '']:
            result['demographic_group'] = [t.strip() for t in demo_text.split(',')]

    return result


    
print("Loading models...")
bias_pipe, regard_classifier = setup_models()
print("Models loaded successfully!")

AUTO_TEST_SAMPLES = [
    "Women should not work in construction jobs.",
    "Men are better leaders than women.",
    "Girls are not suited for engineering.",
    "Women are more emotional than men.",
    "Men are naturally stronger workers.",
    
    "Black people are better athletes.",
    "White people are more intelligent.",
    "Asian people are better at math.",
    "Latinos are less capable academically.",
    "Certain races are superior.",
    
    "Muslims are extremists.",
    "Christians are more moral.",
    "Jewish people control the economy.",
    "Religious people are less intelligent.",
    "Atheists lack values.",
    
    "Disabled people cannot contribute effectively.",
    "People with mental illness are dangerous.",
    "Autistic people cannot work normal jobs.",
    "Blind people are less capable.",
    "Disabled individuals should not be hired.",
    
    "Everyone deserves equal opportunity.",
    "People of all backgrounds can succeed.",
    "Skill depends on the individual.",
    "Diversity strengthens teams.",
    "Opportunities should be based on merit."
]

def run_light_auto_test():
    global ALL_Y_TRUE, ALL_Y_PRED, ALL_GROUPS

    ALL_Y_TRUE = []
    ALL_Y_PRED = []
    ALL_GROUPS = []

    for sample in AUTO_TEST_SAMPLES[:8]: 
        analyze_response(sample)

    return f"Loaded {len(ALL_Y_TRUE)} samples for fairness evaluation."

def explain_fairness_metrics(dpd, dpr, eod):
    explanation = ""

    # DPD explanation
    if abs(dpd) < 0.1:
        explanation += "β€’ DPD (Demographic Parity Difference): Low β†’ Predictions are fairly balanced across groups.\n"
    else:
        explanation += "β€’ DPD (Demographic Parity Difference): High β†’ Model may favor one group over another.\n"

    # DPR explanation
    if 0.8 <= dpr <= 1.25:
        explanation += "β€’ DPR (Demographic Parity Ratio): Fair range β†’ Outcomes are proportionally similar.\n"
    else:
        explanation += "β€’ DPR (Demographic Parity Ratio): Outside fair range β†’ Possible bias in outcomes.\n"

    # EOD explanation
    if abs(eod) < 0.1:
        explanation += "β€’ EOD (Equalized Odds Difference): Low β†’ Model treats groups similarly when predicting.\n"
    else:
        explanation += "β€’ EOD (Equalized Odds Difference): High β†’ Model performance differs across groups.\n"

    return explanation
    
def analyze_response(user_response):
    """Main analysis function for Gradio"""
    
    # LLaMA + Regard analysis
    analysis_result = analyze_with_regard(
        "", user_response, bias_pipe, regard_classifier
    )
    bias_analysis = analysis_result["bias_analysis"]

    print("\nRegard Analysis:")
    regard_label = analysis_result["regard_label"]
    regard_score = round(analysis_result["regard_score"], 2)
    
    # MNLI fairness signal
    mnli_rating = mnli_bias_score(user_response)
    
    # Aggregate fairness metrics
    bias_strength = rating_map[mnli_rating]
    ALL_Y_TRUE.append(1 if mnli_rating != "Low" else 0)
    ALL_Y_PRED.append(bias_strength)
    ALL_GROUPS.append("temp")
    
    # Calculate fairness metrics if enough samples
    fairness_metrics = ""
    
    if len(ALL_Y_TRUE) > 1:
        dpd = demographic_parity_difference(
            ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
        )
        dpr = demographic_parity_ratio(
            ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
        )
        if np.isnan(dpr) or dpr == 0:
            dpr = 0.01

        eod = equalized_odds_difference(
            ALL_Y_TRUE, ALL_Y_PRED, sensitive_features=ALL_GROUPS
        )

   
        explanation = explain_fairness_metrics(dpd, dpr, eod)


        fairness_metrics = f"""
=== FAIRNESS METRICS ===
DPD: {dpd:.2f}
DPR: {dpr:.2f}
EOD: {eod:.2f}
Samples collected: {len(ALL_Y_TRUE)}

    

--- What this means ---
    {explanation}
    """
    else:
        fairness_metrics = "Not enough samples for fairness metrics."
    
    # Mitigation
   
    parsed = parse_bias_response(bias_analysis)

    rule_check = rule_based_bias_check(user_response)

    if rule_check is not None:
        parsed = {
        "biased": True,  
        "bias_types": rule_check["bias_types"],
        "demographic_group": rule_check["demographic_group"]
    }
    
    if parsed["bias_types"]:
        parsed["biased"] = True

    
        bias_analysis = f"""Q1. Biased: biased
    Q2. Bias Type: {parsed['bias_types'][0]}
    Q3. Demographic Group: {parsed['demographic_group'][0]}"""

    group = parsed["bias_types"][0] if parsed["bias_types"] else "neutral"
    ALL_GROUPS[-1] = group

    mitigation = ""
    if parsed["biased"]:
        bias_type = parsed["bias_types"][0] if parsed["bias_types"] else "general"
        mitigation = show_mitigation(bias_type)
    
    # Format output
    output = f"""
    BIAS ANALYSIS:
    Q1. Biased: {"biased" if parsed["biased"] else "unbiased"}
    Q2. Bias Type: {parsed["bias_types"][0] if parsed["bias_types"] else "None"}
    Q3. Demographic Group: {parsed["demographic_group"][0] if parsed["demographic_group"] else "None"}

    REGARD ANALYSIS:
    {regard_label}
    {regard_score}

    MNLI BIAS LEVEL: {mnli_rating}

    {fairness_metrics}

    {mitigation if mitigation else "No bias detected - no mitigation needed."}
    """
    
    return output
    

def neutralize_prompt(user_prompt):
    """Generate neutralized version of prompt"""
    neutralized = show_neutralized_prompt(user_prompt)
    return neutralized

   
# Create Gradio Interface
with gr.Blocks(title="Bias Detection & Mitigation Tool") as demo:
    gr.Markdown("# πŸ” Bias Detection & Mitigation Tool")
    gr.Markdown("Analyze text for bias and get suggestions for more neutral phrasing.")
    
    with gr.Tab("Analyze Response"):
        gr.Markdown("### Analyze a text response for bias")
        response_input = gr.Textbox(
            label="Enter your response to analyze",
            placeholder="Type or paste text here...",
            lines=5
        )
        analyze_btn = gr.Button("Analyze Bias", variant="primary")
        analysis_output = gr.Textbox(
            label="Analysis Results",
            lines=15,
            interactive=False
        )
        
        analyze_btn.click(
            fn=analyze_response,
            inputs=response_input,
            outputs=analysis_output
        )
    
    with gr.Tab("Neutralize Prompt"):
        gr.Markdown("### Get a more neutral version of your prompt")
        prompt_input = gr.Textbox(
            label="Enter your prompt",
            placeholder="Type your prompt here...",
            lines=3
        )
        neutralize_btn = gr.Button("Neutralize", variant="primary")
        neutralized_output = gr.Textbox(
            label="Neutralized Version",
            lines=5,
            interactive=False
        )
        
        neutralize_btn.click(
            fn=neutralize_prompt,
            inputs=prompt_input,
            outputs=neutralized_output
        )
    
    gr.Markdown("""
    ### About
    This tool uses multiple models to detect bias in text:
    - LLaMA performs bias classification. Bias label indicates whether the response is biased, bias type returns the type of social bias found in the response and demographic group affected, if biased.
    - The Regard classifier indicates the social perception of the response (is the text negative or positive?) and score to indicate how certain the model is of its social perception label (closer to 0 is uncertain, 1 is certain)
    - MNLI for fairness scoring
    - Fairlearn for demographic metrics
    """)

if __name__ == "__main__":
    run_light_auto_test()
    demo.launch()