| | """
|
| | Data utilities for fire detection classification
|
| | Handles data loading, transformations, and dataset management
|
| | """
|
| |
|
| | import os
|
| | import torch
|
| | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| | from torchvision import transforms, datasets
|
| | from PIL import Image
|
| | import numpy as np
|
| | from typing import Tuple, Dict, List, Optional
|
| | from collections import Counter
|
| | import random
|
| |
|
| | class FireDetectionDataset(Dataset):
|
| | """
|
| | Custom dataset for fire detection images
|
| | Supports both training and validation modes with appropriate transforms
|
| | """
|
| |
|
| | def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224):
|
| | """
|
| | Initialize fire detection dataset
|
| |
|
| | Args:
|
| | data_dir: Root directory containing train/val folders
|
| | split: 'train' or 'val'
|
| | image_size: Size to resize images to
|
| | """
|
| | self.data_dir = data_dir
|
| | self.split = split
|
| | self.image_size = image_size
|
| |
|
| |
|
| | self.classes = ['fire', 'no_fire']
|
| | self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
| |
|
| |
|
| | self.samples = self._load_samples()
|
| |
|
| |
|
| | self.transform = self._get_transforms()
|
| |
|
| | print(f"🔥 {split.upper()} Dataset loaded:")
|
| | print(f" Total samples: {len(self.samples)}")
|
| | print(f" Classes: {self.classes}")
|
| | self._print_class_distribution()
|
| |
|
| | def _load_samples(self) -> List[Tuple[str, int]]:
|
| | """Load image paths and corresponding labels"""
|
| | samples = []
|
| | split_dir = os.path.join(self.data_dir, self.split)
|
| |
|
| | for class_name in self.classes:
|
| | class_dir = os.path.join(split_dir, class_name)
|
| | if not os.path.exists(class_dir):
|
| | print(f"⚠️ Warning: {class_dir} not found")
|
| | continue
|
| |
|
| | class_idx = self.class_to_idx[class_name]
|
| |
|
| |
|
| | for root, dirs, files in os.walk(class_dir):
|
| | for img_name in files:
|
| | if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
| | img_path = os.path.join(root, img_name)
|
| | samples.append((img_path, class_idx))
|
| |
|
| | return samples
|
| |
|
| | def _print_class_distribution(self):
|
| | """Print class distribution for the dataset"""
|
| | class_counts = Counter([label for _, label in self.samples])
|
| | for class_name, class_idx in self.class_to_idx.items():
|
| | count = class_counts.get(class_idx, 0)
|
| | print(f" {class_name}: {count} samples")
|
| |
|
| | def _get_transforms(self) -> transforms.Compose:
|
| | """Get appropriate transforms for the split"""
|
| | if self.split == 'train':
|
| | return transforms.Compose([
|
| | transforms.Resize((self.image_size + 32, self.image_size + 32)),
|
| | transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
|
| | transforms.RandomHorizontalFlip(p=0.5),
|
| | transforms.RandomRotation(degrees=10),
|
| | transforms.ColorJitter(
|
| | brightness=0.2,
|
| | contrast=0.2,
|
| | saturation=0.2,
|
| | hue=0.1
|
| | ),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize(
|
| | mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225]
|
| | ),
|
| | transforms.RandomErasing(p=0.1, scale=(0.02, 0.08))
|
| | ])
|
| | else:
|
| | return transforms.Compose([
|
| | transforms.Resize((self.image_size, self.image_size)),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize(
|
| | mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225]
|
| | )
|
| | ])
|
| |
|
| | def __len__(self) -> int:
|
| | return len(self.samples)
|
| |
|
| | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| | """Get a sample from the dataset"""
|
| | img_path, label = self.samples[idx]
|
| |
|
| |
|
| | try:
|
| | image = Image.open(img_path).convert('RGB')
|
| | except Exception as e:
|
| | print(f"⚠️ Error loading image {img_path}: {e}")
|
| |
|
| | image = Image.new('RGB', (self.image_size, self.image_size), color='black')
|
| |
|
| |
|
| | if self.transform:
|
| | image = self.transform(image)
|
| |
|
| | return image, label
|
| |
|
| | def create_data_loaders(
|
| | data_dir: str,
|
| | batch_size: int = 16,
|
| | num_workers: int = 4,
|
| | image_size: int = 224,
|
| | use_weighted_sampling: bool = True
|
| | ) -> Tuple[DataLoader, DataLoader]:
|
| | """
|
| | Create train and validation data loaders
|
| |
|
| | Args:
|
| | data_dir: Root directory containing train/val folders
|
| | batch_size: Batch size for data loaders
|
| | num_workers: Number of worker processes
|
| | image_size: Size to resize images to
|
| | use_weighted_sampling: Whether to use weighted sampling for imbalanced data
|
| |
|
| | Returns:
|
| | Tuple of (train_loader, val_loader)
|
| | """
|
| |
|
| | train_dataset = FireDetectionDataset(data_dir, 'train', image_size)
|
| | val_dataset = FireDetectionDataset(data_dir, 'val', image_size)
|
| |
|
| |
|
| | train_sampler = None
|
| | if use_weighted_sampling and len(train_dataset) > 0:
|
| | train_sampler = create_weighted_sampler(train_dataset)
|
| |
|
| |
|
| | train_loader = DataLoader(
|
| | train_dataset,
|
| | batch_size=batch_size,
|
| | sampler=train_sampler,
|
| | shuffle=(train_sampler is None),
|
| | num_workers=num_workers,
|
| | pin_memory=torch.cuda.is_available(),
|
| | drop_last=True
|
| | )
|
| |
|
| | val_loader = DataLoader(
|
| | val_dataset,
|
| | batch_size=batch_size,
|
| | shuffle=False,
|
| | num_workers=num_workers,
|
| | pin_memory=torch.cuda.is_available()
|
| | )
|
| |
|
| | print(f"📦 Data loaders created:")
|
| | print(f" Batch size: {batch_size}")
|
| | print(f" Num workers: {num_workers}")
|
| | print(f" Train batches: {len(train_loader)}")
|
| | print(f" Val batches: {len(val_loader)}")
|
| | print(f" Weighted sampling: {use_weighted_sampling}")
|
| |
|
| | return train_loader, val_loader
|
| |
|
| | def create_weighted_sampler(dataset: FireDetectionDataset) -> WeightedRandomSampler:
|
| | """
|
| | Create weighted random sampler for imbalanced datasets
|
| |
|
| | Args:
|
| | dataset: The dataset to create sampler for
|
| |
|
| | Returns:
|
| | WeightedRandomSampler for balanced sampling
|
| | """
|
| |
|
| | class_counts = Counter([label for _, label in dataset.samples])
|
| | total_samples = len(dataset.samples)
|
| |
|
| |
|
| | class_weights = {}
|
| | for class_idx, count in class_counts.items():
|
| | class_weights[class_idx] = total_samples / count
|
| |
|
| |
|
| | sample_weights = [class_weights[label] for _, label in dataset.samples]
|
| |
|
| |
|
| | sampler = WeightedRandomSampler(
|
| | weights=sample_weights,
|
| | num_samples=total_samples,
|
| | replacement=True
|
| | )
|
| |
|
| | print(f"⚖️ Weighted sampler created:")
|
| | for class_name, class_idx in dataset.class_to_idx.items():
|
| | count = class_counts.get(class_idx, 0)
|
| | weight = class_weights.get(class_idx, 0)
|
| | print(f" {class_name}: {count} samples, weight: {weight:.2f}")
|
| |
|
| | return sampler
|
| |
|
| | def get_inference_transform(image_size: int = 224) -> transforms.Compose:
|
| | """
|
| | Get transforms for inference/prediction
|
| |
|
| | Args:
|
| | image_size: Size to resize images to
|
| |
|
| | Returns:
|
| | Transform pipeline for inference
|
| | """
|
| | return transforms.Compose([
|
| | transforms.Resize((image_size, image_size)),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize(
|
| | mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225]
|
| | )
|
| | ])
|
| |
|
| | def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
|
| | """
|
| | Prepare an image for inference
|
| |
|
| | Args:
|
| | image: PIL Image
|
| | transform: Transform pipeline
|
| |
|
| | Returns:
|
| | Tensor ready for model inference
|
| | """
|
| |
|
| | image_tensor = transform(image)
|
| |
|
| |
|
| | image_tensor = image_tensor.unsqueeze(0)
|
| |
|
| | return image_tensor
|
| |
|
| | def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None:
|
| | """
|
| | Visualize a batch of images from the data loader
|
| |
|
| | Args:
|
| | data_loader: DataLoader to sample from
|
| | num_samples: Number of samples to visualize
|
| | """
|
| | import matplotlib.pyplot as plt
|
| |
|
| |
|
| | images, labels = next(iter(data_loader))
|
| |
|
| |
|
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| | std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| |
|
| |
|
| | fig, axes = plt.subplots(2, 4, figsize=(15, 8))
|
| | axes = axes.flatten()
|
| |
|
| | class_names = ['Fire', 'No Fire']
|
| |
|
| | for i in range(min(num_samples, len(images))):
|
| |
|
| | img = images[i] * std + mean
|
| | img = torch.clamp(img, 0, 1)
|
| |
|
| |
|
| | img_np = img.permute(1, 2, 0).numpy()
|
| |
|
| |
|
| | axes[i].imshow(img_np)
|
| | axes[i].set_title(f'{class_names[labels[i]]}')
|
| | axes[i].axis('off')
|
| |
|
| | plt.tight_layout()
|
| | plt.show()
|
| |
|
| | def check_data_directory(data_dir: str) -> Dict[str, int]:
|
| | """
|
| | Check data directory structure and count samples
|
| |
|
| | Args:
|
| | data_dir: Directory to check
|
| |
|
| | Returns:
|
| | Dictionary with data counts
|
| | """
|
| | data_counts = {}
|
| |
|
| | if not os.path.exists(data_dir):
|
| | print(f"❌ Data directory not found: {data_dir}")
|
| | return data_counts
|
| |
|
| | print(f"📊 Data Directory Analysis: {data_dir}")
|
| | print("=" * 50)
|
| |
|
| | total_samples = 0
|
| |
|
| | for split in ['train', 'val']:
|
| | split_dir = os.path.join(data_dir, split)
|
| | if not os.path.exists(split_dir):
|
| | continue
|
| |
|
| | print(f"\n{split.upper()} SET:")
|
| | split_total = 0
|
| |
|
| | for class_name in ['fire', 'no_fire']:
|
| | class_dir = os.path.join(split_dir, class_name)
|
| | if not os.path.exists(class_dir):
|
| | continue
|
| |
|
| |
|
| | count = 0
|
| | for root, dirs, files in os.walk(class_dir):
|
| | for file in files:
|
| | if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
| | count += 1
|
| |
|
| | print(f" {class_name}: {count} images")
|
| | data_counts[f"{split}_{class_name}"] = count
|
| | split_total += count
|
| |
|
| | print(f" Total {split}: {split_total}")
|
| | total_samples += split_total
|
| | data_counts[f"{split}_total"] = split_total
|
| |
|
| | print(f"\nOVERALL TOTAL: {total_samples} images")
|
| | data_counts['total'] = total_samples
|
| | print("=" * 50)
|
| |
|
| | return data_counts
|
| |
|
| | def create_sample_data_structure():
|
| | """Create sample data structure for testing"""
|
| | print("🔥 Creating sample fire detection data structure...")
|
| |
|
| |
|
| | directories = [
|
| | 'data/train/fire',
|
| | 'data/train/no_fire',
|
| | 'data/val/fire',
|
| | 'data/val/no_fire'
|
| | ]
|
| |
|
| | for directory in directories:
|
| | os.makedirs(directory, exist_ok=True)
|
| |
|
| | print("✅ Sample data structure created")
|
| | print(" Please add your fire detection images to the appropriate directories")
|
| | print(" - data/train/fire/ (training fire images)")
|
| | print(" - data/train/no_fire/ (training no-fire images)")
|
| | print(" - data/val/fire/ (validation fire images)")
|
| | print(" - data/val/no_fire/ (validation no-fire images)") |