File size: 7,170 Bytes
0f8411f ab6ae1b a070253 0f8411f a070253 0f8411f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import os
import argparse
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import sys
from weights_utils import get_weight
# sys.path.append(os.path.abspath('/home/gholipos-admin/Desktop/Thesis/Training_Code/Entity_Summarization/CheXbert/src'))
from . import utils
# import utils
from .models.bert_labeler import bert_labeler
from .bert_tokenizer import tokenize
from transformers import BertTokenizer
from collections import OrderedDict
from .datasets.unlabeled_dataset import UnlabeledDataset
from .constants import *
from tqdm import tqdm
def collate_fn_no_labels(sample_list):
"""Custom collate function to pad reports in each batch to the max len,
where the reports have no associated labels
@param sample_list (List): A list of samples. Each sample is a dictionary with
keys 'imp', 'len' as returned by the __getitem__
function of ImpressionsDataset
@returns batch (dictionary): A dictionary with keys 'imp' and 'len' but now
'imp' is a tensor with padding and batch size as the
first dimension. 'len' is a list of the length of
each sequence in batch
"""
tensor_list = [s['imp'] for s in sample_list]
batched_imp = torch.nn.utils.rnn.pad_sequence(tensor_list,
batch_first=True,
padding_value=PAD_IDX)
len_list = [s['len'] for s in sample_list]
batch = {'imp': batched_imp, 'len': len_list}
return batch
def load_unlabeled_data(impressions, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
shuffle=False):
""" Create UnlabeledDataset object for the input reports
@param csv_path (string): path to csv file containing reports
@param batch_size (int): the batch size. As per the BERT repository, the max batch size
that can fit on a TITAN XP is 6 if the max sequence length
is 512, which is our case. We have 3 TITAN XP's
@param num_workers (int): how many worker processes to use to load data
@param shuffle (bool): whether to shuffle the data or not
@param impressions (string): "A sentence containing an impression which is replaced the '\n', and '\s' with ' '"
@returns loader (dataloader): dataloader object for the reports
"""
collate_fn = collate_fn_no_labels
dset = UnlabeledDataset(impressions)
loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, collate_fn=collate_fn)
return loader
def label(checkpoint_path, impressions):
"""Labels a dataset of reports
@param checkpoint_path (string): location of saved model checkpoint
@param csv_path (string): location of csv with reports
@param impressions (string): "A sentence containing an impression which is replaced the '\n', and '\s' with ' '"
@returns y_pred (List[List[int]]): Labels for each of the 14 conditions, per report
"""
ld = load_unlabeled_data(impressions)
model = bert_labeler()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ckpt_path = get_weight(checkpoint_path)
if torch.cuda.device_count() > 0: #works even if only 1 GPU available
# print("Using", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model) #to utilize multiple GPU's
model = model.to(device)
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model_state_dict'])
else:
# checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
was_training = model.training
model.eval()
y_pred = [[] for _ in range(len(CONDITIONS))]
# print("\nBegin report impression labeling. The progress bar counts the # of batches completed:")
# print("The batch size is %d" % BATCH_SIZE)
with torch.no_grad():
for i, data in enumerate(tqdm(ld)):
batch = data['imp'] #(batch_size, max_len)
batch = batch.to(device)
src_len = data['len']
batch_size = batch.shape[0]
attn_mask = utils.generate_attention_masks(batch, src_len, device)
out = model(batch, attn_mask)
for j in range(len(out)):
curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size)
y_pred[j].append(curr_y_pred)
for j in range(len(y_pred)):
y_pred[j] = torch.cat(y_pred[j], dim=0)
if was_training:
model.train()
y_pred = [t.tolist() for t in y_pred]
return y_pred
def save_preds(y_pred, csv_path, out_path):
"""Save predictions as out_path/labeled_reports.csv
@param y_pred (List[List[int]]): list of predictions for each report
@param csv_path (string): path to csv containing reports
@param out_path (string): path to output directory
"""
y_pred = np.array(y_pred)
y_pred = y_pred.T
df = pd.DataFrame(y_pred, columns=CONDITIONS)
reports = pd.read_csv(csv_path)['Report Impression']
df['Report Impression'] = reports.tolist()
new_cols = ['Report Impression'] + CONDITIONS
df = df[new_cols]
df.replace(0, np.nan, inplace=True) #blank class is NaN
df.replace(3, -1, inplace=True) #uncertain class is -1
df.replace(2, 0, inplace=True) #negative class is 0
df.to_csv(os.path.join(out_path, 'labeled_reports.csv'), index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Label a csv file containing radiology reports')
parser.add_argument('-d', '--data', type=str, nargs='?', required=True,
help='path to csv containing reports. The reports should be \
under the \"Report Impression\" column')
parser.add_argument('-o', '--output_dir', type=str, nargs='?', required=True,
help='path to intended output folder')
parser.add_argument('-c', '--checkpoint', type=str, nargs='?', required=True,
help='path to the pytorch checkpoint')
parser.add_argument('-s', '--sentence', type=str, nargs='?', required=True,
help="A sentence containing an impression which is replaced the '\n', and '\s' with ' '")
args = parser.parse_args()
csv_path = args.data
out_path = args.output_dir
checkpoint_path = args.checkpoint
impressions = args.sentence
y_pred = label(checkpoint_path, impressions)
save_preds(y_pred, csv_path, out_path)
|