pszemraj commited on
Commit
006fc23
·
verified ·
1 Parent(s): 15dc377

fix turn representation

Browse files
Files changed (1) hide show
  1. app.py +77 -81
app.py CHANGED
@@ -65,38 +65,37 @@ model.eval()
65
  # ----------------------
66
 
67
 
68
- def build_messages(
69
  system_prompt: str, history: List[Tuple[str, str]]
70
  ) -> List[Dict[str, str]]:
71
- """Transform Gradio history [(user, assistant), ...] into chat template messages."""
 
 
 
 
 
 
 
 
72
  messages: List[Dict[str, str]] = []
 
 
73
  if system_prompt.strip():
74
  messages.append({"role": "system", "content": system_prompt.strip()})
 
 
 
75
  for user_msg, assistant_msg in history:
76
  if user_msg:
77
  messages.append({"role": "user", "content": user_msg})
78
  if assistant_msg:
79
  messages.append({"role": "assistant", "content": assistant_msg})
80
- return messages
81
-
82
 
83
- def apply_first_token_filter(
84
- logits: torch.Tensor, filter_ids: List[int]
85
- ) -> torch.Tensor:
86
- """Apply logit filter for problematic first tokens (Guardrail 1)."""
87
- logits_filtered = logits.clone()
88
- for token_id in filter_ids:
89
- logits_filtered[0, -1, token_id] = float("-inf")
90
- return logits_filtered
91
 
92
 
93
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
94
- """Check if generated text meets length requirements (Guardrail 3).
95
-
96
- Paper used max_words=25 for their simulation experiments, but we use 50
97
- for interactive demo to allow slightly longer responses while still preventing
98
- the model from revealing the entire intent at once.
99
- """
100
  word_count = len(text.split())
101
  return min_words <= word_count <= max_words
102
 
@@ -111,7 +110,7 @@ def is_verbatim_repetition(
111
  if new_text_normalized == system_prompt.strip().lower():
112
  return True
113
 
114
- # Check against previous user messages
115
  for user_msg, _ in history:
116
  if user_msg and new_text_normalized == user_msg.strip().lower():
117
  return True
@@ -120,7 +119,7 @@ def is_verbatim_repetition(
120
 
121
 
122
  @spaces.GPU
123
- def generate_reply(
124
  messages: List[Dict[str, str]],
125
  history: List[Tuple[str, str]],
126
  system_prompt: str,
@@ -129,17 +128,9 @@ def generate_reply(
129
  top_p: float = 0.8,
130
  max_retries: int = 5,
131
  ) -> str:
132
- """Run generation with guardrails from Appendix C.1.
133
-
134
- Implements all 4 guardrails from the paper:
135
- 1. Filter problematic first tokens
136
- 2. Optionally avoid dialogue termination (disabled by default for demo)
137
- 3. Enforce length thresholds with retry
138
- 4. Filter verbatim repetitions with retry
139
- """
140
 
141
  for attempt in range(max_retries):
142
- # Prepare input ids using the model's chat template
143
  inputs = tokenizer.apply_chat_template(
144
  messages,
145
  return_tensors="pt",
@@ -155,10 +146,9 @@ def generate_reply(
155
  max_new_tokens=max_new_tokens,
156
  eos_token_id=EOS_TOKEN_ID,
157
  pad_token_id=tokenizer.eos_token_id,
158
- bad_words_ids=BAD_WORDS_IDS, # Prevents <|endconversation|>
159
  )
160
 
161
- # Slice off the prompt tokens to get only the new text
162
  generated = outputs[0][inputs.shape[1] :]
163
  text = tokenizer.decode(generated, skip_special_tokens=True).strip()
164
 
@@ -169,10 +159,9 @@ def generate_reply(
169
  if is_verbatim_repetition(text, history, system_prompt):
170
  continue
171
 
172
- # Success - return the valid text
173
  return text
174
 
175
- # If all retries failed, return a fallback message
176
  return "(Unable to generate valid response after multiple attempts)"
177
 
178
 
@@ -181,32 +170,39 @@ def generate_reply(
181
  # ----------------------
182
 
183
 
184
- def respond(
185
- assistant_message: str,
186
  chat_history: List[Tuple[str, str]],
187
  system_prompt: str,
188
  max_new_tokens: int,
189
  temperature: float,
190
  top_p: float,
191
  ):
192
- """Generate next user turn.
 
193
 
194
  Flow:
195
- - If history empty + no assistant msg: Generate first user turn
196
- - If history exists: Fill in assistant response to last turn, then generate next user turn
 
 
 
 
 
 
197
  """
198
 
199
- # Update history with assistant's message (if provided)
200
- if assistant_message.strip() and len(chat_history) > 0:
201
- # Fill in the assistant response slot for the last turn
202
  last_user_msg, _ = chat_history[-1]
203
- chat_history[-1] = (last_user_msg, assistant_message.strip())
204
 
205
- # Build messages for user turn generation
206
- messages = build_messages(system_prompt, chat_history)
207
 
 
208
  try:
209
- user_reply = generate_reply(
210
  messages,
211
  chat_history,
212
  system_prompt,
@@ -215,16 +211,20 @@ def respond(
215
  top_p=top_p,
216
  )
217
  except Exception as e:
218
- user_reply = f"(Generation error: {e})"
219
 
220
  # Add new user message to history (with empty assistant slot)
221
- chat_history = chat_history + [(user_reply, "")]
 
 
 
 
222
 
223
- return chat_history, chat_history
224
 
225
 
226
- def clear_state():
227
- return [], DEFAULT_SYSTEM_PROMPT
228
 
229
 
230
  # ----------------------
@@ -236,9 +236,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
236
  # UserLM-8b: User Language Model Demo
237
 
238
  **How to use:**
239
- 1. Set the user's intent in the box below (what the user wants to accomplish)
240
- 2. Click **Generate User Message** to create the first user message
241
- 3. Type assistant responses and click Generate to continue the conversation
 
242
 
243
  **Model:** `{MODEL_ID}` on **{device}**
244
  """
@@ -249,15 +250,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
249
  label="User Intent",
250
  value=DEFAULT_SYSTEM_PROMPT,
251
  lines=3,
252
- placeholder="Enter a high-level user intent (e.g., 'You are a user who wants to...')",
253
  )
254
 
255
- chatbot = gr.Chatbot(height=420, label="Conversation")
 
 
 
 
 
256
 
257
  with gr.Row():
258
  msg = gr.Textbox(
259
- label="Assistant Response (optional for first turn)",
260
- placeholder="Leave empty to generate first user message, or type assistant response to continue",
261
  lines=2,
262
  )
263
 
@@ -267,54 +273,44 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
267
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
268
 
269
  with gr.Row():
270
- submit_btn = gr.Button("Generate User Message", variant="primary")
271
  clear_btn = gr.Button("Clear")
272
 
273
- state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
274
 
275
  with gr.Accordion("Implementation Details", open=False):
276
  gr.Markdown(
277
  """
278
  ### Generation Strategy
279
 
280
- Based on [Appendix C.1](https://arxiv.org/abs/2510.06552) of the UserLM paper, this demo implements:
281
- - **Recommended sampling:** temp=1.0, top_p=0.8 (not the typical 0.8/0.9)
282
- - **First token filtering:** Blocks problematic tokens (I, You, Here) that cause repetition
283
- - **Length constraints:** 3-50 words per turn to prevent revealing entire intent at once
284
  - **Repetition filtering:** Prevents verbatim copies of prior turns
285
 
286
- These guardrails are essential for the 8B model to produce realistic user behavior.
287
-
288
- **Note:** Unlike assistant LMs, UserLM simulates human *users* in conversations.
289
  """
290
  )
291
 
292
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
293
- new_history, visible = respond(asst_text, history, system_prompt, mnt, temp, tp)
294
- return "", visible
295
 
296
  submit_btn.click(
297
  fn=_submit,
298
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
299
- outputs=[msg, chatbot],
300
  )
301
  msg.submit(
302
  fn=_submit,
303
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
304
- outputs=[msg, chatbot],
305
  )
306
 
307
- # Keep state in sync with the visible Chatbot
308
- def _sync_state(chat):
309
- return chat
310
-
311
- chatbot.change(_sync_state, inputs=[chatbot], outputs=[state])
312
-
313
- def _clear():
314
- history, sys = clear_state()
315
- return history, sys, history, ""
316
 
317
- clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
318
 
319
  if __name__ == "__main__":
320
- demo.queue().launch()
 
65
  # ----------------------
66
 
67
 
68
+ def build_messages_for_userlm(
69
  system_prompt: str, history: List[Tuple[str, str]]
70
  ) -> List[Dict[str, str]]:
71
+ """Build messages for UserLM generation.
72
+
73
+ In history tuples: (user_msg, assistant_msg) where:
74
+ - user_msg: what UserLM previously generated
75
+ - assistant_msg: what the human (playing assistant) said
76
+
77
+ For UserLM training, these roles were flipped, so we need to reconstruct
78
+ the conversation as UserLM saw it during training.
79
+ """
80
  messages: List[Dict[str, str]] = []
81
+
82
+ # System prompt defines the user's intent
83
  if system_prompt.strip():
84
  messages.append({"role": "system", "content": system_prompt.strip()})
85
+
86
+ # Add conversation history in the format UserLM expects
87
+ # UserLM was trained to generate "user" role messages given prior context
88
  for user_msg, assistant_msg in history:
89
  if user_msg:
90
  messages.append({"role": "user", "content": user_msg})
91
  if assistant_msg:
92
  messages.append({"role": "assistant", "content": assistant_msg})
 
 
93
 
94
+ return messages
 
 
 
 
 
 
 
95
 
96
 
97
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
98
+ """Check if generated text meets length requirements (Guardrail 3)."""
 
 
 
 
 
99
  word_count = len(text.split())
100
  return min_words <= word_count <= max_words
101
 
 
110
  if new_text_normalized == system_prompt.strip().lower():
111
  return True
112
 
113
+ # Check against previous user messages (UserLM's prior outputs)
114
  for user_msg, _ in history:
115
  if user_msg and new_text_normalized == user_msg.strip().lower():
116
  return True
 
119
 
120
 
121
  @spaces.GPU
122
+ def generate_user_message(
123
  messages: List[Dict[str, str]],
124
  history: List[Tuple[str, str]],
125
  system_prompt: str,
 
128
  top_p: float = 0.8,
129
  max_retries: int = 5,
130
  ) -> str:
131
+ """Generate a user message with guardrails from Appendix C.1."""
 
 
 
 
 
 
 
132
 
133
  for attempt in range(max_retries):
 
134
  inputs = tokenizer.apply_chat_template(
135
  messages,
136
  return_tensors="pt",
 
146
  max_new_tokens=max_new_tokens,
147
  eos_token_id=EOS_TOKEN_ID,
148
  pad_token_id=tokenizer.eos_token_id,
149
+ bad_words_ids=BAD_WORDS_IDS,
150
  )
151
 
 
152
  generated = outputs[0][inputs.shape[1] :]
153
  text = tokenizer.decode(generated, skip_special_tokens=True).strip()
154
 
 
159
  if is_verbatim_repetition(text, history, system_prompt):
160
  continue
161
 
 
162
  return text
163
 
164
+ # If all retries failed
165
  return "(Unable to generate valid response after multiple attempts)"
166
 
167
 
 
170
  # ----------------------
171
 
172
 
173
+ def generate_next_turn(
174
+ assistant_response: str,
175
  chat_history: List[Tuple[str, str]],
176
  system_prompt: str,
177
  max_new_tokens: int,
178
  temperature: float,
179
  top_p: float,
180
  ):
181
+ """
182
+ Generate the next user message from UserLM.
183
 
184
  Flow:
185
+ - If chat_history is empty: Generate first user message
186
+ - If chat_history exists:
187
+ 1. Add assistant's response to last turn
188
+ 2. Generate next user message
189
+
190
+ Tuple structure: (user_message_from_userlm, assistant_response_from_human)
191
+ - Position 0 (left): UserLM's messages
192
+ - Position 1 (right): Human's assistant responses
193
  """
194
 
195
+ # If we have an assistant response, add it to the last turn
196
+ if assistant_response.strip() and len(chat_history) > 0:
 
197
  last_user_msg, _ = chat_history[-1]
198
+ chat_history = chat_history[:-1] + [(last_user_msg, assistant_response.strip())]
199
 
200
+ # Build messages for UserLM
201
+ messages = build_messages_for_userlm(system_prompt, chat_history)
202
 
203
+ # Generate next user message
204
  try:
205
+ user_msg = generate_user_message(
206
  messages,
207
  chat_history,
208
  system_prompt,
 
211
  top_p=top_p,
212
  )
213
  except Exception as e:
214
+ user_msg = f"(Generation error: {e})"
215
 
216
  # Add new user message to history (with empty assistant slot)
217
+ new_history = chat_history + [(user_msg, "")]
218
+
219
+ # Determine button text for next action
220
+ needs_assistant_response = True
221
+ button_text = "Generate Next User Message"
222
 
223
+ return "", new_history, button_text
224
 
225
 
226
+ def clear_conversation():
227
+ return [], DEFAULT_SYSTEM_PROMPT, "Generate First User Message"
228
 
229
 
230
  # ----------------------
 
236
  # UserLM-8b: User Language Model Demo
237
 
238
  **How to use:**
239
+ 1. Set the user's intent below
240
+ 2. Click "Generate First User Message"
241
+ 3. Type your assistant response and click "Generate Next User Message"
242
+ 4. Repeat step 3 to continue the conversation
243
 
244
  **Model:** `{MODEL_ID}` on **{device}**
245
  """
 
250
  label="User Intent",
251
  value=DEFAULT_SYSTEM_PROMPT,
252
  lines=3,
253
+ placeholder="Enter what the user wants to accomplish",
254
  )
255
 
256
+ chatbot = gr.Chatbot(
257
+ height=420,
258
+ label="Conversation",
259
+ type="tuples",
260
+ # Left side = UserLM (simulated user), Right side = You (playing assistant)
261
+ )
262
 
263
  with gr.Row():
264
  msg = gr.Textbox(
265
+ label="Your Assistant Response",
266
+ placeholder="Type your assistant response here (leave empty for first turn)",
267
  lines=2,
268
  )
269
 
 
273
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
274
 
275
  with gr.Row():
276
+ submit_btn = gr.Button("Generate First User Message", variant="primary")
277
  clear_btn = gr.Button("Clear")
278
 
279
+ state = gr.State([])
280
 
281
  with gr.Accordion("Implementation Details", open=False):
282
  gr.Markdown(
283
  """
284
  ### Generation Strategy
285
 
286
+ Based on [Appendix C.1](https://arxiv.org/abs/2510.06552), this implements:
287
+ - **Sampling:** temp=1.0, top_p=0.8 (paper recommendations)
288
+ - **First token filtering:** Blocks I/You/Here to prevent repetition
289
+ - **Length constraints:** 3-50 words to avoid revealing entire intent at once
290
  - **Repetition filtering:** Prevents verbatim copies of prior turns
291
 
292
+ **Note:** UserLM simulates human users, not assistants. You play the assistant role.
 
 
293
  """
294
  )
295
 
296
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
297
+ return generate_next_turn(asst_text, history, system_prompt, mnt, temp, tp)
 
298
 
299
  submit_btn.click(
300
  fn=_submit,
301
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
302
+ outputs=[msg, state, submit_btn],
303
  )
304
  msg.submit(
305
  fn=_submit,
306
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
307
+ outputs=[msg, state, submit_btn],
308
  )
309
 
310
+ # Keep chatbot display in sync with state
311
+ state.change(lambda x: x, inputs=[state], outputs=[chatbot])
 
 
 
 
 
 
 
312
 
313
+ clear_btn.click(fn=clear_conversation, outputs=[state, system_box, submit_btn])
314
 
315
  if __name__ == "__main__":
316
+ demo.queue().launch()