File size: 8,406 Bytes
cce2b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""

SDXL Image Generation Engine

Handles model loading, generation, and optimization

"""

import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
from typing import Optional, Tuple
import warnings

from config import (
    MODEL_ID, 
    REFINER_ID, 
    USE_REFINER,
    DEFAULT_WIDTH,
    DEFAULT_HEIGHT,
    DEFAULT_GUIDANCE_SCALE,
    DEFAULT_NUM_STEPS,
    DEFAULT_REFINER_STEPS
)

# Suppress warnings
warnings.filterwarnings("ignore")


class ImageGenerator:
    """

    SDXL-based image generation with optional refiner

    """
    
    def __init__(self, use_refiner: bool = USE_REFINER):
        """

        Initialize the image generator

        

        Args:

            use_refiner: Whether to use the refiner model for better quality

        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.use_refiner = use_refiner
        self.base_pipe = None
        self.refiner_pipe = None
        self._initialized = False
        
        print(f"Using device: {self.device}")
    
    def load_models(self):
        """

        Load SDXL base and optional refiner models

        """
        if self._initialized:
            return
        
        print("Loading SDXL base model...")
        print("This may take a few minutes on first load...")
        
        try:
            # Load base model
            self.base_pipe = DiffusionPipeline.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                use_safetensors=True,
                variant="fp16" if self.device == "cuda" else None
            )
            
            # Optimize for memory
            self.base_pipe.to(self.device)
            
            # Use better scheduler
            self.base_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.base_pipe.scheduler.config
            )
            
            # Enable memory optimizations
            if self.device == "cuda":
                self.base_pipe.enable_attention_slicing()
                # Try to enable xformers if available
                try:
                    self.base_pipe.enable_xformers_memory_efficient_attention()
                except:
                    pass
            
            print("βœ… Base model loaded successfully!")
            
            # Load refiner if requested
            if self.use_refiner:
                print("Loading refiner model...")
                self.refiner_pipe = DiffusionPipeline.from_pretrained(
                    REFINER_ID,
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                    use_safetensors=True,
                    variant="fp16" if self.device == "cuda" else None
                )
                self.refiner_pipe.to(self.device)
                
                if self.device == "cuda":
                    self.refiner_pipe.enable_attention_slicing()
                
                print("βœ… Refiner model loaded successfully!")
            
            self._initialized = True
            
        except Exception as e:
            print(f"❌ Error loading models: {e}")
            raise
    
    def generate(

        self,

        prompt: str,

        negative_prompt: str = "",

        width: int = DEFAULT_WIDTH,

        height: int = DEFAULT_HEIGHT,

        guidance_scale: float = DEFAULT_GUIDANCE_SCALE,

        num_inference_steps: int = DEFAULT_NUM_STEPS,

        seed: int = -1,

        use_refiner: Optional[bool] = None

    ) -> Tuple[Image.Image, dict]:
        """

        Generate an image from a prompt

        

        Args:

            prompt: Text prompt for generation

            negative_prompt: Negative prompt to avoid certain features

            width: Image width

            height: Image height

            guidance_scale: How closely to follow the prompt (7-9 recommended)

            num_inference_steps: Number of denoising steps (30-50 recommended)

            seed: Random seed for reproducibility (-1 for random)

            use_refiner: Override default refiner setting

            

        Returns:

            Tuple of (generated_image, metadata)

        """
        # Ensure models are loaded
        if not self._initialized:
            self.load_models()
        
        # Handle seed
        if seed == -1:
            seed = torch.randint(0, 2**32 - 1, (1,)).item()
        
        generator = torch.Generator(device=self.device).manual_seed(seed)
        
        # Determine if we should use refiner
        use_refiner_now = use_refiner if use_refiner is not None else self.use_refiner
        
        try:
            print(f"Generating image with seed: {seed}")
            
            # Generate with base model
            if use_refiner_now and self.refiner_pipe is not None:
                # Generate latent with base, refine with refiner
                image = self.base_pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    width=width,
                    height=height,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator,
                    output_type="latent"
                ).images[0]
                
                # Refine the latent
                print("Refining image...")
                image = self.refiner_pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    image=image,
                    num_inference_steps=DEFAULT_REFINER_STEPS,
                    generator=generator
                ).images[0]
            else:
                # Generate directly
                image = self.base_pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    width=width,
                    height=height,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator
                ).images[0]
            
            # Metadata
            metadata = {
                "prompt": prompt,
                "negative_prompt": negative_prompt,
                "seed": seed,
                "width": width,
                "height": height,
                "guidance_scale": guidance_scale,
                "steps": num_inference_steps,
                "refiner_used": use_refiner_now
            }
            
            print("βœ… Image generated successfully!")
            return image, metadata
            
        except Exception as e:
            print(f"❌ Generation error: {e}")
            raise
    
    def unload_models(self):
        """

        Unload models to free memory

        """
        if self.base_pipe is not None:
            del self.base_pipe
            self.base_pipe = None
        
        if self.refiner_pipe is not None:
            del self.refiner_pipe
            self.refiner_pipe = None
        
        if self.device == "cuda":
            torch.cuda.empty_cache()
        
        self._initialized = False
        print("Models unloaded")


# Test function
if __name__ == "__main__":
    print("=== Image Generator Test ===\n")
    
    generator = ImageGenerator(use_refiner=False)
    generator.load_models()
    
    test_prompt = "A beautiful sunset over mountains, highly detailed, photorealistic"
    test_negative = "blurry, low quality, distorted"
    
    print(f"\nGenerating test image...")
    print(f"Prompt: {test_prompt}")
    
    image, metadata = generator.generate(
        prompt=test_prompt,
        negative_prompt=test_negative,
        width=512,  # Smaller for testing
        height=512,
        num_inference_steps=20,  # Fewer steps for testing
        seed=42
    )
    
    # Save test image
    output_path = "test_output.png"
    image.save(output_path)
    print(f"\nβœ… Test image saved to: {output_path}")
    print(f"Metadata: {metadata}")