""" Example usage of dataset loaders. This script demonstrates how to use both Fashion-MNIST and VITON-HD datasets with their respective adapters. """ import numpy as np try: from torchvision import transforms TORCHVISION_AVAILABLE = True except ImportError: TORCHVISION_AVAILABLE = False transforms = None from tryon.datasets import ( FashionMNIST, load_fashion_mnist, get_fashion_mnist_class_name, get_fashion_mnist_class_names, VITONHD, ) # ============================================================================ # FASHION-MNIST EXAMPLES # ============================================================================ def fashion_mnist_class_based(): """Example using the class-based approach (recommended for extensibility).""" print("\n" + "=" * 60) print("FASHION-MNIST: CLASS-BASED APPROACH") print("=" * 60) # Create dataset instance dataset = FashionMNIST(download=True) print(f"\nDataset: {dataset}") # Get dataset info before loading info = dataset.get_info() print(f"\nDataset info:") print(f" Name: {info['name']}") print(f" Classes: {info['num_classes']}") print(f" Image shape: {info['image_shape']}") print(f" Train size: {info['train_size']}") print(f" Test size: {info['test_size']}") # Load the dataset (train_images, train_labels), (test_images, test_labels) = dataset.load( normalize=True, flatten=False ) print(f"\nDataset loaded!") print(f" Training set: {train_images.shape} images, {train_labels.shape} labels") print(f" Test set: {test_images.shape} images, {test_labels.shape} labels") # Get updated info after loading info = dataset.get_info() print(f"\nAfter loading:") print(f" Normalized: {info['normalized']}") print(f" Flattened: {info['flattened']}") # Use class methods print(f"\nClass names:") for i, name in enumerate(dataset.get_class_names()): print(f" {i}: {name}") # Get a random sample idx = np.random.randint(0, len(train_images)) print(f"\nRandom sample:") print(f" Index: {idx}") print(f" Label: {train_labels[idx]} ({dataset.get_class_name(train_labels[idx])})") print(f" Image shape: {train_images[idx].shape}") def fashion_mnist_function_based(): """Example using the function-based approach (simpler, backward compatible).""" print("\n" + "=" * 60) print("FASHION-MNIST: FUNCTION-BASED APPROACH") print("=" * 60) # Load the dataset using convenience function (train_images, train_labels), (test_images, test_labels) = load_fashion_mnist( download=True, normalize=True, flatten=False ) print(f"\nDataset loaded!") print(f" Training set: {train_images.shape} images, {train_labels.shape} labels") print(f" Test set: {test_images.shape} images, {test_labels.shape} labels") print(f" Image dtype: {train_images.dtype}") print(f" Range: [{train_images.min():.2f}, {train_images.max():.2f}]") # Display class names print(f"\nClass names:") for i, name in enumerate(get_fashion_mnist_class_names()): print(f" {i}: {name}") # Show label distribution print(f"\nLabel distribution (training set):") unique, counts = np.unique(train_labels, return_counts=True) for label, count in zip(unique, counts): print(f" {label} ({get_fashion_mnist_class_name(label)}): {count} samples") # Get a random sample idx = np.random.randint(0, len(train_images)) print(f"\nRandom sample:") print(f" Index: {idx}") print(f" Label: {train_labels[idx]} ({get_fashion_mnist_class_name(train_labels[idx])})") print(f" Image shape: {train_images[idx].shape}") print(f" Image stats: min={train_images[idx].min():.2f}, max={train_images[idx].max():.2f}, mean={train_images[idx].mean():.2f}") # ============================================================================ # VITON-HD EXAMPLES # ============================================================================ def viton_hd_dataloader(): """Example using PyTorch DataLoader (recommended for large datasets).""" if not TORCHVISION_AVAILABLE: print("\n" + "=" * 60) print("VITON-HD: PYTORCH DATALOADER APPROACH (Recommended)") print("=" * 60) print("\n⚠ torchvision is not installed. Install it with: pip install torchvision") return print("\n" + "=" * 60) print("VITON-HD: PYTORCH DATALOADER APPROACH (Recommended)") print("=" * 60) # Create dataset instance dataset = VITONHD( data_dir="./datasets/viton_hd", # Update with your path download=False ) print(f"\nDataset: {dataset}") print(f"Info: {dataset.get_info()}") # Define transforms transform = transforms.Compose([ transforms.Resize((512, 384)), # Resize for faster processing transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Get DataLoader for training train_loader = dataset.get_dataloader( split='train', batch_size=4, shuffle=True, num_workers=2, transform=transform, return_numpy=False # Return PyTorch tensors ) print(f"\nTrain DataLoader created:") print(f" Batch size: {train_loader.batch_size}") print(f" Number of batches: {len(train_loader)}") # Iterate through batches (lazy loading - images loaded on-demand) print(f"\nIterating through first 2 batches:") for batch_idx, batch in enumerate(train_loader): if batch_idx >= 2: break person_imgs = batch['person'] clothing_imgs = batch['clothing'] print(f"\n Batch {batch_idx + 1}:") print(f" Person images shape: {person_imgs.shape}") print(f" Clothing images shape: {clothing_imgs.shape}") print(f" Person paths: {batch['person_path'][:2]}") # Show first 2 paths def viton_hd_single_sample(): """Example loading a single sample.""" print("\n" + "=" * 60) print("VITON-HD: SINGLE SAMPLE APPROACH") print("=" * 60) dataset = VITONHD( data_dir="./datasets/viton_hd", # Update with your path download=False ) # Get a single sample sample = dataset.get_sample( index=0, split='train', return_numpy=True ) print(f"\nSample 0:") print(f" Person image shape: {sample['person'].shape}") print(f" Clothing image shape: {sample['clothing'].shape}") print(f" Person path: {sample['person_path']}") print(f" Clothing path: {sample['clothing_path']}") print(f" Index: {sample['index']}") def viton_hd_load_to_memory(): """Example loading samples into memory (use with caution for large datasets).""" print("\n" + "=" * 60) print("VITON-HD: LOAD TO MEMORY APPROACH (Use with caution)") print("=" * 60) dataset = VITONHD( data_dir="./datasets/viton_hd", # Update with your path download=False ) # Load only first 10 samples (to avoid memory issues) (person_imgs, clothing_imgs), _ = dataset.load( split='train', max_samples=10, # Limit to 10 samples normalize=True ) print(f"\nLoaded samples:") print(f" Person images shape: {person_imgs.shape}") print(f" Clothing images shape: {clothing_imgs.shape}") print(f" Data type: {person_imgs.dtype}") print(f" Value range: [{person_imgs.min():.2f}, {person_imgs.max():.2f}]") def viton_hd_custom_transform(): """Example with custom transforms.""" if not TORCHVISION_AVAILABLE: print("\n" + "=" * 60) print("VITON-HD: CUSTOM TRANSFORMS APPROACH") print("=" * 60) print("\n⚠ torchvision is not installed. Install it with: pip install torchvision") return print("\n" + "=" * 60) print("VITON-HD: CUSTOM TRANSFORMS APPROACH") print("=" * 60) dataset = VITONHD( data_dir="./datasets/viton_hd", # Update with your path download=False ) # Custom transform pipeline custom_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), ]) train_loader = dataset.get_dataloader( split='train', batch_size=2, shuffle=True, transform=custom_transform ) print(f"\nCustom transform DataLoader created") print(f" Applied transforms: Resize, RandomHorizontalFlip, ColorJitter, ToTensor") # Get one batch batch = next(iter(train_loader)) print(f"\n Batch shape: {batch['person'].shape}") # ============================================================================ # MAIN FUNCTION # ============================================================================ def main(): """Run all examples.""" print("=" * 60) print("DATASET USAGE EXAMPLES") print("=" * 60) print("\nThis script demonstrates usage of both Fashion-MNIST and VITON-HD datasets.") print("\n" + "-" * 60) print("FASHION-MNIST") print("-" * 60) print("Small dataset (60MB) - loads entirely into memory") print("Use cases: Classification, quick prototyping") print("\n" + "-" * 60) print("VITON-HD") print("-" * 60) print("Large dataset (4.6GB) - uses lazy loading via PyTorch DataLoader") print("Use cases: Virtual try-on, high-resolution image generation") print("\nNote: Update 'data_dir' path for VITON-HD to point to your dataset") print("VITON-HD dataset structure should be:") print(" data_dir/") print(" person/") print(" clothing/") print(" train_pairs.txt") print(" test_pairs.txt") print("\n" + "=" * 60) print("FASHION-MNIST EXAMPLES") print("=" * 60) # Uncomment the Fashion-MNIST examples you want to run: # fashion_mnist_class_based() # fashion_mnist_function_based() print("\n" + "=" * 60) print("VITON-HD EXAMPLES") print("=" * 60) # Uncomment the VITON-HD examples you want to run: # viton_hd_dataloader() # viton_hd_single_sample() # viton_hd_load_to_memory() # viton_hd_custom_transform() print("\n" + "=" * 60) print("Examples ready! Uncomment the functions you want to run.") print("=" * 60) if __name__ == '__main__': main()