|
|
|
|
|
""" |
|
|
CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning |
|
|
|
|
|
This tool helps prepare your personal style dataset for LoRA training: |
|
|
- Organize and validate style images |
|
|
- Generate appropriate captions |
|
|
- Resize and format images for training |
|
|
- Create training/validation splits |
|
|
|
|
|
Usage: |
|
|
python src/generators/compi_phase1e_dataset_prep.py --help |
|
|
python src/generators/compi_phase1e_dataset_prep.py --input-dir my_style_images --style-name "my_art_style" |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import json |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple |
|
|
import random |
|
|
|
|
|
from PIL import Image, ImageOps |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_IMAGE_SIZE = 512 |
|
|
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} |
|
|
MIN_IMAGES_RECOMMENDED = 10 |
|
|
TRAIN_SPLIT_RATIO = 0.8 |
|
|
|
|
|
|
|
|
|
|
|
def setup_args(): |
|
|
"""Setup command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Prepare dataset from a folder of images |
|
|
python %(prog)s --input-dir my_artwork --style-name "impressionist_style" |
|
|
|
|
|
# Custom output directory and image size |
|
|
python %(prog)s --input-dir paintings --style-name "oil_painting" --output-dir datasets/oil_style --size 768 |
|
|
|
|
|
# Generate captions with custom trigger word |
|
|
python %(prog)s --input-dir sketches --style-name "pencil_sketch" --trigger-word "sketch_style" |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument("--input-dir", required=True, |
|
|
help="Directory containing your style images") |
|
|
|
|
|
parser.add_argument("--style-name", required=True, |
|
|
help="Name for your style (used in file naming and captions)") |
|
|
|
|
|
parser.add_argument("--output-dir", |
|
|
help="Output directory for prepared dataset (default: datasets/{style_name})") |
|
|
|
|
|
parser.add_argument("--trigger-word", |
|
|
help="Trigger word for style (default: style_name)") |
|
|
|
|
|
parser.add_argument("--size", type=int, default=DEFAULT_IMAGE_SIZE, |
|
|
help=f"Target image size in pixels (default: {DEFAULT_IMAGE_SIZE})") |
|
|
|
|
|
parser.add_argument("--caption-template", |
|
|
default="a painting in {trigger_word} style", |
|
|
help="Template for generating captions") |
|
|
|
|
|
parser.add_argument("--train-split", type=float, default=TRAIN_SPLIT_RATIO, |
|
|
help=f"Ratio for train/validation split (default: {TRAIN_SPLIT_RATIO})") |
|
|
|
|
|
parser.add_argument("--copy-images", action="store_true", |
|
|
help="Copy images instead of creating symlinks") |
|
|
|
|
|
parser.add_argument("--validate-only", action="store_true", |
|
|
help="Only validate input directory without processing") |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def validate_image_directory(input_dir: str) -> Tuple[List[str], List[str]]: |
|
|
"""Validate input directory and return valid/invalid image files.""" |
|
|
if not os.path.exists(input_dir): |
|
|
raise FileNotFoundError(f"Input directory not found: {input_dir}") |
|
|
|
|
|
all_files = os.listdir(input_dir) |
|
|
valid_images = [] |
|
|
invalid_files = [] |
|
|
|
|
|
for filename in all_files: |
|
|
filepath = os.path.join(input_dir, filename) |
|
|
|
|
|
|
|
|
if not os.path.isfile(filepath): |
|
|
continue |
|
|
|
|
|
|
|
|
ext = Path(filename).suffix.lower() |
|
|
if ext not in SUPPORTED_FORMATS: |
|
|
invalid_files.append(f"{filename} (unsupported format: {ext})") |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
with Image.open(filepath) as img: |
|
|
|
|
|
if img.size[0] < 64 or img.size[1] < 64: |
|
|
invalid_files.append(f"{filename} (too small: {img.size})") |
|
|
continue |
|
|
|
|
|
valid_images.append(filename) |
|
|
except Exception as e: |
|
|
invalid_files.append(f"{filename} (corrupt: {str(e)})") |
|
|
|
|
|
return valid_images, invalid_files |
|
|
|
|
|
def process_image(input_path: str, output_path: str, target_size: int) -> Dict: |
|
|
"""Process a single image for training.""" |
|
|
with Image.open(input_path) as img: |
|
|
|
|
|
if img.mode != 'RGB': |
|
|
img = img.convert('RGB') |
|
|
|
|
|
|
|
|
original_size = img.size |
|
|
|
|
|
|
|
|
img = ImageOps.fit(img, (target_size, target_size), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
img.save(output_path, 'PNG', quality=95) |
|
|
|
|
|
return { |
|
|
'original_size': original_size, |
|
|
'processed_size': img.size, |
|
|
'format': 'PNG' |
|
|
} |
|
|
|
|
|
def generate_captions(image_files: List[str], caption_template: str, trigger_word: str) -> Dict[str, str]: |
|
|
"""Generate captions for training images.""" |
|
|
captions = {} |
|
|
|
|
|
for filename in image_files: |
|
|
|
|
|
caption = caption_template.format(trigger_word=trigger_word) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captions[filename] = caption |
|
|
|
|
|
return captions |
|
|
|
|
|
def create_dataset_structure(output_dir: str, style_name: str): |
|
|
"""Create the dataset directory structure.""" |
|
|
dataset_dir = Path(output_dir) |
|
|
|
|
|
|
|
|
dirs_to_create = [ |
|
|
dataset_dir, |
|
|
dataset_dir / "images", |
|
|
dataset_dir / "train", |
|
|
dataset_dir / "validation" |
|
|
] |
|
|
|
|
|
for dir_path in dirs_to_create: |
|
|
dir_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
return dataset_dir |
|
|
|
|
|
def split_dataset(image_files: List[str], train_ratio: float) -> Tuple[List[str], List[str]]: |
|
|
"""Split images into train and validation sets.""" |
|
|
random.shuffle(image_files) |
|
|
|
|
|
train_count = int(len(image_files) * train_ratio) |
|
|
train_files = image_files[:train_count] |
|
|
val_files = image_files[train_count:] |
|
|
|
|
|
return train_files, val_files |
|
|
|
|
|
def save_metadata(dataset_dir: Path, metadata: Dict): |
|
|
"""Save dataset metadata.""" |
|
|
metadata_file = dataset_dir / "dataset_info.json" |
|
|
|
|
|
with open(metadata_file, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"π Dataset metadata saved to: {metadata_file}") |
|
|
|
|
|
def create_captions_file(dataset_dir: Path, captions: Dict[str, str], split_name: str): |
|
|
"""Create captions file for training.""" |
|
|
captions_file = dataset_dir / f"{split_name}_captions.txt" |
|
|
|
|
|
with open(captions_file, 'w') as f: |
|
|
for filename, caption in captions.items(): |
|
|
f.write(f"{filename}: {caption}\n") |
|
|
|
|
|
return captions_file |
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataset(args): |
|
|
"""Main dataset preparation function.""" |
|
|
print(f"π¨ CompI Phase 1.E: Preparing LoRA Dataset for '{args.style_name}'") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
input_dir = Path(args.input_dir) |
|
|
if args.output_dir: |
|
|
output_dir = Path(args.output_dir) |
|
|
else: |
|
|
output_dir = Path("datasets") / args.style_name |
|
|
|
|
|
trigger_word = args.trigger_word or args.style_name |
|
|
|
|
|
print(f"π Input directory: {input_dir}") |
|
|
print(f"π Output directory: {output_dir}") |
|
|
print(f"π― Style name: {args.style_name}") |
|
|
print(f"π€ Trigger word: {trigger_word}") |
|
|
print(f"π Target size: {args.size}x{args.size}") |
|
|
|
|
|
|
|
|
print(f"\nπ Validating input directory...") |
|
|
valid_images, invalid_files = validate_image_directory(str(input_dir)) |
|
|
|
|
|
print(f"β
Found {len(valid_images)} valid images") |
|
|
if invalid_files: |
|
|
print(f"β οΈ Found {len(invalid_files)} invalid files:") |
|
|
for invalid in invalid_files[:5]: |
|
|
print(f" - {invalid}") |
|
|
if len(invalid_files) > 5: |
|
|
print(f" ... and {len(invalid_files) - 5} more") |
|
|
|
|
|
if len(valid_images) < MIN_IMAGES_RECOMMENDED: |
|
|
print(f"β οΈ Warning: Only {len(valid_images)} images found. Recommended minimum: {MIN_IMAGES_RECOMMENDED}") |
|
|
print(" Consider adding more images for better style learning.") |
|
|
|
|
|
if args.validate_only: |
|
|
print("β
Validation complete (--validate-only specified)") |
|
|
return |
|
|
|
|
|
|
|
|
print(f"\nπ Creating dataset structure...") |
|
|
dataset_dir = create_dataset_structure(str(output_dir), args.style_name) |
|
|
|
|
|
|
|
|
train_files, val_files = split_dataset(valid_images, args.train_split) |
|
|
print(f"π Dataset split: {len(train_files)} train, {len(val_files)} validation") |
|
|
|
|
|
|
|
|
print(f"\nπ Generating captions...") |
|
|
all_captions = generate_captions(valid_images, args.caption_template, trigger_word) |
|
|
|
|
|
|
|
|
print(f"\nπΌοΈ Processing images...") |
|
|
processed_count = 0 |
|
|
processing_stats = [] |
|
|
|
|
|
for split_name, file_list in [("train", train_files), ("validation", val_files)]: |
|
|
if not file_list: |
|
|
continue |
|
|
|
|
|
split_dir = dataset_dir / split_name |
|
|
split_captions = {} |
|
|
|
|
|
for filename in file_list: |
|
|
input_path = input_dir / filename |
|
|
output_filename = f"{Path(filename).stem}.png" |
|
|
output_path = split_dir / output_filename |
|
|
|
|
|
try: |
|
|
stats = process_image(str(input_path), str(output_path), args.size) |
|
|
processing_stats.append(stats) |
|
|
split_captions[output_filename] = all_captions[filename] |
|
|
processed_count += 1 |
|
|
|
|
|
if processed_count % 10 == 0: |
|
|
print(f" Processed {processed_count}/{len(valid_images)} images...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error processing {filename}: {e}") |
|
|
|
|
|
|
|
|
if split_captions: |
|
|
captions_file = create_captions_file(dataset_dir, split_captions, split_name) |
|
|
print(f"π Created {split_name} captions: {captions_file}") |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'style_name': args.style_name, |
|
|
'trigger_word': trigger_word, |
|
|
'total_images': len(valid_images), |
|
|
'train_images': len(train_files), |
|
|
'validation_images': len(val_files), |
|
|
'image_size': args.size, |
|
|
'caption_template': args.caption_template, |
|
|
'created_at': pd.Timestamp.now().isoformat(), |
|
|
'processing_stats': { |
|
|
'processed_count': processed_count, |
|
|
'failed_count': len(valid_images) - processed_count |
|
|
} |
|
|
} |
|
|
|
|
|
save_metadata(dataset_dir, metadata) |
|
|
|
|
|
print(f"\nπ Dataset preparation complete!") |
|
|
print(f"π Dataset location: {dataset_dir}") |
|
|
print(f"π Ready for LoRA training with {processed_count} processed images") |
|
|
print(f"\nπ‘ Next steps:") |
|
|
print(f" 1. Review the generated dataset in: {dataset_dir}") |
|
|
print(f" 2. Run LoRA training: python src/generators/compi_phase1e_lora_training.py --dataset-dir {dataset_dir}") |
|
|
|
|
|
def main(): |
|
|
"""Main function.""" |
|
|
args = setup_args() |
|
|
|
|
|
try: |
|
|
prepare_dataset(args) |
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
return 1 |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|