Update app.py
Browse filesactual last attempt
app.py
CHANGED
|
@@ -65,35 +65,6 @@ model.eval()
|
|
| 65 |
# ----------------------
|
| 66 |
|
| 67 |
|
| 68 |
-
def build_messages_for_userlm(
|
| 69 |
-
system_prompt: str, history: List[Tuple[str, str]]
|
| 70 |
-
) -> List[Dict[str, str]]:
|
| 71 |
-
"""Build messages for UserLM generation.
|
| 72 |
-
|
| 73 |
-
In history tuples: (user_msg, assistant_msg) where:
|
| 74 |
-
- user_msg: what UserLM previously generated
|
| 75 |
-
- assistant_msg: what the human (playing assistant) said
|
| 76 |
-
|
| 77 |
-
For UserLM training, these roles were flipped, so we need to reconstruct
|
| 78 |
-
the conversation as UserLM saw it during training.
|
| 79 |
-
"""
|
| 80 |
-
messages: List[Dict[str, str]] = []
|
| 81 |
-
|
| 82 |
-
# System prompt defines the user's intent
|
| 83 |
-
if system_prompt.strip():
|
| 84 |
-
messages.append({"role": "system", "content": system_prompt.strip()})
|
| 85 |
-
|
| 86 |
-
# Add conversation history in the format UserLM expects
|
| 87 |
-
# UserLM was trained to generate "user" role messages given prior context
|
| 88 |
-
for user_msg, assistant_msg in history:
|
| 89 |
-
if user_msg:
|
| 90 |
-
messages.append({"role": "user", "content": user_msg})
|
| 91 |
-
if assistant_msg:
|
| 92 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
| 93 |
-
|
| 94 |
-
return messages
|
| 95 |
-
|
| 96 |
-
|
| 97 |
def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
|
| 98 |
"""Check if generated text meets length requirements (Guardrail 3)."""
|
| 99 |
word_count = len(text.split())
|
|
@@ -101,21 +72,19 @@ def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
|
|
| 101 |
|
| 102 |
|
| 103 |
def is_verbatim_repetition(
|
| 104 |
-
new_text: str, history: List[
|
| 105 |
) -> bool:
|
| 106 |
-
"""Check if text is exact repetition
|
| 107 |
-
|
| 108 |
-
History format: (assistant_msg, user_msg) - so user messages are in position 1
|
| 109 |
-
"""
|
| 110 |
new_text_normalized = new_text.strip().lower()
|
| 111 |
|
| 112 |
if new_text_normalized == system_prompt.strip().lower():
|
| 113 |
return True
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
for
|
| 117 |
-
if
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
return False
|
| 121 |
|
|
@@ -123,7 +92,7 @@ def is_verbatim_repetition(
|
|
| 123 |
@spaces.GPU
|
| 124 |
def generate_user_message(
|
| 125 |
messages: List[Dict[str, str]],
|
| 126 |
-
history: List[
|
| 127 |
system_prompt: str,
|
| 128 |
max_new_tokens: int = 256,
|
| 129 |
temperature: float = 1.0,
|
|
@@ -174,38 +143,35 @@ def generate_user_message(
|
|
| 174 |
|
| 175 |
def generate_next_turn(
|
| 176 |
assistant_response: str,
|
| 177 |
-
chat_history: List[
|
| 178 |
system_prompt: str,
|
| 179 |
max_new_tokens: int,
|
| 180 |
temperature: float,
|
| 181 |
top_p: float,
|
| 182 |
):
|
| 183 |
"""
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
"""
|
| 187 |
|
| 188 |
-
# If we have an assistant response,
|
| 189 |
-
if assistant_response.strip()
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
-
# Build messages for UserLM
|
| 194 |
messages = []
|
| 195 |
if system_prompt.strip():
|
| 196 |
messages.append({"role": "system", "content": system_prompt.strip()})
|
| 197 |
-
|
| 198 |
-
for asst_msg, user_msg in chat_history:
|
| 199 |
-
if user_msg:
|
| 200 |
-
messages.append({"role": "user", "content": user_msg})
|
| 201 |
-
if asst_msg:
|
| 202 |
-
messages.append({"role": "assistant", "content": asst_msg})
|
| 203 |
|
| 204 |
# Generate next user message
|
| 205 |
try:
|
| 206 |
user_msg = generate_user_message(
|
| 207 |
messages,
|
| 208 |
-
|
| 209 |
system_prompt,
|
| 210 |
max_new_tokens=max_new_tokens,
|
| 211 |
temperature=temperature,
|
|
@@ -214,14 +180,14 @@ def generate_next_turn(
|
|
| 214 |
except Exception as e:
|
| 215 |
user_msg = f"(Generation error: {e})"
|
| 216 |
|
| 217 |
-
# Add
|
| 218 |
-
new_history = chat_history + [
|
| 219 |
|
| 220 |
return "", new_history, "Generate Next User Message"
|
| 221 |
|
| 222 |
|
| 223 |
def clear_conversation():
|
| 224 |
-
return [], DEFAULT_SYSTEM_PROMPT, "Generate First User Message"
|
| 225 |
|
| 226 |
|
| 227 |
# ----------------------
|
|
@@ -270,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 270 |
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
|
| 271 |
|
| 272 |
with gr.Row():
|
| 273 |
-
submit_btn = gr.Button("Generate
|
| 274 |
clear_btn = gr.Button("Clear")
|
| 275 |
|
| 276 |
state = gr.State([])
|
|
@@ -291,23 +257,26 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 291 |
)
|
| 292 |
|
| 293 |
def _submit(asst_text, history, system_prompt, mnt, temp, tp):
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
submit_btn.click(
|
| 297 |
fn=_submit,
|
| 298 |
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
|
| 299 |
-
outputs=[msg, state,
|
| 300 |
)
|
| 301 |
msg.submit(
|
| 302 |
fn=_submit,
|
| 303 |
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
|
| 304 |
-
outputs=[msg, state,
|
| 305 |
)
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
|
| 312 |
if __name__ == "__main__":
|
| 313 |
demo.queue().launch()
|
|
|
|
| 65 |
# ----------------------
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
|
| 69 |
"""Check if generated text meets length requirements (Guardrail 3)."""
|
| 70 |
word_count = len(text.split())
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def is_verbatim_repetition(
|
| 75 |
+
new_text: str, history: List[Dict], system_prompt: str
|
| 76 |
) -> bool:
|
| 77 |
+
"""Check if text is exact repetition. History is now list of message dicts."""
|
|
|
|
|
|
|
|
|
|
| 78 |
new_text_normalized = new_text.strip().lower()
|
| 79 |
|
| 80 |
if new_text_normalized == system_prompt.strip().lower():
|
| 81 |
return True
|
| 82 |
|
| 83 |
+
# Check against previous user messages
|
| 84 |
+
for msg in history:
|
| 85 |
+
if msg.get("role") == "user" and msg.get("content"):
|
| 86 |
+
if new_text_normalized == msg["content"].strip().lower():
|
| 87 |
+
return True
|
| 88 |
|
| 89 |
return False
|
| 90 |
|
|
|
|
| 92 |
@spaces.GPU
|
| 93 |
def generate_user_message(
|
| 94 |
messages: List[Dict[str, str]],
|
| 95 |
+
history: List[Dict],
|
| 96 |
system_prompt: str,
|
| 97 |
max_new_tokens: int = 256,
|
| 98 |
temperature: float = 1.0,
|
|
|
|
| 143 |
|
| 144 |
def generate_next_turn(
|
| 145 |
assistant_response: str,
|
| 146 |
+
chat_history: List[Dict],
|
| 147 |
system_prompt: str,
|
| 148 |
max_new_tokens: int,
|
| 149 |
temperature: float,
|
| 150 |
top_p: float,
|
| 151 |
):
|
| 152 |
"""
|
| 153 |
+
History format: List of {"role": "user"/"assistant", "content": "..."}
|
| 154 |
+
- "user" role = UserLM (displays LEFT)
|
| 155 |
+
- "assistant" role = Human (displays RIGHT)
|
| 156 |
"""
|
| 157 |
|
| 158 |
+
# If we have an assistant response, add it to history
|
| 159 |
+
if assistant_response.strip():
|
| 160 |
+
chat_history.append(
|
| 161 |
+
{"role": "assistant", "content": assistant_response.strip()}
|
| 162 |
+
)
|
| 163 |
|
| 164 |
+
# Build messages for UserLM from history
|
| 165 |
messages = []
|
| 166 |
if system_prompt.strip():
|
| 167 |
messages.append({"role": "system", "content": system_prompt.strip()})
|
| 168 |
+
messages.extend(chat_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Generate next user message
|
| 171 |
try:
|
| 172 |
user_msg = generate_user_message(
|
| 173 |
messages,
|
| 174 |
+
chat_history,
|
| 175 |
system_prompt,
|
| 176 |
max_new_tokens=max_new_tokens,
|
| 177 |
temperature=temperature,
|
|
|
|
| 180 |
except Exception as e:
|
| 181 |
user_msg = f"(Generation error: {e})"
|
| 182 |
|
| 183 |
+
# Add new user message to history
|
| 184 |
+
new_history = chat_history + [{"role": "user", "content": user_msg}]
|
| 185 |
|
| 186 |
return "", new_history, "Generate Next User Message"
|
| 187 |
|
| 188 |
|
| 189 |
def clear_conversation():
|
| 190 |
+
return [], DEFAULT_SYSTEM_PROMPT, [], "Generate First User Message", []
|
| 191 |
|
| 192 |
|
| 193 |
# ----------------------
|
|
|
|
| 236 |
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
|
| 237 |
|
| 238 |
with gr.Row():
|
| 239 |
+
submit_btn = gr.Button("Generate User Message", variant="primary")
|
| 240 |
clear_btn = gr.Button("Clear")
|
| 241 |
|
| 242 |
state = gr.State([])
|
|
|
|
| 257 |
)
|
| 258 |
|
| 259 |
def _submit(asst_text, history, system_prompt, mnt, temp, tp):
|
| 260 |
+
new_msg, new_history = generate_next_turn(
|
| 261 |
+
asst_text, history, system_prompt, mnt, temp, tp
|
| 262 |
+
)
|
| 263 |
+
return new_msg, new_history, new_history
|
| 264 |
|
| 265 |
submit_btn.click(
|
| 266 |
fn=_submit,
|
| 267 |
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
|
| 268 |
+
outputs=[msg, state, chatbot],
|
| 269 |
)
|
| 270 |
msg.submit(
|
| 271 |
fn=_submit,
|
| 272 |
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
|
| 273 |
+
outputs=[msg, state, chatbot],
|
| 274 |
)
|
| 275 |
|
| 276 |
+
clear_btn.click(
|
| 277 |
+
fn=clear_conversation,
|
| 278 |
+
outputs=[state, system_box, chatbot],
|
| 279 |
+
)
|
| 280 |
|
| 281 |
if __name__ == "__main__":
|
| 282 |
demo.queue().launch()
|