| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| model_path = "gupta1912/phi-2-custom-oasst1" | |
| model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| def generate_text(prompt, response_length): | |
| prompt = str(prompt) | |
| max_len = int(response_length) | |
| gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=max_len) | |
| result = gen(f"<s>[INST] {prompt} [/INST]") | |
| output_msg = result[0]['generated_text'].split("[/INST] ")[1] | |
| return output_msg | |
| def gradio_fn(prompt, response_length): | |
| output_txt_msg = generate_text(prompt, response_length) | |
| return output_txt_msg | |
| markdown_description = """ | |
| - This is a Gradio app that answers the query you ask it | |
| - Uses **microsoft/phi-2** model finetuned on **OpenAssistant/oasst1** dataset | |
| """ | |
| demo = gr.Interface(fn=gradio_fn, | |
| inputs=[gr.Textbox(info="How may I help you ? please enter your prompt here..."), | |
| gr.Slider(value=50, minimum=50, maximum=300, \ | |
| info="Choose a response length min chars=50, max=300")], | |
| outputs=gr.Textbox(), | |
| title="custom trained phi2 - Dialog Partner", | |
| description=markdown_description) | |
| demo.queue().launch(share=True, debug=True) | |