|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
import torch |
|
|
from threading import Thread |
|
|
|
|
|
|
|
|
MODEL_NAMES = { |
|
|
"LFM2-350M": "LiquidAI/LFM2-350M", |
|
|
"LFM2-700M": "LiquidAI/LFM2-700M", |
|
|
"LFM2-1.2B": "LiquidAI/LFM2-1.2B", |
|
|
"LFM2-2.6B": "LiquidAI/LFM2-2.6B", |
|
|
"LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B", |
|
|
} |
|
|
|
|
|
|
|
|
model_cache = {} |
|
|
|
|
|
def load_model(model_key): |
|
|
"""Load and cache the selected model.""" |
|
|
if model_key in model_cache: |
|
|
return model_cache[model_key] |
|
|
|
|
|
model_name = MODEL_NAMES[model_key] |
|
|
print(f"Loading {model_name}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
model_cache[model_key] = (tokenizer, model) |
|
|
return tokenizer, model |
|
|
|
|
|
def chat_with_model(message, history, model_choice): |
|
|
tokenizer, model = load_model(model_choice) |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
for user_msg, bot_msg in history: |
|
|
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
|
|
prompt += f"User: {message}\nAssistant:" |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
partial_text = "" |
|
|
for new_text in streamer: |
|
|
partial_text += new_text |
|
|
yield partial_text |
|
|
|
|
|
def create_demo(): |
|
|
with gr.Blocks(title="LiquidAI Chat Interface") as demo: |
|
|
gr.Markdown("## 💧 LiquidAI Model Chat Playground") |
|
|
|
|
|
with gr.Row(): |
|
|
model_choice = gr.Dropdown( |
|
|
label="Select Model", |
|
|
choices=list(MODEL_NAMES.keys()), |
|
|
value="LFM2-1.2B" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Chat with the model", height=450) |
|
|
msg = gr.Textbox(label="Your message", placeholder="Type a message and hit Enter") |
|
|
|
|
|
clear = gr.Button("Clear Chat") |
|
|
|
|
|
def user_submit(user_message, chat_history, model_choice): |
|
|
chat_history = chat_history + [(user_message, "")] |
|
|
return "", chat_history, model_choice |
|
|
|
|
|
msg.submit( |
|
|
user_submit, |
|
|
[msg, chatbot, model_choice], |
|
|
[msg, chatbot, model_choice], |
|
|
queue=False |
|
|
).then( |
|
|
chat_with_model, |
|
|
[msg, chatbot, model_choice], |
|
|
chatbot |
|
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
demo.queue(max_size=32) |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|