better UX
Browse files
app.py
CHANGED
|
@@ -182,18 +182,31 @@ def generate_reply(
|
|
| 182 |
|
| 183 |
|
| 184 |
def respond(
|
| 185 |
-
|
| 186 |
chat_history: List[Tuple[str, str]],
|
| 187 |
system_prompt: str,
|
| 188 |
max_new_tokens: int,
|
| 189 |
temperature: float,
|
| 190 |
top_p: float,
|
| 191 |
):
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
try:
|
| 196 |
-
|
| 197 |
messages,
|
| 198 |
chat_history,
|
| 199 |
system_prompt,
|
|
@@ -202,9 +215,11 @@ def respond(
|
|
| 202 |
top_p=top_p,
|
| 203 |
)
|
| 204 |
except Exception as e:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
chat_history = chat_history + [(user_message, reply)]
|
| 208 |
return chat_history, chat_history
|
| 209 |
|
| 210 |
|
|
@@ -220,86 +235,62 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 220 |
f"""
|
| 221 |
# UserLM-8b: User Language Model Demo
|
| 222 |
|
| 223 |
-
**
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
- Enforces length thresholds (3-50 words per turn)
|
| 228 |
-
- Prevents verbatim repetition of prior turns
|
| 229 |
-
- Uses recommended sampling params: temp=1.0, top_p=0.8
|
| 230 |
|
| 231 |
-
**
|
| 232 |
-
The system prompt defines the user's high-level intent.
|
| 233 |
"""
|
| 234 |
)
|
| 235 |
|
| 236 |
with gr.Row():
|
| 237 |
system_box = gr.Textbox(
|
| 238 |
-
label="User Intent
|
| 239 |
value=DEFAULT_SYSTEM_PROMPT,
|
| 240 |
lines=3,
|
| 241 |
placeholder="Enter a high-level user intent (e.g., 'You are a user who wants to...')",
|
| 242 |
)
|
| 243 |
|
| 244 |
-
chatbot = gr.Chatbot(height=420, label="
|
| 245 |
|
| 246 |
with gr.Row():
|
| 247 |
msg = gr.Textbox(
|
| 248 |
-
label="Assistant Response",
|
| 249 |
-
placeholder="
|
| 250 |
lines=2,
|
| 251 |
)
|
| 252 |
|
| 253 |
-
with gr.Accordion(
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
16,
|
| 258 |
-
512,
|
| 259 |
-
value=256,
|
| 260 |
-
step=16,
|
| 261 |
-
label="max_new_tokens",
|
| 262 |
-
info="Max tokens per user turn. Paper used stricter limits for simulation.",
|
| 263 |
-
)
|
| 264 |
-
temperature = gr.Slider(
|
| 265 |
-
0.0,
|
| 266 |
-
2.0,
|
| 267 |
-
value=1.0,
|
| 268 |
-
step=0.05,
|
| 269 |
-
label="temperature",
|
| 270 |
-
info="Paper recommends 1.0 for realistic user diversity",
|
| 271 |
-
)
|
| 272 |
-
top_p = gr.Slider(
|
| 273 |
-
0.0,
|
| 274 |
-
1.0,
|
| 275 |
-
value=0.8,
|
| 276 |
-
step=0.01,
|
| 277 |
-
label="top_p",
|
| 278 |
-
info="Paper recommends 0.8 (not 0.9)",
|
| 279 |
-
)
|
| 280 |
|
| 281 |
with gr.Row():
|
| 282 |
-
submit_btn = gr.Button("Generate User
|
| 283 |
clear_btn = gr.Button("Clear")
|
| 284 |
|
| 285 |
state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
|
| 286 |
|
| 287 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
"""
|
| 289 |
-
### Usage Tips:
|
| 290 |
-
- The **system prompt** defines the user's goal (keep it high-level, not overly specific)
|
| 291 |
-
- Type what the **assistant says** in response
|
| 292 |
-
- Click **Generate User Response** to simulate how a human user would reply
|
| 293 |
-
- UserLM naturally reveals intent across multiple turns, not all at once
|
| 294 |
-
"""
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
def _submit(user_text, history, system_prompt, mnt, temp, tp):
|
| 298 |
-
if not user_text or not user_text.strip():
|
| 299 |
-
return gr.update(), history
|
| 300 |
-
new_history, visible = respond(
|
| 301 |
-
user_text.strip(), history, system_prompt, mnt, temp, tp
|
| 302 |
)
|
|
|
|
|
|
|
|
|
|
| 303 |
return "", visible
|
| 304 |
|
| 305 |
submit_btn.click(
|
|
@@ -326,4 +317,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 326 |
clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
|
| 327 |
|
| 328 |
if __name__ == "__main__":
|
| 329 |
-
demo.queue().launch()
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
def respond(
|
| 185 |
+
assistant_message: str,
|
| 186 |
chat_history: List[Tuple[str, str]],
|
| 187 |
system_prompt: str,
|
| 188 |
max_new_tokens: int,
|
| 189 |
temperature: float,
|
| 190 |
top_p: float,
|
| 191 |
):
|
| 192 |
+
"""Generate next user turn.
|
| 193 |
+
|
| 194 |
+
Flow:
|
| 195 |
+
- If history empty + no assistant msg: Generate first user turn
|
| 196 |
+
- If history exists: Fill in assistant response to last turn, then generate next user turn
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Update history with assistant's message (if provided)
|
| 200 |
+
if assistant_message.strip() and len(chat_history) > 0:
|
| 201 |
+
# Fill in the assistant response slot for the last turn
|
| 202 |
+
last_user_msg, _ = chat_history[-1]
|
| 203 |
+
chat_history[-1] = (last_user_msg, assistant_message.strip())
|
| 204 |
+
|
| 205 |
+
# Build messages for user turn generation
|
| 206 |
+
messages = build_messages(system_prompt, chat_history)
|
| 207 |
|
| 208 |
try:
|
| 209 |
+
user_reply = generate_reply(
|
| 210 |
messages,
|
| 211 |
chat_history,
|
| 212 |
system_prompt,
|
|
|
|
| 215 |
top_p=top_p,
|
| 216 |
)
|
| 217 |
except Exception as e:
|
| 218 |
+
user_reply = f"(Generation error: {e})"
|
| 219 |
+
|
| 220 |
+
# Add new user message to history (with empty assistant slot)
|
| 221 |
+
chat_history = chat_history + [(user_reply, "")]
|
| 222 |
|
|
|
|
| 223 |
return chat_history, chat_history
|
| 224 |
|
| 225 |
|
|
|
|
| 235 |
f"""
|
| 236 |
# UserLM-8b: User Language Model Demo
|
| 237 |
|
| 238 |
+
**How to use:**
|
| 239 |
+
1. Set the user's intent in the box below (what the user wants to accomplish)
|
| 240 |
+
2. Click **Generate User Message** to create the first user message
|
| 241 |
+
3. Type assistant responses and click Generate to continue the conversation
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
**Model:** `{MODEL_ID}` on **{device}**
|
|
|
|
| 244 |
"""
|
| 245 |
)
|
| 246 |
|
| 247 |
with gr.Row():
|
| 248 |
system_box = gr.Textbox(
|
| 249 |
+
label="User Intent",
|
| 250 |
value=DEFAULT_SYSTEM_PROMPT,
|
| 251 |
lines=3,
|
| 252 |
placeholder="Enter a high-level user intent (e.g., 'You are a user who wants to...')",
|
| 253 |
)
|
| 254 |
|
| 255 |
+
chatbot = gr.Chatbot(height=420, label="Conversation")
|
| 256 |
|
| 257 |
with gr.Row():
|
| 258 |
msg = gr.Textbox(
|
| 259 |
+
label="Assistant Response (optional for first turn)",
|
| 260 |
+
placeholder="Leave empty to generate first user message, or type assistant response to continue",
|
| 261 |
lines=2,
|
| 262 |
)
|
| 263 |
|
| 264 |
+
with gr.Accordion("Generation Settings", open=False):
|
| 265 |
+
max_new_tokens = gr.Slider(16, 512, value=256, step=16, label="max_new_tokens")
|
| 266 |
+
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="temperature")
|
| 267 |
+
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
with gr.Row():
|
| 270 |
+
submit_btn = gr.Button("Generate User Message", variant="primary")
|
| 271 |
clear_btn = gr.Button("Clear")
|
| 272 |
|
| 273 |
state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
|
| 274 |
|
| 275 |
+
with gr.Accordion("Implementation Details", open=False):
|
| 276 |
+
gr.Markdown(
|
| 277 |
+
"""
|
| 278 |
+
### Generation Strategy
|
| 279 |
+
|
| 280 |
+
Based on [Appendix C.1](https://arxiv.org/abs/2510.06552) of the UserLM paper, this demo implements:
|
| 281 |
+
- **Recommended sampling:** temp=1.0, top_p=0.8 (not the typical 0.8/0.9)
|
| 282 |
+
- **First token filtering:** Blocks problematic tokens (I, You, Here) that cause repetition
|
| 283 |
+
- **Length constraints:** 3-50 words per turn to prevent revealing entire intent at once
|
| 284 |
+
- **Repetition filtering:** Prevents verbatim copies of prior turns
|
| 285 |
+
|
| 286 |
+
These guardrails are essential for the 8B model to produce realistic user behavior.
|
| 287 |
+
|
| 288 |
+
**Note:** Unlike assistant LMs, UserLM simulates human *users* in conversations.
|
| 289 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
)
|
| 291 |
+
|
| 292 |
+
def _submit(asst_text, history, system_prompt, mnt, temp, tp):
|
| 293 |
+
new_history, visible = respond(asst_text, history, system_prompt, mnt, temp, tp)
|
| 294 |
return "", visible
|
| 295 |
|
| 296 |
submit_btn.click(
|
|
|
|
| 317 |
clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
|
| 318 |
|
| 319 |
if __name__ == "__main__":
|
| 320 |
+
demo.queue().launch()
|