mariamSoub commited on
Commit
3bc62c1
·
verified ·
1 Parent(s): 6ac33ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -88,28 +88,35 @@ tokenizer.save_pretrained("./bias-eliminator-model")
88
  bias_prompt_eliminator = pipeline("text-generation", model="./bias-eliminator-model", tokenizer="./bias-eliminator-model")
89
 
90
  def show_neutralized_prompt(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
91
  sep = " -> "
92
  input_text_format = input_text + sep
93
-
94
- result = bias_prompt_eliminator(
95
- input_text_format,
96
- max_length=60,
97
- do_sample=False
98
- )
99
 
100
  generated_text = result[0]['generated_text']
101
 
102
- if sep in generated_text:
103
- output = generated_text.split(sep)[-1].strip()
104
 
 
 
 
 
 
 
 
 
105
 
106
- if len(output.split()) < 3:
107
- return "Rewrite using neutral, inclusive language."
108
-
109
- return output.capitalize()
110
-
111
- return "Could not generate neutral version."
112
-
113
  # FAIRNESS MODEL (MNLI)
114
  mnli_model_name = "facebookAI/roberta-large-mnli"
115
  mnli_tokenizer = AutoTokenizer.from_pretrained(mnli_model_name)
 
88
  bias_prompt_eliminator = pipeline("text-generation", model="./bias-eliminator-model", tokenizer="./bias-eliminator-model")
89
 
90
  def show_neutralized_prompt(input_text):
91
+ # input into retrained gpt2 model requires the format:
92
+ # "<input_text><text sep>"
93
+ #
94
+ # Where: <input_text> is the user prompt
95
+ # <text sep> is the string " -> "
96
+ #
97
+ # Example:
98
+ #
99
+ # <input text> = "Explain why immigrants struggle with career advancement in public services."
100
+ # Input format to model is:
101
+ # <input_text><text sep> = "Explain why immigrants struggle with career advancement in public services. ->"
102
+
103
  sep = " -> "
104
  input_text_format = input_text + sep
105
+ result = bias_prompt_eliminator(input_text_format, max_length=30, num_return_sequences=1)
 
 
 
 
 
106
 
107
  generated_text = result[0]['generated_text']
108
 
109
+ first = generated_text.find(sep)
 
110
 
111
+ if first != -1:
112
+ second = generated_text.find(sep, first +len(sep))
113
+ else:
114
+ second = -1
115
+ if second != -1:
116
+ print(generated_text[0:second])
117
+ else:
118
+ print(generated_text[0:first])
119
 
 
 
 
 
 
 
 
120
  # FAIRNESS MODEL (MNLI)
121
  mnli_model_name = "facebookAI/roberta-large-mnli"
122
  mnli_tokenizer = AutoTokenizer.from_pretrained(mnli_model_name)