ByteDream / prepare_dataset.py
Enzo8930302's picture
Upload folder using huggingface_hub
80b58c8 verified
"""
Dataset Preparation Tool
Prepare and preprocess image-text datasets for training
"""
import argparse
from pathlib import Path
from PIL import Image
import json
import shutil
from typing import List, Tuple
def prepare_dataset(
input_dir: str,
output_dir: str,
image_size: int = 512,
min_resolution: int = 256,
filter_low_quality: bool = True,
):
"""
Prepare dataset for training
Args:
input_dir: Directory with raw images
output_dir: Output directory for processed data
image_size: Target image size
min_resolution: Minimum acceptable resolution
filter_low_quality: Filter out low quality images
"""
input_path = Path(input_dir)
output_path = Path(output_dir)
# Create output directories
output_path.mkdir(parents=True, exist_ok=True)
(output_path / "images").mkdir(exist_ok=True)
(output_path / "captions").mkdir(exist_ok=True)
# Find all images
image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
image_files = []
for ext in image_extensions:
image_files.extend(input_path.glob(f"*{ext}"))
image_files.extend(input_path.glob(f"**/*{ext}"))
print(f"Found {len(image_files)} images")
# Process each image
processed_count = 0
skipped_count = 0
for img_file in image_files:
try:
process_image(
img_path=img_file,
output_img_path=output_path / "images" / f"{img_file.stem}.jpg",
caption_path=output_path / "captions" / f"{img_file.stem}.txt",
image_size=image_size,
min_resolution=min_resolution,
filter_low_quality=filter_low_quality,
)
processed_count += 1
if processed_count % 10 == 0:
print(f"Processed: {processed_count}/{len(image_files)}")
except Exception as e:
print(f"Error processing {img_file}: {e}")
skipped_count += 1
# Save metadata
metadata = {
'total_images': processed_count,
'skipped_images': skipped_count,
'image_size': image_size,
'min_resolution': min_resolution,
}
with open(output_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\n✓ Dataset preparation complete!")
print(f" Processed: {processed_count} images")
print(f" Skipped: {skipped_count} images")
print(f" Output: {output_path}")
def process_image(
img_path: Path,
output_img_path: Path,
caption_path: Path,
image_size: int = 512,
min_resolution: int = 256,
filter_low_quality: bool = True,
):
"""
Process single image
Args:
img_path: Input image path
output_img_path: Output image path
caption_path: Output caption path
image_size: Target size
min_resolution: Minimum resolution
filter_low_quality: Filter low quality
"""
# Load image
image = Image.open(img_path).convert('RGB')
# Check resolution
width, height = image.size
if width < min_resolution or height < min_resolution:
raise ValueError(f"Image too small: {width}x{height}")
# Resize if necessary
if min(width, height) > image_size * 1.5:
# Downscale large images
scale = image_size / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Center crop to square
size = min(image.size)
left = (image.size[0] - size) // 2
top = (image.size[1] - size) // 2
image = image.crop((left, top, left + size, top + size))
# Resize to target size
image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
# Save processed image
image.save(output_img_path, quality=95, optimize=True)
# Generate or load caption
caption = generate_caption(img_path)
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(caption)
def generate_caption(img_path: Path) -> str:
"""
Generate caption from image filename or load from adjacent text file
Args:
img_path: Path to image
Returns:
Caption text
"""
# Try to load from adjacent .txt file
txt_file = img_path.with_suffix('.txt')
if txt_file.exists():
with open(txt_file, 'r', encoding='utf-8') as f:
caption = f.read().strip()
if caption:
return caption
# Use filename as fallback
caption = img_path.stem.replace('_', ' ').replace('-', ' ')
# Capitalize first letter
caption = caption.capitalize()
return caption
def create_training_splits(
data_dir: str,
train_ratio: float = 0.9,
val_ratio: float = 0.05,
test_ratio: float = 0.05,
):
"""
Create train/val/test splits
Args:
data_dir: Directory with processed data
train_ratio: Training set ratio
val_ratio: Validation set ratio
test_ratio: Test set ratio
"""
data_path = Path(data_dir)
# Get all images
images = list((data_path / "images").glob("*.jpg"))
# Shuffle deterministically
import random
random.seed(42)
random.shuffle(images)
# Calculate split sizes
total = len(images)
train_size = int(total * train_ratio)
val_size = int(total * val_ratio)
# Split datasets
train_images = images[:train_size]
val_images = images[train_size:train_size + val_size]
test_images = images[train_size + val_size:]
# Save splits
def save_split(image_list, split_name):
split_data = {
'images': [str(img.name) for img in image_list],
'count': len(image_list),
}
with open(data_path / f"{split_name}.json", 'w') as f:
json.dump(split_data, f, indent=2)
print(f"{split_name}: {len(image_list)} images")
save_split(train_images, "train")
save_split(val_images, "validation")
save_split(test_images, "test")
print(f"\n✓ Created training splits")
print(f" Total: {total} images")
def main():
parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training")
parser.add_argument(
"--input", "-i",
type=str,
required=True,
help="Input directory with raw images"
)
parser.add_argument(
"--output", "-o",
type=str,
default="./processed_data",
help="Output directory for processed data"
)
parser.add_argument(
"--size", "-s",
type=int,
default=512,
help="Target image size (default: 512)"
)
parser.add_argument(
"--min_res",
type=int,
default=256,
help="Minimum image resolution (default: 256)"
)
parser.add_argument(
"--no_filter",
action="store_true",
help="Disable low quality filtering"
)
parser.add_argument(
"--create_splits",
action="store_true",
help="Create train/val/test splits"
)
args = parser.parse_args()
# Prepare dataset
prepare_dataset(
input_dir=args.input,
output_dir=args.output,
image_size=args.size,
min_resolution=args.min_res,
filter_low_quality=not args.no_filter,
)
# Create splits if requested
if args.create_splits:
create_training_splits(args.output)
if __name__ == "__main__":
main()