Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from transformers import pipeline | |
| import torch | |
| from typing import List, Dict, Optional | |
| # Global variable to store pipelines | |
| model_cache = {} | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "Apollo-1-4B": "Loom-Labs/Apollo-1-4B", | |
| "Apollo-1-8B": "Loom-Labs/Apollo-1-8B", | |
| "Apollo-1-2B": "Loom-Labs/Apollo-1-2B", | |
| "Daedalus-1-2B": "Loom-Labs/Daedalus-1-2B", | |
| "Daedalus-1-8B": "Loom-Labs/Daedalus-1-8B", | |
| } | |
| def initialize_model(model_name): | |
| global model_cache | |
| if model_name not in AVAILABLE_MODELS: | |
| raise ValueError(f"Model {model_name} not found in available models") | |
| model_id = AVAILABLE_MODELS[model_name] | |
| # Check if model is already cached | |
| if model_id not in model_cache: | |
| try: | |
| model_cache[model_id] = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| # Fallback to CPU if GPU fails | |
| model_cache[model_id] = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| trust_remote_code=True | |
| ) | |
| return model_cache[model_id] | |
| def generate_response(message, history, model_name, max_length=512, temperature=0.7, top_p=0.9): | |
| """Generate response using the selected model""" | |
| # Initialize model inside the GPU-decorated function | |
| try: | |
| model_pipe = initialize_model(model_name) | |
| except Exception as e: | |
| return f"Error loading model {model_name}: {str(e)}" | |
| # Format the conversation history | |
| messages = [] | |
| # Add conversation history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Generate response | |
| try: | |
| # Some models may not support the messages format, so we'll try different approaches | |
| try: | |
| # Try with messages format first | |
| response = model_pipe( | |
| messages, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=model_pipe.tokenizer.eos_token_id, | |
| return_full_text=False | |
| ) | |
| except: | |
| # Fallback to simple text format | |
| conversation_text = "" | |
| for msg in messages: | |
| if msg["role"] == "user": | |
| conversation_text += f"User: {msg['content']}\n" | |
| else: | |
| conversation_text += f"Assistant: {msg['content']}\n" | |
| conversation_text += "Assistant:" | |
| response = model_pipe( | |
| conversation_text, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=model_pipe.tokenizer.eos_token_id, | |
| return_full_text=False | |
| ) | |
| # Extract the generated text | |
| if isinstance(response, list) and len(response) > 0: | |
| generated_text = response[0]['generated_text'] | |
| else: | |
| generated_text = str(response) | |
| # Clean up the response | |
| if isinstance(generated_text, list): | |
| assistant_response = generated_text[-1]['content'] | |
| else: | |
| # Remove the prompt and extract assistant response | |
| assistant_response = str(generated_text).strip() | |
| if "Assistant:" in assistant_response: | |
| assistant_response = assistant_response.split("Assistant:")[-1].strip() | |
| return assistant_response | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def generate( | |
| model: str, | |
| user_input: str, | |
| history: Optional[str] = "", | |
| temperature: float = 0.7, | |
| system_prompt: Optional[str] = "", | |
| max_tokens: int = 512 | |
| ): | |
| """ | |
| API endpoint for LLM generation | |
| Args: | |
| model: Model name to use (Nous-1-2B, Nous-1-4B, or Nous-1-8B) | |
| user_input: Current user message/input | |
| history: JSON string of conversation history in format [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] | |
| temperature: Temperature for generation (0.1-2.0) | |
| system_prompt: System prompt to guide the model | |
| max_tokens: Maximum tokens to generate (1-8192) | |
| Returns: | |
| Generated response from the model | |
| """ | |
| # Validate model | |
| if model not in AVAILABLE_MODELS: | |
| return f"Error: Model {model} not available. Available models: {list(AVAILABLE_MODELS.keys())}" | |
| # Initialize model | |
| try: | |
| model_pipe = initialize_model(model) | |
| except Exception as e: | |
| return f"Error loading model {model}: {str(e)}" | |
| # Parse history if provided and convert to gradio format | |
| gradio_history = [] | |
| if history and history.strip(): | |
| try: | |
| import json | |
| history_list = json.loads(history) | |
| current_pair = [None, None] | |
| for msg in history_list: | |
| if isinstance(msg, dict) and "role" in msg and "content" in msg: | |
| if msg["role"] == "user": | |
| if current_pair[0] is not None: | |
| gradio_history.append([current_pair[0], current_pair[1]]) | |
| current_pair = [msg["content"], None] | |
| elif msg["role"] == "assistant": | |
| current_pair[1] = msg["content"] | |
| if current_pair[0] is not None: | |
| gradio_history.append([current_pair[0], current_pair[1]]) | |
| except: | |
| # If history parsing fails, continue without history | |
| pass | |
| # Add system prompt to user input if provided | |
| final_user_input = user_input | |
| if system_prompt and system_prompt.strip(): | |
| final_user_input = f"System: {system_prompt}\n\nUser: {user_input}" | |
| # Use the original generate_response function | |
| return generate_response(final_user_input, gradio_history, model, max_tokens, temperature, 0.9) | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Multi-Model Chat") as demo: | |
| gr.Markdown(""" | |
| # 🚀 Loom Labs Model Chat Interface | |
| Chat with the models by Loom Labs. | |
| **Available Models:** | |
| - Apollo-1-4B (4 billion parameters) | |
| - Apollo-1-8B (8 billion parameters) | |
| - Apollo-1-2B (2 billion parameters) | |
| - Daedalus-1-2B (2 billion parameters) | |
| - Daedalus-1-8B (8 billion parameters) | |
| """) | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="Apollo-1-4B", | |
| label="Select Model", | |
| info="Choose which model to use for generation" | |
| ) | |
| chatbot = gr.Chatbot( | |
| height=400, | |
| placeholder="Select a model and start chatting...", | |
| label="Chat" | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Message", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_length = gr.Slider( | |
| minimum=200, | |
| maximum=8192, | |
| value=2048, | |
| step=50, | |
| label="Max Length", | |
| info="Maximum length of generated response" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness in generation" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top P", | |
| info="Controls diversity via nucleus sampling" | |
| ) | |
| # Event handlers | |
| def user_message(message, history): | |
| return "", history + [[message, None]] | |
| def bot_response(history, model_name, max_len, temp, top_p): | |
| if history: | |
| user_message = history[-1][0] | |
| bot_message = generate_response( | |
| user_message, | |
| history[:-1], | |
| model_name, | |
| max_len, | |
| temp, | |
| top_p | |
| ) | |
| history[-1][1] = bot_message | |
| return history | |
| def model_changed(model_name): | |
| return gr.update(placeholder=f"Chat with {model_name}...") | |
| # Wire up the events | |
| msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then( | |
| bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot | |
| ) | |
| submit_btn.click(user_message, [msg, chatbot], [msg, chatbot]).then( | |
| bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot | |
| ) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| model_selector.change(model_changed, model_selector, chatbot) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| # Enable API and launch | |
| demo.launch(share=True) |