File size: 2,726 Bytes
99ec8a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import torch
import torch.nn as nn
import numpy as np
from torch import load
from torch.utils.data import DataLoader
from typing import Union
from model.epu import EPUCNN
from utils.omega_parser import EPUCNNParams
from utils.config_utils import model_cfg_to_epucnn
class EPUCNNEval(EPUCNN):
def __init__(self, epu_cfg: EPUCNNParams):
super().__init__(**model_cfg_to_epucnn(epu_cfg))
self.eval()
# @staticmethod
# def initialize_model_from_config(epu_cfg: EPUCNNParams, device: torch.device) -> EPUCNN:
# model = EPUCNN(**model_cfg_to_epucnn(epu_cfg))
# model.to(device)
# return model
# @staticmethod
# def load_ckpt(model: Union[nn.Module, EPUCNN], device: torch.device, ckpt_path: str) -> Union[nn.Module, EPUCNN]:
# state_dict = load(ckpt_path, map_location=device)
# model.load_state_dict(state_dict)
# model.to(device)
# model.eval()
# return model
def load_ckpt(self, device: torch.device, ckpt_path: str):
state_dict = load(ckpt_path, map_location=device)
self.load_state_dict(state_dict)
self.to(device)
self.eval()
return self
# @staticmethod
# def get_pretrained_model_from_config(epu_cfg: EPUCNNParams,
# device: torch.device,
# ckpt_path: str
# ):
# model = EPUCNNEval.initialize_model_from_config(epu_cfg, device)
# model = EPUCNNEval.load_ckpt(model, device, ckpt_path)
# return model
class InferenceRunnerEPUCNN:
def __init__(self, epu_model: Union[EPUCNN, nn.Module], device: torch.device, mode: str = 'binary'):
self.epu_model = epu_model
self.device = device
self.mode = mode
def predict(self, dataloader: DataLoader, raw_logits=False, return_predictions: bool = False):
self.epu_model.eval()
all_targets = []
all_predictions = []
results = {}
with torch.no_grad():
for batch in dataloader:
x, y = batch
x = x.to(self.device)
y = y.to(self.device, dtype=torch.float32).unsqueeze(1) # from [bs] to [bs, 1]
y_hat = self.epu_model(x, ret_raw_logits=raw_logits)
all_predictions.append(y_hat.cpu().detach().numpy())
if return_predictions:
all_targets.append(y.cpu().detach().numpy())
results["predictions"] = np.concatenate(all_predictions, axis=0)
if return_predictions:
results["targets"] = np.concatenate(all_targets, axis=0)
return results
|