Spaces:
Runtime error
Runtime error
0.11 simplifying wo pharia
Browse files
app.py
CHANGED
|
@@ -8,13 +8,15 @@ import os
|
|
| 8 |
|
| 9 |
from threading import Thread
|
| 10 |
|
|
|
|
|
|
|
| 11 |
logging.basicConfig(level=logging.DEBUG)
|
| 12 |
|
| 13 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 14 |
login(token=HF_TOKEN)
|
| 15 |
|
| 16 |
models_available = [
|
| 17 |
-
"
|
| 18 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 19 |
]
|
| 20 |
|
|
@@ -58,7 +60,6 @@ def load_model_a(model_id):
|
|
| 58 |
device_map="auto",
|
| 59 |
trust_remote_code=True,
|
| 60 |
).eval()
|
| 61 |
-
model_a.tie_weights()
|
| 62 |
return gr.update(label=model_id)
|
| 63 |
|
| 64 |
def load_model_b(model_id):
|
|
@@ -97,29 +98,17 @@ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_token
|
|
| 97 |
new_messages_a = system_prompt_list + chat_history_a + input_text_list
|
| 98 |
new_messages_b = system_prompt_list + chat_history_b + input_text_list
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
if "Pharia" in model_id_b:
|
| 113 |
-
logging.debug("model b is Pharia based, applying own template")
|
| 114 |
-
formatted_message_b = apply_chat_template(new_messages_a, add_generation_prompt=True)
|
| 115 |
-
logging.debug(f"***** formatted message is {formatted_message_b}")
|
| 116 |
-
input_ids_b = tokenizer_b(formatted_message_b, return_tensors="pt").input_ids.to(model_b.device)
|
| 117 |
-
else:
|
| 118 |
-
input_ids_b = tokenizer_b.apply_chat_template(
|
| 119 |
-
new_messages_b,
|
| 120 |
-
add_generation_prompt=True,
|
| 121 |
-
return_tensors="pt"
|
| 122 |
-
).to(model_b.device)
|
| 123 |
|
| 124 |
generation_kwargs_a = dict(
|
| 125 |
input_ids=input_ids_a,
|
|
|
|
| 8 |
|
| 9 |
from threading import Thread
|
| 10 |
|
| 11 |
+
# Status: Breaks during generation
|
| 12 |
+
|
| 13 |
logging.basicConfig(level=logging.DEBUG)
|
| 14 |
|
| 15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 16 |
login(token=HF_TOKEN)
|
| 17 |
|
| 18 |
models_available = [
|
| 19 |
+
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
| 20 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 21 |
]
|
| 22 |
|
|
|
|
| 60 |
device_map="auto",
|
| 61 |
trust_remote_code=True,
|
| 62 |
).eval()
|
|
|
|
| 63 |
return gr.update(label=model_id)
|
| 64 |
|
| 65 |
def load_model_b(model_id):
|
|
|
|
| 98 |
new_messages_a = system_prompt_list + chat_history_a + input_text_list
|
| 99 |
new_messages_b = system_prompt_list + chat_history_b + input_text_list
|
| 100 |
|
| 101 |
+
input_ids_a = tokenizer_a.apply_chat_template(
|
| 102 |
+
new_messages_a,
|
| 103 |
+
add_generation_prompt=True,
|
| 104 |
+
return_tensors="pt"
|
| 105 |
+
).to(model_a.device)
|
| 106 |
+
|
| 107 |
+
input_ids_b = tokenizer_b.apply_chat_template(
|
| 108 |
+
new_messages_b,
|
| 109 |
+
add_generation_prompt=True,
|
| 110 |
+
return_tensors="pt"
|
| 111 |
+
).to(model_b.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
generation_kwargs_a = dict(
|
| 114 |
input_ids=input_ids_a,
|