Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| from datetime import datetime | |
| from threading import Lock | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # ========== Auto-create folders ========== | |
| os.makedirs("chat_history", exist_ok=True) | |
| os.makedirs("system", exist_ok=True) | |
| # ========== Load System Context ========== | |
| context_path = "system/context.txt" | |
| if not os.path.exists(context_path): | |
| raise FileNotFoundError(f"Missing system context file at {context_path}!") | |
| with open(context_path, "r", encoding="utf-8") as f: | |
| loaded_context = f.read() | |
| # ========== Simple Chatbot Logic ========== | |
| lock = Lock() | |
| # Provide the folder path, not the file path | |
| model_folder = "model/Mistral-7B-Instruct-v0.3" | |
| # Load the model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16) | |
| tokenizer = AutoTokenizer.from_pretrained(model_folder) | |
| # Set pad_token to eos_token if pad_token is not available | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Initialize the pipeline for text generation | |
| generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| # ========== Helper Functions ========== | |
| def sanitize_username(username): | |
| return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip() | |
| def user_folder(username): | |
| return os.path.join("chat_history", username) | |
| def load_latest_history(username): | |
| folder = user_folder(username) | |
| if not os.path.exists(folder): | |
| os.makedirs(folder, exist_ok=True) | |
| return [] | |
| files = sorted(os.listdir(folder), reverse=True) | |
| if not files: | |
| return [] | |
| latest_file = os.path.join(folder, files[0]) | |
| with open(latest_file, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| history = [] | |
| for line in lines: | |
| if ": " in line: | |
| user, msg = line.split(": ", 1) | |
| history.append((user.strip(), msg.strip())) | |
| return history | |
| def save_history(username, history): | |
| folder = user_folder(username) | |
| os.makedirs(folder, exist_ok=True) | |
| filepath = os.path.join(folder, "history.txt") | |
| with open(filepath, "a", encoding="utf-8") as f: | |
| # Only write the last two new entries (user + Sanny Lin) | |
| for user, msg in history[-2:]: | |
| f.write(f"{user}: {msg}\n") | |
| def format_chat(history): | |
| formatted = "" | |
| for user, msg in history: | |
| if user == "Sanny Lin": | |
| formatted += f""" | |
| <div style='text-align: left; margin: 5px;'> | |
| <span class='sanny-message' style='background-color: #e74c3c; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'> | |
| {msg} | |
| </span> | |
| </div> | |
| """ | |
| else: | |
| formatted += f""" | |
| <div style='text-align: right; margin: 5px;'> | |
| <span style='background-color: #3498db; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'> | |
| {msg} | |
| </span> | |
| </div> | |
| """ | |
| return formatted | |
| def generate_reply(username, user_message, history): | |
| with lock: | |
| if not user_message.strip(): | |
| return history | |
| # Retrieve the last 30 messages, including history from the user | |
| history = history[-30:] # Limit to the last 30 messages | |
| messages = [] | |
| # Start with the system context | |
| if not history: | |
| messages.append({"role": "system", "content": loaded_context}) | |
| # Add the last 30 messages to the conversation history | |
| for user, msg in history: | |
| role = "user" if user == username else "assistant" | |
| messages.append({"role": role, "content": msg}) | |
| # Add the user message at the end | |
| messages.append({"role": "user", "content": user_message}) | |
| # Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context | |
| user_prompt = f"You are chatting with {username} now. Reply to this message:" | |
| messages.append({"role": "system", "content": user_prompt}) | |
| # Extract the content part of each message for encoding | |
| text_messages = [message["content"] for message in messages] | |
| # Tokenize using only the content part | |
| prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False) | |
| # Generate the assistant's reply without the user message being included at the start | |
| generated_output = generator(user_message, | |
| max_length=32768, | |
| max_new_tokens=512,# Set max length for truncation | |
| num_return_sequences=1, | |
| do_sample=True, | |
| temperature=0.5, | |
| top_p=0.5, | |
| top_k=0, | |
| typical_p=1, | |
| repetition_penalty=1) # Disable sampling for more creative and deterministic responses | |
| response = generated_output[0]["generated_text"] | |
| # Clean the response to remove any prefix from the last user message | |
| if response.startswith(user_message): | |
| response = response[len(user_message):].strip() | |
| # Smart truncation to cut off at 4096 characters without cutting in the middle of a word | |
| max_length = 4096 | |
| if len(response) > max_length: | |
| # Find the last space before the cutoff point | |
| truncated_response = response[:max_length] | |
| last_space_idx = truncated_response.rfind(" ") | |
| if last_space_idx != -1: | |
| response = truncated_response[:last_space_idx] | |
| else: | |
| response = truncated_response | |
| # Add the user message and assistant's response to history | |
| history.append((username, user_message)) | |
| history.append(("Sanny Lin", response)) | |
| save_history(username, history) | |
| return format_chat(history) | |
| # ========== Gradio Interface ========== | |
| with gr.Blocks(theme=gr.themes.Monochrome(), css=""" | |
| @font-face { | |
| font-family: "DaemonFont"; | |
| src: url('static/daemon.otf') format('opentype'); | |
| } | |
| body { background-color: #121212 !important; } | |
| .gradio-container { background-color: #121212 !important; } | |
| textarea { background-color: #1e1e1e !important; color: white; } | |
| input { background-color: #1e1e1e !important; color: white; } | |
| #chat_display { overflow-y: auto; height: calc(100vh - 200px); } | |
| .sanny-message { | |
| font-family: "DaemonFont", sans-serif; | |
| } | |
| """) as demo: | |
| chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False) | |
| with gr.Row(): | |
| username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2) | |
| user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8) | |
| send_button = gr.Button("Send", scale=1) | |
| username_state = gr.State("") | |
| history_state = gr.State([]) | |
| def user_send(user_message, username, history, username_input): | |
| if not username_input.strip(): | |
| return "<div style='color: red;'>Please enter a valid username first.</div>", history, username | |
| username_input = sanitize_username(username_input) | |
| if not username: | |
| username = username_input | |
| history = history or load_latest_history(username) | |
| return generate_reply(username, user_message, history), history, username | |
| send_button.click( | |
| fn=user_send, | |
| inputs=[user_input, username_state, history_state, username_box], | |
| outputs=[chat_display, history_state, username_state] | |
| ) | |
| send_button.click(lambda: "", None, user_input) # Clear input after send | |
| demo.load(None, None, None, js=""" | |
| () => { | |
| const textbox = document.querySelector('textarea'); | |
| const sendButton = document.querySelector('button'); | |
| textbox.addEventListener('keydown', function(e) { | |
| if (e.key === 'Enter' && !e.shiftKey) { | |
| e.preventDefault(); | |
| sendButton.click(); | |
| } | |
| }); | |
| } | |
| """) | |
| demo.launch(share=False) | |