Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/model.py +3 -3
src/models/model.py
CHANGED
|
@@ -6,7 +6,7 @@ from dagshub.pytorch_lightning import DAGsHubLogger
|
|
| 6 |
from transformers import (
|
| 7 |
AdamW,
|
| 8 |
T5ForConditionalGeneration,
|
| 9 |
-
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration,ByT5Tokenizer,
|
| 10 |
)
|
| 11 |
from torch.utils.data import Dataset, DataLoader
|
| 12 |
import pytorch_lightning as pl
|
|
@@ -248,7 +248,7 @@ class Summarization:
|
|
| 248 |
""" initiates Summarization class """
|
| 249 |
pass
|
| 250 |
|
| 251 |
-
def from_pretrained(self,
|
| 252 |
"""
|
| 253 |
loads T5/MT5 Model model for training/finetuning
|
| 254 |
Args:
|
|
@@ -345,7 +345,7 @@ class Summarization:
|
|
| 345 |
trainer.fit(self.T5Model, self.data_module)
|
| 346 |
|
| 347 |
def load_model(
|
| 348 |
-
self, model_type:str ='t5'
|
| 349 |
):
|
| 350 |
"""
|
| 351 |
loads a checkpoint for inferencing/prediction
|
|
|
|
| 6 |
from transformers import (
|
| 7 |
AdamW,
|
| 8 |
T5ForConditionalGeneration,
|
| 9 |
+
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration, ByT5Tokenizer,
|
| 10 |
)
|
| 11 |
from torch.utils.data import Dataset, DataLoader
|
| 12 |
import pytorch_lightning as pl
|
|
|
|
| 248 |
""" initiates Summarization class """
|
| 249 |
pass
|
| 250 |
|
| 251 |
+
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|
| 252 |
"""
|
| 253 |
loads T5/MT5 Model model for training/finetuning
|
| 254 |
Args:
|
|
|
|
| 345 |
trainer.fit(self.T5Model, self.data_module)
|
| 346 |
|
| 347 |
def load_model(
|
| 348 |
+
self, model_type: str = 't5', model_dir: str = "../../models", use_gpu: bool = False
|
| 349 |
):
|
| 350 |
"""
|
| 351 |
loads a checkpoint for inferencing/prediction
|