|
|
from lightning.pytorch import seed_everything |
|
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
|
|
import lightning.pytorch as pl |
|
|
import pandas as pd |
|
|
from sklearn.model_selection import train_test_split |
|
|
from transformers import AutoTokenizer |
|
|
from ast import literal_eval |
|
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
|
|
|
|
|
|
import config |
|
|
from model import CL_model |
|
|
from dataset import CLDataModule |
|
|
|
|
|
if __name__ == "__main__": |
|
|
seed_everything(0, workers=True) |
|
|
logger = TensorBoardLogger( |
|
|
"/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL" |
|
|
) |
|
|
|
|
|
query_df = pd.read_csv( |
|
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" |
|
|
) |
|
|
query_df["concepts"] = query_df["concepts"].apply(literal_eval) |
|
|
query_df["codes"] = query_df["codes"].apply(literal_eval) |
|
|
query_df["codes"] = query_df["codes"].apply( |
|
|
lambda x: [val for val in x if val is not None] |
|
|
) |
|
|
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) |
|
|
|
|
|
all_d = pd.read_csv( |
|
|
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" |
|
|
) |
|
|
all_d.drop(columns=["finding_sites", "morphology"], inplace=True) |
|
|
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) |
|
|
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) |
|
|
dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) |
|
|
|
|
|
pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
|
|
|
data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs) |
|
|
data_module.setup() |
|
|
|
|
|
model = CL_model( |
|
|
n_batches=len(data_module.train_dataset) / config.batch_size, |
|
|
n_epochs=config.max_epochs, |
|
|
lr=config.learning_rate, |
|
|
mlm_weight=config.mlm_weight, |
|
|
unfreeze=config.unfreeze_ratio, |
|
|
) |
|
|
|
|
|
checkpoint = ModelCheckpoint( |
|
|
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2", |
|
|
filename="{epoch}-{step}", |
|
|
save_weights_only=True, |
|
|
save_last=True, |
|
|
every_n_train_steps=config.ckcpt_every_n_steps, |
|
|
monitor=None, |
|
|
save_top_k=-1, |
|
|
) |
|
|
|
|
|
trainer = pl.Trainer( |
|
|
accelerator=config.accelerator, |
|
|
devices=config.devices, |
|
|
strategy="ddp", |
|
|
logger=logger, |
|
|
max_epochs=config.max_epochs, |
|
|
min_epochs=config.min_epochs, |
|
|
precision=config.precision, |
|
|
callbacks=[ |
|
|
EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"), |
|
|
checkpoint, |
|
|
], |
|
|
profiler="simple", |
|
|
log_every_n_steps=config.log_every_n_steps, |
|
|
) |
|
|
|
|
|
trainer.fit(model, data_module) |
|
|
|