File size: 2,922 Bytes
ba006b9
 
 
 
 
 
0ba7b45
643063e
ba006b9
 
ca23a08
 
0ba7b45
ca23a08
 
0ba7b45
ca23a08
 
7956b7a
f3f859d
ca23a08
f3f859d
814721c
0ba7b45
 
ba006b9
ca23a08
 
ba006b9
 
7956b7a
ba006b9
 
7956b7a
ba006b9
 
 
7956b7a
ba006b9
ca23a08
ba006b9
 
 
 
f3f859d
ca23a08
ba006b9
 
814721c
ba006b9
f3f859d
ba006b9
 
0ba7b45
 
ba006b9
0ba7b45
ca23a08
f3f859d
 
 
 
ba006b9
 
f3f859d
 
 
 
 
ba006b9
 
ca23a08
ba006b9
 
f3f859d
ba006b9
 
f3f859d
 
 
 
 
ba006b9
f3f859d
ba006b9
0ba7b45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import wandb
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning import Trainer

from mentioned.model import LitMentionDetector, ModelRegistry
from mentioned.data import DataRegistry


def train(
    model_factory: str = "model_v1",
    data_factory: str = "litbank_mentions",
    repo_id: str = "kadarakos/mention-detector-poc-dry-run",
    project_name: str = "mention-detector-poc",
    encoder_id: str = "distilroberta-base",
    patience: int = 5,
    val_interval: int = 1000,
    stop_criterion: str = "val_f1_mention",
    max_epochs: int | None = None,
):
    if max_epochs is None:
        max_epochs = 1000
    data = DataRegistry.get(data_factory)()
    model = ModelRegistry.get(model_factory)(data, encoder_id)
    wandb_logger = WandbLogger(
        project=project_name,
        name=encoder_id,
    )
    best_checkpoint = ModelCheckpoint(
        monitor=stop_criterion,
        mode="max",
        save_top_k=1,
        filename=f"best-{stop_criterion}",
        verbose=True,
    )
    early_stopper = EarlyStopping(
        monitor=stop_criterion,
        min_delta=0.01,
        patience=patience,
        verbose=True,
        mode="max",
    )
    trainer = Trainer(
        max_epochs=max_epochs,      # Now configurable
        val_check_interval=val_interval,
        callbacks=[early_stopper, best_checkpoint],
        logger=wandb_logger,
        accelerator="auto",
    )
    print(f"Starting Trainer for {max_epochs} epochs.")
    trainer.fit(
        model=model,
        train_dataloaders=data.train_loader,
        val_dataloaders=data.val_loader,
    )
    trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False)
    print(f"Pushing best model to: {repo_id}")
    fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id)
    labeler = getattr(fresh_bundle, "mention_labeler", None)
    l2id = getattr(fresh_bundle, "label2id", None)

    best_model = LitMentionDetector.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path,
        tokenizer=fresh_bundle.tokenizer,
        encoder=fresh_bundle.encoder,
        mention_detector=fresh_bundle.mention_detector,
        label2id=l2id,
        mention_labeler=labeler,
        weights_only=False,
    )
    best_model.push_to_hub(repo_id, private=True)
    wandb.finish()

    print("Verifying Hub upload by pulling and re-evaluating...")
    remote_model = LitMentionDetector.from_pretrained(
        repo_id,
        tokenizer=fresh_bundle.tokenizer,
        encoder=fresh_bundle.encoder,
        mention_detector=fresh_bundle.mention_detector,
        label2id=l2id,
        mention_labeler=labeler,
    )

    verify_trainer = Trainer(accelerator="auto", logger=False)
    verify_trainer.test(model=remote_model, dataloaders=data.test_loader)