| | import datetime |
| | import json |
| | import os |
| | import sys |
| | import time |
| | from random import randint |
| | from threading import Lock, Thread |
| |
|
| | import numpy as np |
| | import torch |
| | import triton_python_backend_utils as pb_utils |
| | from torch import from_numpy |
| | from torch.utils.dlpack import from_dlpack |
| |
|
| | import tensorrt_llm.bindings.executor as trtllm |
| |
|
| |
|
| | def get_input_tensor_by_name(request, |
| | name, |
| | expected_batch_size=None, |
| | batch_index=None): |
| | tensor = pb_utils.get_input_tensor_by_name(request, name) |
| | if tensor is None: |
| | return None |
| |
|
| | if tensor.is_cpu(): |
| | tensor = tensor.as_numpy() |
| | else: |
| | tensor = from_dlpack(tensor.to_dlpack()) |
| |
|
| | if expected_batch_size is not None and tensor.shape[ |
| | 0] != expected_batch_size: |
| | raise pb_utils.TritonModelException( |
| | f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}" |
| | ) |
| |
|
| | if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size: |
| | raise pb_utils.TritonModelException( |
| | f"Invalid batch index in get_input_tensor_by_name for {name}") |
| |
|
| | if batch_index is not None: |
| | |
| | if isinstance(tensor, np.ndarray): |
| | return np.expand_dims(tensor[batch_index], axis=0) |
| | elif isinstance(tensor, torch.Tensor): |
| | return torch.unsqueeze(tensor[batch_index], dim=0) |
| | else: |
| | return tensor |
| |
|
| |
|
| | def get_input_scalar_by_name(request, |
| | name, |
| | expected_batch_size=1, |
| | batch_index=0): |
| | tensor = pb_utils.get_input_tensor_by_name(request, name) |
| | if tensor is None: |
| | return None |
| | tensor = tensor.as_numpy() |
| |
|
| | if tensor.size != expected_batch_size: |
| | raise pb_utils.TritonModelException( |
| | f"Expected a scalar tensor for tensor {name}") |
| |
|
| | return tensor.item(batch_index) |
| |
|
| |
|
| | def read_parameter_as_type(value, name, pytype=str): |
| | if value == "": |
| | return None |
| | if value.startswith("${") and value.endswith("}"): |
| | return None |
| | if pytype is bool: |
| | return value.lower() in ["1", "true"] |
| | try: |
| | result = pytype(value) |
| | return result |
| | except: |
| | pb_utils.Logger.log_warning( |
| | f"Could not read parameter '{name}' with value '{value}', will use default." |
| | ) |
| | return None |
| |
|
| |
|
| | def get_parameter(model_config, name, pytype=str): |
| | if name not in model_config['parameters']: |
| | return None |
| | return read_parameter_as_type( |
| | model_config['parameters'][name]['string_value'], name, pytype) |
| |
|
| |
|
| | def convert_word_list(word_list): |
| | if word_list is None: |
| | return None |
| | word_list = word_list.tolist() |
| | if len(word_list) == 0 or len(word_list[0]) != 2: |
| | raise pb_utils.TritonModelException(f"Invalid format for word list.") |
| | words, indices = word_list[0] |
| | result = [] |
| | current_index = 0 |
| | for i in indices: |
| | if i == -1: |
| | continue |
| | if i > len(words): |
| | raise pb_utils.TritonModelException( |
| | f"Invalid format for word list.") |
| | current_word = [] |
| | while current_index < i: |
| | current_word.append(words[current_index]) |
| | current_index += 1 |
| | result.append(current_word) |
| | return result |
| |
|
| |
|
| | def parse_medusa_choices(medusa_choices): |
| | if medusa_choices is None: |
| | return None |
| | try: |
| | result = json.loads( |
| | "[" + medusa_choices.replace("{", "[").replace("}", "]") + "]") |
| | assert isinstance(result, list) and len(result) > 0 |
| | assert all([isinstance(x, list) for x in result]) |
| | assert all([isinstance(y, int) for x in result for y in x]) |
| | except Exception: |
| | raise pb_utils.TritonModelException( |
| | "Invalid format for medusa_choices") |
| | return result |
| |
|
| |
|
| | def get_sampling_config_from_request(request, batch_size=1, batch_index=0): |
| | kwargs = {} |
| | kwargs['beam_width'] = get_input_scalar_by_name( |
| | request, 'beam_width', batch_size, batch_index) or 1 |
| | kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k', |
| | batch_size, batch_index) |
| | kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p', |
| | batch_size, batch_index) |
| | kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[ |
| | 'top_p'] <= 0 else kwargs['top_p'] |
| | kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed', |
| | batch_size, batch_index) |
| | kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature', |
| | batch_size, batch_index) |
| | kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length', |
| | batch_size, batch_index) |
| | kwargs['repetition_penalty'] = get_input_scalar_by_name( |
| | request, 'repetition_penalty', batch_size, batch_index) |
| | kwargs['presence_penalty'] = get_input_scalar_by_name( |
| | request, 'presence_penalty', batch_size, batch_index) |
| | kwargs['frequency_penalty'] = get_input_scalar_by_name( |
| | request, 'frequency_penalty', batch_size, batch_index) |
| | kwargs['length_penalty'] = get_input_scalar_by_name( |
| | request, 'len_penalty', batch_size, batch_index) |
| | kwargs['top_p_min'] = get_input_scalar_by_name(request, |
| | 'runtime_top_p_min', |
| | batch_size, batch_index) |
| | kwargs['top_p_reset_ids'] = get_input_scalar_by_name( |
| | request, 'runtime_top_p_reset_ids', batch_size, batch_index) |
| | kwargs['top_p_decay'] = get_input_scalar_by_name(request, |
| | 'runtime_top_p_decay', |
| | batch_size, batch_index) |
| | kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name( |
| | request, 'beam_search_diversity_rate', batch_size, batch_index) |
| | kwargs['early_stopping'] = get_input_scalar_by_name( |
| | request, 'early_stopping', batch_size, batch_index) |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.SamplingConfig(**kwargs) |
| |
|
| |
|
| | def get_output_config_from_request(request, |
| | exclude_input_from_output, |
| | batch_size=1, |
| | batch_index=0): |
| | kwargs = {} |
| | kwargs["return_log_probs"] = get_input_scalar_by_name( |
| | request, 'return_log_probs', batch_size, batch_index) |
| | kwargs["return_context_logits"] = get_input_scalar_by_name( |
| | request, 'return_context_logits', batch_size, batch_index) |
| | kwargs["return_generation_logits"] = get_input_scalar_by_name( |
| | request, 'return_generation_logits', batch_size, batch_index) |
| | kwargs["exclude_input_from_output"] = exclude_input_from_output |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.OutputConfig(**kwargs) |
| |
|
| |
|
| | def get_external_draft_tokens_config_from_request(request, |
| | batch_size=1, |
| | batch_index=0): |
| | kwargs = {} |
| | draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids', |
| | batch_size, batch_index) |
| | if draft_input_ids is not None: |
| | kwargs['tokens'] = draft_input_ids[0].tolist() |
| | draft_logits = get_input_tensor_by_name(request, 'draft_logits', |
| | batch_size, batch_index) |
| | if draft_logits is not None: |
| | kwargs['logits'] = from_numpy(draft_logits).squeeze() |
| | kwargs['acceptance_threshold'] = get_input_scalar_by_name( |
| | request, 'draft_acceptance_threshold', batch_size, batch_index) |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | if len(kwargs) > 0: |
| | return trtllm.ExternalDraftTokensConfig(**kwargs) |
| | return None |
| |
|
| |
|
| | def get_prompt_tuning_config_from_request(request, |
| | batch_size=1, |
| | batch_index=0): |
| | |
| | kwargs = {} |
| | prompt_embedding_table = get_input_tensor_by_name( |
| | request, 'prompt_embedding_table', batch_size, batch_index) |
| | if prompt_embedding_table is not None: |
| | if isinstance(prompt_embedding_table, np.ndarray): |
| | kwargs["embedding_table"] = from_numpy( |
| | prompt_embedding_table).squeeze() |
| | elif isinstance(prompt_embedding_table, torch.Tensor): |
| | kwargs["embedding_table"] = from_dlpack( |
| | prompt_embedding_table.to_dlpack()).squeeze(dim=0) |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | if len(kwargs) > 0: |
| | return trtllm.PromptTuningConfig(**kwargs) |
| | return None |
| |
|
| |
|
| | def get_lora_config_from_request(request, batch_size=1, batch_index=0): |
| | kwargs = {} |
| | kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id', |
| | batch_size, batch_index) |
| | lora_weights = get_input_tensor_by_name(request, 'lora_weights', |
| | batch_size, batch_index) |
| | if lora_weights is not None: |
| | kwargs["weights"] = from_numpy(lora_weights).squeeze() |
| | lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size, |
| | batch_index) |
| | if lora_config is not None: |
| | kwargs["config"] = from_numpy(lora_config).squeeze() |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | if len(kwargs) > 0: |
| | return trtllm.LoraConfig(**kwargs) |
| | return None |
| |
|
| |
|
| | def convert_request(request, exclude_input_from_output, decoupled): |
| | inputs = {} |
| | input_token_ids = get_input_tensor_by_name(request, 'input_ids') |
| | if input_token_ids is None: |
| | raise pb_utils.TritonModelException( |
| | "A value is required for input_ids") |
| | if len(input_token_ids.shape) != 2: |
| | raise pb_utils.TritonModelException(f"Invalid format for input_ids") |
| | batch_size = input_token_ids.shape[0] |
| | requests = [] |
| | for batch_index in range(0, batch_size): |
| | input_token_ids = get_input_tensor_by_name(request, 'input_ids', |
| | batch_size, batch_index)[0] |
| | if input_token_ids is None: |
| | raise pb_utils.TritonModelException( |
| | "A value is required for input_ids") |
| | input_token_ids = input_token_ids.tolist() |
| | if len(input_token_ids) == 0: |
| | raise pb_utils.TritonModelException( |
| | f"Invalid format for input_ids") |
| |
|
| | input_length = get_input_scalar_by_name(request, 'input_lengths', |
| | batch_size, batch_index) |
| | if input_length is None: |
| | input_length = len(input_token_ids) |
| | |
| | inputs['input_token_ids'] = input_token_ids[0:input_length] |
| |
|
| | inputs['max_new_tokens'] = get_input_scalar_by_name( |
| | request, 'request_output_len', batch_size, batch_index) |
| | if inputs['max_new_tokens'] is None: |
| | raise pb_utils.TritonModelException( |
| | "A value is required for request_output_len") |
| | inputs['streaming'] = get_input_scalar_by_name(request, 'streaming', |
| | batch_size, batch_index) |
| | if inputs['streaming'] and not decoupled: |
| | raise pb_utils.TritonModelException( |
| | "Streaming is only supported in decoupled mode.") |
| | inputs['end_id'] = get_input_scalar_by_name(request, 'end_id', |
| | batch_size, batch_index) |
| | inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id', |
| | batch_size, batch_index) |
| | inputs['stop_words'] = convert_word_list( |
| | get_input_tensor_by_name(request, 'stop_words_list', batch_size, |
| | batch_index)) |
| | inputs['bad_words'] = convert_word_list( |
| | get_input_tensor_by_name(request, 'bad_words_list', batch_size, |
| | batch_index)) |
| | embedding_bias = get_input_tensor_by_name(request, 'embedding_bias', |
| | batch_size, batch_index) |
| | if embedding_bias is not None and embedding_bias.size != 0: |
| | inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze() |
| |
|
| | sampling_config = get_sampling_config_from_request( |
| | request, batch_size, batch_index) |
| | output_config = get_output_config_from_request( |
| | request, exclude_input_from_output, batch_size, batch_index) |
| | external_draft_tokens_config = get_external_draft_tokens_config_from_request( |
| | request, batch_size, batch_index) |
| | prompt_tuning_config = get_prompt_tuning_config_from_request( |
| | request, batch_size, batch_index) |
| | lora_config = get_lora_config_from_request(request, batch_size, |
| | batch_index) |
| |
|
| | requests.append( |
| | trtllm.Request( |
| | **inputs, |
| | sampling_config=sampling_config, |
| | output_config=output_config, |
| | external_draft_tokens_config=external_draft_tokens_config, |
| | prompt_tuning_config=prompt_tuning_config, |
| | lora_config=lora_config, |
| | )) |
| | return requests |
| |
|
| |
|
| | def convert_response(response, batch_index): |
| | if response.has_error(): |
| | return pb_utils.InferenceResponse(output_tensors=[], |
| | error=pb_utils.TritonError( |
| | response.error_msg)), True |
| | result = response.result |
| | beam_lengths = np.expand_dims( |
| | np.array([len(beam) for beam in result.output_token_ids], np.int32), 0) |
| | max_beam_length = max([len(beam) for beam in result.output_token_ids]) |
| | output_ids = np.full((1, len(result.output_token_ids), max_beam_length), |
| | -1, np.int32) |
| | for idx, beam in enumerate(result.output_token_ids): |
| | output_ids[0, idx, :len(beam)] = beam |
| | output_tensors = [ |
| | pb_utils.Tensor("output_ids", output_ids), |
| | pb_utils.Tensor("sequence_length", beam_lengths), |
| | ] |
| | output_tensors.append( |
| | pb_utils.Tensor( |
| | "cum_log_probs", |
| | np.expand_dims(np.array(result.cum_log_probs, np.float32), 0) |
| | if result.cum_log_probs is not None else np.zeros( |
| | (1, 1), np.float32))) |
| | output_tensors.append( |
| | pb_utils.Tensor( |
| | "output_log_probs", |
| | np.expand_dims(np.array(result.log_probs, np.float32), 0) if |
| | result.log_probs is not None else np.zeros((1, 1, 1), np.float32))) |
| | output_tensors.append( |
| | pb_utils.Tensor( |
| | "context_logits", |
| | np.expand_dims(np.array(result.context_logits, np.float32), 0) |
| | if result.context_logits is not None else np.zeros( |
| | (1, 1, 1), np.float32))) |
| | output_tensors.append( |
| | pb_utils.Tensor( |
| | "generation_logits", |
| | np.expand_dims(np.array(result.generation_logits, np.float32), 0) |
| | if result.generation_logits is not None else np.zeros( |
| | (1, 1, 1, 1), np.float32))) |
| | output_tensors.append( |
| | pb_utils.Tensor("batch_index", |
| | np.expand_dims(np.array([batch_index], np.int32), 0))) |
| |
|
| | return pb_utils.InferenceResponse(output_tensors), result.is_final |
| |
|
| |
|
| | def convert_scheduler_policy(batch_scheduler_policy: str): |
| | if batch_scheduler_policy.lower() == "max_utilization": |
| | return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION |
| | elif batch_scheduler_policy.lower() == "guaranteed_no_evict": |
| | return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT |
| | raise pb_utils.TritonModelException( |
| | f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported." |
| | ) |
| |
|
| |
|
| | def convert_batching_type(gpt_model_type: str): |
| | if gpt_model_type is None: |
| | return None |
| | if gpt_model_type.lower( |
| | ) == "inflight_fused_batching" or gpt_model_type.lower( |
| | ) == "inflight_batching": |
| | return trtllm.BatchingType.INFLIGHT |
| | elif gpt_model_type.lower() == "v1": |
| | return trtllm.BatchingType.STATIC |
| | raise pb_utils.TritonModelException( |
| | f"gpt_model_type value of '{gpt_model_type}' is not supported.") |
| |
|
| |
|
| | def convert_decoding_mode(decoding_mode: str): |
| | if decoding_mode is None: |
| | return None |
| | elif decoding_mode == "auto": |
| | return trtllm.DecodingMode.Auto() |
| | elif decoding_mode == "top_k": |
| | return trtllm.DecodingMode.TopK() |
| | elif decoding_mode == "top_p": |
| | return trtllm.DecodingMode.TopP() |
| | elif decoding_mode == "top_k_top_p": |
| | return trtllm.DecodingMode.TopKTopP() |
| | elif decoding_mode == "beam_search": |
| | return trtllm.DecodingMode.BeamSearch() |
| | elif decoding_mode == "medusa": |
| | return trtllm.DecodingMode.Medusa() |
| | raise pb_utils.TritonModelException( |
| | f"decoding_mode value of '{decoding_mode}' is not supported.") |
| |
|
| |
|
| | def convert_timestamp_to_seconds(timestamp: str): |
| | return int( |
| | datetime.datetime.strptime(timestamp, |
| | "%m-%d-%Y %H:%M:%S.%f").timestamp()) |
| |
|
| |
|
| | class TritonPythonModel: |
| | """Your Python model must use the same class name. Every Python model |
| | that is created must have "TritonPythonModel" as the class name. |
| | """ |
| |
|
| | def get_scheduler_config(self, model_config): |
| | batch_scheduler_policy = get_parameter(model_config, |
| | "batch_scheduler_policy") |
| | if batch_scheduler_policy is None: |
| | return trtllm.SchedulerConfig() |
| | return trtllm.SchedulerConfig( |
| | convert_scheduler_policy(batch_scheduler_policy)) |
| |
|
| | def get_kv_cache_config(self, model_config): |
| | kwargs = { |
| | "enable_block_reuse": |
| | get_parameter(model_config, "enable_kv_cache_reuse", bool), |
| | "max_tokens": |
| | get_parameter(model_config, "max_tokens_in_paged_kv_cache", int), |
| | "sink_token_length": |
| | get_parameter(model_config, "sink_token_length", int), |
| | "free_gpu_memory_fraction": |
| | get_parameter(model_config, "kv_cache_free_gpu_mem_fraction", |
| | float), |
| | "host_cache_size": |
| | get_parameter(model_config, "kv_cache_host_memory_bytes", int), |
| | "onboard_blocks": |
| | get_parameter(model_config, "kv_cache_onboard_blocks", bool), |
| | } |
| | max_attention_window_size = get_parameter(model_config, |
| | "max_attention_window_size") |
| | if max_attention_window_size: |
| | kwargs["max_attention_window"] = [ |
| | int(x) for x in max_attention_window_size.split(",") |
| | ] |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.KvCacheConfig(**kwargs) |
| |
|
| | def get_parallel_config(self, model_config): |
| | kwargs = {} |
| | gpu_device_ids = get_parameter(model_config, "gpu_device_ids") |
| | if gpu_device_ids: |
| | kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")] |
| | self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR", |
| | "0") == "1" |
| | if self.use_orchestrator_mode: |
| | kwargs[ |
| | "communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR |
| | worker_path = get_parameter(model_config, "worker_path") |
| | if worker_path is not None: |
| | raise pb_utils.TritonModelException( |
| | "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable." |
| | ) |
| | executor_worker_path = get_parameter(model_config, |
| | "executor_worker_path") |
| | kwargs["orchestrator_config"] = trtllm.OrchestratorConfig( |
| | True, executor_worker_path) |
| | if len(kwargs) > 0: |
| | return trtllm.ParallelConfig(**kwargs) |
| | return None |
| |
|
| | def get_peft_cache_config(self, model_config): |
| | kwargs = { |
| | "optimal_adapter_size": |
| | get_parameter(model_config, "lora_cache_optimal_adapter_size", |
| | int), |
| | "max_adapter_size": |
| | get_parameter(model_config, "lora_cache_max_adapter_size", int), |
| | "device_cache_percent": |
| | get_parameter(model_config, "lora_cache_gpu_memory_fraction", |
| | float), |
| | "host_cache_size": |
| | get_parameter(model_config, "lora_cache_host_memory_bytes", int), |
| | } |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.PeftCacheConfig(**kwargs) |
| |
|
| | def get_decoding_config(self, model_config): |
| | kwargs = { |
| | "medusa_choices": |
| | parse_medusa_choices(get_parameter(model_config, |
| | "medusa_choices")), |
| | "decoding_mode": |
| | convert_decoding_mode(get_parameter(model_config, |
| | "decoding_mode")), |
| | } |
| | print(kwargs) |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.DecodingConfig(**kwargs) |
| |
|
| | def get_extended_runtime_perf_knob_config(self, model_config): |
| | kwargs = { |
| | "multi_block_mode": |
| | get_parameter(model_config, "multi_block_mode", bool), |
| | "enable_context_fmha_fp32_acc": |
| | get_parameter(model_config, "enable_context_fmha_fp32_acc", bool) |
| | } |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs) |
| |
|
| | def get_executor_config(self, model_config): |
| | kwargs = { |
| | "max_beam_width": |
| | get_parameter(model_config, "max_beam_width", int), |
| | "scheduler_config": |
| | self.get_scheduler_config(model_config), |
| | "kv_cache_config": |
| | self.get_kv_cache_config(model_config), |
| | "enable_chunked_context": |
| | get_parameter(model_config, "enable_chunked_context", bool), |
| | "normalize_log_probs": |
| | get_parameter(model_config, "normalize_log_probs", bool), |
| | "batching_type": |
| | convert_batching_type(get_parameter(model_config, |
| | "gpt_model_type")), |
| | "parallel_config": |
| | self.get_parallel_config(model_config), |
| | "peft_cache_config": |
| | self.get_peft_cache_config(model_config), |
| | "decoding_config": |
| | self.get_decoding_config(model_config), |
| | "max_queue_size": |
| | model_config.get( |
| | "dynamic_batching", |
| | {}, |
| | ).get( |
| | "default_queue_policy", |
| | {}, |
| | ).get("max_queue_size"), |
| | "extended_runtime_perf_knob_config": |
| | self.get_extended_runtime_perf_knob_config(model_config) |
| | } |
| | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | return trtllm.ExecutorConfig(**kwargs) |
| |
|
| | def create_metrics(self, model: str, version: str, is_v1_model: bool): |
| | self.request_metric_family = pb_utils.MetricFamily( |
| | name="nv_trt_llm_request_metrics", |
| | description="TRT LLM request metrics", |
| | kind=pb_utils.MetricFamily.GAUGE, |
| | ) |
| | self.runtime_memory_metric_family = pb_utils.MetricFamily( |
| | name="nv_trt_llm_runtime_memory_metrics", |
| | description="TRT LLM runtime memory metrics", |
| | kind=pb_utils.MetricFamily.GAUGE, |
| | ) |
| | self.kv_cache_metric_family = pb_utils.MetricFamily( |
| | name="nv_trt_llm_kv_cache_block_metrics", |
| | description="TRT LLM KV cache block metrics", |
| | kind=pb_utils.MetricFamily.GAUGE, |
| | ) |
| | model_type = "v1" if is_v1_model else "inflight_batcher" |
| | self.model_type_metric_family = pb_utils.MetricFamily( |
| | name=f"nv_trt_llm_{model_type}_metrics", |
| | description=f"TRT LLM {model_type}-specific metrics", |
| | kind=pb_utils.MetricFamily.GAUGE, |
| | ) |
| | self.general_metric_family = pb_utils.MetricFamily( |
| | name="nv_trt_llm_general_metrics", |
| | description="General TRT LLM metrics", |
| | kind=pb_utils.MetricFamily.GAUGE, |
| | ) |
| | common_labels = {"model": model, "version": version} |
| | self.all_metrics = { |
| | |
| | "num_active_requests": |
| | self.request_metric_family.Metric(labels={ |
| | "request_type": "active", |
| | **common_labels |
| | }), |
| | "max_num_active_requests": |
| | self.request_metric_family.Metric(labels={ |
| | "request_type": "max", |
| | **common_labels |
| | }), |
| | "num_scheduled_requests": |
| | self.request_metric_family.Metric(labels={ |
| | "request_type": "scheduled", |
| | **common_labels |
| | }), |
| | "num_context_requests": |
| | self.request_metric_family.Metric(labels={ |
| | "request_type": "context", |
| | **common_labels |
| | }), |
| | |
| | "cpu_mem_usage": |
| | self.runtime_memory_metric_family.Metric(labels={ |
| | "memory_type": "cpu", |
| | **common_labels |
| | }), |
| | "gpu_mem_usage": |
| | self.runtime_memory_metric_family.Metric(labels={ |
| | "memory_type": "gpu", |
| | **common_labels |
| | }), |
| | "pinned_mem_usage": |
| | self.runtime_memory_metric_family.Metric(labels={ |
| | "memory_type": "pinned", |
| | **common_labels |
| | }), |
| | |
| | "max_num_blocks": |
| | self.kv_cache_metric_family.Metric(labels={ |
| | "kv_cache_block_type": "max", |
| | **common_labels |
| | }), |
| | "free_num_blocks": |
| | self.kv_cache_metric_family.Metric(labels={ |
| | "kv_cache_block_type": "free", |
| | **common_labels |
| | }), |
| | "used_num_blocks": |
| | self.kv_cache_metric_family.Metric(labels={ |
| | "kv_cache_block_type": "used", |
| | **common_labels |
| | }), |
| | "tokens_per_block": |
| | self.kv_cache_metric_family.Metric(labels={ |
| | "kv_cache_block_type": "tokens_per", |
| | **common_labels |
| | }), |
| | |
| | "timestamp": |
| | self.general_metric_family.Metric(labels={ |
| | "general_type": "timestamp", |
| | **common_labels |
| | }), |
| | "iter": |
| | self.general_metric_family.Metric(labels={ |
| | "general_type": "iteration_counter", |
| | **common_labels |
| | }), |
| | } |
| | if is_v1_model: |
| | self.all_metrics.update({ |
| | "num_ctx_tokens": |
| | self.model_type_metric_family.Metric(labels={ |
| | "v1_specific_metric": "total_context_tokens", |
| | **common_labels |
| | }), |
| | "num_gen_tokens": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "v1_specific_metric": "total_generation_tokens", |
| | **common_labels |
| | }), |
| | "empty_gen_slots": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "v1_specific_metric": "empty_generation_slots", |
| | **common_labels |
| | }), |
| | }) |
| | else: |
| | self.all_metrics.update({ |
| | "num_ctx_tokens": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "inflight_batcher_specific_metric": |
| | "total_context_tokens", |
| | **common_labels |
| | }), |
| | "num_gen_requests": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "inflight_batcher_specific_metric": |
| | "generation_requests", |
| | **common_labels |
| | }), |
| | "micro_batch_id": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "inflight_batcher_specific_metric": "micro_batch_id", |
| | **common_labels |
| | }), |
| | "num_paused_requests": |
| | self.model_type_metric_family.Metric( |
| | labels={ |
| | "inflight_batcher_specific_metric": "paused_requests", |
| | **common_labels |
| | }), |
| | }) |
| |
|
| | def initialize(self, args): |
| | """`initialize` is called only once when the model is being loaded. |
| | Implementing `initialize` function is optional. This function allows |
| | the model to initialize any state associated with this model. |
| | |
| | Parameters |
| | ---------- |
| | args : dict |
| | Both keys and values are strings. The dictionary keys and values are: |
| | * model_config: A JSON string containing the model configuration |
| | * model_instance_kind: A string containing model instance kind |
| | * model_instance_device_id: A string containing model instance device ID |
| | * model_repository: Model repository path |
| | * model_version: Model version |
| | * model_name: Model name |
| | """ |
| | model_config = json.loads(args['model_config']) |
| | gpt_model_path = get_parameter(model_config, "gpt_model_path") |
| | if get_parameter(model_config, "enable_trt_overlap", bool): |
| | raise pb_utils.TritonModelException( |
| | f"enable_trt_overlap=true is not supported.") |
| | self.exclude_input_from_output = get_parameter( |
| | model_config, "exclude_input_in_output", bool) |
| | executor_config = self.get_executor_config(model_config) |
| | self.executor = trtllm.Executor(gpt_model_path, |
| | trtllm.ModelType.DECODER_ONLY, |
| | executor_config) |
| | self.decoupled = pb_utils.using_decoupled_model_transaction_policy( |
| | model_config) |
| | self.cancellation_check_period_ms = get_parameter( |
| | model_config, "cancellation_check_period_ms", int) or 100 |
| | self.stats_check_period_ms = get_parameter( |
| | model_config, "stats_check_period_ms", int) or 100 |
| |
|
| | if not self.decoupled: |
| | raise pb_utils.TritonModelException( |
| | "Please enable decoupled transaction policy in the model configuration to serve this model" |
| | ) |
| |
|
| | self.create_metrics(args["model_name"], |
| | args["model_version"], |
| | is_v1_model=executor_config.batching_type == |
| | trtllm.BatchingType.STATIC) |
| | self.triton_user_id_to_req_ids = {} |
| | self.triton_req_id_to_req_ids = {} |
| | self.req_id_to_request_data = {} |
| | self.lock = Lock() |
| | self.running = False |
| | self.awaiter_thread = Thread(target=self.awaiter_loop) |
| | self.cancellation_thread = Thread(target=self.cancellation_loop) |
| | self.metrics_thread = Thread(target=self.metrics_loop) |
| | if self.executor.can_enqueue_requests(): |
| | self.running = True |
| | self.awaiter_thread.start() |
| | self.cancellation_thread.start() |
| | self.metrics_thread.start() |
| | else: |
| | |
| | self.executor.shutdown() |
| |
|
| | def handle_stop_request(self, triton_user_id, response_sender): |
| | if triton_user_id is None or triton_user_id == "": |
| | response_sender.send( |
| | pb_utils.InferenceResponse(error=pb_utils.TritonError( |
| | "A request id must be provided for request cancellation")), |
| | flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
| | return |
| |
|
| | with self.lock: |
| | if triton_user_id in self.triton_user_id_to_req_ids: |
| | req_ids = self.triton_user_id_to_req_ids[triton_user_id] |
| | for req_id in req_ids: |
| | self.executor.cancel_request(req_id) |
| |
|
| | response_sender.send( |
| | pb_utils.InferenceResponse(), |
| | flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
| |
|
| | def execute(self, requests): |
| | """`execute` must be implemented in every Python model. `execute` |
| | function receives a list of pb_utils.InferenceRequest as the only |
| | argument. This function is called when an inference is requested |
| | for this model. |
| | |
| | Parameters |
| | ---------- |
| | requests : list |
| | A list of pb_utils.InferenceRequest |
| | |
| | Returns |
| | ------- |
| | list |
| | A list of pb_utils.InferenceResponse. The length of this list must |
| | be the same as `requests` |
| | """ |
| | if not self.executor.can_enqueue_requests(): |
| | return |
| |
|
| | |
| |
|
| | triton_requests = [] |
| | executor_requests = [] |
| | batch_indices = [] |
| | triton_user_ids = [] |
| | triton_req_ids = [] |
| |
|
| | for request in requests: |
| |
|
| | triton_user_id = request.request_id() |
| |
|
| | response_sender = request.get_response_sender() |
| | stop = get_input_scalar_by_name(request, 'stop') |
| |
|
| | if stop: |
| | self.handle_stop_request(triton_user_id, response_sender) |
| | else: |
| | |
| | triton_req_id = str(randint(0, sys.maxsize)) |
| | self.triton_req_id_to_req_ids[triton_req_id] = set() |
| | if triton_user_id is not None and triton_user_id != "": |
| | self.triton_user_id_to_req_ids[triton_user_id] = set() |
| |
|
| | try: |
| | converted_reqs = convert_request( |
| | request, self.exclude_input_from_output, |
| | self.decoupled) |
| | except Exception as e: |
| | response_sender.send( |
| | pb_utils.InferenceResponse(error=pb_utils.TritonError( |
| | f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'" |
| | )), |
| | flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
| | else: |
| | for batch_index, converted_req in enumerate( |
| | converted_reqs): |
| | triton_requests.append(request) |
| | executor_requests.append(converted_req) |
| | triton_user_ids.append(triton_user_id) |
| | triton_req_ids.append(triton_req_id) |
| | batch_indices.append(batch_index) |
| |
|
| | with self.lock: |
| | request_ids = self.executor.enqueue_requests(executor_requests) |
| | for req_id, triton_req_id, triton_user_id, triton_request, batch_index in zip( |
| | request_ids, triton_req_ids, triton_user_ids, |
| | triton_requests, batch_indices): |
| | self.req_id_to_request_data[ |
| | req_id] = triton_req_id, triton_user_id, batch_index, triton_request.get_response_sender( |
| | ) |
| | self.triton_req_id_to_req_ids[triton_req_id].add(req_id) |
| | if triton_user_id is not None and triton_user_id != "": |
| | self.triton_user_id_to_req_ids[triton_user_id].add(req_id) |
| |
|
| | return None |
| |
|
| | def awaiter_loop(self): |
| | """Gets responses from executor and returns the results.""" |
| | while self.running: |
| | for response in self.executor.await_responses( |
| | timeout=datetime.timedelta(milliseconds=1)): |
| | req_id = response.request_id |
| | with self.lock: |
| | if req_id not in self.req_id_to_request_data: |
| | continue |
| | triton_req_id, triton_user_id, batch_index, response_sender = self.req_id_to_request_data[ |
| | req_id] |
| |
|
| | triton_response, is_final = convert_response( |
| | response, batch_index) |
| |
|
| | triton_request_final = False |
| | if is_final: |
| | with self.lock: |
| | |
| | self.triton_req_id_to_req_ids[triton_req_id].remove( |
| | req_id) |
| | if len(self.triton_req_id_to_req_ids[triton_req_id] |
| | ) == 0: |
| | pb_utils.Logger.log_info( |
| | f"DELETING Req id {req_id}, triton_req_id {triton_req_id} " |
| | ) |
| | triton_request_final = True |
| | del self.triton_req_id_to_req_ids[triton_req_id] |
| | if triton_user_id is not None and triton_user_id != "": |
| | del self.triton_user_id_to_req_ids[ |
| | triton_user_id] |
| | del self.req_id_to_request_data[req_id] |
| |
|
| | response_sender.send( |
| | triton_response, |
| | flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL |
| | if triton_request_final else 0) |
| |
|
| | |
| | del response_sender |
| |
|
| | def cancellation_loop(self): |
| | """Checks if any pending requests have been cancelled.""" |
| | while self.running: |
| | time.sleep(self.cancellation_check_period_ms / 1000.0) |
| | with self.lock: |
| | for req_id, (triton_req_id, triton_user_id, batch_index, |
| | response_sender |
| | ) in self.req_id_to_request_data.items(): |
| | if response_sender.is_cancelled(): |
| | self.executor.cancel_request(req_id) |
| | |
| | del response_sender |
| |
|
| | def metrics_loop(self): |
| | """Updates triton metrics using stats from the executor.""" |
| | while self.running: |
| | time.sleep(self.stats_check_period_ms / 1000.0) |
| | for stat in self.executor.get_latest_iteration_stats(): |
| | try: |
| | for key, metric in self.all_metrics.items(): |
| | value = None |
| | if hasattr(stat, key): |
| | value = getattr(stat, key) |
| | elif stat.kv_cache_stats is not None and hasattr( |
| | stat.kv_cache_stats, key): |
| | value = getattr(stat.kv_cache_stats, key) |
| | elif stat.static_batching_stats is not None and hasattr( |
| | stat.static_batching_stats, key): |
| | value = getattr(stat.static_batching_stats, key) |
| | elif stat.inflight_batching_stats is not None and hasattr( |
| | stat.inflight_batching_stats, key): |
| | value = getattr(stat.inflight_batching_stats, key) |
| | if value is not None: |
| | if key == "timestamp": |
| | value = convert_timestamp_to_seconds(value) |
| | metric.set(value) |
| | else: |
| | pb_utils.Logger.log_warn( |
| | f"Metric \"{key}\" not found.") |
| | except Exception as e: |
| | pb_utils.Logger.log_warn( |
| | f"Error while processing metrics: {e}") |
| |
|
| | def finalize(self): |
| | """`finalize` is called only once when the model is being unloaded. |
| | Implementing `finalize` function is optional. This function allows |
| | the model to perform any necessary clean ups before exit. |
| | """ |
| | if self.executor.can_enqueue_requests(): |
| | self.running = False |
| | self.awaiter_thread.join() |
| | self.cancellation_thread.join() |
| | self.metrics_thread.join() |
| | self.executor.shutdown() |
| |
|