File size: 1,942 Bytes
aa1d2f8
 
5cf843c
 
aa1d2f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a5bcff
7d99e02
 
 
 
 
ff733bc
 
 
 
fe87420
ff733bc
7d99e02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import math
import torch
import torch.nn.functional as F
# Custom sigmoid function
def sigmoid(x):
    return 1 / (1 + math.exp(-x))

# Vectorized sigmoid
sigmoid_v = np.vectorize(sigmoid)

def inference(model, dataloader, device):
    """
    Perform inference using a BERT model on a given dataloader.

    Args:
        model (torch.nn.Module): The trained BERT model.
        dataloader (torch.utils.data.DataLoader): DataLoader for test or validation data.
        device (torch.device): The device to run the inference on (e.g., 'cpu' or 'cuda').

    Returns:
        tuple: (list of probabilities, list of true labels)
    """
    # Set model to evaluation mode
    model.eval()

    # Tracking variables
    logits_list = []
    labels_list = []

    # Iterate through batches in dataloader
    for batch in dataloader:
        # Load batch to device
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

        # Perform inference without gradient computation
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)

        # Append logits and labels to tracking variables
        logits_list.extend(logits.cpu().numpy())
        labels_list.extend(b_labels.cpu().numpy())

    # Calculate probabilities using sigmoid

    #probs = (np.sum(sigmoid_v(logits_list), axis=0).flatten() / len(logits_list))[1]
    probs = torch.softmax(torch.tensor(logits_list), dim=1)  
    probs2=[]
    for i in range(len(probs)):
        summed=np.array([1 if sublist[0] <= 0.55 else 0 for sublist in [probs[i]]]) #/len(probs[i])
        probs2.append(summed)
    score=(np.mean(probs2)) #* 100
    probs = torch.softmax(torch.tensor(logits_list), dim=1)
    # get pathogenic probabilities (second column)
    pathogenic_probs = probs[:, 1].numpy()
    # mean probability across chunks
    score = float(np.mean(pathogenic_probs))

    return score, labels_list,logits_list