ovedtal commited on
Commit ·
cb9235c
0
Parent(s):
Initial Project
Browse files- 2367801_Final_Project_Winter_25 (2).pdf +0 -0
- code/environment.yml +51 -0
- code/main.py +56 -0
- code/utils.py +54 -0
2367801_Final_Project_Winter_25 (2).pdf
ADDED
|
Binary file (282 kB). View file
|
|
|
code/environment.yml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: aes
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- ca-certificates=2024.12.31=h06a4308_0
|
| 8 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 9 |
+
- libffi=3.3=he6710b0_2
|
| 10 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 11 |
+
- libgomp=11.2.0=h1234567_1
|
| 12 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 13 |
+
- ncurses=6.4=h6a678d5_0
|
| 14 |
+
- openssl=1.1.1w=h7f8727e_0
|
| 15 |
+
- pip=25.0=py39h06a4308_0
|
| 16 |
+
- python=3.9.0=hdb3f193_2
|
| 17 |
+
- readline=8.2=h5eee18b_0
|
| 18 |
+
- setuptools=75.8.0=py39h06a4308_0
|
| 19 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 20 |
+
- tk=8.6.14=h39e8969_0
|
| 21 |
+
- tzdata=2025a=h04d1e81_0
|
| 22 |
+
- wheel=0.45.1=py39h06a4308_0
|
| 23 |
+
- xz=5.4.6=h5eee18b_1
|
| 24 |
+
- zlib=1.2.13=h5eee18b_1
|
| 25 |
+
- pip:
|
| 26 |
+
- contourpy==1.3.0
|
| 27 |
+
- cycler==0.12.1
|
| 28 |
+
- filelock==3.17.0
|
| 29 |
+
- fonttools==4.55.8
|
| 30 |
+
- fsspec==2025.2.0
|
| 31 |
+
- importlib-resources==6.5.2
|
| 32 |
+
- jinja2==3.1.5
|
| 33 |
+
- kiwisolver==1.4.7
|
| 34 |
+
- markupsafe==3.0.2
|
| 35 |
+
- matplotlib==3.9.4
|
| 36 |
+
- mpmath==1.3.0
|
| 37 |
+
- networkx==3.2.1
|
| 38 |
+
- numpy==2.0.2
|
| 39 |
+
- packaging==24.2
|
| 40 |
+
- pillow==11.1.0
|
| 41 |
+
- pyparsing==3.2.1
|
| 42 |
+
- python-dateutil==2.9.0.post0
|
| 43 |
+
- six==1.17.0
|
| 44 |
+
- sympy==1.13.1
|
| 45 |
+
- torch==2.6.0
|
| 46 |
+
- torchaudio==2.6.0
|
| 47 |
+
- torchvision==0.21.0
|
| 48 |
+
- triton==3.2.0
|
| 49 |
+
- typing-extensions==4.12.2
|
| 50 |
+
- zipp==3.21.0
|
| 51 |
+
- scikit-learn
|
code/main.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import datasets, transforms
|
| 3 |
+
import numpy as np
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
from utils import plot_tsne
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
NUM_CLASSES = 10
|
| 11 |
+
|
| 12 |
+
def freeze_seeds(seed=0):
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
np.random.seed(seed)
|
| 15 |
+
torch.manual_seed(seed)
|
| 16 |
+
def get_args():
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument('--seed', default=0, type=int, help='Seed for random number generators')
|
| 19 |
+
parser.add_argument('--data-path', default="/datasets/cv_datasets/data", type=str, help='Path to dataset')
|
| 20 |
+
parser.add_argument('--batch-size', default=8, type=int, help='Size of each batch')
|
| 21 |
+
parser.add_argument('--latent-dim', default=128, type=int, help='encoding dimension')
|
| 22 |
+
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help='Default device to use')
|
| 23 |
+
parser.add_argument('--mnist', action='store_true', default=False,
|
| 24 |
+
help='Whether to use MNIST (True) or CIFAR10 (False) data')
|
| 25 |
+
parser.add_argument('--self-supervised', action='store_true', default=False,
|
| 26 |
+
help='Whether train self-supervised with reconstruction objective, or jointly with classifier for classification objective.')
|
| 27 |
+
return parser.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
transform = transforms.Compose([
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) #one possible convenient normalization. You don't have to use it.
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
args = get_args()
|
| 37 |
+
freeze_seeds(args.seed)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if args.mnist:
|
| 41 |
+
train_dataset = datasets.MNIST(root=args.data_path, train=True, download=False, transform=transform)
|
| 42 |
+
test_dataset = datasets.MNIST(root=args.data_path, train=False, download=False, transform=transform)
|
| 43 |
+
else:
|
| 44 |
+
train_dataset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform)
|
| 45 |
+
test_dataset = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform)
|
| 46 |
+
|
| 47 |
+
# When you create your dataloader you should split train_dataset or test_dataset to leave some aside for validation
|
| 48 |
+
|
| 49 |
+
#this is just for the example. Simple flattening of the image is probably not the best idea
|
| 50 |
+
encoder_model = torch.nn.Linear(32*32*3,args.latent_dim).to(args.device)
|
| 51 |
+
decoder_model = torch.nn.Linear(args.latent_dim,32*32*3 if args.self_supervised else NUM_CLASSES).to(args.device)
|
| 52 |
+
|
| 53 |
+
sample = train_dataset[0][0][None].to(args.device) #This is just for the example - you should use a dataloader
|
| 54 |
+
output = decoder_model(encoder_model(sample.flatten()))
|
| 55 |
+
print(output.shape)
|
| 56 |
+
|
code/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.manifold import TSNE
|
| 4 |
+
|
| 5 |
+
def plot_tsne(model, dataloader, device):
|
| 6 |
+
'''
|
| 7 |
+
model - torch.nn.Module subclass. This is your encoder model
|
| 8 |
+
dataloader - test dataloader to over over data for which you wish to compute projections
|
| 9 |
+
device - cuda or cpu (as a string)
|
| 10 |
+
'''
|
| 11 |
+
model.eval()
|
| 12 |
+
|
| 13 |
+
images_list = []
|
| 14 |
+
labels_list = []
|
| 15 |
+
latent_list = []
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
for data in dataloader:
|
| 19 |
+
images, labels = data
|
| 20 |
+
images, labels = images.to(device), labels.to(device)
|
| 21 |
+
|
| 22 |
+
#approximate the latent space from data
|
| 23 |
+
latent_vector = model(images)
|
| 24 |
+
|
| 25 |
+
images_list.append(images.cpu().numpy())
|
| 26 |
+
labels_list.append(labels.cpu().numpy())
|
| 27 |
+
latent_list.append(latent_vector.cpu().numpy())
|
| 28 |
+
|
| 29 |
+
images = np.concatenate(images_list, axis=0)
|
| 30 |
+
labels = np.concatenate(labels_list, axis=0)
|
| 31 |
+
latent_vectors = np.concatenate(latent_list, axis=0)
|
| 32 |
+
|
| 33 |
+
# Plot TSNE for latent space
|
| 34 |
+
tsne_latent = TSNE(n_components=2, random_state=0)
|
| 35 |
+
latent_tsne = tsne_latent.fit_transform(latent_vectors)
|
| 36 |
+
|
| 37 |
+
plt.figure(figsize=(8, 6))
|
| 38 |
+
scatter = plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=labels, cmap='tab10', s=10) # Smaller points
|
| 39 |
+
plt.colorbar(scatter)
|
| 40 |
+
plt.title('t-SNE of Latent Space')
|
| 41 |
+
plt.savefig('latent_tsne.png')
|
| 42 |
+
plt.close()
|
| 43 |
+
|
| 44 |
+
#plot image domain tsne
|
| 45 |
+
tsne_image = TSNE(n_components=2, random_state=42)
|
| 46 |
+
images_flattened = images.reshape(images.shape[0], -1)
|
| 47 |
+
image_tsne = tsne_image.fit_transform(images_flattened)
|
| 48 |
+
|
| 49 |
+
plt.figure(figsize=(8, 6))
|
| 50 |
+
scatter = plt.scatter(image_tsne[:, 0], image_tsne[:, 1], c=labels, cmap='tab10', s=10)
|
| 51 |
+
plt.colorbar(scatter)
|
| 52 |
+
plt.title('t-SNE of Image Space')
|
| 53 |
+
plt.savefig('image_tsne.png')
|
| 54 |
+
plt.close()
|