ovedtal commited on
Commit
cb9235c
·
0 Parent(s):

Initial Project

Browse files
2367801_Final_Project_Winter_25 (2).pdf ADDED
Binary file (282 kB). View file
 
code/environment.yml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: aes
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - ca-certificates=2024.12.31=h06a4308_0
8
+ - ld_impl_linux-64=2.40=h12ee557_0
9
+ - libffi=3.3=he6710b0_2
10
+ - libgcc-ng=11.2.0=h1234567_1
11
+ - libgomp=11.2.0=h1234567_1
12
+ - libstdcxx-ng=11.2.0=h1234567_1
13
+ - ncurses=6.4=h6a678d5_0
14
+ - openssl=1.1.1w=h7f8727e_0
15
+ - pip=25.0=py39h06a4308_0
16
+ - python=3.9.0=hdb3f193_2
17
+ - readline=8.2=h5eee18b_0
18
+ - setuptools=75.8.0=py39h06a4308_0
19
+ - sqlite=3.45.3=h5eee18b_0
20
+ - tk=8.6.14=h39e8969_0
21
+ - tzdata=2025a=h04d1e81_0
22
+ - wheel=0.45.1=py39h06a4308_0
23
+ - xz=5.4.6=h5eee18b_1
24
+ - zlib=1.2.13=h5eee18b_1
25
+ - pip:
26
+ - contourpy==1.3.0
27
+ - cycler==0.12.1
28
+ - filelock==3.17.0
29
+ - fonttools==4.55.8
30
+ - fsspec==2025.2.0
31
+ - importlib-resources==6.5.2
32
+ - jinja2==3.1.5
33
+ - kiwisolver==1.4.7
34
+ - markupsafe==3.0.2
35
+ - matplotlib==3.9.4
36
+ - mpmath==1.3.0
37
+ - networkx==3.2.1
38
+ - numpy==2.0.2
39
+ - packaging==24.2
40
+ - pillow==11.1.0
41
+ - pyparsing==3.2.1
42
+ - python-dateutil==2.9.0.post0
43
+ - six==1.17.0
44
+ - sympy==1.13.1
45
+ - torch==2.6.0
46
+ - torchaudio==2.6.0
47
+ - torchvision==0.21.0
48
+ - triton==3.2.0
49
+ - typing-extensions==4.12.2
50
+ - zipp==3.21.0
51
+ - scikit-learn
code/main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import datasets, transforms
3
+ import numpy as np
4
+ from matplotlib import pyplot as plt
5
+ from utils import plot_tsne
6
+ import numpy as np
7
+ import random
8
+ import argparse
9
+
10
+ NUM_CLASSES = 10
11
+
12
+ def freeze_seeds(seed=0):
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--seed', default=0, type=int, help='Seed for random number generators')
19
+ parser.add_argument('--data-path', default="/datasets/cv_datasets/data", type=str, help='Path to dataset')
20
+ parser.add_argument('--batch-size', default=8, type=int, help='Size of each batch')
21
+ parser.add_argument('--latent-dim', default=128, type=int, help='encoding dimension')
22
+ parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help='Default device to use')
23
+ parser.add_argument('--mnist', action='store_true', default=False,
24
+ help='Whether to use MNIST (True) or CIFAR10 (False) data')
25
+ parser.add_argument('--self-supervised', action='store_true', default=False,
26
+ help='Whether train self-supervised with reconstruction objective, or jointly with classifier for classification objective.')
27
+ return parser.parse_args()
28
+
29
+
30
+ if __name__ == "__main__":
31
+ transform = transforms.Compose([
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) #one possible convenient normalization. You don't have to use it.
34
+ ])
35
+
36
+ args = get_args()
37
+ freeze_seeds(args.seed)
38
+
39
+
40
+ if args.mnist:
41
+ train_dataset = datasets.MNIST(root=args.data_path, train=True, download=False, transform=transform)
42
+ test_dataset = datasets.MNIST(root=args.data_path, train=False, download=False, transform=transform)
43
+ else:
44
+ train_dataset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform)
45
+ test_dataset = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform)
46
+
47
+ # When you create your dataloader you should split train_dataset or test_dataset to leave some aside for validation
48
+
49
+ #this is just for the example. Simple flattening of the image is probably not the best idea
50
+ encoder_model = torch.nn.Linear(32*32*3,args.latent_dim).to(args.device)
51
+ decoder_model = torch.nn.Linear(args.latent_dim,32*32*3 if args.self_supervised else NUM_CLASSES).to(args.device)
52
+
53
+ sample = train_dataset[0][0][None].to(args.device) #This is just for the example - you should use a dataloader
54
+ output = decoder_model(encoder_model(sample.flatten()))
55
+ print(output.shape)
56
+
code/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.manifold import TSNE
4
+
5
+ def plot_tsne(model, dataloader, device):
6
+ '''
7
+ model - torch.nn.Module subclass. This is your encoder model
8
+ dataloader - test dataloader to over over data for which you wish to compute projections
9
+ device - cuda or cpu (as a string)
10
+ '''
11
+ model.eval()
12
+
13
+ images_list = []
14
+ labels_list = []
15
+ latent_list = []
16
+
17
+ with torch.no_grad():
18
+ for data in dataloader:
19
+ images, labels = data
20
+ images, labels = images.to(device), labels.to(device)
21
+
22
+ #approximate the latent space from data
23
+ latent_vector = model(images)
24
+
25
+ images_list.append(images.cpu().numpy())
26
+ labels_list.append(labels.cpu().numpy())
27
+ latent_list.append(latent_vector.cpu().numpy())
28
+
29
+ images = np.concatenate(images_list, axis=0)
30
+ labels = np.concatenate(labels_list, axis=0)
31
+ latent_vectors = np.concatenate(latent_list, axis=0)
32
+
33
+ # Plot TSNE for latent space
34
+ tsne_latent = TSNE(n_components=2, random_state=0)
35
+ latent_tsne = tsne_latent.fit_transform(latent_vectors)
36
+
37
+ plt.figure(figsize=(8, 6))
38
+ scatter = plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=labels, cmap='tab10', s=10) # Smaller points
39
+ plt.colorbar(scatter)
40
+ plt.title('t-SNE of Latent Space')
41
+ plt.savefig('latent_tsne.png')
42
+ plt.close()
43
+
44
+ #plot image domain tsne
45
+ tsne_image = TSNE(n_components=2, random_state=42)
46
+ images_flattened = images.reshape(images.shape[0], -1)
47
+ image_tsne = tsne_image.fit_transform(images_flattened)
48
+
49
+ plt.figure(figsize=(8, 6))
50
+ scatter = plt.scatter(image_tsne[:, 0], image_tsne[:, 1], c=labels, cmap='tab10', s=10)
51
+ plt.colorbar(scatter)
52
+ plt.title('t-SNE of Image Space')
53
+ plt.savefig('image_tsne.png')
54
+ plt.close()