Biswajeet1 commited on
Commit
bd3980f
·
verified ·
1 Parent(s): 327c90d

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +93 -93
dataset.py CHANGED
@@ -1,93 +1,93 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import cv2
5
- from torch.utils.data import Dataset
6
- import albumentations as A
7
- from albumentations.pytorch import ToTensorV2
8
-
9
-
10
- class CocoSegmentationDataset(Dataset):
11
-
12
- def __init__(self, coco, image_folder,
13
- category_name=None,
14
- transform=None):
15
-
16
- self.coco = coco
17
- self.image_folder = image_folder
18
- self.transform = transform
19
-
20
- if category_name:
21
- self.cat_ids = self.coco.getCatIds(catNms=[category_name])
22
- self.img_ids = self.coco.getImgIds(catIds=self.cat_ids)
23
- else:
24
- # Use all categories and all images if no specific category is provided
25
- self.cat_ids = self.coco.getCatIds()
26
- self.img_ids = self.coco.getImgIds()
27
-
28
- def __len__(self):
29
- return len(self.img_ids)
30
-
31
- def __getitem__(self, index):
32
-
33
- img_id = self.img_ids[index]
34
- img_info = self.coco.loadImgs(img_id)[0]
35
- img_path = os.path.join(self.image_folder, img_info['file_name'])
36
-
37
- # Load image with OpenCV (BGR to RGB)
38
- image = cv2.imread(img_path)
39
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
-
41
- # Fetch annotations for the image. If self.cat_ids is everything, it gets all annotations.
42
- ann_ids = self.coco.getAnnIds(
43
- imgIds=img_info['id'],
44
- catIds=self.cat_ids,
45
- iscrowd=None
46
- )
47
-
48
- anns = self.coco.loadAnns(ann_ids)
49
- mask = np.zeros((img_info['height'], img_info['width']))
50
-
51
- for ann in anns:
52
- mask = np.maximum(mask, self.coco.annToMask(ann))
53
-
54
- if self.transform:
55
- augmented = self.transform(image=image, mask=mask)
56
- image = augmented['image']
57
- mask = augmented['mask']
58
-
59
- if not isinstance(mask, torch.Tensor):
60
- mask = torch.from_numpy(mask).float()
61
-
62
- if mask.ndim == 2:
63
- mask = mask.unsqueeze(0)
64
-
65
- return image, mask
66
-
67
-
68
- def get_train_transforms(image_size=256):
69
- return A.Compose([
70
- A.LongestMaxSize(max_size=image_size),
71
- A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, fill=(123.675, 116.28, 103.53), fill_mask=0),
72
- A.HorizontalFlip(p=0.5),
73
- A.VerticalFlip(p=0.3),
74
- A.RandomBrightnessContrast(p=0.4),
75
- A.Affine(
76
- scale=(0.9, 1.1),
77
- rotate=(-15, 15),
78
- translate_percent=(0.05, 0.05),
79
- p=0.5
80
- ),
81
- A.GaussianBlur(p=0.2),
82
- A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
83
- ToTensorV2(),
84
- ])
85
-
86
-
87
- def get_val_transforms(image_size=256):
88
- return A.Compose([
89
- A.LongestMaxSize(max_size=image_size),
90
- A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, fill=(123.675, 116.28, 103.53), fill_mask=0),
91
- A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
92
- ToTensorV2(),
93
- ])
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+
9
+
10
+ class CocoSegmentationDataset(Dataset):
11
+
12
+ def __init__(self, coco, image_folder,
13
+ category_name=None,
14
+ transform=None):
15
+
16
+ self.coco = coco
17
+ self.image_folder = image_folder
18
+ self.transform = transform
19
+
20
+ if category_name:
21
+ self.cat_ids = self.coco.getCatIds(catNms=[category_name])
22
+ self.img_ids = self.coco.getImgIds(catIds=self.cat_ids)
23
+ else:
24
+ # Use all categories and all images if no specific category is provided
25
+ self.cat_ids = self.coco.getCatIds()
26
+ self.img_ids = self.coco.getImgIds()
27
+
28
+ def __len__(self):
29
+ return len(self.img_ids)
30
+
31
+ def __getitem__(self, index):
32
+
33
+ img_id = self.img_ids[index]
34
+ img_info = self.coco.loadImgs(img_id)[0]
35
+ img_path = os.path.join(self.image_folder, img_info['file_name'])
36
+
37
+ # Load image with OpenCV (BGR to RGB)
38
+ image = cv2.imread(img_path)
39
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
+
41
+ # Fetch annotations for the image. If self.cat_ids is everything, it gets all annotations.
42
+ ann_ids = self.coco.getAnnIds(
43
+ imgIds=img_info['id'],
44
+ catIds=self.cat_ids,
45
+ iscrowd=None
46
+ )
47
+
48
+ anns = self.coco.loadAnns(ann_ids)
49
+ mask = np.zeros((img_info['height'], img_info['width']))
50
+
51
+ for ann in anns:
52
+ mask = np.maximum(mask, self.coco.annToMask(ann))
53
+
54
+ if self.transform:
55
+ augmented = self.transform(image=image, mask=mask)
56
+ image = augmented['image']
57
+ mask = augmented['mask']
58
+
59
+ if not isinstance(mask, torch.Tensor):
60
+ mask = torch.from_numpy(mask).float()
61
+
62
+ if mask.ndim == 2:
63
+ mask = mask.unsqueeze(0)
64
+
65
+ return image, mask
66
+
67
+
68
+ def get_train_transforms(image_size=256):
69
+ return A.Compose([
70
+ A.LongestMaxSize(max_size=image_size),
71
+ A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0),
72
+ A.HorizontalFlip(p=0.5),
73
+ A.VerticalFlip(p=0.3),
74
+ A.RandomBrightnessContrast(p=0.4),
75
+ A.Affine(
76
+ scale=(0.9, 1.1),
77
+ rotate=(-15, 15),
78
+ translate_percent=(0.05, 0.05),
79
+ p=0.5
80
+ ),
81
+ A.GaussianBlur(p=0.2),
82
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
83
+ ToTensorV2(),
84
+ ])
85
+
86
+
87
+ def get_val_transforms(image_size=256):
88
+ return A.Compose([
89
+ A.LongestMaxSize(max_size=image_size),
90
+ A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0),
91
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
92
+ ToTensorV2(),
93
+ ])