Spaces:
Build error
Build error
t5 led exchange
Browse files- src/Surveyor.py +13 -5
src/Surveyor.py
CHANGED
|
@@ -131,8 +131,12 @@ class Surveyor:
|
|
| 131 |
#self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
|
| 132 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
self.ledmodel.eval()
|
| 137 |
if not no_save_models:
|
| 138 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
@@ -144,7 +148,7 @@ class Surveyor:
|
|
| 144 |
self.embedder.save(models_dir + "/embedder")
|
| 145 |
else:
|
| 146 |
print("\nInitializing from previously saved models at" + models_dir)
|
| 147 |
-
self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name)
|
| 148 |
self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
|
| 149 |
self.title_model.eval()
|
| 150 |
|
|
@@ -157,8 +161,12 @@ class Surveyor:
|
|
| 157 |
self.summ_model.eval()
|
| 158 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
self.ledmodel.eval()
|
| 163 |
|
| 164 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|
|
|
|
| 131 |
#self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
|
| 132 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 133 |
|
| 134 |
+
if 't5' not in ledmodel_name:
|
| 135 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
| 136 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 137 |
+
else:
|
| 138 |
+
self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
|
| 139 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 140 |
self.ledmodel.eval()
|
| 141 |
if not no_save_models:
|
| 142 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
|
|
| 148 |
self.embedder.save(models_dir + "/embedder")
|
| 149 |
else:
|
| 150 |
print("\nInitializing from previously saved models at" + models_dir)
|
| 151 |
+
self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name)
|
| 152 |
self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
|
| 153 |
self.title_model.eval()
|
| 154 |
|
|
|
|
| 161 |
self.summ_model.eval()
|
| 162 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 163 |
|
| 164 |
+
if 't5' not in ledmodel_name:
|
| 165 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
| 166 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 167 |
+
else:
|
| 168 |
+
self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
|
| 169 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 170 |
self.ledmodel.eval()
|
| 171 |
|
| 172 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|