""" Data loading and preprocessing utilities """ import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from pathlib import Path from typing import Tuple, Optional import config class TabletDataset(Dataset): """Dataset for loading tablet images""" def __init__(self, root_dir: Path, transform=None, mask_dir: Optional[Path] = None): """ Args: root_dir: Directory containing images transform: Optional transform to apply to images mask_dir: Optional directory containing ground truth masks """ self.root_dir = root_dir self.transform = transform self.mask_dir = mask_dir # Get all PNG images self.image_paths = sorted(list(root_dir.glob("*.png"))) if not self.image_paths: raise ValueError(f"No images found in {root_dir}") def __len__(self) -> int: return len(self.image_paths) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Optional[torch.Tensor]]: """ Returns: image: Preprocessed image tensor image_path: Path to the image mask: Ground truth mask (if available) """ img_path = self.image_paths[idx] image = Image.open(img_path).convert("RGB") # Load mask if available mask = None if self.mask_dir is not None: mask_path = self.mask_dir / img_path.name if mask_path.exists(): mask = Image.open(mask_path).convert("L") mask = transforms.Resize(config.IMAGE_SIZE)(mask) mask = torch.tensor(np.array(mask), dtype=torch.float32) mask = (mask > 0).float() # Binarize if self.transform: image = self.transform(image) return image, str(img_path), mask def get_transforms(is_train: bool = False): """Get image preprocessing transforms""" transform_list = [ transforms.Resize(config.IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=config.MEAN, std=config.STD) ] # No augmentation needed for unsupervised anomaly detection return transforms.Compose(transform_list) def custom_collate(batch): """Custom collate function to handle None masks""" images = torch.stack([item[0] for item in batch]) paths = [item[1] for item in batch] masks = [item[2] for item in batch] # Convert None masks to empty list if all are None if all(m is None for m in masks): masks = None else: # Stack non-None masks, pad None with zeros masks = torch.stack([m if m is not None else torch.zeros_like(masks[0]) for m in masks]) return images, paths, masks def get_dataloader(data_dir: Path, batch_size: int = 32, shuffle: bool = False, mask_dir: Optional[Path] = None) -> DataLoader: """Create DataLoader for tablet images""" transform = get_transforms() dataset = TabletDataset(data_dir, transform=transform, mask_dir=mask_dir) # Set num_workers to 0 for Windows compatibility num_workers = 0 dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False, # Disable for CPU collate_fn=custom_collate ) return dataloader def denormalize_image(tensor: torch.Tensor) -> torch.Tensor: """Denormalize image tensor for visualization""" mean = torch.tensor(config.MEAN).view(3, 1, 1) std = torch.tensor(config.STD).view(3, 1, 1) return tensor * std + mean import numpy as np # Need this import def load_single_image(image_path: str) -> Tuple[torch.Tensor, Image.Image]: """ Load and preprocess a single image for inference Args: image_path: Path to the image Returns: preprocessed: Preprocessed tensor [1, 3, H, W] original: Original PIL image """ original = Image.open(image_path).convert("RGB") transform = get_transforms() preprocessed = transform(original).unsqueeze(0) # Add batch dimension return preprocessed, original