File size: 12,680 Bytes
cc5958e
 
 
 
 
 
 
 
 
 
 
 
 
6287e96
fb609fe
6287e96
fb609fe
6287e96
 
 
fb609fe
6287e96
 
 
 
 
 
 
cc5958e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb609fe
6287e96
cc5958e
 
 
 
 
 
 
 
 
 
 
6287e96
fb609fe
 
 
 
 
 
 
 
 
6287e96
fb609fe
6287e96
 
fb609fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713f69a
fb609fe
 
 
 
 
713f69a
fb609fe
 
 
6287e96
fb609fe
713f69a
 
6287e96
fb609fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6287e96
fb609fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6287e96
 
 
fb609fe
 
6287e96
 
 
713f69a
 
 
 
 
fb609fe
 
 
 
 
6287e96
 
fb609fe
6287e96
fb609fe
6287e96
fb609fe
 
6287e96
fb609fe
6287e96
 
 
 
cc5958e
6287e96
 
 
cc5958e
7537a36
 
 
cc5958e
 
 
 
 
 
 
6287e96
fb609fe
cc5958e
 
 
 
 
 
 
 
 
 
7537a36
 
cc5958e
 
 
fb609fe
7537a36
 
 
 
 
 
fb609fe
6287e96
 
 
fb609fe
6287e96
 
 
 
fb609fe
6287e96
fb609fe
 
 
cc5958e
fb609fe
cc5958e
fb609fe
cc5958e
fb609fe
 
713f69a
 
 
 
 
 
 
 
 
fb609fe
713f69a
fb609fe
6287e96
fb609fe
6287e96
 
 
 
 
 
fb609fe
6287e96
 
 
 
fb609fe
6287e96
 
fb609fe
6287e96
 
fb609fe
6287e96
fb609fe
 
6287e96
 
 
 
 
 
 
 
 
 
fb609fe
 
6287e96
fb609fe
 
6287e96
 
fb609fe
6287e96
fb609fe
6287e96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""Pepe Meme Generator - Core generation logic.

This module contains the main PepeGenerator class which handles:
- Loading and caching Stable Diffusion models
- Managing LoRA and LCM-LoRA adapters
- Configuring schedulers and optimizations
- Generating images from text prompts
- Progress tracking during generation

The generator supports multiple models, automatic GPU/CPU detection,
memory optimizations, and both standard and fast (LCM) inference modes.

"""

from typing import Optional, List, Callable
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
import streamlit as st
from PIL import Image
import logging
import os

from .config import ModelConfig

logger = logging.getLogger(__name__)


class PepeGenerator:
    """
    Main generator class for creating Pepe meme images.
    
    This class manages the entire image generation pipeline including:
    - Model loading and caching (with Streamlit cache_resource)
    - LoRA and LCM-LoRA adapter management
    - Scheduler configuration (DPM Solver or LCM)
    - Memory optimizations (attention slicing, VAE slicing, xformers)
    - Device management (automatic CUDA/CPU detection)
    - Progress tracking callbacks
    
    The generator is designed to work efficiently on both GPU and CPU,
    with automatic optimizations based on available hardware.
    
    Attributes:
        config: ModelConfig instance with generation settings
        device: Torch device ('cuda' or 'cpu')
        pipe: Cached StableDiffusionPipeline instance
    """

    def __init__(self, config: Optional[ModelConfig] = None):
        """
        Initialize the Pepe generator with configuration.
        
        Sets up the generator by determining the compute device (GPU/CPU),
        loading the model pipeline, and caching it for reuse. The model
        loading is cached using Streamlit's cache_resource decorator to avoid
        reloading on every interaction.
        
        Args:
            config: ModelConfig instance. If None, uses default configuration.
        """
        self.config = config or ModelConfig()
        self.device = self._get_device(self.config.FORCE_CPU)
        self.pipe = self._load_model(
            self.config.BASE_MODEL,
            self.config.USE_LORA,
            self.config.LORA_PATH,
            self.config.FORCE_CPU,
            self.config.USE_LCM,
            self.config.LCM_LORA_PATH
        )
        logger.info(f"PepeGenerator initialized on {self.device}")

    @staticmethod
    @st.cache_resource
    def _load_model(base_model: str, use_lora: bool, lora_path: Optional[str], 
                    force_cpu: bool = False, use_lcm: bool = False, 
                    lcm_lora_path: Optional[str] = None) -> StableDiffusionPipeline:
        """Load and cache the Stable Diffusion model with LoRA and LCM support"""
        logger.info("="*60)
        logger.info("LOADING NEW MODEL PIPELINE")
        logger.info(f"Base Model: {base_model}")
        logger.info(f"LoRA Enabled: {use_lora}")
        if use_lora and lora_path:
            logger.info(f"LoRA Path: {lora_path}")
        logger.info(f"LCM Enabled: {use_lcm}")
        if use_lcm and lcm_lora_path:
            logger.info(f"LCM-LoRA Path: {lcm_lora_path}")
        logger.info(f"Force CPU: {force_cpu}")
        logger.info("="*60)

        # Determine appropriate dtype based on device
        if force_cpu:
            device = "cpu"
            logger.info("🔧 FORCED CPU MODE - GPU disabled for testing")
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        torch_dtype = torch.float16 if (device == "cuda" and not force_cpu) else torch.float32
        logger.info(f"Using device: {device}, dtype: {torch_dtype}")

        pipe = StableDiffusionPipeline.from_pretrained(
            base_model,
            torch_dtype=torch_dtype,
            safety_checker=None,  # Disabled for meme generation - users must comply with SD license
        )

        # Load LoRA weights if configured
        if use_lora and lora_path:
            logger.info(f"Loading LoRA weights from: {lora_path}")
            try:
                # Check if it's a local path or Hugging Face model ID
                # Explicitly name it "pepe" to avoid "default_0" naming
                if os.path.exists(lora_path):
                    # Local path
                    pipe.load_lora_weights(lora_path, adapter_name="pepe")
                    logger.info("LoRA weights loaded successfully from local path")
                elif "/" in lora_path:
                    # Hugging Face model ID (format: username/model_name)
                    pipe.load_lora_weights(lora_path, adapter_name="pepe")
                    logger.info(f"✅ LoRA weights loaded successfully from Hugging Face: {lora_path}")
                else:
                    logger.warning(f"Invalid LoRA path format: {lora_path}")
                
                # If not using LCM, set Pepe LoRA as the active adapter
                if not use_lcm:
                    pipe.set_adapters(["pepe"])
                    logger.info("✅ Pepe LoRA active")
            except Exception as e:
                logger.error(f"Failed to load LoRA weights: {e}")
                logger.info("Continuing without LoRA weights...")
        
        # Load LCM-LoRA on top if configured (this enables fast inference!)
        if use_lcm and lcm_lora_path:
            logger.info(f"Loading LCM-LoRA from: {lcm_lora_path}")
            try:
                # Load LCM-LoRA as a separate adapter
                pipe.load_lora_weights(lcm_lora_path, adapter_name="lcm")
                logger.info("✅ LCM-LoRA loaded successfully")
                
                # If we have both Pepe LoRA and LCM-LoRA, fuse them
                if use_lora:
                    logger.info("Fusing Pepe LoRA + LCM-LoRA adapters...")
                    # Use the correct adapter names: "pepe" and "lcm"
                    pipe.set_adapters(["pepe", "lcm"], adapter_weights=[1.0, 1.0])
                    logger.info("✅ Both LoRAs fused successfully (pepe + lcm)")
                else:
                    # Only LCM, set it as active
                    pipe.set_adapters(["lcm"])
                    logger.info("✅ LCM-LoRA active (solo mode)")
            except Exception as e:
                logger.error(f"Failed to load LCM-LoRA: {e}")
                logger.info("Continuing without LCM...")
                use_lcm = False

        # Set appropriate scheduler based on LCM mode
        if use_lcm:
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
            logger.info("⚡ Using LCM Scheduler (few-step mode)")
        else:
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                pipe.scheduler.config
            )
            logger.info("🔧 Using DPM Solver Scheduler (standard mode)")

        # Enable memory optimizations
        pipe.enable_attention_slicing()
        pipe.enable_vae_slicing()

        if device == "cuda" and not force_cpu:
            pipe = pipe.to("cuda")
            try:
                pipe.enable_xformers_memory_efficient_attention()
            except ImportError:
                logger.warning("xformers not available, using default attention")
            except Exception as e:
                logger.warning(f"Could not enable xformers: {e}")
        else:
            if force_cpu:
                logger.info("Running on CPU - FORCED for testing")
            else:
                logger.info("Running on CPU - memory optimizations applied")

        logger.info("Model loaded successfully")
        return pipe

    @staticmethod
    def _get_device(force_cpu: bool = False) -> str:
        """Determine the best available device"""
        if force_cpu:
            return "cpu"
        return "cuda" if torch.cuda.is_available() else "cpu"

    def generate(
        self,
        prompt: str,
        negative_prompt: Optional[str] = None,
        num_inference_steps: int = 25,
        guidance_scale: float = 7.5,
        width: int = 512,
        height: int = 512,
        seed: Optional[int] = None,
        progress_callback: Optional[Callable[[int, int], None]] = None,
        style: str = "default",
        raw_prompt: bool = False
    ) -> Image:
        """
        Generate a Pepe meme image from a text prompt.
        
        This method runs the diffusion process to generate an image based on
        the provided text prompt. It supports various parameters to control
        the generation quality, style, and randomness.
        
        Args:
            prompt: Text description of the desired image. For best results with
                the fine-tuned model, include the trigger word 'pepe_style_frog'.
            negative_prompt: Text describing what to avoid in the image.
                If None, uses default from config.
            num_inference_steps: Number of denoising steps (4-8 for LCM, 20-50 normal).
            guidance_scale: CFG scale (1.0-2.0 for LCM, 5.0-15.0 normal).
            width: Output image width in pixels (must be divisible by 8).
            height: Output image height in pixels (must be divisible by 8).
            seed: Random seed for reproducible generation.
            progress_callback: Optional callback(current_step, total_steps).
            style: Style preset to apply ('default', 'happy', 'sad', 'smug', 'angry', 'thinking', 'surprised').
            raw_prompt: If True, use prompt as-is without trigger words or style modifiers.
        
        Returns:
            PIL Image object containing the generated image.
        """
        # Handle raw prompt mode - use prompt as-is if requested
        if raw_prompt:
            enhanced_prompt = prompt
        else:
            # Apply style preset if not in raw mode
            enhanced_prompt = self._apply_style_preset(prompt, style) if style != "default" else prompt

        # Set default negative prompt
        if negative_prompt is None:
            negative_prompt = self.config.DEFAULT_NEGATIVE_PROMPT

        # Set seed for reproducibility
        generator = None
        if seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(seed)

        logger.info(f"Generating: {enhanced_prompt[:50]}...")
        logger.debug(f"Full prompt: {enhanced_prompt}")
        logger.debug(f"Model config - Base: {self.config.BASE_MODEL}, LoRA: {self.config.USE_LORA}")

        # Create callback wrapper if provided (using new diffusers API)
        callback_on_step_end_fn = None
        if progress_callback:
            def callback_on_step_end_fn(pipe, step, timestep, callback_kwargs):
                progress_callback(step + 1, num_inference_steps)
                return callback_kwargs

        # Generate image (removed autocast for CPU compatibility)
        output = self.pipe(
            prompt=enhanced_prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            width=width,
            height=height,
            callback_on_step_end=callback_on_step_end_fn,
        )

        return output.images[0]

    def generate_batch(
        self,
        prompt: str,
        num_images: int = 4,
        **kwargs
    ) -> List[Image.Image]:
        """Generate multiple variations with callback support"""
        images = []
        for i in range(num_images):
            if 'seed' not in kwargs:
                kwargs['seed'] = torch.randint(0, 100000, (1,)).item()

            image = self.generate(prompt, **kwargs)
            images.append(image)

            if 'seed' in kwargs:
                del kwargs['seed']

        return images

    def _apply_style_preset(self, prompt: str, style: str) -> str:
        """Apply style-specific prompt enhancements"""
        style_modifiers = {
            "happy": "cheerful, smiling, joyful",
            "sad": "melancholic, crying, emotional",
            "smug": "confident, satisfied, smirking",
            "angry": "frustrated, mad, intense",
            "thinking": "contemplative, philosophical",
            "surprised": "shocked, amazed, wide eyes",
        }
        
        # Use trigger word from config
        trigger_word = self.config.TRIGGER_WORD
        
        base = f"{trigger_word}, {prompt}"

        if style in style_modifiers:
            base = f"{base}, {style_modifiers[style]}"

        base = f"{base}, high quality, detailed, meme art"

        return base