FlameF0X commited on
Commit
56ccbf2
·
verified ·
1 Parent(s): 2517cb9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ import torch
4
+ from threading import Thread
5
+
6
+ # Available model options
7
+ MODEL_NAMES = {
8
+ "LFM2-350M": "LiquidAI/LFM2-350M",
9
+ "LFM2-700M": "LiquidAI/LFM2-700M",
10
+ "LFM2-1.2B": "LiquidAI/LFM2-1.2B",
11
+ "LFM2-2.6B": "LiquidAI/LFM2-2.6B",
12
+ "LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B",
13
+ }
14
+
15
+ # Cache for loaded models
16
+ model_cache = {}
17
+
18
+ def load_model(model_key):
19
+ """Load and cache the selected model."""
20
+ if model_key in model_cache:
21
+ return model_cache[model_key]
22
+
23
+ model_name = MODEL_NAMES[model_key]
24
+ print(f"Loading {model_name}...")
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ device_map="auto"
30
+ )
31
+ model_cache[model_key] = (tokenizer, model)
32
+ return tokenizer, model
33
+
34
+ def chat_with_model(message, history, model_choice):
35
+ tokenizer, model = load_model(model_choice)
36
+
37
+ # Build the chat history as a string
38
+ prompt = ""
39
+ for user_msg, bot_msg in history:
40
+ prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
41
+ prompt += f"User: {message}\nAssistant:"
42
+
43
+ # Streaming setup
44
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+
47
+ generation_kwargs = dict(
48
+ **inputs,
49
+ streamer=streamer,
50
+ max_new_tokens=256,
51
+ temperature=0.7,
52
+ do_sample=True,
53
+ top_p=0.9
54
+ )
55
+
56
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
+ thread.start()
58
+
59
+ partial_text = ""
60
+ for new_text in streamer:
61
+ partial_text += new_text
62
+ yield partial_text
63
+
64
+ def create_demo():
65
+ with gr.Blocks(title="LiquidAI Chat Interface") as demo:
66
+ gr.Markdown("## 💧 LiquidAI Model Chat Playground")
67
+
68
+ with gr.Row():
69
+ model_choice = gr.Dropdown(
70
+ label="Select Model",
71
+ choices=list(MODEL_NAMES.keys()),
72
+ value="LFM2-1.2B"
73
+ )
74
+
75
+ chatbot = gr.Chatbot(label="Chat with the model", height=450)
76
+ msg = gr.Textbox(label="Your message", placeholder="Type a message and hit Enter")
77
+
78
+ clear = gr.Button("Clear Chat")
79
+
80
+ def user_submit(user_message, chat_history, model_choice):
81
+ chat_history = chat_history + [(user_message, "")]
82
+ return "", chat_history, model_choice
83
+
84
+ msg.submit(
85
+ user_submit,
86
+ [msg, chatbot, model_choice],
87
+ [msg, chatbot, model_choice],
88
+ queue=False
89
+ ).then(
90
+ chat_with_model,
91
+ [msg, chatbot, model_choice],
92
+ chatbot
93
+ )
94
+
95
+ clear.click(lambda: None, None, chatbot, queue=False)
96
+
97
+ return demo
98
+
99
+ if __name__ == "__main__":
100
+ demo = create_demo()
101
+ demo.queue(max_size=32)
102
+ demo.launch(server_name="0.0.0.0", server_port=7860)