Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import time | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from peft import PeftModel | |
| import torch | |
| import re | |
| print("Initializing model") | |
| # Initialize the tokenizer and model | |
| base_model = "mistralai/Mistral-7B-Instruct-v0.2" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
| base_model = AutoModelForCausalLM.from_pretrained(base_model) | |
| ft_model = PeftModel.from_pretrained(base_model, "nuratamton/story_sculptor_mistral") | |
| # ft_model = ft_model.merge_and_unload() | |
| ft_model.eval() | |
| # Set the device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ft_model.to(device) | |
| memory = ConversationBufferWindowMemory(k=10) | |
| def slow_echo(message, history): | |
| message = chat_interface(message) | |
| for i in range(len(message)): | |
| time.sleep(0.05) | |
| yield message[: i+1] | |
| def chat_interface(user_in): | |
| if user_in.lower() == "quit": | |
| return "Goodbye!" | |
| #memory.save_context({"input": user_in}, {"output": ""}) | |
| memory_context = memory.load_memory_variables({})["history"] | |
| user_input = f"[INST] Continue the game and maintain context and keep the story consistent throughout: {memory_context}{user_in}[/INST]" | |
| encodings = tokenizer(user_input, return_tensors="pt", padding=True).to(device) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| output_ids = ft_model.generate(input_ids, attention_mask = attention_mask, max_new_tokens=1000, num_return_sequences=1, do_sample=True, temperature=1.1, top_p=0.9, repetition_penalty=1.2) | |
| generated_ids = output_ids[0, input_ids.shape[-1]:] | |
| # Decode the output | |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| memory.save_context({"input": user_in}, {"output": response}) | |
| print(f"Game Agent: {response}") | |
| # Your chatbot logic here | |
| # response = "You said: " + user_in | |
| return response | |
| iface = gr.ChatInterface(slow_echo).queue() | |
| iface.launch(share=True) | |