pszemraj commited on
Commit
c502b05
·
verified ·
1 Parent(s): b84bed8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -31
app.py CHANGED
@@ -18,18 +18,16 @@ DEFAULT_SYSTEM_PROMPT = (
18
  "The first two numbers in the sequence are 1 and 1."
19
  )
20
 
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
 
24
  def load_model(model_id: str = MODEL_ID):
25
  """Load tokenizer and model, with a reasonable dtype and device fallback."""
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
 
28
- dtype = torch.float16 if device == "cuda" else torch.float32
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
31
  trust_remote_code=True,
32
- torch_dtype=dtype,
 
33
  )
34
 
35
  # Special tokens for stopping / filtering
@@ -57,7 +55,6 @@ def load_model(model_id: str = MODEL_ID):
57
 
58
 
59
  tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
60
- model = model.to(device)
61
  model.eval()
62
 
63
  # ----------------------
@@ -70,18 +67,17 @@ def build_messages(
70
  ) -> List[Dict[str, str]]:
71
  """Transform Gradio history into chat template messages.
72
 
73
- IMPORTANT: History is stored as (human_assistant_msg, model_user_msg) for display,
74
- but we need to flip it back to (user, assistant) for the model's chat template.
75
  """
76
  messages: List[Dict[str, str]] = []
77
  if system_prompt.strip():
78
  messages.append({"role": "system", "content": system_prompt.strip()})
79
 
80
- # Flip the roles: history stores (human's assistant msg, model's user msg)
81
- for human_assistant, model_user in history:
82
- if model_user: # Model's user message
83
  messages.append({"role": "user", "content": model_user})
84
- if human_assistant: # Human's assistant response
85
  messages.append({"role": "assistant", "content": human_assistant})
86
 
87
  return messages
@@ -118,8 +114,8 @@ def is_verbatim_repetition(
118
  if new_text_normalized == system_prompt.strip().lower():
119
  return True
120
 
121
- # Check against previous model user messages (stored in second position)
122
- for _, model_user in history:
123
  if model_user and new_text_normalized == model_user.strip().lower():
124
  return True
125
 
@@ -151,7 +147,7 @@ def generate_reply(
151
  messages,
152
  return_tensors="pt",
153
  add_generation_prompt=True,
154
- ).to(device)
155
 
156
  with torch.no_grad():
157
  outputs = model.generate(
@@ -202,10 +198,9 @@ def respond(
202
 
203
  Flow:
204
  - If history empty: Generate first user message (ignores assistant_message input)
205
- - If history exists with assistant message: Add it and generate next user turn
206
- - If history exists without assistant message: Warning to user
207
 
208
- History format: (human_assistant_msg, model_user_msg) for proper display
209
  """
210
 
211
  # First message generation - ignore any text in the assistant box
@@ -222,24 +217,21 @@ def respond(
222
  top_p=top_p,
223
  )
224
 
225
- # Start conversation with first user message
226
- chat_history = [("", user_reply)]
227
  return chat_history, chat_history
228
 
229
  # Subsequent messages - require assistant response
230
  if not assistant_message.strip():
231
  # User clicked generate without providing assistant response
232
- # Just return current state without changes
233
  gr.Info(
234
  "Please type your assistant response before generating the next user message."
235
  )
236
  return chat_history, chat_history
237
 
238
- # Update history with human's assistant message
239
- if len(chat_history) > 0:
240
- # Fill in the human's assistant response for the last turn
241
- _, last_model_user = chat_history[-1]
242
- chat_history[-1] = (assistant_message.strip(), last_model_user)
243
 
244
  # Build messages for next user turn generation
245
  messages = build_messages(system_prompt, chat_history)
@@ -253,8 +245,8 @@ def respond(
253
  top_p=top_p,
254
  )
255
 
256
- # Add new model user message to history
257
- chat_history = chat_history + [("", user_reply)]
258
 
259
  return chat_history, chat_history
260
 
@@ -271,7 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
271
  f"""
272
  # UserLM-8b: User Language Model Demo
273
 
274
- **Model:** `{MODEL_ID}` | **Device:** `{device}`
275
 
276
  The AI plays the user, you play the assistant.
277
  """
@@ -285,11 +277,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
285
  placeholder="Enter the user's goal or intent",
286
  )
287
 
288
- # Display with role labels to clarify the reversal
289
  chatbot = gr.Chatbot(
290
  height=420,
291
  label="Conversation",
292
- avatar_images=(None, None), # Remove default avatars to avoid confusion
293
  )
294
 
295
  with gr.Row():
@@ -308,7 +298,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
308
  submit_btn = gr.Button("Generate", variant="primary")
309
  clear_btn = gr.Button("Clear")
310
 
311
- state = gr.State([]) # chat history state: List[Tuple[human_assistant, model_user]]
312
 
313
  with gr.Accordion("Implementation Details", open=False):
314
  gr.Markdown(
 
18
  "The first two numbers in the sequence are 1 and 1."
19
  )
20
 
 
 
21
 
22
  def load_model(model_id: str = MODEL_ID):
23
  """Load tokenizer and model, with a reasonable dtype and device fallback."""
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
  trust_remote_code=True,
29
+ torch_dtype="auto",
30
+ device_map="auto",
31
  )
32
 
33
  # Special tokens for stopping / filtering
 
55
 
56
 
57
  tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
 
58
  model.eval()
59
 
60
  # ----------------------
 
67
  ) -> List[Dict[str, str]]:
68
  """Transform Gradio history into chat template messages.
69
 
70
+ History is stored as (model_user, human_assistant) tuples.
 
71
  """
72
  messages: List[Dict[str, str]] = []
73
  if system_prompt.strip():
74
  messages.append({"role": "system", "content": system_prompt.strip()})
75
 
76
+ # Each tuple is (model_user, human_assistant)
77
+ for model_user, human_assistant in history:
78
+ if model_user:
79
  messages.append({"role": "user", "content": model_user})
80
+ if human_assistant:
81
  messages.append({"role": "assistant", "content": human_assistant})
82
 
83
  return messages
 
114
  if new_text_normalized == system_prompt.strip().lower():
115
  return True
116
 
117
+ # Check against previous model user messages (first element in tuple)
118
+ for model_user, _ in history:
119
  if model_user and new_text_normalized == model_user.strip().lower():
120
  return True
121
 
 
147
  messages,
148
  return_tensors="pt",
149
  add_generation_prompt=True,
150
+ ).to(model.device)
151
 
152
  with torch.no_grad():
153
  outputs = model.generate(
 
198
 
199
  Flow:
200
  - If history empty: Generate first user message (ignores assistant_message input)
201
+ - If history exists: Add assistant response and generate next user turn
 
202
 
203
+ History format: (model_user, human_assistant)
204
  """
205
 
206
  # First message generation - ignore any text in the assistant box
 
217
  top_p=top_p,
218
  )
219
 
220
+ # Start conversation with first user message (empty assistant slot)
221
+ chat_history = [(user_reply, None)]
222
  return chat_history, chat_history
223
 
224
  # Subsequent messages - require assistant response
225
  if not assistant_message.strip():
226
  # User clicked generate without providing assistant response
 
227
  gr.Info(
228
  "Please type your assistant response before generating the next user message."
229
  )
230
  return chat_history, chat_history
231
 
232
+ # Update the last tuple with the assistant response
233
+ last_model_user, _ = chat_history[-1]
234
+ chat_history[-1] = (last_model_user, assistant_message.strip())
 
 
235
 
236
  # Build messages for next user turn generation
237
  messages = build_messages(system_prompt, chat_history)
 
245
  top_p=top_p,
246
  )
247
 
248
+ # Add new model user message (with empty assistant slot)
249
+ chat_history.append((user_reply, None))
250
 
251
  return chat_history, chat_history
252
 
 
263
  f"""
264
  # UserLM-8b: User Language Model Demo
265
 
266
+ **Model:** `{MODEL_ID}`
267
 
268
  The AI plays the user, you play the assistant.
269
  """
 
277
  placeholder="Enter the user's goal or intent",
278
  )
279
 
 
280
  chatbot = gr.Chatbot(
281
  height=420,
282
  label="Conversation",
 
283
  )
284
 
285
  with gr.Row():
 
298
  submit_btn = gr.Button("Generate", variant="primary")
299
  clear_btn = gr.Button("Clear")
300
 
301
+ state = gr.State([]) # chat history: List[Tuple[model_user, human_assistant]]
302
 
303
  with gr.Accordion("Implementation Details", open=False):
304
  gr.Markdown(