kadarakos commited on
Commit
7956b7a
·
1 Parent(s): 436d409

add configurable stop criterion

Browse files
Files changed (1) hide show
  1. src/mentioned/train.py +4 -3
src/mentioned/train.py CHANGED
@@ -17,6 +17,7 @@ def train(
17
  encoder_id: str = "distilroberta-base",
18
  patience: int = 5,
19
  val_interval: int = 1000,
 
20
  ):
21
  data = DataRegistry.get(data_factory)()
22
  model = ModelRegistry.get(model_factory)(data, encoder_id)
@@ -26,14 +27,14 @@ def train(
26
  )
27
  # Save only the best model for the PoC purposes.
28
  best_checkpoint = ModelCheckpoint(
29
- monitor="val_f1_mention",
30
  mode="max",
31
  save_top_k=1,
32
- filename="best-mention-f1",
33
  verbose=True,
34
  )
35
  early_stopper = EarlyStopping(
36
- monitor="val_f1_mention",
37
  min_delta=0.01,
38
  patience=patience,
39
  verbose=True,
 
17
  encoder_id: str = "distilroberta-base",
18
  patience: int = 5,
19
  val_interval: int = 1000,
20
+ stop_criterion: str = "val_f1_mention",
21
  ):
22
  data = DataRegistry.get(data_factory)()
23
  model = ModelRegistry.get(model_factory)(data, encoder_id)
 
27
  )
28
  # Save only the best model for the PoC purposes.
29
  best_checkpoint = ModelCheckpoint(
30
+ monitor=stop_criterion,
31
  mode="max",
32
  save_top_k=1,
33
+ filename=f"best-{stop_criterion}",
34
  verbose=True,
35
  )
36
  early_stopper = EarlyStopping(
37
+ monitor=stop_criterion,
38
  min_delta=0.01,
39
  patience=patience,
40
  verbose=True,