Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from captum.insights import AttributionVisualizer, Batch | |
| from captum.insights.attr_vis.features import ImageFeature | |
| def get_classes(): | |
| classes = [ | |
| "Plane", | |
| "Car", | |
| "Bird", | |
| "Cat", | |
| "Deer", | |
| "Dog", | |
| "Frog", | |
| "Horse", | |
| "Ship", | |
| "Truck", | |
| ] | |
| return classes | |
| def get_pretrained_model(): | |
| class Net(nn.Module): | |
| def __init__(self) -> None: | |
| super(Net, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 6, 5) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(6, 16, 5) | |
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, 10) | |
| self.relu1 = nn.ReLU() | |
| self.relu2 = nn.ReLU() | |
| self.relu3 = nn.ReLU() | |
| self.relu4 = nn.ReLU() | |
| def forward(self, x): | |
| x = self.pool1(self.relu1(self.conv1(x))) | |
| x = self.pool2(self.relu2(self.conv2(x))) | |
| x = x.view(-1, 16 * 5 * 5) | |
| x = self.relu3(self.fc1(x)) | |
| x = self.relu4(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| net = Net() | |
| pt_path = os.path.abspath( | |
| os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt") | |
| ) | |
| net.load_state_dict(torch.load(pt_path)) | |
| return net | |
| def baseline_func(input): | |
| return input * 0 | |
| def formatted_data_iter(): | |
| dataset = torchvision.datasets.CIFAR10( | |
| root="data/test", train=False, download=True, transform=transforms.ToTensor() | |
| ) | |
| dataloader = iter( | |
| torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2) | |
| ) | |
| while True: | |
| images, labels = next(dataloader) | |
| yield Batch(inputs=images, labels=labels) | |
| def main(): | |
| normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| model = get_pretrained_model() | |
| visualizer = AttributionVisualizer( | |
| models=[model], | |
| score_func=lambda o: torch.nn.functional.softmax(o, 1), | |
| classes=get_classes(), | |
| features=[ | |
| ImageFeature( | |
| "Photo", | |
| baseline_transforms=[baseline_func], | |
| input_transforms=[normalize], | |
| ) | |
| ], | |
| dataset=formatted_data_iter(), | |
| ) | |
| visualizer.serve(debug=True) | |
| if __name__ == "__main__": | |
| main() | |