Aditya12kumar commited on
Commit
8b8d155
·
verified ·
1 Parent(s): fb7fcbd

Create quote_generation.py

Browse files
Files changed (1) hide show
  1. quote_generation.py +33 -0
quote_generation.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
2
+ from transformers.pipelines import TextGenerationPipeline
3
+
4
+ class QuoteGenerator:
5
+ def __init__(self):
6
+ self.quote_generator: TextGenerationPipeline
7
+ self.tokenizer = AutoTokenizer.from_pretrained("nandinib1999/quote-generator")
8
+ self.model = AutoModelWithLMHead.from_pretrained("nandinib1999/quote-generator")
9
+ self.default_prompts = ['life is not fair', 'she is a warrior', 'friendship is like', 'you cannot find light']
10
+
11
+ def load_generator(self) -> None:
12
+ self.quote_generator = pipeline('text-generation', model = self.model, tokenizer = self.tokenizer)
13
+ def clean_text(self, quote):
14
+ quote = quote.strip()
15
+ # Remove any special characters
16
+ quote = re.sub('[~`^*&%$@#|]', ' ', quote)
17
+ # Remove extra spaces
18
+ quote = re.sub("\s\s+", " ", quote)
19
+ # Remove space before punctuation
20
+ quote = re.sub('\s[.,;!?]', '\1', quote)
21
+ # Sentence case
22
+ quote = quote.capitalize()
23
+ return quote
24
+
25
+ def generate_quote(self, prompt_start: str, min_length=3, max_length=30, top_p=0.9):
26
+ if len(prompt_start.strip()) == 0:
27
+ rand_index = random.randint(0, len(self.default_prompts)-1)
28
+ prompt_start = self.default_prompts[rand_index]
29
+
30
+ quote_gen = self.quote_generator(prompt_start.strip(), min_length = min_length, max_length = max_length, temperature = 1.0, top_k = 5, top_p = top_p)
31
+ quote = quote_gen[0]['generated_text']
32
+ quote_clean = self.clean_text(quote)
33
+ return quote_clean