FlameF0X commited on
Commit
0d4961f
·
verified ·
1 Parent(s): 393c05e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -48
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import torch
4
  from threading import Thread
5
 
6
- # Available model options
7
  MODEL_NAMES = {
8
  "LFM2-350M": "LiquidAI/LFM2-350M",
9
  "LFM2-700M": "LiquidAI/LFM2-700M",
@@ -12,91 +11,94 @@ MODEL_NAMES = {
12
  "LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B",
13
  }
14
 
15
- # Cache for loaded models
16
  model_cache = {}
17
 
18
  def load_model(model_key):
19
- """Load and cache the selected model."""
20
  if model_key in model_cache:
21
  return model_cache[model_key]
22
-
23
  model_name = MODEL_NAMES[model_key]
24
  print(f"Loading {model_name}...")
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
- model_name,
28
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
- device_map="auto"
30
- )
 
31
  model_cache[model_key] = (tokenizer, model)
32
  return tokenizer, model
33
 
 
34
  def chat_with_model(message, history, model_choice):
35
  tokenizer, model = load_model(model_choice)
 
36
 
37
- # Build the chat history as a string
38
  prompt = ""
39
- for user_msg, bot_msg in history:
40
- prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
 
 
 
41
  prompt += f"User: {message}\nAssistant:"
42
 
43
- # Streaming setup
44
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
45
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
 
47
  generation_kwargs = dict(
48
  **inputs,
49
  streamer=streamer,
50
  max_new_tokens=256,
51
  temperature=0.7,
 
52
  do_sample=True,
53
- top_p=0.9
54
  )
55
 
56
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
  thread.start()
58
 
59
- partial_text = ""
60
  for new_text in streamer:
61
- partial_text += new_text
62
- yield partial_text
 
63
 
64
  def create_demo():
65
- with gr.Blocks(title="LiquidAI Chat Interface") as demo:
66
- gr.Markdown("## 💧 LiquidAI Model Chat Playground")
67
-
68
- with gr.Row():
69
- model_choice = gr.Dropdown(
70
- label="Select Model",
71
- choices=list(MODEL_NAMES.keys()),
72
- value="LFM2-1.2B"
73
- )
74
-
75
- chatbot = gr.Chatbot(label="Chat with the model", height=450)
76
- msg = gr.Textbox(label="Your message", placeholder="Type a message and hit Enter")
77
-
78
- clear = gr.Button("Clear Chat")
79
-
80
- def user_submit(user_message, chat_history, model_choice):
81
- chat_history = chat_history + [(user_message, "")]
82
- return "", chat_history, model_choice
83
-
84
- msg.submit(
85
- user_submit,
86
- [msg, chatbot, model_choice],
87
- [msg, chatbot, model_choice],
88
- queue=False
89
- ).then(
90
- chat_with_model,
91
- [msg, chatbot, model_choice],
92
- chatbot
93
  )
94
 
95
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  return demo
98
 
 
99
  if __name__ == "__main__":
100
  demo = create_demo()
101
- demo.queue(max_size=32)
102
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  import torch
4
  from threading import Thread
5
 
 
6
  MODEL_NAMES = {
7
  "LFM2-350M": "LiquidAI/LFM2-350M",
8
  "LFM2-700M": "LiquidAI/LFM2-700M",
 
11
  "LFM2-8B-A1B": "LiquidAI/LFM2-8B-A1B",
12
  }
13
 
 
14
  model_cache = {}
15
 
16
  def load_model(model_key):
 
17
  if model_key in model_cache:
18
  return model_cache[model_key]
19
+
20
  model_name = MODEL_NAMES[model_key]
21
  print(f"Loading {model_name}...")
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
28
+ device_map=None, # Disable meta/offload shenanigans
29
+ ).to(device)
30
+
31
  model_cache[model_key] = (tokenizer, model)
32
  return tokenizer, model
33
 
34
+
35
  def chat_with_model(message, history, model_choice):
36
  tokenizer, model = load_model(model_choice)
37
+ device = model.device
38
 
39
+ # Convert the Gradio message history into a string prompt
40
  prompt = ""
41
+ for msg in history:
42
+ if msg["role"] == "user":
43
+ prompt += f"User: {msg['content']}\n"
44
+ elif msg["role"] == "assistant":
45
+ prompt += f"Assistant: {msg['content']}\n"
46
  prompt += f"User: {message}\nAssistant:"
47
 
 
48
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
49
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
50
 
51
  generation_kwargs = dict(
52
  **inputs,
53
  streamer=streamer,
54
  max_new_tokens=256,
55
  temperature=0.7,
56
+ top_p=0.9,
57
  do_sample=True,
 
58
  )
59
 
60
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
61
  thread.start()
62
 
63
+ partial = ""
64
  for new_text in streamer:
65
+ partial += new_text
66
+ yield partial
67
+
68
 
69
  def create_demo():
70
+ with gr.Blocks(title="LiquidAI Chat Playground") as demo:
71
+ gr.Markdown("## 💧 LiquidAI Chat Interface")
72
+
73
+ model_choice = gr.Dropdown(
74
+ label="Select Model",
75
+ choices=list(MODEL_NAMES.keys()),
76
+ value="LFM2-1.2B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
 
79
+ chatbot = gr.Chatbot(
80
+ label="Chat with LiquidAI",
81
+ type="messages",
82
+ height=450
83
+ )
84
+
85
+ msg = gr.Textbox(label="Your message", placeholder="Type something...")
86
+ clear = gr.Button("Clear")
87
+
88
+ def add_user_message(user_message, chat_history):
89
+ chat_history = chat_history + [{"role": "user", "content": user_message}]
90
+ return "", chat_history
91
+
92
+ msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
93
+ chat_with_model, [msg, chatbot, model_choice], chatbot
94
+ )
95
+
96
+ clear.click(lambda: [], None, chatbot, queue=False)
97
 
98
  return demo
99
 
100
+
101
  if __name__ == "__main__":
102
  demo = create_demo()
103
+ demo.queue()
104
  demo.launch(server_name="0.0.0.0", server_port=7860)