| | from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead |
| | from transformers.pipelines import TextGenerationPipeline |
| |
|
| | class QuoteGenerator: |
| | def __init__(self): |
| | self.quote_generator: TextGenerationPipeline |
| | self.tokenizer = AutoTokenizer.from_pretrained("nandinib1999/quote-generator") |
| | self.model = AutoModelWithLMHead.from_pretrained("nandinib1999/quote-generator") |
| | self.default_prompts = ['life is not fair', 'she is a warrior', 'friendship is like', 'you cannot find light'] |
| |
|
| | def load_generator(self) -> None: |
| | self.quote_generator = pipeline('text-generation', model = self.model, tokenizer = self.tokenizer) |
| | def clean_text(self, quote): |
| | quote = quote.strip() |
| | |
| | quote = re.sub('[~`^*&%$@#|]', ' ', quote) |
| | |
| | quote = re.sub("\s\s+", " ", quote) |
| | |
| | quote = re.sub('\s[.,;!?]', '\1', quote) |
| | |
| | quote = quote.capitalize() |
| | return quote |
| |
|
| | def generate_quote(self, prompt_start: str, min_length=3, max_length=30, top_p=0.9): |
| | if len(prompt_start.strip()) == 0: |
| | rand_index = random.randint(0, len(self.default_prompts)-1) |
| | prompt_start = self.default_prompts[rand_index] |
| | |
| | quote_gen = self.quote_generator(prompt_start.strip(), min_length = min_length, max_length = max_length, temperature = 1.0, top_k = 5, top_p = top_p) |
| | quote = quote_gen[0]['generated_text'] |
| | quote_clean = self.clean_text(quote) |
| | return quote_clean |