File size: 2,874 Bytes
975624b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import lightning.pytorch as pl
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from ast import literal_eval
from pytorch_lightning.loggers import TensorBoardLogger

# imports from our own modules
import config
from model import CL_model
from dataset import CLDataModule

if __name__ == "__main__":
    seed_everything(0, workers=True)
    logger = TensorBoardLogger(
        "/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL"
    )

    query_df = pd.read_csv(
        "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
    )
    query_df["concepts"] = query_df["concepts"].apply(literal_eval)
    query_df["codes"] = query_df["codes"].apply(literal_eval)
    query_df["codes"] = query_df["codes"].apply(
        lambda x: [val for val in x if val is not None]
    )
    train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)

    all_d = pd.read_csv(
        "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
    )
    all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
    all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
    all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
    dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))

    pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")

    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
    data_module.setup()

    model = CL_model(
        n_batches=len(data_module.train_dataset) / config.batch_size,
        n_epochs=config.max_epochs,
        lr=config.learning_rate,
        mlm_weight=config.mlm_weight,
        unfreeze=config.unfreeze_ratio,
    )

    checkpoint = ModelCheckpoint(
        dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2",
        filename="{epoch}-{step}",
        save_weights_only=True,
        save_last=True,
        every_n_train_steps=config.ckcpt_every_n_steps,
        monitor=None,
        save_top_k=-1,
    )

    trainer = pl.Trainer(
        accelerator=config.accelerator,
        devices=config.devices,
        strategy="ddp",
        logger=logger,
        max_epochs=config.max_epochs,
        min_epochs=config.min_epochs,
        precision=config.precision,
        callbacks=[
            EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"),
            checkpoint,
        ],
        profiler="simple",
        log_every_n_steps=config.log_every_n_steps,
    )

    trainer.fit(model, data_module)