|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b") |
|
|
DEFAULT_SYSTEM_PROMPT = ( |
|
|
"You are a user who wants to implement a special type of sequence. " |
|
|
"The sequence sums up the two previous numbers in the sequence and adds 1 to the result. " |
|
|
"The first two numbers in the sequence are 1 and 1." |
|
|
) |
|
|
|
|
|
|
|
|
def load_model(model_id: str = MODEL_ID): |
|
|
"""Load tokenizer and model, with a reasonable dtype and device fallback.""" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
|
|
|
end_token = "<|eot_id|>" |
|
|
end_conv_token = "<|endconversation|>" |
|
|
end_token_ids = tokenizer.encode(end_token, add_special_tokens=False) |
|
|
end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False) |
|
|
|
|
|
|
|
|
problematic_tokens = ["I", "You", "Here", "i", "you", "here"] |
|
|
first_token_filter_ids = [] |
|
|
for token in problematic_tokens: |
|
|
token_ids = tokenizer.encode(token, add_special_tokens=False) |
|
|
if len(token_ids) > 0: |
|
|
first_token_filter_ids.append(token_ids[0]) |
|
|
|
|
|
eos_token_id = ( |
|
|
end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id |
|
|
) |
|
|
bad_words_ids = ( |
|
|
[[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None |
|
|
) |
|
|
|
|
|
return tokenizer, model, eos_token_id, bad_words_ids, first_token_filter_ids |
|
|
|
|
|
|
|
|
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_messages( |
|
|
system_prompt: str, history: List[Tuple[str, str]] |
|
|
) -> List[Dict[str, str]]: |
|
|
"""Transform Gradio history into chat template messages. |
|
|
|
|
|
History is stored as (model_user, human_assistant) tuples. |
|
|
""" |
|
|
messages: List[Dict[str, str]] = [] |
|
|
if system_prompt.strip(): |
|
|
messages.append({"role": "system", "content": system_prompt.strip()}) |
|
|
|
|
|
|
|
|
for model_user, human_assistant in history: |
|
|
if model_user: |
|
|
messages.append({"role": "user", "content": model_user}) |
|
|
if human_assistant: |
|
|
messages.append({"role": "assistant", "content": human_assistant}) |
|
|
|
|
|
return messages |
|
|
|
|
|
|
|
|
def apply_first_token_filter( |
|
|
logits: torch.Tensor, filter_ids: List[int] |
|
|
) -> torch.Tensor: |
|
|
"""Apply logit filter for problematic first tokens (Guardrail 1).""" |
|
|
logits_filtered = logits.clone() |
|
|
for token_id in filter_ids: |
|
|
logits_filtered[0, -1, token_id] = float("-inf") |
|
|
return logits_filtered |
|
|
|
|
|
|
|
|
def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool: |
|
|
"""Check if generated text meets length requirements (Guardrail 3). |
|
|
|
|
|
Paper used max_words=25 for their simulation experiments, but we use 50 |
|
|
for interactive demo to allow slightly longer responses while still preventing |
|
|
the model from revealing the entire intent at once. |
|
|
""" |
|
|
word_count = len(text.split()) |
|
|
return min_words <= word_count <= max_words |
|
|
|
|
|
|
|
|
def is_verbatim_repetition( |
|
|
new_text: str, history: List[Tuple[str, str]], system_prompt: str |
|
|
) -> bool: |
|
|
"""Check if text is exact repetition of prior user turn or system prompt (Guardrail 4).""" |
|
|
new_text_normalized = new_text.strip().lower() |
|
|
|
|
|
|
|
|
if new_text_normalized == system_prompt.strip().lower(): |
|
|
return True |
|
|
|
|
|
|
|
|
for model_user, _ in history: |
|
|
if model_user and new_text_normalized == model_user.strip().lower(): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_reply( |
|
|
messages: List[Dict[str, str]], |
|
|
history: List[Tuple[str, str]], |
|
|
system_prompt: str, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 1.0, |
|
|
top_p: float = 0.8, |
|
|
max_retries: int = 5, |
|
|
) -> str: |
|
|
"""Run generation with guardrails from Appendix C.1. |
|
|
|
|
|
Implements all 4 guardrails from the paper: |
|
|
1. Filter problematic first tokens |
|
|
2. Optionally avoid dialogue termination (disabled by default for demo) |
|
|
3. Enforce length thresholds with retry |
|
|
4. Filter verbatim repetitions with retry |
|
|
""" |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True, |
|
|
).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=inputs, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=EOS_TOKEN_ID, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
bad_words_ids=BAD_WORDS_IDS, |
|
|
) |
|
|
|
|
|
|
|
|
generated = outputs[0][inputs.shape[1] :] |
|
|
text = tokenizer.decode(generated, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
if not is_valid_length(text): |
|
|
continue |
|
|
|
|
|
if is_verbatim_repetition(text, history, system_prompt): |
|
|
continue |
|
|
|
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
f"Failed to generate valid response after {max_retries} attempts" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def respond( |
|
|
assistant_message: str, |
|
|
chat_history: List[Tuple[str, str]], |
|
|
system_prompt: str, |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
): |
|
|
"""Generate next user turn. |
|
|
|
|
|
Flow: |
|
|
- If history empty: Generate first user message (ignores assistant_message input) |
|
|
- If history exists: Add assistant response and generate next user turn |
|
|
|
|
|
History format: (model_user, human_assistant) |
|
|
""" |
|
|
|
|
|
|
|
|
if len(chat_history) == 0: |
|
|
|
|
|
messages = build_messages(system_prompt, []) |
|
|
|
|
|
user_reply = generate_reply( |
|
|
messages, |
|
|
chat_history, |
|
|
system_prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
|
|
|
|
|
|
chat_history = [(user_reply, None)] |
|
|
return chat_history, chat_history |
|
|
|
|
|
|
|
|
if not assistant_message.strip(): |
|
|
|
|
|
gr.Info( |
|
|
"Please type your assistant response before generating the next user message." |
|
|
) |
|
|
return chat_history, chat_history |
|
|
|
|
|
|
|
|
last_model_user, _ = chat_history[-1] |
|
|
chat_history[-1] = (last_model_user, assistant_message.strip()) |
|
|
|
|
|
|
|
|
messages = build_messages(system_prompt, chat_history) |
|
|
|
|
|
user_reply = generate_reply( |
|
|
messages, |
|
|
chat_history, |
|
|
system_prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
|
|
|
|
|
|
chat_history.append((user_reply, None)) |
|
|
|
|
|
return chat_history, chat_history |
|
|
|
|
|
|
|
|
def clear_state(): |
|
|
return [], DEFAULT_SYSTEM_PROMPT |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
f""" |
|
|
# UserLM-8b: User Language Model Demo |
|
|
|
|
|
**Model:** `{MODEL_ID}` |
|
|
|
|
|
The AI plays the user, you play the assistant. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
system_box = gr.Textbox( |
|
|
label="User Intent", |
|
|
value=DEFAULT_SYSTEM_PROMPT, |
|
|
lines=3, |
|
|
placeholder="Enter the user's goal or intent", |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
height=420, |
|
|
label="Conversation", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
label="Assistant Response", |
|
|
placeholder="Leave empty for first generation, then type your responses", |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
with gr.Accordion("Generation Settings", open=False): |
|
|
max_new_tokens = gr.Slider(16, 512, value=256, step=16, label="max_new_tokens") |
|
|
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="temperature") |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Generate", variant="primary") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
state = gr.State([]) |
|
|
|
|
|
with gr.Accordion("Implementation Details", open=False): |
|
|
gr.Markdown( |
|
|
""" |
|
|
Based on Appendix C.1 of the UserLM paper: |
|
|
- Sampling: temp=1.0, top_p=0.8 |
|
|
- First token filtering for problematic tokens |
|
|
- Length constraints: 3-50 words |
|
|
- Repetition filtering |
|
|
""" |
|
|
) |
|
|
|
|
|
def _submit(asst_text, history, system_prompt, mnt, temp, tp): |
|
|
new_history, visible = respond(asst_text, history, system_prompt, mnt, temp, tp) |
|
|
|
|
|
return "", visible |
|
|
|
|
|
submit_btn.click( |
|
|
fn=_submit, |
|
|
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p], |
|
|
outputs=[msg, chatbot], |
|
|
) |
|
|
msg.submit( |
|
|
fn=_submit, |
|
|
inputs=[msg, state, system_box, max_new_tokens, temperature, top_p], |
|
|
outputs=[msg, chatbot], |
|
|
) |
|
|
|
|
|
|
|
|
def _sync_state(chat): |
|
|
return chat |
|
|
|
|
|
chatbot.change(_sync_state, inputs=[chatbot], outputs=[state]) |
|
|
|
|
|
def _clear(): |
|
|
history, sys = clear_state() |
|
|
return history, sys, history, "" |
|
|
|
|
|
clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|