Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import torch | |
| # Load the base model and tokenizer | |
| def load_model(): | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Load the fine-tuned adapter | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| "satyanayak/PHI2-SFT-OASST1", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "microsoft/phi-2", | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer | |
| # Generate response | |
| def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9): | |
| inputs = tokenizer(f"Human: {prompt}\nAssistant:", return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the Assistant's response | |
| response = response.split("Assistant:")[-1].strip() | |
| return response | |
| # Example prompts - Update to include values for all input parameters | |
| EXAMPLE_PROMPTS = [ | |
| ["What is the capital of France?", 512, 0.7, 0.9], | |
| ["Write a short poem about autumn.", 512, 0.7, 0.9], | |
| ["Explain quantum computing in simple terms.", 512, 0.7, 0.9], | |
| ["Give me a recipe for chocolate chip cookies.", 512, 0.7, 0.9], | |
| ["What are the benefits of regular exercise?", 512, 0.7, 0.9] | |
| ] | |
| # Load model and tokenizer | |
| print("Loading model...") | |
| model, tokenizer = load_model() | |
| print("Model loaded!") | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_response, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your message here...", | |
| lines=4 | |
| ), | |
| gr.Slider( | |
| minimum=64, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Maximum Length" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top P" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Response", lines=10), | |
| examples=EXAMPLE_PROMPTS, | |
| title="Phi-2 Assistant", | |
| description="This is a fine-tuned version of Phi-2 on the OpenAssistant dataset. Enter your prompt and adjust generation parameters as needed.", | |
| ) | |
| # Add this line at the end of the file | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |