Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import mlx.core as mx | |
| import utils | |
| # Load the model and tokenizer | |
| def load_model(model_path, adapter_path): | |
| model, tokenizer, _ = utils.load(model_path) | |
| if adapter_path: | |
| try: | |
| adapter_weights = mx.load(adapter_path) | |
| # Filter out any weights that don't match the model's structure | |
| filtered_weights = {k: v for k, v in adapter_weights.items() if k in model.parameters()} | |
| model.load_weights(filtered_weights, strict=False) | |
| print(f"Loaded adapter weights from {adapter_path}") | |
| except Exception as e: | |
| print(f"Error loading adapter weights: {str(e)}") | |
| return model, tokenizer | |
| # Generate response | |
| def generate_response(model, tokenizer, prompt, max_tokens, temperature): | |
| prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
| generated_tokens = [] | |
| for token in utils.generate(prompt_tokens, model, temperature): | |
| generated_tokens.append(token.item()) | |
| if len(generated_tokens) >= max_tokens or token.item() == tokenizer.eos_token_id: | |
| break | |
| return tokenizer.decode(generated_tokens) | |
| # Inference function | |
| def infer(question, max_tokens, temperature): | |
| prompt = f"Q: {question}\nA:" | |
| response = generate_response(model, tokenizer, prompt, max_tokens, temperature) | |
| return response | |
| # Load the model and tokenizer (do this outside the infer function to load only once) | |
| model_path = "./phi-2" # Update this with the actual path to your model | |
| adapter_path = "./adapters.npz" # Update this with the actual path to your adapters | |
| model, tokenizer = load_model(model_path, adapter_path) | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=infer, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter your question here..."), | |
| gr.Slider(minimum=1, maximum=500, value=100, step=1, label="Max Tokens"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
| ], | |
| outputs="text", | |
| title="Fine-tuned Phi-2 Q&A Demo", | |
| description="Ask a question and get an answer from the fine-tuned Phi-2 model. Finetuned on OASST1 dataset." | |
| ) | |
| # Launch the interface | |
| iface.launch() |