Update app.py
Browse files
app.py
CHANGED
|
@@ -18,18 +18,16 @@ DEFAULT_SYSTEM_PROMPT = (
|
|
| 18 |
"The first two numbers in the sequence are 1 and 1."
|
| 19 |
)
|
| 20 |
|
| 21 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
-
|
| 23 |
|
| 24 |
def load_model(model_id: str = MODEL_ID):
|
| 25 |
"""Load tokenizer and model, with a reasonable dtype and device fallback."""
|
| 26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 27 |
|
| 28 |
-
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 29 |
model = AutoModelForCausalLM.from_pretrained(
|
| 30 |
model_id,
|
| 31 |
trust_remote_code=True,
|
| 32 |
-
torch_dtype=
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
# Special tokens for stopping / filtering
|
|
@@ -57,7 +55,6 @@ def load_model(model_id: str = MODEL_ID):
|
|
| 57 |
|
| 58 |
|
| 59 |
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
|
| 60 |
-
model = model.to(device)
|
| 61 |
model.eval()
|
| 62 |
|
| 63 |
# ----------------------
|
|
@@ -70,18 +67,17 @@ def build_messages(
|
|
| 70 |
) -> List[Dict[str, str]]:
|
| 71 |
"""Transform Gradio history into chat template messages.
|
| 72 |
|
| 73 |
-
|
| 74 |
-
but we need to flip it back to (user, assistant) for the model's chat template.
|
| 75 |
"""
|
| 76 |
messages: List[Dict[str, str]] = []
|
| 77 |
if system_prompt.strip():
|
| 78 |
messages.append({"role": "system", "content": system_prompt.strip()})
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
for
|
| 82 |
-
if model_user:
|
| 83 |
messages.append({"role": "user", "content": model_user})
|
| 84 |
-
if human_assistant:
|
| 85 |
messages.append({"role": "assistant", "content": human_assistant})
|
| 86 |
|
| 87 |
return messages
|
|
@@ -118,8 +114,8 @@ def is_verbatim_repetition(
|
|
| 118 |
if new_text_normalized == system_prompt.strip().lower():
|
| 119 |
return True
|
| 120 |
|
| 121 |
-
# Check against previous model user messages (
|
| 122 |
-
for
|
| 123 |
if model_user and new_text_normalized == model_user.strip().lower():
|
| 124 |
return True
|
| 125 |
|
|
@@ -151,7 +147,7 @@ def generate_reply(
|
|
| 151 |
messages,
|
| 152 |
return_tensors="pt",
|
| 153 |
add_generation_prompt=True,
|
| 154 |
-
).to(device)
|
| 155 |
|
| 156 |
with torch.no_grad():
|
| 157 |
outputs = model.generate(
|
|
@@ -202,10 +198,9 @@ def respond(
|
|
| 202 |
|
| 203 |
Flow:
|
| 204 |
- If history empty: Generate first user message (ignores assistant_message input)
|
| 205 |
-
- If history exists
|
| 206 |
-
- If history exists without assistant message: Warning to user
|
| 207 |
|
| 208 |
-
History format: (
|
| 209 |
"""
|
| 210 |
|
| 211 |
# First message generation - ignore any text in the assistant box
|
|
@@ -222,24 +217,21 @@ def respond(
|
|
| 222 |
top_p=top_p,
|
| 223 |
)
|
| 224 |
|
| 225 |
-
# Start conversation with first user message
|
| 226 |
-
chat_history = [(
|
| 227 |
return chat_history, chat_history
|
| 228 |
|
| 229 |
# Subsequent messages - require assistant response
|
| 230 |
if not assistant_message.strip():
|
| 231 |
# User clicked generate without providing assistant response
|
| 232 |
-
# Just return current state without changes
|
| 233 |
gr.Info(
|
| 234 |
"Please type your assistant response before generating the next user message."
|
| 235 |
)
|
| 236 |
return chat_history, chat_history
|
| 237 |
|
| 238 |
-
# Update
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
_, last_model_user = chat_history[-1]
|
| 242 |
-
chat_history[-1] = (assistant_message.strip(), last_model_user)
|
| 243 |
|
| 244 |
# Build messages for next user turn generation
|
| 245 |
messages = build_messages(system_prompt, chat_history)
|
|
@@ -253,8 +245,8 @@ def respond(
|
|
| 253 |
top_p=top_p,
|
| 254 |
)
|
| 255 |
|
| 256 |
-
# Add new model user message
|
| 257 |
-
chat_history
|
| 258 |
|
| 259 |
return chat_history, chat_history
|
| 260 |
|
|
@@ -271,7 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 271 |
f"""
|
| 272 |
# UserLM-8b: User Language Model Demo
|
| 273 |
|
| 274 |
-
**Model:** `{MODEL_ID}`
|
| 275 |
|
| 276 |
The AI plays the user, you play the assistant.
|
| 277 |
"""
|
|
@@ -285,11 +277,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 285 |
placeholder="Enter the user's goal or intent",
|
| 286 |
)
|
| 287 |
|
| 288 |
-
# Display with role labels to clarify the reversal
|
| 289 |
chatbot = gr.Chatbot(
|
| 290 |
height=420,
|
| 291 |
label="Conversation",
|
| 292 |
-
avatar_images=(None, None), # Remove default avatars to avoid confusion
|
| 293 |
)
|
| 294 |
|
| 295 |
with gr.Row():
|
|
@@ -308,7 +298,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 308 |
submit_btn = gr.Button("Generate", variant="primary")
|
| 309 |
clear_btn = gr.Button("Clear")
|
| 310 |
|
| 311 |
-
state = gr.State([]) # chat history
|
| 312 |
|
| 313 |
with gr.Accordion("Implementation Details", open=False):
|
| 314 |
gr.Markdown(
|
|
|
|
| 18 |
"The first two numbers in the sequence are 1 and 1."
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def load_model(model_id: str = MODEL_ID):
|
| 23 |
"""Load tokenizer and model, with a reasonable dtype and device fallback."""
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 25 |
|
|
|
|
| 26 |
model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
model_id,
|
| 28 |
trust_remote_code=True,
|
| 29 |
+
torch_dtype="auto",
|
| 30 |
+
device_map="auto",
|
| 31 |
)
|
| 32 |
|
| 33 |
# Special tokens for stopping / filtering
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
|
|
|
|
| 58 |
model.eval()
|
| 59 |
|
| 60 |
# ----------------------
|
|
|
|
| 67 |
) -> List[Dict[str, str]]:
|
| 68 |
"""Transform Gradio history into chat template messages.
|
| 69 |
|
| 70 |
+
History is stored as (model_user, human_assistant) tuples.
|
|
|
|
| 71 |
"""
|
| 72 |
messages: List[Dict[str, str]] = []
|
| 73 |
if system_prompt.strip():
|
| 74 |
messages.append({"role": "system", "content": system_prompt.strip()})
|
| 75 |
|
| 76 |
+
# Each tuple is (model_user, human_assistant)
|
| 77 |
+
for model_user, human_assistant in history:
|
| 78 |
+
if model_user:
|
| 79 |
messages.append({"role": "user", "content": model_user})
|
| 80 |
+
if human_assistant:
|
| 81 |
messages.append({"role": "assistant", "content": human_assistant})
|
| 82 |
|
| 83 |
return messages
|
|
|
|
| 114 |
if new_text_normalized == system_prompt.strip().lower():
|
| 115 |
return True
|
| 116 |
|
| 117 |
+
# Check against previous model user messages (first element in tuple)
|
| 118 |
+
for model_user, _ in history:
|
| 119 |
if model_user and new_text_normalized == model_user.strip().lower():
|
| 120 |
return True
|
| 121 |
|
|
|
|
| 147 |
messages,
|
| 148 |
return_tensors="pt",
|
| 149 |
add_generation_prompt=True,
|
| 150 |
+
).to(model.device)
|
| 151 |
|
| 152 |
with torch.no_grad():
|
| 153 |
outputs = model.generate(
|
|
|
|
| 198 |
|
| 199 |
Flow:
|
| 200 |
- If history empty: Generate first user message (ignores assistant_message input)
|
| 201 |
+
- If history exists: Add assistant response and generate next user turn
|
|
|
|
| 202 |
|
| 203 |
+
History format: (model_user, human_assistant)
|
| 204 |
"""
|
| 205 |
|
| 206 |
# First message generation - ignore any text in the assistant box
|
|
|
|
| 217 |
top_p=top_p,
|
| 218 |
)
|
| 219 |
|
| 220 |
+
# Start conversation with first user message (empty assistant slot)
|
| 221 |
+
chat_history = [(user_reply, None)]
|
| 222 |
return chat_history, chat_history
|
| 223 |
|
| 224 |
# Subsequent messages - require assistant response
|
| 225 |
if not assistant_message.strip():
|
| 226 |
# User clicked generate without providing assistant response
|
|
|
|
| 227 |
gr.Info(
|
| 228 |
"Please type your assistant response before generating the next user message."
|
| 229 |
)
|
| 230 |
return chat_history, chat_history
|
| 231 |
|
| 232 |
+
# Update the last tuple with the assistant response
|
| 233 |
+
last_model_user, _ = chat_history[-1]
|
| 234 |
+
chat_history[-1] = (last_model_user, assistant_message.strip())
|
|
|
|
|
|
|
| 235 |
|
| 236 |
# Build messages for next user turn generation
|
| 237 |
messages = build_messages(system_prompt, chat_history)
|
|
|
|
| 245 |
top_p=top_p,
|
| 246 |
)
|
| 247 |
|
| 248 |
+
# Add new model user message (with empty assistant slot)
|
| 249 |
+
chat_history.append((user_reply, None))
|
| 250 |
|
| 251 |
return chat_history, chat_history
|
| 252 |
|
|
|
|
| 263 |
f"""
|
| 264 |
# UserLM-8b: User Language Model Demo
|
| 265 |
|
| 266 |
+
**Model:** `{MODEL_ID}`
|
| 267 |
|
| 268 |
The AI plays the user, you play the assistant.
|
| 269 |
"""
|
|
|
|
| 277 |
placeholder="Enter the user's goal or intent",
|
| 278 |
)
|
| 279 |
|
|
|
|
| 280 |
chatbot = gr.Chatbot(
|
| 281 |
height=420,
|
| 282 |
label="Conversation",
|
|
|
|
| 283 |
)
|
| 284 |
|
| 285 |
with gr.Row():
|
|
|
|
| 298 |
submit_btn = gr.Button("Generate", variant="primary")
|
| 299 |
clear_btn = gr.Button("Clear")
|
| 300 |
|
| 301 |
+
state = gr.State([]) # chat history: List[Tuple[model_user, human_assistant]]
|
| 302 |
|
| 303 |
with gr.Accordion("Implementation Details", open=False):
|
| 304 |
gr.Markdown(
|