File size: 8,950 Bytes
3f9fa87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/env python3
"""
Generate images using Illustrious model from augmented prompts.
Supports resuming from interruptions by checking existing files.
"""

import json
import os
import random
import argparse
import hashlib
from pathlib import Path
from tqdm import tqdm
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class IllustriousImageGenerator:
    def __init__(self, model_path, output_dir, jsonl_path):
        self.model_path = model_path
        self.output_dir = Path(output_dir)
        self.jsonl_path = jsonl_path
        self.pipe = None
        
        # Image dimensions to choose from
        self.dimensions = [512, 768, 1024, 1536, 2048]
        
        # Create output directories
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.metadata_dir = self.output_dir / "metadata"
        self.metadata_dir.mkdir(exist_ok=True)
        
    def load_model(self):
        """Load the Illustrious model"""
        logger.info(f"Loading model from {self.model_path}")
        try:
            self.pipe = StableDiffusionXLPipeline.from_single_file(
                self.model_path,
                torch_dtype=torch.float16,
                use_safetensors=True,
            )
            self.pipe.to("cuda")
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise
    
    def generate_filename_hash(self, prompt_data, width, height):
        """Generate a unique filename hash based on prompt and dimensions"""
        content = f"{prompt_data['positive_prompt']}_{prompt_data['negative_prompt']}_{width}_{height}"
        return hashlib.md5(content.encode()).hexdigest()[:12]
    
    def is_already_generated(self, filename_hash):
        """Check if image with this hash already exists"""
        image_path = self.output_dir / f"{filename_hash}.png"
        metadata_path = self.metadata_dir / f"{filename_hash}.json"
        return image_path.exists() and metadata_path.exists()
    
    def load_prompts(self):
        """Load prompts from JSONL file"""
        prompts = []
        try:
            with open(self.jsonl_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        prompt_data = json.loads(line.strip())
                        prompts.append(prompt_data)
                    except json.JSONDecodeError as e:
                        logger.warning(f"Error parsing line {line_num}: {e}")
                        continue
            logger.info(f"Loaded {len(prompts)} prompts from {self.jsonl_path}")
            return prompts
        except Exception as e:
            logger.error(f"Error loading prompts: {e}")
            raise
    
    def get_random_dimensions(self):
        """Get random width and height from available dimensions"""
        width = random.choice(self.dimensions)
        height = random.choice(self.dimensions)
        return width, height
    
    def save_metadata(self, filename_hash, prompt_data, width, height, generation_params):
        """Save metadata for the generated image"""
        metadata = {
            "filename_hash": filename_hash,
            "original_prompt_data": prompt_data,
            "generation_parameters": {
                "width": width,
                "height": height,
                **generation_params
            },
            "model_info": {
                "model_path": self.model_path,
                "model_type": "StableDiffusionXL",
                "torch_dtype": "float16"
            }
        }
        
        metadata_path = self.metadata_dir / f"{filename_hash}.json"
        with open(metadata_path, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    def generate_single_image(self, prompt_data, width, height, num_inference_steps=35, guidance_scale=7.5):
        """Generate a single image from prompt data"""
        try:
            positive_prompt = prompt_data.get('positive_prompt', '')
            negative_prompt = prompt_data.get('negative_prompt', '')
            
            # Generate image
            image = self.pipe(
                prompt=positive_prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                width=width,
                height=height
            ).images[0]
            
            return image, {
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale
            }
            
        except Exception as e:
            logger.error(f"Error generating image: {e}")
            raise
    
    def generate_images(self, max_images=None, num_inference_steps=35, guidance_scale=7.5):
        """Generate images from all prompts"""
        # Load model if not already loaded
        if self.pipe is None:
            self.load_model()
        
        # Load prompts
        prompts = self.load_prompts()
        
        if max_images:
            prompts = prompts[:max_images]
        
        generated_count = 0
        skipped_count = 0
        
        # Set random seed for reproducible dimension selection
        random.seed(42)
        
        logger.info(f"Starting generation for {len(prompts)} prompts")
        
        for i, prompt_data in enumerate(tqdm(prompts, desc="Generating images")):
            try:
                # Get random dimensions
                width, height = self.get_random_dimensions()
                
                # Generate filename hash
                filename_hash = self.generate_filename_hash(prompt_data, width, height)
                
                # Check if already generated
                if self.is_already_generated(filename_hash):
                    logger.info(f"Skipping {filename_hash} - already exists")
                    skipped_count += 1
                    continue
                
                # Generate image
                logger.info(f"Generating image {i+1}/{len(prompts)} - {width}x{height}")
                image, generation_params = self.generate_single_image(
                    prompt_data, width, height, num_inference_steps, guidance_scale
                )
                
                # Save image
                image_path = self.output_dir / f"{filename_hash}.png"
                image.save(image_path)
                
                # Save metadata
                self.save_metadata(filename_hash, prompt_data, width, height, generation_params)
                
                generated_count += 1
                logger.info(f"Saved image: {image_path}")
                
            except Exception as e:
                logger.error(f"Error processing prompt {i+1}: {e}")
                continue
        
        logger.info(f"Generation complete! Generated: {generated_count}, Skipped: {skipped_count}")
    
    def cleanup(self):
        """Clean up resources"""
        if self.pipe is not None:
            del self.pipe
            torch.cuda.empty_cache()


def main():
    parser = argparse.ArgumentParser(description="Generate images using Illustrious model")
    parser.add_argument("--model-path", 
                       default="models/waiNSFWIllustrious_v140.safetensors",
                       help="Path to the Illustrious model file")
    parser.add_argument("--jsonl-path", 
                       default="augmented_prompts.jsonl",
                       help="Path to the JSONL file containing prompts")
    parser.add_argument("--output-dir", 
                       default="illustrious_generated",
                       help="Output directory for generated images")
    parser.add_argument("--max-images", type=int, default=None,
                       help="Maximum number of images to generate (for testing)")
    parser.add_argument("--num-inference-steps", type=int, default=35,
                       help="Number of inference steps")
    parser.add_argument("--guidance-scale", type=float, default=7.5,
                       help="Guidance scale for generation")
    
    args = parser.parse_args()
    
    # Create generator
    generator = IllustriousImageGenerator(
        model_path=args.model_path,
        output_dir=args.output_dir,
        jsonl_path=args.jsonl_path
    )
    
    try:
        # Generate images
        generator.generate_images(
            max_images=args.max_images,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale
        )
    finally:
        # Clean up
        generator.cleanup()


if __name__ == "__main__":
    main()