formatting
Browse files- src/mentioned/train.py +2 -12
src/mentioned/train.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import wandb
|
| 2 |
-
import torch
|
| 3 |
from lightning.pytorch.loggers import WandbLogger
|
| 4 |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
| 5 |
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
@@ -19,20 +18,15 @@ def train(
|
|
| 19 |
val_interval: int = 1000,
|
| 20 |
stop_criterion: str = "val_f1_mention",
|
| 21 |
max_epochs: int | None = None,
|
| 22 |
-
**kwargs,
|
| 23 |
):
|
| 24 |
if max_epochs is None:
|
| 25 |
-
max_epochs =
|
| 26 |
-
# 1. Data and Model Architecture Setup
|
| 27 |
data = DataRegistry.get(data_factory)()
|
| 28 |
model = ModelRegistry.get(model_factory)(data, encoder_id)
|
| 29 |
-
|
| 30 |
wandb_logger = WandbLogger(
|
| 31 |
project=project_name,
|
| 32 |
name=encoder_id,
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
# 2. Callbacks for Training
|
| 36 |
best_checkpoint = ModelCheckpoint(
|
| 37 |
monitor=stop_criterion,
|
| 38 |
mode="max",
|
|
@@ -40,7 +34,6 @@ def train(
|
|
| 40 |
filename=f"best-{stop_criterion}",
|
| 41 |
verbose=True,
|
| 42 |
)
|
| 43 |
-
|
| 44 |
early_stopper = EarlyStopping(
|
| 45 |
monitor=stop_criterion,
|
| 46 |
min_delta=0.01,
|
|
@@ -48,14 +41,12 @@ def train(
|
|
| 48 |
verbose=True,
|
| 49 |
mode="max",
|
| 50 |
)
|
| 51 |
-
|
| 52 |
-
# 3. Training Execution
|
| 53 |
trainer = Trainer(
|
| 54 |
max_epochs=max_epochs, # Now configurable
|
| 55 |
val_check_interval=val_interval,
|
| 56 |
callbacks=[early_stopper, best_checkpoint],
|
| 57 |
logger=wandb_logger,
|
| 58 |
-
|
| 59 |
)
|
| 60 |
print(f"Starting Trainer for {max_epochs} epochs.")
|
| 61 |
trainer.fit(
|
|
@@ -81,7 +72,6 @@ def train(
|
|
| 81 |
best_model.push_to_hub(repo_id, private=True)
|
| 82 |
wandb.finish()
|
| 83 |
|
| 84 |
-
# 6. Verification: Pull from Hub and Test
|
| 85 |
print("Verifying Hub upload by pulling and re-evaluating...")
|
| 86 |
remote_model = LitMentionDetector.from_pretrained(
|
| 87 |
repo_id,
|
|
|
|
| 1 |
import wandb
|
|
|
|
| 2 |
from lightning.pytorch.loggers import WandbLogger
|
| 3 |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
| 4 |
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
|
| 18 |
val_interval: int = 1000,
|
| 19 |
stop_criterion: str = "val_f1_mention",
|
| 20 |
max_epochs: int | None = None,
|
|
|
|
| 21 |
):
|
| 22 |
if max_epochs is None:
|
| 23 |
+
max_epochs = 1000
|
|
|
|
| 24 |
data = DataRegistry.get(data_factory)()
|
| 25 |
model = ModelRegistry.get(model_factory)(data, encoder_id)
|
|
|
|
| 26 |
wandb_logger = WandbLogger(
|
| 27 |
project=project_name,
|
| 28 |
name=encoder_id,
|
| 29 |
)
|
|
|
|
|
|
|
| 30 |
best_checkpoint = ModelCheckpoint(
|
| 31 |
monitor=stop_criterion,
|
| 32 |
mode="max",
|
|
|
|
| 34 |
filename=f"best-{stop_criterion}",
|
| 35 |
verbose=True,
|
| 36 |
)
|
|
|
|
| 37 |
early_stopper = EarlyStopping(
|
| 38 |
monitor=stop_criterion,
|
| 39 |
min_delta=0.01,
|
|
|
|
| 41 |
verbose=True,
|
| 42 |
mode="max",
|
| 43 |
)
|
|
|
|
|
|
|
| 44 |
trainer = Trainer(
|
| 45 |
max_epochs=max_epochs, # Now configurable
|
| 46 |
val_check_interval=val_interval,
|
| 47 |
callbacks=[early_stopper, best_checkpoint],
|
| 48 |
logger=wandb_logger,
|
| 49 |
+
accelerator="auto",
|
| 50 |
)
|
| 51 |
print(f"Starting Trainer for {max_epochs} epochs.")
|
| 52 |
trainer.fit(
|
|
|
|
| 72 |
best_model.push_to_hub(repo_id, private=True)
|
| 73 |
wandb.finish()
|
| 74 |
|
|
|
|
| 75 |
print("Verifying Hub upload by pulling and re-evaluating...")
|
| 76 |
remote_model = LitMentionDetector.from_pretrained(
|
| 77 |
repo_id,
|