File size: 3,676 Bytes
0e3999b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""

Create Test Dataset

Generate sample images for testing the training pipeline

"""

from PIL import Image, ImageDraw, ImageFont
import numpy as np
from pathlib import Path
import random


def create_test_dataset(output_dir: str = "./dataset", num_images: int = 10):
    """

    Create a test dataset with synthetic images

    

    Args:

        output_dir: Output directory

        num_images: Number of test images to create

    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    print(f"Creating {num_images} test images in {output_path}...")
    
    # Color palettes
    colors = [
        (255, 99, 71),    # Tomato
        (64, 224, 208),   # Turquoise
        (255, 215, 0),    # Gold
        (138, 43, 226),   # Blue Violet
        (50, 205, 50),    # Lime Green
        (255, 165, 0),    # Orange
        (219, 112, 147),  # Pale Violet Red
        (70, 130, 180),   # Steel Blue
        (255, 192, 203),  # Pink
        (144, 238, 144),  # Light Green
    ]
    
    shapes = ['circle', 'rectangle', 'triangle']
    
    captions = []
    
    for i in range(num_images):
        # Generate random image
        img_size = 512
        img = Image.new('RGB', (img_size, img_size), color=(240, 240, 240))
        draw = ImageDraw.Draw(img)
        
        # Random parameters
        num_shapes = random.randint(3, 8)
        bg_color = random.choice(colors)
        
        # Draw background gradient
        for y in range(img_size):
            r = int(bg_color[0] * (0.8 + 0.2 * y / img_size))
            g = int(bg_color[1] * (0.8 + 0.2 * y / img_size))
            b = int(bg_color[2] * (0.8 + 0.2 * y / img_size))
            draw.line([(0, y), (img_size, y)], fill=(r, g, b))
        
        # Draw random shapes
        for _ in range(num_shapes):
            shape = random.choice(shapes)
            color = tuple(random.randint(50, 255) for _ in range(3))
            
            x1 = random.randint(50, img_size - 50)
            y1 = random.randint(50, img_size - 50)
            size = random.randint(30, 100)
            
            if shape == 'circle':
                bbox = [x1, y1, x1 + size, y1 + size]
                draw.ellipse(bbox, fill=color, outline=(0, 0, 0))
            elif shape == 'rectangle':
                bbox = [x1, y1, x1 + size, y1 + size // 2]
                draw.rectangle(bbox, fill=color, outline=(0, 0, 0))
            elif shape == 'triangle':
                points = [
                    (x1, y1),
                    (x1 + size, y1),
                    (x1 + size // 2, y1 + size)
                ]
                draw.polygon(points, fill=color, outline=(0, 0, 0))
        
        # Save image
        img_path = output_path / f"test_image_{i+1:03d}.jpg"
        img.save(img_path, quality=95)
        
        # Create caption
        caption = f"A colorful abstract composition with {num_shapes} geometric shapes on a {bg_color[0]} background"
        captions.append(caption)
        
        # Save caption
        caption_path = output_path / f"test_image_{i+1:03d}.txt"
        with open(caption_path, 'w', encoding='utf-8') as f:
            f.write(caption)
        
        print(f"  Created: {img_path.name}")
    
    print(f"\n✓ Test dataset created successfully!")
    print(f"  Location: {output_path.absolute()}")
    print(f"  Images: {num_images}")
    print(f"\nTo train with this dataset:")
    print(f"  python train.py --config config.yaml --train_data {output_path}")


if __name__ == "__main__":
    create_test_dataset()