from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks.early_stopping import EarlyStopping import lightning.pytorch as pl import pandas as pd from sklearn.model_selection import train_test_split from transformers import AutoTokenizer from ast import literal_eval from pytorch_lightning.loggers import TensorBoardLogger # imports from our own modules import config from model import CL_model from dataset import CLDataModule if __name__ == "__main__": seed_everything(0, workers=True) logger = TensorBoardLogger( "/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL" ) query_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" ) query_df["concepts"] = query_df["concepts"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply( lambda x: [val for val in x if val is not None] ) train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) all_d = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" ) all_d.drop(columns=["finding_sites", "morphology"], inplace=True) all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv") tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs) data_module.setup() model = CL_model( n_batches=len(data_module.train_dataset) / config.batch_size, n_epochs=config.max_epochs, lr=config.learning_rate, mlm_weight=config.mlm_weight, unfreeze=config.unfreeze_ratio, ) checkpoint = ModelCheckpoint( dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2", filename="{epoch}-{step}", save_weights_only=True, save_last=True, every_n_train_steps=config.ckcpt_every_n_steps, monitor=None, save_top_k=-1, ) trainer = pl.Trainer( accelerator=config.accelerator, devices=config.devices, strategy="ddp", logger=logger, max_epochs=config.max_epochs, min_epochs=config.min_epochs, precision=config.precision, callbacks=[ EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"), checkpoint, ], profiler="simple", log_every_n_steps=config.log_every_n_steps, ) trainer.fit(model, data_module)