File size: 5,017 Bytes
f3d77fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2c43cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Utility functions for the Z-Image Turbo LoRA Generator
"""

import os
import json
from pathlib import Path
from datetime import datetime


def get_output_dir():
    """Get or create the output directory for generated LoRAs."""
    output_dir = Path("output_loras")
    output_dir.mkdir(exist_ok=True)
    return output_dir


def validate_images(images):
    """
    Validate that images are provided and in correct format.
    
    Args:
        images: Single image path or list of image paths
        
    Returns:
        tuple: (is_valid, image_list, error_message)
    """
    if not images:
        return False, [], "No images provided"
    
    # Convert to list if single image
    image_list = images if isinstance(images, list) else [images]
    
    if len(image_list) == 0:
        return False, [], "No images in list"
    
    # Validate each image exists
    for img_path in image_list:
        if not os.path.exists(str(img_path)):
            return False, [], f"Image not found: {img_path}"
    
    return True, image_list, None


def generate_training_metadata(
    project_name,
    trigger_word,
    num_images,
    training_steps,
    batch_size,
    learning_rate,
    resolution,
    rank,
    alpha,
):
    """
    Generate metadata dictionary for the LoRA training.
    
    Args:
        project_name: Name of the LoRA project
        trigger_word: Word to activate the LoRA
        num_images: Number of training images
        training_steps: Total training steps
        batch_size: Batch size for training
        learning_rate: Learning rate
        resolution: Image resolution
        rank: LoRA rank dimension
        alpha: LoRA alpha value
        
    Returns:
        dict: Metadata dictionary
    """
    return {
        "project_name": project_name,
        "trigger_word": trigger_word,
        "num_images": num_images,
        "training_config": {
            "steps": training_steps,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "resolution": resolution,
            "rank": rank,
            "alpha": alpha,
        },
        "model_info": {
            "type": "LoRA",
            "format": "safetensors",
            "compatibility": "Z-Image Turbo",
        },
        "created_at": datetime.now().isoformat(),
    }


def format_log_message(step, message):
    """
    Format a log message with timestamp.
    
    Args:
        step: Current training step
        message: Log message
        
    Returns:
        str: Formatted message
    """
    timestamp = datetime.now().strftime("%H:%M:%S")
    return f"[{timestamp}] Step {step}: {message}"


def cleanup_old_outputs(max_age_hours=24):
    """
    Clean up old output files to save disk space.
    
    Args:
        max_age_hours: Maximum age in hours for files to keep
    """
    import time
    
    output_dir = get_output_dir()
    current_time = time.time()
    max_age_seconds = max_age_hours * 3600
    
    for item in output_dir.iterdir():
        if item.is_file():
            file_age = current_time - item.stat().st_mtime
            if file_age > max_age_seconds:
                item.unlink()
        elif item.is_dir():
            # Check directory age
            dir_age = current_time - item.stat().st_mtime
            if dir_age > max_age_seconds:
                import shutil
                shutil.rmtree(item)


# Example utility for real implementation (not used in demo)
def create_training_command(
    images_dir,
    output_dir,
    trigger_word,
    rank=16,
    alpha=16,
    learning_rate=1e-4,
    steps=500,
    batch_size=1,
    resolution=512,
):
    """
    Create a Kohya LoRA training command (for reference).
    
    This would be used in a real implementation with actual LoRA training.
    """
    return [
        "python", "train_network.py",
        "--pretrained_model", "v1-5-pruned.safetensors",
        "--train_data_dir", str(images_dir),
        "--output_dir", str(output_dir),
        "--output_name", "lora",
        "--network_module", "networks.lora",
        "--network_dim", str(rank),
        "--network_alpha", str(alpha),
        "--train_batch_size", str(batch_size),
        "--learning_rate", str(learning_rate),
        "--max_train_steps", str(steps),
        "--resolution", f"{resolution},{resolution}",
        "--clip_skip", "2",
        "--enable_bucket",
        "--caption_column", "text",
        "--shuffle_caption",
        "--weighted_captions",
    ]


if __name__ == "__main__":
    # Test utilities
    print("Testing utilities...")
    print(f"Output directory: {get_output_dir()}")
    
    metadata = generate_training_metadata(
        project_name="test_lora",
        trigger_word="test_style",
        num_images=10,
        training_steps=500,
        batch_size=1,
        learning_rate=1e-4,
        resolution=512,
        rank=16,
        alpha=16,
    )
    print(f"Metadata: {json.dumps(metadata, indent=2)}")
    
    print("Utilities test complete!")