Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/model.py +16 -5
src/models/model.py
CHANGED
|
@@ -161,8 +161,6 @@ class LightningModel(LightningModule):
|
|
| 161 |
self.model = model
|
| 162 |
self.tokenizer = tokenizer
|
| 163 |
self.output = output
|
| 164 |
-
# self.val_acc = Accuracy()
|
| 165 |
-
# self.train_acc = Accuracy()
|
| 166 |
|
| 167 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
| 168 |
""" forward step """
|
|
@@ -347,7 +345,7 @@ class Summarization:
|
|
| 347 |
trainer.fit(self.T5Model, self.data_module)
|
| 348 |
|
| 349 |
def load_model(
|
| 350 |
-
self, model_dir: str = "../../models", use_gpu: bool = False
|
| 351 |
):
|
| 352 |
"""
|
| 353 |
loads a checkpoint for inferencing/prediction
|
|
@@ -356,8 +354,21 @@ class Summarization:
|
|
| 356 |
model_dir (str, optional): path to model directory. Defaults to "outputs".
|
| 357 |
use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
|
| 358 |
"""
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
if use_gpu:
|
| 363 |
if torch.cuda.is_available():
|
|
|
|
| 161 |
self.model = model
|
| 162 |
self.tokenizer = tokenizer
|
| 163 |
self.output = output
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
| 166 |
""" forward step """
|
|
|
|
| 345 |
trainer.fit(self.T5Model, self.data_module)
|
| 346 |
|
| 347 |
def load_model(
|
| 348 |
+
self, model_type:str ='t5' , model_dir: str = "../../models", use_gpu: bool = False
|
| 349 |
):
|
| 350 |
"""
|
| 351 |
loads a checkpoint for inferencing/prediction
|
|
|
|
| 354 |
model_dir (str, optional): path to model directory. Defaults to "outputs".
|
| 355 |
use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
|
| 356 |
"""
|
| 357 |
+
if model_type == "t5":
|
| 358 |
+
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
|
| 359 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
| 360 |
+
f"{model_dir}", return_dict=True
|
| 361 |
+
)
|
| 362 |
+
elif model_type == "mt5":
|
| 363 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
|
| 364 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(
|
| 365 |
+
f"{model_dir}", return_dict=True
|
| 366 |
+
)
|
| 367 |
+
elif model_type == "byt5":
|
| 368 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
|
| 369 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
| 370 |
+
f"{model_dir}", return_dict=True
|
| 371 |
+
)
|
| 372 |
|
| 373 |
if use_gpu:
|
| 374 |
if torch.cuda.is_available():
|