| import os |
| import pprint |
| import argparse |
| from tqdm import tqdm |
|
|
| import torch |
| from torch.utils import data |
| from torch import nn |
| import torch.optim as optim |
| from torchvision.transforms import Compose, Normalize, Resize |
|
|
| import clip |
| from model import CLIP |
| from simple_tokenizer import SimpleTokenizer |
|
|
| from train import train_main, load_data, load_clip, preprocess_text |
| from zero_shot import run_cxr_zero_shot, run_zero_shot |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--cxr_filepath', type=str, default='data/cxr.h5', help="Directory to load chest x-ray image data from.") |
| parser.add_argument('--txt_filepath', type=str, default='data/mimic_impressions.csv', help="Directory to load radiology report impressions text from.") |
| parser.add_argument('--batch_size', type=int, default=16) |
| parser.add_argument('--epochs', type=int, default=4) |
| parser.add_argument('--lr', type=float, default=1e-4) |
| parser.add_argument('--save_interval', type=int, default=100) |
| parser.add_argument('--log_interval', type=int, default=10) |
| parser.add_argument('--save_dir', type=str, default="checkpoints/", help="Directory to save the trained model.") |
| parser.add_argument('--seed', type=int, default=1234) |
| parser.add_argument('--optimizer', type=str, default="sgd") |
| parser.add_argument('--momentum', type=float, default=0.9) |
| parser.add_argument('--context_length', type=int, default=77) |
| parser.add_argument('--random_init', action='store_true') |
| parser.add_argument('--model_name', type=str, default="pt-imp") |
| args = parser.parse_args() |
| return args |
|
|
| def model_pipeline(config, verbose=0): |
| |
| model, data_loader, device, criterion, optimizer = make(config) |
|
|
| |
| train(model, data_loader, device, criterion, optimizer, config) |
|
|
| |
| model_path = os.path.join(config.save_dir, str(config.model_name), 'checkpoint.pt') |
| save(model, model_path) |
|
|
| if verbose: |
| print(model) |
| return model |
|
|
| def make(config): |
| pretrained = not config.random_init |
| data_loader, device = load_data(config.cxr_filepath, config.txt_filepath, batch_size=config.batch_size, pretrained=pretrained, column="impression") |
| model = load_clip(model_path=None, pretrained=pretrained, context_length=config.context_length) |
| model.to(device) |
| print('Model on Device.') |
|
|
| |
| criterion = nn.CrossEntropyLoss().cuda() |
| if config.optimizer == "adam": |
| optimizer = optim.AdamW(model.parameters(), lr=config.lr) |
| elif config.optimizer == "sgd": |
| optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum) |
| return model, data_loader, device, criterion, optimizer |
|
|
| def train(model, loader, device, criterion, optimizer, config): |
| model_save_dir = os.path.join(config.save_dir, config.model_name) |
| if not os.path.exists(model_save_dir): |
| |
| os.makedirs(model_save_dir) |
| |
| |
| total_batches = len(loader) * config.epochs |
| example_ct = 0 |
| batch_ct = 0 |
| report_freq = config.log_interval |
| highest_val_auc = 0 |
| |
| for epoch in range(config.epochs): |
| running_loss = 0.0 |
| for data in tqdm(loader): |
| |
| images = data['img'] |
|
|
| texts = data['txt'] |
| texts = preprocess_text(texts, model) |
| |
| |
| loss = train_batch(images, texts, model, device, criterion, optimizer) |
| example_ct += len(images) |
| batch_ct += 1 |
| running_loss += loss.item() |
|
|
| |
| if (batch_ct % report_freq) == 0: |
| train_log(running_loss / report_freq, example_ct, epoch) |
| running_loss = 0.0 |
| |
| if (batch_ct % config.save_interval) == 0: |
| model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format( |
| batch_ct=str(batch_ct), |
| )) |
| print("Saved checkpoint to: ", model_path) |
| save(model, model_path) |
| |
| def train_batch(images, texts, model, device, criterion, optimizer): |
| images, texts = images.to(device), texts.to(device) |
| |
| |
| logits_per_image, logits_per_text = model(images, texts) |
| |
| |
| batch_size = images.shape[0] |
| labels = torch.arange(batch_size).to(device) |
| |
| |
| loss_img = criterion(logits_per_image, labels) |
| loss_txt = criterion(logits_per_text, labels) |
| loss = (loss_img + loss_txt)/2 |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| |
| |
| optimizer.step() |
| |
| return loss |
|
|
| def train_log(loss, example_ct, epoch): |
| loss = float(loss) |
| print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}") |
| |
| def save(model, path): |
| torch.save(model.state_dict(), path) |
| |
| if __name__ == "__main__": |
| args = parse_args() |
| model = model_pipeline(args) |
| |
|
|
|
|