Alhdrawi commited on
Commit
9b2bc96
·
verified ·
1 Parent(s): b91fa02

Upload run_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_train.py +143 -0
run_train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pprint
3
+ import argparse
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ from torch.utils import data
8
+ from torch import nn
9
+ import torch.optim as optim
10
+ from torchvision.transforms import Compose, Normalize, Resize
11
+
12
+ import clip
13
+ from model import CLIP
14
+ from simple_tokenizer import SimpleTokenizer
15
+
16
+ from train import train_main, load_data, load_clip, preprocess_text
17
+ from zero_shot import run_cxr_zero_shot, run_zero_shot
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--cxr_filepath', type=str, default='data/cxr.h5', help="Directory to load chest x-ray image data from.")
22
+ parser.add_argument('--txt_filepath', type=str, default='data/mimic_impressions.csv', help="Directory to load radiology report impressions text from.")
23
+ parser.add_argument('--batch_size', type=int, default=16)
24
+ parser.add_argument('--epochs', type=int, default=4)
25
+ parser.add_argument('--lr', type=float, default=1e-4)
26
+ parser.add_argument('--save_interval', type=int, default=100)
27
+ parser.add_argument('--log_interval', type=int, default=10)
28
+ parser.add_argument('--save_dir', type=str, default="checkpoints/", help="Directory to save the trained model.")
29
+ parser.add_argument('--seed', type=int, default=1234)
30
+ parser.add_argument('--optimizer', type=str, default="sgd")
31
+ parser.add_argument('--momentum', type=float, default=0.9)
32
+ parser.add_argument('--context_length', type=int, default=77)
33
+ parser.add_argument('--random_init', action='store_true')
34
+ parser.add_argument('--model_name', type=str, default="pt-imp")
35
+ args = parser.parse_args()
36
+ return args
37
+
38
+ def model_pipeline(config, verbose=0):
39
+ # make the model, data, and optimization problem
40
+ model, data_loader, device, criterion, optimizer = make(config)
41
+
42
+ # and use them to train the model
43
+ train(model, data_loader, device, criterion, optimizer, config)
44
+
45
+ # save model
46
+ model_path = os.path.join(config.save_dir, str(config.model_name), 'checkpoint.pt')
47
+ save(model, model_path)
48
+
49
+ if verbose:
50
+ print(model)
51
+ return model
52
+
53
+ def make(config):
54
+ pretrained = not config.random_init
55
+ data_loader, device = load_data(config.cxr_filepath, config.txt_filepath, batch_size=config.batch_size, pretrained=pretrained, column="impression")
56
+ model = load_clip(model_path=None, pretrained=pretrained, context_length=config.context_length)
57
+ model.to(device)
58
+ print('Model on Device.')
59
+
60
+ # make the optimizer
61
+ criterion = nn.CrossEntropyLoss().cuda()
62
+ if config.optimizer == "adam":
63
+ optimizer = optim.AdamW(model.parameters(), lr=config.lr)
64
+ elif config.optimizer == "sgd":
65
+ optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
66
+ return model, data_loader, device, criterion, optimizer
67
+
68
+ def train(model, loader, device, criterion, optimizer, config):
69
+ model_save_dir = os.path.join(config.save_dir, config.model_name)
70
+ if not os.path.exists(model_save_dir):
71
+ # Create a new folder if not exists
72
+ os.makedirs(model_save_dir)
73
+
74
+ # Run training
75
+ total_batches = len(loader) * config.epochs
76
+ example_ct = 0 # number of examples seen
77
+ batch_ct = 0
78
+ report_freq = config.log_interval
79
+ highest_val_auc = 0 # save highest mean auc
80
+
81
+ for epoch in range(config.epochs):
82
+ running_loss = 0.0 # running loss over batch
83
+ for data in tqdm(loader):
84
+ # get the images
85
+ images = data['img']
86
+
87
+ texts = data['txt']
88
+ texts = preprocess_text(texts, model)
89
+
90
+ # perform step for a single batch
91
+ loss = train_batch(images, texts, model, device, criterion, optimizer)
92
+ example_ct += len(images)
93
+ batch_ct += 1
94
+ running_loss += loss.item()
95
+
96
+ # Report metrics every `report_freq` batch
97
+ if (batch_ct % report_freq) == 0:
98
+ train_log(running_loss / report_freq, example_ct, epoch)
99
+ running_loss = 0.0
100
+
101
+ if (batch_ct % config.save_interval) == 0:
102
+ model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format(
103
+ batch_ct=str(batch_ct),
104
+ ))
105
+ print("Saved checkpoint to: ", model_path)
106
+ save(model, model_path)
107
+
108
+ def train_batch(images, texts, model, device, criterion, optimizer):
109
+ images, texts = images.to(device), texts.to(device)
110
+
111
+ # Forward pass ➡
112
+ logits_per_image, logits_per_text = model(images, texts)
113
+
114
+ # Create labels
115
+ batch_size = images.shape[0]
116
+ labels = torch.arange(batch_size).to(device)
117
+
118
+ # Compute loss
119
+ loss_img = criterion(logits_per_image, labels)
120
+ loss_txt = criterion(logits_per_text, labels)
121
+ loss = (loss_img + loss_txt)/2 # avg. img and txt loss
122
+
123
+ # Backward pass ⬅
124
+ optimizer.zero_grad()
125
+ loss.backward()
126
+
127
+ # Step with optimizer
128
+ optimizer.step()
129
+
130
+ return loss
131
+
132
+ def train_log(loss, example_ct, epoch):
133
+ loss = float(loss)
134
+ print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
135
+
136
+ def save(model, path):
137
+ torch.save(model.state_dict(), path)
138
+
139
+ if __name__ == "__main__":
140
+ args = parse_args()
141
+ model = model_pipeline(args)
142
+
143
+