NeMo / nemo /collections /nlp /modules /common /text_generation_server.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import json
import threading
import torch
from flask import Flask, jsonify, request
from flask_restful import Api, Resource
from nemo.collections.nlp.modules.common.retro_inference_strategies import (
RetroModelTextGenerationStrategy,
RetroQAModelTextGenerationStrategy,
)
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.utils import logging
GENERATE_NUM = 0
lock = threading.Lock()
API_ALLOWED_KEYS = set(
[
'all_probs',
'sentences',
"task_ids",
"tokens_to_generate",
"temperature",
"add_BOS",
"greedy",
"top_k",
"top_p",
"neighbors",
"repetition_penalty",
"min_tokens_to_generate",
]
)
class MegatronGenerate(Resource):
def __init__(self, model, inference_strategy=None):
self.model = model
self.inference_strategy = inference_strategy
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
logging.info("request IP: " + str(request.remote_addr))
logging.info(json.dumps(request.get_json()))
# check keys
for key in request.get_json().keys():
if key not in API_ALLOWED_KEYS:
logging.error(f"The request key {key} is not allowed")
sentences = request.get_json()["sentences"]
if isinstance(sentences, tuple): # Input can be text or tensor
if len(sentences[0]) != len(sentences[1]) or sentences[0] > 128:
return "Maximum number of sentences is 128", 400
elif len(sentences) > 128:
return "Maximum number of sentences is 128", 400
task_ids = None # Used for ptuned/prompt tuned models only
if "task_ids" in request.get_json():
task_ids = request.get_json()["task_ids"]
if not isinstance(sentences, tuple):
return "Input at 'sentences' must by a tuple of two tensors like:\
(context_tokens_tensor, context_length_tensor) if task ids are given"
if len(task_ids) != len(sentences[0]):
return "Each sentence must have a corresponding task id for p-tuned/prompt-tuned models"
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 1:
return "tokens_to_generate must be an integer greater than 0"
all_probs = False
if "all_probs" in request.get_json():
all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
greedy = False
if "greedy" in request.get_json():
greedy = request.get_json()["greedy"]
if not isinstance(greedy, bool):
return "greedy must be a boolean value"
top_k = 0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int or type(top_k) == float):
return "top_k must be a positive integer number"
if not (0 <= top_k):
return "top_k must be a positive integer number"
top_p = 0.9
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == int or type(top_p) == float):
return "top_p must be a positive number less than or equal to 1.0"
if not (0.0 <= top_p <= 1.0):
return "top_p must be a positive number less than or equal to 1.0"
repetition_penalty = 1.2
if "repetition_penalty" in request.get_json():
repetition_penalty = request.get_json()["repetition_penalty"]
if not (type(repetition_penalty) == int or type(repetition_penalty) == float):
return "repetition_penalty must be a positive number no less than 1.0"
if not (1.0 <= repetition_penalty):
return "repetition_penalty must be a positive number no less than 1.0"
min_tokens_to_generate = 0
if "min_tokens_to_generate" in request.get_json():
min_tokens_to_generate = request.get_json()["min_tokens_to_generate"]
if not isinstance(min_tokens_to_generate, int):
return "min_tokens_to_generate must be an integer no less than 0"
if min_tokens_to_generate < 0:
return "min_tokens_to_generate must be an integer no less than 0"
neighbors = None
if "neighbors" in request.get_json():
neighbors = request.get_json()["neighbors"]
if not isinstance(neighbors, int):
return "num of neighbors must be an integer no less than 0"
if neighbors < 0:
return "num of neighbors must be an integer no less than 0"
with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
extra = {}
if task_ids is not None:
extra['task_ids'] = task_ids
if self.inference_strategy is not None:
extra['strategy'] = self.inference_strategy
# RETRO specific arguments
if isinstance(
self.inference_strategy, (RetroModelTextGenerationStrategy, RetroQAModelTextGenerationStrategy)
):
if neighbors is not None:
self.inference_strategy.update_neighbors(neighbors)
output = generate(
self.model,
sentences,
tokens_to_generate,
all_probs,
temperature,
add_BOS,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
**extra,
)
for k in output:
if isinstance(output[k], torch.Tensor):
output[k] = output[k].tolist()
if not all_probs:
del output['full_logprob']
if self.inference_strategy is not None:
if isinstance(
self.inference_strategy, (RetroModelTextGenerationStrategy, RetroQAModelTextGenerationStrategy)
):
retrieved_doc = self.inference_strategy.retrieved_text
output['retrieved'] = retrieved_doc
return jsonify(output)
class MegatronServer(object):
def __init__(self, model, inference_strategy=None):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model, inference_strategy])
def run(self, url, port=5000):
self.app.run(url, threaded=True, port=port, debug=False)