fulviodeo's picture
Added prediction box
c898199
raw
history blame
1.61 kB
import torch
from torch.utils.data import DataLoader, TensorDataset
class PopulationHealthScreener:
input_col = 'Abstract'
max_length = 512
batch_size = 8
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def load_fine_tuned_weights(self, model_path):
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict, strict=False)
def load_test_data(self, data):
encodings = self.tokenizer(
data, truncation=True, padding='max_length',
max_length=self.max_length, return_tensors='pt'
)
dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'])
return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
def predict(self, data, on_batch=None):
test_data = self.load_test_data(data)
predictions = []
n_total = len(data)
self.model.eval()
for i, batch in enumerate(test_data):
input_ids = batch[0].to(self.device)
attention_mask = batch[1].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist())
if on_batch:
on_batch(min((i + 1) * self.batch_size, n_total), n_total)
return predictions