File size: 6,648 Bytes
9a8d870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import sys

# Add the interface directory to the path so we can import app_logic
sys.path.append(os.path.join(os.path.dirname(__file__), "interface"))

# Adjust working directory to the root of the project to ensure relative paths work
# Or we can just import and run. 
# The interface/app.py has paths like "../training/...". 
# If we run from root, those will fail.

# Better: Let's create a root-compatible version of the app logic.
# Actually, I will just create a simple script that launches the Gradio app.

import gradio as gr
import json
import glob
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Paths adjusted for root execution
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "core/training_pipeline/trainers/outputs/fantecchi-nsfw-bot"
CHATBOT_PROFILES = "chatbots/profiles_json/*.json"

# [Rest of the logic from interface/app.py but with updated paths]
# I will copy the core logic here to ensure it works on Spaces.

# Load Chatbot Profiles
profiles = {}
for filepath in glob.glob(CHATBOT_PROFILES):
    with open(filepath, 'r', encoding='utf-8') as f:
        data = json.load(f)
        profiles[data['name']] = data

print(f"Loaded {len(profiles)} chatbot profiles.")

model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is not None:
        return "Model already loaded."
    
    print("Loading tokenizer and model on CPU (bfloat16)...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="auto" # Use auto for Spaces (it might have T4)
        )
        if os.path.exists(ADAPTER_PATH):
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, ADAPTER_PATH)
            print("Successfully loaded custom LoRA adapter.")
        else:
            print(f"WARNING: Adapter path {ADAPTER_PATH} not found. Running base model.")
            
        return "Model loaded successfully!"
    except Exception as e:
        return f"Error loading model: {str(e)}"

FORMATTING_RULE = (
    "\n\n[CRITICAL INSTRUCTION: You are controlling multiple NPCs. "
    "Every single paragraph or line of dialogue/action MUST begin with the specific character's name followed by a colon.]"
)

def parse_multi_character_output(generated_text):
    parsed_messages = []
    pattern = re.compile(r'(?m)^([A-Za-z0-9\'\- ]+):\s*(.*?)(?=(?:^[A-Za-z0-9\'\- ]+:)|\Z)', re.DOTALL)
    matches = pattern.findall(generated_text.strip())
    
    if not matches:
        parsed_messages.append({"role": "assistant", "content": generated_text.strip(), "metadata": {"title": "Narrator"}})
        return parsed_messages
        
    for match in matches:
        name = match[0].strip()
        text = match[1].strip()
        if text:
            parsed_messages.append({"role": "assistant", "content": text, "metadata": {"title": name}})
    return parsed_messages

def generate_response(message, history, profile_name, temp, top_p, max_tokens):
    if model is None or tokenizer is None:
        history.append({"role": "assistant", "content": "Please load the model first.", "metadata": {"title": "System"}})
        yield history, 0
        return
        
    profile = profiles.get(profile_name)
    if not profile:
        history.append({"role": "assistant", "content": "Error: Profile not found.", "metadata": {"title": "System"}})
        yield history, 0
        return

    scenario = profile.get("scenario", "")
    chars = profile.get("characters", [])
    char_desc = "\n".join([f"{c.get('name', 'NPC')}: {c.get('behavior', '')} {c.get('appearance', '')}" for c in chars])
    sys_prompt = f"Scenario: {scenario}\nCharacters:\n{char_desc}\n" + FORMATTING_RULE
    
    chatml_messages = [{"role": "system", "content": str(sys_prompt)}]
    for msg in history[-10:]:
        content = str(msg["content"])
        if msg["role"] == "user":
            chatml_messages.append({"role": "user", "content": content})
        else:
            name = msg.get("metadata", {}).get("title", "Narrator")
            chatml_messages.append({"role": "assistant", "content": f"{name}: {content}"})
                
    chatml_messages.append({"role": "user", "content": str(message)})

    prompt = tokenizer.apply_chat_template(chatml_messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    token_count = inputs.input_ids.shape[1]

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=float(top_p), do_sample=True if float(temp) > 0 else False, pad_token_id=tokenizer.pad_token_id, streamer=streamer)

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    full_response = ""
    history.append({"role": "assistant", "content": "...", "metadata": {"title": "Generating..."}})
    for new_text in streamer:
        full_response += new_text
        history[-1]["content"] = full_response
        yield history, token_count

    parsed = parse_multi_character_output(full_response)
    if parsed:
        history.pop()
        for p in parsed:
            history.append(p)
    yield history, token_count

with gr.Blocks() as demo:
    gr.Markdown("# Fantecchi Hugging Face Interface")
    with gr.Row():
        with gr.Column(scale=1):
            profile_dropdown = gr.Dropdown(choices=list(profiles.keys()), label="Select Scenario", value=list(profiles.keys())[0] if profiles else None)
            load_btn = gr.Button("Load Model", variant="primary")
            load_status = gr.Textbox(label="Status", interactive=False)
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=600, type="messages")
            msg_input = gr.Textbox(placeholder="Type your response...")
            
            def user_submit(user_text, history):
                return "", history + [{"role": "user", "content": user_text}]

            msg_input.submit(user_submit, [msg_input, chatbot], [msg_input, chatbot]).then(generate_response, [msg_input, chatbot, profile_dropdown, gr.State(0.8), gr.State(0.9), gr.State(256)], [chatbot])

    load_btn.click(load_model, outputs=load_status)

if __name__ == "__main__":
    demo.launch()