Spaces:
Build error
Build error
| import gradio as gr | |
| def generate_text(context, num_samples, context_length, model_name): | |
| from base import main | |
| from pathlib import Path | |
| if model_name == "pythia_160m_deduped_custom" or model_name == "pythia_160m_deduped_huggingface": | |
| ckpt_dir = Path('/home/user/app/checkpoints/EleutherAI/pythia-160m-deduped') | |
| elif model_name == "pythia_70m_deduped": | |
| ckpt_dir = Path('/home/user/app/checkpoints/EleutherAI/pythia-70m-deduped') | |
| elif model_name == "pythia_410m_deduped": | |
| ckpt_dir = Path('/home/user/app/checkpoints/EleutherAI/pythia-410m-deduped') | |
| context = str(context) | |
| num_samples = int(num_samples) | |
| context_length = int(context_length) | |
| model_name = str(model_name) | |
| output_msg_list = main(prompt=context, checkpoint_dir=ckpt_dir, model_name=model_name, num_samples=num_samples, max_new_tokens=context_length) | |
| output_msg = str() | |
| for idx, msg in enumerate(output_msg_list): | |
| title = f"--Generated message : {idx + 1} using the model : {model_name}--\n" | |
| output_msg += f"{title}\n" | |
| output_msg += f"{msg}\n" | |
| output_msg += f"\n" | |
| return output_msg | |
| def gradio_fn(context, num_samples, context_length, model_name): | |
| output_txt_msg = generate_text(context, num_samples, context_length, model_name) | |
| return output_txt_msg | |
| markdown_description = """ | |
| - This is a Gradio app that generates text based on | |
| - given text context | |
| - for given character length | |
| - number of Samples | |
| - using Selected GPT model | |
| - Currently following models are available : | |
| - **(a)** pythia_160m_deduped_huggingface **(b)** pythia_160m_deduped_custom \ | |
| **(c)** pythia_410m_deduped **(d)** pythia_70m_deduped | |
| """ | |
| demo = gr.Interface(fn=gradio_fn, | |
| inputs=[gr.Textbox(info="Start my passage with: 'I would like to'"), | |
| gr.Number(value=1, minimum=1, maximum=5, \ | |
| info="Number of samples to be generated min=1, max=5"), | |
| gr.Slider(value=50, minimum=50, maximum=250, \ | |
| info="Num characters for passage min=50, max=250"), | |
| gr.Dropdown(["pythia_160m_deduped_huggingface", "pythia_160m_deduped_custom", | |
| "pythia_410m_deduped", "pythia_70m_deduped"], \ | |
| multiselect=False, label="Model-Name", \ | |
| info="Pretrained model to be used for text generation")], | |
| outputs=gr.Textbox(), | |
| title="DialogGen - Dialogue Generator", | |
| description=markdown_description, | |
| article=" **Credits** : https://github.com/Lightning-AI/lit-gpt ") | |
| # demo.launch(debug=True, share=True) | |
| demo.launch(share=True) | |