| |
| |
| |
|
|
| import os |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| __all__ = ["AccuracyPredictor"] |
|
|
|
|
| class AccuracyPredictor(nn.Module): |
| def __init__( |
| self, |
| arch_encoder, |
| hidden_size=400, |
| n_layers=3, |
| checkpoint_path=None, |
| device="cuda:0", |
| base_acc_val = None |
| ): |
| super(AccuracyPredictor, self).__init__() |
| self.arch_encoder = arch_encoder |
| self.hidden_size = hidden_size |
| self.n_layers = n_layers |
| self.device = device |
| self.base_acc_val = base_acc_val |
| |
| layers = [] |
| for i in range(self.n_layers): |
| layers.append( |
| nn.Sequential( |
| nn.Linear( |
| self.arch_encoder.n_dim if i == 0 else self.hidden_size, |
| self.hidden_size, |
| ), |
| nn.ReLU(inplace=True), |
| ) |
| ) |
| layers.append(nn.Linear(self.hidden_size, 1, bias=False)) |
| self.layers = nn.Sequential(*layers) |
| if self.base_acc_val!=None : |
| self.base_acc = nn.Parameter( |
| torch.tensor(self.base_acc_val, device=self.device), requires_grad=False |
| ) |
| else: |
| self.base_acc = nn.Parameter( |
| torch.zeros(1, device=self.device), requires_grad=False |
| ) |
|
|
| if checkpoint_path is not None and os.path.exists(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| if "state_dict" in checkpoint: |
| checkpoint = checkpoint["state_dict"] |
| self.load_state_dict(checkpoint) |
| print("Loaded checkpoint from %s" % checkpoint_path) |
|
|
| self.layers = self.layers.to(self.device) |
|
|
| def forward(self, x): |
| y = self.layers(x).squeeze() |
| return y + self.base_acc |
|
|
| def predict_acc(self, arch_dict_list): |
| X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list] |
| X = torch.tensor(np.array(X)).float().to(self.device) |
| return self.forward(X) |
|
|