Spaces:
Runtime error
Runtime error
fixes
Browse files- src/models/model.py +9 -9
src/models/model.py
CHANGED
|
@@ -7,7 +7,7 @@ from transformers import (
|
|
| 7 |
)
|
| 8 |
from torch.utils.data import Dataset, DataLoader
|
| 9 |
import pytorch_lightning as pl
|
| 10 |
-
from pytorch_lightning.loggers import MLFlowLogger
|
| 11 |
from pytorch_lightning import Trainer
|
| 12 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 13 |
from pytorch_lightning import LightningDataModule
|
|
@@ -328,12 +328,12 @@ class Summarization:
|
|
| 328 |
)
|
| 329 |
|
| 330 |
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
| 331 |
-
|
| 332 |
-
save_dir=
|
| 333 |
|
| 334 |
# WandLogger = WandbLogger(project="summarization-dagshub")
|
| 335 |
|
| 336 |
-
#logger = DAGsHubLogger(metrics_path='reports/training_metrics.txt')
|
| 337 |
|
| 338 |
early_stop_callback = (
|
| 339 |
[
|
|
@@ -352,7 +352,7 @@ class Summarization:
|
|
| 352 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
| 353 |
|
| 354 |
trainer = Trainer(
|
| 355 |
-
logger=
|
| 356 |
callbacks=early_stop_callback,
|
| 357 |
max_epochs=max_epochs,
|
| 358 |
gpus=gpus,
|
|
@@ -460,10 +460,10 @@ class Summarization:
|
|
| 460 |
num_return_sequences=num_return_sequences,
|
| 461 |
)
|
| 462 |
preds = self.tokenizer.decode(
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
return preds
|
| 468 |
|
| 469 |
def evaluate(
|
|
|
|
| 7 |
)
|
| 8 |
from torch.utils.data import Dataset, DataLoader
|
| 9 |
import pytorch_lightning as pl
|
| 10 |
+
from pytorch_lightning.loggers import MLFlowLogger
|
| 11 |
from pytorch_lightning import Trainer
|
| 12 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 13 |
from pytorch_lightning import LightningDataModule
|
|
|
|
| 328 |
)
|
| 329 |
|
| 330 |
MLlogger = MLFlowLogger(experiment_name="Summarization",
|
| 331 |
+
tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
| 332 |
+
#save_dir="reports/training_metrics.txt"
|
| 333 |
|
| 334 |
# WandLogger = WandbLogger(project="summarization-dagshub")
|
| 335 |
|
| 336 |
+
# logger = DAGsHubLogger(metrics_path='reports/training_metrics.txt')
|
| 337 |
|
| 338 |
early_stop_callback = (
|
| 339 |
[
|
|
|
|
| 352 |
gpus = -1 if use_gpu and torch.cuda.is_available() else 0
|
| 353 |
|
| 354 |
trainer = Trainer(
|
| 355 |
+
logger=MLlogger,
|
| 356 |
callbacks=early_stop_callback,
|
| 357 |
max_epochs=max_epochs,
|
| 358 |
gpus=gpus,
|
|
|
|
| 460 |
num_return_sequences=num_return_sequences,
|
| 461 |
)
|
| 462 |
preds = self.tokenizer.decode(
|
| 463 |
+
generated_ids[0],
|
| 464 |
+
skip_special_tokens=skip_special_tokens,
|
| 465 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 466 |
+
)
|
| 467 |
return preds
|
| 468 |
|
| 469 |
def evaluate(
|