|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for generating text.""" |
|
|
|
|
|
import json |
|
|
import threading |
|
|
|
|
|
import torch |
|
|
from flask import Flask, jsonify, request |
|
|
from flask_restful import Api, Resource |
|
|
|
|
|
from nemo.collections.nlp.modules.common.retro_inference_strategies import ( |
|
|
RetroModelTextGenerationStrategy, |
|
|
RetroQAModelTextGenerationStrategy, |
|
|
) |
|
|
from nemo.collections.nlp.modules.common.text_generation_utils import generate |
|
|
from nemo.utils import logging |
|
|
|
|
|
GENERATE_NUM = 0 |
|
|
lock = threading.Lock() |
|
|
|
|
|
API_ALLOWED_KEYS = set( |
|
|
[ |
|
|
'all_probs', |
|
|
'sentences', |
|
|
"task_ids", |
|
|
"tokens_to_generate", |
|
|
"temperature", |
|
|
"add_BOS", |
|
|
"greedy", |
|
|
"top_k", |
|
|
"top_p", |
|
|
"neighbors", |
|
|
"repetition_penalty", |
|
|
"min_tokens_to_generate", |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
class MegatronGenerate(Resource): |
|
|
def __init__(self, model, inference_strategy=None): |
|
|
self.model = model |
|
|
self.inference_strategy = inference_strategy |
|
|
|
|
|
@staticmethod |
|
|
def send_do_generate(): |
|
|
choice = torch.cuda.LongTensor([GENERATE_NUM]) |
|
|
torch.distributed.broadcast(choice, 0) |
|
|
|
|
|
def put(self): |
|
|
logging.info("request IP: " + str(request.remote_addr)) |
|
|
logging.info(json.dumps(request.get_json())) |
|
|
|
|
|
for key in request.get_json().keys(): |
|
|
if key not in API_ALLOWED_KEYS: |
|
|
logging.error(f"The request key {key} is not allowed") |
|
|
|
|
|
sentences = request.get_json()["sentences"] |
|
|
if isinstance(sentences, tuple): |
|
|
if len(sentences[0]) != len(sentences[1]) or sentences[0] > 128: |
|
|
return "Maximum number of sentences is 128", 400 |
|
|
elif len(sentences) > 128: |
|
|
return "Maximum number of sentences is 128", 400 |
|
|
|
|
|
task_ids = None |
|
|
if "task_ids" in request.get_json(): |
|
|
task_ids = request.get_json()["task_ids"] |
|
|
if not isinstance(sentences, tuple): |
|
|
return "Input at 'sentences' must by a tuple of two tensors like:\ |
|
|
(context_tokens_tensor, context_length_tensor) if task ids are given" |
|
|
if len(task_ids) != len(sentences[0]): |
|
|
return "Each sentence must have a corresponding task id for p-tuned/prompt-tuned models" |
|
|
|
|
|
tokens_to_generate = 64 |
|
|
if "tokens_to_generate" in request.get_json(): |
|
|
tokens_to_generate = request.get_json()["tokens_to_generate"] |
|
|
if not isinstance(tokens_to_generate, int): |
|
|
return "tokens_to_generate must be an integer greater than 0" |
|
|
if tokens_to_generate < 1: |
|
|
return "tokens_to_generate must be an integer greater than 0" |
|
|
|
|
|
all_probs = False |
|
|
if "all_probs" in request.get_json(): |
|
|
all_probs = request.get_json()["all_probs"] |
|
|
if not isinstance(all_probs, bool): |
|
|
return "all_probs must be a boolean value" |
|
|
|
|
|
temperature = 1.0 |
|
|
if "temperature" in request.get_json(): |
|
|
temperature = request.get_json()["temperature"] |
|
|
if not (type(temperature) == int or type(temperature) == float): |
|
|
return "temperature must be a positive number less than or equal to 100.0" |
|
|
if not (0.0 < temperature <= 100.0): |
|
|
return "temperature must be a positive number less than or equal to 100.0" |
|
|
|
|
|
add_BOS = False |
|
|
if "add_BOS" in request.get_json(): |
|
|
add_BOS = request.get_json()["add_BOS"] |
|
|
if not isinstance(add_BOS, bool): |
|
|
return "add_BOS must be a boolean value" |
|
|
|
|
|
greedy = False |
|
|
if "greedy" in request.get_json(): |
|
|
greedy = request.get_json()["greedy"] |
|
|
if not isinstance(greedy, bool): |
|
|
return "greedy must be a boolean value" |
|
|
|
|
|
top_k = 0 |
|
|
if "top_k" in request.get_json(): |
|
|
top_k = request.get_json()["top_k"] |
|
|
if not (type(top_k) == int or type(top_k) == float): |
|
|
return "top_k must be a positive integer number" |
|
|
if not (0 <= top_k): |
|
|
return "top_k must be a positive integer number" |
|
|
|
|
|
top_p = 0.9 |
|
|
if "top_p" in request.get_json(): |
|
|
top_p = request.get_json()["top_p"] |
|
|
if not (type(top_p) == int or type(top_p) == float): |
|
|
return "top_p must be a positive number less than or equal to 1.0" |
|
|
if not (0.0 <= top_p <= 1.0): |
|
|
return "top_p must be a positive number less than or equal to 1.0" |
|
|
|
|
|
repetition_penalty = 1.2 |
|
|
if "repetition_penalty" in request.get_json(): |
|
|
repetition_penalty = request.get_json()["repetition_penalty"] |
|
|
if not (type(repetition_penalty) == int or type(repetition_penalty) == float): |
|
|
return "repetition_penalty must be a positive number no less than 1.0" |
|
|
if not (1.0 <= repetition_penalty): |
|
|
return "repetition_penalty must be a positive number no less than 1.0" |
|
|
|
|
|
min_tokens_to_generate = 0 |
|
|
if "min_tokens_to_generate" in request.get_json(): |
|
|
min_tokens_to_generate = request.get_json()["min_tokens_to_generate"] |
|
|
if not isinstance(min_tokens_to_generate, int): |
|
|
return "min_tokens_to_generate must be an integer no less than 0" |
|
|
if min_tokens_to_generate < 0: |
|
|
return "min_tokens_to_generate must be an integer no less than 0" |
|
|
|
|
|
neighbors = None |
|
|
if "neighbors" in request.get_json(): |
|
|
neighbors = request.get_json()["neighbors"] |
|
|
if not isinstance(neighbors, int): |
|
|
return "num of neighbors must be an integer no less than 0" |
|
|
if neighbors < 0: |
|
|
return "num of neighbors must be an integer no less than 0" |
|
|
|
|
|
with lock: |
|
|
MegatronGenerate.send_do_generate() |
|
|
extra = {} |
|
|
if task_ids is not None: |
|
|
extra['task_ids'] = task_ids |
|
|
if self.inference_strategy is not None: |
|
|
extra['strategy'] = self.inference_strategy |
|
|
|
|
|
if isinstance( |
|
|
self.inference_strategy, (RetroModelTextGenerationStrategy, RetroQAModelTextGenerationStrategy) |
|
|
): |
|
|
if neighbors is not None: |
|
|
self.inference_strategy.update_neighbors(neighbors) |
|
|
|
|
|
output = generate( |
|
|
self.model, |
|
|
sentences, |
|
|
tokens_to_generate, |
|
|
all_probs, |
|
|
temperature, |
|
|
add_BOS, |
|
|
top_k, |
|
|
top_p, |
|
|
greedy, |
|
|
repetition_penalty, |
|
|
min_tokens_to_generate, |
|
|
**extra, |
|
|
) |
|
|
for k in output: |
|
|
if isinstance(output[k], torch.Tensor): |
|
|
output[k] = output[k].tolist() |
|
|
if not all_probs: |
|
|
del output['full_logprob'] |
|
|
|
|
|
if self.inference_strategy is not None: |
|
|
if isinstance( |
|
|
self.inference_strategy, (RetroModelTextGenerationStrategy, RetroQAModelTextGenerationStrategy) |
|
|
): |
|
|
retrieved_doc = self.inference_strategy.retrieved_text |
|
|
output['retrieved'] = retrieved_doc |
|
|
return jsonify(output) |
|
|
|
|
|
|
|
|
class MegatronServer(object): |
|
|
def __init__(self, model, inference_strategy=None): |
|
|
self.app = Flask(__name__, static_url_path='') |
|
|
api = Api(self.app) |
|
|
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model, inference_strategy]) |
|
|
|
|
|
def run(self, url, port=5000): |
|
|
self.app.run(url, threaded=True, port=port, debug=False) |
|
|
|