| from transformers import AutoTokenizer, TFGPT2LMHeadModel, pipeline |
| from transformers.pipelines import TextGenerationPipeline |
| from typing import Union |
|
|
| class QuoteGenerator(): |
| def __init__(self, model_name:str='gruhit-patel/quote-generator-v2'): |
| self.model_name = model_name |
| self.quote_generator: TextGenerationPipeline |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| self.model = TFGPT2LMHeadModel.from_pretrained(self.model_name) |
| self.default_tags = 'love,life' |
| print("Model has been loaded") |
| |
| def load_generator(self) -> None: |
| self.quote_generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer) |
| print("Pipeline has been generated") |
| |
| def preprocess_tags(self, tags: Union[None, str] = None) -> str: |
| if tags is None: |
| tags = self.default_tags |
| |
| return self.tokenizer.bos_token + tags + '<bot>:' |
|
|
| def generate_quote(self, tags:Union[None, str], max_new_tokens: int, do_sample: bool, |
| num_beams: int, top_k: int, top_p: float, temperature: float): |
| |
| tags = self.preprocess_tags(tags) |
| output = self.quote_generator( |
| tags, |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| do_sample = do_sample, |
| early_stopping=num_beams > 1 |
| ) |
| |
| return output[0]['generated_text'] |
|
|
|
|