| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| from torch import load |
| from torch.utils.data import DataLoader |
| from typing import Union |
|
|
| from model.epu import EPUCNN |
| from utils.omega_parser import EPUCNNParams |
| from utils.config_utils import model_cfg_to_epucnn |
|
|
|
|
| class EPUCNNEval(EPUCNN): |
| def __init__(self, epu_cfg: EPUCNNParams): |
| super().__init__(**model_cfg_to_epucnn(epu_cfg)) |
| self.eval() |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def load_ckpt(self, device: torch.device, ckpt_path: str): |
| state_dict = load(ckpt_path, map_location=device) |
| self.load_state_dict(state_dict) |
| self.to(device) |
| self.eval() |
| return self |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class InferenceRunnerEPUCNN: |
| def __init__(self, epu_model: Union[EPUCNN, nn.Module], device: torch.device, mode: str = 'binary'): |
| self.epu_model = epu_model |
| self.device = device |
| self.mode = mode |
|
|
| def predict(self, dataloader: DataLoader, raw_logits=False, return_predictions: bool = False): |
| self.epu_model.eval() |
|
|
| all_targets = [] |
| all_predictions = [] |
|
|
| results = {} |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| x, y = batch |
| x = x.to(self.device) |
| y = y.to(self.device, dtype=torch.float32).unsqueeze(1) |
| y_hat = self.epu_model(x, ret_raw_logits=raw_logits) |
|
|
| all_predictions.append(y_hat.cpu().detach().numpy()) |
| if return_predictions: |
| all_targets.append(y.cpu().detach().numpy()) |
|
|
| results["predictions"] = np.concatenate(all_predictions, axis=0) |
| if return_predictions: |
| results["targets"] = np.concatenate(all_targets, axis=0) |
| return results |
|
|