from typing import Dict, List, Any from pathlib import Path import torch from transformers import ( BartConfig, BartForConditionalGeneration, PreTrainedTokenizerFast, ) class EndpointHandler(): def __init__(self, path=""): # Load model from HuggingFace Hub self.model_path = path + "/" + "kobartbasekosummary.pt" config = BartConfig.from_pretrained("hyunwoongko/kobart") self.model = BartForConditionalGeneration(config).eval().to('cpu') self.model.model.load_state_dict(torch.load( self.model_path, map_location='cpu', )) self.tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/kobart") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # destruct model and tokenizer model = self.model tokenizer = self.tokenizer #parmeters beam = 5 sampling = False temperature = 1.0 sampling_topk = -1 sampling_topp = -1 length_penalty = 1.0 max_len_a = 1 max_len_b = 50 no_repeat_ngram_size = 4 return_tokens = False bad_words_ids = None dataPop = data.pop("inputs", data) if isinstance(dataPop, str): texts = [dataPop] else: texts = dataPop tokenized = self.tokenize(tokenizer, texts) input_ids = tokenized["input_ids"] attention_mask = tokenized["attention_mask"] generated = model.generate( input_ids.to('cpu'), attention_mask=attention_mask.to('cpu'), use_cache=True, early_stopping=False, decoder_start_token_id=tokenizer.bos_token_id, num_beams=beam, do_sample=sampling, temperature=temperature, top_k=sampling_topk if sampling_topk > 0 else None, top_p=sampling_topp if sampling_topk > 0 else None, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=[[tokenizer.convert_tokens_to_ids("")]] if not bad_words_ids else bad_words_ids + [[tokenizer.convert_tokens_to_ids("")]], length_penalty=length_penalty, max_length=max_len_a * len(input_ids[0]) + max_len_b, ) summ_result = '' if return_tokens: output = [ tokenizer.convert_ids_to_tokens(_) for _ in generated.tolist() ] summ_result = (output[0] if isinstance( dataPop, str, ) else output) else: output = tokenizer.batch_decode( generated.tolist(), skip_special_tokens=True, ) summ_result = (output[0].strip() if isinstance( dataPop, str, ) else [o.strip() for o in output]) return {"summarization": summ_result} def tokenize( self, tokenizer, texts: List[str], max_len: int = 1024, ) -> Dict: if isinstance(texts, str): texts = [texts] texts = [f" {text}" for text in texts] eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) eos_list = [eos for _ in range(len(texts))] tokens = tokenizer( texts, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False, max_length=max_len - 1, # result + ) return self.add_bos_eos_tokens(tokenizer, tokens, eos_list) def add_bos_eos_tokens(self, tokenizer, tokens, eos_list): input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] token_added_ids, token_added_masks = [], [] for input_id, atn_mask, eos in zip( input_ids, attention_mask, eos_list, ): maximum_idx = [ i for i, val in enumerate(input_id) if val != tokenizer.convert_tokens_to_ids("") ] if len(maximum_idx) == 0: idx_to_add = 0 else: idx_to_add = max(maximum_idx) + 1 eos = torch.tensor([eos], requires_grad=False) additional_atn_mask = torch.tensor([1], requires_grad=False) input_id = torch.cat([ input_id[:idx_to_add], eos, input_id[idx_to_add:], ]).long() atn_mask = torch.cat([ atn_mask[:idx_to_add], additional_atn_mask, atn_mask[idx_to_add:], ]).long() token_added_ids.append(input_id.unsqueeze(0)) token_added_masks.append(atn_mask.unsqueeze(0)) tokens["input_ids"] = torch.cat(token_added_ids, dim=0) tokens["attention_mask"] = torch.cat(token_added_masks, dim=0) return tokens