Spaces:
Runtime error
Runtime error
changed file names
Browse files
app.py
CHANGED
|
@@ -29,7 +29,7 @@ import base64
|
|
| 29 |
# else:
|
| 30 |
# print("Google Drive is already mounted.")
|
| 31 |
|
| 32 |
-
list_c1 = torch.load('
|
| 33 |
|
| 34 |
class CustomDataset(torch.utils.data.Dataset):
|
| 35 |
def __init__(self, data):
|
|
@@ -54,7 +54,7 @@ def get_images():
|
|
| 54 |
pil_images = [transform_to_pil(image) for image in images]
|
| 55 |
return pil_images, labels.tolist()
|
| 56 |
|
| 57 |
-
list_c2 = torch.load('
|
| 58 |
dataset_c2 = CustomDataset(list_c2)
|
| 59 |
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
|
| 60 |
def get_images_2():
|
|
@@ -173,7 +173,7 @@ class Network(nn.Module):
|
|
| 173 |
loaded_model_non_dann = Network()
|
| 174 |
loaded_model_non_dann = loaded_model_non_dann.to(device)
|
| 175 |
# Load the saved state dictionary
|
| 176 |
-
loaded_model_non_dann.load_state_dict(torch.load('
|
| 177 |
loaded_model_non_dann.eval()
|
| 178 |
|
| 179 |
## DANN
|
|
@@ -181,7 +181,7 @@ loaded_model_non_dann.eval()
|
|
| 181 |
loaded_model_dann = Network()
|
| 182 |
loaded_model_dann = loaded_model_dann.to(device)
|
| 183 |
# Load the saved state dictionary
|
| 184 |
-
loaded_model_dann.load_state_dict(torch.load('
|
| 185 |
loaded_model_dann.eval()
|
| 186 |
|
| 187 |
img_size = 28 # for mnist
|
|
|
|
| 29 |
# else:
|
| 30 |
# print("Google Drive is already mounted.")
|
| 31 |
|
| 32 |
+
list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified_08_07.pt')
|
| 33 |
|
| 34 |
class CustomDataset(torch.utils.data.Dataset):
|
| 35 |
def __init__(self, data):
|
|
|
|
| 54 |
pil_images = [transform_to_pil(image) for image in images]
|
| 55 |
return pil_images, labels.tolist()
|
| 56 |
|
| 57 |
+
list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified_08_07.pt')
|
| 58 |
dataset_c2 = CustomDataset(list_c2)
|
| 59 |
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
|
| 60 |
def get_images_2():
|
|
|
|
| 173 |
loaded_model_non_dann = Network()
|
| 174 |
loaded_model_non_dann = loaded_model_non_dann.to(device)
|
| 175 |
# Load the saved state dictionary
|
| 176 |
+
loaded_model_non_dann.load_state_dict(torch.load('non_dann_08_07.pt', map_location=device), strict=False)
|
| 177 |
loaded_model_non_dann.eval()
|
| 178 |
|
| 179 |
## DANN
|
|
|
|
| 181 |
loaded_model_dann = Network()
|
| 182 |
loaded_model_dann = loaded_model_dann.to(device)
|
| 183 |
# Load the saved state dictionary
|
| 184 |
+
loaded_model_dann.load_state_dict(torch.load('dann_08_07.pt', map_location=device), strict=False)
|
| 185 |
loaded_model_dann.eval()
|
| 186 |
|
| 187 |
img_size = 28 # for mnist
|