| from torch import nn | |
| import torchvision | |
| class BadNet(nn.Module): | |
| # def __init__(self, input_channel, output_label) -> None: | |
| # 目前只假设cifar10 | |
| def __init__(self, output_label) -> None: | |
| super(BadNet, self).__init__() | |
| self.model = torchvision.models.resnet18(pretrained=True) | |
| num_features = self.model.fc.out_features | |
| self.fc = nn.Linear(in_features=num_features, out_features=output_label) | |
| def forward(self, xs): | |
| out = self.model(xs) | |
| return self.fc(out) | |
| # class BadNet(nn.Module): | |
| # def __init__(self, input_channels, output_num): | |
| # super().__init__() | |
| # self.conv1 = nn.Sequential( | |
| # nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1), | |
| # nn.ReLU(), | |
| # nn.AvgPool2d(kernel_size=2, stride=2) | |
| # ) | |
| # self.conv2 = nn.Sequential( | |
| # nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1), | |
| # nn.ReLU(), | |
| # nn.AvgPool2d(kernel_size=2, stride=2) | |
| # ) | |
| # fc1_input_features = 800 if input_channels == 3 else 512 | |
| # self.fc1 = nn.Sequential( | |
| # nn.Linear(in_features=fc1_input_features, out_features=512), | |
| # nn.ReLU() | |
| # ) | |
| # self.fc2 = nn.Sequential( | |
| # nn.Linear(in_features=512, out_features=output_num), | |
| # nn.Softmax(dim=-1) | |
| # ) | |
| # self.dropout = nn.Dropout(p=.5) | |
| # def forward(self, x): | |
| # x = self.conv1(x) | |
| # x = self.conv2(x) | |
| # print(x.shape) | |
| # x = x.view(x.size(0), -1) | |
| # x = self.fc1(x) | |
| # x = self.fc2(x) | |
| # return x |