m23csa016 commited on
Commit
e760b35
·
1 Parent(s): 00045fc

Added data files

Browse files
cycleGAN/data/.gitkeep ADDED
File without changes
cycleGAN/data/__init__.py ADDED
File without changes
cycleGAN/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
cycleGAN/data/__pycache__/customdataset.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
cycleGAN/data/customdataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image, ImageFilter
6
+ import pandas as pd
7
+ import numpy as np
8
+ import os
9
+
10
+ # sys.path.append("/iitjhome/m23csa016/DLASS4")
11
+ sys.path.append("data")
12
+ TRAIN_LABELS = "data/Assignment_4/Train/Train_labels.csv"
13
+ TEST_LABELS = "data/Assignment_4/Test/Test_Labels.csv"
14
+
15
+ TRAIN_DATA_DIR = "data/Assignment_4/Train/Train_data"
16
+ TEST_DATA_DIR = "data/Assignment_4/Test/Test"
17
+
18
+ TRAIN_SKETCH_DIR = "data/Assignment_4/Train/Contours"
19
+ TEST_SKETCH_DIR = "data/Assignment_4/Test/Test_contours"
20
+
21
+ # Create Dataset
22
+ class ISICDataset(Dataset):
23
+ def __init__(self, datadir, csvpath, sketchdir, transform=None):
24
+ self.datadir = datadir
25
+ self.csv = pd.read_csv(csvpath)
26
+ self.sketchdir = sketchdir
27
+ self.transform = transform
28
+
29
+ def __len__(self):
30
+ return len(self.csv[:300])
31
+
32
+ def __getitem__(self, index):
33
+ img_path = os.path.join(self.datadir, self.csv.iloc[index, 0] + ".jpg")
34
+ image = Image.open(img_path)
35
+
36
+ # Apply Gaussian blur
37
+ blurred_image = image.filter(ImageFilter.GaussianBlur(radius=2))
38
+
39
+ # Apply sharpening
40
+ sharpened_image = image.filter(ImageFilter.UnsharpMask)
41
+
42
+ labels = self.csv.iloc[index, 1:].values
43
+
44
+ sketch_name = random.choice(os.listdir(self.sketchdir))
45
+ sketch_path = os.path.join(self.sketchdir, sketch_name)
46
+ fs, ext = os.path.splitext(sketch_path)
47
+
48
+ while ext not in ['.jpg', '.jpeg', '.png']:
49
+ sketch_name = random.choice(os.listdir(self.sketchdir))
50
+ sketch_path = os.path.join(self.sketchdir, sketch_name)
51
+ fs, ext = os.path.splitext(sketch_path)
52
+
53
+ sketch = Image.open(sketch_path)
54
+
55
+ image = self.transform['img'](sharpened_image)
56
+ sketch = self.transform['sketch'](sketch)
57
+
58
+ true_label = np.argmax(labels)
59
+ # Create a numpy array of zeros with shape (num_classes, 256, 256)
60
+ encoded_label = np.zeros((7, image.size(1), image.size(1)), dtype=np.float32)
61
+
62
+ # Set all elements in the channel corresponding to the true label to 1
63
+ encoded_label[true_label, :, :] = 1 # Use broadcasting to set all elements
64
+
65
+ label = torch.tensor(encoded_label, dtype=torch.float32)
66
+
67
+ return label, image, sketch
68
+
69
+
70
+ def prepdata(config, transform=None):
71
+ # Train Dataset and Dataloader
72
+ train_dataset = ISICDataset(TRAIN_DATA_DIR, TRAIN_LABELS, TRAIN_SKETCH_DIR, transform=transform)
73
+ train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
74
+
75
+ # Train Dataset and Dataloader
76
+ test_dataset = ISICDataset(TEST_DATA_DIR, TEST_LABELS, TEST_SKETCH_DIR, transform=transform)
77
+ test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
78
+
79
+ return train_dataloader, test_dataloader
80
+
81
+
82
+
cycleGAN/data/make_dataset.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zipfile import ZipFile
2
+ import os
3
+
4
+ # Extract the data
5
+ DATA_DIR = "data"
6
+ EXTRACT_PATH = "data"
7
+
8
+ TRAIN_DIR = os.path.join(DATA_DIR, "Assignment_4/Train")
9
+ TRAIN_FILES = os.listdir(TRAIN_DIR)
10
+
11
+ TEST_DIR = os.path.join(DATA_DIR, "Assignment_4/Test")
12
+ TEST_FILES = os.listdir(os.path.join(DATA_DIR, "Assignment_4/Test"))
13
+
14
+ def extract_zipfile(file_path, extract_path):
15
+ with ZipFile(file_path, 'r') as asz:
16
+ asz.extractall(extract_path)
17
+
18
+ extract_zipfile(os.path.join(DATA_DIR, "Assignment_4-20240408T115837Z-002.zip"), EXTRACT_PATH)
19
+
20
+ for file in TRAIN_FILES:
21
+ if file.endswith(".zip"):
22
+ file_path = os.path.join(TRAIN_DIR, file)
23
+ extract_zipfile(file_path, TRAIN_DIR)
24
+
25
+ for file in TEST_FILES:
26
+ if file.endswith(".zip"):
27
+ file_path = os.path.join(TEST_DIR, file)
28
+ extract_zipfile(file_path, TEST_DIR)