|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
import torch |
|
|
from threading import Thread |
|
|
|
|
|
MODEL_NAMES = { |
|
|
"LFM2-350M": "LiquidAI/LFM2-350M", |
|
|
"LFM2-700M": "LiquidAI/LFM2-700M", |
|
|
"LFM2-1.2B": "LiquidAI/LFM2-1.2B", |
|
|
"LFM2-2.6B": "LiquidAI/LFM2-2.6B", |
|
|
"LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B", |
|
|
} |
|
|
|
|
|
model_cache = {} |
|
|
|
|
|
def load_model(model_key): |
|
|
if model_key in model_cache: |
|
|
return model_cache[model_key] |
|
|
|
|
|
model_name = MODEL_NAMES[model_key] |
|
|
print(f"Loading {model_name}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
).to(device) |
|
|
|
|
|
model_cache[model_key] = (tokenizer, model) |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
def chat_with_model(message, history, model_choice): |
|
|
tokenizer, model = load_model(model_choice) |
|
|
device = model.device |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
for msg in history: |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
prompt += f"{role.capitalize()}: {content}\n" |
|
|
prompt += f"User: {message}\nAssistant:" |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
partial_text = "" |
|
|
for new_text in streamer: |
|
|
partial_text += new_text |
|
|
|
|
|
yield history + [ |
|
|
{"role": "user", "content": message}, |
|
|
{"role": "assistant", "content": partial_text}, |
|
|
] |
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
with gr.Blocks(title="LiquidAI Chat Playground") as demo: |
|
|
gr.Markdown("## 💧 LiquidAI Chat Playground") |
|
|
|
|
|
model_choice = gr.Dropdown( |
|
|
label="Select Model", |
|
|
choices=list(MODEL_NAMES.keys()), |
|
|
value="LFM2-1.2B" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
label="Chat with LiquidAI", |
|
|
type="messages", |
|
|
height=450 |
|
|
) |
|
|
|
|
|
msg = gr.Textbox(label="Your message", placeholder="Type something...") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
def add_user_message(user_message, chat_history): |
|
|
chat_history = chat_history + [{"role": "user", "content": user_message}] |
|
|
return "", chat_history |
|
|
|
|
|
msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
chat_with_model, [msg, chatbot, model_choice], chatbot |
|
|
) |
|
|
|
|
|
clear.click(lambda: [], None, chatbot, queue=False) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
demo.queue() |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|