| """
|
| Dataset Preparation Tool
|
| Prepare and preprocess image-text datasets for training
|
| """
|
|
|
| import argparse
|
| from pathlib import Path
|
| from PIL import Image
|
| import json
|
| import shutil
|
| from typing import List, Tuple
|
|
|
|
|
| def prepare_dataset(
|
| input_dir: str,
|
| output_dir: str,
|
| image_size: int = 512,
|
| min_resolution: int = 256,
|
| filter_low_quality: bool = True,
|
| ):
|
| """
|
| Prepare dataset for training
|
|
|
| Args:
|
| input_dir: Directory with raw images
|
| output_dir: Output directory for processed data
|
| image_size: Target image size
|
| min_resolution: Minimum acceptable resolution
|
| filter_low_quality: Filter out low quality images
|
| """
|
| input_path = Path(input_dir)
|
| output_path = Path(output_dir)
|
|
|
|
|
| output_path.mkdir(parents=True, exist_ok=True)
|
| (output_path / "images").mkdir(exist_ok=True)
|
| (output_path / "captions").mkdir(exist_ok=True)
|
|
|
|
|
| image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
| image_files = []
|
|
|
| for ext in image_extensions:
|
| image_files.extend(input_path.glob(f"*{ext}"))
|
| image_files.extend(input_path.glob(f"**/*{ext}"))
|
|
|
| print(f"Found {len(image_files)} images")
|
|
|
|
|
| processed_count = 0
|
| skipped_count = 0
|
|
|
| for img_file in image_files:
|
| try:
|
| process_image(
|
| img_path=img_file,
|
| output_img_path=output_path / "images" / f"{img_file.stem}.jpg",
|
| caption_path=output_path / "captions" / f"{img_file.stem}.txt",
|
| image_size=image_size,
|
| min_resolution=min_resolution,
|
| filter_low_quality=filter_low_quality,
|
| )
|
| processed_count += 1
|
|
|
| if processed_count % 10 == 0:
|
| print(f"Processed: {processed_count}/{len(image_files)}")
|
|
|
| except Exception as e:
|
| print(f"Error processing {img_file}: {e}")
|
| skipped_count += 1
|
|
|
|
|
| metadata = {
|
| 'total_images': processed_count,
|
| 'skipped_images': skipped_count,
|
| 'image_size': image_size,
|
| 'min_resolution': min_resolution,
|
| }
|
|
|
| with open(output_path / "metadata.json", 'w') as f:
|
| json.dump(metadata, f, indent=2)
|
|
|
| print(f"\n✓ Dataset preparation complete!")
|
| print(f" Processed: {processed_count} images")
|
| print(f" Skipped: {skipped_count} images")
|
| print(f" Output: {output_path}")
|
|
|
|
|
| def process_image(
|
| img_path: Path,
|
| output_img_path: Path,
|
| caption_path: Path,
|
| image_size: int = 512,
|
| min_resolution: int = 256,
|
| filter_low_quality: bool = True,
|
| ):
|
| """
|
| Process single image
|
|
|
| Args:
|
| img_path: Input image path
|
| output_img_path: Output image path
|
| caption_path: Output caption path
|
| image_size: Target size
|
| min_resolution: Minimum resolution
|
| filter_low_quality: Filter low quality
|
| """
|
|
|
| image = Image.open(img_path).convert('RGB')
|
|
|
|
|
| width, height = image.size
|
|
|
| if width < min_resolution or height < min_resolution:
|
| raise ValueError(f"Image too small: {width}x{height}")
|
|
|
|
|
| if min(width, height) > image_size * 1.5:
|
|
|
| scale = image_size / max(width, height)
|
| new_width = int(width * scale)
|
| new_height = int(height * scale)
|
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
| size = min(image.size)
|
| left = (image.size[0] - size) // 2
|
| top = (image.size[1] - size) // 2
|
| image = image.crop((left, top, left + size, top + size))
|
|
|
|
|
| image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
|
|
|
|
| image.save(output_img_path, quality=95, optimize=True)
|
|
|
|
|
| caption = generate_caption(img_path)
|
|
|
| with open(caption_path, 'w', encoding='utf-8') as f:
|
| f.write(caption)
|
|
|
|
|
| def generate_caption(img_path: Path) -> str:
|
| """
|
| Generate caption from image filename or load from adjacent text file
|
|
|
| Args:
|
| img_path: Path to image
|
|
|
| Returns:
|
| Caption text
|
| """
|
|
|
| txt_file = img_path.with_suffix('.txt')
|
|
|
| if txt_file.exists():
|
| with open(txt_file, 'r', encoding='utf-8') as f:
|
| caption = f.read().strip()
|
| if caption:
|
| return caption
|
|
|
|
|
| caption = img_path.stem.replace('_', ' ').replace('-', ' ')
|
|
|
|
|
| caption = caption.capitalize()
|
|
|
| return caption
|
|
|
|
|
| def create_training_splits(
|
| data_dir: str,
|
| train_ratio: float = 0.9,
|
| val_ratio: float = 0.05,
|
| test_ratio: float = 0.05,
|
| ):
|
| """
|
| Create train/val/test splits
|
|
|
| Args:
|
| data_dir: Directory with processed data
|
| train_ratio: Training set ratio
|
| val_ratio: Validation set ratio
|
| test_ratio: Test set ratio
|
| """
|
| data_path = Path(data_dir)
|
|
|
|
|
| images = list((data_path / "images").glob("*.jpg"))
|
|
|
|
|
| import random
|
| random.seed(42)
|
| random.shuffle(images)
|
|
|
|
|
| total = len(images)
|
| train_size = int(total * train_ratio)
|
| val_size = int(total * val_ratio)
|
|
|
|
|
| train_images = images[:train_size]
|
| val_images = images[train_size:train_size + val_size]
|
| test_images = images[train_size + val_size:]
|
|
|
|
|
| def save_split(image_list, split_name):
|
| split_data = {
|
| 'images': [str(img.name) for img in image_list],
|
| 'count': len(image_list),
|
| }
|
|
|
| with open(data_path / f"{split_name}.json", 'w') as f:
|
| json.dump(split_data, f, indent=2)
|
|
|
| print(f"{split_name}: {len(image_list)} images")
|
|
|
| save_split(train_images, "train")
|
| save_split(val_images, "validation")
|
| save_split(test_images, "test")
|
|
|
| print(f"\n✓ Created training splits")
|
| print(f" Total: {total} images")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training")
|
|
|
| parser.add_argument(
|
| "--input", "-i",
|
| type=str,
|
| required=True,
|
| help="Input directory with raw images"
|
| )
|
|
|
| parser.add_argument(
|
| "--output", "-o",
|
| type=str,
|
| default="./processed_data",
|
| help="Output directory for processed data"
|
| )
|
|
|
| parser.add_argument(
|
| "--size", "-s",
|
| type=int,
|
| default=512,
|
| help="Target image size (default: 512)"
|
| )
|
|
|
| parser.add_argument(
|
| "--min_res",
|
| type=int,
|
| default=256,
|
| help="Minimum image resolution (default: 256)"
|
| )
|
|
|
| parser.add_argument(
|
| "--no_filter",
|
| action="store_true",
|
| help="Disable low quality filtering"
|
| )
|
|
|
| parser.add_argument(
|
| "--create_splits",
|
| action="store_true",
|
| help="Create train/val/test splits"
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| prepare_dataset(
|
| input_dir=args.input,
|
| output_dir=args.output,
|
| image_size=args.size,
|
| min_resolution=args.min_res,
|
| filter_low_quality=not args.no_filter,
|
| )
|
|
|
|
|
| if args.create_splits:
|
| create_training_splits(args.output)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|