File size: 2,922 Bytes
ba006b9 0ba7b45 643063e ba006b9 ca23a08 0ba7b45 ca23a08 0ba7b45 ca23a08 7956b7a f3f859d ca23a08 f3f859d 814721c 0ba7b45 ba006b9 ca23a08 ba006b9 7956b7a ba006b9 7956b7a ba006b9 7956b7a ba006b9 ca23a08 ba006b9 f3f859d ca23a08 ba006b9 814721c ba006b9 f3f859d ba006b9 0ba7b45 ba006b9 0ba7b45 ca23a08 f3f859d ba006b9 f3f859d ba006b9 ca23a08 ba006b9 f3f859d ba006b9 f3f859d ba006b9 f3f859d ba006b9 0ba7b45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import wandb
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning import Trainer
from mentioned.model import LitMentionDetector, ModelRegistry
from mentioned.data import DataRegistry
def train(
model_factory: str = "model_v1",
data_factory: str = "litbank_mentions",
repo_id: str = "kadarakos/mention-detector-poc-dry-run",
project_name: str = "mention-detector-poc",
encoder_id: str = "distilroberta-base",
patience: int = 5,
val_interval: int = 1000,
stop_criterion: str = "val_f1_mention",
max_epochs: int | None = None,
):
if max_epochs is None:
max_epochs = 1000
data = DataRegistry.get(data_factory)()
model = ModelRegistry.get(model_factory)(data, encoder_id)
wandb_logger = WandbLogger(
project=project_name,
name=encoder_id,
)
best_checkpoint = ModelCheckpoint(
monitor=stop_criterion,
mode="max",
save_top_k=1,
filename=f"best-{stop_criterion}",
verbose=True,
)
early_stopper = EarlyStopping(
monitor=stop_criterion,
min_delta=0.01,
patience=patience,
verbose=True,
mode="max",
)
trainer = Trainer(
max_epochs=max_epochs, # Now configurable
val_check_interval=val_interval,
callbacks=[early_stopper, best_checkpoint],
logger=wandb_logger,
accelerator="auto",
)
print(f"Starting Trainer for {max_epochs} epochs.")
trainer.fit(
model=model,
train_dataloaders=data.train_loader,
val_dataloaders=data.val_loader,
)
trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False)
print(f"Pushing best model to: {repo_id}")
fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id)
labeler = getattr(fresh_bundle, "mention_labeler", None)
l2id = getattr(fresh_bundle, "label2id", None)
best_model = LitMentionDetector.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
tokenizer=fresh_bundle.tokenizer,
encoder=fresh_bundle.encoder,
mention_detector=fresh_bundle.mention_detector,
label2id=l2id,
mention_labeler=labeler,
weights_only=False,
)
best_model.push_to_hub(repo_id, private=True)
wandb.finish()
print("Verifying Hub upload by pulling and re-evaluating...")
remote_model = LitMentionDetector.from_pretrained(
repo_id,
tokenizer=fresh_bundle.tokenizer,
encoder=fresh_bundle.encoder,
mention_detector=fresh_bundle.mention_detector,
label2id=l2id,
mention_labeler=labeler,
)
verify_trainer = Trainer(accelerator="auto", logger=False)
verify_trainer.test(model=remote_model, dataloaders=data.test_loader)
|