| import gradio as gr |
| import json |
| import glob |
| import os |
| import re |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| from threading import Thread |
|
|
| |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| ADAPTER_PATH = "../../training_pipeline/trainers/outputs/fantecchi-nsfw-bot" |
| CHATBOT_PROFILES = "../../../chatbots/profiles_json/*.json" |
|
|
| |
| profiles = {} |
| for filepath in glob.glob(CHATBOT_PROFILES): |
| with open(filepath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| profiles[data['name']] = data |
|
|
| print(f"Loaded {len(profiles)} chatbot profiles.") |
|
|
| |
| model = None |
| tokenizer = None |
|
|
| def load_model(): |
| global model, tokenizer |
| if model is not None: |
| return "Model already loaded." |
| |
| print("Loading tokenizer and model on CPU (bfloat16)... This may take a moment.") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
| if os.path.exists(ADAPTER_PATH): |
| model.load_adapter(ADAPTER_PATH) |
| print("Successfully loaded custom LoRA adapter on CPU (bfloat16).") |
| else: |
| print(f"WARNING: Adapter path {ADAPTER_PATH} not found. Running base model.") |
| |
| return "Model loaded successfully (CPU bfloat16)!" |
| except Exception as e: |
| return f"Error loading model: {str(e)}" |
|
|
| |
| FORMATTING_RULE = ( |
| "\n\n[CRITICAL INSTRUCTION: You are controlling multiple NPCs. " |
| "Every single paragraph or line of dialogue/action MUST begin with the specific character's name followed by a colon. " |
| "Example:\nAlpha: *growls* Get him.\nBeta: *nods* Yes, Alpha.\n" |
| "If a new character is introduced, you must prefix their actions with their name and a colon as well.]" |
| ) |
|
|
| def parse_multi_character_output(generated_text): |
| """ |
| Parses the LLM output looking for 'Name: text' patterns. |
| Splits the output into multiple distinct Gradio 'messages' so each character gets their own bubble. |
| """ |
| parsed_messages = [] |
| |
| |
| pattern = re.compile(r'(?m)^([A-Za-z0-9\'\- ]+):\s*(.*?)(?=(?:^[A-Za-z0-9\'\- ]+:)|\Z)', re.DOTALL) |
| matches = pattern.findall(generated_text.strip()) |
| |
| if not matches: |
| |
| parsed_messages.append({ |
| "role": "assistant", |
| "content": generated_text.strip(), |
| "metadata": {"title": "Narrator"} |
| }) |
| return parsed_messages |
| |
| for match in matches: |
| name = match[0].strip() |
| text = match[1].strip() |
| if text: |
| parsed_messages.append({ |
| "role": "assistant", |
| "content": text, |
| "metadata": {"title": name} |
| }) |
| |
| return parsed_messages |
|
|
| def generate_response(message, history, profile_name, temp, top_p, max_tokens): |
| if model is None or tokenizer is None: |
| history.append({"role": "assistant", "content": "Please load the model first using the 'Load Model' button.", "metadata": {"title": "System"}}) |
| yield history, 0 |
| return |
| |
| profile = profiles.get(profile_name) |
| if not profile: |
| history.append({"role": "assistant", "content": "Error: Profile not found.", "metadata": {"title": "System"}}) |
| yield history |
| return |
|
|
| |
| scenario = profile.get("scenario", "") |
| chars = profile.get("characters", []) |
| char_desc = "\n".join([f"{c.get('name', 'NPC')}: {c.get('behavior', '')} {c.get('appearance', '')}" for c in chars]) |
| sys_prompt = f"Scenario: {scenario}\nCharacters:\n{char_desc}\n" + FORMATTING_RULE |
| |
| chatml_messages = [{"role": "system", "content": str(sys_prompt)}] |
| |
| |
| MAX_HISTORY = 10 |
| recent_history = history[-MAX_HISTORY:] if len(history) > MAX_HISTORY else history |
| |
| for msg in recent_history: |
| content = str(msg["content"]) |
| if msg["role"] == "user": |
| chatml_messages.append({"role": "user", "content": content}) |
| else: |
| name = msg.get("metadata", {}).get("title", "Narrator") |
| formatted_content = f"{name}: {content}" |
| if chatml_messages and chatml_messages[-1]["role"] == "assistant": |
| chatml_messages[-1]["content"] += f"\n{formatted_content}" |
| else: |
| chatml_messages.append({"role": "assistant", "content": formatted_content}) |
| |
| chatml_messages.append({"role": "user", "content": str(message)}) |
|
|
| prompt = tokenizer.apply_chat_template(chatml_messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| token_count = inputs.input_ids.shape[1] |
|
|
| |
| streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) |
| |
| generate_kwargs = dict( |
| **inputs, |
| max_new_tokens=int(max_tokens), |
| temperature=float(temp), |
| top_p=float(top_p), |
| do_sample=True if float(temp) > 0 else False, |
| pad_token_id=tokenizer.pad_token_id, |
| streamer=streamer, |
| stop_strings=["<|user|>", "[USER]", "User:", "Narrator:"], |
| ) |
|
|
| thread = Thread(target=model.generate, kwargs=generate_kwargs) |
| thread.start() |
|
|
| full_response = "" |
| history.append({"role": "assistant", "content": "...", "metadata": {"title": "Generating..."}}) |
| |
| for new_text in streamer: |
| full_response += new_text |
| history[-1]["content"] = full_response |
| yield history, token_count |
|
|
| |
| parsed = parse_multi_character_output(full_response) |
| if parsed and len(parsed) == 1: |
| |
| history[-1]["metadata"]["title"] = parsed[0]["metadata"]["title"] |
| history[-1]["content"] = parsed[0]["content"] |
| elif len(parsed) > 1: |
| |
| history.pop() |
| for p in parsed: |
| history.append(p) |
| |
| yield history, token_count |
|
|
| |
| with gr.Blocks() as app: |
| gr.Markdown("# Fantecchi Local Chatbot Interface\nLoad your trained multi-character bots and chat entirely locally.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| profile_dropdown = gr.Dropdown(choices=list(profiles.keys()), label="Select Chatbot Scenario", value=list(profiles.keys())[0] if profiles else None) |
| load_btn = gr.Button("Load Model into VRAM", variant="primary") |
| load_status = gr.Textbox(label="Status", interactive=False) |
| |
| scenario_box = gr.Textbox(label="Scenario Description", lines=5, interactive=False) |
| avatar_display = gr.Image(label="Character Avatar", interactive=False) |
| |
| with gr.Accordion("Advanced Settings & Memory", open=False): |
| temp_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.8, step=0.1, label="Temperature") |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") |
| max_tokens_slider = gr.Slider(minimum=32, maximum=1024, value=256, step=32, label="Max New Tokens") |
| |
| token_counter = gr.Number(label="Last Prompt Token Count", value=0, interactive=False) |
| memory_viewer = gr.Textbox(label="Bot Memory (System Prompt)", lines=10, interactive=False) |
|
|
| def update_sidebar(name): |
| if not name: return "", None, "" |
| p = profiles[name] |
| img = p.get("image_prompt", None) |
| |
| sys_mem = p.get("description", "") + FORMATTING_RULE |
| |
| if img and img.startswith("http"): |
| return p.get("scenario", ""), img, sys_mem |
| return p.get("scenario", ""), None, sys_mem |
| |
| profile_dropdown.change(fn=update_sidebar, inputs=[profile_dropdown], outputs=[scenario_box, avatar_display, memory_viewer]) |
| |
| with gr.Column(scale=3): |
| |
| chatbot = gr.Chatbot(height=600) |
| msg_input = gr.Textbox(placeholder="Type your response...", label="Your Input") |
| |
| def init_chat(profile_name): |
| |
| if not profile_name: return [] |
| profile = profiles[profile_name] |
| first_msg = profile.get("first_mes", "") |
| |
| return [{"role": "assistant", "content": first_msg, "metadata": {"title": "Scenario Start"}}] |
| |
| profile_dropdown.change(fn=init_chat, inputs=[profile_dropdown], outputs=[chatbot]) |
| |
| |
| def user_submit(user_text, history): |
| history.append({"role": "user", "content": user_text}) |
| return "", history |
| |
| msg_input.submit(user_submit, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot]).then( |
| generate_response, |
| inputs=[msg_input, chatbot, profile_dropdown, temp_slider, top_p_slider, max_tokens_slider], |
| outputs=[chatbot, token_counter] |
| ) |
| |
| load_btn.click(fn=load_model, inputs=[], outputs=[load_status]) |
|
|
| if __name__ == "__main__": |
| app.launch(server_name="127.0.0.1", server_port=7860, theme=gr.themes.Monochrome()) |
|
|