Spaces:
Sleeping
Sleeping
File size: 10,648 Bytes
83039b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 | """
Example usage of dataset loaders.
This script demonstrates how to use both Fashion-MNIST and VITON-HD datasets
with their respective adapters.
"""
import numpy as np
try:
from torchvision import transforms
TORCHVISION_AVAILABLE = True
except ImportError:
TORCHVISION_AVAILABLE = False
transforms = None
from tryon.datasets import (
FashionMNIST,
load_fashion_mnist,
get_fashion_mnist_class_name,
get_fashion_mnist_class_names,
VITONHD,
)
# ============================================================================
# FASHION-MNIST EXAMPLES
# ============================================================================
def fashion_mnist_class_based():
"""Example using the class-based approach (recommended for extensibility)."""
print("\n" + "=" * 60)
print("FASHION-MNIST: CLASS-BASED APPROACH")
print("=" * 60)
# Create dataset instance
dataset = FashionMNIST(download=True)
print(f"\nDataset: {dataset}")
# Get dataset info before loading
info = dataset.get_info()
print(f"\nDataset info:")
print(f" Name: {info['name']}")
print(f" Classes: {info['num_classes']}")
print(f" Image shape: {info['image_shape']}")
print(f" Train size: {info['train_size']}")
print(f" Test size: {info['test_size']}")
# Load the dataset
(train_images, train_labels), (test_images, test_labels) = dataset.load(
normalize=True,
flatten=False
)
print(f"\nDataset loaded!")
print(f" Training set: {train_images.shape} images, {train_labels.shape} labels")
print(f" Test set: {test_images.shape} images, {test_labels.shape} labels")
# Get updated info after loading
info = dataset.get_info()
print(f"\nAfter loading:")
print(f" Normalized: {info['normalized']}")
print(f" Flattened: {info['flattened']}")
# Use class methods
print(f"\nClass names:")
for i, name in enumerate(dataset.get_class_names()):
print(f" {i}: {name}")
# Get a random sample
idx = np.random.randint(0, len(train_images))
print(f"\nRandom sample:")
print(f" Index: {idx}")
print(f" Label: {train_labels[idx]} ({dataset.get_class_name(train_labels[idx])})")
print(f" Image shape: {train_images[idx].shape}")
def fashion_mnist_function_based():
"""Example using the function-based approach (simpler, backward compatible)."""
print("\n" + "=" * 60)
print("FASHION-MNIST: FUNCTION-BASED APPROACH")
print("=" * 60)
# Load the dataset using convenience function
(train_images, train_labels), (test_images, test_labels) = load_fashion_mnist(
download=True,
normalize=True,
flatten=False
)
print(f"\nDataset loaded!")
print(f" Training set: {train_images.shape} images, {train_labels.shape} labels")
print(f" Test set: {test_images.shape} images, {test_labels.shape} labels")
print(f" Image dtype: {train_images.dtype}")
print(f" Range: [{train_images.min():.2f}, {train_images.max():.2f}]")
# Display class names
print(f"\nClass names:")
for i, name in enumerate(get_fashion_mnist_class_names()):
print(f" {i}: {name}")
# Show label distribution
print(f"\nLabel distribution (training set):")
unique, counts = np.unique(train_labels, return_counts=True)
for label, count in zip(unique, counts):
print(f" {label} ({get_fashion_mnist_class_name(label)}): {count} samples")
# Get a random sample
idx = np.random.randint(0, len(train_images))
print(f"\nRandom sample:")
print(f" Index: {idx}")
print(f" Label: {train_labels[idx]} ({get_fashion_mnist_class_name(train_labels[idx])})")
print(f" Image shape: {train_images[idx].shape}")
print(f" Image stats: min={train_images[idx].min():.2f}, max={train_images[idx].max():.2f}, mean={train_images[idx].mean():.2f}")
# ============================================================================
# VITON-HD EXAMPLES
# ============================================================================
def viton_hd_dataloader():
"""Example using PyTorch DataLoader (recommended for large datasets)."""
if not TORCHVISION_AVAILABLE:
print("\n" + "=" * 60)
print("VITON-HD: PYTORCH DATALOADER APPROACH (Recommended)")
print("=" * 60)
print("\n⚠ torchvision is not installed. Install it with: pip install torchvision")
return
print("\n" + "=" * 60)
print("VITON-HD: PYTORCH DATALOADER APPROACH (Recommended)")
print("=" * 60)
# Create dataset instance
dataset = VITONHD(
data_dir="./datasets/viton_hd", # Update with your path
download=False
)
print(f"\nDataset: {dataset}")
print(f"Info: {dataset.get_info()}")
# Define transforms
transform = transforms.Compose([
transforms.Resize((512, 384)), # Resize for faster processing
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Get DataLoader for training
train_loader = dataset.get_dataloader(
split='train',
batch_size=4,
shuffle=True,
num_workers=2,
transform=transform,
return_numpy=False # Return PyTorch tensors
)
print(f"\nTrain DataLoader created:")
print(f" Batch size: {train_loader.batch_size}")
print(f" Number of batches: {len(train_loader)}")
# Iterate through batches (lazy loading - images loaded on-demand)
print(f"\nIterating through first 2 batches:")
for batch_idx, batch in enumerate(train_loader):
if batch_idx >= 2:
break
person_imgs = batch['person']
clothing_imgs = batch['clothing']
print(f"\n Batch {batch_idx + 1}:")
print(f" Person images shape: {person_imgs.shape}")
print(f" Clothing images shape: {clothing_imgs.shape}")
print(f" Person paths: {batch['person_path'][:2]}") # Show first 2 paths
def viton_hd_single_sample():
"""Example loading a single sample."""
print("\n" + "=" * 60)
print("VITON-HD: SINGLE SAMPLE APPROACH")
print("=" * 60)
dataset = VITONHD(
data_dir="./datasets/viton_hd", # Update with your path
download=False
)
# Get a single sample
sample = dataset.get_sample(
index=0,
split='train',
return_numpy=True
)
print(f"\nSample 0:")
print(f" Person image shape: {sample['person'].shape}")
print(f" Clothing image shape: {sample['clothing'].shape}")
print(f" Person path: {sample['person_path']}")
print(f" Clothing path: {sample['clothing_path']}")
print(f" Index: {sample['index']}")
def viton_hd_load_to_memory():
"""Example loading samples into memory (use with caution for large datasets)."""
print("\n" + "=" * 60)
print("VITON-HD: LOAD TO MEMORY APPROACH (Use with caution)")
print("=" * 60)
dataset = VITONHD(
data_dir="./datasets/viton_hd", # Update with your path
download=False
)
# Load only first 10 samples (to avoid memory issues)
(person_imgs, clothing_imgs), _ = dataset.load(
split='train',
max_samples=10, # Limit to 10 samples
normalize=True
)
print(f"\nLoaded samples:")
print(f" Person images shape: {person_imgs.shape}")
print(f" Clothing images shape: {clothing_imgs.shape}")
print(f" Data type: {person_imgs.dtype}")
print(f" Value range: [{person_imgs.min():.2f}, {person_imgs.max():.2f}]")
def viton_hd_custom_transform():
"""Example with custom transforms."""
if not TORCHVISION_AVAILABLE:
print("\n" + "=" * 60)
print("VITON-HD: CUSTOM TRANSFORMS APPROACH")
print("=" * 60)
print("\n⚠ torchvision is not installed. Install it with: pip install torchvision")
return
print("\n" + "=" * 60)
print("VITON-HD: CUSTOM TRANSFORMS APPROACH")
print("=" * 60)
dataset = VITONHD(
data_dir="./datasets/viton_hd", # Update with your path
download=False
)
# Custom transform pipeline
custom_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
])
train_loader = dataset.get_dataloader(
split='train',
batch_size=2,
shuffle=True,
transform=custom_transform
)
print(f"\nCustom transform DataLoader created")
print(f" Applied transforms: Resize, RandomHorizontalFlip, ColorJitter, ToTensor")
# Get one batch
batch = next(iter(train_loader))
print(f"\n Batch shape: {batch['person'].shape}")
# ============================================================================
# MAIN FUNCTION
# ============================================================================
def main():
"""Run all examples."""
print("=" * 60)
print("DATASET USAGE EXAMPLES")
print("=" * 60)
print("\nThis script demonstrates usage of both Fashion-MNIST and VITON-HD datasets.")
print("\n" + "-" * 60)
print("FASHION-MNIST")
print("-" * 60)
print("Small dataset (60MB) - loads entirely into memory")
print("Use cases: Classification, quick prototyping")
print("\n" + "-" * 60)
print("VITON-HD")
print("-" * 60)
print("Large dataset (4.6GB) - uses lazy loading via PyTorch DataLoader")
print("Use cases: Virtual try-on, high-resolution image generation")
print("\nNote: Update 'data_dir' path for VITON-HD to point to your dataset")
print("VITON-HD dataset structure should be:")
print(" data_dir/")
print(" person/")
print(" clothing/")
print(" train_pairs.txt")
print(" test_pairs.txt")
print("\n" + "=" * 60)
print("FASHION-MNIST EXAMPLES")
print("=" * 60)
# Uncomment the Fashion-MNIST examples you want to run:
# fashion_mnist_class_based()
# fashion_mnist_function_based()
print("\n" + "=" * 60)
print("VITON-HD EXAMPLES")
print("=" * 60)
# Uncomment the VITON-HD examples you want to run:
# viton_hd_dataloader()
# viton_hd_single_sample()
# viton_hd_load_to_memory()
# viton_hd_custom_transform()
print("\n" + "=" * 60)
print("Examples ready! Uncomment the functions you want to run.")
print("=" * 60)
if __name__ == '__main__':
main()
|