NeMo_Canary / scripts /deploy /nlp /deploy_inframework_hf_triton.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import sys
import torch
import torch.distributed as dist
from nemo.deploy import DeployPyTriton
from nemo.deploy.nlp.hf_deployable import HuggingFaceLLMDeploy
LOGGER = logging.getLogger("NeMo")
def setup_torch_dist(rank, world_size):
"""Sets up PyTorch distributed training environment.
Args:
rank (int): The rank of the current process
world_size (int): Total number of processes for distributed training
"""
torch.cuda.set_device(rank)
# Initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def get_args(argv):
"""Get command line arguments for deploying HuggingFace models to Triton.
Returns:
argparse.Namespace: Parsed command line arguments including:
- hf_model_id_path: Path to HuggingFace model
- task: Model task type (text-generation)
- device_map: Device mapping strategy
- tp_plan: Tensor parallelism plan
- trust_remote_code: Whether to trust remote code
- triton_model_name: Name for model in Triton
- triton_model_version: Model version number
- triton_port: Triton HTTP port
- triton_http_address: Triton HTTP address
- max_batch_size: Maximum inference batch size
- debug_mode: Enable debug logging
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Deploy HuggingFace models to Triton Inference Server",
)
parser.add_argument(
"-hp",
"--hf_model_id_path",
type=str,
help="Path to local HuggingFace " "model directory or model ID from HuggingFace " "Hub",
)
parser.add_argument(
"-t",
"--task",
nargs='?',
choices=['text-generation'],
default="text-generation",
type=str,
help="Task type for the HuggingFace model (currently only text-generation is supported)",
)
parser.add_argument(
"-dvm",
"--device_map",
nargs='?',
choices=['auto', 'balanced', 'balanced_low_0', 'sequential'],
default=None,
type=str,
help="Device mapping " "strategy for model placement " "(e.g. 'auto', 'sequential', etc)",
)
parser.add_argument(
"-tpp",
"--tp_plan",
nargs='?',
choices=['auto'],
default=None,
type=str,
help="Tensor parallelism plan for distributed inference",
)
parser.add_argument(
"-trc",
"--trust_remote_code",
default=False,
action='store_true',
help="Allow loading " "remote code from HuggingFace " "Hub",
)
parser.add_argument(
"-tmn", "--triton_model_name", required=True, type=str, help="Name to " "identify the model in " "Triton"
)
parser.add_argument(
"-tmv", "--triton_model_version", default=1, type=int, help="Version " "number for the model " "in Triton"
)
parser.add_argument(
"-trp", "--triton_port", default=8000, type=int, help="Port number for Triton server " "HTTP endpoint"
)
parser.add_argument(
"-tha",
"--triton_http_address",
default="0.0.0.0",
type=str,
help="Network interface " "address for Triton HTTP endpoint",
)
parser.add_argument(
"-mbs", "--max_batch_size", default=8, type=int, help="Maximum " "batch size for model inference"
)
parser.add_argument(
"-dm", "--debug_mode", default=False, action='store_true', help="Enable " "verbose debug logging"
)
args = parser.parse_args(argv)
return args
def hf_deploy(argv):
"""Deploy a HuggingFace model to Triton Inference Server.
This function handles the deployment workflow including:
- Parsing command line arguments
- Setting up distributed training if needed
- Initializing the HuggingFace model
- Starting the Triton server
Args:
argv: Command line arguments
Raises:
ValueError: If required arguments are missing or invalid
"""
args = get_args(argv)
if args.debug_mode:
loglevel = logging.DEBUG
else:
loglevel = logging.INFO
LOGGER.setLevel(loglevel)
LOGGER.info("Logging level set to {}".format(loglevel))
LOGGER.info(args)
if args.hf_model_id_path is None:
raise ValueError("In-Framework deployment requires a Hugging Face model ID or path.")
if "RANK" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if world_size > 1:
setup_torch_dist(rank, world_size)
else:
if args.device_map == "auto":
LOGGER.warning(
"device_map is set to auto and it is recommended that the script"
"is started with torchrun with a process per GPU. You might "
"see unexpected issues during the inference otherwise."
)
if args.tp_plan is not None:
raise ValueError("tp_plan is only available with torchrun.")
hf_deployable = HuggingFaceLLMDeploy(
hf_model_id_path=args.hf_model_id_path,
task=args.task,
trust_remote_code=args.trust_remote_code,
device_map=args.device_map,
tp_plan=args.tp_plan,
)
start_triton_server = True
if dist.is_initialized():
if dist.get_rank() > 0:
start_triton_server = False
if start_triton_server:
try:
nm = DeployPyTriton(
model=hf_deployable,
triton_model_name=args.triton_model_name,
triton_model_version=args.triton_model_version,
max_batch_size=args.max_batch_size,
http_port=args.triton_port,
address=args.triton_http_address,
)
LOGGER.info("Triton deploy function will be called.")
nm.deploy()
except Exception as error:
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
if dist.is_initialized():
dist.barrier()
return
try:
LOGGER.info("Model serving on Triton will be started.")
nm.serve()
except Exception as error:
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
if dist.is_initialized():
if dist.get_world_size() > 1:
torch.distributed.broadcast(torch.tensor([1], dtype=torch.long, device="cuda"), src=0)
LOGGER.info("Model serving will be stopped.")
nm.stop()
else:
if dist.is_initialized():
if dist.get_rank() > 0:
hf_deployable.generate_other_ranks()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
hf_deploy(sys.argv[1:])