File size: 2,874 Bytes
975624b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import lightning.pytorch as pl
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from ast import literal_eval
from pytorch_lightning.loggers import TensorBoardLogger
# imports from our own modules
import config
from model import CL_model
from dataset import CLDataModule
if __name__ == "__main__":
seed_everything(0, workers=True)
logger = TensorBoardLogger(
"/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL"
)
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
)
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
)
all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
data_module.setup()
model = CL_model(
n_batches=len(data_module.train_dataset) / config.batch_size,
n_epochs=config.max_epochs,
lr=config.learning_rate,
mlm_weight=config.mlm_weight,
unfreeze=config.unfreeze_ratio,
)
checkpoint = ModelCheckpoint(
dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2",
filename="{epoch}-{step}",
save_weights_only=True,
save_last=True,
every_n_train_steps=config.ckcpt_every_n_steps,
monitor=None,
save_top_k=-1,
)
trainer = pl.Trainer(
accelerator=config.accelerator,
devices=config.devices,
strategy="ddp",
logger=logger,
max_epochs=config.max_epochs,
min_epochs=config.min_epochs,
precision=config.precision,
callbacks=[
EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"),
checkpoint,
],
profiler="simple",
log_every_n_steps=config.log_every_n_steps,
)
trainer.fit(model, data_module)
|