File size: 1,625 Bytes
8b8d155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
from transformers.pipelines import TextGenerationPipeline

class QuoteGenerator:
    def __init__(self):
        self.quote_generator: TextGenerationPipeline
        self.tokenizer = AutoTokenizer.from_pretrained("nandinib1999/quote-generator")
        self.model = AutoModelWithLMHead.from_pretrained("nandinib1999/quote-generator")   
        self.default_prompts = ['life is not fair', 'she is a warrior', 'friendship is like', 'you cannot find light']

    def load_generator(self) -> None:
        self.quote_generator = pipeline('text-generation', model = self.model, tokenizer = self.tokenizer)
    def clean_text(self, quote):
        quote = quote.strip()
        # Remove any special characters
        quote = re.sub('[~`^*&%$@#|]', ' ', quote)
        # Remove extra spaces
        quote = re.sub("\s\s+", " ", quote)
        # Remove space before punctuation
        quote = re.sub('\s[.,;!?]', '\1', quote)
        # Sentence case
        quote = quote.capitalize()
        return quote

    def generate_quote(self, prompt_start: str, min_length=3, max_length=30, top_p=0.9):
        if len(prompt_start.strip()) == 0:
            rand_index = random.randint(0, len(self.default_prompts)-1)
            prompt_start = self.default_prompts[rand_index]
            
        quote_gen = self.quote_generator(prompt_start.strip(), min_length = min_length, max_length = max_length, temperature = 1.0, top_k = 5, top_p = top_p)
        quote = quote_gen[0]['generated_text']
        quote_clean = self.clean_text(quote)
        return quote_clean