import wandb from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks import ModelCheckpoint from lightning import Trainer from mentioned.model import LitMentionDetector, ModelRegistry from mentioned.data import DataRegistry def train( model_factory: str = "model_v1", data_factory: str = "litbank_mentions", repo_id: str = "kadarakos/mention-detector-poc-dry-run", project_name: str = "mention-detector-poc", encoder_id: str = "distilroberta-base", patience: int = 5, val_interval: int = 1000, stop_criterion: str = "val_f1_mention", max_epochs: int | None = None, ): if max_epochs is None: max_epochs = 1000 data = DataRegistry.get(data_factory)() model = ModelRegistry.get(model_factory)(data, encoder_id) wandb_logger = WandbLogger( project=project_name, name=encoder_id, ) best_checkpoint = ModelCheckpoint( monitor=stop_criterion, mode="max", save_top_k=1, filename=f"best-{stop_criterion}", verbose=True, ) early_stopper = EarlyStopping( monitor=stop_criterion, min_delta=0.01, patience=patience, verbose=True, mode="max", ) trainer = Trainer( max_epochs=max_epochs, # Now configurable val_check_interval=val_interval, callbacks=[early_stopper, best_checkpoint], logger=wandb_logger, accelerator="auto", ) print(f"Starting Trainer for {max_epochs} epochs.") trainer.fit( model=model, train_dataloaders=data.train_loader, val_dataloaders=data.val_loader, ) trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False) print(f"Pushing best model to: {repo_id}") fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id) labeler = getattr(fresh_bundle, "mention_labeler", None) l2id = getattr(fresh_bundle, "label2id", None) best_model = LitMentionDetector.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, tokenizer=fresh_bundle.tokenizer, encoder=fresh_bundle.encoder, mention_detector=fresh_bundle.mention_detector, label2id=l2id, mention_labeler=labeler, weights_only=False, ) best_model.push_to_hub(repo_id, private=True) wandb.finish() print("Verifying Hub upload by pulling and re-evaluating...") remote_model = LitMentionDetector.from_pretrained( repo_id, tokenizer=fresh_bundle.tokenizer, encoder=fresh_bundle.encoder, mention_detector=fresh_bundle.mention_detector, label2id=l2id, mention_labeler=labeler, ) verify_trainer = Trainer(accelerator="auto", logger=False) verify_trainer.test(model=remote_model, dataloaders=data.test_loader)