File size: 10,411 Bytes
9a8d870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import gradio as gr
import json
import glob
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Paths
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "../../training_pipeline/trainers/outputs/fantecchi-nsfw-bot"
CHATBOT_PROFILES = "../../../chatbots/profiles_json/*.json"

# Load Chatbot Profiles
profiles = {}
for filepath in glob.glob(CHATBOT_PROFILES):
    with open(filepath, 'r', encoding='utf-8') as f:
        data = json.load(f)
        profiles[data['name']] = data

print(f"Loaded {len(profiles)} chatbot profiles.")

# Global model variables
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is not None:
        return "Model already loaded."
    
    print("Loading tokenizer and model on CPU (bfloat16)... This may take a moment.")
    try:
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        # Loading on CPU using bfloat16 for memory efficiency
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="cpu"
        )
        if os.path.exists(ADAPTER_PATH):
            model.load_adapter(ADAPTER_PATH)
            print("Successfully loaded custom LoRA adapter on CPU (bfloat16).")
        else:
            print(f"WARNING: Adapter path {ADAPTER_PATH} not found. Running base model.")
            
        return "Model loaded successfully (CPU bfloat16)!"
    except Exception as e:
        return f"Error loading model: {str(e)}"

# System formatting rule to force the LLM to tag outputs by character name
FORMATTING_RULE = (
    "\n\n[CRITICAL INSTRUCTION: You are controlling multiple NPCs. "
    "Every single paragraph or line of dialogue/action MUST begin with the specific character's name followed by a colon. "
    "Example:\nAlpha: *growls* Get him.\nBeta: *nods* Yes, Alpha.\n"
    "If a new character is introduced, you must prefix their actions with their name and a colon as well.]"
)

def parse_multi_character_output(generated_text):
    """
    Parses the LLM output looking for 'Name: text' patterns.
    Splits the output into multiple distinct Gradio 'messages' so each character gets their own bubble.
    """
    parsed_messages = []
    
    # Regex matches a name at the start of a line (allowing some spaces), a colon, and the rest of the text
    pattern = re.compile(r'(?m)^([A-Za-z0-9\'\- ]+):\s*(.*?)(?=(?:^[A-Za-z0-9\'\- ]+:)|\Z)', re.DOTALL)
    matches = pattern.findall(generated_text.strip())
    
    if not matches:
        # If the LLM failed to format properly, just return it as a Narrator bubble
        parsed_messages.append({
            "role": "assistant",
            "content": generated_text.strip(),
            "metadata": {"title": "Narrator"}
        })
        return parsed_messages
        
    for match in matches:
        name = match[0].strip()
        text = match[1].strip()
        if text:
            parsed_messages.append({
                "role": "assistant",
                "content": text,
                "metadata": {"title": name}
            })
            
    return parsed_messages

def generate_response(message, history, profile_name, temp, top_p, max_tokens):
    if model is None or tokenizer is None:
        history.append({"role": "assistant", "content": "Please load the model first using the 'Load Model' button.", "metadata": {"title": "System"}})
        yield history, 0
        return
        
    profile = profiles.get(profile_name)
    if not profile:
        history.append({"role": "assistant", "content": "Error: Profile not found.", "metadata": {"title": "System"}})
        yield history
        return

    # Build the system prompt
    scenario = profile.get("scenario", "")
    chars = profile.get("characters", [])
    char_desc = "\n".join([f"{c.get('name', 'NPC')}: {c.get('behavior', '')} {c.get('appearance', '')}" for c in chars])
    sys_prompt = f"Scenario: {scenario}\nCharacters:\n{char_desc}\n" + FORMATTING_RULE
    
    chatml_messages = [{"role": "system", "content": str(sys_prompt)}]
    
    # Limit history to stay within context window (e.g., last 10 messages)
    MAX_HISTORY = 10
    recent_history = history[-MAX_HISTORY:] if len(history) > MAX_HISTORY else history
    
    for msg in recent_history:
        content = str(msg["content"]) # Force string
        if msg["role"] == "user":
            chatml_messages.append({"role": "user", "content": content})
        else:
            name = msg.get("metadata", {}).get("title", "Narrator")
            formatted_content = f"{name}: {content}"
            if chatml_messages and chatml_messages[-1]["role"] == "assistant":
                chatml_messages[-1]["content"] += f"\n{formatted_content}"
            else:
                chatml_messages.append({"role": "assistant", "content": formatted_content})
                
    chatml_messages.append({"role": "user", "content": str(message)})

    prompt = tokenizer.apply_chat_template(chatml_messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    token_count = inputs.input_ids.shape[1]

    # Use Streamer for real-time feedback
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        **inputs,
        max_new_tokens=int(max_tokens),
        temperature=float(temp),
        top_p=float(top_p),
        do_sample=True if float(temp) > 0 else False,
        pad_token_id=tokenizer.pad_token_id,
        streamer=streamer,
        stop_strings=["<|user|>", "[USER]", "User:", "Narrator:"], # Prevent self-chatting
    )

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    full_response = ""
    history.append({"role": "assistant", "content": "...", "metadata": {"title": "Generating..."}})
    
    for new_text in streamer:
        full_response += new_text
        history[-1]["content"] = full_response
        yield history, token_count

    # After stream ends, try to parse the title from the text (e.g. "Name: ...")
    parsed = parse_multi_character_output(full_response)
    if parsed and len(parsed) == 1:
        # If it's a single character response, update the current bubble's title
        history[-1]["metadata"]["title"] = parsed[0]["metadata"]["title"]
        history[-1]["content"] = parsed[0]["content"]
    elif len(parsed) > 1:
        # If multiple characters spoke, replace the generating bubble with the parsed ones
        history.pop()
        for p in parsed:
            history.append(p)
            
    yield history, token_count

# Gradio Interface
with gr.Blocks() as app:
    gr.Markdown("# Fantecchi Local Chatbot Interface\nLoad your trained multi-character bots and chat entirely locally.")
    
    with gr.Row():
        with gr.Column(scale=1):
            profile_dropdown = gr.Dropdown(choices=list(profiles.keys()), label="Select Chatbot Scenario", value=list(profiles.keys())[0] if profiles else None)
            load_btn = gr.Button("Load Model into VRAM", variant="primary")
            load_status = gr.Textbox(label="Status", interactive=False)
            
            scenario_box = gr.Textbox(label="Scenario Description", lines=5, interactive=False)
            avatar_display = gr.Image(label="Character Avatar", interactive=False)
            
            with gr.Accordion("Advanced Settings & Memory", open=False):
                temp_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.8, step=0.1, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
                max_tokens_slider = gr.Slider(minimum=32, maximum=1024, value=256, step=32, label="Max New Tokens")
                
                token_counter = gr.Number(label="Last Prompt Token Count", value=0, interactive=False)
                memory_viewer = gr.Textbox(label="Bot Memory (System Prompt)", lines=10, interactive=False)

            def update_sidebar(name):
                if not name: return "", None, ""
                p = profiles[name]
                img = p.get("image_prompt", None)
                # Build the system prompt preview
                sys_mem = p.get("description", "") + FORMATTING_RULE
                
                if img and img.startswith("http"):
                    return p.get("scenario", ""), img, sys_mem
                return p.get("scenario", ""), None, sys_mem
                
            profile_dropdown.change(fn=update_sidebar, inputs=[profile_dropdown], outputs=[scenario_box, avatar_display, memory_viewer])
            
        with gr.Column(scale=3):
            # Gradio 5+ handles the message format automatically or via different props
            chatbot = gr.Chatbot(height=600)
            msg_input = gr.Textbox(placeholder="Type your response...", label="Your Input")
            
            def init_chat(profile_name):
                # When a new profile is selected, clear history and inject the First Message
                if not profile_name: return []
                profile = profiles[profile_name]
                first_msg = profile.get("first_mes", "")
                
                return [{"role": "assistant", "content": first_msg, "metadata": {"title": "Scenario Start"}}]
                
            profile_dropdown.change(fn=init_chat, inputs=[profile_dropdown], outputs=[chatbot])
            
            # Submission logic
            def user_submit(user_text, history):
                history.append({"role": "user", "content": user_text})
                return "", history
                
            msg_input.submit(user_submit, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot]).then(
                generate_response, 
                inputs=[msg_input, chatbot, profile_dropdown, temp_slider, top_p_slider, max_tokens_slider], 
                outputs=[chatbot, token_counter]
            )
 
    load_btn.click(fn=load_model, inputs=[], outputs=[load_status])

if __name__ == "__main__":
    app.launch(server_name="127.0.0.1", server_port=7860, theme=gr.themes.Monochrome())