Spaces:
Sleeping
Sleeping
File size: 5,248 Bytes
e76f2f8 0e929c3 e76f2f8 0e929c3 e76f2f8 e4f4e0b 0e929c3 e4f4e0b 0e929c3 edcc1b2 e4f4e0b edcc1b2 e4f4e0b edcc1b2 e4f4e0b edcc1b2 e4f4e0b edcc1b2 e4f4e0b edcc1b2 e4f4e0b edcc1b2 e4f4e0b 0e929c3 e4f4e0b 0e929c3 e4f4e0b 0e929c3 e4f4e0b 0e929c3 e4f4e0b 0e929c3 e4f4e0b 0e929c3 e4f4e0b 6c38322 01bfc56 0e929c3 e4f4e0b 0e929c3 e4f4e0b e76f2f8 edcc1b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
# Initialize cache for models and tokenizers
model_cache = {}
tokenizer_cache = {}
def load_model_and_tokenizer(model_name):
"""Load model and tokenizer with caching to avoid reloading the same model"""
if model_name not in model_cache:
print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
model_cache[model_name] = model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Define a custom chat template if one is not available
if tokenizer.chat_template is None:
# Basic ChatML-style template
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
tokenizer_cache[model_name] = tokenizer
return model_cache[model_name], tokenizer_cache[model_name]
# Define available models
available_models = [
"GoofyLM/BrainrotLM-Assistant-362M",
"GoofyLM/BrainrotLM2-Assistant-362M"
]
def respond(message, chat_history, model_choice, system_message, max_tokens, temperature, top_p):
# Load selected model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_choice)
# Build conversation messages
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # This might be None during streaming
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
# Format prompt using chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Configure generation parameters
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=(temperature > 0 or top_p < 1.0),
pad_token_id=tokenizer.pad_token_id
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield chat_history + [(message, partial_message)]
return chat_history + [(message, partial_message)]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# BrainrotLM Chat Interface")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
lines=3,
show_label=False
)
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear Conversation")
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=available_models,
value=available_models[0],
label="Select Model"
)
system_message = gr.Textbox(
value="Your name is BrainrotLM, an AI assistant trained by GoofyLM.",
label="System message",
lines=4
)
max_tokens = gr.Slider(1, 512, value=144, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.67, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)")
# Set up event handlers
submit_event = msg.submit(
respond,
inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
outputs=chatbot
)
submit_click = submit.click(
respond,
inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
outputs=chatbot
)
# Clear message box after sending
submit_event.then(lambda: "", None, msg)
submit_click.then(lambda: "", None, msg)
# Clear conversation button
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
demo.launch() |