File size: 11,756 Bytes
338d95d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
"""
CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning

This tool helps prepare your personal style dataset for LoRA training:
- Organize and validate style images
- Generate appropriate captions
- Resize and format images for training
- Create training/validation splits

Usage:
    python src/generators/compi_phase1e_dataset_prep.py --help
    python src/generators/compi_phase1e_dataset_prep.py --input-dir my_style_images --style-name "my_art_style"
"""

import os
import argparse
import json
import shutil
from pathlib import Path
from typing import List, Dict, Tuple
import random

from PIL import Image, ImageOps
import pandas as pd

# -------- 1. CONFIGURATION --------

DEFAULT_IMAGE_SIZE = 512
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
MIN_IMAGES_RECOMMENDED = 10
TRAIN_SPLIT_RATIO = 0.8

# -------- 2. UTILITY FUNCTIONS --------

def setup_args():
    """Setup command line arguments."""
    parser = argparse.ArgumentParser(
        description="CompI Phase 1.E: Dataset Preparation for LoRA Fine-tuning",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Prepare dataset from a folder of images
  python %(prog)s --input-dir my_artwork --style-name "impressionist_style"
  
  # Custom output directory and image size
  python %(prog)s --input-dir paintings --style-name "oil_painting" --output-dir datasets/oil_style --size 768
  
  # Generate captions with custom trigger word
  python %(prog)s --input-dir sketches --style-name "pencil_sketch" --trigger-word "sketch_style"
        """
    )
    
    parser.add_argument("--input-dir", required=True,
                       help="Directory containing your style images")
    
    parser.add_argument("--style-name", required=True,
                       help="Name for your style (used in file naming and captions)")
    
    parser.add_argument("--output-dir", 
                       help="Output directory for prepared dataset (default: datasets/{style_name})")
    
    parser.add_argument("--trigger-word",
                       help="Trigger word for style (default: style_name)")
    
    parser.add_argument("--size", type=int, default=DEFAULT_IMAGE_SIZE,
                       help=f"Target image size in pixels (default: {DEFAULT_IMAGE_SIZE})")
    
    parser.add_argument("--caption-template", 
                       default="a painting in {trigger_word} style",
                       help="Template for generating captions")
    
    parser.add_argument("--train-split", type=float, default=TRAIN_SPLIT_RATIO,
                       help=f"Ratio for train/validation split (default: {TRAIN_SPLIT_RATIO})")
    
    parser.add_argument("--copy-images", action="store_true",
                       help="Copy images instead of creating symlinks")
    
    parser.add_argument("--validate-only", action="store_true",
                       help="Only validate input directory without processing")
    
    return parser.parse_args()

def validate_image_directory(input_dir: str) -> Tuple[List[str], List[str]]:
    """Validate input directory and return valid/invalid image files."""
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory not found: {input_dir}")
    
    all_files = os.listdir(input_dir)
    valid_images = []
    invalid_files = []
    
    for filename in all_files:
        filepath = os.path.join(input_dir, filename)
        
        # Check if it's a file
        if not os.path.isfile(filepath):
            continue
            
        # Check extension
        ext = Path(filename).suffix.lower()
        if ext not in SUPPORTED_FORMATS:
            invalid_files.append(f"{filename} (unsupported format: {ext})")
            continue
        
        # Try to open image
        try:
            with Image.open(filepath) as img:
                # Basic validation
                if img.size[0] < 64 or img.size[1] < 64:
                    invalid_files.append(f"{filename} (too small: {img.size})")
                    continue
                    
                valid_images.append(filename)
        except Exception as e:
            invalid_files.append(f"{filename} (corrupt: {str(e)})")
    
    return valid_images, invalid_files

def process_image(input_path: str, output_path: str, target_size: int) -> Dict:
    """Process a single image for training."""
    with Image.open(input_path) as img:
        # Convert to RGB if needed
        if img.mode != 'RGB':
            img = img.convert('RGB')
        
        # Get original dimensions
        original_size = img.size
        
        # Resize maintaining aspect ratio, then center crop
        img = ImageOps.fit(img, (target_size, target_size), Image.Resampling.LANCZOS)
        
        # Save processed image
        img.save(output_path, 'PNG', quality=95)
        
        return {
            'original_size': original_size,
            'processed_size': img.size,
            'format': 'PNG'
        }

def generate_captions(image_files: List[str], caption_template: str, trigger_word: str) -> Dict[str, str]:
    """Generate captions for training images."""
    captions = {}
    
    for filename in image_files:
        # Basic caption using template
        caption = caption_template.format(trigger_word=trigger_word)
        
        # You could add more sophisticated caption generation here
        # For example, using BLIP or other image captioning models
        
        captions[filename] = caption
    
    return captions

def create_dataset_structure(output_dir: str, style_name: str):
    """Create the dataset directory structure."""
    dataset_dir = Path(output_dir)
    
    # Create main directories
    dirs_to_create = [
        dataset_dir,
        dataset_dir / "images",
        dataset_dir / "train",
        dataset_dir / "validation"
    ]
    
    for dir_path in dirs_to_create:
        dir_path.mkdir(parents=True, exist_ok=True)
    
    return dataset_dir

def split_dataset(image_files: List[str], train_ratio: float) -> Tuple[List[str], List[str]]:
    """Split images into train and validation sets."""
    random.shuffle(image_files)
    
    train_count = int(len(image_files) * train_ratio)
    train_files = image_files[:train_count]
    val_files = image_files[train_count:]
    
    return train_files, val_files

def save_metadata(dataset_dir: Path, metadata: Dict):
    """Save dataset metadata."""
    metadata_file = dataset_dir / "dataset_info.json"
    
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"πŸ“„ Dataset metadata saved to: {metadata_file}")

def create_captions_file(dataset_dir: Path, captions: Dict[str, str], split_name: str):
    """Create captions file for training."""
    captions_file = dataset_dir / f"{split_name}_captions.txt"
    
    with open(captions_file, 'w') as f:
        for filename, caption in captions.items():
            f.write(f"{filename}: {caption}\n")
    
    return captions_file

# -------- 3. MAIN PROCESSING FUNCTION --------

def prepare_dataset(args):
    """Main dataset preparation function."""
    print(f"🎨 CompI Phase 1.E: Preparing LoRA Dataset for '{args.style_name}'")
    print("=" * 60)
    
    # Setup paths
    input_dir = Path(args.input_dir)
    if args.output_dir:
        output_dir = Path(args.output_dir)
    else:
        output_dir = Path("datasets") / args.style_name
    
    trigger_word = args.trigger_word or args.style_name
    
    print(f"πŸ“ Input directory: {input_dir}")
    print(f"πŸ“ Output directory: {output_dir}")
    print(f"🎯 Style name: {args.style_name}")
    print(f"πŸ”€ Trigger word: {trigger_word}")
    print(f"πŸ“ Target size: {args.size}x{args.size}")
    
    # Validate input directory
    print(f"\nπŸ” Validating input directory...")
    valid_images, invalid_files = validate_image_directory(str(input_dir))
    
    print(f"βœ… Found {len(valid_images)} valid images")
    if invalid_files:
        print(f"⚠️  Found {len(invalid_files)} invalid files:")
        for invalid in invalid_files[:5]:  # Show first 5
            print(f"   - {invalid}")
        if len(invalid_files) > 5:
            print(f"   ... and {len(invalid_files) - 5} more")
    
    if len(valid_images) < MIN_IMAGES_RECOMMENDED:
        print(f"⚠️  Warning: Only {len(valid_images)} images found. Recommended minimum: {MIN_IMAGES_RECOMMENDED}")
        print("   Consider adding more images for better style learning.")
    
    if args.validate_only:
        print("βœ… Validation complete (--validate-only specified)")
        return
    
    # Create dataset structure
    print(f"\nπŸ“ Creating dataset structure...")
    dataset_dir = create_dataset_structure(str(output_dir), args.style_name)
    
    # Split dataset
    train_files, val_files = split_dataset(valid_images, args.train_split)
    print(f"πŸ“Š Dataset split: {len(train_files)} train, {len(val_files)} validation")
    
    # Generate captions
    print(f"\nπŸ“ Generating captions...")
    all_captions = generate_captions(valid_images, args.caption_template, trigger_word)
    
    # Process images
    print(f"\nπŸ–ΌοΈ  Processing images...")
    processed_count = 0
    processing_stats = []
    
    for split_name, file_list in [("train", train_files), ("validation", val_files)]:
        if not file_list:
            continue
            
        split_dir = dataset_dir / split_name
        split_captions = {}
        
        for filename in file_list:
            input_path = input_dir / filename
            output_filename = f"{Path(filename).stem}.png"
            output_path = split_dir / output_filename
            
            try:
                stats = process_image(str(input_path), str(output_path), args.size)
                processing_stats.append(stats)
                split_captions[output_filename] = all_captions[filename]
                processed_count += 1
                
                if processed_count % 10 == 0:
                    print(f"   Processed {processed_count}/{len(valid_images)} images...")
                    
            except Exception as e:
                print(f"❌ Error processing {filename}: {e}")
        
        # Create captions file for this split
        if split_captions:
            captions_file = create_captions_file(dataset_dir, split_captions, split_name)
            print(f"πŸ“ Created {split_name} captions: {captions_file}")
    
    # Save metadata
    metadata = {
        'style_name': args.style_name,
        'trigger_word': trigger_word,
        'total_images': len(valid_images),
        'train_images': len(train_files),
        'validation_images': len(val_files),
        'image_size': args.size,
        'caption_template': args.caption_template,
        'created_at': pd.Timestamp.now().isoformat(),
        'processing_stats': {
            'processed_count': processed_count,
            'failed_count': len(valid_images) - processed_count
        }
    }
    
    save_metadata(dataset_dir, metadata)
    
    print(f"\nπŸŽ‰ Dataset preparation complete!")
    print(f"πŸ“ Dataset location: {dataset_dir}")
    print(f"πŸ“Š Ready for LoRA training with {processed_count} processed images")
    print(f"\nπŸ’‘ Next steps:")
    print(f"   1. Review the generated dataset in: {dataset_dir}")
    print(f"   2. Run LoRA training: python src/generators/compi_phase1e_lora_training.py --dataset-dir {dataset_dir}")

def main():
    """Main function."""
    args = setup_args()
    
    try:
        prepare_dataset(args)
    except Exception as e:
        print(f"❌ Error: {e}")
        return 1
    
    return 0

if __name__ == "__main__":
    exit(main())