| import torch |
| import gc |
| from ts.torch_handler.base_handler import BaseHandler |
| from transformers import GPT2LMHeadModel |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SampleTransformerModel(BaseHandler): |
| def __init__(self): |
| super(SampleTransformerModel, self).__init__() |
| self.model = None |
| self.device = None |
| self.initialized = False |
|
|
| def load_model(self, model_dir): |
| self.model = GPT2LMHeadModel.from_pretrained(model_dir, return_dict=True) |
| self.model.to(self.device) |
|
|
| def initialize(self, ctx): |
| |
| properties = ctx.system_properties |
| model_dir = properties.get("model_dir") |
| self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") |
|
|
| self.load_model(model_dir) |
|
|
| self.model.eval() |
| self.initialized = True |
|
|
| def preprocess(self, requests): |
| input_batch = {} |
| for idx, data in enumerate(requests): |
| input_ids = torch.tensor([data.get("body").get("text")]).to(self.device) |
| input_batch["input_ids"] = input_ids |
| input_batch["num_samples"] = data.get("body").get("num_samples") |
| input_batch["length"] = data.get("body").get("length") + len(data.get("body").get("text")) |
| del requests |
| gc.collect() |
| return input_batch |
|
|
| def inference(self, input_batch): |
| input_ids = input_batch["input_ids"] |
| length = input_batch["length"] |
|
|
| inference_output = self.model.generate(input_ids, |
| bos_token_id=self.model.config.bos_token_id, |
| eos_token_id=self.model.config.eos_token_id, |
| pad_token_id=self.model.config.eos_token_id, |
| do_sample=True, |
| max_length=length, |
| top_k=50, |
| top_p=0.95, |
| no_repeat_ngram_size=2, |
| num_return_sequences=input_batch["num_samples"]) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| del input_batch |
| gc.collect() |
| return inference_output |
|
|
| def postprocess(self, inference_output): |
| output = inference_output.cpu().numpy().tolist() |
| del inference_output |
| gc.collect() |
| return [output] |
|
|
| def handle(self, data, context): |
| |
| data = self.preprocess(data) |
| data = self.inference(data) |
| data = self.postprocess(data) |
| return data |
|
|