File size: 1,613 Bytes
d8289f0
 
 
 
 
 
 
 
 
34f8215
d8289f0
 
 
 
 
 
8dd9580
 
 
 
d8289f0
 
 
 
 
 
 
 
c898199
d8289f0
 
c898199
d8289f0
 
 
 
 
 
 
 
 
 
c898199
 
 
d8289f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch.utils.data import DataLoader, TensorDataset


class PopulationHealthScreener:
    input_col = 'Abstract'
    max_length = 512
    batch_size = 8

    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def load_fine_tuned_weights(self, model_path):
        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict, strict=False)

    def load_test_data(self, data):
        encodings = self.tokenizer(
            data, truncation=True, padding='max_length',
            max_length=self.max_length, return_tensors='pt'
        )
        dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'])
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

    def predict(self, data, on_batch=None):
        test_data = self.load_test_data(data)
        predictions = []
        n_total = len(data)
        self.model.eval()

        for i, batch in enumerate(test_data):
            input_ids = batch[0].to(self.device)
            attention_mask = batch[1].to(self.device)

            with torch.no_grad():
                outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
                predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist())

            if on_batch:
                on_batch(min((i + 1) * self.batch_size, n_total), n_total)

        return predictions