kadarakos commited on
Commit
f3f859d
·
1 Parent(s): 7956b7a

fix deserialization when using labeler

Browse files
Files changed (1) hide show
  1. src/mentioned/train.py +30 -13
src/mentioned/train.py CHANGED
@@ -1,5 +1,5 @@
1
  import wandb
2
-
3
  from lightning.pytorch.loggers import WandbLogger
4
  from lightning.pytorch.callbacks.early_stopping import EarlyStopping
5
  from lightning.pytorch.callbacks import ModelCheckpoint
@@ -18,14 +18,20 @@ def train(
18
  patience: int = 5,
19
  val_interval: int = 1000,
20
  stop_criterion: str = "val_f1_mention",
 
21
  ):
 
 
 
22
  data = DataRegistry.get(data_factory)()
23
  model = ModelRegistry.get(model_factory)(data, encoder_id)
 
24
  wandb_logger = WandbLogger(
25
  project=project_name,
26
  name=encoder_id,
27
  )
28
- # Save only the best model for the PoC purposes.
 
29
  best_checkpoint = ModelCheckpoint(
30
  monitor=stop_criterion,
31
  mode="max",
@@ -33,6 +39,7 @@ def train(
33
  filename=f"best-{stop_criterion}",
34
  verbose=True,
35
  )
 
36
  early_stopper = EarlyStopping(
37
  monitor=stop_criterion,
38
  min_delta=0.01,
@@ -40,13 +47,15 @@ def train(
40
  verbose=True,
41
  mode="max",
42
  )
 
 
43
  trainer = Trainer(
 
44
  val_check_interval=val_interval,
45
- check_val_every_n_epoch=None,
46
  callbacks=[early_stopper, best_checkpoint],
47
  logger=wandb_logger,
48
  )
49
- print("Starting Trainer.")
50
  trainer.fit(
51
  model=model,
52
  train_dataloaders=data.train_loader,
@@ -54,24 +63,32 @@ def train(
54
  )
55
  trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False)
56
  print(f"Pushing best model to: {repo_id}")
57
- fresh_model = ModelRegistry.get(model_factory)(data, encoder_id)
 
 
 
58
  best_model = LitMentionDetector.load_from_checkpoint(
59
  trainer.checkpoint_callback.best_model_path,
60
- tokenizer=fresh_model.tokenizer,
61
- encoder=fresh_model.encoder,
62
- mention_detector=fresh_model.mention_detector,
 
 
63
  weights_only=False,
64
  )
65
  best_model.push_to_hub(repo_id, private=True)
66
  wandb.finish()
67
 
68
- print("Testing pulling and evaluating the model.")
69
- fresh_model = ModelRegistry.get(model_factory)(data, encoder_id)
70
  remote_model = LitMentionDetector.from_pretrained(
71
  repo_id,
72
- tokenizer=fresh_model.tokenizer,
73
- encoder=fresh_model.encoder,
74
- mention_detector=fresh_model.mention_detector,
 
 
75
  )
 
76
  verify_trainer = Trainer(accelerator="auto", logger=False)
77
  verify_trainer.test(model=remote_model, dataloaders=data.test_loader)
 
1
  import wandb
2
+ import torch
3
  from lightning.pytorch.loggers import WandbLogger
4
  from lightning.pytorch.callbacks.early_stopping import EarlyStopping
5
  from lightning.pytorch.callbacks import ModelCheckpoint
 
18
  patience: int = 5,
19
  val_interval: int = 1000,
20
  stop_criterion: str = "val_f1_mention",
21
+ max_epochs: int | None = None,
22
  ):
23
+ if max_epochs is None:
24
+ max_epochs = 999
25
+ # 1. Data and Model Architecture Setup
26
  data = DataRegistry.get(data_factory)()
27
  model = ModelRegistry.get(model_factory)(data, encoder_id)
28
+
29
  wandb_logger = WandbLogger(
30
  project=project_name,
31
  name=encoder_id,
32
  )
33
+
34
+ # 2. Callbacks for Training
35
  best_checkpoint = ModelCheckpoint(
36
  monitor=stop_criterion,
37
  mode="max",
 
39
  filename=f"best-{stop_criterion}",
40
  verbose=True,
41
  )
42
+
43
  early_stopper = EarlyStopping(
44
  monitor=stop_criterion,
45
  min_delta=0.01,
 
47
  verbose=True,
48
  mode="max",
49
  )
50
+
51
+ # 3. Training Execution
52
  trainer = Trainer(
53
+ max_epochs=max_epochs, # Now configurable
54
  val_check_interval=val_interval,
 
55
  callbacks=[early_stopper, best_checkpoint],
56
  logger=wandb_logger,
57
  )
58
+ print(f"Starting Trainer for {max_epochs} epochs.")
59
  trainer.fit(
60
  model=model,
61
  train_dataloaders=data.train_loader,
 
63
  )
64
  trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False)
65
  print(f"Pushing best model to: {repo_id}")
66
+ fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id)
67
+ labeler = getattr(fresh_bundle, "mention_labeler", None)
68
+ l2id = getattr(fresh_bundle, "label2id", None)
69
+
70
  best_model = LitMentionDetector.load_from_checkpoint(
71
  trainer.checkpoint_callback.best_model_path,
72
+ tokenizer=fresh_bundle.tokenizer,
73
+ encoder=fresh_bundle.encoder,
74
+ mention_detector=fresh_bundle.mention_detector,
75
+ label2id=l2id,
76
+ mention_labeler=labeler,
77
  weights_only=False,
78
  )
79
  best_model.push_to_hub(repo_id, private=True)
80
  wandb.finish()
81
 
82
+ # 6. Verification: Pull from Hub and Test
83
+ print("Verifying Hub upload by pulling and re-evaluating...")
84
  remote_model = LitMentionDetector.from_pretrained(
85
  repo_id,
86
+ tokenizer=fresh_bundle.tokenizer,
87
+ encoder=fresh_bundle.encoder,
88
+ mention_detector=fresh_bundle.mention_detector,
89
+ label2id=l2id,
90
+ mention_labeler=labeler,
91
  )
92
+
93
  verify_trainer = Trainer(accelerator="auto", logger=False)
94
  verify_trainer.test(model=remote_model, dataloaders=data.test_loader)