File size: 8,388 Bytes
36dd4e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Dataset loading and preprocessing for crop disease detection
"""

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path

def get_transforms(split='train', input_size=224):
    """
    Get image transforms for different dataset splits
    
    Args:
        split: 'train', 'val', or 'test'
        input_size: Input image size (default: 224)
    
    Returns:
        transforms.Compose: Composed transforms
    """
    if split == 'train':
        # Training transforms with augmentation
        return transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        # Validation/test transforms (no augmentation)
        return transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])

def get_inference_transforms(input_size=224):
    """
    Get transforms for inference (prediction)
    
    Args:
        input_size: Input image size (default: 224)
    
    Returns:
        transforms.Compose: Composed transforms for inference
    """
    return transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

class CropDiseaseDataset(Dataset):
    """Custom dataset for crop disease images"""
    
    def __init__(self, data_dir, transform=None, class_to_idx=None):
        """
        Args:
            data_dir: Path to dataset directory (train/val/test)
            transform: Optional transform to be applied on images
            class_to_idx: Dictionary mapping class names to indices
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        
        # Get all image files and their labels
        self.samples = []
        self.classes = []
        
        # Scan all class directories
        for class_dir in sorted(self.data_dir.iterdir()):
            if class_dir.is_dir() and not class_dir.name.startswith('.'):
                self.classes.append(class_dir.name)
        
        # Create class to index mapping if not provided
        if class_to_idx is None:
            self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        else:
            self.class_to_idx = class_to_idx
        
        # Collect all image samples
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        
        for class_name in self.classes:
            class_dir = self.data_dir / class_name
            class_idx = self.class_to_idx[class_name]
            
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in image_extensions:
                    self.samples.append((str(img_path), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            # Create a dummy image if file doesn't exist or is corrupted
            print(f"Warning: Could not load {img_path}, creating dummy image")
            image = Image.new('RGB', (224, 224), color=(128, 128, 128))
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_class_names(self):
        """Return list of class names"""
        return self.classes
    
    def get_class_to_idx(self):
        """Return class to index mapping"""
        return self.class_to_idx

def get_data_transforms():
    """Get data transforms for training and validation"""
    
    # ImageNet normalization values
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    
    # Training transforms with data augmentation
    train_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        normalize
    ])
    
    # Validation/Test transforms (no augmentation)
    val_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize
    ])
    
    return train_transforms, val_transforms

def create_data_loaders(data_dir, batch_size=32, num_workers=0):
    """Create data loaders for training, validation, and testing"""
    
    train_transforms, val_transforms = get_data_transforms()
    pin_memory = torch.cuda.is_available()
    
    # Create datasets
    train_dataset = CropDiseaseDataset(
        data_dir=os.path.join(data_dir, 'train'),
        transform=train_transforms
    )
    
    val_dataset = CropDiseaseDataset(
        data_dir=os.path.join(data_dir, 'val'),
        transform=val_transforms,
        class_to_idx=train_dataset.get_class_to_idx()
    )
    
    test_dataset = CropDiseaseDataset(
        data_dir=os.path.join(data_dir, 'test'),
        transform=val_transforms,
        class_to_idx=train_dataset.get_class_to_idx()
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    return train_loader, val_loader, test_loader, train_dataset.get_class_names()

def get_class_weights(data_dir):
    """Calculate class weights for handling imbalanced datasets"""
    
    train_dataset = CropDiseaseDataset(data_dir=os.path.join(data_dir, 'train'))
    
    # Count samples per class
    class_counts = {}
    for _, label in train_dataset.samples:
        class_name = train_dataset.classes[label]
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    
    # Calculate weights (inverse frequency)
    total_samples = len(train_dataset.samples)
    num_classes = len(train_dataset.classes)
    
    class_weights = []
    for class_name in train_dataset.classes:
        count = class_counts.get(class_name, 1)
        weight = total_samples / (num_classes * count)
        class_weights.append(weight)
    
    return torch.FloatTensor(class_weights)

if __name__ == "__main__":
    # Test the dataset loading
    data_dir = "data"
    
    try:
        train_loader, val_loader, test_loader, class_names = create_data_loaders(data_dir, batch_size=4)
        
        print(f"Dataset loaded successfully!")
        print(f"Number of classes: {len(class_names)}")
        print(f"Classes: {class_names}")
        print(f"Train batches: {len(train_loader)}")
        print(f"Val batches: {len(val_loader)}")
        print(f"Test batches: {len(test_loader)}")
        
        # Test loading a batch
        for images, labels in train_loader:
            print(f"Batch shape: {images.shape}")
            print(f"Label shape: {labels.shape}")
            break
            
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Make sure the dataset is properly organized in data/train, data/val, data/test")