pszemraj commited on
Commit
b84bed8
·
verified ·
1 Parent(s): d9a9b75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -75
app.py CHANGED
@@ -65,43 +65,88 @@ model.eval()
65
  # ----------------------
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
69
- """Check if generated text meets length requirements (Guardrail 3)."""
 
 
 
 
 
70
  word_count = len(text.split())
71
  return min_words <= word_count <= max_words
72
 
73
 
74
  def is_verbatim_repetition(
75
- new_text: str, history: List[Dict], system_prompt: str
76
  ) -> bool:
77
- """Check if text is exact repetition. History is now list of message dicts."""
78
  new_text_normalized = new_text.strip().lower()
79
 
 
80
  if new_text_normalized == system_prompt.strip().lower():
81
  return True
82
 
83
- # Check against previous user messages
84
- for msg in history:
85
- if msg.get("role") == "user" and msg.get("content"):
86
- if new_text_normalized == msg["content"].strip().lower():
87
- return True
88
 
89
  return False
90
 
91
 
92
  @spaces.GPU
93
- def generate_user_message(
94
  messages: List[Dict[str, str]],
95
- history: List[Dict],
96
  system_prompt: str,
97
  max_new_tokens: int = 256,
98
  temperature: float = 1.0,
99
  top_p: float = 0.8,
100
  max_retries: int = 5,
101
  ) -> str:
102
- """Generate a user message with guardrails from Appendix C.1."""
 
 
 
 
 
 
 
103
 
104
  for attempt in range(max_retries):
 
105
  inputs = tokenizer.apply_chat_template(
106
  messages,
107
  return_tensors="pt",
@@ -117,9 +162,10 @@ def generate_user_message(
117
  max_new_tokens=max_new_tokens,
118
  eos_token_id=EOS_TOKEN_ID,
119
  pad_token_id=tokenizer.eos_token_id,
120
- bad_words_ids=BAD_WORDS_IDS,
121
  )
122
 
 
123
  generated = outputs[0][inputs.shape[1] :]
124
  text = tokenizer.decode(generated, skip_special_tokens=True).strip()
125
 
@@ -130,10 +176,13 @@ def generate_user_message(
130
  if is_verbatim_repetition(text, history, system_prompt):
131
  continue
132
 
 
133
  return text
134
 
135
- # If all retries failed
136
- return "(Unable to generate valid response after multiple attempts)"
 
 
137
 
138
 
139
  # ----------------------
@@ -141,35 +190,30 @@ def generate_user_message(
141
  # ----------------------
142
 
143
 
144
- def generate_next_turn(
145
- assistant_response: str,
146
- chat_history: List[Dict],
147
  system_prompt: str,
148
  max_new_tokens: int,
149
  temperature: float,
150
  top_p: float,
151
  ):
152
- """
153
- History format: List of {"role": "user"/"assistant", "content": "..."}
154
- - "user" role = UserLM (displays LEFT)
155
- - "assistant" role = Human (displays RIGHT)
156
- """
157
 
158
- # If we have an assistant response, add it to history
159
- if assistant_response.strip():
160
- chat_history.append(
161
- {"role": "assistant", "content": assistant_response.strip()}
162
- )
163
 
164
- # Build messages for UserLM from history
165
- messages = []
166
- if system_prompt.strip():
167
- messages.append({"role": "system", "content": system_prompt.strip()})
168
- messages.extend(chat_history)
 
 
169
 
170
- # Generate next user message
171
- try:
172
- user_msg = generate_user_message(
173
  messages,
174
  chat_history,
175
  system_prompt,
@@ -177,17 +221,46 @@ def generate_next_turn(
177
  temperature=temperature,
178
  top_p=top_p,
179
  )
180
- except Exception as e:
181
- user_msg = f"(Generation error: {e})"
182
 
183
- # Add new user message to history
184
- new_history = chat_history + [{"role": "user", "content": user_msg}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- return "", new_history, "Generate Next User Message"
187
 
188
 
189
- def clear_conversation():
190
- return [], DEFAULT_SYSTEM_PROMPT, [], "Generate First User Message", []
191
 
192
 
193
  # ----------------------
@@ -198,13 +271,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
198
  f"""
199
  # UserLM-8b: User Language Model Demo
200
 
201
- **How to use:**
202
- 1. Set the user's intent below
203
- 2. Click "Generate First User Message"
204
- 3. Type your assistant response and click "Generate Next User Message"
205
- 4. Repeat step 3 to continue the conversation
206
 
207
- **Model:** `{MODEL_ID}` on **{device}**
208
  """
209
  )
210
 
@@ -213,20 +282,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
213
  label="User Intent",
214
  value=DEFAULT_SYSTEM_PROMPT,
215
  lines=3,
216
- placeholder="Enter what the user wants to accomplish",
217
  )
218
 
 
219
  chatbot = gr.Chatbot(
220
  height=420,
221
  label="Conversation",
222
- type="messages", # Changed from tuples to have more control
223
- # Will manually format messages with role attribute
224
  )
225
 
226
  with gr.Row():
227
  msg = gr.Textbox(
228
- label="Your Assistant Response",
229
- placeholder="Type your assistant response here (leave empty for first turn)",
230
  lines=2,
231
  )
232
 
@@ -236,47 +305,49 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
236
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
237
 
238
  with gr.Row():
239
- submit_btn = gr.Button("Generate User Message", variant="primary")
240
  clear_btn = gr.Button("Clear")
241
 
242
- state = gr.State([])
243
 
244
  with gr.Accordion("Implementation Details", open=False):
245
  gr.Markdown(
246
  """
247
- ### Generation Strategy
248
-
249
- Based on [Appendix C.1](https://arxiv.org/abs/2510.06552), this implements:
250
- - **Sampling:** temp=1.0, top_p=0.8 (paper recommendations)
251
- - **First token filtering:** Blocks I/You/Here to prevent repetition
252
- - **Length constraints:** 3-50 words to avoid revealing entire intent at once
253
- - **Repetition filtering:** Prevents verbatim copies of prior turns
254
-
255
- **Note:** UserLM simulates human users, not assistants. You play the assistant role.
256
- """
257
  )
258
 
259
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
260
- new_msg, new_history = generate_next_turn(
261
- asst_text, history, system_prompt, mnt, temp, tp
262
- )
263
- return new_msg, new_history, new_history
264
 
265
  submit_btn.click(
266
  fn=_submit,
267
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
268
- outputs=[msg, state, chatbot],
269
  )
270
  msg.submit(
271
  fn=_submit,
272
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
273
- outputs=[msg, state, chatbot],
274
  )
275
 
276
- clear_btn.click(
277
- fn=clear_conversation,
278
- outputs=[state, system_box, chatbot],
279
- )
 
 
 
 
 
 
 
280
 
281
  if __name__ == "__main__":
282
  demo.queue().launch()
 
65
  # ----------------------
66
 
67
 
68
+ def build_messages(
69
+ system_prompt: str, history: List[Tuple[str, str]]
70
+ ) -> List[Dict[str, str]]:
71
+ """Transform Gradio history into chat template messages.
72
+
73
+ IMPORTANT: History is stored as (human_assistant_msg, model_user_msg) for display,
74
+ but we need to flip it back to (user, assistant) for the model's chat template.
75
+ """
76
+ messages: List[Dict[str, str]] = []
77
+ if system_prompt.strip():
78
+ messages.append({"role": "system", "content": system_prompt.strip()})
79
+
80
+ # Flip the roles: history stores (human's assistant msg, model's user msg)
81
+ for human_assistant, model_user in history:
82
+ if model_user: # Model's user message
83
+ messages.append({"role": "user", "content": model_user})
84
+ if human_assistant: # Human's assistant response
85
+ messages.append({"role": "assistant", "content": human_assistant})
86
+
87
+ return messages
88
+
89
+
90
+ def apply_first_token_filter(
91
+ logits: torch.Tensor, filter_ids: List[int]
92
+ ) -> torch.Tensor:
93
+ """Apply logit filter for problematic first tokens (Guardrail 1)."""
94
+ logits_filtered = logits.clone()
95
+ for token_id in filter_ids:
96
+ logits_filtered[0, -1, token_id] = float("-inf")
97
+ return logits_filtered
98
+
99
+
100
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
101
+ """Check if generated text meets length requirements (Guardrail 3).
102
+
103
+ Paper used max_words=25 for their simulation experiments, but we use 50
104
+ for interactive demo to allow slightly longer responses while still preventing
105
+ the model from revealing the entire intent at once.
106
+ """
107
  word_count = len(text.split())
108
  return min_words <= word_count <= max_words
109
 
110
 
111
  def is_verbatim_repetition(
112
+ new_text: str, history: List[Tuple[str, str]], system_prompt: str
113
  ) -> bool:
114
+ """Check if text is exact repetition of prior user turn or system prompt (Guardrail 4)."""
115
  new_text_normalized = new_text.strip().lower()
116
 
117
+ # Check against system prompt
118
  if new_text_normalized == system_prompt.strip().lower():
119
  return True
120
 
121
+ # Check against previous model user messages (stored in second position)
122
+ for _, model_user in history:
123
+ if model_user and new_text_normalized == model_user.strip().lower():
124
+ return True
 
125
 
126
  return False
127
 
128
 
129
  @spaces.GPU
130
+ def generate_reply(
131
  messages: List[Dict[str, str]],
132
+ history: List[Tuple[str, str]],
133
  system_prompt: str,
134
  max_new_tokens: int = 256,
135
  temperature: float = 1.0,
136
  top_p: float = 0.8,
137
  max_retries: int = 5,
138
  ) -> str:
139
+ """Run generation with guardrails from Appendix C.1.
140
+
141
+ Implements all 4 guardrails from the paper:
142
+ 1. Filter problematic first tokens
143
+ 2. Optionally avoid dialogue termination (disabled by default for demo)
144
+ 3. Enforce length thresholds with retry
145
+ 4. Filter verbatim repetitions with retry
146
+ """
147
 
148
  for attempt in range(max_retries):
149
+ # Prepare input ids using the model's chat template
150
  inputs = tokenizer.apply_chat_template(
151
  messages,
152
  return_tensors="pt",
 
162
  max_new_tokens=max_new_tokens,
163
  eos_token_id=EOS_TOKEN_ID,
164
  pad_token_id=tokenizer.eos_token_id,
165
+ bad_words_ids=BAD_WORDS_IDS, # Prevents <|endconversation|>
166
  )
167
 
168
+ # Slice off the prompt tokens to get only the new text
169
  generated = outputs[0][inputs.shape[1] :]
170
  text = tokenizer.decode(generated, skip_special_tokens=True).strip()
171
 
 
176
  if is_verbatim_repetition(text, history, system_prompt):
177
  continue
178
 
179
+ # Success - return the valid text
180
  return text
181
 
182
+ # If all retries failed, raise an error
183
+ raise RuntimeError(
184
+ f"Failed to generate valid response after {max_retries} attempts"
185
+ )
186
 
187
 
188
  # ----------------------
 
190
  # ----------------------
191
 
192
 
193
+ def respond(
194
+ assistant_message: str,
195
+ chat_history: List[Tuple[str, str]],
196
  system_prompt: str,
197
  max_new_tokens: int,
198
  temperature: float,
199
  top_p: float,
200
  ):
201
+ """Generate next user turn.
 
 
 
 
202
 
203
+ Flow:
204
+ - If history empty: Generate first user message (ignores assistant_message input)
205
+ - If history exists with assistant message: Add it and generate next user turn
206
+ - If history exists without assistant message: Warning to user
 
207
 
208
+ History format: (human_assistant_msg, model_user_msg) for proper display
209
+ """
210
+
211
+ # First message generation - ignore any text in the assistant box
212
+ if len(chat_history) == 0:
213
+ # Generate initial user message from system prompt alone
214
+ messages = build_messages(system_prompt, [])
215
 
216
+ user_reply = generate_reply(
 
 
217
  messages,
218
  chat_history,
219
  system_prompt,
 
221
  temperature=temperature,
222
  top_p=top_p,
223
  )
 
 
224
 
225
+ # Start conversation with first user message
226
+ chat_history = [("", user_reply)]
227
+ return chat_history, chat_history
228
+
229
+ # Subsequent messages - require assistant response
230
+ if not assistant_message.strip():
231
+ # User clicked generate without providing assistant response
232
+ # Just return current state without changes
233
+ gr.Info(
234
+ "Please type your assistant response before generating the next user message."
235
+ )
236
+ return chat_history, chat_history
237
+
238
+ # Update history with human's assistant message
239
+ if len(chat_history) > 0:
240
+ # Fill in the human's assistant response for the last turn
241
+ _, last_model_user = chat_history[-1]
242
+ chat_history[-1] = (assistant_message.strip(), last_model_user)
243
+
244
+ # Build messages for next user turn generation
245
+ messages = build_messages(system_prompt, chat_history)
246
+
247
+ user_reply = generate_reply(
248
+ messages,
249
+ chat_history,
250
+ system_prompt,
251
+ max_new_tokens=max_new_tokens,
252
+ temperature=temperature,
253
+ top_p=top_p,
254
+ )
255
+
256
+ # Add new model user message to history
257
+ chat_history = chat_history + [("", user_reply)]
258
 
259
+ return chat_history, chat_history
260
 
261
 
262
+ def clear_state():
263
+ return [], DEFAULT_SYSTEM_PROMPT
264
 
265
 
266
  # ----------------------
 
271
  f"""
272
  # UserLM-8b: User Language Model Demo
273
 
274
+ **Model:** `{MODEL_ID}` | **Device:** `{device}`
 
 
 
 
275
 
276
+ The AI plays the user, you play the assistant.
277
  """
278
  )
279
 
 
282
  label="User Intent",
283
  value=DEFAULT_SYSTEM_PROMPT,
284
  lines=3,
285
+ placeholder="Enter the user's goal or intent",
286
  )
287
 
288
+ # Display with role labels to clarify the reversal
289
  chatbot = gr.Chatbot(
290
  height=420,
291
  label="Conversation",
292
+ avatar_images=(None, None), # Remove default avatars to avoid confusion
 
293
  )
294
 
295
  with gr.Row():
296
  msg = gr.Textbox(
297
+ label="Assistant Response",
298
+ placeholder="Leave empty for first generation, then type your responses",
299
  lines=2,
300
  )
301
 
 
305
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
306
 
307
  with gr.Row():
308
+ submit_btn = gr.Button("Generate", variant="primary")
309
  clear_btn = gr.Button("Clear")
310
 
311
+ state = gr.State([]) # chat history state: List[Tuple[human_assistant, model_user]]
312
 
313
  with gr.Accordion("Implementation Details", open=False):
314
  gr.Markdown(
315
  """
316
+ Based on Appendix C.1 of the UserLM paper:
317
+ - Sampling: temp=1.0, top_p=0.8
318
+ - First token filtering for problematic tokens
319
+ - Length constraints: 3-50 words
320
+ - Repetition filtering
321
+ """
 
 
 
 
322
  )
323
 
324
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
325
+ new_history, visible = respond(asst_text, history, system_prompt, mnt, temp, tp)
326
+ # Clear input box after submission
327
+ return "", visible
 
328
 
329
  submit_btn.click(
330
  fn=_submit,
331
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
332
+ outputs=[msg, chatbot],
333
  )
334
  msg.submit(
335
  fn=_submit,
336
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
337
+ outputs=[msg, chatbot],
338
  )
339
 
340
+ # Keep state in sync with the visible Chatbot
341
+ def _sync_state(chat):
342
+ return chat
343
+
344
+ chatbot.change(_sync_state, inputs=[chatbot], outputs=[state])
345
+
346
+ def _clear():
347
+ history, sys = clear_state()
348
+ return history, sys, history, ""
349
+
350
+ clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
351
 
352
  if __name__ == "__main__":
353
  demo.queue().launch()