Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| from threading import Thread | |
| import spaces | |
| import time | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer - must be GPU decorated for Spaces""" | |
| global model, tokenizer | |
| if model is None: | |
| print("Loading model...") | |
| # Configure quantization (optional, remove if not needed) | |
| # nf4_config = BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_quant_type="nf4", | |
| # bnb_4bit_use_double_quant=True, | |
| # bnb_4bit_compute_dtype=torch.bfloat16 | |
| # ) | |
| # Load model - adjust model name and settings as needed | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Rorical/0-roleplay-20240814-AWQ", | |
| device_map="auto", # Changed from "cuda" to "auto" | |
| # attn_implementation="flash_attention_2", | |
| torch_dtype=torch.float16, | |
| # quantization_config=nf4_config, # Uncomment if using quantization | |
| ) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "Rorical/0-roleplay-20240814-AWQ" | |
| ) | |
| # Set custom chat template | |
| tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + ((message['role'] + '\n') if message['role'] != '' else '') + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" | |
| print("Model loaded successfully!") | |
| def chat_response(message, history, character_name="ηη", max_tokens=512, temperature=0.7, top_p=0.9): | |
| """Generate chat response using the loaded model""" | |
| global model, tokenizer | |
| # Load model if not already loaded | |
| if model is None or tokenizer is None: | |
| load_model() | |
| # Convert history to the expected format for the model | |
| messages = [] | |
| for msg in history: | |
| if msg['role'] == 'user': | |
| messages.append({"role": "user", "content": msg['content']}) | |
| elif msg['role'] == 'assistant': | |
| messages.append({"role": character_name, "content": msg['content']}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| try: | |
| # Format chat using tokenizer template | |
| formatted_chat = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| formatted_chat += f"<|im_start|>{character_name}\n" | |
| # Tokenize input | |
| inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) | |
| # Move to the same device as model | |
| inputs = inputs.to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| generate_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| use_cache=True, | |
| ) | |
| end_time = time.time() | |
| # Decode response | |
| response_text = tokenizer.decode( | |
| generate_ids[0, inputs['input_ids'].size(1):], | |
| skip_special_tokens=True | |
| ).strip() | |
| print(f"Generation time: {end_time - start_time:.2f}s") | |
| return response_text | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| return f"Sorry, I encountered an error: {str(e)}" | |
| def chat_response_streaming(message, history, character_name="ηη", max_tokens=512, temperature=0.7, top_p=0.9): | |
| """Generate streaming chat response""" | |
| global model, tokenizer | |
| # Load model if not already loaded | |
| if model is None or tokenizer is None: | |
| load_model() | |
| # Convert history to the expected format for the model | |
| messages = [] | |
| for msg in history: | |
| if msg['role'] == 'user': | |
| messages.append({"role": "user", "content": msg['content']}) | |
| elif msg['role'] == 'assistant': | |
| messages.append({"role": character_name, "content": msg['content']}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| try: | |
| # Format chat using tokenizer template | |
| formatted_chat = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| formatted_chat += f"<|im_start|>{character_name}\n" | |
| # Tokenize input | |
| inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) | |
| # Move to the same device as model | |
| inputs = inputs.to(model.device) | |
| # Set up streaming | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| use_cache=True, | |
| ) | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the response | |
| partial_response = "" | |
| for new_text in streamer: | |
| if new_text: | |
| partial_response += new_text | |
| yield partial_response | |
| thread.join() | |
| except Exception as e: | |
| print(f"Error during streaming generation: {e}") | |
| yield f"Sorry, I encountered an error: {str(e)}" | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| .chat-message { | |
| font-size: 16px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="AI Roleplay Chatbot", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π€ AI Roleplay Chatbot") | |
| gr.Markdown("Chat with an AI character using advanced language modeling.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| height=500, | |
| show_label=False, | |
| container=True, | |
| elem_classes=["chat-message"], | |
| type='messages' | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here...", | |
| lines=2, | |
| max_lines=4, | |
| show_label=False, | |
| container=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", size="lg") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| retry_btn = gr.Button("Retry Last", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| character_name = gr.Textbox( | |
| label="Character Name", | |
| value="ηη", | |
| placeholder="Enter character name..." | |
| ) | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=50, | |
| maximum=1024, | |
| value=512, | |
| step=50 | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05 | |
| ) | |
| streaming = gr.Checkbox( | |
| label="Enable Streaming", | |
| value=True | |
| ) | |
| # Event handlers | |
| def respond(message, chat_history, char_name, max_tok, temp, top_p_val, use_streaming): | |
| if use_streaming: | |
| # For streaming response - add user message first | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": ""}) | |
| for partial_response in chat_response_streaming( | |
| message, chat_history[:-2], char_name, max_tok, temp, top_p_val | |
| ): | |
| chat_history[-1]["content"] = partial_response | |
| yield chat_history, "" | |
| else: | |
| # For non-streaming response | |
| response = chat_response(message, chat_history, char_name, max_tok, temp, top_p_val) | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": response}) | |
| yield chat_history, "" | |
| def clear_chat(): | |
| return [], "" | |
| def retry_last(chat_history, char_name, max_tok, temp, top_p_val, use_streaming): | |
| if not chat_history: | |
| return chat_history, "" | |
| # Find the last user message | |
| last_user_message = None | |
| for i in range(len(chat_history) - 1, -1, -1): | |
| if chat_history[i]['role'] == 'user': | |
| last_user_message = chat_history[i]['content'] | |
| # Remove the last user message and any assistant responses after it | |
| chat_history = chat_history[:i] | |
| break | |
| if last_user_message is None: | |
| return chat_history, "" | |
| if use_streaming: | |
| chat_history.append({"role": "user", "content": last_user_message}) | |
| chat_history.append({"role": "assistant", "content": ""}) | |
| for partial_response in chat_response_streaming( | |
| last_user_message, chat_history[:-2], char_name, max_tok, temp, top_p_val | |
| ): | |
| chat_history[-1]["content"] = partial_response | |
| yield chat_history, "" | |
| else: | |
| response = chat_response(last_user_message, chat_history, char_name, max_tok, temp, top_p_val) | |
| chat_history.append({"role": "user", "content": last_user_message}) | |
| chat_history.append({"role": "assistant", "content": response}) | |
| yield chat_history, "" | |
| # Connect events | |
| msg.submit( | |
| respond, | |
| inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming], | |
| outputs=[chatbot, msg] | |
| ) | |
| send_btn.click( | |
| respond, | |
| inputs=[msg, chatbot, character_name, max_tokens, temperature, top_p, streaming], | |
| outputs=[chatbot, msg] | |
| ) | |
| clear_btn.click(clear_chat, outputs=[chatbot, msg]) | |
| retry_btn.click( | |
| retry_last, | |
| inputs=[chatbot, character_name, max_tokens, temperature, top_p, streaming], | |
| outputs=[chatbot, msg] | |
| ) | |
| return demo | |
| demo = create_interface() | |
| demo.launch( | |
| show_error=True | |
| ) |