Spaces:
Runtime error
Runtime error
Commit
·
0743d21
1
Parent(s):
21a6ab6
created function for llama and gpt strategies
Browse files
app.py
CHANGED
|
@@ -112,6 +112,31 @@ def llama_respond(tab_name, message, chat_history):
|
|
| 112 |
time.sleep(2)
|
| 113 |
return tab_name, "", chat_history
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
| 116 |
formatted_prompt = ""
|
| 117 |
if (task_name == "POS Tagging"):
|
|
@@ -144,6 +169,38 @@ def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_
|
|
| 144 |
time.sleep(2)
|
| 145 |
return task_name, "", chat_history
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def interface():
|
| 148 |
|
| 149 |
# prompt = template_single.format(tab_name, textbox_prompt)
|
|
@@ -205,11 +262,13 @@ def interface():
|
|
| 205 |
llama_S1_chatbot = gr.Chatbot(label="llama-7b")
|
| 206 |
gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
|
| 207 |
gr.Markdown("Strategy 2 Instruction-Based Prompting")
|
|
|
|
| 208 |
with gr.Row():
|
| 209 |
vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
|
| 210 |
llama_S2_chatbot = gr.Chatbot(label="llama-7b")
|
| 211 |
gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
|
| 212 |
gr.Markdown("Strategy 3 Structured Prompting")
|
|
|
|
| 213 |
with gr.Row():
|
| 214 |
vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
|
| 215 |
llama_S3_chatbot = gr.Chatbot(label="llama-7b")
|
|
@@ -223,9 +282,9 @@ def interface():
|
|
| 223 |
# Event Handlers for Vicuna Chatbot POS/Chunk
|
| 224 |
task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
|
| 225 |
outputs=[task, task_prompt, vicuna_S1_chatbot])
|
| 226 |
-
task_btn.click(vicuna_strategies_respond, inputs=[
|
| 227 |
outputs=[task, task_prompt, vicuna_S2_chatbot])
|
| 228 |
-
task_btn.click(vicuna_strategies_respond, inputs=[
|
| 229 |
outputs=[task, task_prompt, vicuna_S3_chatbot])
|
| 230 |
|
| 231 |
# Event Handler for LLaMA Chatbot POS/Chunk
|
|
|
|
| 112 |
time.sleep(2)
|
| 113 |
return tab_name, "", chat_history
|
| 114 |
|
| 115 |
+
def gpt_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history, max_convo_length = 10):
|
| 116 |
+
formatted_system_prompt = ""
|
| 117 |
+
if (task_name == "POS Tagging"):
|
| 118 |
+
if (strategy == "S1"):
|
| 119 |
+
formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
| 120 |
+
elif (strategy == "S2"):
|
| 121 |
+
formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
| 122 |
+
elif (strategy == "S3"):
|
| 123 |
+
formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
| 124 |
+
elif (task_name == "Chunking"):
|
| 125 |
+
if (strategy == "S1"):
|
| 126 |
+
formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
| 127 |
+
elif (strategy == "S2"):
|
| 128 |
+
formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
| 129 |
+
elif (strategy == "S3"):
|
| 130 |
+
formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
| 131 |
+
|
| 132 |
+
formatted_prompt = format_chat_prompt(message, chat_history, max_convo_length)
|
| 133 |
+
print('Prompt + Context:')
|
| 134 |
+
print(formatted_prompt)
|
| 135 |
+
bot_message = chat(system_prompt = formatted_system_prompt,
|
| 136 |
+
user_prompt = formatted_prompt)
|
| 137 |
+
chat_history.append((message, bot_message))
|
| 138 |
+
return "", chat_history
|
| 139 |
+
|
| 140 |
def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
| 141 |
formatted_prompt = ""
|
| 142 |
if (task_name == "POS Tagging"):
|
|
|
|
| 169 |
time.sleep(2)
|
| 170 |
return task_name, "", chat_history
|
| 171 |
|
| 172 |
+
def llama_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
| 173 |
+
formatted_prompt = ""
|
| 174 |
+
if (task_name == "POS Tagging"):
|
| 175 |
+
if (strategy == "S1"):
|
| 176 |
+
formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
| 177 |
+
elif (strategy == "S2"):
|
| 178 |
+
formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
| 179 |
+
elif (strategy == "S3"):
|
| 180 |
+
formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
| 181 |
+
elif (task_name == "Chunking"):
|
| 182 |
+
if (strategy == "S1"):
|
| 183 |
+
formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
| 184 |
+
elif (strategy == "S2"):
|
| 185 |
+
formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
| 186 |
+
elif (strategy == "S3"):
|
| 187 |
+
formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
| 188 |
+
|
| 189 |
+
# print('Llama Strategies - Prompt + Context:')
|
| 190 |
+
# print(formatted_prompt)
|
| 191 |
+
input_ids = llama_tokenizer.encode(formatted_prompt, return_tensors="pt")
|
| 192 |
+
output_ids = llama_model.generate(input_ids, do_sample=True, max_length=1024, num_beams=5, no_repeat_ngram_size=2)
|
| 193 |
+
bot_message = llama_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 194 |
+
# print(bot_message)
|
| 195 |
+
|
| 196 |
+
# Remove formatted prompt from bot_message
|
| 197 |
+
bot_message = bot_message.replace(formatted_prompt, '')
|
| 198 |
+
# print(bot_message)
|
| 199 |
+
|
| 200 |
+
chat_history.append((formatted_prompt, bot_message))
|
| 201 |
+
time.sleep(2)
|
| 202 |
+
return task_name, "", chat_history
|
| 203 |
+
|
| 204 |
def interface():
|
| 205 |
|
| 206 |
# prompt = template_single.format(tab_name, textbox_prompt)
|
|
|
|
| 262 |
llama_S1_chatbot = gr.Chatbot(label="llama-7b")
|
| 263 |
gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
|
| 264 |
gr.Markdown("Strategy 2 Instruction-Based Prompting")
|
| 265 |
+
strategy2 = gr.Markdown("S2", visible=False)
|
| 266 |
with gr.Row():
|
| 267 |
vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
|
| 268 |
llama_S2_chatbot = gr.Chatbot(label="llama-7b")
|
| 269 |
gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
|
| 270 |
gr.Markdown("Strategy 3 Structured Prompting")
|
| 271 |
+
strategy3 = gr.Markdown("S3", visible=False)
|
| 272 |
with gr.Row():
|
| 273 |
vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
|
| 274 |
llama_S3_chatbot = gr.Chatbot(label="llama-7b")
|
|
|
|
| 282 |
# Event Handlers for Vicuna Chatbot POS/Chunk
|
| 283 |
task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
|
| 284 |
outputs=[task, task_prompt, vicuna_S1_chatbot])
|
| 285 |
+
task_btn.click(vicuna_strategies_respond, inputs=[strategy2, task, task_linguistic_entities, task_prompt, vicuna_S2_chatbot],
|
| 286 |
outputs=[task, task_prompt, vicuna_S2_chatbot])
|
| 287 |
+
task_btn.click(vicuna_strategies_respond, inputs=[strategy3, task, task_linguistic_entities, task_prompt, vicuna_S3_chatbot],
|
| 288 |
outputs=[task, task_prompt, vicuna_S3_chatbot])
|
| 289 |
|
| 290 |
# Event Handler for LLaMA Chatbot POS/Chunk
|