|
|
"""
|
|
|
Data utilities for fire detection classification
|
|
|
Handles data loading, transformations, and dataset management
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
|
|
from torchvision import transforms, datasets
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
from typing import Tuple, Dict, List, Optional
|
|
|
from collections import Counter
|
|
|
import random
|
|
|
|
|
|
class FireDetectionDataset(Dataset):
|
|
|
"""
|
|
|
Custom dataset for fire detection images
|
|
|
Supports both training and validation modes with appropriate transforms
|
|
|
"""
|
|
|
|
|
|
def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224):
|
|
|
"""
|
|
|
Initialize fire detection dataset
|
|
|
|
|
|
Args:
|
|
|
data_dir: Root directory containing train/val folders
|
|
|
split: 'train' or 'val'
|
|
|
image_size: Size to resize images to
|
|
|
"""
|
|
|
self.data_dir = data_dir
|
|
|
self.split = split
|
|
|
self.image_size = image_size
|
|
|
|
|
|
|
|
|
self.classes = ['fire', 'no_fire']
|
|
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
|
|
|
|
|
|
|
|
self.samples = self._load_samples()
|
|
|
|
|
|
|
|
|
self.transform = self._get_transforms()
|
|
|
|
|
|
print(f"🔥 {split.upper()} Dataset loaded:")
|
|
|
print(f" Total samples: {len(self.samples)}")
|
|
|
print(f" Classes: {self.classes}")
|
|
|
self._print_class_distribution()
|
|
|
|
|
|
def _load_samples(self) -> List[Tuple[str, int]]:
|
|
|
"""Load image paths and corresponding labels"""
|
|
|
samples = []
|
|
|
split_dir = os.path.join(self.data_dir, self.split)
|
|
|
|
|
|
for class_name in self.classes:
|
|
|
class_dir = os.path.join(split_dir, class_name)
|
|
|
if not os.path.exists(class_dir):
|
|
|
print(f"⚠️ Warning: {class_dir} not found")
|
|
|
continue
|
|
|
|
|
|
class_idx = self.class_to_idx[class_name]
|
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(class_dir):
|
|
|
for img_name in files:
|
|
|
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
|
|
img_path = os.path.join(root, img_name)
|
|
|
samples.append((img_path, class_idx))
|
|
|
|
|
|
return samples
|
|
|
|
|
|
def _print_class_distribution(self):
|
|
|
"""Print class distribution for the dataset"""
|
|
|
class_counts = Counter([label for _, label in self.samples])
|
|
|
for class_name, class_idx in self.class_to_idx.items():
|
|
|
count = class_counts.get(class_idx, 0)
|
|
|
print(f" {class_name}: {count} samples")
|
|
|
|
|
|
def _get_transforms(self) -> transforms.Compose:
|
|
|
"""Get appropriate transforms for the split"""
|
|
|
if self.split == 'train':
|
|
|
return transforms.Compose([
|
|
|
transforms.Resize((self.image_size + 32, self.image_size + 32)),
|
|
|
transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
|
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
|
transforms.RandomRotation(degrees=10),
|
|
|
transforms.ColorJitter(
|
|
|
brightness=0.2,
|
|
|
contrast=0.2,
|
|
|
saturation=0.2,
|
|
|
hue=0.1
|
|
|
),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(
|
|
|
mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225]
|
|
|
),
|
|
|
transforms.RandomErasing(p=0.1, scale=(0.02, 0.08))
|
|
|
])
|
|
|
else:
|
|
|
return transforms.Compose([
|
|
|
transforms.Resize((self.image_size, self.image_size)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(
|
|
|
mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225]
|
|
|
)
|
|
|
])
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
return len(self.samples)
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
|
"""Get a sample from the dataset"""
|
|
|
img_path, label = self.samples[idx]
|
|
|
|
|
|
|
|
|
try:
|
|
|
image = Image.open(img_path).convert('RGB')
|
|
|
except Exception as e:
|
|
|
print(f"⚠️ Error loading image {img_path}: {e}")
|
|
|
|
|
|
image = Image.new('RGB', (self.image_size, self.image_size), color='black')
|
|
|
|
|
|
|
|
|
if self.transform:
|
|
|
image = self.transform(image)
|
|
|
|
|
|
return image, label
|
|
|
|
|
|
def create_data_loaders(
|
|
|
data_dir: str,
|
|
|
batch_size: int = 16,
|
|
|
num_workers: int = 4,
|
|
|
image_size: int = 224,
|
|
|
use_weighted_sampling: bool = True
|
|
|
) -> Tuple[DataLoader, DataLoader]:
|
|
|
"""
|
|
|
Create train and validation data loaders
|
|
|
|
|
|
Args:
|
|
|
data_dir: Root directory containing train/val folders
|
|
|
batch_size: Batch size for data loaders
|
|
|
num_workers: Number of worker processes
|
|
|
image_size: Size to resize images to
|
|
|
use_weighted_sampling: Whether to use weighted sampling for imbalanced data
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (train_loader, val_loader)
|
|
|
"""
|
|
|
|
|
|
train_dataset = FireDetectionDataset(data_dir, 'train', image_size)
|
|
|
val_dataset = FireDetectionDataset(data_dir, 'val', image_size)
|
|
|
|
|
|
|
|
|
train_sampler = None
|
|
|
if use_weighted_sampling and len(train_dataset) > 0:
|
|
|
train_sampler = create_weighted_sampler(train_dataset)
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=batch_size,
|
|
|
sampler=train_sampler,
|
|
|
shuffle=(train_sampler is None),
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=torch.cuda.is_available(),
|
|
|
drop_last=True
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=torch.cuda.is_available()
|
|
|
)
|
|
|
|
|
|
print(f"📦 Data loaders created:")
|
|
|
print(f" Batch size: {batch_size}")
|
|
|
print(f" Num workers: {num_workers}")
|
|
|
print(f" Train batches: {len(train_loader)}")
|
|
|
print(f" Val batches: {len(val_loader)}")
|
|
|
print(f" Weighted sampling: {use_weighted_sampling}")
|
|
|
|
|
|
return train_loader, val_loader
|
|
|
|
|
|
def create_weighted_sampler(dataset: FireDetectionDataset) -> WeightedRandomSampler:
|
|
|
"""
|
|
|
Create weighted random sampler for imbalanced datasets
|
|
|
|
|
|
Args:
|
|
|
dataset: The dataset to create sampler for
|
|
|
|
|
|
Returns:
|
|
|
WeightedRandomSampler for balanced sampling
|
|
|
"""
|
|
|
|
|
|
class_counts = Counter([label for _, label in dataset.samples])
|
|
|
total_samples = len(dataset.samples)
|
|
|
|
|
|
|
|
|
class_weights = {}
|
|
|
for class_idx, count in class_counts.items():
|
|
|
class_weights[class_idx] = total_samples / count
|
|
|
|
|
|
|
|
|
sample_weights = [class_weights[label] for _, label in dataset.samples]
|
|
|
|
|
|
|
|
|
sampler = WeightedRandomSampler(
|
|
|
weights=sample_weights,
|
|
|
num_samples=total_samples,
|
|
|
replacement=True
|
|
|
)
|
|
|
|
|
|
print(f"⚖️ Weighted sampler created:")
|
|
|
for class_name, class_idx in dataset.class_to_idx.items():
|
|
|
count = class_counts.get(class_idx, 0)
|
|
|
weight = class_weights.get(class_idx, 0)
|
|
|
print(f" {class_name}: {count} samples, weight: {weight:.2f}")
|
|
|
|
|
|
return sampler
|
|
|
|
|
|
def get_inference_transform(image_size: int = 224) -> transforms.Compose:
|
|
|
"""
|
|
|
Get transforms for inference/prediction
|
|
|
|
|
|
Args:
|
|
|
image_size: Size to resize images to
|
|
|
|
|
|
Returns:
|
|
|
Transform pipeline for inference
|
|
|
"""
|
|
|
return transforms.Compose([
|
|
|
transforms.Resize((image_size, image_size)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(
|
|
|
mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225]
|
|
|
)
|
|
|
])
|
|
|
|
|
|
def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
|
|
|
"""
|
|
|
Prepare an image for inference
|
|
|
|
|
|
Args:
|
|
|
image: PIL Image
|
|
|
transform: Transform pipeline
|
|
|
|
|
|
Returns:
|
|
|
Tensor ready for model inference
|
|
|
"""
|
|
|
|
|
|
image_tensor = transform(image)
|
|
|
|
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(0)
|
|
|
|
|
|
return image_tensor
|
|
|
|
|
|
def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None:
|
|
|
"""
|
|
|
Visualize a batch of images from the data loader
|
|
|
|
|
|
Args:
|
|
|
data_loader: DataLoader to sample from
|
|
|
num_samples: Number of samples to visualize
|
|
|
"""
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
images, labels = next(iter(data_loader))
|
|
|
|
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
|
|
|
axes = axes.flatten()
|
|
|
|
|
|
class_names = ['Fire', 'No Fire']
|
|
|
|
|
|
for i in range(min(num_samples, len(images))):
|
|
|
|
|
|
img = images[i] * std + mean
|
|
|
img = torch.clamp(img, 0, 1)
|
|
|
|
|
|
|
|
|
img_np = img.permute(1, 2, 0).numpy()
|
|
|
|
|
|
|
|
|
axes[i].imshow(img_np)
|
|
|
axes[i].set_title(f'{class_names[labels[i]]}')
|
|
|
axes[i].axis('off')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
|
|
|
def check_data_directory(data_dir: str) -> Dict[str, int]:
|
|
|
"""
|
|
|
Check data directory structure and count samples
|
|
|
|
|
|
Args:
|
|
|
data_dir: Directory to check
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with data counts
|
|
|
"""
|
|
|
data_counts = {}
|
|
|
|
|
|
if not os.path.exists(data_dir):
|
|
|
print(f"❌ Data directory not found: {data_dir}")
|
|
|
return data_counts
|
|
|
|
|
|
print(f"📊 Data Directory Analysis: {data_dir}")
|
|
|
print("=" * 50)
|
|
|
|
|
|
total_samples = 0
|
|
|
|
|
|
for split in ['train', 'val']:
|
|
|
split_dir = os.path.join(data_dir, split)
|
|
|
if not os.path.exists(split_dir):
|
|
|
continue
|
|
|
|
|
|
print(f"\n{split.upper()} SET:")
|
|
|
split_total = 0
|
|
|
|
|
|
for class_name in ['fire', 'no_fire']:
|
|
|
class_dir = os.path.join(split_dir, class_name)
|
|
|
if not os.path.exists(class_dir):
|
|
|
continue
|
|
|
|
|
|
|
|
|
count = 0
|
|
|
for root, dirs, files in os.walk(class_dir):
|
|
|
for file in files:
|
|
|
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
|
|
count += 1
|
|
|
|
|
|
print(f" {class_name}: {count} images")
|
|
|
data_counts[f"{split}_{class_name}"] = count
|
|
|
split_total += count
|
|
|
|
|
|
print(f" Total {split}: {split_total}")
|
|
|
total_samples += split_total
|
|
|
data_counts[f"{split}_total"] = split_total
|
|
|
|
|
|
print(f"\nOVERALL TOTAL: {total_samples} images")
|
|
|
data_counts['total'] = total_samples
|
|
|
print("=" * 50)
|
|
|
|
|
|
return data_counts
|
|
|
|
|
|
def create_sample_data_structure():
|
|
|
"""Create sample data structure for testing"""
|
|
|
print("🔥 Creating sample fire detection data structure...")
|
|
|
|
|
|
|
|
|
directories = [
|
|
|
'data/train/fire',
|
|
|
'data/train/no_fire',
|
|
|
'data/val/fire',
|
|
|
'data/val/no_fire'
|
|
|
]
|
|
|
|
|
|
for directory in directories:
|
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
|
|
print("✅ Sample data structure created")
|
|
|
print(" Please add your fire detection images to the appropriate directories")
|
|
|
print(" - data/train/fire/ (training fire images)")
|
|
|
print(" - data/train/no_fire/ (training no-fire images)")
|
|
|
print(" - data/val/fire/ (validation fire images)")
|
|
|
print(" - data/val/no_fire/ (validation no-fire images)") |