aeb56 commited on
Commit
9905f0a
Β·
1 Parent(s): a82de92

Switch to transformers inference (vLLM doesn't support KimiLinear architecture)

Browse files
Files changed (2) hide show
  1. app.py +132 -269
  2. requirements.txt +8 -9
app.py CHANGED
@@ -1,312 +1,175 @@
1
  import gradio as gr
2
- import requests
3
- import json
4
- import subprocess
5
- import time
6
  import os
7
- import signal
8
- import sys
9
 
10
  # Model configuration
11
  MODEL_NAME = "optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune"
12
- VLLM_PORT = 8000
13
- VLLM_PROCESS = None
14
 
15
- def start_vllm_server():
16
- """Start vLLM server in background"""
17
- global VLLM_PROCESS
18
-
19
- if VLLM_PROCESS is not None:
20
- return "βœ… vLLM server already running"
21
-
22
- try:
23
- # Start vLLM server with tensor parallelism for multi-GPU
24
- cmd = [
25
- "python3", "-m", "vllm.entrypoints.openai.api_server",
26
- "--model", MODEL_NAME,
27
- "--host", "0.0.0.0",
28
- "--port", str(VLLM_PORT),
29
- "--dtype", "bfloat16",
30
- "--trust-remote-code",
31
- "--tensor-parallel-size", "4", # Use all 4 GPUs
32
- "--max-model-len", "8192", # Limit context to save memory
33
- ]
34
-
35
- log_file = open("/tmp/vllm.log", "w")
36
- VLLM_PROCESS = subprocess.Popen(
37
- cmd,
38
- stdout=log_file,
39
- stderr=subprocess.STDOUT,
40
- preexec_fn=os.setsid if sys.platform != 'win32' else None
41
- )
42
-
43
- status_msg = "πŸ”„ **vLLM server starting...**\n\n"
44
- status_msg += "This takes 5-10 minutes for the 48B model.\n\n"
45
- status_msg += "**Progress:**\n"
46
- status_msg += "1. Downloading model (if not cached)\n"
47
- status_msg += "2. Loading weights across 4 GPUs\n"
48
- status_msg += "3. Initializing inference engine\n\n"
49
- status_msg += "**Status:** Initializing...\n\n"
50
- status_msg += "_Check logs at /tmp/vllm.log for details_"
51
 
52
- # Wait longer for big model - up to 10 minutes
53
- max_retries = 300 # 300 * 2 seconds = 10 minutes
54
- for i in range(max_retries):
55
- try:
56
- response = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
57
- if response.status_code == 200:
58
- return "βœ… **vLLM server started successfully!**\n\nYou can now start chatting below."
59
- except requests.exceptions.RequestException:
60
- pass
61
 
62
- # Check if process died
63
- if VLLM_PROCESS.poll() is not None:
64
- # Process ended
65
- with open("/tmp/vllm.log", "r") as f:
66
- last_lines = f.readlines()[-20:]
67
- error_msg = "❌ **vLLM server crashed during startup**\n\n"
68
- error_msg += "**Last log lines:**\n```\n"
69
- error_msg += "".join(last_lines)
70
- error_msg += "\n```"
71
- return error_msg
72
 
73
- time.sleep(2)
74
-
75
- # Timeout but process still running
76
- return "⚠️ **vLLM server started but taking longer than expected**\n\nThe server may still be initializing. Wait a few more minutes and try sending a message."
77
-
78
- except Exception as e:
79
- return f"❌ **Failed to start vLLM server:**\n\n{str(e)}"
80
-
81
- def view_logs():
82
- """View vLLM server logs"""
83
- try:
84
- if not os.path.exists("/tmp/vllm.log"):
85
- return "πŸ“ No logs yet. Start the server first."
86
-
87
- with open("/tmp/vllm.log", "r") as f:
88
- lines = f.readlines()
89
- last_lines = lines[-50:] # Last 50 lines
90
-
91
- log_text = "πŸ“‹ **vLLM Server Logs (Last 50 lines)**\n\n```\n"
92
- log_text += "".join(last_lines)
93
- log_text += "\n```"
94
- return log_text
95
- except Exception as e:
96
- return f"❌ Error reading logs: {str(e)}"
97
-
98
- def chat(message, history, system_prompt, max_tokens, temperature, top_p):
99
- """Send chat message to vLLM server"""
100
- try:
101
- # Build messages
102
- messages = []
103
-
104
- if system_prompt.strip():
105
- messages.append({"role": "system", "content": system_prompt.strip()})
106
-
107
- # Add history
108
- for human, assistant in history:
109
- messages.append({"role": "user", "content": human})
110
- if assistant:
111
- messages.append({"role": "assistant", "content": assistant})
112
-
113
- # Add current message
114
- messages.append({"role": "user", "content": message})
115
-
116
- # Call vLLM API
117
- response = requests.post(
118
- f"http://localhost:{VLLM_PORT}/v1/chat/completions",
119
- headers={"Content-Type": "application/json"},
120
- json={
121
- "model": MODEL_NAME,
122
- "messages": messages,
123
- "max_tokens": max_tokens,
124
- "temperature": temperature,
125
- "top_p": top_p,
126
- "stream": False
127
- },
128
- timeout=300
129
- )
130
 
131
- if response.status_code == 200:
132
- result = response.json()
133
- assistant_message = result["choices"][0]["message"]["content"]
134
- return assistant_message
135
- else:
136
- return f"❌ Error: {response.status_code} - {response.text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- except requests.exceptions.ConnectionError:
139
- return "❌ Cannot connect to vLLM server. Please start the server first."
140
- except Exception as e:
141
- return f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
142
 
143
- # Custom CSS
144
- custom_css = """
145
- .gradio-container {
146
- max-width: 1200px !important;
147
- }
148
- """
149
 
150
- # Create Gradio interface
151
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Kimi 48B Fine-tuned") as demo:
152
  gr.Markdown("""
153
- # πŸš€ Kimi Linear 48B A3B - Fine-tuned Inference
154
 
155
- High-performance inference using **vLLM** for the fine-tuned Kimi-Linear-48B-A3B-Instruct model.
156
 
157
  **Model:** `optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune`
158
  """)
159
 
 
 
 
 
 
 
 
160
  with gr.Row():
161
  with gr.Column(scale=1):
162
- gr.Markdown("### πŸŽ›οΈ Server Control")
163
- start_btn = gr.Button("πŸš€ Start vLLM Server", variant="primary", size="lg")
164
- server_status = gr.Markdown("**Status:** Server not started")
165
- view_logs_btn = gr.Button("πŸ“‹ View Server Logs", size="sm")
166
- logs_display = gr.Markdown("", visible=False)
167
 
168
  gr.Markdown("---")
169
- gr.Markdown("### βš™οΈ Generation Settings")
170
 
171
  system_prompt = gr.Textbox(
172
- label="System Prompt (Optional)",
173
- placeholder="You are a helpful AI assistant...",
174
- lines=3,
175
- value=""
176
- )
177
-
178
- max_tokens = gr.Slider(
179
- minimum=50,
180
- maximum=4096,
181
- value=1024,
182
- step=1,
183
- label="Max Tokens"
184
  )
185
 
186
- temperature = gr.Slider(
187
- minimum=0.0,
188
- maximum=2.0,
189
- value=0.7,
190
- step=0.05,
191
- label="Temperature"
192
- )
193
-
194
- top_p = gr.Slider(
195
- minimum=0.0,
196
- maximum=1.0,
197
- value=0.9,
198
- step=0.05,
199
- label="Top P"
200
- )
201
-
202
- gr.Markdown("""
203
- ### πŸ“– Instructions
204
-
205
- 1. **Start Server** - Click the button above (takes 2-5 min)
206
- 2. **Wait for "βœ…"** - Server is ready when you see green checkmark
207
- 3. **Start Chatting** - Type your message below
208
-
209
- **Note:** First message may be slow as the model loads into memory.
210
- """)
211
 
212
  with gr.Column(scale=2):
213
  gr.Markdown("### πŸ’¬ Chat")
214
-
215
- chatbot = gr.Chatbot(
216
- height=500,
217
- show_copy_button=True
218
- )
219
 
220
  with gr.Row():
221
- msg = gr.Textbox(
222
- label="Your Message",
223
- placeholder="Type your message here...",
224
- lines=2,
225
- scale=4
226
- )
227
- send_btn = gr.Button("πŸ“€ Send", variant="primary", scale=1)
228
 
229
- with gr.Row():
230
- clear_btn = gr.Button("πŸ—‘οΈ Clear Chat")
231
-
232
- # Event handlers
233
- start_btn.click(
234
- fn=start_vllm_server,
235
- outputs=server_status
236
- )
237
-
238
- def show_logs():
239
- return {logs_display: gr.update(value=view_logs(), visible=True)}
240
-
241
- view_logs_btn.click(
242
- fn=show_logs,
243
- outputs=logs_display
244
- )
245
-
246
- def user_message(user_msg, history):
247
- return "", history + [[user_msg, None]]
248
-
249
- def bot_response(history, system_prompt, max_tokens, temperature, top_p):
250
- if not history or history[-1][1] is not None:
251
- return history
252
-
253
- user_msg = history[-1][0]
254
- bot_msg = chat(user_msg, history[:-1], system_prompt, max_tokens, temperature, top_p)
255
- history[-1][1] = bot_msg
256
- return history
257
 
258
- msg.submit(
259
- user_message,
260
- [msg, chatbot],
261
- [msg, chatbot],
262
- queue=False
263
- ).then(
264
- bot_response,
265
- [chatbot, system_prompt, max_tokens, temperature, top_p],
266
- chatbot
267
- )
268
 
269
- send_btn.click(
270
- user_message,
271
- [msg, chatbot],
272
- [msg, chatbot],
273
- queue=False
274
- ).then(
275
- bot_response,
276
- [chatbot, system_prompt, max_tokens, temperature, top_p],
277
- chatbot
278
- )
279
 
280
- clear_btn.click(lambda: None, None, chatbot, queue=False)
 
 
281
 
282
  gr.Markdown("""
283
  ---
284
-
285
- **Powered by vLLM** - High-performance LLM inference engine
286
-
287
  **Model:** [optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune](https://huggingface.co/optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune)
288
  """)
289
 
290
- # Cleanup on exit
291
- def cleanup():
292
- global VLLM_PROCESS
293
- if VLLM_PROCESS:
294
- try:
295
- if sys.platform == 'win32':
296
- VLLM_PROCESS.terminate()
297
- else:
298
- os.killpg(os.getpgid(VLLM_PROCESS.pid), signal.SIGTERM)
299
- except:
300
- pass
301
-
302
- import atexit
303
- atexit.register(cleanup)
304
-
305
  if __name__ == "__main__":
306
- demo.queue()
307
- demo.launch(
308
- server_name="0.0.0.0",
309
- server_port=7860,
310
- share=True,
311
- show_error=True
312
- )
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
  import os
 
 
5
 
6
  # Model configuration
7
  MODEL_NAME = "optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune"
 
 
8
 
9
+ class ChatBot:
10
+ def __init__(self):
11
+ self.model = None
12
+ self.tokenizer = None
13
+ self.loaded = False
14
+
15
+ def load_model(self):
16
+ if self.loaded:
17
+ return "βœ… Model already loaded!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ try:
20
+ yield "πŸ”„ Loading tokenizer..."
21
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
22
 
23
+ yield "πŸ”„ Loading model (this takes 5-10 minutes)...\n\nThe 48B model is being distributed across 4 GPUs..."
 
 
 
 
 
 
 
 
 
24
 
25
+ # Configure memory for 4 GPUs
26
+ num_gpus = torch.cuda.device_count()
27
+ max_memory = {i: f"{int(23)}GB" for i in range(num_gpus)} # L4 has 24GB, leave 1GB
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_NAME,
31
+ torch_dtype=torch.bfloat16,
32
+ device_map="balanced", # Distribute evenly
33
+ max_memory=max_memory,
34
+ trust_remote_code=True,
35
+ low_cpu_mem_usage=True,
36
+ )
37
+
38
+ self.model.eval()
39
+ self.loaded = True
40
+
41
+ # Get GPU distribution info
42
+ if hasattr(self.model, 'hf_device_map'):
43
+ device_info = "\n\n**GPU Distribution:**\n"
44
+ devices = {}
45
+ for name, device in self.model.hf_device_map.items():
46
+ if device not in devices:
47
+ devices[device] = 0
48
+ devices[device] += 1
49
+ for device, count in devices.items():
50
+ device_info += f"- {device}: {count} layers\n"
51
+ else:
52
+ device_info = ""
53
+
54
+ yield f"βœ… **Model loaded successfully!**{device_info}\n\nYou can now start chatting below."
55
+
56
+ except Exception as e:
57
+ self.loaded = False
58
+ yield f"❌ **Error loading model:**\n\n{str(e)}"
59
+
60
+ def chat(self, message, history, system_prompt, max_tokens, temperature, top_p):
61
+ if not self.loaded:
62
+ return "❌ Please load the model first by clicking the 'Load Model' button."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ try:
65
+ # Build prompt from history
66
+ conversation = []
67
+ if system_prompt.strip():
68
+ conversation.append(f"System: {system_prompt}")
69
+
70
+ for user_msg, bot_msg in history:
71
+ conversation.append(f"User: {user_msg}")
72
+ if bot_msg:
73
+ conversation.append(f"Assistant: {bot_msg}")
74
+
75
+ conversation.append(f"User: {message}")
76
+ conversation.append("Assistant:")
77
+
78
+ prompt = "\n".join(conversation)
79
+
80
+ # Tokenize
81
+ inputs = self.tokenizer(prompt, return_tensors="pt")
82
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
83
+
84
+ # Generate
85
+ with torch.no_grad():
86
+ outputs = self.model.generate(
87
+ **inputs,
88
+ max_new_tokens=max_tokens,
89
+ temperature=temperature,
90
+ top_p=top_p,
91
+ do_sample=temperature > 0,
92
+ pad_token_id=self.tokenizer.eos_token_id,
93
+ )
94
 
95
+ # Decode
96
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+
98
+ # Extract assistant response
99
+ if "Assistant:" in response:
100
+ response = response.split("Assistant:")[-1].strip()
101
+
102
+ return response
103
+
104
+ except Exception as e:
105
+ return f"❌ Error: {str(e)}"
106
 
107
+ # Initialize
108
+ bot = ChatBot()
 
 
 
 
109
 
110
+ # UI
111
+ with gr.Blocks(theme=gr.themes.Soft(), title="Kimi 48B Fine-tuned") as demo:
112
  gr.Markdown("""
113
+ # πŸš€ Kimi Linear 48B A3B - Fine-tuned
114
 
115
+ Chat interface for the fine-tuned Kimi model.
116
 
117
  **Model:** `optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune`
118
  """)
119
 
120
+ # Show GPU info
121
+ if torch.cuda.is_available():
122
+ gpu_count = torch.cuda.device_count()
123
+ gpu_name = torch.cuda.get_device_name(0)
124
+ total_vram = sum(torch.cuda.get_device_properties(i).total_memory / 1024**3 for i in range(gpu_count))
125
+ gr.Markdown(f"**Hardware:** {gpu_count}x {gpu_name} ({total_vram:.0f}GB total VRAM)")
126
+
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
+ gr.Markdown("### πŸŽ›οΈ Controls")
130
+
131
+ load_btn = gr.Button("πŸš€ Load Model", variant="primary", size="lg")
132
+ status = gr.Markdown("**Status:** Model not loaded")
 
133
 
134
  gr.Markdown("---")
135
+ gr.Markdown("### βš™οΈ Settings")
136
 
137
  system_prompt = gr.Textbox(
138
+ label="System Prompt",
139
+ placeholder="You are a helpful assistant...",
140
+ lines=2
 
 
 
 
 
 
 
 
 
141
  )
142
 
143
+ max_tokens = gr.Slider(50, 2048, 512, label="Max Tokens", step=1)
144
+ temperature = gr.Slider(0, 2, 0.7, label="Temperature", step=0.1)
145
+ top_p = gr.Slider(0, 1, 0.9, label="Top P", step=0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  with gr.Column(scale=2):
148
  gr.Markdown("### πŸ’¬ Chat")
149
+ chatbot = gr.Chatbot(height=500, show_copy_button=True)
 
 
 
 
150
 
151
  with gr.Row():
152
+ msg = gr.Textbox(label="Message", placeholder="Type here...", scale=4)
153
+ send = gr.Button("Send", variant="primary", scale=1)
 
 
 
 
 
154
 
155
+ clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Events
158
+ load_btn.click(bot.load_model, outputs=status)
 
 
 
 
 
 
 
 
159
 
160
+ def respond(message, history, system, max_tok, temp, top):
161
+ bot_message = bot.chat(message, history, system, max_tok, temp, top)
162
+ history.append((message, bot_message))
163
+ return history, ""
 
 
 
 
 
 
164
 
165
+ msg.submit(respond, [msg, chatbot, system_prompt, max_tokens, temperature, top_p], [chatbot, msg])
166
+ send.click(respond, [msg, chatbot, system_prompt, max_tokens, temperature, top_p], [chatbot, msg])
167
+ clear.click(lambda: None, None, chatbot)
168
 
169
  gr.Markdown("""
170
  ---
 
 
 
171
  **Model:** [optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune](https://huggingface.co/optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune)
172
  """)
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if __name__ == "__main__":
175
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,12 +1,11 @@
1
- # vLLM for high-performance inference
2
- vllm>=0.6.0
 
 
 
3
 
4
- # Core dependencies
5
  gradio==4.19.2
6
- requests>=2.31.0
7
 
8
- # Note: vLLM automatically installs:
9
- # - torch
10
- # - transformers
11
- # - tokenizers
12
- # - etc.
 
1
+ # Core ML dependencies
2
+ torch>=2.1.0
3
+ transformers>=4.56.0
4
+ accelerate>=0.34.0
5
+ sentencepiece>=0.1.99
6
 
7
+ # UI
8
  gradio==4.19.2
 
9
 
10
+ # Utils
11
+ safetensors>=0.4.0