File size: 3,657 Bytes
09967a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
PyTorch dataset for LIDC-IDRI flat format (image + majority-vote mask).
"""
import os
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import random


class LIDCFlatDataset(Dataset):
    """Dataset for flat directory structure with matched image/mask pairs."""
    
    def __init__(self, root_dir, augment=False, img_size=128):
        """
        Args:
            root_dir: Directory containing 'images/' and 'masks/' subdirectories
            augment: Whether to apply data augmentation
            img_size: Target image size (images should already be this size)
        """
        self.root_dir = root_dir
        self.augment = augment
        self.img_size = img_size
        
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        
        self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
        
        if len(self.image_files) == 0:
            raise RuntimeError(f"No images found in {self.image_dir}")
        
        print(f"Dataset: {len(self.image_files)} samples from {root_dir}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        basename = os.path.basename(img_path)
        mask_path = os.path.join(self.mask_dir, basename)
        
        # Load image (grayscale)
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        
        # Apply augmentation
        if self.augment:
            # Random horizontal flip
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            
            # Random vertical flip
            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
            
            # Random rotation (±15 degrees)
            angle = random.uniform(-15, 15)
            image = TF.rotate(image, angle, fill=0)
            mask = TF.rotate(mask, angle, fill=0)
        
        # Convert to tensors
        image = TF.to_tensor(image)  # [1, H, W], range [0, 1]
        mask = TF.to_tensor(mask)    # [1, H, W], range [0, 1]
        mask = (mask > 0.5).float()  # Binarize
        
        # Get sample ID for evaluation
        sample_id = os.path.splitext(basename)[0]
        
        return image, mask, sample_id
    
    def get_sample_id(self, idx):
        """Get sample ID without loading the image."""
        return os.path.splitext(os.path.basename(self.image_files[idx]))[0]


class LIDCTestDataset(Dataset):
    """Test dataset - loads only images (no masks needed for prediction)."""
    
    def __init__(self, root_dir, img_size=128):
        self.root_dir = root_dir
        self.img_size = img_size
        self.image_dir = os.path.join(root_dir, "images")
        self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
        
        if len(self.image_files) == 0:
            raise RuntimeError(f"No images found in {self.image_dir}")
        
        print(f"Test dataset: {len(self.image_files)} samples from {root_dir}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        basename = os.path.basename(img_path)
        
        image = Image.open(img_path).convert("L")
        image = TF.to_tensor(image)
        
        sample_id = os.path.splitext(basename)[0]
        return image, sample_id