Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import time | |
| from typing import List, Tuple | |
| # Model configuration | |
| MODEL_PATH = "microsoft/UserLM-8b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer.""" | |
| global model, tokenizer | |
| print(f"Loading model {MODEL_PATH}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(DEVICE) | |
| print(f"Model loaded successfully on {DEVICE}") | |
| return model, tokenizer | |
| def generate_response( | |
| message: str, | |
| chat_history: List[Tuple[str, str]], | |
| system_prompt: str, | |
| temperature: float, | |
| top_p: float, | |
| max_new_tokens: int, | |
| ) -> str: | |
| """Generate a response from the model.""" | |
| global model, tokenizer | |
| # Load model if not already loaded | |
| if model is None or tokenizer is None: | |
| model, tokenizer = load_model() | |
| # Build conversation history | |
| messages = [] | |
| # Add system prompt if provided | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add chat history | |
| for user_msg, assistant_msg in chat_history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Tokenize input | |
| inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(DEVICE) | |
| # Define special tokens | |
| end_token = "<|eot_id|>" | |
| end_token_id = tokenizer.encode(end_token, add_special_tokens=False) | |
| end_conv_token = "<|endconversation|>" | |
| end_conv_token_id = tokenizer.encode(end_conv_token, add_special_tokens=False) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| eos_token_id=end_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| bad_words_ids=[[token_id] for token_id in end_conv_token_id] | |
| ) | |
| # Decode response | |
| response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| return response | |
| def respond( | |
| message: str, | |
| chat_history: List[Tuple[str, str]], | |
| system_prompt: str, | |
| temperature: float, | |
| top_p: float, | |
| max_new_tokens: int, | |
| ): | |
| """Stream response to the chatbot.""" | |
| # Generate complete response | |
| bot_message = generate_response( | |
| message, | |
| chat_history, | |
| system_prompt, | |
| temperature, | |
| top_p, | |
| max_new_tokens | |
| ) | |
| # Add to chat history | |
| chat_history.append((message, bot_message)) | |
| # Stream the response character by character for better UX | |
| partial_message = "" | |
| for char in bot_message: | |
| partial_message += char | |
| time.sleep(0.01) # Small delay for streaming effect | |
| yield chat_history[:-1] + [(message, partial_message)] | |
| yield chat_history | |
| def clear_conversation(): | |
| """Clear the conversation history.""" | |
| return [], None | |
| # Create the Gradio interface | |
| with gr.Blocks(title="UserLM-8b Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ UserLM-8b Chat Interface | |
| Chat with Microsoft's UserLM-8b model. This model is designed to simulate user behavior and generate responses as if from a user perspective. | |
| [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| show_copy_button=True, | |
| bubble_full_width=False, | |
| avatar_images=(None, "π€"), | |
| render_markdown=True, | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here and press Enter...", | |
| lines=2, | |
| scale=4, | |
| autofocus=True, | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton( | |
| [chatbot, msg], | |
| value="ποΈ Clear Chat" | |
| ) | |
| retry_btn = gr.Button("π Retry Last") | |
| undo_btn = gr.Button("β©οΈ Undo Last") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Settings") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="Set the behavior of the model...", | |
| value="You are a user who wants to implement a special type of sequence. The sequence sums up the two previous numbers in the sequence and adds 1 to the result. The first two numbers in the sequence are 1 and 1.", | |
| lines=4, | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values make output more random" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| info="Lower values focus on more likely tokens" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=512, | |
| value=100, | |
| step=10, | |
| label="Max New Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### π Model Info | |
| - **Model**: microsoft/UserLM-8b | |
| - **Parameters**: 8 billion | |
| - **Device**: """ + DEVICE.upper() + """ | |
| - **Precision**: FP16 (CUDA) / FP32 (CPU) | |
| """ | |
| ) | |
| # Store conversation history | |
| chat_history = gr.State([]) | |
| # Event handlers | |
| def user_submit(message, history): | |
| return "", history + [(message, None)] | |
| def bot_respond(history, system, temp, top_p, max_tokens): | |
| if not history or history[-1][1] is not None: | |
| return history | |
| message = history[-1][0] | |
| history_without_last = history[:-1] | |
| for new_history in respond(message, history_without_last, system, temp, top_p, max_tokens): | |
| yield new_history | |
| def retry_last(history, system, temp, top_p, max_tokens): | |
| if not history: | |
| return history | |
| # Remove last exchange and regenerate | |
| last_user_msg = history[-1][0] | |
| history = history[:-1] | |
| for new_history in respond(last_user_msg, history, system, temp, top_p, max_tokens): | |
| yield new_history | |
| def undo_last(history): | |
| if history: | |
| return history[:-1] | |
| return history | |
| # Connect events | |
| msg.submit( | |
| user_submit, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_respond, | |
| [chatbot, system_prompt, temperature, top_p, max_new_tokens], | |
| chatbot | |
| ) | |
| submit_btn.click( | |
| user_submit, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_respond, | |
| [chatbot, system_prompt, temperature, top_p, max_new_tokens], | |
| chatbot | |
| ) | |
| retry_btn.click( | |
| retry_last, | |
| [chatbot, system_prompt, temperature, top_p, max_new_tokens], | |
| chatbot | |
| ) | |
| undo_btn.click( | |
| undo_last, | |
| chatbot, | |
| chatbot | |
| ) | |
| # Load model on startup | |
| demo.load( | |
| fn=lambda: gr.Info("Model loading... This may take a moment on first run."), | |
| inputs=None, | |
| outputs=None | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Can you help me understand how this sequence works?"], | |
| ["What would be the next 5 numbers in the sequence?"], | |
| ["Let's implement this sequence in Python together."], | |
| ["Can you explain the pattern: 1, 1, 3, 5, 9, 15...?"], | |
| ], | |
| inputs=msg, | |
| label="Example Messages", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) |