File size: 11,756 Bytes
338d95d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
#!/usr/bin/env python3
"""
CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning
This tool helps prepare your personal style dataset for LoRA training:
- Organize and validate style images
- Generate appropriate captions
- Resize and format images for training
- Create training/validation splits
Usage:
python src/generators/compi_phase1e_dataset_prep.py --help
python src/generators/compi_phase1e_dataset_prep.py --input-dir my_style_images --style-name "my_art_style"
"""
import os
import argparse
import json
import shutil
from pathlib import Path
from typing import List, Dict, Tuple
import random
from PIL import Image, ImageOps
import pandas as pd
# -------- 1. CONFIGURATION --------
DEFAULT_IMAGE_SIZE = 512
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
MIN_IMAGES_RECOMMENDED = 10
TRAIN_SPLIT_RATIO = 0.8
# -------- 2. UTILITY FUNCTIONS --------
def setup_args():
"""Setup command line arguments."""
parser = argparse.ArgumentParser(
description="CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Prepare dataset from a folder of images
python %(prog)s --input-dir my_artwork --style-name "impressionist_style"
# Custom output directory and image size
python %(prog)s --input-dir paintings --style-name "oil_painting" --output-dir datasets/oil_style --size 768
# Generate captions with custom trigger word
python %(prog)s --input-dir sketches --style-name "pencil_sketch" --trigger-word "sketch_style"
"""
)
parser.add_argument("--input-dir", required=True,
help="Directory containing your style images")
parser.add_argument("--style-name", required=True,
help="Name for your style (used in file naming and captions)")
parser.add_argument("--output-dir",
help="Output directory for prepared dataset (default: datasets/{style_name})")
parser.add_argument("--trigger-word",
help="Trigger word for style (default: style_name)")
parser.add_argument("--size", type=int, default=DEFAULT_IMAGE_SIZE,
help=f"Target image size in pixels (default: {DEFAULT_IMAGE_SIZE})")
parser.add_argument("--caption-template",
default="a painting in {trigger_word} style",
help="Template for generating captions")
parser.add_argument("--train-split", type=float, default=TRAIN_SPLIT_RATIO,
help=f"Ratio for train/validation split (default: {TRAIN_SPLIT_RATIO})")
parser.add_argument("--copy-images", action="store_true",
help="Copy images instead of creating symlinks")
parser.add_argument("--validate-only", action="store_true",
help="Only validate input directory without processing")
return parser.parse_args()
def validate_image_directory(input_dir: str) -> Tuple[List[str], List[str]]:
"""Validate input directory and return valid/invalid image files."""
if not os.path.exists(input_dir):
raise FileNotFoundError(f"Input directory not found: {input_dir}")
all_files = os.listdir(input_dir)
valid_images = []
invalid_files = []
for filename in all_files:
filepath = os.path.join(input_dir, filename)
# Check if it's a file
if not os.path.isfile(filepath):
continue
# Check extension
ext = Path(filename).suffix.lower()
if ext not in SUPPORTED_FORMATS:
invalid_files.append(f"{filename} (unsupported format: {ext})")
continue
# Try to open image
try:
with Image.open(filepath) as img:
# Basic validation
if img.size[0] < 64 or img.size[1] < 64:
invalid_files.append(f"{filename} (too small: {img.size})")
continue
valid_images.append(filename)
except Exception as e:
invalid_files.append(f"{filename} (corrupt: {str(e)})")
return valid_images, invalid_files
def process_image(input_path: str, output_path: str, target_size: int) -> Dict:
"""Process a single image for training."""
with Image.open(input_path) as img:
# Convert to RGB if needed
if img.mode != 'RGB':
img = img.convert('RGB')
# Get original dimensions
original_size = img.size
# Resize maintaining aspect ratio, then center crop
img = ImageOps.fit(img, (target_size, target_size), Image.Resampling.LANCZOS)
# Save processed image
img.save(output_path, 'PNG', quality=95)
return {
'original_size': original_size,
'processed_size': img.size,
'format': 'PNG'
}
def generate_captions(image_files: List[str], caption_template: str, trigger_word: str) -> Dict[str, str]:
"""Generate captions for training images."""
captions = {}
for filename in image_files:
# Basic caption using template
caption = caption_template.format(trigger_word=trigger_word)
# You could add more sophisticated caption generation here
# For example, using BLIP or other image captioning models
captions[filename] = caption
return captions
def create_dataset_structure(output_dir: str, style_name: str):
"""Create the dataset directory structure."""
dataset_dir = Path(output_dir)
# Create main directories
dirs_to_create = [
dataset_dir,
dataset_dir / "images",
dataset_dir / "train",
dataset_dir / "validation"
]
for dir_path in dirs_to_create:
dir_path.mkdir(parents=True, exist_ok=True)
return dataset_dir
def split_dataset(image_files: List[str], train_ratio: float) -> Tuple[List[str], List[str]]:
"""Split images into train and validation sets."""
random.shuffle(image_files)
train_count = int(len(image_files) * train_ratio)
train_files = image_files[:train_count]
val_files = image_files[train_count:]
return train_files, val_files
def save_metadata(dataset_dir: Path, metadata: Dict):
"""Save dataset metadata."""
metadata_file = dataset_dir / "dataset_info.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"π Dataset metadata saved to: {metadata_file}")
def create_captions_file(dataset_dir: Path, captions: Dict[str, str], split_name: str):
"""Create captions file for training."""
captions_file = dataset_dir / f"{split_name}_captions.txt"
with open(captions_file, 'w') as f:
for filename, caption in captions.items():
f.write(f"{filename}: {caption}\n")
return captions_file
# -------- 3. MAIN PROCESSING FUNCTION --------
def prepare_dataset(args):
"""Main dataset preparation function."""
print(f"π¨ CompI Phase 1.E: Preparing LoRA Dataset for '{args.style_name}'")
print("=" * 60)
# Setup paths
input_dir = Path(args.input_dir)
if args.output_dir:
output_dir = Path(args.output_dir)
else:
output_dir = Path("datasets") / args.style_name
trigger_word = args.trigger_word or args.style_name
print(f"π Input directory: {input_dir}")
print(f"π Output directory: {output_dir}")
print(f"π― Style name: {args.style_name}")
print(f"π€ Trigger word: {trigger_word}")
print(f"π Target size: {args.size}x{args.size}")
# Validate input directory
print(f"\nπ Validating input directory...")
valid_images, invalid_files = validate_image_directory(str(input_dir))
print(f"β
Found {len(valid_images)} valid images")
if invalid_files:
print(f"β οΈ Found {len(invalid_files)} invalid files:")
for invalid in invalid_files[:5]: # Show first 5
print(f" - {invalid}")
if len(invalid_files) > 5:
print(f" ... and {len(invalid_files) - 5} more")
if len(valid_images) < MIN_IMAGES_RECOMMENDED:
print(f"β οΈ Warning: Only {len(valid_images)} images found. Recommended minimum: {MIN_IMAGES_RECOMMENDED}")
print(" Consider adding more images for better style learning.")
if args.validate_only:
print("β
Validation complete (--validate-only specified)")
return
# Create dataset structure
print(f"\nπ Creating dataset structure...")
dataset_dir = create_dataset_structure(str(output_dir), args.style_name)
# Split dataset
train_files, val_files = split_dataset(valid_images, args.train_split)
print(f"π Dataset split: {len(train_files)} train, {len(val_files)} validation")
# Generate captions
print(f"\nπ Generating captions...")
all_captions = generate_captions(valid_images, args.caption_template, trigger_word)
# Process images
print(f"\nπΌοΈ Processing images...")
processed_count = 0
processing_stats = []
for split_name, file_list in [("train", train_files), ("validation", val_files)]:
if not file_list:
continue
split_dir = dataset_dir / split_name
split_captions = {}
for filename in file_list:
input_path = input_dir / filename
output_filename = f"{Path(filename).stem}.png"
output_path = split_dir / output_filename
try:
stats = process_image(str(input_path), str(output_path), args.size)
processing_stats.append(stats)
split_captions[output_filename] = all_captions[filename]
processed_count += 1
if processed_count % 10 == 0:
print(f" Processed {processed_count}/{len(valid_images)} images...")
except Exception as e:
print(f"β Error processing {filename}: {e}")
# Create captions file for this split
if split_captions:
captions_file = create_captions_file(dataset_dir, split_captions, split_name)
print(f"π Created {split_name} captions: {captions_file}")
# Save metadata
metadata = {
'style_name': args.style_name,
'trigger_word': trigger_word,
'total_images': len(valid_images),
'train_images': len(train_files),
'validation_images': len(val_files),
'image_size': args.size,
'caption_template': args.caption_template,
'created_at': pd.Timestamp.now().isoformat(),
'processing_stats': {
'processed_count': processed_count,
'failed_count': len(valid_images) - processed_count
}
}
save_metadata(dataset_dir, metadata)
print(f"\nπ Dataset preparation complete!")
print(f"π Dataset location: {dataset_dir}")
print(f"π Ready for LoRA training with {processed_count} processed images")
print(f"\nπ‘ Next steps:")
print(f" 1. Review the generated dataset in: {dataset_dir}")
print(f" 2. Run LoRA training: python src/generators/compi_phase1e_lora_training.py --dataset-dir {dataset_dir}")
def main():
"""Main function."""
args = setup_args()
try:
prepare_dataset(args)
except Exception as e:
print(f"β Error: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())
|