File size: 8,036 Bytes
80b58c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""

Dataset Preparation Tool

Prepare and preprocess image-text datasets for training

"""

import argparse
from pathlib import Path
from PIL import Image
import json
import shutil
from typing import List, Tuple


def prepare_dataset(

    input_dir: str,

    output_dir: str,

    image_size: int = 512,

    min_resolution: int = 256,

    filter_low_quality: bool = True,

):
    """

    Prepare dataset for training

    

    Args:

        input_dir: Directory with raw images

        output_dir: Output directory for processed data

        image_size: Target image size

        min_resolution: Minimum acceptable resolution

        filter_low_quality: Filter out low quality images

    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    # Create output directories
    output_path.mkdir(parents=True, exist_ok=True)
    (output_path / "images").mkdir(exist_ok=True)
    (output_path / "captions").mkdir(exist_ok=True)
    
    # Find all images
    image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
    image_files = []
    
    for ext in image_extensions:
        image_files.extend(input_path.glob(f"*{ext}"))
        image_files.extend(input_path.glob(f"**/*{ext}"))
    
    print(f"Found {len(image_files)} images")
    
    # Process each image
    processed_count = 0
    skipped_count = 0
    
    for img_file in image_files:
        try:
            process_image(
                img_path=img_file,
                output_img_path=output_path / "images" / f"{img_file.stem}.jpg",
                caption_path=output_path / "captions" / f"{img_file.stem}.txt",
                image_size=image_size,
                min_resolution=min_resolution,
                filter_low_quality=filter_low_quality,
            )
            processed_count += 1
            
            if processed_count % 10 == 0:
                print(f"Processed: {processed_count}/{len(image_files)}")
                
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
            skipped_count += 1
    
    # Save metadata
    metadata = {
        'total_images': processed_count,
        'skipped_images': skipped_count,
        'image_size': image_size,
        'min_resolution': min_resolution,
    }
    
    with open(output_path / "metadata.json", 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"\n✓ Dataset preparation complete!")
    print(f"  Processed: {processed_count} images")
    print(f"  Skipped: {skipped_count} images")
    print(f"  Output: {output_path}")


def process_image(

    img_path: Path,

    output_img_path: Path,

    caption_path: Path,

    image_size: int = 512,

    min_resolution: int = 256,

    filter_low_quality: bool = True,

):
    """

    Process single image

    

    Args:

        img_path: Input image path

        output_img_path: Output image path

        caption_path: Output caption path

        image_size: Target size

        min_resolution: Minimum resolution

        filter_low_quality: Filter low quality

    """
    # Load image
    image = Image.open(img_path).convert('RGB')
    
    # Check resolution
    width, height = image.size
    
    if width < min_resolution or height < min_resolution:
        raise ValueError(f"Image too small: {width}x{height}")
    
    # Resize if necessary
    if min(width, height) > image_size * 1.5:
        # Downscale large images
        scale = image_size / max(width, height)
        new_width = int(width * scale)
        new_height = int(height * scale)
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # Center crop to square
    size = min(image.size)
    left = (image.size[0] - size) // 2
    top = (image.size[1] - size) // 2
    image = image.crop((left, top, left + size, top + size))
    
    # Resize to target size
    image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
    
    # Save processed image
    image.save(output_img_path, quality=95, optimize=True)
    
    # Generate or load caption
    caption = generate_caption(img_path)
    
    with open(caption_path, 'w', encoding='utf-8') as f:
        f.write(caption)


def generate_caption(img_path: Path) -> str:
    """

    Generate caption from image filename or load from adjacent text file

    

    Args:

        img_path: Path to image

        

    Returns:

        Caption text

    """
    # Try to load from adjacent .txt file
    txt_file = img_path.with_suffix('.txt')
    
    if txt_file.exists():
        with open(txt_file, 'r', encoding='utf-8') as f:
            caption = f.read().strip()
            if caption:
                return caption
    
    # Use filename as fallback
    caption = img_path.stem.replace('_', ' ').replace('-', ' ')
    
    # Capitalize first letter
    caption = caption.capitalize()
    
    return caption


def create_training_splits(

    data_dir: str,

    train_ratio: float = 0.9,

    val_ratio: float = 0.05,

    test_ratio: float = 0.05,

):
    """

    Create train/val/test splits

    

    Args:

        data_dir: Directory with processed data

        train_ratio: Training set ratio

        val_ratio: Validation set ratio

        test_ratio: Test set ratio

    """
    data_path = Path(data_dir)
    
    # Get all images
    images = list((data_path / "images").glob("*.jpg"))
    
    # Shuffle deterministically
    import random
    random.seed(42)
    random.shuffle(images)
    
    # Calculate split sizes
    total = len(images)
    train_size = int(total * train_ratio)
    val_size = int(total * val_ratio)
    
    # Split datasets
    train_images = images[:train_size]
    val_images = images[train_size:train_size + val_size]
    test_images = images[train_size + val_size:]
    
    # Save splits
    def save_split(image_list, split_name):
        split_data = {
            'images': [str(img.name) for img in image_list],
            'count': len(image_list),
        }
        
        with open(data_path / f"{split_name}.json", 'w') as f:
            json.dump(split_data, f, indent=2)
        
        print(f"{split_name}: {len(image_list)} images")
    
    save_split(train_images, "train")
    save_split(val_images, "validation")
    save_split(test_images, "test")
    
    print(f"\n✓ Created training splits")
    print(f"  Total: {total} images")


def main():
    parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training")
    
    parser.add_argument(
        "--input", "-i",
        type=str,
        required=True,
        help="Input directory with raw images"
    )
    
    parser.add_argument(
        "--output", "-o",
        type=str,
        default="./processed_data",
        help="Output directory for processed data"
    )
    
    parser.add_argument(
        "--size", "-s",
        type=int,
        default=512,
        help="Target image size (default: 512)"
    )
    
    parser.add_argument(
        "--min_res",
        type=int,
        default=256,
        help="Minimum image resolution (default: 256)"
    )
    
    parser.add_argument(
        "--no_filter",
        action="store_true",
        help="Disable low quality filtering"
    )
    
    parser.add_argument(
        "--create_splits",
        action="store_true",
        help="Create train/val/test splits"
    )
    
    args = parser.parse_args()
    
    # Prepare dataset
    prepare_dataset(
        input_dir=args.input,
        output_dir=args.output,
        image_size=args.size,
        min_resolution=args.min_res,
        filter_low_quality=not args.no_filter,
    )
    
    # Create splits if requested
    if args.create_splits:
        create_training_splits(args.output)


if __name__ == "__main__":
    main()