File size: 902 Bytes
6039b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

class ModelTester:
    def __init__(
        self,
        model: Module,
        device: torch.device,
        show_prgress_bar: bool = True
    ) -> None:
        self.model = model
        self.device = device
        self.show_prgress_bar = show_prgress_bar

    def test(self, dataloader: DataLoader):
        self.model.eval()
        result = []
        with torch.no_grad():
            pbar = dataloader
            if self.show_prgress_bar:
                pbar = tqdm(dataloader, total=len(
                    dataloader), desc="embedding")
            for x in pbar:
                x = [d.to(self.device) for d in x]
                pred: torch.Tensor = self.model(*x)
                result.append(pred.cpu().numpy())
        return np.concatenate(result, axis=0)