File size: 2,726 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import numpy as np

from torch import load
from torch.utils.data import DataLoader
from typing import Union

from model.epu import EPUCNN
from utils.omega_parser import EPUCNNParams
from utils.config_utils import model_cfg_to_epucnn


class EPUCNNEval(EPUCNN):
    def __init__(self, epu_cfg: EPUCNNParams):
        super().__init__(**model_cfg_to_epucnn(epu_cfg))
        self.eval()

    # @staticmethod
    # def initialize_model_from_config(epu_cfg: EPUCNNParams, device: torch.device) -> EPUCNN:
    #     model = EPUCNN(**model_cfg_to_epucnn(epu_cfg))
    #     model.to(device)
    #     return model

    # @staticmethod
    # def load_ckpt(model: Union[nn.Module, EPUCNN], device: torch.device, ckpt_path: str) -> Union[nn.Module, EPUCNN]:
    #     state_dict = load(ckpt_path, map_location=device)
    #     model.load_state_dict(state_dict)
    #     model.to(device)
    #     model.eval()
    #     return model

    def load_ckpt(self, device: torch.device, ckpt_path: str):
        state_dict = load(ckpt_path, map_location=device)
        self.load_state_dict(state_dict)
        self.to(device)
        self.eval()
        return self

    # @staticmethod
    # def get_pretrained_model_from_config(epu_cfg: EPUCNNParams,
    #                                      device: torch.device,
    #                                      ckpt_path: str
    #                                      ):
    #     model = EPUCNNEval.initialize_model_from_config(epu_cfg, device)
    #     model = EPUCNNEval.load_ckpt(model, device, ckpt_path)
    #     return model


class InferenceRunnerEPUCNN:
    def __init__(self, epu_model: Union[EPUCNN, nn.Module], device: torch.device, mode: str = 'binary'):
        self.epu_model = epu_model
        self.device = device
        self.mode = mode

    def predict(self, dataloader: DataLoader, raw_logits=False, return_predictions: bool = False):
        self.epu_model.eval()

        all_targets = []
        all_predictions = []

        results = {}

        with torch.no_grad():
            for batch in dataloader:
                x, y = batch
                x = x.to(self.device)
                y = y.to(self.device, dtype=torch.float32).unsqueeze(1)  # from [bs] to [bs, 1]
                y_hat = self.epu_model(x, ret_raw_logits=raw_logits)

                all_predictions.append(y_hat.cpu().detach().numpy())
                if return_predictions:
                    all_targets.append(y.cpu().detach().numpy())

            results["predictions"] = np.concatenate(all_predictions, axis=0)
            if return_predictions:
                results["targets"] = np.concatenate(all_targets, axis=0)
            return results