Spaces:
Runtime error
Runtime error
fixes
Browse files- src/models/model.py +7 -2
src/models/model.py
CHANGED
|
@@ -13,6 +13,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
| 13 |
from pytorch_lightning import LightningDataModule
|
| 14 |
from pytorch_lightning import LightningModule
|
| 15 |
from datasets import load_metric
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
| 18 |
|
|
@@ -331,7 +333,10 @@ class Summarization:
|
|
| 331 |
|
| 332 |
WandLogger = WandbLogger(project="summarization-dagshub")
|
| 333 |
|
| 334 |
-
# logger = DAGsHubLogger(metrics_path='reports/
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
early_stop_callback = (
|
| 337 |
[
|
|
@@ -360,7 +365,7 @@ class Summarization:
|
|
| 360 |
trainer.fit(self.T5Model, self.data_module)
|
| 361 |
|
| 362 |
def load_model(
|
| 363 |
-
self, model_type: str = 't5', model_dir: str = "
|
| 364 |
):
|
| 365 |
"""
|
| 366 |
loads a checkpoint for inferencing/prediction
|
|
|
|
| 13 |
from pytorch_lightning import LightningDataModule
|
| 14 |
from pytorch_lightning import LightningModule
|
| 15 |
from datasets import load_metric
|
| 16 |
+
from tqdm.auto import tqdm
|
| 17 |
+
|
| 18 |
|
| 19 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
| 20 |
|
|
|
|
| 333 |
|
| 334 |
WandLogger = WandbLogger(project="summarization-dagshub")
|
| 335 |
|
| 336 |
+
# logger = DAGsHubLogger(metrics_path='reports/training_metrics.txt')
|
| 337 |
+
|
| 338 |
+
df = pd.read_json(r'wandb/latest-run/files/wandb-summary.json')
|
| 339 |
+
df.to_csv(r'reports/training_metrics.txt', index=False)
|
| 340 |
|
| 341 |
early_stop_callback = (
|
| 342 |
[
|
|
|
|
| 365 |
trainer.fit(self.T5Model, self.data_module)
|
| 366 |
|
| 367 |
def load_model(
|
| 368 |
+
self, model_type: str = 't5', model_dir: str = "models", use_gpu: bool = False
|
| 369 |
):
|
| 370 |
"""
|
| 371 |
loads a checkpoint for inferencing/prediction
|