Spaces:
Sleeping
Sleeping
| """ | |
| Data utilities for telecom site classification | |
| Handles data loading, transformations, and dataset management | |
| """ | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from torchvision import transforms, datasets | |
| from PIL import Image | |
| import numpy as np | |
| from typing import Tuple, Dict, List, Optional | |
| from collections import Counter | |
| import random | |
| class TelecomSiteDataset(Dataset): | |
| """ | |
| Custom dataset for telecom site images | |
| Supports both training and validation modes with appropriate transforms | |
| """ | |
| def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224): | |
| """ | |
| Initialize telecom site dataset | |
| Args: | |
| data_dir: Root directory containing train/val folders | |
| split: 'train' or 'val' | |
| image_size: Size to resize images to | |
| """ | |
| self.data_dir = data_dir | |
| self.split = split | |
| self.image_size = image_size | |
| # Define class mapping | |
| self.classes = ['bad', 'good'] # 0: bad, 1: good | |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} | |
| # Load image paths and labels | |
| self.samples = self._load_samples() | |
| # Define transforms | |
| self.transform = self._get_transforms() | |
| print(f"π {split.upper()} Dataset loaded:") | |
| print(f" Total samples: {len(self.samples)}") | |
| print(f" Classes: {self.classes}") | |
| self._print_class_distribution() | |
| def _load_samples(self) -> List[Tuple[str, int]]: | |
| """Load image paths and corresponding labels""" | |
| samples = [] | |
| split_dir = os.path.join(self.data_dir, self.split) | |
| for class_name in self.classes: | |
| class_dir = os.path.join(split_dir, class_name) | |
| if not os.path.exists(class_dir): | |
| print(f"β οΈ Warning: {class_dir} not found") | |
| continue | |
| class_idx = self.class_to_idx[class_name] | |
| # Load all images from class directory | |
| for img_name in os.listdir(class_dir): | |
| if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
| img_path = os.path.join(class_dir, img_name) | |
| samples.append((img_path, class_idx)) | |
| return samples | |
| def _print_class_distribution(self): | |
| """Print class distribution for the dataset""" | |
| class_counts = Counter([label for _, label in self.samples]) | |
| for class_name, class_idx in self.class_to_idx.items(): | |
| count = class_counts.get(class_idx, 0) | |
| print(f" {class_name}: {count} samples") | |
| def _get_transforms(self) -> transforms.Compose: | |
| """Get appropriate transforms for the split""" | |
| if self.split == 'train': | |
| return transforms.Compose([ | |
| transforms.Resize((self.image_size + 32, self.image_size + 32)), | |
| transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(degrees=10), | |
| transforms.ColorJitter( | |
| brightness=0.2, | |
| contrast=0.2, | |
| saturation=0.2, | |
| hue=0.1 | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| transforms.RandomErasing(p=0.1, scale=(0.02, 0.08)) | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((self.image_size, self.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: | |
| """Get a sample from the dataset""" | |
| img_path, label = self.samples[idx] | |
| # Load image | |
| try: | |
| image = Image.open(img_path).convert('RGB') | |
| except Exception as e: | |
| print(f"β οΈ Error loading image {img_path}: {e}") | |
| # Return a black image as fallback | |
| image = Image.new('RGB', (self.image_size, self.image_size), color='black') | |
| # Apply transforms | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def create_data_loaders( | |
| data_dir: str, | |
| batch_size: int = 16, | |
| num_workers: int = 4, | |
| image_size: int = 224, | |
| use_weighted_sampling: bool = True | |
| ) -> Tuple[DataLoader, DataLoader]: | |
| """ | |
| Create train and validation data loaders | |
| Args: | |
| data_dir: Root directory containing train/val folders | |
| batch_size: Batch size for data loaders | |
| num_workers: Number of worker processes | |
| image_size: Size to resize images to | |
| use_weighted_sampling: Whether to use weighted sampling for imbalanced data | |
| Returns: | |
| Tuple of (train_loader, val_loader) | |
| """ | |
| # Create datasets | |
| train_dataset = TelecomSiteDataset(data_dir, 'train', image_size) | |
| val_dataset = TelecomSiteDataset(data_dir, 'val', image_size) | |
| # Create samplers | |
| train_sampler = None | |
| if use_weighted_sampling and len(train_dataset) > 0: | |
| train_sampler = create_weighted_sampler(train_dataset) | |
| # Create data loaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| sampler=train_sampler, | |
| shuffle=(train_sampler is None), | |
| num_workers=num_workers, | |
| pin_memory=torch.cuda.is_available(), | |
| drop_last=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=torch.cuda.is_available() | |
| ) | |
| print(f"π¦ Data loaders created:") | |
| print(f" Batch size: {batch_size}") | |
| print(f" Num workers: {num_workers}") | |
| print(f" Train batches: {len(train_loader)}") | |
| print(f" Val batches: {len(val_loader)}") | |
| print(f" Weighted sampling: {use_weighted_sampling}") | |
| return train_loader, val_loader | |
| def create_weighted_sampler(dataset: TelecomSiteDataset) -> WeightedRandomSampler: | |
| """ | |
| Create weighted random sampler for imbalanced datasets | |
| Args: | |
| dataset: The dataset to create sampler for | |
| Returns: | |
| WeightedRandomSampler for balanced sampling | |
| """ | |
| # Count samples per class | |
| class_counts = Counter([label for _, label in dataset.samples]) | |
| total_samples = len(dataset.samples) | |
| # Calculate weights (inverse frequency) | |
| class_weights = {} | |
| for class_idx in range(len(dataset.classes)): | |
| class_weights[class_idx] = total_samples / (len(dataset.classes) * class_counts.get(class_idx, 1)) | |
| # Create sample weights | |
| sample_weights = [class_weights[label] for _, label in dataset.samples] | |
| sampler = WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(sample_weights), | |
| replacement=True | |
| ) | |
| print(f"βοΈ Weighted sampler created:") | |
| for class_name, class_idx in dataset.class_to_idx.items(): | |
| print(f" {class_name}: weight={class_weights[class_idx]:.3f}") | |
| return sampler | |
| def get_inference_transform(image_size: int = 224) -> transforms.Compose: | |
| """ | |
| Get transform for inference/prediction | |
| Args: | |
| image_size: Size to resize images to | |
| Returns: | |
| Transform pipeline for inference | |
| """ | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor: | |
| """ | |
| Prepare a PIL image for model inference | |
| Args: | |
| image: PIL Image | |
| transform: Transform pipeline | |
| Returns: | |
| Preprocessed tensor ready for model | |
| """ | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Apply transforms and add batch dimension | |
| tensor = transform(image).unsqueeze(0) | |
| return tensor | |
| def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None: | |
| """ | |
| Visualize a batch of images from the data loader | |
| Args: | |
| data_loader: DataLoader to sample from | |
| num_samples: Number of samples to visualize | |
| """ | |
| try: | |
| import matplotlib.pyplot as plt | |
| # Get a batch | |
| batch_images, batch_labels = next(iter(data_loader)) | |
| # Denormalize images for visualization | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| # Create figure | |
| fig, axes = plt.subplots(2, 4, figsize=(12, 6)) | |
| axes = axes.flatten() | |
| class_names = ['Bad', 'Good'] | |
| for i in range(min(num_samples, len(batch_images))): | |
| # Denormalize | |
| img = batch_images[i] * std + mean | |
| img = torch.clamp(img, 0, 1) | |
| # Convert to numpy and transpose | |
| img_np = img.permute(1, 2, 0).numpy() | |
| # Plot | |
| axes[i].imshow(img_np) | |
| axes[i].set_title(f'Class: {class_names[batch_labels[i]]}') | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| except ImportError: | |
| print("β οΈ Matplotlib not available for visualization") | |
| def check_data_directory(data_dir: str) -> Dict[str, int]: | |
| """ | |
| Check the data directory structure and count samples | |
| Args: | |
| data_dir: Root directory to check | |
| Returns: | |
| Dictionary with sample counts | |
| """ | |
| print(f"π Checking data directory: {data_dir}") | |
| if not os.path.exists(data_dir): | |
| print(f"β Data directory not found: {data_dir}") | |
| return {} | |
| counts = {} | |
| for split in ['train', 'val']: | |
| split_dir = os.path.join(data_dir, split) | |
| if not os.path.exists(split_dir): | |
| print(f"β οΈ {split} directory not found") | |
| continue | |
| split_counts = {} | |
| for class_name in ['good', 'bad']: | |
| class_dir = os.path.join(split_dir, class_name) | |
| if os.path.exists(class_dir): | |
| image_files = [f for f in os.listdir(class_dir) | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] | |
| split_counts[class_name] = len(image_files) | |
| else: | |
| split_counts[class_name] = 0 | |
| counts[split] = split_counts | |
| print(f" {split.upper()}: Good={split_counts['good']}, Bad={split_counts['bad']}") | |
| return counts | |
| def create_sample_data_structure(): | |
| """ | |
| Create sample data directory structure with instructions | |
| """ | |
| instructions = """ | |
| π Data Directory Structure: | |
| data/ | |
| βββ train/ | |
| β βββ good/ # Place good telecom site images here | |
| β β βββ good_site_001.jpg | |
| β β βββ good_site_002.jpg | |
| β β βββ ... | |
| β βββ bad/ # Place bad telecom site images here | |
| β βββ bad_site_001.jpg | |
| β βββ bad_site_002.jpg | |
| β βββ ... | |
| βββ val/ | |
| βββ good/ # Validation good images | |
| β βββ val_good_001.jpg | |
| β βββ ... | |
| βββ bad/ # Validation bad images | |
| βββ val_bad_001.jpg | |
| βββ ... | |
| π Data Requirements: | |
| - Minimum 50 images per class for training | |
| - 20% of data should be reserved for validation | |
| - Images should be clear and well-lit | |
| - Recommended resolution: 224x224 or higher | |
| - Supported formats: JPG, PNG, JPEG, BMP, TIFF | |
| π Good Site Criteria: | |
| - Proper cable assembly and routing | |
| - All cards correctly installed and labeled | |
| - Clean and organized equipment layout | |
| - Proper grounding and safety measures | |
| - Clear and readable labels | |
| π Bad Site Criteria: | |
| - Messy or improper cable routing | |
| - Missing or incorrectly installed cards | |
| - Poor equipment organization | |
| - Missing or unreadable labels | |
| - Safety issues or violations | |
| """ | |
| print(instructions) | |
| return instructions |