moplaymobile commited on
Commit
fdbbf2e
·
verified ·
1 Parent(s): 60aa6ef

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ whisper_large_v3_max_batch_512/decoder/rank0.engine filter=lfs diff=lfs merge=lfs -text
37
+ whisper_large_v3_max_batch_512/encoder/rank0.engine filter=lfs diff=lfs merge=lfs -text
model_repo_whisper_512/tensorrt_llm/1/.gitkeep ADDED
File without changes
model_repo_whisper_512/tensorrt_llm/1/model.py ADDED
@@ -0,0 +1,1518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import sys
5
+ import time
6
+ from dataclasses import dataclass
7
+ from random import randint
8
+ from threading import Lock, Thread
9
+ from typing import Any, List
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import triton_python_backend_utils as pb_utils
15
+ from torch import from_numpy
16
+ from torch.utils.dlpack import from_dlpack
17
+
18
+ import tensorrt_llm.bindings.executor as trtllm
19
+ from tensorrt_llm.llmapi.tokenizer import _xgrammar_tokenizer_info
20
+
21
+ METRIC_TOTAL_OUTPUT_TOKENS = "total_output_tokens"
22
+ METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens"
23
+ import tensorrt_llm.logger as logger
24
+
25
+ # From https://github.com/pytorch/pytorch/blob/39425feac799905402abe4d15667fa47c344f2d7/torch/testing/_internal/common_utils.py#L1761
26
+ # Dict of NumPy dtype -> torch dtype (when the correspondence exists)
27
+ numpy_to_torch_dtype_dict = {
28
+ np.bool_: torch.bool,
29
+ np.uint8: torch.uint8,
30
+ np.uint16: torch.uint16,
31
+ np.uint32: torch.uint32,
32
+ np.uint64: torch.uint64,
33
+ np.int8: torch.int8,
34
+ np.int16: torch.int16,
35
+ np.int32: torch.int32,
36
+ np.int64: torch.int64,
37
+ np.float16: torch.float16,
38
+ np.float32: torch.float32,
39
+ np.float64: torch.float64,
40
+ np.complex64: torch.complex64,
41
+ np.complex128: torch.complex128
42
+ }
43
+
44
+ # Dict of torch dtype -> NumPy dtype
45
+ torch_to_numpy_dtype_dict = {
46
+ value: key
47
+ for (key, value) in numpy_to_torch_dtype_dict.items()
48
+ }
49
+ torch_to_numpy_dtype_dict.update({
50
+ torch.bfloat16: np.float32,
51
+ torch.complex32: np.complex64
52
+ })
53
+
54
+
55
+ @dataclass
56
+ class RequestData:
57
+ triton_req_id: int
58
+ triton_user_id: str
59
+ batch_index: int
60
+ batch_size: int
61
+ num_return_sequences: int
62
+ num_input_tokens: int
63
+ num_output_tokens: int
64
+ response_sender: Any
65
+
66
+
67
+ def mpi_comm():
68
+ from mpi4py import MPI
69
+ return MPI.COMM_WORLD
70
+
71
+
72
+ def mpi_rank():
73
+ return mpi_comm().Get_rank()
74
+
75
+
76
+ def get_input_tensor_by_name(request,
77
+ name,
78
+ expected_batch_size=None,
79
+ batch_index=None,
80
+ force_on_torch=False):
81
+ tensor = pb_utils.get_input_tensor_by_name(request, name)
82
+ if tensor is None:
83
+ return None
84
+
85
+ if tensor.is_cpu() and not force_on_torch:
86
+ tensor = tensor.as_numpy()
87
+ else:
88
+ tensor = from_dlpack(tensor.to_dlpack())
89
+
90
+ if expected_batch_size is not None and tensor.shape[
91
+ 0] != expected_batch_size:
92
+ raise pb_utils.TritonModelException(
93
+ f"Expected batch size doesn't match batch size for tensor {name}. Expected {expected_batch_size} got {tensor.shape[0]}"
94
+ )
95
+
96
+ if batch_index is not None and expected_batch_size is not None and batch_index >= expected_batch_size:
97
+ raise pb_utils.TritonModelException(
98
+ f"Invalid batch index in get_input_tensor_by_name for {name}")
99
+
100
+ if batch_index is not None:
101
+ # Add leading 1 batch dimension
102
+ if isinstance(tensor, np.ndarray):
103
+ return np.expand_dims(tensor[batch_index], axis=0)
104
+ elif isinstance(tensor, torch.Tensor):
105
+ return torch.unsqueeze(tensor[batch_index], dim=0)
106
+ else:
107
+ return tensor
108
+
109
+
110
+ def get_input_scalar_by_name(request,
111
+ name,
112
+ expected_batch_size=1,
113
+ batch_index=0):
114
+ tensor = pb_utils.get_input_tensor_by_name(request, name)
115
+ if tensor is None:
116
+ return None
117
+ tensor = tensor.as_numpy()
118
+
119
+ if tensor.size != expected_batch_size:
120
+ raise pb_utils.TritonModelException(
121
+ f"Expected a scalar tensor for tensor {name}")
122
+
123
+ return tensor.item(batch_index)
124
+
125
+
126
+ def read_parameter_as_type(value, name, pytype=str):
127
+ if value == "":
128
+ return None
129
+ if value.startswith("${") and value.endswith("}"):
130
+ return None
131
+ if pytype is bool:
132
+ return value.lower() in ["1", "true"]
133
+ try:
134
+ result = pytype(value)
135
+ return result
136
+ except:
137
+ pb_utils.Logger.log_warning(
138
+ f"Could not read parameter '{name}' with value '{value}', will use default."
139
+ )
140
+ return None
141
+
142
+
143
+ def get_parameter(model_config, name, pytype=str):
144
+ if name not in model_config['parameters']:
145
+ return None
146
+ return read_parameter_as_type(
147
+ model_config['parameters'][name]['string_value'], name, pytype)
148
+
149
+
150
+ def convert_word_list(word_list):
151
+ if word_list is None:
152
+ return None
153
+ word_list = word_list.tolist()
154
+ if len(word_list) == 0 or len(word_list[0]) != 2:
155
+ raise pb_utils.TritonModelException(f"Invalid format for word list.")
156
+ words, indices = word_list[0]
157
+ result = []
158
+ current_index = 0
159
+ for i in indices:
160
+ if i == -1:
161
+ continue
162
+ if i > len(words):
163
+ raise pb_utils.TritonModelException(
164
+ f"Invalid format for word list.")
165
+ current_word = []
166
+ while current_index < i:
167
+ current_word.append(words[current_index])
168
+ current_index += 1
169
+ result.append(current_word)
170
+ return result
171
+
172
+
173
+ def parse_medusa_choices(medusa_choices):
174
+ if medusa_choices is None:
175
+ return None
176
+ try:
177
+ result = json.loads(
178
+ "[" + medusa_choices.replace("{", "[").replace("}", "]") + "]")
179
+ assert isinstance(result, list) and len(result) > 0
180
+ assert all([isinstance(x, list) for x in result])
181
+ assert all([isinstance(y, int) for x in result for y in x])
182
+ except Exception:
183
+ raise pb_utils.TritonModelException(
184
+ "Invalid format for medusa_choices")
185
+ return result
186
+
187
+
188
+ def parse_eagle_choices(eagle_choices):
189
+ return parse_medusa_choices(eagle_choices)
190
+
191
+
192
+ def get_sampling_config_from_request(request, batch_size=1, batch_index=0):
193
+ kwargs = {}
194
+ kwargs['beam_width'] = get_input_scalar_by_name(
195
+ request, 'beam_width', batch_size, batch_index) or 1
196
+ kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k',
197
+ batch_size, batch_index)
198
+ kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p',
199
+ batch_size, batch_index)
200
+ kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[
201
+ 'top_p'] <= 0 else kwargs['top_p']
202
+ kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed',
203
+ batch_size, batch_index)
204
+ kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature',
205
+ batch_size, batch_index)
206
+ kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length',
207
+ batch_size, batch_index)
208
+ kwargs['repetition_penalty'] = get_input_scalar_by_name(
209
+ request, 'repetition_penalty', batch_size, batch_index)
210
+ kwargs['presence_penalty'] = get_input_scalar_by_name(
211
+ request, 'presence_penalty', batch_size, batch_index)
212
+ kwargs['frequency_penalty'] = get_input_scalar_by_name(
213
+ request, 'frequency_penalty', batch_size, batch_index)
214
+ kwargs['length_penalty'] = get_input_scalar_by_name(
215
+ request, 'len_penalty', batch_size, batch_index)
216
+ kwargs['top_p_min'] = get_input_scalar_by_name(request,
217
+ 'runtime_top_p_min',
218
+ batch_size, batch_index)
219
+ kwargs['top_p_reset_ids'] = get_input_scalar_by_name(
220
+ request, 'runtime_top_p_reset_ids', batch_size, batch_index)
221
+ kwargs['top_p_decay'] = get_input_scalar_by_name(request,
222
+ 'runtime_top_p_decay',
223
+ batch_size, batch_index)
224
+ kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name(
225
+ request, 'beam_search_diversity_rate', batch_size, batch_index)
226
+ kwargs['early_stopping'] = get_input_scalar_by_name(
227
+ request, 'early_stopping', batch_size, batch_index)
228
+ kwargs['num_return_sequences'] = get_input_scalar_by_name(
229
+ request, 'num_return_sequences', batch_size, batch_index) or 1
230
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
231
+ return trtllm.SamplingConfig(**kwargs)
232
+
233
+
234
+ def get_output_config_from_request(request, batch_size=1, batch_index=0):
235
+ kwargs = {}
236
+ kwargs["return_log_probs"] = get_input_scalar_by_name(
237
+ request, 'return_log_probs', batch_size, batch_index)
238
+ kwargs["return_context_logits"] = get_input_scalar_by_name(
239
+ request, 'return_context_logits', batch_size, batch_index)
240
+ kwargs["return_generation_logits"] = get_input_scalar_by_name(
241
+ request, 'return_generation_logits', batch_size, batch_index)
242
+ kwargs["return_perf_metrics"] = get_input_scalar_by_name(
243
+ request, 'return_perf_metrics', batch_size, batch_index)
244
+ if get_input_scalar_by_name(request, 'return_kv_cache_reuse_stats',
245
+ batch_size, batch_index):
246
+ pb_utils.Logger.log_warn(
247
+ "return_kv_cache_reuse_stats is deprecated, please use return_perf_metrics instead."
248
+ )
249
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
250
+ return trtllm.OutputConfig(**kwargs)
251
+
252
+
253
+ def get_external_draft_tokens_config_from_request(request,
254
+ batch_size=1,
255
+ batch_index=0):
256
+ kwargs = {}
257
+ draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids',
258
+ batch_size, batch_index)
259
+ if draft_input_ids is not None:
260
+ kwargs['tokens'] = draft_input_ids[0].tolist()
261
+ draft_logits = get_input_tensor_by_name(request, 'draft_logits',
262
+ batch_size, batch_index)
263
+ if draft_logits is not None:
264
+ kwargs['logits'] = from_numpy(draft_logits).squeeze(dim=0)
265
+ kwargs['acceptance_threshold'] = get_input_scalar_by_name(
266
+ request, 'draft_acceptance_threshold', batch_size, batch_index)
267
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
268
+ if len(kwargs) > 0:
269
+ return trtllm.ExternalDraftTokensConfig(**kwargs)
270
+ return None
271
+
272
+
273
+ def get_prompt_tuning_config_from_request(request,
274
+ batch_size=1,
275
+ batch_index=0,
276
+ input_length=0):
277
+ # prompt_vocab_size is unused by executor.
278
+ kwargs = {}
279
+ prompt_embedding_table = get_input_tensor_by_name(
280
+ request, 'prompt_embedding_table', batch_size, batch_index)
281
+ prompt_table_extra_ids = get_input_tensor_by_name(
282
+ request, 'prompt_table_extra_ids', batch_size, batch_index)
283
+ if prompt_embedding_table is not None:
284
+ if isinstance(prompt_embedding_table, np.ndarray):
285
+ kwargs["embedding_table"] = from_numpy(
286
+ prompt_embedding_table).squeeze(dim=0)
287
+ elif isinstance(prompt_embedding_table, torch.Tensor):
288
+ kwargs["embedding_table"] = prompt_embedding_table.squeeze(dim=0)
289
+
290
+ if prompt_table_extra_ids is not None:
291
+ prompt_table_extra_ids = prompt_table_extra_ids[0].tolist()
292
+ if len(prompt_table_extra_ids) != 0:
293
+ kwargs["input_token_extra_ids"] = prompt_table_extra_ids[
294
+ 0:input_length]
295
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
296
+ if len(kwargs) > 0:
297
+ return trtllm.PromptTuningConfig(**kwargs)
298
+ return None
299
+
300
+
301
+ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
302
+ kwargs = {}
303
+ kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id',
304
+ batch_size, batch_index)
305
+ lora_weights = get_input_tensor_by_name(request, 'lora_weights',
306
+ batch_size, batch_index)
307
+ if lora_weights is not None:
308
+ kwargs["weights"] = from_numpy(lora_weights).squeeze(dim=0)
309
+ lora_config = get_input_tensor_by_name(request, 'lora_config', batch_size,
310
+ batch_index)
311
+ if lora_config is not None:
312
+ kwargs["config"] = from_numpy(lora_config).squeeze(dim=0)
313
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
314
+ if len(kwargs) > 0:
315
+ return trtllm.LoraConfig(**kwargs)
316
+ return None
317
+
318
+
319
+ def get_guided_decoding_params_from_request(request,
320
+ batch_size=1,
321
+ batch_index=0):
322
+ kwargs = {}
323
+ guided_decoding_guide_type = get_input_tensor_by_name(
324
+ request, 'guided_decoding_guide_type', batch_size, batch_index)
325
+ if guided_decoding_guide_type is not None:
326
+ guided_decoding_guide_type = guided_decoding_guide_type.squeeze(
327
+ axis=0)[0].decode()
328
+ guided_decoding_guide_type_mapping = {
329
+ "json": trtllm.GuidedDecodingParams.GuideType.JSON,
330
+ "json_schema": trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA,
331
+ "regex": trtllm.GuidedDecodingParams.GuideType.REGEX,
332
+ "ebnf_grammar": trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR
333
+ }
334
+ guided_decoding_guide_type = guided_decoding_guide_type_mapping.get(
335
+ guided_decoding_guide_type)
336
+ kwargs['guide_type'] = guided_decoding_guide_type
337
+
338
+ guided_decoding_guide = get_input_tensor_by_name(request,
339
+ 'guided_decoding_guide',
340
+ batch_size, batch_index)
341
+ if guided_decoding_guide is not None:
342
+ kwargs['guide'] = guided_decoding_guide.squeeze(axis=0)[0].decode()
343
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
344
+ if len(kwargs) > 0:
345
+ return trtllm.GuidedDecodingParams(**kwargs)
346
+ return None
347
+
348
+
349
+ def get_kv_cache_retention_config_from_request(request,
350
+ batch_size=1,
351
+ batch_index=0):
352
+
353
+ def get_tensor_and_check_length(name: str, expected_length: int):
354
+ tensor = get_input_tensor_by_name(request, name, batch_size,
355
+ batch_index)
356
+
357
+ if tensor is None:
358
+ raise RuntimeError(f"{name} must be provided.")
359
+
360
+ tensor = np.squeeze(tensor, axis=0)
361
+
362
+ if len(tensor) != expected_length:
363
+ raise RuntimeError(
364
+ f"Invalid {name} length. Expected length {expected_length}, got length {len(tensor)}"
365
+ )
366
+
367
+ return tensor
368
+
369
+ token_range_starts = get_input_tensor_by_name(
370
+ request, "retention_token_range_starts", batch_size, batch_index)
371
+
372
+ if token_range_starts is not None:
373
+ token_range_starts = np.squeeze(token_range_starts, axis=0)
374
+
375
+ token_range_ends = get_tensor_and_check_length(
376
+ "retention_token_range_ends", len(token_range_starts))
377
+ token_range_ends = [
378
+ None if end == -1 else end for end in token_range_ends
379
+ ]
380
+
381
+ token_range_priorities = get_tensor_and_check_length(
382
+ "retention_token_range_priorities", len(token_range_starts))
383
+
384
+ token_range_durations_ms = get_input_tensor_by_name(
385
+ request, "retention_token_range_durations_ms", batch_size,
386
+ batch_index)
387
+
388
+ if token_range_durations_ms is None:
389
+ token_range_durations_ms = [None] * len(token_range_starts)
390
+ else:
391
+ token_range_durations_ms = np.squeeze(token_range_durations_ms,
392
+ axis=0)
393
+ token_range_durations_ms = [
394
+ None if duration == -1 else duration
395
+ for duration in token_range_durations_ms
396
+ ]
397
+
398
+ if len(token_range_durations_ms) != len(token_range_starts):
399
+ raise RuntimeError(
400
+ f"Invalid retention_token_range_durations length. Expected length {len(token_range_starts)}, got length {len(token_range_durations_ms)}"
401
+ )
402
+
403
+ ranges = []
404
+
405
+ for start, end, priority, duration_ms in zip(token_range_starts,
406
+ token_range_ends,
407
+ token_range_priorities,
408
+ token_range_durations_ms):
409
+ ranges.append(
410
+ trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig(
411
+ token_start=start,
412
+ token_end=end,
413
+ priority=priority.item(),
414
+ duration_ms=None if duration_ms is None else
415
+ datetime.timedelta(milliseconds=duration_ms.item())))
416
+
417
+ decode_args = {}
418
+
419
+ decode_priority = get_input_scalar_by_name(
420
+ request, "retention_decode_priority", batch_size, batch_index)
421
+ if decode_priority is not None:
422
+ decode_args['decode_retention_priority'] = decode_priority
423
+
424
+ decode_duration_ms = get_input_scalar_by_name(
425
+ request, "retention_decode_duration_ms", batch_size, batch_index)
426
+ if decode_duration_ms is not None:
427
+ decode_args[
428
+ 'decode_duration_ms'] = decode_duration_ms if decode_duration_ms != -1 else None
429
+
430
+ return trtllm.KvCacheRetentionConfig(
431
+ token_range_retention_configs=ranges, **decode_args)
432
+
433
+ return None
434
+
435
+
436
+ def get_lookahead_decoding_config_from_request(request,
437
+ executor_lookahead_config,
438
+ batch_size=1,
439
+ batch_index=0):
440
+ lookahead_window_size = get_input_tensor_by_name(request,
441
+ "lookahead_window_size",
442
+ batch_size, batch_index)
443
+
444
+ lookahead_ngram_size = get_input_tensor_by_name(request,
445
+ "lookahead_ngram_size",
446
+ batch_size, batch_index)
447
+
448
+ lookahead_verification_set_size = get_input_tensor_by_name(
449
+ request, "lookahead_verification_set_size", batch_size, batch_index)
450
+
451
+ # None lookahead config for requests.
452
+ if all(x is None for x in [
453
+ lookahead_window_size, lookahead_ngram_size,
454
+ lookahead_verification_set_size
455
+ ]):
456
+ return None
457
+
458
+ # Have request lookahead config but no executor config.
459
+ if executor_lookahead_config is None:
460
+ raise RuntimeError(
461
+ "The request lookahead decoding input tensors (window_size, ngram_size and verification_set_size) can only be set if the model instance lookahead parameters are also specified"
462
+ )
463
+
464
+ return trtllm.LookaheadDecodingConfig(lookahead_window_size,
465
+ lookahead_ngram_size,
466
+ lookahead_verification_set_size)
467
+
468
+
469
+ def build_1_2_5_buckets(max_value: int) -> List[int]:
470
+ """
471
+ Builds a list of buckets with increasing powers of 10 multiplied by
472
+ mantissa values (1, 5), starting from 10 until the value exceeds
473
+ the specified maximum.
474
+
475
+ Example:
476
+ >>> build_1_2_5_buckets(1000)
477
+ [10, 50, 100, 500, 1000]
478
+ """
479
+ mantissa_lst = [1, 5]
480
+ exponent = 1 # Start from exponent 1 instead of 0
481
+ buckets: List[int] = []
482
+ while True:
483
+ for m in mantissa_lst:
484
+ value = m * 10**exponent
485
+ if value <= max_value:
486
+ buckets.append(value)
487
+ else:
488
+ return buckets
489
+ exponent += 1
490
+
491
+
492
+ def convert_request(request,
493
+ exclude_input_from_output,
494
+ decoupled,
495
+ executor_lookahead_config=None):
496
+ inputs = {}
497
+ input_token_ids = get_input_tensor_by_name(request, 'input_ids')
498
+ if input_token_ids is None:
499
+ raise pb_utils.TritonModelException(
500
+ "A value is required for input_ids")
501
+ if len(input_token_ids.shape) != 2:
502
+ raise pb_utils.TritonModelException(f"Invalid format for input_ids")
503
+ batch_size = input_token_ids.shape[0]
504
+ requests = []
505
+ for batch_index in range(0, batch_size):
506
+ input_token_ids = get_input_tensor_by_name(request, 'input_ids',
507
+ batch_size, batch_index)[0]
508
+ if input_token_ids is None:
509
+ raise pb_utils.TritonModelException(
510
+ "A value is required for input_ids")
511
+ input_token_ids = input_token_ids.tolist()
512
+ if len(input_token_ids) == 0:
513
+ raise pb_utils.TritonModelException(
514
+ f"Invalid format for input_ids")
515
+
516
+ input_length = get_input_scalar_by_name(request, 'input_lengths',
517
+ batch_size, batch_index)
518
+ if input_length is None:
519
+ input_length = len(input_token_ids)
520
+ # Trim input token ids with input_lengths
521
+ inputs['input_token_ids'] = input_token_ids[0:input_length]
522
+ inputs['max_new_tokens'] = get_input_scalar_by_name(
523
+ request, 'request_output_len', batch_size, batch_index)
524
+ if inputs['max_new_tokens'] is None:
525
+ raise pb_utils.TritonModelException(
526
+ "A value is required for request_output_len")
527
+ inputs['streaming'] = get_input_scalar_by_name(request, 'streaming',
528
+ batch_size, batch_index)
529
+ if inputs['streaming'] and not decoupled:
530
+ raise pb_utils.TritonModelException(
531
+ "Streaming is only supported in decoupled mode.")
532
+
533
+ inputs['end_id'] = get_input_scalar_by_name(request, 'end_id',
534
+ batch_size, batch_index)
535
+ inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id',
536
+ batch_size, batch_index)
537
+ inputs['stop_words'] = convert_word_list(
538
+ get_input_tensor_by_name(request, 'stop_words_list', batch_size,
539
+ batch_index))
540
+ inputs['bad_words'] = convert_word_list(
541
+ get_input_tensor_by_name(request, 'bad_words_list', batch_size,
542
+ batch_index))
543
+ embedding_bias = get_input_tensor_by_name(request, 'embedding_bias',
544
+ batch_size, batch_index)
545
+ if embedding_bias is not None and embedding_bias.size != 0:
546
+ inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze(
547
+ dim=0)
548
+
549
+ sampling_config = get_sampling_config_from_request(
550
+ request, batch_size, batch_index)
551
+ output_config = get_output_config_from_request(request, batch_size,
552
+ batch_index)
553
+ req_exclude_input_from_output = get_input_scalar_by_name(
554
+ request, 'exclude_input_in_output', batch_size, batch_index)
555
+ if req_exclude_input_from_output is None:
556
+ # if request doesn't specify exclude_input_from_output, try to use the parameter
557
+ output_config.exclude_input_from_output = (
558
+ exclude_input_from_output
559
+ if exclude_input_from_output is not None else False)
560
+ else:
561
+ output_config.exclude_input_from_output = req_exclude_input_from_output
562
+
563
+ external_draft_tokens_config = get_external_draft_tokens_config_from_request(
564
+ request, batch_size, batch_index)
565
+ prompt_tuning_config = get_prompt_tuning_config_from_request(
566
+ request, batch_size, batch_index, input_length)
567
+ lora_config = get_lora_config_from_request(request, batch_size,
568
+ batch_index)
569
+ kv_cache_retention_config = get_kv_cache_retention_config_from_request(
570
+ request, batch_size, batch_index)
571
+ request_lookahead_config = get_lookahead_decoding_config_from_request(
572
+ request, executor_lookahead_config, batch_size, batch_index)
573
+
574
+ # Inputs for mllama support
575
+ encoder_input_features = get_input_tensor_by_name(
576
+ request, 'encoder_input_features', batch_size, batch_index)
577
+ if encoder_input_features is not None:
578
+ if isinstance(encoder_input_features, np.ndarray):
579
+ encoder_input_features = from_numpy(
580
+ encoder_input_features).squeeze(dim=0)
581
+ elif isinstance(encoder_input_features, torch.Tensor):
582
+ encoder_input_features = encoder_input_features.squeeze(dim=0)
583
+ inputs['encoder_input_features'] = encoder_input_features
584
+ logger.debug(
585
+ f"inputs to llm: encoder_input_features ({encoder_input_features.shape}"
586
+ )
587
+
588
+ encoder_output_length = get_input_tensor_by_name(
589
+ request, 'encoder_output_lengths', batch_size, batch_index)
590
+ if encoder_output_length is not None:
591
+ inputs['encoder_output_length'] = np.squeeze(
592
+ encoder_output_length, axis=0)
593
+
594
+ cross_attention_mask = get_input_tensor_by_name(
595
+ request, 'cross_attention_mask', batch_size, batch_index)
596
+ if cross_attention_mask is not None:
597
+ inputs['cross_attention_mask'] = cross_attention_mask[0]
598
+ logger.debug(
599
+ f"inputs to llm: cross_attention_mask ({ cross_attention_mask.shape})"
600
+ )
601
+
602
+ skip_cross_attn_blocks = get_input_tensor_by_name(
603
+ request,
604
+ 'skip_cross_attn_blocks',
605
+ batch_size,
606
+ batch_index,
607
+ force_on_torch=True)
608
+ if skip_cross_attn_blocks is not None:
609
+ inputs['skip_cross_attn_blocks'] = skip_cross_attn_blocks[0]
610
+ logger.debug(
611
+ f"inputs to llm: skip_cross_attn_blocks ({ skip_cross_attn_blocks.shape})"
612
+ )
613
+
614
+ guided_decoding_params = get_guided_decoding_params_from_request(
615
+ request, batch_size, batch_index)
616
+
617
+ requests.append(
618
+ trtllm.Request(
619
+ **inputs,
620
+ sampling_config=sampling_config,
621
+ output_config=output_config,
622
+ external_draft_tokens_config=external_draft_tokens_config,
623
+ prompt_tuning_config=prompt_tuning_config,
624
+ lora_config=lora_config,
625
+ guided_decoding_params=guided_decoding_params,
626
+ lookahead_config=request_lookahead_config,
627
+ kv_cache_retention_config=kv_cache_retention_config))
628
+ return requests
629
+
630
+
631
+ def convert_response(response,
632
+ batch_index,
633
+ batch_size,
634
+ num_return_sequences,
635
+ expected_logits_dtype=torch.float32):
636
+
637
+ if response.has_error():
638
+ return pb_utils.InferenceResponse(output_tensors=[],
639
+ error=pb_utils.TritonError(
640
+ response.error_msg)), True, 0
641
+ result = response.result
642
+ beam_lengths = np.expand_dims(
643
+ np.array([len(beam) for beam in result.output_token_ids], np.int32), 0)
644
+ max_beam_length = max([len(beam) for beam in result.output_token_ids])
645
+ output_ids = np.full((1, len(result.output_token_ids), max_beam_length),
646
+ -1, np.int32)
647
+ for idx, beam in enumerate(result.output_token_ids):
648
+ output_ids[0, idx, :len(beam)] = beam
649
+
650
+ output_lengths = output_ids.size
651
+ output_tensors = [
652
+ pb_utils.Tensor("output_ids", output_ids),
653
+ pb_utils.Tensor("sequence_length", beam_lengths),
654
+ ]
655
+
656
+ if result.cum_log_probs is not None:
657
+ output_tensors.append(
658
+ pb_utils.Tensor(
659
+ "cum_log_probs",
660
+ np.expand_dims(np.array(result.cum_log_probs, np.float32), 0)))
661
+
662
+ if result.log_probs is not None:
663
+ output_tensors.append(
664
+ pb_utils.Tensor(
665
+ "output_log_probs",
666
+ np.expand_dims(np.array(result.log_probs, np.float32), 0)))
667
+
668
+ if result.context_logits is not None:
669
+ assert (result.context_logits.dtype is expected_logits_dtype)
670
+ output_tensors.append(
671
+ pb_utils.Tensor(
672
+ "context_logits",
673
+ np.expand_dims(
674
+ np.array(
675
+ result.context_logits, torch_to_numpy_dtype_dict[
676
+ result.context_logits.dtype]), 0)))
677
+
678
+ if result.generation_logits is not None:
679
+ assert (result.generation_logits.dtype is expected_logits_dtype)
680
+ output_tensors.append(
681
+ pb_utils.Tensor(
682
+ "generation_logits",
683
+ np.expand_dims(
684
+ np.array(
685
+ result.generation_logits, torch_to_numpy_dtype_dict[
686
+ result.generation_logits.dtype]), 0)))
687
+
688
+ if batch_size > 1:
689
+ output_tensors.append(
690
+ pb_utils.Tensor(
691
+ "batch_index",
692
+ np.expand_dims(np.array([batch_index], np.int32), 0)))
693
+
694
+ if num_return_sequences > 1:
695
+ output_tensors.append(
696
+ pb_utils.Tensor(
697
+ "sequence_index",
698
+ np.expand_dims(np.array([result.sequence_index], np.int32),
699
+ 0)))
700
+
701
+ if result.request_perf_metrics is not None:
702
+ kv_cache_metrics = result.request_perf_metrics.kv_cache_metrics
703
+ output_tensors.append(
704
+ pb_utils.Tensor(
705
+ "kv_cache_alloc_new_blocks",
706
+ np.expand_dims(
707
+ np.array([kv_cache_metrics.num_new_allocated_blocks],
708
+ np.int32), 0)))
709
+ output_tensors.append(
710
+ pb_utils.Tensor(
711
+ "kv_cache_reused_blocks",
712
+ np.expand_dims(
713
+ np.array([kv_cache_metrics.num_reused_blocks], np.int32),
714
+ 0)))
715
+ output_tensors.append(
716
+ pb_utils.Tensor(
717
+ "kv_cache_alloc_total_blocks",
718
+ np.expand_dims(
719
+ np.array([kv_cache_metrics.num_total_allocated_blocks],
720
+ np.int32), 0)))
721
+
722
+ timing_metrics = result.request_perf_metrics.timing_metrics
723
+ output_tensors.append(
724
+ pb_utils.Tensor(
725
+ "arrival_time_ns",
726
+ np.expand_dims(
727
+ np.array([pd.Timedelta(timing_metrics.arrival_time).value],
728
+ np.int64), 0)))
729
+ output_tensors.append(
730
+ pb_utils.Tensor(
731
+ "first_scheduled_time_ns",
732
+ np.expand_dims(
733
+ np.array([
734
+ pd.Timedelta(timing_metrics.first_scheduled_time).value
735
+ ], np.int64), 0)))
736
+ output_tensors.append(
737
+ pb_utils.Tensor(
738
+ "first_token_time_ns",
739
+ np.expand_dims(
740
+ np.array(
741
+ [pd.Timedelta(timing_metrics.first_token_time).value],
742
+ np.int64), 0)))
743
+ output_tensors.append(
744
+ pb_utils.Tensor(
745
+ "last_token_time_ns",
746
+ np.expand_dims(
747
+ np.array(
748
+ [pd.Timedelta(timing_metrics.last_token_time).value],
749
+ np.int64), 0)))
750
+
751
+ spec_dec_metrics = result.request_perf_metrics.speculative_decoding
752
+ output_tensors.append(
753
+ pb_utils.Tensor(
754
+ "acceptance_rate",
755
+ np.expand_dims(
756
+ np.array([spec_dec_metrics.acceptance_rate], np.float32),
757
+ 0)))
758
+ output_tensors.append(
759
+ pb_utils.Tensor(
760
+ "total_accepted_draft_tokens",
761
+ np.expand_dims(
762
+ np.array([spec_dec_metrics.total_accepted_draft_tokens],
763
+ np.int32), 0)))
764
+ output_tensors.append(
765
+ pb_utils.Tensor(
766
+ "total_draft_tokens",
767
+ np.expand_dims(
768
+ np.array([spec_dec_metrics.total_draft_tokens], np.int32),
769
+ 0)))
770
+
771
+ return pb_utils.InferenceResponse(
772
+ output_tensors), result.is_final, output_lengths
773
+
774
+
775
+ def convert_scheduler_policy(batch_scheduler_policy: str):
776
+ if batch_scheduler_policy.lower() == "max_utilization":
777
+ return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
778
+ elif batch_scheduler_policy.lower() == "guaranteed_no_evict":
779
+ return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
780
+ raise pb_utils.TritonModelException(
781
+ f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported."
782
+ )
783
+
784
+
785
+ def convert_batching_type(gpt_model_type: str):
786
+ if gpt_model_type is None:
787
+ return None
788
+ if gpt_model_type.lower(
789
+ ) == "inflight_fused_batching" or gpt_model_type.lower(
790
+ ) == "inflight_batching":
791
+ return trtllm.BatchingType.INFLIGHT
792
+ elif gpt_model_type.lower() == "v1":
793
+ return trtllm.BatchingType.STATIC
794
+ raise pb_utils.TritonModelException(
795
+ f"gpt_model_type value of '{gpt_model_type}' is not supported.")
796
+
797
+
798
+ def convert_decoding_mode(decoding_mode: str):
799
+ if decoding_mode is None:
800
+ return None
801
+ elif decoding_mode == "auto":
802
+ return trtllm.DecodingMode.Auto()
803
+ elif decoding_mode == "top_k":
804
+ return trtllm.DecodingMode.TopK()
805
+ elif decoding_mode == "top_p":
806
+ return trtllm.DecodingMode.TopP()
807
+ elif decoding_mode == "top_k_top_p":
808
+ return trtllm.DecodingMode.TopKTopP()
809
+ elif decoding_mode == "beam_search":
810
+ return trtllm.DecodingMode.BeamSearch()
811
+ elif decoding_mode == "medusa":
812
+ return trtllm.DecodingMode.Medusa()
813
+ elif decoding_mode == "redrafter":
814
+ return trtllm.DecodingMode.ExplicitDraftTokens()
815
+ elif decoding_mode == "lookahead":
816
+ return trtllm.DecodingMode.Lookahead()
817
+ elif decoding_mode == "eagle":
818
+ return trtllm.DecodingMode.Eagle()
819
+ raise pb_utils.TritonModelException(
820
+ f"decoding_mode value of '{decoding_mode}' is not supported.")
821
+
822
+
823
+ def convert_timestamp_to_seconds(timestamp: str):
824
+ return int(
825
+ datetime.datetime.strptime(timestamp,
826
+ "%m-%d-%Y %H:%M:%S.%f").timestamp())
827
+
828
+
829
+ def triton_string_to_torch(dtype):
830
+ type_map = {
831
+ "TYPE_BOOL": torch.bool,
832
+ "TYPE_UINT8": torch.uint8,
833
+ "TYPE_INT8": torch.int8,
834
+ "TYPE_INT16": torch.int16,
835
+ "TYPE_INT32": torch.int32,
836
+ "TYPE_INT64": torch.int64,
837
+ "TYPE_FP16": torch.float16,
838
+ "TYPE_FP32": torch.float32,
839
+ "TYPE_FP64": torch.float64,
840
+ "TYPE_BF16": torch.bfloat16
841
+ }
842
+ return type_map[dtype]
843
+
844
+
845
+ class TritonPythonModel:
846
+ """Your Python model must use the same class name. Every Python model
847
+ that is created must have "TritonPythonModel" as the class name.
848
+ """
849
+
850
+ def get_scheduler_config(self, model_config):
851
+ batch_scheduler_policy = get_parameter(model_config,
852
+ "batch_scheduler_policy")
853
+ if batch_scheduler_policy is None:
854
+ return trtllm.SchedulerConfig()
855
+ return trtllm.SchedulerConfig(
856
+ convert_scheduler_policy(batch_scheduler_policy))
857
+
858
+ def get_kv_cache_config(self, model_config):
859
+ kwargs = {
860
+ "enable_block_reuse":
861
+ get_parameter(model_config, "enable_kv_cache_reuse", bool),
862
+ "max_tokens":
863
+ get_parameter(model_config, "max_tokens_in_paged_kv_cache", int),
864
+ "sink_token_length":
865
+ get_parameter(model_config, "sink_token_length", int),
866
+ "free_gpu_memory_fraction":
867
+ get_parameter(model_config, "kv_cache_free_gpu_mem_fraction",
868
+ float),
869
+ "cross_kv_cache_fraction":
870
+ get_parameter(model_config, "cross_kv_cache_fraction", float),
871
+ "host_cache_size":
872
+ get_parameter(model_config, "kv_cache_host_memory_bytes", int),
873
+ "onboard_blocks":
874
+ get_parameter(model_config, "kv_cache_onboard_blocks", bool),
875
+ }
876
+ max_attention_window_size = get_parameter(model_config,
877
+ "max_attention_window_size")
878
+ if max_attention_window_size:
879
+ kwargs["max_attention_window"] = [
880
+ int(x) for x in max_attention_window_size.split(",")
881
+ ]
882
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
883
+ return trtllm.KvCacheConfig(**kwargs)
884
+
885
+ def get_parallel_config(self, model_config):
886
+ kwargs = {}
887
+ gpu_device_ids = get_parameter(model_config, "gpu_device_ids")
888
+ if gpu_device_ids:
889
+ kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")]
890
+ self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR",
891
+ "0") == "1"
892
+ if self.use_orchestrator_mode:
893
+ kwargs[
894
+ "communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR
895
+ worker_path = get_parameter(model_config, "worker_path")
896
+ spawn_processes = os.environ.get(
897
+ "TRTLLM_ORCHESTRATOR_SPAWN_PROCESSES", "1") == "1"
898
+ if not spawn_processes:
899
+ raise pb_utils.TritonModelException(
900
+ "Orchestrator mode with --disable-spawn-processes is not supported in the Python backend."
901
+ )
902
+ is_orchestrator = (mpi_rank() == 0) if spawn_processes else True
903
+ if worker_path is not None:
904
+ raise pb_utils.TritonModelException(
905
+ "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable."
906
+ )
907
+ executor_worker_path = get_parameter(model_config,
908
+ "executor_worker_path")
909
+ kwargs["orchestrator_config"] = trtllm.OrchestratorConfig(
910
+ is_orchestrator, executor_worker_path)
911
+ if len(kwargs) > 0:
912
+ return trtllm.ParallelConfig(**kwargs)
913
+ return None
914
+
915
+ def get_peft_cache_config(self, model_config):
916
+ kwargs = {
917
+ "optimal_adapter_size":
918
+ get_parameter(model_config, "lora_cache_optimal_adapter_size",
919
+ int),
920
+ "max_adapter_size":
921
+ get_parameter(model_config, "lora_cache_max_adapter_size", int),
922
+ "device_cache_percent":
923
+ get_parameter(model_config, "lora_cache_gpu_memory_fraction",
924
+ float),
925
+ "host_cache_size":
926
+ get_parameter(model_config, "lora_cache_host_memory_bytes", int),
927
+ "lora_prefetch_dir":
928
+ get_parameter(model_config, "lora_prefetch_dir", int),
929
+ }
930
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
931
+ return trtllm.PeftCacheConfig(**kwargs)
932
+
933
+ def get_executor_lookahead_config(self, model_config):
934
+ lookahead_window_size = get_parameter(model_config,
935
+ "lookahead_window_size", int)
936
+ lookahead_ngram_size = get_parameter(model_config,
937
+ "lookahead_ngram_size", int)
938
+ lookahead_verification_set_size = get_parameter(
939
+ model_config, "lookahead_verification_set_size", int)
940
+ # executor_lookahead_config is not set
941
+ if all(item is None for item in [
942
+ lookahead_window_size, lookahead_ngram_size,
943
+ lookahead_verification_set_size
944
+ ]):
945
+ return None
946
+
947
+ incomplete_config = None in [
948
+ lookahead_window_size, lookahead_ngram_size,
949
+ lookahead_verification_set_size
950
+ ]
951
+
952
+ assert (
953
+ not incomplete_config
954
+ ), "Please set executor_lookahead_window_size, executor_lookahead_ngram_size and executor_lookahead_verification_set_size together."
955
+
956
+ return trtllm.LookaheadDecodingConfig(lookahead_window_size,
957
+ lookahead_ngram_size,
958
+ lookahead_verification_set_size)
959
+
960
+ def get_decoding_config(self, model_config):
961
+
962
+ decoding_mode = convert_decoding_mode(
963
+ get_parameter(model_config, "decoding_mode"))
964
+ self.executor_lookahead_config = None
965
+ if decoding_mode == trtllm.DecodingMode.Lookahead():
966
+ # Add LAD config
967
+ self.executor_lookahead_config = self.get_executor_lookahead_config(
968
+ model_config)
969
+ eagle_choices = parse_eagle_choices(
970
+ get_parameter(model_config, "eagle_choices"))
971
+ kwargs = {
972
+ "medusa_choices":
973
+ parse_medusa_choices(get_parameter(model_config,
974
+ "medusa_choices")),
975
+ "eagle_config":
976
+ None
977
+ if eagle_choices is None else trtllm.EagleConfig(eagle_choices),
978
+ "lookahead_decoding_config":
979
+ self.executor_lookahead_config,
980
+ "decoding_mode":
981
+ decoding_mode,
982
+ }
983
+ print(kwargs)
984
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
985
+ return trtllm.DecodingConfig(**kwargs)
986
+
987
+ def get_extended_runtime_perf_knob_config(self, model_config):
988
+ kwargs = {
989
+ "multi_block_mode":
990
+ get_parameter(model_config, "multi_block_mode", bool),
991
+ "enable_context_fmha_fp32_acc":
992
+ get_parameter(model_config, "enable_context_fmha_fp32_acc", bool),
993
+ "cuda_graph_mode":
994
+ get_parameter(model_config, "cuda_graph_mode", bool),
995
+ "cuda_graph_cache_size":
996
+ get_parameter(model_config, "cuda_graph_cache_size", int),
997
+ }
998
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
999
+ return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs)
1000
+
1001
+ def get_guided_decoding_config(self, model_config):
1002
+
1003
+ guided_decoding_backend = get_parameter(model_config,
1004
+ "guided_decoding_backend", str)
1005
+
1006
+ tokenizer_dir = get_parameter(model_config, "tokenizer_dir", str)
1007
+ if guided_decoding_backend not in ['xgrammar']:
1008
+ if tokenizer_dir:
1009
+ pb_utils.Logger.log_warn(
1010
+ f"Guided decoding backend has not been set but tokenizer_dir is given. Tokenizer_dir will be ignored."
1011
+ )
1012
+ return None
1013
+
1014
+ if guided_decoding_backend == 'xgrammar':
1015
+ guided_decoding_backend = trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR
1016
+
1017
+ if not tokenizer_dir:
1018
+ raise ValueError(
1019
+ "Guided decoding requires tokenizer's information. Please provide 'tokenizer_dir'."
1020
+ )
1021
+ from transformers import AutoTokenizer
1022
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
1023
+ pb_utils.Logger.log_info(
1024
+ f"Guided decoding has been set with {guided_decoding_backend} backend"
1025
+ )
1026
+ return trtllm.GuidedDecodingConfig(
1027
+ backend=guided_decoding_backend,
1028
+ **_xgrammar_tokenizer_info(tokenizer))
1029
+
1030
+ def get_executor_config(self, model_config):
1031
+ kwargs = {
1032
+ "max_beam_width":
1033
+ get_parameter(model_config, "max_beam_width", int),
1034
+ "scheduler_config":
1035
+ self.get_scheduler_config(model_config),
1036
+ "kv_cache_config":
1037
+ self.get_kv_cache_config(model_config),
1038
+ "enable_chunked_context":
1039
+ get_parameter(model_config, "enable_chunked_context", bool),
1040
+ "normalize_log_probs":
1041
+ get_parameter(model_config, "normalize_log_probs", bool),
1042
+ "batching_type":
1043
+ convert_batching_type(get_parameter(model_config,
1044
+ "gpt_model_type")),
1045
+ "parallel_config":
1046
+ self.get_parallel_config(model_config),
1047
+ "peft_cache_config":
1048
+ self.get_peft_cache_config(model_config),
1049
+ "decoding_config":
1050
+ self.get_decoding_config(model_config),
1051
+ "max_queue_size":
1052
+ model_config.get(
1053
+ "dynamic_batching",
1054
+ {},
1055
+ ).get(
1056
+ "default_queue_policy",
1057
+ {},
1058
+ ).get("max_queue_size"),
1059
+ "extended_runtime_perf_knob_config":
1060
+ self.get_extended_runtime_perf_knob_config(model_config),
1061
+ "guided_decoding_config":
1062
+ self.get_guided_decoding_config(model_config)
1063
+ }
1064
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
1065
+ return trtllm.ExecutorConfig(**kwargs)
1066
+
1067
+ def create_metrics(self, model: str, version: str, is_v1_model: bool):
1068
+ self.request_metric_family = pb_utils.MetricFamily(
1069
+ name="nv_trt_llm_request_metrics",
1070
+ description="TRT LLM request metrics",
1071
+ kind=pb_utils.MetricFamily.GAUGE,
1072
+ )
1073
+ self.runtime_memory_metric_family = pb_utils.MetricFamily(
1074
+ name="nv_trt_llm_runtime_memory_metrics",
1075
+ description="TRT LLM runtime memory metrics",
1076
+ kind=pb_utils.MetricFamily.GAUGE,
1077
+ )
1078
+ self.kv_cache_metric_family = pb_utils.MetricFamily(
1079
+ name="nv_trt_llm_kv_cache_block_metrics",
1080
+ description="TRT LLM KV cache block metrics",
1081
+ kind=pb_utils.MetricFamily.GAUGE,
1082
+ )
1083
+ model_type = "v1" if is_v1_model else "inflight_batcher"
1084
+ self.model_type_metric_family = pb_utils.MetricFamily(
1085
+ name=f"nv_trt_llm_{model_type}_metrics",
1086
+ description=f"TRT LLM {model_type}-specific metrics",
1087
+ kind=pb_utils.MetricFamily.GAUGE,
1088
+ )
1089
+ self.general_metric_family = pb_utils.MetricFamily(
1090
+ name="nv_trt_llm_general_metrics",
1091
+ description="General TRT LLM metrics",
1092
+ kind=pb_utils.MetricFamily.GAUGE,
1093
+ )
1094
+ # Set the metric using self.general_metric_output_family.observe(string_size)
1095
+ self.request_tokens_metric_family = pb_utils.MetricFamily(
1096
+ name="nv_llm_input_token_len",
1097
+ description="TRT LLM response metrics",
1098
+ kind=pb_utils.MetricFamily.HISTOGRAM,
1099
+ )
1100
+ self.response_tokens_metric_family = pb_utils.MetricFamily(
1101
+ name="nv_llm_output_token_len",
1102
+ description="TRT LLM response metrics",
1103
+ kind=pb_utils.MetricFamily.HISTOGRAM,
1104
+ )
1105
+ common_labels = {"model": model, "version": version}
1106
+ self.all_metrics = {
1107
+ # Request metrics
1108
+ "num_active_requests":
1109
+ self.request_metric_family.Metric(labels={
1110
+ "request_type": "active",
1111
+ **common_labels
1112
+ }),
1113
+ "max_num_active_requests":
1114
+ self.request_metric_family.Metric(labels={
1115
+ "request_type": "max",
1116
+ **common_labels
1117
+ }),
1118
+ "num_scheduled_requests":
1119
+ self.request_metric_family.Metric(labels={
1120
+ "request_type": "scheduled",
1121
+ **common_labels
1122
+ }),
1123
+ "num_context_requests":
1124
+ self.request_metric_family.Metric(labels={
1125
+ "request_type": "context",
1126
+ **common_labels
1127
+ }),
1128
+ # Runtime metrics
1129
+ "cpu_mem_usage":
1130
+ self.runtime_memory_metric_family.Metric(labels={
1131
+ "memory_type": "cpu",
1132
+ **common_labels
1133
+ }),
1134
+ "gpu_mem_usage":
1135
+ self.runtime_memory_metric_family.Metric(labels={
1136
+ "memory_type": "gpu",
1137
+ **common_labels
1138
+ }),
1139
+ "pinned_mem_usage":
1140
+ self.runtime_memory_metric_family.Metric(labels={
1141
+ "memory_type": "pinned",
1142
+ **common_labels
1143
+ }),
1144
+ # KV cache metrics
1145
+ "max_num_blocks":
1146
+ self.kv_cache_metric_family.Metric(labels={
1147
+ "kv_cache_block_type": "max",
1148
+ **common_labels
1149
+ }),
1150
+ "free_num_blocks":
1151
+ self.kv_cache_metric_family.Metric(labels={
1152
+ "kv_cache_block_type": "free",
1153
+ **common_labels
1154
+ }),
1155
+ "used_num_blocks":
1156
+ self.kv_cache_metric_family.Metric(labels={
1157
+ "kv_cache_block_type": "used",
1158
+ **common_labels
1159
+ }),
1160
+ "tokens_per_block":
1161
+ self.kv_cache_metric_family.Metric(labels={
1162
+ "kv_cache_block_type": "tokens_per",
1163
+ **common_labels
1164
+ }),
1165
+ # General metrics
1166
+ "timestamp":
1167
+ self.general_metric_family.Metric(labels={
1168
+ "general_type": "timestamp",
1169
+ **common_labels
1170
+ }),
1171
+ "iter":
1172
+ self.general_metric_family.Metric(labels={
1173
+ "general_type": "iteration_counter",
1174
+ **common_labels
1175
+ }),
1176
+ METRIC_TOTAL_OUTPUT_TOKENS:
1177
+ self.response_tokens_metric_family.Metric(
1178
+ labels={
1179
+ "response_metric_type": METRIC_TOTAL_OUTPUT_TOKENS,
1180
+ **common_labels
1181
+ },
1182
+ buckets=build_1_2_5_buckets(1000)),
1183
+ METRIC_TOTAL_INPUT_TOKENS:
1184
+ self.request_tokens_metric_family.Metric(
1185
+ labels={
1186
+ "response_metric_type": METRIC_TOTAL_INPUT_TOKENS,
1187
+ **common_labels
1188
+ },
1189
+ buckets=build_1_2_5_buckets(1000)),
1190
+ }
1191
+ if is_v1_model:
1192
+ self.all_metrics.update({
1193
+ "num_ctx_tokens":
1194
+ self.model_type_metric_family.Metric(labels={
1195
+ "v1_specific_metric": "total_context_tokens",
1196
+ **common_labels
1197
+ }),
1198
+ "num_gen_tokens":
1199
+ self.model_type_metric_family.Metric(
1200
+ labels={
1201
+ "v1_specific_metric": "total_generation_tokens",
1202
+ **common_labels
1203
+ }),
1204
+ "empty_gen_slots":
1205
+ self.model_type_metric_family.Metric(
1206
+ labels={
1207
+ "v1_specific_metric": "empty_generation_slots",
1208
+ **common_labels
1209
+ }),
1210
+ })
1211
+ else:
1212
+ self.all_metrics.update({
1213
+ "num_ctx_tokens":
1214
+ self.model_type_metric_family.Metric(
1215
+ labels={
1216
+ "inflight_batcher_specific_metric":
1217
+ "total_context_tokens",
1218
+ **common_labels
1219
+ }),
1220
+ "num_gen_requests":
1221
+ self.model_type_metric_family.Metric(
1222
+ labels={
1223
+ "inflight_batcher_specific_metric":
1224
+ "generation_requests",
1225
+ **common_labels
1226
+ }),
1227
+ "micro_batch_id":
1228
+ self.model_type_metric_family.Metric(
1229
+ labels={
1230
+ "inflight_batcher_specific_metric": "micro_batch_id",
1231
+ **common_labels
1232
+ }),
1233
+ "num_paused_requests":
1234
+ self.model_type_metric_family.Metric(
1235
+ labels={
1236
+ "inflight_batcher_specific_metric": "paused_requests",
1237
+ **common_labels
1238
+ }),
1239
+ })
1240
+
1241
+ def initialize(self, args):
1242
+ """`initialize` is called only once when the model is being loaded.
1243
+ Implementing `initialize` function is optional. This function allows
1244
+ the model to initialize any state associated with this model.
1245
+
1246
+ Parameters
1247
+ ----------
1248
+ args : dict
1249
+ Both keys and values are strings. The dictionary keys and values are:
1250
+ * model_config: A JSON string containing the model configuration
1251
+ * model_instance_kind: A string containing model instance kind
1252
+ * model_instance_device_id: A string containing model instance device ID
1253
+ * model_repository: Model repository path
1254
+ * model_version: Model version
1255
+ * model_name: Model name
1256
+ """
1257
+ model_config = json.loads(args['model_config'])
1258
+ gpt_model_path = get_parameter(model_config, "gpt_model_path")
1259
+ if get_parameter(model_config, "enable_trt_overlap", bool):
1260
+ raise pb_utils.TritonModelException(
1261
+ f"enable_trt_overlap=true is not supported.")
1262
+ self.exclude_input_from_output = get_parameter(
1263
+ model_config, "exclude_input_in_output", bool)
1264
+ executor_config = self.get_executor_config(model_config)
1265
+ self.executor = trtllm.Executor(gpt_model_path,
1266
+ trtllm.ModelType.DECODER_ONLY,
1267
+ executor_config)
1268
+ self.decoupled = pb_utils.using_decoupled_model_transaction_policy(
1269
+ model_config)
1270
+ self.cancellation_check_period_ms = get_parameter(
1271
+ model_config, "cancellation_check_period_ms", int) or 100
1272
+ self.stats_check_period_ms = get_parameter(
1273
+ model_config, "stats_check_period_ms", int) or 100
1274
+
1275
+ self.logits_dtype = None
1276
+ for output in model_config['output']:
1277
+ if output['name'] == 'context_logits' or output[
1278
+ 'name'] == 'generation_logits':
1279
+ self.logits_dtype = triton_string_to_torch(output['data_type'])
1280
+
1281
+ self.create_metrics(args["model_name"],
1282
+ args["model_version"],
1283
+ is_v1_model=executor_config.batching_type ==
1284
+ trtllm.BatchingType.STATIC)
1285
+ self.triton_user_id_to_req_ids = {}
1286
+ self.triton_req_id_to_req_ids = {}
1287
+ self.req_id_to_request_data = {}
1288
+ self.lock = Lock()
1289
+ self.running = False
1290
+ self.awaiter_thread = Thread(target=self.awaiter_loop)
1291
+ self.cancellation_thread = Thread(target=self.cancellation_loop)
1292
+ self.metrics_thread = Thread(target=self.metrics_loop)
1293
+ if self.executor.can_enqueue_requests():
1294
+ self.running = True
1295
+ self.awaiter_thread.start()
1296
+ self.cancellation_thread.start()
1297
+ self.metrics_thread.start()
1298
+ else:
1299
+ # In leader mode, worker ranks will wait here until leader is done.
1300
+ self.executor.shutdown()
1301
+
1302
+ def handle_stop_request(self, triton_user_id, response_sender):
1303
+ if triton_user_id is None or triton_user_id == "":
1304
+ response_sender.send(
1305
+ pb_utils.InferenceResponse(error=pb_utils.TritonError(
1306
+ "A request id must be provided for request cancellation")),
1307
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
1308
+ return
1309
+
1310
+ with self.lock:
1311
+ if triton_user_id in self.triton_user_id_to_req_ids:
1312
+ req_ids = self.triton_user_id_to_req_ids[triton_user_id]
1313
+ for req_id in req_ids:
1314
+ self.executor.cancel_request(req_id)
1315
+
1316
+ response_sender.send(
1317
+ pb_utils.InferenceResponse(),
1318
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
1319
+
1320
+ def execute(self, requests):
1321
+ """`execute` must be implemented in every Python model. `execute`
1322
+ function receives a list of pb_utils.InferenceRequest as the only
1323
+ argument. This function is called when an inference is requested
1324
+ for this model.
1325
+
1326
+ Parameters
1327
+ ----------
1328
+ requests : list
1329
+ A list of pb_utils.InferenceRequest
1330
+
1331
+ Returns
1332
+ -------
1333
+ list
1334
+ A list of pb_utils.InferenceResponse. The length of this list must
1335
+ be the same as `requests`
1336
+ """
1337
+ if not self.executor.can_enqueue_requests():
1338
+ return
1339
+
1340
+ # Convert to executor requests.
1341
+
1342
+ triton_requests = []
1343
+ executor_requests = []
1344
+ batch_indices = []
1345
+ triton_user_ids = []
1346
+ triton_req_ids = []
1347
+
1348
+ for request in requests:
1349
+
1350
+ triton_user_id = request.request_id()
1351
+
1352
+ response_sender = request.get_response_sender()
1353
+ stop = get_input_scalar_by_name(request, 'stop')
1354
+
1355
+ if stop:
1356
+ self.handle_stop_request(triton_user_id, response_sender)
1357
+ else:
1358
+ #Unique request id used to identify each triton request
1359
+ triton_req_id = str(randint(0, sys.maxsize))
1360
+ self.triton_req_id_to_req_ids[triton_req_id] = set()
1361
+ if triton_user_id is not None and triton_user_id != "":
1362
+ self.triton_user_id_to_req_ids[triton_user_id] = set()
1363
+
1364
+ try:
1365
+ converted_reqs = convert_request(
1366
+ request, self.exclude_input_from_output,
1367
+ self.decoupled, self.executor_lookahead_config)
1368
+ except Exception as e:
1369
+ response_sender.send(
1370
+ pb_utils.InferenceResponse(error=pb_utils.TritonError(
1371
+ f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'"
1372
+ )),
1373
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
1374
+ else:
1375
+ for batch_index, converted_req in enumerate(
1376
+ converted_reqs):
1377
+ triton_requests.append(request)
1378
+ executor_requests.append(converted_req)
1379
+ triton_user_ids.append(triton_user_id)
1380
+ triton_req_ids.append(triton_req_id)
1381
+ batch_indices.append(batch_index)
1382
+
1383
+ with self.lock:
1384
+ request_ids = self.executor.enqueue_requests(executor_requests)
1385
+ for req_id, triton_req_id, triton_user_id, executor_request, triton_request, batch_index in zip(
1386
+ request_ids, triton_req_ids, triton_user_ids,
1387
+ executor_requests, triton_requests, batch_indices):
1388
+
1389
+ self.req_id_to_request_data[req_id] = RequestData(
1390
+ triton_req_id, triton_user_id, batch_index,
1391
+ len(batch_indices),
1392
+ executor_request.sampling_config.num_return_sequences, 0,
1393
+ 0, triton_request.get_response_sender())
1394
+ self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
1395
+ input_len = len(
1396
+ executor_request.input_token_ids
1397
+ ) if executor_request.input_token_ids is not None else 0
1398
+ self.req_id_to_request_data[
1399
+ req_id].num_input_tokens += input_len
1400
+ # This checks both request level and instance config level
1401
+ if executor_request.output_config.exclude_input_from_output == False and executor_request.streaming == False:
1402
+ self.req_id_to_request_data[
1403
+ req_id].num_output_tokens -= self.req_id_to_request_data[
1404
+ req_id].num_input_tokens * executor_request.sampling_config.beam_width
1405
+ if triton_user_id is not None and triton_user_id != "":
1406
+ self.triton_user_id_to_req_ids[triton_user_id].add(req_id)
1407
+
1408
+ return None
1409
+
1410
+ def awaiter_loop(self):
1411
+ """Gets responses from executor and returns the results."""
1412
+ while self.running:
1413
+ for response in self.executor.await_responses(
1414
+ timeout=datetime.timedelta(milliseconds=1)):
1415
+ req_id = response.request_id
1416
+ request_data = None
1417
+ with self.lock:
1418
+ if req_id not in self.req_id_to_request_data:
1419
+ continue
1420
+ request_data = self.req_id_to_request_data[req_id]
1421
+
1422
+ triton_response, is_final, output_length = convert_response(
1423
+ response, request_data.batch_index,
1424
+ request_data.batch_size, request_data.num_return_sequences,
1425
+ self.logits_dtype)
1426
+ with self.lock:
1427
+ self.req_id_to_request_data[
1428
+ req_id].num_output_tokens += output_length
1429
+ triton_request_final = False
1430
+ if is_final:
1431
+ with self.lock:
1432
+ # Check if all executor requests part of that triton request are finished
1433
+ self.triton_req_id_to_req_ids[
1434
+ request_data.triton_req_id].remove(req_id)
1435
+ if len(self.triton_req_id_to_req_ids[
1436
+ request_data.triton_req_id]) == 0:
1437
+ pb_utils.Logger.log_info(
1438
+ f"DELETING Req id {req_id}, triton_req_id {request_data.triton_req_id} "
1439
+ )
1440
+ triton_request_final = True
1441
+ del self.triton_req_id_to_req_ids[
1442
+ request_data.triton_req_id]
1443
+ if request_data.triton_user_id is not None and request_data.triton_user_id != "":
1444
+ del self.triton_user_id_to_req_ids[
1445
+ request_data.triton_user_id]
1446
+ self.update_metrics_per_request(req_id)
1447
+ del self.req_id_to_request_data[req_id]
1448
+
1449
+ request_data.response_sender.send(
1450
+ triton_response,
1451
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
1452
+ if triton_request_final else 0)
1453
+
1454
+ def cancellation_loop(self):
1455
+ """Checks if any pending requests have been cancelled."""
1456
+ while self.running:
1457
+ time.sleep(self.cancellation_check_period_ms / 1000.0)
1458
+ with self.lock:
1459
+ for req_id, request_data in self.req_id_to_request_data.items(
1460
+ ):
1461
+ if request_data.response_sender.is_cancelled():
1462
+ self.executor.cancel_request(req_id)
1463
+
1464
+ def update_metrics_per_request(self, req_id):
1465
+ """Updates triton metrics after completing one request"""
1466
+ output_tokens = self.req_id_to_request_data[req_id].num_output_tokens
1467
+ input_tokens = self.req_id_to_request_data[req_id].num_input_tokens
1468
+
1469
+ self.all_metrics[METRIC_TOTAL_OUTPUT_TOKENS].observe(output_tokens)
1470
+ self.all_metrics[METRIC_TOTAL_INPUT_TOKENS].observe(input_tokens)
1471
+
1472
+ def metrics_loop(self):
1473
+ """Updates triton metrics using stats from the executor."""
1474
+ while self.running:
1475
+ time.sleep(self.stats_check_period_ms / 1000.0)
1476
+ for stat in self.executor.get_latest_iteration_stats():
1477
+ try:
1478
+ for key, metric in self.all_metrics.items():
1479
+ # Skip processing for both histogram metrics
1480
+ if isinstance(key, str) and key in [
1481
+ METRIC_TOTAL_OUTPUT_TOKENS,
1482
+ METRIC_TOTAL_INPUT_TOKENS
1483
+ ]:
1484
+ continue
1485
+ value = None
1486
+ if hasattr(stat, key):
1487
+ value = getattr(stat, key)
1488
+ elif stat.kv_cache_stats is not None and hasattr(
1489
+ stat.kv_cache_stats, key):
1490
+ value = getattr(stat.kv_cache_stats, key)
1491
+ elif stat.static_batching_stats is not None and hasattr(
1492
+ stat.static_batching_stats, key):
1493
+ value = getattr(stat.static_batching_stats, key)
1494
+ elif stat.inflight_batching_stats is not None and hasattr(
1495
+ stat.inflight_batching_stats, key):
1496
+ value = getattr(stat.inflight_batching_stats, key)
1497
+ if value is not None:
1498
+ if key == "timestamp":
1499
+ value = convert_timestamp_to_seconds(value)
1500
+ metric.set(value)
1501
+ else:
1502
+ pb_utils.Logger.log_warn(
1503
+ f"Metric \"{key}\" not found.")
1504
+ except Exception as e:
1505
+ pb_utils.Logger.log_warn(
1506
+ f"Error while processing metrics: {e}")
1507
+
1508
+ def finalize(self):
1509
+ """`finalize` is called only once when the model is being unloaded.
1510
+ Implementing `finalize` function is optional. This function allows
1511
+ the model to perform any necessary clean ups before exit.
1512
+ """
1513
+ if self.executor.can_enqueue_requests():
1514
+ self.running = False
1515
+ self.awaiter_thread.join()
1516
+ self.cancellation_thread.join()
1517
+ self.metrics_thread.join()
1518
+ self.executor.shutdown()
model_repo_whisper_512/tensorrt_llm/config.pbtxt ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "tensorrtllm"
29
+ max_batch_size: 512
30
+
31
+ model_transaction_policy {
32
+ decoupled: false
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ 512 ]
37
+ max_queue_delay_microseconds: 5000
38
+ default_queue_policy: { max_queue_size: 0 }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: TYPE_FP16
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "num_return_sequences"
77
+ data_type: TYPE_INT32
78
+ dims: [ 1 ]
79
+ reshape: { shape: [ ] }
80
+ optional: true
81
+ },
82
+ {
83
+ name: "draft_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_ids"
91
+ data_type: TYPE_INT32
92
+ dims: [ -1 ]
93
+ optional: true
94
+ allow_ragged_batch: true
95
+ },
96
+ {
97
+ name: "decoder_input_lengths"
98
+ data_type: TYPE_INT32
99
+ dims: [ 1 ]
100
+ optional: true
101
+ reshape: { shape: [ ] }
102
+ },
103
+ {
104
+ name: "draft_logits"
105
+ data_type: TYPE_FP32
106
+ dims: [ -1, -1 ]
107
+ optional: true
108
+ allow_ragged_batch: true
109
+ },
110
+ {
111
+ name: "draft_acceptance_threshold"
112
+ data_type: TYPE_FP32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "end_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "pad_id"
126
+ data_type: TYPE_INT32
127
+ dims: [ 1 ]
128
+ reshape: { shape: [ ] }
129
+ optional: true
130
+ },
131
+ {
132
+ name: "stop_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "bad_words_list"
140
+ data_type: TYPE_INT32
141
+ dims: [ 2, -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "embedding_bias"
147
+ data_type: TYPE_FP32
148
+ dims: [ -1 ]
149
+ optional: true
150
+ allow_ragged_batch: true
151
+ },
152
+ {
153
+ name: "beam_width"
154
+ data_type: TYPE_INT32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "temperature"
161
+ data_type: TYPE_FP32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_k"
168
+ data_type: TYPE_INT32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_min"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_decay"
189
+ data_type: TYPE_FP32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "runtime_top_p_reset_ids"
196
+ data_type: TYPE_INT32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "len_penalty"
203
+ data_type: TYPE_FP32
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "early_stopping"
210
+ data_type: TYPE_BOOL
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "repetition_penalty"
217
+ data_type: TYPE_FP32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "min_length"
224
+ data_type: TYPE_INT32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "beam_search_diversity_rate"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "presence_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "frequency_penalty"
245
+ data_type: TYPE_FP32
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "random_seed"
252
+ data_type: TYPE_UINT64
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_log_probs"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_context_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "return_generation_logits"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "return_perf_metrics"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "exclude_input_in_output"
287
+ data_type: TYPE_BOOL
288
+ dims: [ 1 ]
289
+ reshape: { shape: [ ] }
290
+ optional: true
291
+ },
292
+ {
293
+ name: "stop"
294
+ data_type: TYPE_BOOL
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ {
300
+ name: "streaming"
301
+ data_type: TYPE_BOOL
302
+ dims: [ 1 ]
303
+ reshape: { shape: [ ] }
304
+ optional: true
305
+ },
306
+ {
307
+ name: "prompt_embedding_table"
308
+ data_type: TYPE_FP16
309
+ dims: [ -1, -1 ]
310
+ optional: true
311
+ allow_ragged_batch: true
312
+ },
313
+ {
314
+ name: "prompt_table_extra_ids"
315
+ data_type: TYPE_UINT64
316
+ dims: [ -1 ]
317
+ optional: true
318
+ allow_ragged_batch: true
319
+ },
320
+ {
321
+ name: "prompt_vocab_size"
322
+ data_type: TYPE_INT32
323
+ dims: [ 1 ]
324
+ reshape: { shape: [ ] }
325
+ optional: true
326
+ },
327
+ # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
328
+ {
329
+ name: "cross_attention_mask"
330
+ data_type: TYPE_BOOL
331
+ dims: [ -1, -1 ]
332
+ optional: true
333
+ allow_ragged_batch: true
334
+ },
335
+ # the unique task ID for the given LoRA.
336
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
337
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
338
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
339
+ {
340
+ name: "lora_task_id"
341
+ data_type: TYPE_UINT64
342
+ dims: [ 1 ]
343
+ reshape: { shape: [ ] }
344
+ optional: true
345
+ },
346
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
347
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
348
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
349
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
350
+ {
351
+ name: "lora_weights"
352
+ data_type: TYPE_FP16
353
+ dims: [ -1, -1 ]
354
+ optional: true
355
+ allow_ragged_batch: true
356
+ },
357
+ # module identifier (same size a first dimension of lora_weights)
358
+ # See LoraModule::ModuleType for model id mapping
359
+ #
360
+ # "attn_qkv": 0 # compbined qkv adapter
361
+ # "attn_q": 1 # q adapter
362
+ # "attn_k": 2 # k adapter
363
+ # "attn_v": 3 # v adapter
364
+ # "attn_dense": 4 # adapter for the dense layer in attention
365
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
366
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
367
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
368
+ #
369
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
370
+ {
371
+ name: "lora_config"
372
+ data_type: TYPE_INT32
373
+ dims: [ -1, 3 ]
374
+ optional: true
375
+ allow_ragged_batch: true
376
+ },
377
+ {
378
+ name: "context_phase_params"
379
+ data_type: TYPE_UINT8
380
+ dims: [ -1 ]
381
+ optional: true
382
+ allow_ragged_batch: true
383
+ },
384
+ # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
385
+ {
386
+ name: "skip_cross_attn_blocks"
387
+ data_type: TYPE_BOOL
388
+ dims: [ 1 ]
389
+ optional: true
390
+ allow_ragged_batch: true
391
+ },
392
+ {
393
+ name: "retention_token_range_starts"
394
+ data_type: TYPE_INT32
395
+ dims: [ -1 ]
396
+ optional: true
397
+ allow_ragged_batch: true
398
+ },
399
+ {
400
+ name: "retention_token_range_ends"
401
+ data_type: TYPE_INT32
402
+ dims: [ -1 ]
403
+ optional: true
404
+ allow_ragged_batch: true
405
+ },
406
+ {
407
+ name: "retention_token_range_priorities"
408
+ data_type: TYPE_INT32
409
+ dims: [ -1 ]
410
+ optional: true
411
+ allow_ragged_batch: true
412
+ },
413
+ {
414
+ name: "retention_token_range_durations_ms"
415
+ data_type: TYPE_INT32
416
+ dims: [ -1 ]
417
+ optional: true
418
+ allow_ragged_batch: true
419
+ },
420
+ {
421
+ name: "retention_decode_priority"
422
+ data_type: TYPE_INT32
423
+ dims: [ 1 ]
424
+ optional: true
425
+ allow_ragged_batch: true
426
+ },
427
+ {
428
+ name: "retention_decode_duration_ms"
429
+ data_type: TYPE_INT32
430
+ dims: [ 1 ]
431
+ optional: true
432
+ allow_ragged_batch: true
433
+ },
434
+ {
435
+ name: "guided_decoding_guide_type"
436
+ data_type: TYPE_STRING
437
+ dims: [ 1 ]
438
+ optional: true
439
+ allow_ragged_batch: true
440
+ },
441
+ {
442
+ name: "guided_decoding_guide"
443
+ data_type: TYPE_STRING
444
+ dims: [ 1 ]
445
+ optional: true
446
+ allow_ragged_batch: true
447
+ },
448
+ {
449
+ name: "lookahead_window_size"
450
+ data_type: TYPE_INT32
451
+ dims: [ 1 ]
452
+ optional: true
453
+ allow_ragged_batch: true
454
+ },
455
+ {
456
+ name: "lookahead_ngram_size"
457
+ data_type: TYPE_INT32
458
+ dims: [ 1 ]
459
+ optional: true
460
+ allow_ragged_batch: true
461
+ },
462
+ {
463
+ name: "lookahead_verification_set_size"
464
+ data_type: TYPE_INT32
465
+ dims: [ 1 ]
466
+ optional: true
467
+ allow_ragged_batch: true
468
+ }
469
+ ]
470
+ output [
471
+ {
472
+ name: "output_ids"
473
+ data_type: TYPE_INT32
474
+ dims: [ -1, -1 ]
475
+ },
476
+ {
477
+ name: "sequence_length"
478
+ data_type: TYPE_INT32
479
+ dims: [ -1 ]
480
+ },
481
+ {
482
+ name: "cum_log_probs"
483
+ data_type: TYPE_FP32
484
+ dims: [ -1 ]
485
+ },
486
+ {
487
+ name: "output_log_probs"
488
+ data_type: TYPE_FP32
489
+ dims: [ -1, -1 ]
490
+ },
491
+ {
492
+ name: "context_logits"
493
+ data_type: TYPE_FP32
494
+ dims: [ -1, -1 ]
495
+ },
496
+ {
497
+ name: "generation_logits"
498
+ data_type: TYPE_FP32
499
+ dims: [ -1, -1, -1 ]
500
+ },
501
+ {
502
+ name: "batch_index"
503
+ data_type: TYPE_INT32
504
+ dims: [ 1 ]
505
+ },
506
+ {
507
+ name: "sequence_index"
508
+ data_type: TYPE_INT32
509
+ dims: [ 1 ]
510
+ },
511
+ {
512
+ name: "context_phase_params"
513
+ data_type: TYPE_UINT8
514
+ dims: [ -1 ]
515
+ },
516
+ {
517
+ name: "kv_cache_alloc_new_blocks"
518
+ data_type: TYPE_INT32
519
+ dims: [ 1 ]
520
+ },
521
+ {
522
+ name: "kv_cache_reused_blocks"
523
+ data_type: TYPE_INT32
524
+ dims: [ 1 ]
525
+ },
526
+ {
527
+ name: "kv_cache_alloc_total_blocks"
528
+ data_type: TYPE_INT32
529
+ dims: [ 1 ]
530
+ },
531
+ {
532
+ name: "arrival_time_ns"
533
+ data_type: TYPE_INT64
534
+ dims: [ 1 ]
535
+ },
536
+ {
537
+ name: "first_scheduled_time_ns"
538
+ data_type: TYPE_INT64
539
+ dims: [ 1 ]
540
+ },
541
+ {
542
+ name: "first_token_time_ns"
543
+ data_type: TYPE_INT64
544
+ dims: [ 1 ]
545
+ },
546
+ {
547
+ name: "last_token_time_ns"
548
+ data_type: TYPE_INT64
549
+ dims: [ 1 ]
550
+ },
551
+ {
552
+ name: "acceptance_rate"
553
+ data_type: TYPE_FP32
554
+ dims: [ 1 ]
555
+ },
556
+ {
557
+ name: "total_accepted_draft_tokens"
558
+ data_type: TYPE_INT32
559
+ dims: [ 1 ]
560
+ },
561
+ {
562
+ name: "total_draft_tokens"
563
+ data_type: TYPE_INT32
564
+ dims: [ 1 ]
565
+ }
566
+ ]
567
+ instance_group [
568
+ {
569
+ count: 1
570
+ kind : KIND_CPU
571
+ }
572
+ ]
573
+ parameters: {
574
+ key: "max_beam_width"
575
+ value: {
576
+ string_value: "4"
577
+ }
578
+ }
579
+ parameters: {
580
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
581
+ value: {
582
+ string_value: "no"
583
+ }
584
+ }
585
+ parameters: {
586
+ key: "gpt_model_type"
587
+ value: {
588
+ string_value: "inflight_fused_batching"
589
+ }
590
+ }
591
+ parameters: {
592
+ key: "gpt_model_path"
593
+ value: {
594
+ string_value: "/models/whisper_large_v3_max_batch_512/decoder"
595
+ }
596
+ }
597
+ parameters: {
598
+ key: "encoder_model_path"
599
+ value: {
600
+ string_value: "/models/whisper_large_v3_max_batch_512/encoder"
601
+ }
602
+ }
603
+ parameters: {
604
+ key: "max_tokens_in_paged_kv_cache"
605
+ value: {
606
+ string_value: "24000"
607
+ }
608
+ }
609
+ parameters: {
610
+ key: "max_attention_window_size"
611
+ value: {
612
+ string_value: ""
613
+ }
614
+ }
615
+ parameters: {
616
+ key: "sink_token_length"
617
+ value: {
618
+ string_value: "${sink_token_length}"
619
+ }
620
+ }
621
+ parameters: {
622
+ key: "batch_scheduler_policy"
623
+ value: {
624
+ string_value: ""
625
+ }
626
+ }
627
+ parameters: {
628
+ key: "kv_cache_free_gpu_mem_fraction"
629
+ value: {
630
+ string_value: "0.5"
631
+ }
632
+ }
633
+ parameters: {
634
+ key: "cross_kv_cache_fraction"
635
+ value: {
636
+ string_value: "0.5"
637
+ }
638
+ }
639
+ parameters: {
640
+ key: "kv_cache_host_memory_bytes"
641
+ value: {
642
+ string_value: "${kv_cache_host_memory_bytes}"
643
+ }
644
+ }
645
+ # kv_cache_onboard_blocks is for internal implementation.
646
+ parameters: {
647
+ key: "kv_cache_onboard_blocks"
648
+ value: {
649
+ string_value: "${kv_cache_onboard_blocks}"
650
+ }
651
+ }
652
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
653
+ # parameters: {
654
+ # key: "enable_trt_overlap"
655
+ # value: {
656
+ # string_value: "${enable_trt_overlap}"
657
+ # }
658
+ # }
659
+ parameters: {
660
+ key: "exclude_input_in_output"
661
+ value: {
662
+ string_value: "True"
663
+ }
664
+ }
665
+ parameters: {
666
+ key: "cancellation_check_period_ms"
667
+ value: {
668
+ string_value: "${cancellation_check_period_ms}"
669
+ }
670
+ }
671
+ parameters: {
672
+ key: "stats_check_period_ms"
673
+ value: {
674
+ string_value: "${stats_check_period_ms}"
675
+ }
676
+ }
677
+ parameters: {
678
+ key: "iter_stats_max_iterations"
679
+ value: {
680
+ string_value: "${iter_stats_max_iterations}"
681
+ }
682
+ }
683
+ parameters: {
684
+ key: "request_stats_max_iterations"
685
+ value: {
686
+ string_value: "${request_stats_max_iterations}"
687
+ }
688
+ }
689
+ parameters: {
690
+ key: "enable_kv_cache_reuse"
691
+ value: {
692
+ string_value: "false"
693
+ }
694
+ }
695
+ parameters: {
696
+ key: "normalize_log_probs"
697
+ value: {
698
+ string_value: ""
699
+ }
700
+ }
701
+ parameters: {
702
+ key: "enable_chunked_context"
703
+ value: {
704
+ string_value: "false"
705
+ }
706
+ }
707
+ parameters: {
708
+ key: "gpu_device_ids"
709
+ value: {
710
+ string_value: ""
711
+ }
712
+ }
713
+ parameters: {
714
+ key: "participant_ids"
715
+ value: {
716
+ string_value: "${participant_ids}"
717
+ }
718
+ }
719
+ parameters: {
720
+ key: "lora_cache_optimal_adapter_size"
721
+ value: {
722
+ string_value: "${lora_cache_optimal_adapter_size}"
723
+ }
724
+ }
725
+ parameters: {
726
+ key: "lora_cache_max_adapter_size"
727
+ value: {
728
+ string_value: "${lora_cache_max_adapter_size}"
729
+ }
730
+ }
731
+ parameters: {
732
+ key: "lora_cache_gpu_memory_fraction"
733
+ value: {
734
+ string_value: "${lora_cache_gpu_memory_fraction}"
735
+ }
736
+ }
737
+ parameters: {
738
+ key: "lora_cache_host_memory_bytes"
739
+ value: {
740
+ string_value: "${lora_cache_host_memory_bytes}"
741
+ }
742
+ }
743
+ parameters: {
744
+ key: "lora_prefetch_dir"
745
+ value: {
746
+ string_value: "${lora_prefetch_dir}"
747
+ }
748
+ }
749
+ parameters: {
750
+ key: "decoding_mode"
751
+ value: {
752
+ string_value: ""
753
+ }
754
+ }
755
+ parameters: {
756
+ key: "executor_worker_path"
757
+ value: {
758
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
759
+ }
760
+ }
761
+ parameters: {
762
+ key: "lookahead_window_size"
763
+ value: {
764
+ string_value: "${lookahead_window_size}"
765
+ }
766
+ }
767
+ parameters: {
768
+ key: "lookahead_ngram_size"
769
+ value: {
770
+ string_value: "${lookahead_ngram_size}"
771
+ }
772
+ }
773
+ parameters: {
774
+ key: "lookahead_verification_set_size"
775
+ value: {
776
+ string_value: "${lookahead_verification_set_size}"
777
+ }
778
+ }
779
+ parameters: {
780
+ key: "medusa_choices"
781
+ value: {
782
+ string_value: "${medusa_choices}"
783
+ }
784
+ }
785
+ parameters: {
786
+ key: "eagle_choices"
787
+ value: {
788
+ string_value: "${eagle_choices}"
789
+ }
790
+ }
791
+ parameters: {
792
+ key: "gpu_weights_percent"
793
+ value: {
794
+ string_value: "${gpu_weights_percent}"
795
+ }
796
+ }
797
+ parameters: {
798
+ key: "enable_context_fmha_fp32_acc"
799
+ value: {
800
+ string_value: ""
801
+ }
802
+ }
803
+ parameters: {
804
+ key: "multi_block_mode"
805
+ value: {
806
+ string_value: "${multi_block_mode}"
807
+ }
808
+ }
809
+ parameters: {
810
+ key: "cuda_graph_mode"
811
+ value: {
812
+ string_value: "${cuda_graph_mode}"
813
+ }
814
+ }
815
+ parameters: {
816
+ key: "cuda_graph_cache_size"
817
+ value: {
818
+ string_value: "${cuda_graph_cache_size}"
819
+ }
820
+ }
821
+ parameters: {
822
+ key: "speculative_decoding_fast_logits"
823
+ value: {
824
+ string_value: "${speculative_decoding_fast_logits}"
825
+ }
826
+ }
827
+ parameters: {
828
+ key: "tokenizer_dir"
829
+ value: {
830
+ string_value: "${tokenizer_dir}"
831
+ }
832
+ }
833
+ parameters: {
834
+ key: "guided_decoding_backend"
835
+ value: {
836
+ string_value: "${guided_decoding_backend}"
837
+ }
838
+ }
839
+ parameters: {
840
+ key: "xgrammar_tokenizer_info_path"
841
+ value: {
842
+ string_value: "${xgrammar_tokenizer_info_path}"
843
+ }
844
+ }
whisper_large_v3_max_batch_512/decoder/config.json ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.17.0.post1",
3
+ "pretrained_config": {
4
+ "architecture": "DecoderModel",
5
+ "dtype": "float16",
6
+ "vocab_size": 51866,
7
+ "hidden_size": 1280,
8
+ "num_hidden_layers": 32,
9
+ "num_attention_heads": 20,
10
+ "hidden_act": "gelu",
11
+ "logits_dtype": "float16",
12
+ "norm_epsilon": 1e-05,
13
+ "runtime_defaults": null,
14
+ "position_embedding_type": "learned_absolute",
15
+ "num_key_value_heads": 20,
16
+ "intermediate_size": 5120,
17
+ "max_position_embeddings": 448,
18
+ "mapping": {
19
+ "world_size": 1,
20
+ "gpus_per_node": 8,
21
+ "cp_size": 1,
22
+ "tp_size": 1,
23
+ "pp_size": 1,
24
+ "moe_tp_size": 1,
25
+ "moe_ep_size": 1,
26
+ "auto_parallel": false
27
+ },
28
+ "quantization": {
29
+ "quant_algo": null,
30
+ "kv_cache_quant_algo": null,
31
+ "group_size": 128,
32
+ "smoothquant_val": 0.5,
33
+ "clamp_val": null,
34
+ "use_meta_recipe": false,
35
+ "has_zero_point": false,
36
+ "pre_quant_scale": false,
37
+ "exclude_modules": null
38
+ },
39
+ "use_parallel_embedding": false,
40
+ "embedding_sharding_dim": 0,
41
+ "head_size": 64,
42
+ "qk_layernorm": false,
43
+ "rotary_embedding_dim": 64,
44
+ "use_prompt_tuning": false,
45
+ "has_position_embedding": true,
46
+ "layernorm_type": 0,
47
+ "has_attention_qkvo_bias": true,
48
+ "has_mlp_bias": true,
49
+ "has_model_final_layernorm": true,
50
+ "has_embedding_layernorm": false,
51
+ "has_embedding_scale": false,
52
+ "ffn_hidden_size": 5120,
53
+ "q_scaling": 1.0,
54
+ "layernorm_position": 0,
55
+ "relative_attention": false,
56
+ "max_distance": 0,
57
+ "num_buckets": 0,
58
+ "model_type": "whisper",
59
+ "rescale_before_lm_head": false,
60
+ "encoder_hidden_size": 1280,
61
+ "encoder_num_heads": 20,
62
+ "encoder_head_size": null,
63
+ "skip_cross_kv": false,
64
+ "type_vocab_size": null,
65
+ "encoder_num_kv_heads": null,
66
+ "mlp_type": 0,
67
+ "residual_scaling": 1.0,
68
+ "has_lm_head_bias": false
69
+ },
70
+ "build_config": {
71
+ "max_input_len": 14,
72
+ "max_seq_len": 114,
73
+ "opt_batch_size": 8,
74
+ "max_batch_size": 512,
75
+ "max_beam_width": 4,
76
+ "max_num_tokens": 8192,
77
+ "opt_num_tokens": 2048,
78
+ "max_prompt_embedding_table_size": 0,
79
+ "kv_cache_type": "PAGED",
80
+ "gather_context_logits": false,
81
+ "gather_generation_logits": false,
82
+ "strongly_typed": true,
83
+ "force_num_profiles": null,
84
+ "profiling_verbosity": "layer_names_only",
85
+ "enable_debug_output": false,
86
+ "max_draft_len": 0,
87
+ "speculative_decoding_mode": 1,
88
+ "use_refit": false,
89
+ "input_timing_cache": null,
90
+ "output_timing_cache": "model.cache",
91
+ "lora_config": {
92
+ "lora_dir": [],
93
+ "lora_ckpt_source": "hf",
94
+ "max_lora_rank": 64,
95
+ "lora_target_modules": [],
96
+ "trtllm_modules_to_hf_modules": {}
97
+ },
98
+ "auto_parallel_config": {
99
+ "world_size": 1,
100
+ "gpus_per_node": 8,
101
+ "cluster_key": "H100-PCIe",
102
+ "cluster_info": null,
103
+ "sharding_cost_model": "alpha_beta",
104
+ "comm_cost_model": "alpha_beta",
105
+ "enable_pipeline_parallelism": false,
106
+ "enable_shard_unbalanced_shape": false,
107
+ "enable_shard_dynamic_shape": false,
108
+ "enable_reduce_scatter": true,
109
+ "builder_flags": null,
110
+ "debug_mode": false,
111
+ "infer_shape": true,
112
+ "validation_mode": false,
113
+ "same_buffer_io": {
114
+ "past_key_value_(\\d+)": "present_key_value_\\1"
115
+ },
116
+ "same_spec_io": {},
117
+ "sharded_io_allowlist": [
118
+ "past_key_value_\\d+",
119
+ "present_key_value_\\d*"
120
+ ],
121
+ "fill_weights": false,
122
+ "parallel_config_cache": null,
123
+ "profile_cache": null,
124
+ "dump_path": null,
125
+ "debug_outputs": []
126
+ },
127
+ "weight_sparsity": false,
128
+ "weight_streaming": false,
129
+ "plugin_config": {
130
+ "dtype": "float16",
131
+ "bert_attention_plugin": "float16",
132
+ "gpt_attention_plugin": "float16",
133
+ "gemm_plugin": "float16",
134
+ "explicitly_disable_gemm_plugin": false,
135
+ "gemm_swiglu_plugin": null,
136
+ "fp8_rowwise_gemm_plugin": null,
137
+ "qserve_gemm_plugin": null,
138
+ "identity_plugin": null,
139
+ "nccl_plugin": null,
140
+ "lora_plugin": null,
141
+ "weight_only_groupwise_quant_matmul_plugin": null,
142
+ "weight_only_quant_matmul_plugin": null,
143
+ "smooth_quant_plugins": true,
144
+ "smooth_quant_gemm_plugin": null,
145
+ "layernorm_quantization_plugin": null,
146
+ "rmsnorm_quantization_plugin": null,
147
+ "quantize_per_token_plugin": false,
148
+ "quantize_tensor_plugin": false,
149
+ "moe_plugin": null,
150
+ "mamba_conv1d_plugin": "auto",
151
+ "low_latency_gemm_plugin": null,
152
+ "low_latency_gemm_swiglu_plugin": null,
153
+ "context_fmha": true,
154
+ "bert_context_fmha_fp32_acc": false,
155
+ "paged_kv_cache": true,
156
+ "remove_input_padding": true,
157
+ "reduce_fusion": false,
158
+ "user_buffer": false,
159
+ "tokens_per_block": 64,
160
+ "use_paged_context_fmha": false,
161
+ "use_fp8_context_fmha": false,
162
+ "multiple_profiles": false,
163
+ "paged_state": false,
164
+ "streamingllm": false,
165
+ "manage_weights": false,
166
+ "use_fused_mlp": true,
167
+ "pp_reduce_scatter": false
168
+ },
169
+ "use_strip_plan": false,
170
+ "max_encoder_input_len": 3000,
171
+ "monitor_memory": false,
172
+ "use_mrope": false
173
+ }
174
+ }
whisper_large_v3_max_batch_512/decoder/rank0.engine ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e14050661dc50c175348498694b2eea42e45e71d10d73be9646b159dfa3dcb4a
3
+ size 2166109620
whisper_large_v3_max_batch_512/encoder/config.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.17.0.post1",
3
+ "pretrained_config": {
4
+ "architecture": "WhisperEncoder",
5
+ "dtype": "float16",
6
+ "vocab_size": 51866,
7
+ "hidden_size": 1280,
8
+ "num_hidden_layers": 32,
9
+ "num_attention_heads": 20,
10
+ "hidden_act": "gelu",
11
+ "logits_dtype": "float32",
12
+ "norm_epsilon": 1e-05,
13
+ "runtime_defaults": null,
14
+ "position_embedding_type": "learned_absolute",
15
+ "num_key_value_heads": 20,
16
+ "intermediate_size": 5120,
17
+ "max_position_embeddings": 1500,
18
+ "mapping": {
19
+ "world_size": 1,
20
+ "gpus_per_node": 8,
21
+ "cp_size": 1,
22
+ "tp_size": 1,
23
+ "pp_size": 1,
24
+ "moe_tp_size": 1,
25
+ "moe_ep_size": 1,
26
+ "auto_parallel": false
27
+ },
28
+ "quantization": {
29
+ "quant_algo": null,
30
+ "kv_cache_quant_algo": null,
31
+ "group_size": 128,
32
+ "smoothquant_val": 0.5,
33
+ "clamp_val": null,
34
+ "use_meta_recipe": false,
35
+ "has_zero_point": false,
36
+ "pre_quant_scale": false,
37
+ "exclude_modules": null
38
+ },
39
+ "use_parallel_embedding": false,
40
+ "embedding_sharding_dim": 0,
41
+ "head_size": 64,
42
+ "qk_layernorm": false,
43
+ "rotary_embedding_dim": 64,
44
+ "has_position_embedding": true,
45
+ "n_mels": 128,
46
+ "num_languages": 100
47
+ },
48
+ "build_config": {
49
+ "max_input_len": 3000,
50
+ "max_seq_len": 3000,
51
+ "opt_batch_size": 8,
52
+ "max_batch_size": 512,
53
+ "max_beam_width": 1,
54
+ "max_num_tokens": 8192,
55
+ "opt_num_tokens": 512,
56
+ "max_prompt_embedding_table_size": 0,
57
+ "kv_cache_type": "PAGED",
58
+ "gather_context_logits": false,
59
+ "gather_generation_logits": false,
60
+ "strongly_typed": true,
61
+ "force_num_profiles": null,
62
+ "profiling_verbosity": "layer_names_only",
63
+ "enable_debug_output": false,
64
+ "max_draft_len": 0,
65
+ "speculative_decoding_mode": 1,
66
+ "use_refit": false,
67
+ "input_timing_cache": null,
68
+ "output_timing_cache": "model.cache",
69
+ "lora_config": {
70
+ "lora_dir": [],
71
+ "lora_ckpt_source": "hf",
72
+ "max_lora_rank": 64,
73
+ "lora_target_modules": [],
74
+ "trtllm_modules_to_hf_modules": {}
75
+ },
76
+ "auto_parallel_config": {
77
+ "world_size": 1,
78
+ "gpus_per_node": 8,
79
+ "cluster_key": "H100-PCIe",
80
+ "cluster_info": null,
81
+ "sharding_cost_model": "alpha_beta",
82
+ "comm_cost_model": "alpha_beta",
83
+ "enable_pipeline_parallelism": false,
84
+ "enable_shard_unbalanced_shape": false,
85
+ "enable_shard_dynamic_shape": false,
86
+ "enable_reduce_scatter": true,
87
+ "builder_flags": null,
88
+ "debug_mode": false,
89
+ "infer_shape": true,
90
+ "validation_mode": false,
91
+ "same_buffer_io": {
92
+ "past_key_value_(\\d+)": "present_key_value_\\1"
93
+ },
94
+ "same_spec_io": {},
95
+ "sharded_io_allowlist": [
96
+ "past_key_value_\\d+",
97
+ "present_key_value_\\d*"
98
+ ],
99
+ "fill_weights": false,
100
+ "parallel_config_cache": null,
101
+ "profile_cache": null,
102
+ "dump_path": null,
103
+ "debug_outputs": []
104
+ },
105
+ "weight_sparsity": false,
106
+ "weight_streaming": false,
107
+ "plugin_config": {
108
+ "dtype": "float16",
109
+ "bert_attention_plugin": "float16",
110
+ "gpt_attention_plugin": "auto",
111
+ "gemm_plugin": null,
112
+ "explicitly_disable_gemm_plugin": true,
113
+ "gemm_swiglu_plugin": null,
114
+ "fp8_rowwise_gemm_plugin": null,
115
+ "qserve_gemm_plugin": null,
116
+ "identity_plugin": null,
117
+ "nccl_plugin": null,
118
+ "lora_plugin": null,
119
+ "weight_only_groupwise_quant_matmul_plugin": null,
120
+ "weight_only_quant_matmul_plugin": null,
121
+ "smooth_quant_plugins": true,
122
+ "smooth_quant_gemm_plugin": null,
123
+ "layernorm_quantization_plugin": null,
124
+ "rmsnorm_quantization_plugin": null,
125
+ "quantize_per_token_plugin": false,
126
+ "quantize_tensor_plugin": false,
127
+ "moe_plugin": null,
128
+ "mamba_conv1d_plugin": "auto",
129
+ "low_latency_gemm_plugin": null,
130
+ "low_latency_gemm_swiglu_plugin": null,
131
+ "context_fmha": true,
132
+ "bert_context_fmha_fp32_acc": false,
133
+ "paged_kv_cache": true,
134
+ "remove_input_padding": true,
135
+ "reduce_fusion": false,
136
+ "user_buffer": false,
137
+ "tokens_per_block": 64,
138
+ "use_paged_context_fmha": false,
139
+ "use_fp8_context_fmha": false,
140
+ "multiple_profiles": false,
141
+ "paged_state": false,
142
+ "streamingllm": false,
143
+ "manage_weights": false,
144
+ "use_fused_mlp": true,
145
+ "pp_reduce_scatter": false
146
+ },
147
+ "use_strip_plan": false,
148
+ "max_encoder_input_len": 1024,
149
+ "monitor_memory": false,
150
+ "use_mrope": false
151
+ }
152
+ }
whisper_large_v3_max_batch_512/encoder/rank0.engine ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3f2f4cc9b86f771778b657e581a32222addfec3f26e78868308a535629a0354
3
+ size 1297639156