kadarakos commited on
Commit
814721c
·
1 Parent(s): 693dd2b

formatting

Browse files
Files changed (1) hide show
  1. src/mentioned/train.py +2 -12
src/mentioned/train.py CHANGED
@@ -1,5 +1,4 @@
1
  import wandb
2
- import torch
3
  from lightning.pytorch.loggers import WandbLogger
4
  from lightning.pytorch.callbacks.early_stopping import EarlyStopping
5
  from lightning.pytorch.callbacks import ModelCheckpoint
@@ -19,20 +18,15 @@ def train(
19
  val_interval: int = 1000,
20
  stop_criterion: str = "val_f1_mention",
21
  max_epochs: int | None = None,
22
- **kwargs,
23
  ):
24
  if max_epochs is None:
25
- max_epochs = 999
26
- # 1. Data and Model Architecture Setup
27
  data = DataRegistry.get(data_factory)()
28
  model = ModelRegistry.get(model_factory)(data, encoder_id)
29
-
30
  wandb_logger = WandbLogger(
31
  project=project_name,
32
  name=encoder_id,
33
  )
34
-
35
- # 2. Callbacks for Training
36
  best_checkpoint = ModelCheckpoint(
37
  monitor=stop_criterion,
38
  mode="max",
@@ -40,7 +34,6 @@ def train(
40
  filename=f"best-{stop_criterion}",
41
  verbose=True,
42
  )
43
-
44
  early_stopper = EarlyStopping(
45
  monitor=stop_criterion,
46
  min_delta=0.01,
@@ -48,14 +41,12 @@ def train(
48
  verbose=True,
49
  mode="max",
50
  )
51
-
52
- # 3. Training Execution
53
  trainer = Trainer(
54
  max_epochs=max_epochs, # Now configurable
55
  val_check_interval=val_interval,
56
  callbacks=[early_stopper, best_checkpoint],
57
  logger=wandb_logger,
58
- **kwargs,
59
  )
60
  print(f"Starting Trainer for {max_epochs} epochs.")
61
  trainer.fit(
@@ -81,7 +72,6 @@ def train(
81
  best_model.push_to_hub(repo_id, private=True)
82
  wandb.finish()
83
 
84
- # 6. Verification: Pull from Hub and Test
85
  print("Verifying Hub upload by pulling and re-evaluating...")
86
  remote_model = LitMentionDetector.from_pretrained(
87
  repo_id,
 
1
  import wandb
 
2
  from lightning.pytorch.loggers import WandbLogger
3
  from lightning.pytorch.callbacks.early_stopping import EarlyStopping
4
  from lightning.pytorch.callbacks import ModelCheckpoint
 
18
  val_interval: int = 1000,
19
  stop_criterion: str = "val_f1_mention",
20
  max_epochs: int | None = None,
 
21
  ):
22
  if max_epochs is None:
23
+ max_epochs = 1000
 
24
  data = DataRegistry.get(data_factory)()
25
  model = ModelRegistry.get(model_factory)(data, encoder_id)
 
26
  wandb_logger = WandbLogger(
27
  project=project_name,
28
  name=encoder_id,
29
  )
 
 
30
  best_checkpoint = ModelCheckpoint(
31
  monitor=stop_criterion,
32
  mode="max",
 
34
  filename=f"best-{stop_criterion}",
35
  verbose=True,
36
  )
 
37
  early_stopper = EarlyStopping(
38
  monitor=stop_criterion,
39
  min_delta=0.01,
 
41
  verbose=True,
42
  mode="max",
43
  )
 
 
44
  trainer = Trainer(
45
  max_epochs=max_epochs, # Now configurable
46
  val_check_interval=val_interval,
47
  callbacks=[early_stopper, best_checkpoint],
48
  logger=wandb_logger,
49
+ accelerator="auto",
50
  )
51
  print(f"Starting Trainer for {max_epochs} epochs.")
52
  trainer.fit(
 
72
  best_model.push_to_hub(repo_id, private=True)
73
  wandb.finish()
74
 
 
75
  print("Verifying Hub upload by pulling and re-evaluating...")
76
  remote_model = LitMentionDetector.from_pretrained(
77
  repo_id,