NeMo / nemo /collections /nlp /modules /common /megatron_web_server.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import gradio as gr
from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import (
convert_retrieved_to_md,
request_data,
text_generation,
)
__all__ = ['RetroDemoWebApp', 'get_demo']
def create_gen_function(port=5555):
def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition):
data = {
"sentences": [prompt],
"tokens_to_generate": int(token_to_gen),
"temperature": temp,
"add_BOS": add_BOS,
"top_k": top_k,
"top_p": top_p,
"greedy": greedy,
"all_probs": False,
"repetition_penalty": repetition,
"min_tokens_to_generate": int(min_tokens),
}
sentences = text_generation(data, port=port)['sentences']
return sentences[0]
return get_generation
def get_demo(share, username, password, server_port=5555, web_port=9889, loop=None):
asyncio.set_event_loop(loop)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2, width=200):
greedy_flag = gr.Checkbox(label="Greedy")
add_BOS = gr.Checkbox(label="Add BOS token", value=False)
token_to_gen = gr.Number(label='Number of Tokens to generate', value=300, type=int)
min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int)
temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1)
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P')
top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K')
repetition_penality = gr.Slider(
minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty'
)
with gr.Column(scale=1, min_width=800):
input_prompt = gr.Textbox(
label="Input",
value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?",
lines=5,
)
output_box = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(
create_gen_function(server_port),
inputs=[
input_prompt,
greedy_flag,
add_BOS,
token_to_gen,
min_token_to_gen,
temperature,
top_p,
top_k,
repetition_penality,
],
outputs=[output_box],
)
demo.launch(share=share, server_port=web_port, server_name='0.0.0.0', auth=(username, password))
class RetroDemoWebApp:
def __init__(self, text_service_ip, text_service_port, combo_service_ip, combo_service_port):
self.text_service_ip = text_service_ip
self.text_service_port = text_service_port
self.combo_service_ip = combo_service_ip
self.combo_service_port = combo_service_port
def get_retro_generation(
self, prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, neighbors, weight
):
data = {
"sentences": [prompt],
"tokens_to_generate": int(token_to_gen),
"temperature": temp,
"add_BOS": add_BOS,
"top_k": top_k,
"top_p": top_p,
"greedy": greedy,
"all_probs": False,
"repetition_penalty": repetition,
"min_tokens_to_generate": int(min_tokens),
"neighbors": int(neighbors),
}
self.update_weight(weight)
output_json = text_generation(data, self.text_service_ip, self.text_service_port)
sentences = output_json['sentences']
retrieved = output_json['retrieved']
return sentences[0], convert_retrieved_to_md(retrieved)
def update_weight(self, weight):
data = {"update_weight": [weight, 1.0 - weight]}
return request_data(data, self.combo_service_ip, self.combo_service_port)
def add_doc(self, doc, add_eos):
data = {
"sentences": [doc],
"add_eos": add_eos,
}
return request_data(data, self.combo_service_ip, self.combo_service_port)
def reset_index(self):
data = {"reset": None}
return request_data(data, self.combo_service_ip, self.combo_service_port)
def run_demo(self, share, username, password, port):
with gr.Blocks(css="table, th, td { border: 1px solid blue; table-layout: fixed; width: 100%; }") as demo:
with gr.Row():
with gr.Column(scale=2, width=200):
greedy_flag = gr.Checkbox(label="Greedy", value=True)
add_BOS = gr.Checkbox(label="Add BOS token", value=False)
token_to_gen = gr.Number(label='Number of Tokens to generate', value=30, type=int)
min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int)
temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1)
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P')
top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K')
repetition_penality = gr.Slider(
minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty'
)
k_neighbors = gr.Slider(minimum=0, maximum=50, step=1, value=2, label='Retrieved Documents')
weight = gr.Slider(
minimum=0.0, maximum=1.0, value=1.0, label='Weight for the Static Retrieval DB', step=0.02
)
add_retrival_doc = gr.Textbox(label="Add New Retrieval Doc", value="", lines=5,)
add_EOS = gr.Checkbox(label="Add EOS token to Retrieval Doc", value=False)
with gr.Row():
add_btn = gr.Button(value="Add")
reset_btn = gr.Button(value="Reset Index")
output_status = gr.Label(value='')
add_btn.click(self.add_doc, inputs=[add_retrival_doc, add_EOS], outputs=[output_status])
reset_btn.click(self.reset_index, inputs=[], outputs=[output_status])
with gr.Column(scale=1, min_width=800):
input_prompt = gr.Textbox(
label="Input",
value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?",
lines=5,
)
output_box = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
output_retrieval = gr.HTML()
btn.click(
self.get_retro_generation,
inputs=[
input_prompt,
greedy_flag,
add_BOS,
token_to_gen,
min_token_to_gen,
temperature,
top_p,
top_k,
repetition_penality,
k_neighbors,
weight,
],
outputs=[output_box, output_retrieval],
)
demo.launch(share=share, server_port=port, server_name='0.0.0.0', auth=(username, password))