| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| | from concurrent import futures |
| |
|
| | import api.nmt_pb2 as nmt |
| | import api.nmt_pb2_grpc as nmtsrv |
| | import grpc |
| | import torch |
| |
|
| | import nemo.collections.nlp as nemo_nlp |
| | from nemo.utils import logging |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.", |
| | ) |
| | parser.add_argument( |
| | "--punctuation_model", |
| | default="", |
| | type=str, |
| | help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)", |
| | ) |
| | parser.add_argument("--port", default=50052, type=int, required=False) |
| | parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process") |
| | parser.add_argument("--beam_size", type=int, default=1, help="Beam Size") |
| | parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty") |
| | parser.add_argument("--max_delta_length", type=int, default=5, help="Max Delta Generation Length.") |
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def batches(lst, n): |
| | """Yield successive n-sized chunks from lst.""" |
| | for i in range(0, len(lst), n): |
| | yield lst[i : i + n] |
| |
|
| |
|
| | class RivaTranslateServicer(nmtsrv.RivaTranslateServicer): |
| | """Provides methods that implement functionality of route guide server.""" |
| |
|
| | def __init__( |
| | self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256, |
| | ): |
| | self._models = {} |
| | self._beam_size = beam_size |
| | self._len_pen = len_pen |
| | self._max_delta_length = max_delta_length |
| | self._batch_size = batch_size |
| | self._punctuation_model_path = punctuation_model_path |
| | self._model_dir = model_dir |
| |
|
| | model_paths = [os.path.join(model_dir, fname) for fname in os.listdir(model_dir) if fname.endswith('.nemo')] |
| |
|
| | for idx, model_path in enumerate(model_paths): |
| | assert os.path.exists(model_path) |
| | logging.info(f"Loading model {model_path}") |
| | self._load_model(model_path) |
| |
|
| | if self._punctuation_model_path != "": |
| | assert os.path.exists(punctuation_model_path) |
| | logging.info(f"Loading punctuation model {model_path}") |
| | self._load_puncutation_model(punctuation_model_path) |
| |
|
| | logging.info("Models loaded. Ready for inference requests.") |
| |
|
| | def _load_puncutation_model(self, punctuation_model_path): |
| | if punctuation_model_path.endswith(".nemo"): |
| | self.punctuation_model = nemo_nlp.models.PunctuationCapitalizationModel.restore_from( |
| | restore_path=punctuation_model_path |
| | ) |
| | self.punctuation_model.eval() |
| | else: |
| | raise NotImplemented(f"Only support .nemo files, but got: {punctuation_model_path}") |
| |
|
| | if torch.cuda.is_available(): |
| | self.punctuation_model = self.punctuation_model.cuda() |
| |
|
| | def _load_model(self, model_path): |
| | if model_path.endswith(".nemo"): |
| | logging.info("Attempting to initialize from .nemo file") |
| | model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path) |
| | model = model.eval() |
| | model.beam_search.beam_size = self._beam_size |
| | model.beam_search.len_pen = self._len_pen |
| | model.beam_search.max_delta_length = self._max_delta_length |
| | if torch.cuda.is_available(): |
| | model = model.cuda() |
| | else: |
| | raise NotImplemented(f"Only support .nemo files, but got: {model_path}") |
| |
|
| | if not hasattr(model, "src_language") or not hasattr(model, "tgt_language"): |
| | raise ValueError( |
| | f"Could not find src_language and tgt_language in model attributes. If using NeMo rc1 checkpoints, please edit the config files to add model.src_language and model.tgt_language" |
| | ) |
| |
|
| | src_language = model.src_language |
| | tgt_language = model.tgt_language |
| |
|
| | if src_language not in self._models: |
| | self._models[src_language] = {} |
| |
|
| | if tgt_language not in self._models[src_language]: |
| | self._models[src_language][tgt_language] = model |
| | if torch.cuda.is_available(): |
| | self._models[src_language][tgt_language] = self._models[src_language][tgt_language].cuda() |
| | else: |
| | raise ValueError(f"Already found model for language pair {src_language}-{tgt_language}") |
| |
|
| | def TranslateText(self, request, context): |
| | logging.info(f"Request received w/ {len(request.texts)} utterances") |
| | results = [] |
| |
|
| | if request.source_language not in self._models: |
| | context.set_code(grpc.StatusCode.INVALID_ARGUMENT) |
| | context.set_details( |
| | f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." |
| | ) |
| | return nmt.TranslateTextResponse() |
| |
|
| | if request.target_language not in self._models[request.source_language]: |
| | context.set_code(grpc.StatusCode.INVALID_ARGUMENT) |
| | context.set_details( |
| | f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." |
| | ) |
| | return nmt.TranslateTextResponse() |
| |
|
| | request_strings = [x for x in request.texts] |
| |
|
| | for batch in batches(request_strings, self._batch_size): |
| | if self._punctuation_model_path != "": |
| | batch = self.punctuation_model.add_punctuation_capitalization(batch) |
| | batch_results = self._models[request.source_language][request.target_language].translate(text=batch) |
| | translations = [nmt.Translation(translation=x) for x in batch_results] |
| | results.extend(translations) |
| |
|
| | return nmt.TranslateTextResponse(translations=results) |
| |
|
| |
|
| | def serve(): |
| | args = get_args() |
| | server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) |
| | servicer = RivaTranslateServicer( |
| | model_dir=args.model_dir, |
| | punctuation_model_path=args.punctuation_model, |
| | beam_size=args.beam_size, |
| | len_pen=args.len_pen, |
| | batch_size=args.batch_size, |
| | max_delta_length=args.max_delta_length, |
| | ) |
| | nmtsrv.add_RivaTranslateServicer_to_server(servicer, server) |
| | server.add_insecure_port('[::]:' + str(args.port)) |
| | server.start() |
| | server.wait_for_termination() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | serve() |
| |
|