import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from threading import Thread import spaces import time # Global variables for model and tokenizer model = None tokenizer = None def load_model(): """Load the model and tokenizer - must be GPU decorated for Spaces""" global model, tokenizer if model is None: print("Loading model...") # Configure quantization (optional, remove if not needed) # nf4_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_use_double_quant=True, # bnb_4bit_compute_dtype=torch.bfloat16 # ) # Load model - adjust model name and settings as needed model = AutoModelForCausalLM.from_pretrained( "Rorical/0-roleplay-20240814-AWQ", device_map="auto", # Changed from "cuda" to "auto" # attn_implementation="flash_attention_2", torch_dtype=torch.float16, # quantization_config=nf4_config, # Uncomment if using quantization ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( "Rorical/0-roleplay-20240814-AWQ" ) # Set custom chat template tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + ((message['role'] + '\n') if message['role'] != '' else '') + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" print("Model loaded successfully!") def chat_response(message, history, character_name="珞璃", max_tokens=512, temperature=0.7, top_p=0.9): """Generate chat response using the loaded model""" global model, tokenizer # Load model if not already loaded if model is None or tokenizer is None: load_model() # Convert history to the expected format for the model messages = [] for msg in history: if msg['role'] == 'user': messages.append({"role": "user", "content": msg['content']}) elif msg['role'] == 'assistant': messages.append({"role": character_name, "content": msg['content']}) # Add current message messages.append({"role": "user", "content": message}) try: # Format chat using tokenizer template formatted_chat = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) formatted_chat += f"<|im_start|>{character_name}\n" # Tokenize input inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) # Move to the same device as model inputs = inputs.to(model.device) # Generate response with torch.no_grad(): start_time = time.time() generate_ids = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, do_sample=True, use_cache=True, ) end_time = time.time() # Decode response response_text = tokenizer.decode( generate_ids[0, inputs['input_ids'].size(1):], skip_special_tokens=True ).strip() print(f"Generation time: {end_time - start_time:.2f}s") return response_text except Exception as e: print(f"Error during generation: {e}") return f"Sorry, I encountered an error: {str(e)}" def chat_response_streaming(message, history, character_name="珞璃", max_tokens=512, temperature=0.7, top_p=0.9): """Generate streaming chat response""" global model, tokenizer # Load model if not already loaded if model is None or tokenizer is None: load_model() # Convert history to the expected format for the model messages = [] for msg in history: if msg['role'] == 'user': messages.append({"role": "user", "content": msg['content']}) elif msg['role'] == 'assistant': messages.append({"role": character_name, "content": msg['content']}) # Add current message messages.append({"role": "user", "content": message}) try: # Format chat using tokenizer template formatted_chat = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) formatted_chat += f"<|im_start|>{character_name}\n" # Tokenize input inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) # Move to the same device as model inputs = inputs.to(model.device) # Set up streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, do_sample=True, use_cache=True, ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream the response partial_response = "" for new_text in streamer: if new_text: partial_response += new_text yield partial_response thread.join() except Exception as e: print(f"Error during streaming generation: {e}") yield f"Sorry, I encountered an error: {str(e)}" def create_interface(): """Create the Gradio interface""" # Custom CSS for better styling css = """ .gradio-container { max-width: 1200px !important; margin: auto !important; } .chat-message { font-size: 16px !important; } """ with gr.Blocks(css=css, title="AI Roleplay Chatbot", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🤖 AI Roleplay Chatbot") gr.Markdown("Chat with an AI character using advanced language modeling.") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Chat", height=500, show_label=False, container=True, elem_classes=["chat-message"], type='messages' ) with gr.Row(): msg = gr.Textbox( label="Message", placeholder="Type your message here...", lines=2, max_lines=4, show_label=False, container=False ) send_btn = gr.Button("Send", variant="primary", size="lg") with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") retry_btn = gr.Button("Retry Last", variant="secondary") with gr.Column(scale=1): gr.Markdown("### Settings") character_name = gr.Textbox( label="Character Name", value="珞璃", placeholder="Enter character name..." ) max_tokens = gr.Slider( label="Max Tokens", minimum=50, maximum=1024, value=512, step=50 ) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1 ) top_p = gr.Slider( label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.05 ) streaming = gr.Checkbox( label="Enable Streaming", value=True ) # Event handlers @spaces.GPU(duration=30) def respond(message, chat_history, char_name, max_tok, temp, top_p_val, use_streaming): if use_streaming: # For streaming response - add user message first chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": ""}) for partial_response in chat_response_streaming( message, chat_history[:-2], char_name, max_tok, temp, top_p_val ): chat_history[-1]["content"] = partial_response yield chat_history, "" else: # For non-streaming response response = chat_response(message, chat_history, char_name, max_tok, temp, top_p_val) chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": response}) yield chat_history, "" def clear_chat(): return [], "" def retry_last(chat_history, char_name, max_tok, temp, top_p_val, use_streaming): if not chat_history: return chat_history, "" # Find the last user message last_user_message = None for i in range(len(chat_history) - 1, -1, -1): if chat_history[i]['role'] == 'user': last_user_message = chat_history[i]['content'] # Remove the last user message and any assistant responses after it chat_history = chat_history[:i] break if last_user_message is None: return chat_history, "" if use_streaming: chat_history.append({"role": "user", "content": last_user_message}) chat_history.append({"role": "assistant", "content": ""}) for partial_response in chat_response_streaming( last_user_message, chat_history[:-2], char_name, max_tok, temp, top_p_val ): chat_history[-1]["content"] = partial_response yield chat_history, "" else: response = chat_response(last_user_message, chat_history, char_name, max_tok, temp, top_p_val) chat_history.append({"role": "user", "content": last_user_message}) chat_history.append({"role": "assistant", "content": response}) yield chat_history, "" # Connect events msg.submit( respond, inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming], outputs=[chatbot, msg] ) send_btn.click( respond, inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming], outputs=[chatbot, msg] ) clear_btn.click(clear_chat, outputs=[chatbot, msg]) retry_btn.click( retry_last, inputs=[chatbot, character_name, max_tokens, temperature, top_p, streaming], outputs=[chatbot, msg] ) return demo demo = create_interface() demo.launch( show_error=True )