Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
|
|
| 12 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 13 |
login(token=HF_TOKEN)
|
| 14 |
|
| 15 |
-
# MODEL = os.environ.get("MODEL_ID")
|
| 16 |
MODEL = "mistralai/Mistral-Nemo-Instruct-2407"
|
| 17 |
|
| 18 |
TITLE = "<h1><center>Mistral-Nemo</center></h1>"
|
|
@@ -46,13 +45,13 @@ footer{visibility: hidden}
|
|
| 46 |
|
| 47 |
device = "cuda" # or "cpu"
|
| 48 |
|
| 49 |
-
#
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(MODEL, fix_mistral_regex=True)
|
| 51 |
model = AutoModelForCausalLM.from_pretrained(
|
| 52 |
MODEL,
|
| 53 |
dtype=torch.bfloat16, # torch_dtype is deprecated in newer transformers
|
| 54 |
device_map="auto",
|
| 55 |
-
ignore_mismatched_sizes=True
|
| 56 |
)
|
| 57 |
|
| 58 |
|
|
@@ -68,13 +67,15 @@ def _system_prompt_for(name: str) -> str:
|
|
| 68 |
|
| 69 |
@spaces.GPU()
|
| 70 |
def get_response(conversation):
|
|
|
|
|
|
|
|
|
|
| 71 |
temperature = 0.3
|
| 72 |
max_new_tokens = 512
|
| 73 |
top_p = 1.0
|
| 74 |
top_k = 20
|
| 75 |
penalty = 1.2
|
| 76 |
|
| 77 |
-
# conversation is already in messages format [{'role', 'content'}, ...]
|
| 78 |
input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
|
| 79 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 80 |
streamer = TextIteratorStreamer(
|
|
@@ -110,21 +111,21 @@ def get_response(conversation):
|
|
| 110 |
@spaces.GPU()
|
| 111 |
def stream_chat(history, character_a, character_b):
|
| 112 |
"""
|
| 113 |
-
history: list of messages
|
| 114 |
[{"role": "user" | "assistant", "content": "..."}, ...]
|
| 115 |
|
| 116 |
In the UI:
|
| 117 |
-
-
|
| 118 |
-
-
|
| 119 |
|
| 120 |
-
Each
|
| 121 |
1. B says something new (as 'user')
|
| 122 |
2. A replies (as 'assistant')
|
| 123 |
"""
|
| 124 |
if history is None:
|
| 125 |
history = []
|
| 126 |
|
| 127 |
-
# ---------- B speaks ----------
|
| 128 |
if len(history) == 0:
|
| 129 |
# First turn: B introduces themselves to A
|
| 130 |
b_user_prompt = (
|
|
@@ -132,7 +133,7 @@ def stream_chat(history, character_a, character_b):
|
|
| 132 |
"Introduce yourself and start the conversation."
|
| 133 |
)
|
| 134 |
else:
|
| 135 |
-
#
|
| 136 |
last_msg = history[-1]
|
| 137 |
last_text = last_msg["content"]
|
| 138 |
b_user_prompt = (
|
|
@@ -148,7 +149,7 @@ def stream_chat(history, character_a, character_b):
|
|
| 148 |
response_b = get_response(conv_for_b)
|
| 149 |
print("response_b:", response_b)
|
| 150 |
|
| 151 |
-
# ---------- A speaks ----------
|
| 152 |
conv_for_a = [
|
| 153 |
{"role": "system", "content": _system_prompt_for(character_a)},
|
| 154 |
*history,
|
|
@@ -157,10 +158,10 @@ def stream_chat(history, character_a, character_b):
|
|
| 157 |
response_a = get_response(conv_for_a)
|
| 158 |
print("response_a:", response_a)
|
| 159 |
|
| 160 |
-
# ---------- Append to chat history
|
| 161 |
new_history = history + [
|
| 162 |
-
{"role": "user", "content": response_b}, # B
|
| 163 |
-
{"role": "assistant", "content": response_a}, # A
|
| 164 |
]
|
| 165 |
print("history:", new_history)
|
| 166 |
|
|
@@ -173,14 +174,12 @@ def get_img(keyword):
|
|
| 173 |
bing_crawler = BingImageCrawler(storage={"root_dir": path})
|
| 174 |
bing_crawler.crawl(keyword=keyword, max_num=1)
|
| 175 |
|
| 176 |
-
# Look for image files in the folder
|
| 177 |
for file_name in os.listdir(path):
|
| 178 |
if file_name.lower().endswith(
|
| 179 |
(".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff")
|
| 180 |
):
|
| 181 |
return os.path.join(path, file_name)
|
| 182 |
|
| 183 |
-
# If no image is found
|
| 184 |
return None
|
| 185 |
|
| 186 |
|
|
@@ -188,7 +187,7 @@ def set_characters(a, b):
|
|
| 188 |
img_a = get_img(a)
|
| 189 |
img_b = get_img(b)
|
| 190 |
# avatar_images=(user_avatar, assistant_avatar) => (B, A)
|
| 191 |
-
#
|
| 192 |
return img_a, img_b, gr.update(avatar_images=(img_b, img_a), value=[])
|
| 193 |
|
| 194 |
|
|
@@ -234,8 +233,8 @@ with gr.Blocks() as demo:
|
|
| 234 |
gr.Markdown(" ")
|
| 235 |
image_b = gr.Image(show_label=False, interactive=False)
|
| 236 |
|
| 237 |
-
# IMPORTANT: type=
|
| 238 |
-
chat = gr.Chatbot(show_label=False
|
| 239 |
submit_button = gr.Button("Start Conversation")
|
| 240 |
|
| 241 |
character_button.click(
|
|
|
|
| 12 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 13 |
login(token=HF_TOKEN)
|
| 14 |
|
|
|
|
| 15 |
MODEL = "mistralai/Mistral-Nemo-Instruct-2407"
|
| 16 |
|
| 17 |
TITLE = "<h1><center>Mistral-Nemo</center></h1>"
|
|
|
|
| 45 |
|
| 46 |
device = "cuda" # or "cpu"
|
| 47 |
|
| 48 |
+
# Recommended flag for this tokenizer
|
| 49 |
tokenizer = AutoTokenizer.from_pretrained(MODEL, fix_mistral_regex=True)
|
| 50 |
model = AutoModelForCausalLM.from_pretrained(
|
| 51 |
MODEL,
|
| 52 |
dtype=torch.bfloat16, # torch_dtype is deprecated in newer transformers
|
| 53 |
device_map="auto",
|
| 54 |
+
ignore_mismatched_sizes=True,
|
| 55 |
)
|
| 56 |
|
| 57 |
|
|
|
|
| 67 |
|
| 68 |
@spaces.GPU()
|
| 69 |
def get_response(conversation):
|
| 70 |
+
"""
|
| 71 |
+
conversation: list of {"role": "system" | "user" | "assistant", "content": str}
|
| 72 |
+
"""
|
| 73 |
temperature = 0.3
|
| 74 |
max_new_tokens = 512
|
| 75 |
top_p = 1.0
|
| 76 |
top_k = 20
|
| 77 |
penalty = 1.2
|
| 78 |
|
|
|
|
| 79 |
input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
|
| 80 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 81 |
streamer = TextIteratorStreamer(
|
|
|
|
| 111 |
@spaces.GPU()
|
| 112 |
def stream_chat(history, character_a, character_b):
|
| 113 |
"""
|
| 114 |
+
history: list of messages (messages format):
|
| 115 |
[{"role": "user" | "assistant", "content": "..."}, ...]
|
| 116 |
|
| 117 |
In the UI:
|
| 118 |
+
- user messages = Character B
|
| 119 |
+
- assistant messages = Character A
|
| 120 |
|
| 121 |
+
Each click:
|
| 122 |
1. B says something new (as 'user')
|
| 123 |
2. A replies (as 'assistant')
|
| 124 |
"""
|
| 125 |
if history is None:
|
| 126 |
history = []
|
| 127 |
|
| 128 |
+
# ---------- B speaks (user side) ----------
|
| 129 |
if len(history) == 0:
|
| 130 |
# First turn: B introduces themselves to A
|
| 131 |
b_user_prompt = (
|
|
|
|
| 133 |
"Introduce yourself and start the conversation."
|
| 134 |
)
|
| 135 |
else:
|
| 136 |
+
# Find last assistant message (A) to respond to
|
| 137 |
last_msg = history[-1]
|
| 138 |
last_text = last_msg["content"]
|
| 139 |
b_user_prompt = (
|
|
|
|
| 149 |
response_b = get_response(conv_for_b)
|
| 150 |
print("response_b:", response_b)
|
| 151 |
|
| 152 |
+
# ---------- A speaks (assistant side) ----------
|
| 153 |
conv_for_a = [
|
| 154 |
{"role": "system", "content": _system_prompt_for(character_a)},
|
| 155 |
*history,
|
|
|
|
| 158 |
response_a = get_response(conv_for_a)
|
| 159 |
print("response_a:", response_a)
|
| 160 |
|
| 161 |
+
# ---------- Append to chat history ----------
|
| 162 |
new_history = history + [
|
| 163 |
+
{"role": "user", "content": response_b}, # B's line
|
| 164 |
+
{"role": "assistant", "content": response_a}, # A's line
|
| 165 |
]
|
| 166 |
print("history:", new_history)
|
| 167 |
|
|
|
|
| 174 |
bing_crawler = BingImageCrawler(storage={"root_dir": path})
|
| 175 |
bing_crawler.crawl(keyword=keyword, max_num=1)
|
| 176 |
|
|
|
|
| 177 |
for file_name in os.listdir(path):
|
| 178 |
if file_name.lower().endswith(
|
| 179 |
(".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff")
|
| 180 |
):
|
| 181 |
return os.path.join(path, file_name)
|
| 182 |
|
|
|
|
| 183 |
return None
|
| 184 |
|
| 185 |
|
|
|
|
| 187 |
img_a = get_img(a)
|
| 188 |
img_b = get_img(b)
|
| 189 |
# avatar_images=(user_avatar, assistant_avatar) => (B, A)
|
| 190 |
+
# also reset chat history when characters change
|
| 191 |
return img_a, img_b, gr.update(avatar_images=(img_b, img_a), value=[])
|
| 192 |
|
| 193 |
|
|
|
|
| 233 |
gr.Markdown(" ")
|
| 234 |
image_b = gr.Image(show_label=False, interactive=False)
|
| 235 |
|
| 236 |
+
# IMPORTANT: no 'type=' argument here; your Gradio build doesn't support it
|
| 237 |
+
chat = gr.Chatbot(show_label=False)
|
| 238 |
submit_button = gr.Button("Start Conversation")
|
| 239 |
|
| 240 |
character_button.click(
|