File size: 5,248 Bytes
e76f2f8
 
 
 
 
0e929c3
 
 
e76f2f8
0e929c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e76f2f8
e4f4e0b
 
 
 
 
 
 
0e929c3
e4f4e0b
0e929c3
edcc1b2
 
e4f4e0b
 
 
 
edcc1b2
e4f4e0b
 
edcc1b2
e4f4e0b
 
 
edcc1b2
e4f4e0b
 
edcc1b2
e4f4e0b
 
 
 
 
 
 
 
 
 
edcc1b2
e4f4e0b
 
 
edcc1b2
e4f4e0b
 
 
 
 
 
 
0e929c3
e4f4e0b
0e929c3
 
 
 
e4f4e0b
0e929c3
e4f4e0b
 
 
 
 
 
 
 
 
 
 
 
0e929c3
 
 
 
 
 
e4f4e0b
0e929c3
e4f4e0b
0e929c3
 
 
e4f4e0b
6c38322
01bfc56
0e929c3
 
e4f4e0b
 
 
 
 
 
 
 
 
 
 
0e929c3
e4f4e0b
 
 
 
 
 
 
e76f2f8
edcc1b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread

# Initialize cache for models and tokenizers
model_cache = {}
tokenizer_cache = {}

def load_model_and_tokenizer(model_name):
    """Load model and tokenizer with caching to avoid reloading the same model"""
    if model_name not in model_cache:
        print(f"Loading model: {model_name}")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        model_cache[model_name] = model
        
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Set pad token if missing
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        # Define a custom chat template if one is not available
        if tokenizer.chat_template is None:
            # Basic ChatML-style template
            tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
            
        tokenizer_cache[model_name] = tokenizer
    
    return model_cache[model_name], tokenizer_cache[model_name]

# Define available models
available_models = [
    "GoofyLM/BrainrotLM-Assistant-362M",
    "GoofyLM/BrainrotLM2-Assistant-362M"
]

def respond(message, chat_history, model_choice, system_message, max_tokens, temperature, top_p):
    # Load selected model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_choice)
    
    # Build conversation messages
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in chat_history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:  # This might be None during streaming
            messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add the current message
    messages.append({"role": "user", "content": message})
    
    # Format prompt using chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Set up streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Configure generation parameters
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=(temperature > 0 or top_p < 1.0),
        pad_token_id=tokenizer.pad_token_id
    )
    
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Stream the response
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield chat_history + [(message, partial_message)]
        
    return chat_history + [(message, partial_message)]

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# BrainrotLM Chat Interface")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=600)
            
            with gr.Row():
                msg = gr.Textbox(
                    label="Message",
                    placeholder="Type your message here...",
                    lines=3,
                    show_label=False
                )
                submit = gr.Button("Send", variant="primary")
            
            clear = gr.Button("Clear Conversation")
            
        with gr.Column(scale=1):
            model_dropdown = gr.Dropdown(
                choices=available_models,
                value=available_models[0],
                label="Select Model"
            )
            
            system_message = gr.Textbox(
                value="Your name is BrainrotLM, an AI assistant trained by GoofyLM.",
                label="System message",
                lines=4
            )
            
            max_tokens = gr.Slider(1, 512, value=144, label="Max new tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.67, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)")
    
    # Set up event handlers
    submit_event = msg.submit(
        respond,
        inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
        outputs=chatbot
    )
    
    submit_click = submit.click(
        respond,
        inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
        outputs=chatbot
    )
    
    # Clear message box after sending
    submit_event.then(lambda: "", None, msg)
    submit_click.then(lambda: "", None, msg)
    
    # Clear conversation button
    clear.click(lambda: None, None, chatbot)

if __name__ == "__main__":
    demo.launch()