CHOPT-NEW / train.py
sxtforreal's picture
Upload 5 files
975624b verified
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import lightning.pytorch as pl
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from ast import literal_eval
from pytorch_lightning.loggers import TensorBoardLogger
# imports from our own modules
import config
from model import CL_model
from dataset import CLDataModule
if __name__ == "__main__":
seed_everything(0, workers=True)
logger = TensorBoardLogger(
"/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL"
)
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
)
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
)
all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
data_module.setup()
model = CL_model(
n_batches=len(data_module.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
mlm_weight=config.mlm_weight,
unfreeze=config.unfreeze_ratio,
)
checkpoint = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2",
filename="{epoch}-{step}",
save_weights_only=True,
save_last=True,
every_n_train_steps=config.ckcpt_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"),
checkpoint,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer.fit(model, data_module)