File size: 1,005 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn

################################################################################
# Wrapper for all PyTorch audio classifiers
################################################################################


class Model(nn.Module):
    """

    Wrapper class for PyTorch models; provides a consistent interface for

    attack algorithms and prediction

    """

    def __init__(self):
        """

        Initialize model

        """
        super().__init__()

    def forward(self, x: torch.Tensor):
        """

        Perform forward pass

        """
        raise NotImplementedError()

    def load_weights(self, path: str):
        """

        Load weights from checkpoint file

        """
        raise NotImplementedError()

    @staticmethod
    def match_predict(y_pred: torch.Tensor, y_true: torch.Tensor):
        """

        Determine whether target pairs are equivalent

        """
        raise NotImplementedError()