from __future__ import annotations import os from typing import Any, Dict, List, Optional, Tuple import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList # ====================== # Config # ====================== MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b") DEFAULT_SYSTEM_PROMPT = ( "You are a user who wants to compute rolling 7-day averages over uneven time stamps. " "You are suspicious of resampling magic and will accuse the assistant of witchcraft if it's not explicit." ) # ====================== # Load model # ====================== def load_model(model_id: str = MODEL_ID): tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) mdl = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ) # Special tokens eot = "<|eot_id|>" end_conv = "<|endconversation|>" eot_ids = tok.encode(eot, add_special_tokens=False) end_conv_ids = tok.encode(end_conv, add_special_tokens=False) eos_token_id = eot_ids[0] if len(eot_ids) > 0 else tok.eos_token_id bad_words_ids = [[tid] for tid in end_conv_ids] if len(end_conv_ids) > 0 else None # Guardrail 1: problematic first tokens (Appendix C.1) prob_first_tokens = ["I", "You", "Here", "i", "you", "here"] first_token_filter_ids = [] for w in prob_first_tokens: ids = tok.encode(w, add_special_tokens=False) if ids: first_token_filter_ids.append(ids[0]) return tok, mdl, eos_token_id, bad_words_ids, first_token_filter_ids tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model() model.generation_config.eos_token_id = EOS_TOKEN_ID model.generation_config.pad_token_id = tokenizer.eos_token_id model.eval() # ====================== # Guardrail helpers # ====================== def is_valid_length(text: str, min_words: int = 3, max_words: int = 25) -> bool: wc = len(text.split()) return min_words <= wc <= max_words def is_verbatim_repetition( new_text: str, history_pairs: List[Tuple[str, Optional[str]]], system_prompt: str ) -> bool: t = new_text.strip().lower() if t == system_prompt.strip().lower(): return True for model_user, _ in history_pairs: if model_user and t == model_user.strip().lower(): return True return False class ForbidFirstToken(LogitsProcessor): """Set -inf on a token list for the *first* generated token only.""" def __init__(self, forbid_ids: List[int], prompt_len: int): self.forbid = list(set(int(x) for x in forbid_ids)) self.prompt_len = int(prompt_len) def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: # Apply only when generating the very first token (seq len == prompt_len) if input_ids.shape[1] == self.prompt_len and self.forbid: scores[:, self.forbid] = float("-inf") return scores # ====================== # Message utilities # ====================== def build_hf_messages( system_prompt: str, history_pairs: List[Tuple[str, Optional[str]]] ) -> List[Dict[str, str]]: """ Construct messages for tokenizer.apply_chat_template. history_pairs = list of (model_user, human_assistant) """ msgs: List[Dict[str, str]] = [] if system_prompt.strip(): msgs.append({"role": "system", "content": system_prompt.strip()}) for model_user, human_assistant in history_pairs: if model_user: msgs.append({"role": "user", "content": model_user}) if human_assistant: msgs.append({"role": "assistant", "content": human_assistant}) return msgs def pairs_to_ui_messages( history_pairs: List[Tuple[str, Optional[str]]] ) -> List[Dict[str, str]]: """ Convert (model_user, human_assistant) pairs to Gradio Chatbot(type='messages') UI messages. Visual convention: - LEFT (role='assistant'): UserLM's utterances (the simulator) - RIGHT (role='user'): Your replies (you play the assistant) """ ui: List[Dict[str, str]] = [] for model_user, human_assistant in history_pairs: if model_user: ui.append({"role": "assistant", "content": model_user}) if human_assistant: ui.append({"role": "user", "content": human_assistant}) return ui # ====================== # Generation # ====================== @spaces.GPU def generate_reply( system_prompt: str, history_pairs: List[Tuple[str, Optional[str]]], max_new_tokens: int = 128, temperature: float = 1.0, top_p: float = 0.8, max_retries: int = 10, ) -> str: """Implements the 4 guardrails from Appendix C.1 and passes an explicit attention_mask.""" messages = build_hf_messages(system_prompt, history_pairs) inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(model.device) # Robust attention mask even when pad_token_id == eos_token_id. # If no padding is present (usual single-sequence case), use all-ones. pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if pad_id is not None and (inputs == pad_id).any(): attention_mask = (inputs != pad_id).long() else: attention_mask = torch.ones_like(inputs, dtype=torch.long) for _ in range(max_retries): lp = LogitsProcessorList( [ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])] ) with torch.no_grad(): out = model.generate( input_ids=inputs, attention_mask=attention_mask, # <-- explicit mask to silence warning & be robust do_sample=True, top_p=top_p, temperature=temperature, max_new_tokens=max_new_tokens, eos_token_id=EOS_TOKEN_ID, pad_token_id=tokenizer.eos_token_id, bad_words_ids=BAD_WORDS_IDS, # Guardrail 2: block <|endconversation|> logits_processor=lp, # Guardrail 1: first-token filter ) gen = out[0][inputs.shape[1]:] text = tokenizer.decode(gen, skip_special_tokens=True).strip() # Guardrails 3 & 4 if not is_valid_length(text, min_words=3, max_words=25): continue if is_verbatim_repetition(text, history_pairs, system_prompt): continue return text raise RuntimeError("Failed to generate a valid user utterance after retries.") # ====================== # Gradio UI # ====================== def respond( your_reply: str, history_pairs: List[Tuple[str, Optional[str]]], system_prompt: str, max_new_tokens: int, temperature: float, top_p: float, ): # First turn: ignore your_reply and generate the initial UserLM utterance if not history_pairs: userlm = generate_reply( system_prompt, [], max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) history_pairs = [(userlm, None)] return pairs_to_ui_messages(history_pairs), history_pairs, "" # Subsequent turns require your reply if not your_reply.strip(): gr.Info("Type your (assistant) reply on the right, then click Generate.") return pairs_to_ui_messages(history_pairs), history_pairs, "" # Close the last pair with your reply last_userlm, _ = history_pairs[-1] history_pairs[-1] = (last_userlm, your_reply.strip()) # Generate the next UserLM utterance userlm = generate_reply( system_prompt, history_pairs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) history_pairs.append((userlm, None)) return pairs_to_ui_messages(history_pairs), history_pairs, "" def _clear(): return [], [], DEFAULT_SYSTEM_PROMPT, "" with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # UserLM-8b: User Language Model Demo **Model:** `{MODEL_ID}` The AI plays the **user**, you play the **assistant**. Your messages appear on the **right**. """ ) system_box = gr.Textbox( label="User Intent", value=DEFAULT_SYSTEM_PROMPT, lines=3, placeholder="Enter the user's goal or intent", ) # Use messages format so we can control left/right explicitly chatbot = gr.Chatbot( label="Conversation", height=420, type="messages", # modern format; tuples are deprecated render_markdown=True, autoscroll=True, show_copy_button=True, # You can set avatar images like: avatar_images=("assets/you.png", "assets/userlm.png") ) # Your reply box (you play the assistant) msg = gr.Textbox( label="Your Reply (assistant)", placeholder="Type your assistant response here…", info="Leave blank & press _Generate_ to create the **first user message**.", lines=2, ) with gr.Accordion("Generation Settings", open=False): max_new_tokens = gr.Slider(16, 512, value=128, step=16, label="max_new_tokens") temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="temperature") top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p") with gr.Row(): submit_btn = gr.Button("Generate", variant="primary") clear_btn = gr.Button("Clear") # Internal state keeps the compact (userLM, you) pairs used for decoding history_pairs_state = gr.State([]) # List[Tuple[str, Optional[str]]] with gr.Accordion("Implementation Details", open=False): gr.Markdown( """ - Decoding defaults from [the model card](https://hf.co/microsoft/UserLM-8b): `temperature=1.0`, `top_p=0.8`, stop on `<|eot_id|>`, and block `<|endconversation|>`. - Guardrails from Appendix C.1 [of the paper](https://arxiv.org/abs/2510.06552): (1) first-token logit filter, (2) block endconversation, (3) 3–25 word length, (4) verbatim repetition filter. """ ) def _submit(your_text, pairs, sys_prompt, mnt, temp, tp): ui_msgs, new_pairs, cleared_text = respond( your_text, pairs, sys_prompt, mnt, temp, tp ) return ui_msgs, new_pairs, cleared_text submit_btn.click( fn=_submit, inputs=[ msg, history_pairs_state, system_box, max_new_tokens, temperature, top_p, ], outputs=[chatbot, history_pairs_state, msg], ) msg.submit( fn=_submit, inputs=[ msg, history_pairs_state, system_box, max_new_tokens, temperature, top_p, ], outputs=[chatbot, history_pairs_state, msg], ) clear_btn.click( fn=_clear, outputs=[chatbot, history_pairs_state, system_box, msg], ) if __name__ == "__main__": demo.queue().launch()