|
|
""" |
|
|
Data utilities for telecom site classification |
|
|
Handles data loading, transformations, and dataset management |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
|
|
from torchvision import transforms, datasets |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from typing import Tuple, Dict, List, Optional |
|
|
from collections import Counter |
|
|
import random |
|
|
|
|
|
class TelecomSiteDataset(Dataset): |
|
|
""" |
|
|
Custom dataset for telecom site images |
|
|
Supports both training and validation modes with appropriate transforms |
|
|
""" |
|
|
|
|
|
def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224): |
|
|
""" |
|
|
Initialize telecom site dataset |
|
|
|
|
|
Args: |
|
|
data_dir: Root directory containing train/val folders |
|
|
split: 'train' or 'val' |
|
|
image_size: Size to resize images to |
|
|
""" |
|
|
self.data_dir = data_dir |
|
|
self.split = split |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.classes = ['bad', 'good'] |
|
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
|
|
|
|
|
|
|
|
self.samples = self._load_samples() |
|
|
|
|
|
|
|
|
self.transform = self._get_transforms() |
|
|
|
|
|
print(f"π {split.upper()} Dataset loaded:") |
|
|
print(f" Total samples: {len(self.samples)}") |
|
|
print(f" Classes: {self.classes}") |
|
|
self._print_class_distribution() |
|
|
|
|
|
def _load_samples(self) -> List[Tuple[str, int]]: |
|
|
"""Load image paths and corresponding labels""" |
|
|
samples = [] |
|
|
split_dir = os.path.join(self.data_dir, self.split) |
|
|
|
|
|
for class_name in self.classes: |
|
|
class_dir = os.path.join(split_dir, class_name) |
|
|
if not os.path.exists(class_dir): |
|
|
print(f"β οΈ Warning: {class_dir} not found") |
|
|
continue |
|
|
|
|
|
class_idx = self.class_to_idx[class_name] |
|
|
|
|
|
|
|
|
for img_name in os.listdir(class_dir): |
|
|
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): |
|
|
img_path = os.path.join(class_dir, img_name) |
|
|
samples.append((img_path, class_idx)) |
|
|
|
|
|
return samples |
|
|
|
|
|
def _print_class_distribution(self): |
|
|
"""Print class distribution for the dataset""" |
|
|
class_counts = Counter([label for _, label in self.samples]) |
|
|
for class_name, class_idx in self.class_to_idx.items(): |
|
|
count = class_counts.get(class_idx, 0) |
|
|
print(f" {class_name}: {count} samples") |
|
|
|
|
|
def _get_transforms(self) -> transforms.Compose: |
|
|
"""Get appropriate transforms for the split""" |
|
|
if self.split == 'train': |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((self.image_size + 32, self.image_size + 32)), |
|
|
transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)), |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.RandomRotation(degrees=10), |
|
|
transforms.ColorJitter( |
|
|
brightness=0.2, |
|
|
contrast=0.2, |
|
|
saturation=0.2, |
|
|
hue=0.1 |
|
|
), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
), |
|
|
transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)) |
|
|
]) |
|
|
else: |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((self.image_size, self.image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: |
|
|
"""Get a sample from the dataset""" |
|
|
img_path, label = self.samples[idx] |
|
|
|
|
|
|
|
|
try: |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"β οΈ Error loading image {img_path}: {e}") |
|
|
|
|
|
image = Image.new('RGB', (self.image_size, self.image_size), color='black') |
|
|
|
|
|
|
|
|
if self.transform: |
|
|
image = self.transform(image) |
|
|
|
|
|
return image, label |
|
|
|
|
|
def create_data_loaders( |
|
|
data_dir: str, |
|
|
batch_size: int = 16, |
|
|
num_workers: int = 4, |
|
|
image_size: int = 224, |
|
|
use_weighted_sampling: bool = True |
|
|
) -> Tuple[DataLoader, DataLoader]: |
|
|
""" |
|
|
Create train and validation data loaders |
|
|
|
|
|
Args: |
|
|
data_dir: Root directory containing train/val folders |
|
|
batch_size: Batch size for data loaders |
|
|
num_workers: Number of worker processes |
|
|
image_size: Size to resize images to |
|
|
use_weighted_sampling: Whether to use weighted sampling for imbalanced data |
|
|
|
|
|
Returns: |
|
|
Tuple of (train_loader, val_loader) |
|
|
""" |
|
|
|
|
|
train_dataset = TelecomSiteDataset(data_dir, 'train', image_size) |
|
|
val_dataset = TelecomSiteDataset(data_dir, 'val', image_size) |
|
|
|
|
|
|
|
|
train_sampler = None |
|
|
if use_weighted_sampling and len(train_dataset) > 0: |
|
|
train_sampler = create_weighted_sampler(train_dataset) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
sampler=train_sampler, |
|
|
shuffle=(train_sampler is None), |
|
|
num_workers=num_workers, |
|
|
pin_memory=torch.cuda.is_available(), |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=torch.cuda.is_available() |
|
|
) |
|
|
|
|
|
print(f"π¦ Data loaders created:") |
|
|
print(f" Batch size: {batch_size}") |
|
|
print(f" Num workers: {num_workers}") |
|
|
print(f" Train batches: {len(train_loader)}") |
|
|
print(f" Val batches: {len(val_loader)}") |
|
|
print(f" Weighted sampling: {use_weighted_sampling}") |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def create_weighted_sampler(dataset: TelecomSiteDataset) -> WeightedRandomSampler: |
|
|
""" |
|
|
Create weighted random sampler for imbalanced datasets |
|
|
|
|
|
Args: |
|
|
dataset: The dataset to create sampler for |
|
|
|
|
|
Returns: |
|
|
WeightedRandomSampler for balanced sampling |
|
|
""" |
|
|
|
|
|
class_counts = Counter([label for _, label in dataset.samples]) |
|
|
total_samples = len(dataset.samples) |
|
|
|
|
|
|
|
|
class_weights = {} |
|
|
for class_idx in range(len(dataset.classes)): |
|
|
class_weights[class_idx] = total_samples / (len(dataset.classes) * class_counts.get(class_idx, 1)) |
|
|
|
|
|
|
|
|
sample_weights = [class_weights[label] for _, label in dataset.samples] |
|
|
|
|
|
sampler = WeightedRandomSampler( |
|
|
weights=sample_weights, |
|
|
num_samples=len(sample_weights), |
|
|
replacement=True |
|
|
) |
|
|
|
|
|
print(f"βοΈ Weighted sampler created:") |
|
|
for class_name, class_idx in dataset.class_to_idx.items(): |
|
|
print(f" {class_name}: weight={class_weights[class_idx]:.3f}") |
|
|
|
|
|
return sampler |
|
|
|
|
|
def get_inference_transform(image_size: int = 224) -> transforms.Compose: |
|
|
""" |
|
|
Get transform for inference/prediction |
|
|
|
|
|
Args: |
|
|
image_size: Size to resize images to |
|
|
|
|
|
Returns: |
|
|
Transform pipeline for inference |
|
|
""" |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor: |
|
|
""" |
|
|
Prepare a PIL image for model inference |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
transform: Transform pipeline |
|
|
|
|
|
Returns: |
|
|
Preprocessed tensor ready for model |
|
|
""" |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
tensor = transform(image).unsqueeze(0) |
|
|
return tensor |
|
|
|
|
|
def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None: |
|
|
""" |
|
|
Visualize a batch of images from the data loader |
|
|
|
|
|
Args: |
|
|
data_loader: DataLoader to sample from |
|
|
num_samples: Number of samples to visualize |
|
|
""" |
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
batch_images, batch_labels = next(iter(data_loader)) |
|
|
|
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 4, figsize=(12, 6)) |
|
|
axes = axes.flatten() |
|
|
|
|
|
class_names = ['Bad', 'Good'] |
|
|
|
|
|
for i in range(min(num_samples, len(batch_images))): |
|
|
|
|
|
img = batch_images[i] * std + mean |
|
|
img = torch.clamp(img, 0, 1) |
|
|
|
|
|
|
|
|
img_np = img.permute(1, 2, 0).numpy() |
|
|
|
|
|
|
|
|
axes[i].imshow(img_np) |
|
|
axes[i].set_title(f'Class: {class_names[batch_labels[i]]}') |
|
|
axes[i].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
except ImportError: |
|
|
print("β οΈ Matplotlib not available for visualization") |
|
|
|
|
|
def check_data_directory(data_dir: str) -> Dict[str, int]: |
|
|
""" |
|
|
Check the data directory structure and count samples |
|
|
|
|
|
Args: |
|
|
data_dir: Root directory to check |
|
|
|
|
|
Returns: |
|
|
Dictionary with sample counts |
|
|
""" |
|
|
print(f"π Checking data directory: {data_dir}") |
|
|
|
|
|
if not os.path.exists(data_dir): |
|
|
print(f"β Data directory not found: {data_dir}") |
|
|
return {} |
|
|
|
|
|
counts = {} |
|
|
|
|
|
for split in ['train', 'val']: |
|
|
split_dir = os.path.join(data_dir, split) |
|
|
if not os.path.exists(split_dir): |
|
|
print(f"β οΈ {split} directory not found") |
|
|
continue |
|
|
|
|
|
split_counts = {} |
|
|
for class_name in ['good', 'bad']: |
|
|
class_dir = os.path.join(split_dir, class_name) |
|
|
if os.path.exists(class_dir): |
|
|
image_files = [f for f in os.listdir(class_dir) |
|
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] |
|
|
split_counts[class_name] = len(image_files) |
|
|
else: |
|
|
split_counts[class_name] = 0 |
|
|
|
|
|
counts[split] = split_counts |
|
|
print(f" {split.upper()}: Good={split_counts['good']}, Bad={split_counts['bad']}") |
|
|
|
|
|
return counts |
|
|
|
|
|
def create_sample_data_structure(): |
|
|
""" |
|
|
Create sample data directory structure with instructions |
|
|
""" |
|
|
instructions = """ |
|
|
π Data Directory Structure: |
|
|
|
|
|
data/ |
|
|
βββ train/ |
|
|
β βββ good/ # Place good telecom site images here |
|
|
β β βββ good_site_001.jpg |
|
|
β β βββ good_site_002.jpg |
|
|
β β βββ ... |
|
|
β βββ bad/ # Place bad telecom site images here |
|
|
β βββ bad_site_001.jpg |
|
|
β βββ bad_site_002.jpg |
|
|
β βββ ... |
|
|
βββ val/ |
|
|
βββ good/ # Validation good images |
|
|
β βββ val_good_001.jpg |
|
|
β βββ ... |
|
|
βββ bad/ # Validation bad images |
|
|
βββ val_bad_001.jpg |
|
|
βββ ... |
|
|
|
|
|
π Data Requirements: |
|
|
- Minimum 50 images per class for training |
|
|
- 20% of data should be reserved for validation |
|
|
- Images should be clear and well-lit |
|
|
- Recommended resolution: 224x224 or higher |
|
|
- Supported formats: JPG, PNG, JPEG, BMP, TIFF |
|
|
|
|
|
π Good Site Criteria: |
|
|
- Proper cable assembly and routing |
|
|
- All cards correctly installed and labeled |
|
|
- Clean and organized equipment layout |
|
|
- Proper grounding and safety measures |
|
|
- Clear and readable labels |
|
|
|
|
|
π Bad Site Criteria: |
|
|
- Messy or improper cable routing |
|
|
- Missing or incorrectly installed cards |
|
|
- Poor equipment organization |
|
|
- Missing or unreadable labels |
|
|
- Safety issues or violations |
|
|
""" |
|
|
|
|
|
print(instructions) |
|
|
return instructions |