Shymaa2611 commited on
Commit
8e26987
·
1 Parent(s): bc85cac
Files changed (1) hide show
  1. inference.py +7 -4
inference.py CHANGED
@@ -3,16 +3,19 @@ from dataset import clean
3
  import re
4
  import gdown
5
  import os
 
 
6
 
7
  def load_tokenizer_model():
8
  folder_url = "https://drive.google.com/drive/folders/1DDJ9t-HfMrf6OLYim5bVrP20QgyOZahc?usp=drive_link"
9
  gdown.download_folder(folder_url, output="ChatbotCheckpoint", quiet=False)
10
- model_name="ChatbotCheckpoint"
11
  model = GPT2LMHeadModel.from_pretrained(model_name)
12
  model.eval()
13
- tokenizer=GPT2Tokenizer.from_pretrained(model_name)
14
- tokenizer.pad_token=tokenizer.eos_token
15
- return tokenizer,model
 
16
 
17
  def generate_answer(query):
18
  tokenizer,model=load_tokenizer_model()
 
3
  import re
4
  import gdown
5
  import os
6
+ cache_dir = os.path.expanduser("~/.cache/gdown")
7
+ os.makedirs(cache_dir, exist_ok=True)
8
 
9
  def load_tokenizer_model():
10
  folder_url = "https://drive.google.com/drive/folders/1DDJ9t-HfMrf6OLYim5bVrP20QgyOZahc?usp=drive_link"
11
  gdown.download_folder(folder_url, output="ChatbotCheckpoint", quiet=False)
12
+ model_name = "ChatbotCheckpoint"
13
  model = GPT2LMHeadModel.from_pretrained(model_name)
14
  model.eval()
15
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
16
+ tokenizer.pad_token = tokenizer.eos_token
17
+ return tokenizer, model
18
+
19
 
20
  def generate_answer(query):
21
  tokenizer,model=load_tokenizer_model()