import torch from torch.utils.data import DataLoader, TensorDataset class PopulationHealthScreener: input_col = 'Abstract' max_length = 512 batch_size = 8 def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) def load_fine_tuned_weights(self, model_path): state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict, strict=False) def load_test_data(self, data): encodings = self.tokenizer( data, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask']) return DataLoader(dataset, batch_size=self.batch_size, shuffle=False) def predict(self, data, on_batch=None): test_data = self.load_test_data(data) predictions = [] n_total = len(data) self.model.eval() for i, batch in enumerate(test_data): input_ids = batch[0].to(self.device) attention_mask = batch[1].to(self.device) with torch.no_grad(): outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True) predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist()) if on_batch: on_batch(min((i + 1) * self.batch_size, n_total), n_total) return predictions