ius / utils /eval_utils.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
import torch
import torch.nn as nn
import numpy as np
from torch import load
from torch.utils.data import DataLoader
from typing import Union
from model.epu import EPUCNN
from utils.omega_parser import EPUCNNParams
from utils.config_utils import model_cfg_to_epucnn
class EPUCNNEval(EPUCNN):
def __init__(self, epu_cfg: EPUCNNParams):
super().__init__(**model_cfg_to_epucnn(epu_cfg))
self.eval()
# @staticmethod
# def initialize_model_from_config(epu_cfg: EPUCNNParams, device: torch.device) -> EPUCNN:
# model = EPUCNN(**model_cfg_to_epucnn(epu_cfg))
# model.to(device)
# return model
# @staticmethod
# def load_ckpt(model: Union[nn.Module, EPUCNN], device: torch.device, ckpt_path: str) -> Union[nn.Module, EPUCNN]:
# state_dict = load(ckpt_path, map_location=device)
# model.load_state_dict(state_dict)
# model.to(device)
# model.eval()
# return model
def load_ckpt(self, device: torch.device, ckpt_path: str):
state_dict = load(ckpt_path, map_location=device)
self.load_state_dict(state_dict)
self.to(device)
self.eval()
return self
# @staticmethod
# def get_pretrained_model_from_config(epu_cfg: EPUCNNParams,
# device: torch.device,
# ckpt_path: str
# ):
# model = EPUCNNEval.initialize_model_from_config(epu_cfg, device)
# model = EPUCNNEval.load_ckpt(model, device, ckpt_path)
# return model
class InferenceRunnerEPUCNN:
def __init__(self, epu_model: Union[EPUCNN, nn.Module], device: torch.device, mode: str = 'binary'):
self.epu_model = epu_model
self.device = device
self.mode = mode
def predict(self, dataloader: DataLoader, raw_logits=False, return_predictions: bool = False):
self.epu_model.eval()
all_targets = []
all_predictions = []
results = {}
with torch.no_grad():
for batch in dataloader:
x, y = batch
x = x.to(self.device)
y = y.to(self.device, dtype=torch.float32).unsqueeze(1) # from [bs] to [bs, 1]
y_hat = self.epu_model(x, ret_raw_logits=raw_logits)
all_predictions.append(y_hat.cpu().detach().numpy())
if return_predictions:
all_targets.append(y.cpu().detach().numpy())
results["predictions"] = np.concatenate(all_predictions, axis=0)
if return_predictions:
results["targets"] = np.concatenate(all_targets, axis=0)
return results