fxxkingusername commited on
Commit
d2aee5b
·
verified ·
1 Parent(s): fe88809

Upload src/training\data_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/training//data_loader.py +511 -0
src/training//data_loader.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced data loader for architectural style classification.
3
+ Includes advanced augmentation and better data handling.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms
10
+ import numpy as np
11
+ from typing import Dict, List, Optional, Tuple, Any
12
+ import os
13
+ from PIL import Image
14
+ import random
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+
18
+
19
+ class EnhancedArchitecturalDataset(Dataset):
20
+ """Enhanced dataset for architectural style classification with advanced augmentation."""
21
+
22
+ def __init__(self, data_dir: str, transform: Optional[transforms.Compose] = None,
23
+ split: str = 'train', num_samples: Optional[int] = None, use_albumentations: bool = True):
24
+ self.data_dir = data_dir
25
+ self.split = split
26
+ self.use_albumentations = use_albumentations
27
+
28
+ # Use enhanced transforms if albumentations is available
29
+ if use_albumentations:
30
+ self.transform = transform or self._get_enhanced_transform()
31
+ else:
32
+ self.transform = transform or self._get_default_transform()
33
+
34
+ # Load data paths and labels
35
+ self.data_paths, self.labels = self._load_data()
36
+
37
+ # Limit samples if specified
38
+ if num_samples and len(self.data_paths) > 0:
39
+ # Ensure we don't try to sample more than available
40
+ actual_samples = min(num_samples, len(self.data_paths))
41
+ indices = random.sample(range(len(self.data_paths)), actual_samples)
42
+ self.data_paths = [self.data_paths[i] for i in indices]
43
+ self.labels = [self.labels[i] for i in indices]
44
+
45
+ def _load_data(self) -> Tuple[List[str], List[int]]:
46
+ """Load data paths and labels."""
47
+ data_paths = []
48
+ labels = []
49
+
50
+ # Check if data directory exists
51
+ if not os.path.exists(self.data_dir):
52
+ print(f"Warning: Data directory {self.data_dir} does not exist. Using sample data.")
53
+ return self._generate_sample_data()
54
+
55
+ # First try to load from directory structure directly in data_dir (real data)
56
+ real_data_found = False
57
+ for class_idx in range(25): # 25 architectural styles
58
+ class_dir = os.path.join(self.data_dir, str(class_idx))
59
+ if os.path.exists(class_dir):
60
+ real_data_found = True
61
+ for filename in os.listdir(class_dir):
62
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
63
+ data_paths.append(os.path.join(class_dir, filename))
64
+ labels.append(class_idx)
65
+
66
+ if real_data_found:
67
+ print(f"Loading real data from directory: {self.data_dir}")
68
+ return data_paths, labels
69
+
70
+ # Fallback to sample_data subdirectory if no real data found
71
+ sample_data_dir = os.path.join(self.data_dir, 'sample_data')
72
+ if os.path.exists(sample_data_dir):
73
+ print(f"Loading data from sample_data directory: {sample_data_dir}")
74
+ # Load from sample_data directory structure
75
+ for class_idx in range(25): # 25 architectural styles
76
+ class_dir = os.path.join(sample_data_dir, str(class_idx))
77
+ if os.path.exists(class_dir):
78
+ for filename in os.listdir(class_dir):
79
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
80
+ data_paths.append(os.path.join(class_dir, filename))
81
+ labels.append(class_idx)
82
+
83
+ return data_paths, labels
84
+
85
+ def _get_enhanced_transform(self) -> A.Compose:
86
+ """Get enhanced transforms using Albumentations."""
87
+ if self.split == 'train':
88
+ return A.Compose([
89
+ A.Resize(256, 256),
90
+ A.RandomCrop(224, 224, p=0.8),
91
+ A.HorizontalFlip(p=0.5),
92
+ A.VerticalFlip(p=0.1),
93
+ A.RandomRotate90(p=0.3),
94
+ A.Rotate(limit=15, p=0.5),
95
+ A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
96
+ A.OneOf([
97
+ A.MotionBlur(blur_limit=3, p=0.3),
98
+ A.MedianBlur(blur_limit=3, p=0.3),
99
+ A.Blur(blur_limit=3, p=0.3),
100
+ ], p=0.2),
101
+ A.OneOf([
102
+ A.CLAHE(clip_limit=2, p=0.3),
103
+ A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.3),
104
+ A.RandomGamma(gamma_limit=(80, 120), p=0.3),
105
+ ], p=0.5),
106
+ A.OneOf([
107
+ A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
108
+ A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=0.3),
109
+ ], p=0.3),
110
+ A.OneOf([
111
+ A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
112
+ A.ISONoise(color_shift=(0.01, 0.05), p=0.3),
113
+ ], p=0.2),
114
+ A.OneOf([
115
+ A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
116
+ A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
117
+ A.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
118
+ ], p=0.2),
119
+ A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
120
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
121
+ ToTensorV2(),
122
+ ])
123
+ else:
124
+ return A.Compose([
125
+ A.Resize(224, 224),
126
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
127
+ ToTensorV2(),
128
+ ])
129
+
130
+ def _get_default_transform(self) -> transforms.Compose:
131
+ """Get default transforms for architectural images."""
132
+ if self.split == 'train':
133
+ return transforms.Compose([
134
+ transforms.Resize((256, 256)),
135
+ transforms.RandomCrop((224, 224)),
136
+ transforms.RandomHorizontalFlip(p=0.5),
137
+ transforms.RandomVerticalFlip(p=0.1),
138
+ transforms.RandomRotation(degrees=15),
139
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
140
+ transforms.RandomGrayscale(p=0.1),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225])
144
+ ])
145
+ else:
146
+ return transforms.Compose([
147
+ transforms.Resize((224, 224)),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
150
+ std=[0.229, 0.224, 0.225])
151
+ ])
152
+
153
+ def _generate_sample_data(self) -> Tuple[List[str], List[int]]:
154
+ """Generate sample data for testing."""
155
+ print("Generating sample data for testing...")
156
+
157
+ # Create sample data directory
158
+ sample_dir = os.path.join(self.data_dir, 'sample_data')
159
+ os.makedirs(sample_dir, exist_ok=True)
160
+
161
+ data_paths = []
162
+ labels = []
163
+
164
+ # Generate sample images for each class
165
+ for class_idx in range(25):
166
+ class_dir = os.path.join(sample_dir, str(class_idx))
167
+ os.makedirs(class_dir, exist_ok=True)
168
+
169
+ # Generate 20 sample images per class (increased from 10)
170
+ for i in range(20):
171
+ # Create a simple colored image as placeholder
172
+ img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
173
+
174
+ # Add some class-specific patterns
175
+ if class_idx < 5: # Ancient styles
176
+ img_array[:, :, 0] = np.random.randint(100, 200) # Reddish
177
+ elif class_idx < 10: # Medieval styles
178
+ img_array[:, :, 1] = np.random.randint(100, 200) # Greenish
179
+ elif class_idx < 15: # Renaissance styles
180
+ img_array[:, :, 2] = np.random.randint(100, 200) # Bluish
181
+ elif class_idx < 20: # Modern styles
182
+ img_array[:, :, :] = np.random.randint(150, 255) # Bright
183
+ else: # Contemporary styles
184
+ img_array[:, :, :] = np.random.randint(0, 100) # Dark
185
+
186
+ # Save image
187
+ img = Image.fromarray(img_array)
188
+ img_path = os.path.join(class_dir, f'sample_{i}.jpg')
189
+ img.save(img_path)
190
+
191
+ data_paths.append(img_path)
192
+ labels.append(class_idx)
193
+
194
+ print(f"Generated {len(data_paths)} sample images")
195
+ return data_paths, labels
196
+
197
+ def __len__(self) -> int:
198
+ return len(self.data_paths)
199
+
200
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
201
+ img_path = self.data_paths[idx]
202
+ label = self.labels[idx]
203
+
204
+ # Load image
205
+ try:
206
+ image = Image.open(img_path).convert('RGB')
207
+ except:
208
+ # If image loading fails, create a random image
209
+ image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
210
+
211
+ # Apply transforms
212
+ if self.use_albumentations and isinstance(self.transform, A.Compose):
213
+ # Convert PIL image to numpy array for Albumentations
214
+ image_np = np.array(image)
215
+ transformed = self.transform(image=image_np)
216
+ image = transformed['image']
217
+ else:
218
+ # Use torchvision transforms
219
+ if self.transform:
220
+ image = self.transform(image)
221
+
222
+ return image, label
223
+
224
+
225
+ class EnhancedArchitecturalDataLoader:
226
+ """Enhanced data loader factory for architectural style classification."""
227
+
228
+ def __init__(self, data_dir: str, batch_size: int = 16, num_workers: int = 4, use_albumentations: bool = True):
229
+ self.data_dir = data_dir
230
+ self.batch_size = batch_size
231
+ self.num_workers = num_workers
232
+ self.use_albumentations = use_albumentations
233
+
234
+ # Define transforms
235
+ self.train_transform = self._get_train_transform()
236
+ self.val_transform = self._get_val_transform()
237
+ self.test_transform = self._get_test_transform()
238
+
239
+ def _get_train_transform(self):
240
+ """Get training transforms with advanced augmentation."""
241
+ if self.use_albumentations:
242
+ return A.Compose([
243
+ A.Resize(256, 256),
244
+ A.RandomCrop(224, 224, p=0.8),
245
+ A.HorizontalFlip(p=0.5),
246
+ A.VerticalFlip(p=0.1),
247
+ A.RandomRotate90(p=0.3),
248
+ A.Rotate(limit=15, p=0.5),
249
+ A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
250
+ A.OneOf([
251
+ A.MotionBlur(blur_limit=3, p=0.3),
252
+ A.MedianBlur(blur_limit=3, p=0.3),
253
+ A.Blur(blur_limit=3, p=0.3),
254
+ ], p=0.2),
255
+ A.OneOf([
256
+ A.CLAHE(clip_limit=2, p=0.3),
257
+ A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.3),
258
+ A.RandomGamma(gamma_limit=(80, 120), p=0.3),
259
+ ], p=0.5),
260
+ A.OneOf([
261
+ A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
262
+ A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=0.3),
263
+ ], p=0.3),
264
+ A.OneOf([
265
+ A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
266
+ A.ISONoise(color_shift=(0.01, 0.05), p=0.3),
267
+ ], p=0.2),
268
+ A.OneOf([
269
+ A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
270
+ A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
271
+ A.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
272
+ ], p=0.2),
273
+ A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
274
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
275
+ ToTensorV2(),
276
+ ])
277
+ else:
278
+ return transforms.Compose([
279
+ transforms.Resize((256, 256)),
280
+ transforms.RandomCrop((224, 224)),
281
+ transforms.RandomHorizontalFlip(p=0.5),
282
+ transforms.RandomVerticalFlip(p=0.1),
283
+ transforms.RandomRotation(degrees=15),
284
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
285
+ transforms.RandomGrayscale(p=0.1),
286
+ transforms.ToTensor(),
287
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
288
+ std=[0.229, 0.224, 0.225])
289
+ ])
290
+
291
+ def _get_val_transform(self):
292
+ """Get validation transforms."""
293
+ if self.use_albumentations:
294
+ return A.Compose([
295
+ A.Resize(224, 224),
296
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
297
+ ToTensorV2(),
298
+ ])
299
+ else:
300
+ return transforms.Compose([
301
+ transforms.Resize((224, 224)),
302
+ transforms.ToTensor(),
303
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
304
+ std=[0.229, 0.224, 0.225])
305
+ ])
306
+
307
+ def _get_test_transform(self):
308
+ """Get test transforms."""
309
+ return self._get_val_transform()
310
+
311
+ def get_train_loader(self, num_samples: Optional[int] = None) -> DataLoader:
312
+ """Get training data loader."""
313
+ dataset = EnhancedArchitecturalDataset(
314
+ self.data_dir,
315
+ transform=self.train_transform,
316
+ split='train',
317
+ num_samples=num_samples,
318
+ use_albumentations=self.use_albumentations
319
+ )
320
+
321
+ return DataLoader(
322
+ dataset,
323
+ batch_size=self.batch_size,
324
+ shuffle=True,
325
+ num_workers=self.num_workers,
326
+ pin_memory=True,
327
+ drop_last=True # Drop incomplete batches for better training
328
+ )
329
+
330
+ def get_val_loader(self, num_samples: Optional[int] = None) -> DataLoader:
331
+ """Get validation data loader."""
332
+ dataset = EnhancedArchitecturalDataset(
333
+ self.data_dir,
334
+ transform=self.val_transform,
335
+ split='val',
336
+ num_samples=num_samples,
337
+ use_albumentations=self.use_albumentations
338
+ )
339
+
340
+ return DataLoader(
341
+ dataset,
342
+ batch_size=self.batch_size,
343
+ shuffle=False,
344
+ num_workers=self.num_workers,
345
+ pin_memory=True
346
+ )
347
+
348
+ def get_test_loader(self, num_samples: Optional[int] = None) -> DataLoader:
349
+ """Get test data loader."""
350
+ dataset = EnhancedArchitecturalDataset(
351
+ self.data_dir,
352
+ transform=self.test_transform,
353
+ split='test',
354
+ num_samples=num_samples,
355
+ use_albumentations=self.use_albumentations
356
+ )
357
+
358
+ return DataLoader(
359
+ dataset,
360
+ batch_size=self.batch_size,
361
+ shuffle=False,
362
+ num_workers=self.num_workers,
363
+ pin_memory=True
364
+ )
365
+
366
+ def get_all_loaders(self, num_samples: Optional[int] = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
367
+ """Get all data loaders."""
368
+ train_loader = self.get_train_loader(num_samples)
369
+ val_loader = self.get_val_loader(num_samples)
370
+ test_loader = self.get_test_loader(num_samples)
371
+
372
+ return train_loader, val_loader, test_loader
373
+
374
+
375
+ # Keep the original classes for backward compatibility
376
+ class ArchitecturalDataset(EnhancedArchitecturalDataset):
377
+ """Backward compatibility wrapper."""
378
+ pass
379
+
380
+ class ArchitecturalDataLoader(EnhancedArchitecturalDataLoader):
381
+ """Backward compatibility wrapper."""
382
+ pass
383
+
384
+
385
+ class SampleDataGenerator:
386
+ """Generate sample data for testing and development."""
387
+
388
+ def __init__(self, output_dir: str = 'data/sample'):
389
+ self.output_dir = output_dir
390
+ os.makedirs(output_dir, exist_ok=True)
391
+
392
+ def generate_sample_dataset(self, num_classes: int = 25, samples_per_class: int = 100):
393
+ """Generate a complete sample dataset."""
394
+ print(f"Generating sample dataset with {num_classes} classes and {samples_per_class} samples per class...")
395
+
396
+ for class_idx in range(num_classes):
397
+ class_dir = os.path.join(self.output_dir, str(class_idx))
398
+ os.makedirs(class_dir, exist_ok=True)
399
+
400
+ for sample_idx in range(samples_per_class):
401
+ # Generate sample image
402
+ img_array = self._generate_sample_image(class_idx)
403
+
404
+ # Save image
405
+ img = Image.fromarray(img_array)
406
+ img_path = os.path.join(class_dir, f'sample_{sample_idx:03d}.jpg')
407
+ img.save(img_path)
408
+
409
+ print(f"Sample dataset generated in {self.output_dir}")
410
+ print(f"Total images: {num_classes * samples_per_class}")
411
+
412
+ def _generate_sample_image(self, class_idx: int) -> np.ndarray:
413
+ """Generate a sample image for a specific class."""
414
+ # Base image
415
+ img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
416
+
417
+ # Add class-specific characteristics
418
+ if class_idx < 5: # Ancient styles (Greek, Roman, etc.)
419
+ # Add columns and arches pattern
420
+ img_array = self._add_ancient_patterns(img_array)
421
+ elif class_idx < 10: # Medieval styles (Gothic, Romanesque)
422
+ # Add pointed arches and spires
423
+ img_array = self._add_medieval_patterns(img_array)
424
+ elif class_idx < 15: # Renaissance styles
425
+ # Add symmetry and classical elements
426
+ img_array = self._add_renaissance_patterns(img_array)
427
+ elif class_idx < 20: # Modern styles
428
+ # Add clean lines and geometric shapes
429
+ img_array = self._add_modern_patterns(img_array)
430
+ else: # Contemporary styles
431
+ # Add abstract and experimental elements
432
+ img_array = self._add_contemporary_patterns(img_array)
433
+
434
+ return img_array
435
+
436
+ def _add_ancient_patterns(self, img_array: np.ndarray) -> np.ndarray:
437
+ """Add ancient architectural patterns."""
438
+ # Add column-like vertical lines
439
+ for i in range(0, 224, 40):
440
+ img_array[:, i:i+10, :] = [150, 100, 50] # Brown columns
441
+
442
+ # Add arch-like curves
443
+ for i in range(50, 174, 60):
444
+ for j in range(50, 174):
445
+ if (j - 112) ** 2 + (i - 87) ** 2 < 1000:
446
+ img_array[j, i:i+20, :] = [200, 150, 100] # Light brown arches
447
+
448
+ return img_array
449
+
450
+ def _add_medieval_patterns(self, img_array: np.ndarray) -> np.ndarray:
451
+ """Add medieval architectural patterns."""
452
+ # Add pointed arches
453
+ for i in range(50, 174, 60):
454
+ for j in range(50, 174):
455
+ if abs(j - 112) < 30 and (i - 87) ** 2 > 500:
456
+ img_array[j, i:i+20, :] = [100, 100, 150] # Blue-gray arches
457
+
458
+ # Add spires
459
+ for i in range(20, 204, 80):
460
+ img_array[0:50, i:i+10, :] = [80, 80, 120] # Dark blue spires
461
+
462
+ return img_array
463
+
464
+ def _add_renaissance_patterns(self, img_array: np.ndarray) -> np.ndarray:
465
+ """Add renaissance architectural patterns."""
466
+ # Add symmetrical facade
467
+ for i in range(50, 174):
468
+ img_array[i, 50:174, :] = [180, 180, 200] # Light facade
469
+
470
+ # Add classical elements
471
+ for i in range(0, 224, 60):
472
+ img_array[100:120, i:i+20, :] = [150, 120, 80] # Classical frieze
473
+
474
+ return img_array
475
+
476
+ def _add_modern_patterns(self, img_array: np.ndarray) -> np.ndarray:
477
+ """Add modern architectural patterns."""
478
+ # Add clean horizontal lines
479
+ for i in range(0, 224, 30):
480
+ img_array[i:i+5, :, :] = [200, 200, 200] # White lines
481
+
482
+ # Add geometric shapes
483
+ for i in range(50, 174, 40):
484
+ for j in range(50, 174, 40):
485
+ img_array[j:j+20, i:i+20, :] = [100, 150, 200] # Blue rectangles
486
+
487
+ return img_array
488
+
489
+ def _add_contemporary_patterns(self, img_array: np.ndarray) -> np.ndarray:
490
+ """Add contemporary architectural patterns."""
491
+ # Add abstract patterns
492
+ for i in range(0, 224, 20):
493
+ for j in range(0, 224, 20):
494
+ if random.random() > 0.7:
495
+ color = np.random.randint(0, 255, 3)
496
+ img_array[j:j+15, i:i+15, :] = color
497
+
498
+ # Add curved elements
499
+ for i in range(50, 174):
500
+ for j in range(50, 174):
501
+ if (i - 112) ** 2 + (j - 87) ** 2 < 2000:
502
+ img_array[j, i, :] = [150, 100, 150] # Purple curves
503
+
504
+ return img_array
505
+
506
+
507
+ def create_sample_dataset(data_dir: str = 'data/sample', num_samples: int = 1000):
508
+ """Create a sample dataset for testing."""
509
+ generator = SampleDataGenerator(data_dir)
510
+ generator.generate_sample_dataset(num_classes=25, samples_per_class=num_samples//25)
511
+ return data_dir