Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import DataLoader, TensorDataset | |
| class PopulationHealthScreener: | |
| input_col = 'Abstract' | |
| max_length = 512 | |
| batch_size = 8 | |
| def __init__(self, tokenizer, model): | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model.to(self.device) | |
| def load_fine_tuned_weights(self, model_path): | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| def load_test_data(self, data): | |
| encodings = self.tokenizer( | |
| data, truncation=True, padding='max_length', | |
| max_length=self.max_length, return_tensors='pt' | |
| ) | |
| dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask']) | |
| return DataLoader(dataset, batch_size=self.batch_size, shuffle=False) | |
| def predict(self, data, on_batch=None): | |
| test_data = self.load_test_data(data) | |
| predictions = [] | |
| n_total = len(data) | |
| self.model.eval() | |
| for i, batch in enumerate(test_data): | |
| input_ids = batch[0].to(self.device) | |
| attention_mask = batch[1].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
| predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist()) | |
| if on_batch: | |
| on_batch(min((i + 1) * self.batch_size, n_total), n_total) | |
| return predictions |