import gradio as gr import json import glob import os import re import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread # Paths BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ADAPTER_PATH = "../../training_pipeline/trainers/outputs/fantecchi-nsfw-bot" CHATBOT_PROFILES = "../../../chatbots/profiles_json/*.json" # Load Chatbot Profiles profiles = {} for filepath in glob.glob(CHATBOT_PROFILES): with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) profiles[data['name']] = data print(f"Loaded {len(profiles)} chatbot profiles.") # Global model variables model = None tokenizer = None def load_model(): global model, tokenizer if model is not None: return "Model already loaded." print("Loading tokenizer and model on CPU (bfloat16)... This may take a moment.") try: tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Loading on CPU using bfloat16 for memory efficiency model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu" ) if os.path.exists(ADAPTER_PATH): model.load_adapter(ADAPTER_PATH) print("Successfully loaded custom LoRA adapter on CPU (bfloat16).") else: print(f"WARNING: Adapter path {ADAPTER_PATH} not found. Running base model.") return "Model loaded successfully (CPU bfloat16)!" except Exception as e: return f"Error loading model: {str(e)}" # System formatting rule to force the LLM to tag outputs by character name FORMATTING_RULE = ( "\n\n[CRITICAL INSTRUCTION: You are controlling multiple NPCs. " "Every single paragraph or line of dialogue/action MUST begin with the specific character's name followed by a colon. " "Example:\nAlpha: *growls* Get him.\nBeta: *nods* Yes, Alpha.\n" "If a new character is introduced, you must prefix their actions with their name and a colon as well.]" ) def parse_multi_character_output(generated_text): """ Parses the LLM output looking for 'Name: text' patterns. Splits the output into multiple distinct Gradio 'messages' so each character gets their own bubble. """ parsed_messages = [] # Regex matches a name at the start of a line (allowing some spaces), a colon, and the rest of the text pattern = re.compile(r'(?m)^([A-Za-z0-9\'\- ]+):\s*(.*?)(?=(?:^[A-Za-z0-9\'\- ]+:)|\Z)', re.DOTALL) matches = pattern.findall(generated_text.strip()) if not matches: # If the LLM failed to format properly, just return it as a Narrator bubble parsed_messages.append({ "role": "assistant", "content": generated_text.strip(), "metadata": {"title": "Narrator"} }) return parsed_messages for match in matches: name = match[0].strip() text = match[1].strip() if text: parsed_messages.append({ "role": "assistant", "content": text, "metadata": {"title": name} }) return parsed_messages def generate_response(message, history, profile_name, temp, top_p, max_tokens): if model is None or tokenizer is None: history.append({"role": "assistant", "content": "Please load the model first using the 'Load Model' button.", "metadata": {"title": "System"}}) yield history, 0 return profile = profiles.get(profile_name) if not profile: history.append({"role": "assistant", "content": "Error: Profile not found.", "metadata": {"title": "System"}}) yield history return # Build the system prompt scenario = profile.get("scenario", "") chars = profile.get("characters", []) char_desc = "\n".join([f"{c.get('name', 'NPC')}: {c.get('behavior', '')} {c.get('appearance', '')}" for c in chars]) sys_prompt = f"Scenario: {scenario}\nCharacters:\n{char_desc}\n" + FORMATTING_RULE chatml_messages = [{"role": "system", "content": str(sys_prompt)}] # Limit history to stay within context window (e.g., last 10 messages) MAX_HISTORY = 10 recent_history = history[-MAX_HISTORY:] if len(history) > MAX_HISTORY else history for msg in recent_history: content = str(msg["content"]) # Force string if msg["role"] == "user": chatml_messages.append({"role": "user", "content": content}) else: name = msg.get("metadata", {}).get("title", "Narrator") formatted_content = f"{name}: {content}" if chatml_messages and chatml_messages[-1]["role"] == "assistant": chatml_messages[-1]["content"] += f"\n{formatted_content}" else: chatml_messages.append({"role": "assistant", "content": formatted_content}) chatml_messages.append({"role": "user", "content": str(message)}) prompt = tokenizer.apply_chat_template(chatml_messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) token_count = inputs.input_ids.shape[1] # Use Streamer for real-time feedback streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( **inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=float(top_p), do_sample=True if float(temp) > 0 else False, pad_token_id=tokenizer.pad_token_id, streamer=streamer, stop_strings=["<|user|>", "[USER]", "User:", "Narrator:"], # Prevent self-chatting ) thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() full_response = "" history.append({"role": "assistant", "content": "...", "metadata": {"title": "Generating..."}}) for new_text in streamer: full_response += new_text history[-1]["content"] = full_response yield history, token_count # After stream ends, try to parse the title from the text (e.g. "Name: ...") parsed = parse_multi_character_output(full_response) if parsed and len(parsed) == 1: # If it's a single character response, update the current bubble's title history[-1]["metadata"]["title"] = parsed[0]["metadata"]["title"] history[-1]["content"] = parsed[0]["content"] elif len(parsed) > 1: # If multiple characters spoke, replace the generating bubble with the parsed ones history.pop() for p in parsed: history.append(p) yield history, token_count # Gradio Interface with gr.Blocks() as app: gr.Markdown("# Fantecchi Local Chatbot Interface\nLoad your trained multi-character bots and chat entirely locally.") with gr.Row(): with gr.Column(scale=1): profile_dropdown = gr.Dropdown(choices=list(profiles.keys()), label="Select Chatbot Scenario", value=list(profiles.keys())[0] if profiles else None) load_btn = gr.Button("Load Model into VRAM", variant="primary") load_status = gr.Textbox(label="Status", interactive=False) scenario_box = gr.Textbox(label="Scenario Description", lines=5, interactive=False) avatar_display = gr.Image(label="Character Avatar", interactive=False) with gr.Accordion("Advanced Settings & Memory", open=False): temp_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.8, step=0.1, label="Temperature") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") max_tokens_slider = gr.Slider(minimum=32, maximum=1024, value=256, step=32, label="Max New Tokens") token_counter = gr.Number(label="Last Prompt Token Count", value=0, interactive=False) memory_viewer = gr.Textbox(label="Bot Memory (System Prompt)", lines=10, interactive=False) def update_sidebar(name): if not name: return "", None, "" p = profiles[name] img = p.get("image_prompt", None) # Build the system prompt preview sys_mem = p.get("description", "") + FORMATTING_RULE if img and img.startswith("http"): return p.get("scenario", ""), img, sys_mem return p.get("scenario", ""), None, sys_mem profile_dropdown.change(fn=update_sidebar, inputs=[profile_dropdown], outputs=[scenario_box, avatar_display, memory_viewer]) with gr.Column(scale=3): # Gradio 5+ handles the message format automatically or via different props chatbot = gr.Chatbot(height=600) msg_input = gr.Textbox(placeholder="Type your response...", label="Your Input") def init_chat(profile_name): # When a new profile is selected, clear history and inject the First Message if not profile_name: return [] profile = profiles[profile_name] first_msg = profile.get("first_mes", "") return [{"role": "assistant", "content": first_msg, "metadata": {"title": "Scenario Start"}}] profile_dropdown.change(fn=init_chat, inputs=[profile_dropdown], outputs=[chatbot]) # Submission logic def user_submit(user_text, history): history.append({"role": "user", "content": user_text}) return "", history msg_input.submit(user_submit, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot]).then( generate_response, inputs=[msg_input, chatbot, profile_dropdown, temp_slider, top_p_slider, max_tokens_slider], outputs=[chatbot, token_counter] ) load_btn.click(fn=load_model, inputs=[], outputs=[load_status]) if __name__ == "__main__": app.launch(server_name="127.0.0.1", server_port=7860, theme=gr.themes.Monochrome())