Spaces:
Build error
Build error
Commit
·
0f3e715
1
Parent(s):
39fb3bc
Improve inputs
Browse files
app.py
CHANGED
|
@@ -13,13 +13,7 @@ def reduce_mean(value, mask, axis=None):
|
|
| 13 |
|
| 14 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 15 |
|
| 16 |
-
max_input_len = 256
|
| 17 |
-
max_output_len = 32
|
| 18 |
-
m = 10
|
| 19 |
-
top_p = 0.5
|
| 20 |
-
|
| 21 |
class InteractiveRainier:
|
| 22 |
-
|
| 23 |
def __init__(self):
|
| 24 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
|
| 25 |
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
|
|
@@ -46,7 +40,7 @@ class InteractiveRainier:
|
|
| 46 |
choices.append(choice)
|
| 47 |
return choices
|
| 48 |
|
| 49 |
-
def run(self, question):
|
| 50 |
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
| 51 |
knowledges_ids = self.rainier_model.generate(
|
| 52 |
input_ids=tokenized.input_ids,
|
|
@@ -107,8 +101,8 @@ class InteractiveRainier:
|
|
| 107 |
|
| 108 |
rainier = InteractiveRainier()
|
| 109 |
|
| 110 |
-
def predict(question,
|
| 111 |
-
result = rainier.run(
|
| 112 |
output = ''
|
| 113 |
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
|
| 114 |
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
|
|
@@ -120,13 +114,29 @@ def predict(question, choices):
|
|
| 120 |
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
|
| 121 |
return output
|
| 122 |
|
| 123 |
-
input_question = gr.inputs.
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
output_text = gr.outputs.Textbox(label='Output')
|
| 126 |
|
| 127 |
gr.Interface(
|
| 128 |
fn=predict,
|
| 129 |
-
inputs=[input_question,
|
| 130 |
outputs=output_text,
|
| 131 |
title="Rainier",
|
| 132 |
).launch()
|
|
|
|
| 13 |
|
| 14 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class InteractiveRainier:
|
|
|
|
| 17 |
def __init__(self):
|
| 18 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
|
| 19 |
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
|
|
|
|
| 40 |
choices.append(choice)
|
| 41 |
return choices
|
| 42 |
|
| 43 |
+
def run(self, question, max_input_len, max_output_len, m, top_p):
|
| 44 |
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
| 45 |
knowledges_ids = self.rainier_model.generate(
|
| 46 |
input_ids=tokenized.input_ids,
|
|
|
|
| 101 |
|
| 102 |
rainier = InteractiveRainier()
|
| 103 |
|
| 104 |
+
def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p):
|
| 105 |
+
result = rainier.run(question, max_input_len, max_output_len, m, top_p)
|
| 106 |
output = ''
|
| 107 |
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
|
| 108 |
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
|
|
|
|
| 114 |
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
|
| 115 |
return output
|
| 116 |
|
| 117 |
+
input_question = gr.inputs.Dropdown(
|
| 118 |
+
choices=[
|
| 119 |
+
'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
|
| 120 |
+
'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
|
| 121 |
+
'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
|
| 122 |
+
'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
|
| 123 |
+
'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
|
| 124 |
+
'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
|
| 125 |
+
],
|
| 126 |
+
label='Question:',
|
| 127 |
+
info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
|
| 128 |
+
)
|
| 129 |
+
input_kg_model = gr.inputs.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False)
|
| 130 |
+
input_qa_model = gr.inputs.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False)
|
| 131 |
+
input_max_input_len = gr.inputs.Number(label='Max question length:', value=256, precision=0)
|
| 132 |
+
input_max_output_len = gr.inputs.Number(label='Max knowledge length:', value=32, precision=0)
|
| 133 |
+
input_m = gr.inputs.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1)
|
| 134 |
+
input_top_p = gr.inputs.Slider(label='Top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
|
| 135 |
output_text = gr.outputs.Textbox(label='Output')
|
| 136 |
|
| 137 |
gr.Interface(
|
| 138 |
fn=predict,
|
| 139 |
+
inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p],
|
| 140 |
outputs=output_text,
|
| 141 |
title="Rainier",
|
| 142 |
).launch()
|