VCal's picture
Upload PyTorch model
cecbf37 verified
##############################################################
## C O P Y R I G H T (c) 2024 ##
## DH Healthcare GmbH and/or its affiliates ##
## All Rights Reserved ##
##############################################################
## ##
## THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE OF ##
## DH Healthcare GmbH and/or its affiliates. ##
## The copyright notice above does not evidence any ##
## actual or intended publication of such source code. ##
## ##
##############################################################
import os
import pathlib
import numpy as np
import triton_python_backend_utils as pb_utils
from nlpserving.family_history.serving.models.family_history_model import FamilyHistoryClassificationModel
class TritonPythonModel:
def initialize(self, args):
"""Initialize the model with performance optimizations."""
PATH = os.path.join(pathlib.Path(__file__).parent.resolve(), '../')
self.model = FamilyHistoryClassificationModel(model_dir=PATH)
# Performance configuration
self.batch_size = int(os.environ.get('INFERENCE_BATCH_SIZE', 64))
self.max_sequence_length = int(
os.environ.get('MAX_SEQUENCE_LENGTH', 512)
)
# Pre-allocate common objects to reduce GC pressure
self._empty_response_cache = None
# Warmup the model with a dummy inference
try:
dummy_input = ['warmup text']
self.model(dummy_input, batch_size=1, top_k=1)
except Exception:
pass # Ignore warmup errors
def execute(self, requests):
"""Perform optimized inference with adaptive batching."""
if not requests:
return []
# Collect all texts from all requests for better batching
all_texts = []
request_boundaries = []
current_idx = 0
for request in requests:
input_tensors = pb_utils.get_input_tensor_by_name(request, "text")
# Direct conversion avoiding intermediate list
texts = [
tensor.decode('utf-8') for tensor in input_tensors.as_numpy()
]
all_texts.extend(texts)
request_boundaries.append((current_idx, current_idx + len(texts)))
current_idx += len(texts)
if not all_texts:
return []
# Use adaptive batch size based on text characteristics
total_chars = sum(len(text) for text in all_texts)
avg_chars = total_chars / len(all_texts) if all_texts else 0
# Adjust batch size based on text length
if avg_chars > 1000:
effective_batch_size = min(len(all_texts), self.batch_size // 2)
elif avg_chars < 200:
effective_batch_size = min(len(all_texts), self.batch_size * 2)
else:
effective_batch_size = min(len(all_texts), self.batch_size)
# Process all texts together for better efficiency
all_outputs = self.model(
all_texts,
batch_size=effective_batch_size,
top_k=1
)
# Split outputs back to individual responses
responses = []
for start_idx, end_idx in request_boundaries:
request_outputs = all_outputs[start_idx:end_idx]
# Pre-allocate array for better performance
output = np.array([
str(output_dict).encode('utf-8')
for output_dict in request_outputs
], dtype=object)
response = pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor("output", output)]
)
responses.append(response)
return responses
def finalize(self):
"""Clean up model resources."""
if hasattr(self, 'model'):
del self.model