| import io |
| import os |
| import math |
| import random |
| import torch.nn as nn |
| from torch.utils import data |
| import argparse |
| import numpy as np |
| from torchvision import transforms |
| from PIL import Image |
| import torch |
| import torch.utils.data as Data |
| from torch.autograd import Variable |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
|
|
|
|
|
|
|
|
|
|
| def matrix_poly(matrix, d): |
| x = torch.eye(d).to(matrix.device)+ torch.div(matrix, d) |
| return torch.matrix_power(x, d) |
| |
| |
| def _h_A(A, m): |
| expm_A = matrix_poly(A*A, m) |
| h_A = torch.trace(expm_A) - m |
| return h_A |
| |
|
|
| def mask_threshold(x): |
| x = (x+0.5).int().float() |
| return x |
|
|
|
|
| class CausalCircuit(Dataset): |
| def __init__(self, root, dataset="train"): |
| root = root + "/" + dataset |
| |
| self.imgs = [] |
| self.labels = [] |
| |
| if dataset == "train": |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| data = np.load(f'../data/causal_circuit/train.npz') |
| self.img_labels = data['original_latents'][:, 0, :] |
| |
| |
| indices_11 = np.argwhere((self.img_labels[:, 0] > 0.4) | (self.img_labels[:, 1] > 0.4) | (self.img_labels[:, 2] > 0.4)) |
| self.img_labels_1 = self.img_labels[(self.img_labels[:, 0] > 0.4) | (self.img_labels[:, 1] > 0.4) | (self.img_labels[:, 2] > 0.4)] |
| self.img_labels = self.img_labels_1 |
| |
| temp = data['imgs'][:, 0] |
| filtered_images = np.take(temp, indices_11) |
| for i in range(len(filtered_images)): |
| self.imgs.append(Image.open(io.BytesIO(filtered_images[i]))) |
| self.labels.append(self.img_labels[i]) |
| else: |
| data = np.load('../data/causal_circuit/test.npz') |
| self.img_labels = data['original_latents'][:, 0, :] |
| |
| |
| indices_11 = np.argwhere((self.img_labels[:, 0] > 0.4) | (self.img_labels[:, 1] > 0.4) | (self.img_labels[:, 2] > 0.4)) |
| self.img_labels_1 = self.img_labels[(self.img_labels[:, 0] > 0.4) | (self.img_labels[:, 1] > 0.4) | (self.img_labels[:, 2] > 0.4)] |
| self.img_labels = self.img_labels_1 |
|
|
| temp = data['imgs'][:, 0] |
| filtered_images = np.take(temp, indices_11) |
| for i in range(len(filtered_images)): |
| self.imgs.append(Image.open(io.BytesIO(filtered_images[i]))) |
| self.labels.append(self.img_labels[i]) |
|
|
| self.dataset = dataset |
| self.transforms = transforms.Compose([transforms.Resize(128), transforms.ToTensor()]) |
|
|
| def __getitem__(self, idx): |
| |
| data = self.imgs[idx] |
| |
| perm = [3, 2, 1, 0] |
| label = torch.from_numpy(np.asarray(self.labels)[idx][perm]) |
|
|
|
|
| if self.transforms: |
| data = self.transforms(data) |
| |
| return data, label.float() |
|
|
| def __len__(self): |
| return len(self.imgs) |
|
|
|
|
| |
| class CausalCircuitSimplified(Dataset): |
| def __init__(self, root, dataset="train"): |
| root = root + "/" + dataset |
| |
| self.imgs = [] |
| self.labels = [] |
| |
| if dataset == "train": |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| data = np.load(f'../data/causal_circuit/train.npz') |
|
|
| perm = [3, 2, 1, 0] |
| self.img_labels_0 = data['original_latents'][:, 0, :] |
| self.img_labels_1 = data['original_latents'][:, 1, :] |
| |
| |
|
|
| self.img_labels = np.concatenate((self.img_labels_0, self.img_labels_1)) |
| |
|
|
| indices_11 = np.argwhere((self.img_labels[:, 3] > 0.1) & (self.img_labels[:, 3] < 0.4) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 1] > 0.4) & (self.img_labels[:, 2] < 0.2)) |
| self.img_labels_1 = self.img_labels[(self.img_labels[:, 3] > 0.1) & (self.img_labels[:, 3] < 0.4) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 1] > 0.4) & (self.img_labels[:, 2] < 0.2)] |
|
|
| indices_12 = np.argwhere((self.img_labels[:, 3] > 0.4) & (self.img_labels[:, 3] < 0.7) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] < 0.2) & (self.img_labels[:, 1] < 0.2)) |
| self.img_labels_2 = self.img_labels[(self.img_labels[:, 3] > 0.4) & (self.img_labels[:, 3] < 0.7) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] < 0.2) & (self.img_labels[:, 1] < 0.2)] |
|
|
| indices_13 = np.argwhere((self.img_labels[:, 3] > 0.7) & (self.img_labels[:, 3] < 1) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] > 0.4) & (self.img_labels[:, 1] < 0.2)) |
| self.img_labels_3 = self.img_labels[(self.img_labels[:, 3] > 0.7) & (self.img_labels[:, 3] < 1) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] > 0.4) & (self.img_labels[:, 1] < 0.2)] |
|
|
| |
| |
|
|
|
|
| self.img_labels = np.concatenate((self.img_labels_1, self.img_labels_2, self.img_labels_3)) |
| |
|
|
|
|
| indices = np.concatenate((indices_11, indices_12, indices_13)) |
|
|
| |
|
|
| temp1 = data['imgs'][:, 0] |
| temp2 = data['imgs'][:, 1] |
|
|
| temp = np.concatenate((temp1, temp2)) |
| |
| filtered_images = np.take(temp, indices) |
| |
|
|
| for i in range(len(filtered_images)): |
| self.imgs.append(Image.open(io.BytesIO(filtered_images[i]))) |
| self.labels.append(self.img_labels[i]) |
|
|
| else: |
| data = np.load('../data/causal_circuit/test.npz') |
| self.imgs = [] |
| |
| perm = [3, 2, 1, 0] |
| self.img_labels_0 = data['original_latents'][:, 0, :] |
| self.img_labels_1 = data['original_latents'][:, 1, :] |
| |
| |
|
|
| self.img_labels = np.concatenate((self.img_labels_0, self.img_labels_1)) |
| print(self.img_labels.shape) |
|
|
| indices_11 = np.argwhere((self.img_labels[:, 3] > 0.1) & (self.img_labels[:, 3] < 0.4) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 1] > 0.4) & (self.img_labels[:, 2] < 0.2)) |
| self.img_labels_1 = self.img_labels[(self.img_labels[:, 3] > 0.1) & (self.img_labels[:, 3] < 0.4) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 1] > 0.4) & (self.img_labels[:, 2] < 0.2)] |
|
|
| indices_12 = np.argwhere((self.img_labels[:, 3] > 0.4) & (self.img_labels[:, 3] < 0.7) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] < 0.2) & (self.img_labels[:, 1] < 0.2)) |
| self.img_labels_2 = self.img_labels[(self.img_labels[:, 3] > 0.4) & (self.img_labels[:, 3] < 0.7) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] < 0.2) & (self.img_labels[:, 1] < 0.2)] |
|
|
| indices_13 = np.argwhere((self.img_labels[:, 3] > 0.7) & (self.img_labels[:, 3] < 1) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] > 0.4) & (self.img_labels[:, 1] < 0.2)) |
| self.img_labels_3 = self.img_labels[(self.img_labels[:, 3] > 0.7) & (self.img_labels[:, 3] < 1) & (self.img_labels[:, 0] > 0.5) & (self.img_labels[:, 2] > 0.4) & (self.img_labels[:, 1] < 0.2)] |
|
|
| |
| |
|
|
|
|
| self.labels = np.concatenate((self.img_labels_1, self.img_labels_2, self.img_labels_3)) |
| indices = np.concatenate((indices_11, indices_12, indices_13)) |
|
|
|
|
| temp1 = data['imgs'][:, 0] |
| temp2 = data['imgs'][:, 1] |
|
|
| temp = np.concatenate((temp1, temp2)) |
| filtered_images = np.take(temp, indices) |
|
|
| for i in range(len(filtered_images)): |
| self.imgs.append(Image.open(io.BytesIO(filtered_images[i]))) |
| |
| self.dataset = dataset |
| self.transforms = transforms.Compose([transforms.Resize(128), transforms.ToTensor()]) |
|
|
| def __getitem__(self, idx): |
| |
| data = self.imgs[idx] |
| |
| perm = [3, 2, 1, 0] |
| label = torch.from_numpy(np.asarray(self.labels)[idx][perm]) |
|
|
| if self.transforms: |
| data = self.transforms(data) |
| |
| return data, label.float() |
|
|
| def __len__(self): |
| return len(self.imgs) |
| |
| |
|
|
|
|
| class SyntheticLabeled(data.Dataset): |
| def __init__(self, root, dataset="train"): |
| root = root + "/" + dataset |
|
|
| imgs = os.listdir(root) |
|
|
| self.dataset = dataset |
| |
| self.imgs = [os.path.join(root, k) for k in imgs] |
| self.imglabel = [list(map(int,k[:-4].split("_")[1:])) for k in imgs] |
| |
| label = np.asarray(self.imglabel) |
| self.label_avg = np.mean(label, axis=0) |
| self.label_std = np.std(label, axis=0) |
| |
| |
| |
| self.transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
| def __getitem__(self, idx): |
| |
| img_path = self.imgs[idx] |
| |
| label = torch.from_numpy(np.asarray(self.imglabel[idx])) |
| |
| pil_img = Image.open(img_path) |
| array = np.asarray(pil_img) |
| array1 = np.asarray(label) |
| label = torch.from_numpy(array1) |
| data = torch.from_numpy(array) |
| if self.transforms: |
| data = self.transforms(pil_img) |
| else: |
| pil_img = np.asarray(pil_img).reshape(96,96,4) |
| data = torch.from_numpy(pil_img) |
| |
| return data, label.float() |
|
|
| def __len__(self): |
| return len(self.imgs) |
| |
| |
| |
| |
|
|
| def tupleize_data(images): |
| temp = images |
| tup = [] |
| for i in range(len(temp)): |
| lst = temp[i] |
| |
|
|
| for idx, item in enumerate(lst): |
| |
| |
| if 'orig' in item[-10:]: |
| img = images[i].pop(idx) |
| |
|
|
| for item in images[i]: |
| tup.append((img, item)) |
|
|
| return tup |
|
|
|
|
| class SyntheticPaired(Dataset): |
| def __init__(self, root, dataset="train", mode="train"): |
| root = root + "/" + dataset |
| self.mode = mode |
| imgs = os.listdir(root) |
|
|
| imgs = [os.path.join(root, k) for k in imgs] |
|
|
| final_img = [] |
| for img in imgs: |
| ims = os.listdir(img) |
| ims = [os.path.join(img, k) for k in ims] |
| final_img.append(ims) |
|
|
| self.data = tupleize_data(final_img) |
|
|
| |
| |
| |
| self.u_x = [list(map(int, k[0].split('/')[-1].strip('.png').split('_')[-5:-1])) for k in self.data] |
| self.u_x = torch.tensor(self.u_x, dtype=torch.float32) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| self.u_y = [list(map(int, k[1].split('/')[-1].strip('.png').split('_')[-5:-1])) for k in self.data] |
| self.u_y = torch.tensor(self.u_y, dtype=torch.float32) |
| |
| |
| |
| |
| |
| |
| |
| self.I = [list(map(int, k[1].split('/')[-1].strip('.png').split('_')[-1])) for k in self.data] |
|
|
|
|
| self.I_one_hot = np.eye(4)[self.I].astype('float32') |
|
|
| self.transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
| def __getitem__(self, idx): |
| img_path_1 = self.data[idx][0] |
| img_path_2 = self.data[idx][1] |
|
|
| x = np.asarray(Image.open(img_path_1)) |
| x_int = np.asarray(Image.open(img_path_2)) |
|
|
| target = torch.from_numpy(np.asarray(self.I_one_hot[idx])) |
|
|
| if self.transforms: |
| x = self.transforms(x) |
| x_int = self.transforms(x_int) |
| else: |
| x = np.from_numpy(np.asarray(x).reshape(96, 96, 4)) |
| x_int = np.from_numpy(np.asarray(x_int).reshape(96, 96, 4)) |
|
|
| return x, x_int, self.u_x[idx], self.u_y[idx], target |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def get_paired_data2(dataset_dir, batch_size, dataset="train", mode="train"): |
| dataset = SyntheticPaired(dataset_dir, dataset=dataset, mode=mode) |
| |
| train_dataset = torch.utils.data.Subset(dataset, list(range(0, int(len(dataset) * 0.7)))) |
| val_dataset = torch.utils.data.Subset(dataset, list(range(int(len(dataset) * 0.7), int(len(dataset) * 0.85)))) |
| test_dataset = torch.utils.data.Subset(dataset, list(range(int(len(dataset) * 0.85), len(dataset)))) |
|
|
| train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| val_loader = Data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True) |
| test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) |
|
|
| return train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader |
|
|
|
|
| def get_paired_data(dataset_dir, batch_size, dataset="train", mode="train"): |
| dataset = SyntheticPaired(dataset_dir, dataset=dataset, mode=mode) |
| print(len(dataset)) |
| dataset_split = torch.utils.data.random_split(dataset, [int(len(dataset)*0.7) + 1, int(len(dataset)*0.3)], generator=torch.Generator().manual_seed(42)) |
| train_dataset, test_dataset = dataset_split[0], dataset_split[1] |
| |
| val_split = torch.utils.data.random_split(test_dataset, [int(len(test_dataset)*0.5), int(len(test_dataset)*0.5)], generator=torch.Generator().manual_seed(42)) |
| val_dataset, test_dataset = val_split[0], val_split[1] |
| |
| |
| |
| |
|
|
| train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| val_loader = Data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True) |
| test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) |
|
|
| return train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader |
|
|
|
|
|
|
| def get_paired_loaders(train_dataset, val_dataset, test_dataset, batch_size): |
|
|
| train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| val_loader = Data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True) |
| test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) |
|
|
| return train_loader, val_loader, test_loader |
|
|
|
|
| def get_batch_unin_dataset_withlabel(dataset_dir, batch_size, dataset="train"): |
| |
| dataset = SyntheticLabeled(dataset_dir, dataset) |
| print(len(dataset)) |
| dataloader = Data.DataLoader(dataset, batch_size=batch_size, shuffle=False) |
|
|
| return dataloader |
|
|
|
|
|
|
| def get_circuit_data(dataset_dir, batch_size, dataset="train"): |
| dataset = CausalCircuit(dataset_dir, dataset) |
| print(len(dataset)) |
| |
| dataloader = Data.DataLoader(dataset, batch_size, shuffle=True) |
| |
| return dataloader |
|
|
|
|
| def get_simplified_circuit_data(dataset_dir, batch_size, dataset="train"): |
| dataset = CausalCircuitSimplified(dataset_dir, dataset) |
| print(len(dataset)) |
| |
| dataloader = Data.DataLoader(dataset, batch_size, shuffle=True) |
| |
| return dataloader |
|
|
|
|
|
|
| class dataload(data.Dataset): |
| def __init__(self, root): |
| imgs = os.listdir(root) |
| self.imgs = [os.path.join(root, k) for k in imgs] |
| self.transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
| def __getitem__(self, idx): |
| img_path = self.imgs[idx] |
| pil_img = Image.open(img_path) |
| array = np.asarray(pil_img) |
| data = torch.from_numpy(array) |
| if self.transforms: |
| data = self.transforms(pil_img) |
| else: |
| pil_img = np.asarray(pil_img).reshape(96,96,4) |
| data = torch.from_numpy(pil_img) |
| return data |
|
|
| def __len__(self): |
| return len(self.imgs) |