Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| # Model configuration | |
| MODEL_NAME = "optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune" | |
| class ChatBot: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.loaded = False | |
| def load_model(self): | |
| if self.loaded: | |
| return "β Model already loaded!" | |
| try: | |
| yield "π Loading tokenizer..." | |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| yield "π Loading model (this takes 5-10 minutes)...\n\nThe 48B model is being distributed across 4 GPUs..." | |
| # Configure memory for 4 GPUs | |
| num_gpus = torch.cuda.device_count() | |
| max_memory = {i: f"{int(23)}GB" for i in range(num_gpus)} # L4 has 24GB, leave 1GB | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| device_map="balanced", # Distribute evenly | |
| max_memory=max_memory, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| self.model.eval() | |
| self.loaded = True | |
| # Get GPU distribution info | |
| if hasattr(self.model, 'hf_device_map'): | |
| device_info = "\n\n**GPU Distribution:**\n" | |
| devices = {} | |
| for name, device in self.model.hf_device_map.items(): | |
| if device not in devices: | |
| devices[device] = 0 | |
| devices[device] += 1 | |
| for device, count in devices.items(): | |
| device_info += f"- {device}: {count} layers\n" | |
| else: | |
| device_info = "" | |
| yield f"β **Model loaded successfully!**{device_info}\n\nYou can now start chatting below." | |
| except Exception as e: | |
| self.loaded = False | |
| yield f"β **Error loading model:**\n\n{str(e)}" | |
| def chat(self, message, history, system_prompt, max_tokens, temperature, top_p): | |
| if not self.loaded: | |
| return "β Please load the model first by clicking the 'Load Model' button." | |
| try: | |
| # Build prompt from history | |
| conversation = [] | |
| if system_prompt.strip(): | |
| conversation.append(f"System: {system_prompt}") | |
| for user_msg, bot_msg in history: | |
| conversation.append(f"User: {user_msg}") | |
| if bot_msg: | |
| conversation.append(f"Assistant: {bot_msg}") | |
| conversation.append(f"User: {message}") | |
| conversation.append("Assistant:") | |
| prompt = "\n".join(conversation) | |
| # Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract assistant response | |
| if "Assistant:" in response: | |
| response = response.split("Assistant:")[-1].strip() | |
| return response | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # Initialize | |
| bot = ChatBot() | |
| # UI | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Kimi 48B Fine-tuned") as demo: | |
| gr.Markdown(""" | |
| # π Kimi Linear 48B A3B - Fine-tuned | |
| Chat interface for the fine-tuned Kimi model. | |
| **Model:** `optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune` | |
| """) | |
| # Show GPU info | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| gpu_name = torch.cuda.get_device_name(0) | |
| total_vram = sum(torch.cuda.get_device_properties(i).total_memory / 1024**3 for i in range(gpu_count)) | |
| gr.Markdown(f"**Hardware:** {gpu_count}x {gpu_name} ({total_vram:.0f}GB total VRAM)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ποΈ Controls") | |
| load_btn = gr.Button("π Load Model", variant="primary", size="lg") | |
| status = gr.Markdown("**Status:** Model not loaded") | |
| gr.Markdown("---") | |
| gr.Markdown("### βοΈ Settings") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="You are a helpful assistant...", | |
| lines=2 | |
| ) | |
| max_tokens = gr.Slider(50, 2048, 512, label="Max Tokens", step=1) | |
| temperature = gr.Slider(0, 2, 0.7, label="Temperature", step=0.1) | |
| top_p = gr.Slider(0, 1, 0.9, label="Top P", step=0.05) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¬ Chat") | |
| chatbot = gr.Chatbot(height=500, show_copy_button=True) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Message", placeholder="Type here...", scale=4) | |
| send = gr.Button("Send", variant="primary", scale=1) | |
| clear = gr.Button("Clear") | |
| # Events | |
| load_btn.click(bot.load_model, outputs=status) | |
| def respond(message, history, system, max_tok, temp, top): | |
| bot_message = bot.chat(message, history, system, max_tok, temp, top) | |
| history.append((message, bot_message)) | |
| return history, "" | |
| msg.submit(respond, [msg, chatbot, system_prompt, max_tokens, temperature, top_p], [chatbot, msg]) | |
| send.click(respond, [msg, chatbot, system_prompt, max_tokens, temperature, top_p], [chatbot, msg]) | |
| clear.click(lambda: None, None, chatbot) | |
| gr.Markdown(""" | |
| --- | |
| **Model:** [optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune](https://huggingface.co/optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |