| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import asyncio |
| |
|
| | import gradio as gr |
| |
|
| | from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import ( |
| | convert_retrieved_to_md, |
| | request_data, |
| | text_generation, |
| | ) |
| |
|
| | __all__ = ['RetroDemoWebApp', 'get_demo'] |
| |
|
| |
|
| | def create_gen_function(port=5555): |
| | def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition): |
| | data = { |
| | "sentences": [prompt], |
| | "tokens_to_generate": int(token_to_gen), |
| | "temperature": temp, |
| | "add_BOS": add_BOS, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "greedy": greedy, |
| | "all_probs": False, |
| | "repetition_penalty": repetition, |
| | "min_tokens_to_generate": int(min_tokens), |
| | } |
| | sentences = text_generation(data, port=port)['sentences'] |
| | return sentences[0] |
| |
|
| | return get_generation |
| |
|
| |
|
| | def get_demo(share, username, password, server_port=5555, web_port=9889, loop=None): |
| | asyncio.set_event_loop(loop) |
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(scale=2, width=200): |
| | greedy_flag = gr.Checkbox(label="Greedy") |
| | add_BOS = gr.Checkbox(label="Add BOS token", value=False) |
| | token_to_gen = gr.Number(label='Number of Tokens to generate', value=300, type=int) |
| | min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int) |
| | temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1) |
| | top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P') |
| | top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K') |
| | repetition_penality = gr.Slider( |
| | minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty' |
| | ) |
| | with gr.Column(scale=1, min_width=800): |
| | input_prompt = gr.Textbox( |
| | label="Input", |
| | value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?", |
| | lines=5, |
| | ) |
| | output_box = gr.Textbox(value="", label="Output") |
| | btn = gr.Button(value="Submit") |
| | btn.click( |
| | create_gen_function(server_port), |
| | inputs=[ |
| | input_prompt, |
| | greedy_flag, |
| | add_BOS, |
| | token_to_gen, |
| | min_token_to_gen, |
| | temperature, |
| | top_p, |
| | top_k, |
| | repetition_penality, |
| | ], |
| | outputs=[output_box], |
| | ) |
| | demo.launch(share=share, server_port=web_port, server_name='0.0.0.0', auth=(username, password)) |
| |
|
| |
|
| | class RetroDemoWebApp: |
| | def __init__(self, text_service_ip, text_service_port, combo_service_ip, combo_service_port): |
| | self.text_service_ip = text_service_ip |
| | self.text_service_port = text_service_port |
| | self.combo_service_ip = combo_service_ip |
| | self.combo_service_port = combo_service_port |
| |
|
| | def get_retro_generation( |
| | self, prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, neighbors, weight |
| | ): |
| | data = { |
| | "sentences": [prompt], |
| | "tokens_to_generate": int(token_to_gen), |
| | "temperature": temp, |
| | "add_BOS": add_BOS, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "greedy": greedy, |
| | "all_probs": False, |
| | "repetition_penalty": repetition, |
| | "min_tokens_to_generate": int(min_tokens), |
| | "neighbors": int(neighbors), |
| | } |
| | self.update_weight(weight) |
| | output_json = text_generation(data, self.text_service_ip, self.text_service_port) |
| | sentences = output_json['sentences'] |
| | retrieved = output_json['retrieved'] |
| | return sentences[0], convert_retrieved_to_md(retrieved) |
| |
|
| | def update_weight(self, weight): |
| | data = {"update_weight": [weight, 1.0 - weight]} |
| | return request_data(data, self.combo_service_ip, self.combo_service_port) |
| |
|
| | def add_doc(self, doc, add_eos): |
| | data = { |
| | "sentences": [doc], |
| | "add_eos": add_eos, |
| | } |
| | return request_data(data, self.combo_service_ip, self.combo_service_port) |
| |
|
| | def reset_index(self): |
| | data = {"reset": None} |
| | return request_data(data, self.combo_service_ip, self.combo_service_port) |
| |
|
| | def run_demo(self, share, username, password, port): |
| | with gr.Blocks(css="table, th, td { border: 1px solid blue; table-layout: fixed; width: 100%; }") as demo: |
| | with gr.Row(): |
| | with gr.Column(scale=2, width=200): |
| | greedy_flag = gr.Checkbox(label="Greedy", value=True) |
| | add_BOS = gr.Checkbox(label="Add BOS token", value=False) |
| | token_to_gen = gr.Number(label='Number of Tokens to generate', value=30, type=int) |
| | min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int) |
| | temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1) |
| | top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P') |
| | top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K') |
| | repetition_penality = gr.Slider( |
| | minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty' |
| | ) |
| | k_neighbors = gr.Slider(minimum=0, maximum=50, step=1, value=2, label='Retrieved Documents') |
| | weight = gr.Slider( |
| | minimum=0.0, maximum=1.0, value=1.0, label='Weight for the Static Retrieval DB', step=0.02 |
| | ) |
| | add_retrival_doc = gr.Textbox(label="Add New Retrieval Doc", value="", lines=5,) |
| | add_EOS = gr.Checkbox(label="Add EOS token to Retrieval Doc", value=False) |
| | with gr.Row(): |
| | add_btn = gr.Button(value="Add") |
| | reset_btn = gr.Button(value="Reset Index") |
| | output_status = gr.Label(value='') |
| | add_btn.click(self.add_doc, inputs=[add_retrival_doc, add_EOS], outputs=[output_status]) |
| | reset_btn.click(self.reset_index, inputs=[], outputs=[output_status]) |
| |
|
| | with gr.Column(scale=1, min_width=800): |
| | input_prompt = gr.Textbox( |
| | label="Input", |
| | value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?", |
| | lines=5, |
| | ) |
| | output_box = gr.Textbox(value="", label="Output") |
| | btn = gr.Button(value="Submit") |
| | output_retrieval = gr.HTML() |
| | btn.click( |
| | self.get_retro_generation, |
| | inputs=[ |
| | input_prompt, |
| | greedy_flag, |
| | add_BOS, |
| | token_to_gen, |
| | min_token_to_gen, |
| | temperature, |
| | top_p, |
| | top_k, |
| | repetition_penality, |
| | k_neighbors, |
| | weight, |
| | ], |
| | outputs=[output_box, output_retrieval], |
| | ) |
| | demo.launch(share=share, server_port=port, server_name='0.0.0.0', auth=(username, password)) |
| |
|