pszemraj commited on
Commit
d9a9b75
·
verified ·
1 Parent(s): 5eed9b0

Update app.py

Browse files

actual last attempt

Files changed (1) hide show
  1. app.py +34 -65
app.py CHANGED
@@ -65,35 +65,6 @@ model.eval()
65
  # ----------------------
66
 
67
 
68
- def build_messages_for_userlm(
69
- system_prompt: str, history: List[Tuple[str, str]]
70
- ) -> List[Dict[str, str]]:
71
- """Build messages for UserLM generation.
72
-
73
- In history tuples: (user_msg, assistant_msg) where:
74
- - user_msg: what UserLM previously generated
75
- - assistant_msg: what the human (playing assistant) said
76
-
77
- For UserLM training, these roles were flipped, so we need to reconstruct
78
- the conversation as UserLM saw it during training.
79
- """
80
- messages: List[Dict[str, str]] = []
81
-
82
- # System prompt defines the user's intent
83
- if system_prompt.strip():
84
- messages.append({"role": "system", "content": system_prompt.strip()})
85
-
86
- # Add conversation history in the format UserLM expects
87
- # UserLM was trained to generate "user" role messages given prior context
88
- for user_msg, assistant_msg in history:
89
- if user_msg:
90
- messages.append({"role": "user", "content": user_msg})
91
- if assistant_msg:
92
- messages.append({"role": "assistant", "content": assistant_msg})
93
-
94
- return messages
95
-
96
-
97
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
98
  """Check if generated text meets length requirements (Guardrail 3)."""
99
  word_count = len(text.split())
@@ -101,21 +72,19 @@ def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
101
 
102
 
103
  def is_verbatim_repetition(
104
- new_text: str, history: List[Tuple[str, str]], system_prompt: str
105
  ) -> bool:
106
- """Check if text is exact repetition of prior user turn or system prompt (Guardrail 4).
107
-
108
- History format: (assistant_msg, user_msg) - so user messages are in position 1
109
- """
110
  new_text_normalized = new_text.strip().lower()
111
 
112
  if new_text_normalized == system_prompt.strip().lower():
113
  return True
114
 
115
- # User messages are now in position 1 of the tuple
116
- for _, user_msg in history:
117
- if user_msg and new_text_normalized == user_msg.strip().lower():
118
- return True
 
119
 
120
  return False
121
 
@@ -123,7 +92,7 @@ def is_verbatim_repetition(
123
  @spaces.GPU
124
  def generate_user_message(
125
  messages: List[Dict[str, str]],
126
- history: List[Tuple[str, str]],
127
  system_prompt: str,
128
  max_new_tokens: int = 256,
129
  temperature: float = 1.0,
@@ -174,38 +143,35 @@ def generate_user_message(
174
 
175
  def generate_next_turn(
176
  assistant_response: str,
177
- chat_history: List[Tuple[str, str]],
178
  system_prompt: str,
179
  max_new_tokens: int,
180
  temperature: float,
181
  top_p: float,
182
  ):
183
  """
184
- Tuple: (human_assistant, userlm_user)
185
- Testing if Gradio shows first element on RIGHT not LEFT
 
186
  """
187
 
188
- # If we have an assistant response, fill in position 0 of last turn
189
- if assistant_response.strip() and len(chat_history) > 0:
190
- _, last_user_msg = chat_history[-1]
191
- chat_history = chat_history[:-1] + [(assistant_response.strip(), last_user_msg)]
 
192
 
193
- # Build messages for UserLM - user msgs are now in position 1
194
  messages = []
195
  if system_prompt.strip():
196
  messages.append({"role": "system", "content": system_prompt.strip()})
197
-
198
- for asst_msg, user_msg in chat_history:
199
- if user_msg:
200
- messages.append({"role": "user", "content": user_msg})
201
- if asst_msg:
202
- messages.append({"role": "assistant", "content": asst_msg})
203
 
204
  # Generate next user message
205
  try:
206
  user_msg = generate_user_message(
207
  messages,
208
- [(u, a) for a, u in chat_history], # Swap for repetition check
209
  system_prompt,
210
  max_new_tokens=max_new_tokens,
211
  temperature=temperature,
@@ -214,14 +180,14 @@ def generate_next_turn(
214
  except Exception as e:
215
  user_msg = f"(Generation error: {e})"
216
 
217
- # Add: (empty_assistant, new_user_msg)
218
- new_history = chat_history + [("", user_msg)]
219
 
220
  return "", new_history, "Generate Next User Message"
221
 
222
 
223
  def clear_conversation():
224
- return [], DEFAULT_SYSTEM_PROMPT, "Generate First User Message"
225
 
226
 
227
  # ----------------------
@@ -270,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
270
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
271
 
272
  with gr.Row():
273
- submit_btn = gr.Button("Generate First User Message", variant="primary")
274
  clear_btn = gr.Button("Clear")
275
 
276
  state = gr.State([])
@@ -291,23 +257,26 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
291
  )
292
 
293
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
294
- return generate_next_turn(asst_text, history, system_prompt, mnt, temp, tp)
 
 
 
295
 
296
  submit_btn.click(
297
  fn=_submit,
298
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
299
- outputs=[msg, state, submit_btn],
300
  )
301
  msg.submit(
302
  fn=_submit,
303
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
304
- outputs=[msg, state, submit_btn],
305
  )
306
 
307
- # Keep chatbot display in sync with state
308
- state.change(lambda x: x, inputs=[state], outputs=[chatbot])
309
-
310
- clear_btn.click(fn=clear_conversation, outputs=[state, system_box, submit_btn])
311
 
312
  if __name__ == "__main__":
313
  demo.queue().launch()
 
65
  # ----------------------
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
69
  """Check if generated text meets length requirements (Guardrail 3)."""
70
  word_count = len(text.split())
 
72
 
73
 
74
  def is_verbatim_repetition(
75
+ new_text: str, history: List[Dict], system_prompt: str
76
  ) -> bool:
77
+ """Check if text is exact repetition. History is now list of message dicts."""
 
 
 
78
  new_text_normalized = new_text.strip().lower()
79
 
80
  if new_text_normalized == system_prompt.strip().lower():
81
  return True
82
 
83
+ # Check against previous user messages
84
+ for msg in history:
85
+ if msg.get("role") == "user" and msg.get("content"):
86
+ if new_text_normalized == msg["content"].strip().lower():
87
+ return True
88
 
89
  return False
90
 
 
92
  @spaces.GPU
93
  def generate_user_message(
94
  messages: List[Dict[str, str]],
95
+ history: List[Dict],
96
  system_prompt: str,
97
  max_new_tokens: int = 256,
98
  temperature: float = 1.0,
 
143
 
144
  def generate_next_turn(
145
  assistant_response: str,
146
+ chat_history: List[Dict],
147
  system_prompt: str,
148
  max_new_tokens: int,
149
  temperature: float,
150
  top_p: float,
151
  ):
152
  """
153
+ History format: List of {"role": "user"/"assistant", "content": "..."}
154
+ - "user" role = UserLM (displays LEFT)
155
+ - "assistant" role = Human (displays RIGHT)
156
  """
157
 
158
+ # If we have an assistant response, add it to history
159
+ if assistant_response.strip():
160
+ chat_history.append(
161
+ {"role": "assistant", "content": assistant_response.strip()}
162
+ )
163
 
164
+ # Build messages for UserLM from history
165
  messages = []
166
  if system_prompt.strip():
167
  messages.append({"role": "system", "content": system_prompt.strip()})
168
+ messages.extend(chat_history)
 
 
 
 
 
169
 
170
  # Generate next user message
171
  try:
172
  user_msg = generate_user_message(
173
  messages,
174
+ chat_history,
175
  system_prompt,
176
  max_new_tokens=max_new_tokens,
177
  temperature=temperature,
 
180
  except Exception as e:
181
  user_msg = f"(Generation error: {e})"
182
 
183
+ # Add new user message to history
184
+ new_history = chat_history + [{"role": "user", "content": user_msg}]
185
 
186
  return "", new_history, "Generate Next User Message"
187
 
188
 
189
  def clear_conversation():
190
+ return [], DEFAULT_SYSTEM_PROMPT, [], "Generate First User Message", []
191
 
192
 
193
  # ----------------------
 
236
  top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
237
 
238
  with gr.Row():
239
+ submit_btn = gr.Button("Generate User Message", variant="primary")
240
  clear_btn = gr.Button("Clear")
241
 
242
  state = gr.State([])
 
257
  )
258
 
259
  def _submit(asst_text, history, system_prompt, mnt, temp, tp):
260
+ new_msg, new_history = generate_next_turn(
261
+ asst_text, history, system_prompt, mnt, temp, tp
262
+ )
263
+ return new_msg, new_history, new_history
264
 
265
  submit_btn.click(
266
  fn=_submit,
267
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
268
+ outputs=[msg, state, chatbot],
269
  )
270
  msg.submit(
271
  fn=_submit,
272
  inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
273
+ outputs=[msg, state, chatbot],
274
  )
275
 
276
+ clear_btn.click(
277
+ fn=clear_conversation,
278
+ outputs=[state, system_box, chatbot],
279
+ )
280
 
281
  if __name__ == "__main__":
282
  demo.queue().launch()