| | import torch |
| | import gc |
| | from ts.torch_handler.base_handler import BaseHandler |
| | from transformers import GPT2LMHeadModel |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class SampleTransformerModel(BaseHandler): |
| | def __init__(self): |
| | super(SampleTransformerModel, self).__init__() |
| | self.model = None |
| | self.device = None |
| | self.initialized = False |
| |
|
| | def load_model(self, model_dir): |
| | self.model = GPT2LMHeadModel.from_pretrained(model_dir, return_dict=True) |
| | self.model.to(self.device) |
| |
|
| | def initialize(self, ctx): |
| | |
| | properties = ctx.system_properties |
| | model_dir = properties.get("model_dir") |
| | self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") |
| |
|
| | self.load_model(model_dir) |
| |
|
| | self.model.eval() |
| | self.initialized = True |
| |
|
| | def preprocess(self, requests): |
| | input_batch = {} |
| | for idx, data in enumerate(requests): |
| | input_ids = torch.tensor([data.get("body").get("text")]).to(self.device) |
| | input_batch["input_ids"] = input_ids |
| | input_batch["num_samples"] = data.get("body").get("num_samples") |
| | input_batch["length"] = data.get("body").get("length") + len(data.get("body").get("text")) |
| | del requests |
| | gc.collect() |
| | return input_batch |
| |
|
| | def inference(self, input_batch): |
| | input_ids = input_batch["input_ids"] |
| | length = input_batch["length"] |
| |
|
| | inference_output = self.model.generate(input_ids, |
| | bos_token_id=self.model.config.bos_token_id, |
| | eos_token_id=self.model.config.eos_token_id, |
| | pad_token_id=self.model.config.eos_token_id, |
| | do_sample=True, |
| | max_length=length, |
| | top_k=50, |
| | top_p=0.95, |
| | no_repeat_ngram_size=2, |
| | num_return_sequences=input_batch["num_samples"]) |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | del input_batch |
| | gc.collect() |
| | return inference_output |
| |
|
| | def postprocess(self, inference_output): |
| | output = inference_output.cpu().numpy().tolist() |
| | del inference_output |
| | gc.collect() |
| | return [output] |
| |
|
| | def handle(self, data, context): |
| | |
| | data = self.preprocess(data) |
| | data = self.inference(data) |
| | data = self.postprocess(data) |
| | return data |
| |
|