File size: 3,108 Bytes
56ccbf2
0d4961f
56ccbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d4961f
56ccbf2
 
 
0d4961f
 
56ccbf2
0d4961f
13f0e11
0d4961f
 
56ccbf2
 
 
0d4961f
56ccbf2
 
0d4961f
56ccbf2
13f0e11
56ccbf2
0d4961f
13f0e11
 
 
56ccbf2
 
 
0d4961f
56ccbf2
 
 
 
 
 
0d4961f
56ccbf2
 
 
 
 
 
13f0e11
56ccbf2
13f0e11
 
 
 
 
 
0d4961f
56ccbf2
 
0d4961f
13f0e11
0d4961f
 
 
 
 
56ccbf2
 
0d4961f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56ccbf2
 
 
0d4961f
56ccbf2
 
0d4961f
56ccbf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread

MODEL_NAMES = {
    "LFM2-350M": "LiquidAI/LFM2-350M",
    "LFM2-700M": "LiquidAI/LFM2-700M",
    "LFM2-1.2B": "LiquidAI/LFM2-1.2B",
    "LFM2-2.6B": "LiquidAI/LFM2-2.6B",
    "LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B",
}

model_cache = {}

def load_model(model_key):
    if model_key in model_cache:
        return model_cache[model_key]

    model_name = MODEL_NAMES[model_key]
    print(f"Loading {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype=torch.float16 if device == "cuda" else torch.float32,
    ).to(device)

    model_cache[model_key] = (tokenizer, model)
    return tokenizer, model


def chat_with_model(message, history, model_choice):
    tokenizer, model = load_model(model_choice)
    device = model.device

    # Build the prompt from previous conversation
    prompt = ""
    for msg in history:
        role = msg["role"]
        content = msg["content"]
        prompt += f"{role.capitalize()}: {content}\n"
    prompt += f"User: {message}\nAssistant:"

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        # Yield full chat including this updated assistant message
        yield history + [
            {"role": "user", "content": message},
            {"role": "assistant", "content": partial_text},
        ]


def create_demo():
    with gr.Blocks(title="LiquidAI Chat Playground") as demo:
        gr.Markdown("## 💧 LiquidAI Chat Playground")

        model_choice = gr.Dropdown(
            label="Select Model",
            choices=list(MODEL_NAMES.keys()),
            value="LFM2-1.2B"
        )

        chatbot = gr.Chatbot(
            label="Chat with LiquidAI",
            type="messages",
            height=450
        )

        msg = gr.Textbox(label="Your message", placeholder="Type something...")
        clear = gr.Button("Clear")

        def add_user_message(user_message, chat_history):
            chat_history = chat_history + [{"role": "user", "content": user_message}]
            return "", chat_history

        msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
            chat_with_model, [msg, chatbot, model_choice], chatbot
        )

        clear.click(lambda: [], None, chatbot, queue=False)

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.queue()
    demo.launch(server_name="0.0.0.0", server_port=7860)