| import torch | |
| from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
| class TextSummarizer: | |
| def __init__(self, device='cuda'): | |
| self._load_model( | |
| model_type="t5", | |
| model_dir= | |
| "./pretrained_models/flan-t5-large-finetuned-openai-summarize_from_feedback", | |
| device=device) | |
| def _load_model(self, | |
| model_type: str = "t5", | |
| model_dir: str = "outputs", | |
| device: str = 'cuda'): | |
| """ | |
| loads a checkpoint for inferencing/prediction | |
| Args: | |
| model_type (str, optional): "t5" or "mt5". Defaults to "t5". | |
| model_dir (str, optional): path to model directory. | |
| Defaults to "outputs". | |
| device (str, optional): device to run. Defaults to "cuda". | |
| """ | |
| if model_type == "t5": | |
| self.model = T5ForConditionalGeneration.from_pretrained( | |
| f"{model_dir}") | |
| self.tokenizer = T5TokenizerFast.from_pretrained(f"{model_dir}") | |
| else: | |
| raise NotImplementedError( | |
| f"model_type {model_type} not implemented") | |
| self.device = torch.device(device) | |
| self.model = self.model.to(self.device) | |
| def _predict( | |
| self, | |
| source_text: str, | |
| max_length: int = 512, | |
| num_return_sequences: int = 1, | |
| num_beams: int = 2, | |
| top_k: int = 50, | |
| top_p: float = 0.95, | |
| do_sample: bool = True, | |
| repetition_penalty: float = 2.5, | |
| length_penalty: float = 1.0, | |
| early_stopping: bool = True, | |
| skip_special_tokens: bool = True, | |
| clean_up_tokenization_spaces: bool = True, | |
| ): | |
| """ | |
| generates prediction for T5/MT5 model | |
| Args: | |
| source_text (str): any text for generating predictions | |
| max_length (int, optional): max token length of prediction. | |
| Defaults to 512. | |
| num_return_sequences (int, optional): number of predictions to be | |
| returned. Defaults to 1. | |
| num_beams (int, optional): number of beams. Defaults to 2. | |
| top_k (int, optional): Defaults to 50. | |
| top_p (float, optional): Defaults to 0.95. | |
| do_sample (bool, optional): Defaults to True. | |
| repetition_penalty (float, optional): Defaults to 2.5. | |
| length_penalty (float, optional): Defaults to 1.0. | |
| early_stopping (bool, optional): Defaults to True. | |
| skip_special_tokens (bool, optional): Defaults to True. | |
| clean_up_tokenization_spaces (bool, optional): Defaults to True. | |
| Returns: | |
| list[str]: returns predictions | |
| """ | |
| input_ids = self.tokenizer.encode( | |
| source_text, return_tensors="pt", add_special_tokens=True) | |
| input_ids = input_ids.to(self.device) | |
| generated_ids = self.model.generate( | |
| input_ids=input_ids, | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| early_stopping=early_stopping, | |
| top_p=top_p, | |
| top_k=top_k, | |
| num_return_sequences=num_return_sequences, | |
| ) | |
| preds = [ | |
| self.tokenizer.decode( | |
| g, | |
| skip_special_tokens=skip_special_tokens, | |
| clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
| ) for g in generated_ids | |
| ] | |
| return preds | |
| def __call__(self, source_text): | |
| generated_text = self._predict(source_text=source_text) | |
| return generated_text | |