""" Dataset Preparation Tool Prepare and preprocess image-text datasets for training """ import argparse from pathlib import Path from PIL import Image import json import shutil from typing import List, Tuple def prepare_dataset( input_dir: str, output_dir: str, image_size: int = 512, min_resolution: int = 256, filter_low_quality: bool = True, ): """ Prepare dataset for training Args: input_dir: Directory with raw images output_dir: Output directory for processed data image_size: Target image size min_resolution: Minimum acceptable resolution filter_low_quality: Filter out low quality images """ input_path = Path(input_dir) output_path = Path(output_dir) # Create output directories output_path.mkdir(parents=True, exist_ok=True) (output_path / "images").mkdir(exist_ok=True) (output_path / "captions").mkdir(exist_ok=True) # Find all images image_extensions = ['.jpg', '.jpeg', '.png', '.webp'] image_files = [] for ext in image_extensions: image_files.extend(input_path.glob(f"*{ext}")) image_files.extend(input_path.glob(f"**/*{ext}")) print(f"Found {len(image_files)} images") # Process each image processed_count = 0 skipped_count = 0 for img_file in image_files: try: process_image( img_path=img_file, output_img_path=output_path / "images" / f"{img_file.stem}.jpg", caption_path=output_path / "captions" / f"{img_file.stem}.txt", image_size=image_size, min_resolution=min_resolution, filter_low_quality=filter_low_quality, ) processed_count += 1 if processed_count % 10 == 0: print(f"Processed: {processed_count}/{len(image_files)}") except Exception as e: print(f"Error processing {img_file}: {e}") skipped_count += 1 # Save metadata metadata = { 'total_images': processed_count, 'skipped_images': skipped_count, 'image_size': image_size, 'min_resolution': min_resolution, } with open(output_path / "metadata.json", 'w') as f: json.dump(metadata, f, indent=2) print(f"\n✓ Dataset preparation complete!") print(f" Processed: {processed_count} images") print(f" Skipped: {skipped_count} images") print(f" Output: {output_path}") def process_image( img_path: Path, output_img_path: Path, caption_path: Path, image_size: int = 512, min_resolution: int = 256, filter_low_quality: bool = True, ): """ Process single image Args: img_path: Input image path output_img_path: Output image path caption_path: Output caption path image_size: Target size min_resolution: Minimum resolution filter_low_quality: Filter low quality """ # Load image image = Image.open(img_path).convert('RGB') # Check resolution width, height = image.size if width < min_resolution or height < min_resolution: raise ValueError(f"Image too small: {width}x{height}") # Resize if necessary if min(width, height) > image_size * 1.5: # Downscale large images scale = image_size / max(width, height) new_width = int(width * scale) new_height = int(height * scale) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) # Center crop to square size = min(image.size) left = (image.size[0] - size) // 2 top = (image.size[1] - size) // 2 image = image.crop((left, top, left + size, top + size)) # Resize to target size image = image.resize((image_size, image_size), Image.Resampling.LANCZOS) # Save processed image image.save(output_img_path, quality=95, optimize=True) # Generate or load caption caption = generate_caption(img_path) with open(caption_path, 'w', encoding='utf-8') as f: f.write(caption) def generate_caption(img_path: Path) -> str: """ Generate caption from image filename or load from adjacent text file Args: img_path: Path to image Returns: Caption text """ # Try to load from adjacent .txt file txt_file = img_path.with_suffix('.txt') if txt_file.exists(): with open(txt_file, 'r', encoding='utf-8') as f: caption = f.read().strip() if caption: return caption # Use filename as fallback caption = img_path.stem.replace('_', ' ').replace('-', ' ') # Capitalize first letter caption = caption.capitalize() return caption def create_training_splits( data_dir: str, train_ratio: float = 0.9, val_ratio: float = 0.05, test_ratio: float = 0.05, ): """ Create train/val/test splits Args: data_dir: Directory with processed data train_ratio: Training set ratio val_ratio: Validation set ratio test_ratio: Test set ratio """ data_path = Path(data_dir) # Get all images images = list((data_path / "images").glob("*.jpg")) # Shuffle deterministically import random random.seed(42) random.shuffle(images) # Calculate split sizes total = len(images) train_size = int(total * train_ratio) val_size = int(total * val_ratio) # Split datasets train_images = images[:train_size] val_images = images[train_size:train_size + val_size] test_images = images[train_size + val_size:] # Save splits def save_split(image_list, split_name): split_data = { 'images': [str(img.name) for img in image_list], 'count': len(image_list), } with open(data_path / f"{split_name}.json", 'w') as f: json.dump(split_data, f, indent=2) print(f"{split_name}: {len(image_list)} images") save_split(train_images, "train") save_split(val_images, "validation") save_split(test_images, "test") print(f"\n✓ Created training splits") print(f" Total: {total} images") def main(): parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training") parser.add_argument( "--input", "-i", type=str, required=True, help="Input directory with raw images" ) parser.add_argument( "--output", "-o", type=str, default="./processed_data", help="Output directory for processed data" ) parser.add_argument( "--size", "-s", type=int, default=512, help="Target image size (default: 512)" ) parser.add_argument( "--min_res", type=int, default=256, help="Minimum image resolution (default: 256)" ) parser.add_argument( "--no_filter", action="store_true", help="Disable low quality filtering" ) parser.add_argument( "--create_splits", action="store_true", help="Create train/val/test splits" ) args = parser.parse_args() # Prepare dataset prepare_dataset( input_dir=args.input, output_dir=args.output, image_size=args.size, min_resolution=args.min_res, filter_low_quality=not args.no_filter, ) # Create splits if requested if args.create_splits: create_training_splits(args.output) if __name__ == "__main__": main()