File size: 19,630 Bytes
2a5411f
 
 
 
bf7d351
 
 
2a5411f
 
bf7d351
45852d0
bf7d351
2a5411f
bf7d351
2a5411f
 
bf7d351
 
2a5411f
 
bf7d351
 
 
 
2a5411f
bf7d351
 
 
 
 
 
 
45852d0
 
bf7d351
2a5411f
bf7d351
2a5411f
bf7d351
12eb42e
bf7d351
12eb42e
 
bf7d351
12eb42e
bf7d351
 
 
 
45852d0
 
bf7d351
45852d0
bf7d351
2a5411f
bf7d351
45852d0
bf7d351
 
 
45852d0
bf7d351
45852d0
bf7d351
45852d0
bf7d351
45852d0
bf7d351
45852d0
bf7d351
45852d0
bf7d351
45852d0
bf7d351
 
45852d0
bf7d351
45852d0
 
bf7d351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45852d0
bf7d351
45852d0
bf7d351
 
 
2a5411f
bf7d351
2a5411f
bf7d351
2a5411f
bf7d351
2a5411f
bf7d351
45852d0
bf7d351
 
 
45852d0
 
bf7d351
45852d0
bf7d351
 
45852d0
 
bf7d351
45852d0
 
bf7d351
45852d0
bf7d351
2a5411f
bf7d351
2a5411f
bf7d351
 
2a5411f
45852d0
bf7d351
 
 
45852d0
2a5411f
 
bf7d351
 
 
 
2a5411f
bf7d351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a5411f
 
 
bf7d351
45852d0
2a5411f
bf7d351
 
 
 
 
 
 
2a5411f
45852d0
bf7d351
 
45852d0
bf7d351
 
45852d0
 
 
 
bf7d351
 
 
 
 
45852d0
 
bf7d351
45852d0
bf7d351
45852d0
 
bf7d351
 
45852d0
 
bf7d351
45852d0
 
bf7d351
45852d0
 
 
2a5411f
 
 
 
bf7d351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a5411f
 
bf7d351
45852d0
 
bf7d351
 
 
 
 
 
 
 
 
2a5411f
 
bf7d351
 
2a5411f
 
 
 
 
 
 
bf7d351
 
2a5411f
bf7d351
 
 
2a5411f
bf7d351
2a5411f
bf7d351
2a5411f
 
bf7d351
 
 
 
2a5411f
 
 
45852d0
2a5411f
 
 
 
 
 
bf7d351
2a5411f
 
 
 
 
 
bf7d351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a5411f
 
bf7d351
45852d0
 
 
 
bf7d351
45852d0
bf7d351
45852d0
bf7d351
 
 
2a5411f
bf7d351
 
 
 
 
 
 
2a5411f
 
bf7d351
2a5411f
45852d0
bf7d351
2a5411f
bf7d351
2a5411f
 
bf7d351
2a5411f
 
bf7d351
2a5411f
 
bf7d351
 
 
 
 
 
 
 
 
 
 
45852d0
 
2a5411f
bf7d351
 
 
2a5411f
 
bf7d351
 
12eb42e
 
bf7d351
2a5411f
bf7d351
2a5411f
45852d0
2a5411f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# services/generation.py
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from PIL import Image
import numpy as np
import config # Your configuration file (config.py)
from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions
import logging
import gc # Garbage collector
from typing import List, Tuple
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
# Note: peft is required for load_lora_weights, ensure it's in requirements.txt

logger = logging.getLogger(__name__) # Get logger instance

# --- Global Model Cache ---
# Using a dictionary to store loaded models and pipelines allows loading them
# only once when the application starts, saving time and resources on subsequent requests.
model_cache = {}

# ==============================================================================
# Model Loading Function (Called during Application Startup)
# ==============================================================================

def load_models():
    """
    Loads all configured machine learning models into the global `model_cache`.
    This function is called once during the application's startup lifespan event.
    If any essential model fails to load, it raises an exception to prevent
    the application from starting in a faulty state.
    """
    logger.info("Initiating model loading sequence...")
    try: # <<<--- Start of the MAIN try block for all models ---<<<

        # --- 1. Text Generation Model ---
        logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
        # Load tokenizer associated with the text model
        model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
        # Load the sequence-to-sequence language model
        # Assuming PyTorch weights (.bin or .safetensors) are available for the model.
        model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
            config.TEXT_MODEL_NAME
            # REMOVED: from_tf=True - Attempt to load PyTorch weights directly.
        ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
        logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.")

        # --- 2. Image Generation Model (Base Pipeline) ---
        logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
        # Load the Stable Diffusion pipeline
        image_pipeline = StableDiffusionPipeline.from_pretrained(
            config.IMAGE_MODEL_NAME,
            torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU)
        )
        # Set a default fast scheduler (can be overridden by LCM later)
        image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
        logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.")

        # --- 3. Attempt to Load LCM LoRA (Optional Speedup) ---
        # Safely check if an LCM LoRA is configured in config.py or environment variables
        lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None)
        if lcm_lora_name:
            logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)")
            try:
                # Load the LoRA weights into the existing pipeline. Requires 'peft'.
                image_pipeline.load_lora_weights(lcm_lora_name)
                # Optional: Fuse LoRA weights for potential minor speedup. Test impact.
                # image_pipeline.fuse_lora()
                logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.")

                # IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully
                image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
                logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.")

            except ImportError as peft_import_error:
                 logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}")
            except Exception as e:
                # Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA)
                logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
        else:
            logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.")

        # --- 4. Image Pipeline Device Placement and Final Setup ---
        image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device
        logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}")

        # Optional GPU Optimizations (Commented out as current config targets CPU)
        # if config.DEVICE == "cuda":
        #     try:
        #         # Requires: pip install xformers
        #         # image_pipeline.enable_xformers_memory_efficient_attention()
        #         # logger.info("Enabled xformers memory efficient attention for image pipeline.")
        #         pass
        #     except ImportError:
        #         logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
        #         # Fallback option: Attention slicing (less memory saving than xformers)
        #         # try:
        #         #    image_pipeline.enable_attention_slicing()
        #         #    logger.info("Enabled attention slicing for image pipeline.")
        #         # except Exception as attn_slice_e:
        #         #     logger.warning(f"Could not enable attention slicing: {attn_slice_e}")

        # Store the finalized image pipeline in the cache
        model_cache["image_pipeline"] = image_pipeline
        logger.info("Image generation model setup complete and cached.")

        # --- 5. Video Generation Model ---
        logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}")
        # Load the video diffusion pipeline (e.g., Zeroscope)
        video_pipeline = DiffusionPipeline.from_pretrained(
            config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w")
            torch_dtype=config.DTYPE,
            variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available
        )
        # Set a standard scheduler for the video pipeline
        video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
        logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.")

        # --- 6. Video Pipeline Memory Optimization (CPU Offload) ---
        # Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU)
        # This keeps parts of the model on the CPU until needed.
        try:
             video_pipeline.enable_model_cpu_offload()
             logger.info("Enabled model CPU offload for video pipeline (good for memory saving).")
        except AttributeError:
             # Fallback if the specific pipeline class doesn't support this method
             logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.")
             try:
                 video_pipeline = video_pipeline.to(config.DEVICE)
                 logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}")
             except Exception as move_err:
                 logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
                 raise # Re-raise if moving the whole model also fails

        # Store the video pipeline in the cache
        model_cache["video_pipeline"] = video_pipeline
        logger.info("Video generation model setup complete and cached.")

        # --- Success ---
        logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed

    except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
        logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True)
        # Re-raise the exception. This will be caught by the application's
        # lifespan manager, which should prevent the server from starting properly.
        raise


# ==============================================================================
# Synchronous Generation Functions (Run in Thread Pool)
# ==============================================================================

def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
    """
    Synchronous function to generate text ideas using the loaded language model.
    Designed to be run in a thread pool to avoid blocking the main async event loop.

    Args:
        prompt: The input prompt or instruction for idea generation.
        max_length: Maximum number of tokens for the generated output.
        num_ideas: The desired number of distinct idea sequences to generate.

    Returns:
        A list of generated idea strings.

    Raises:
        RuntimeError: If the text model or tokenizer is not loaded in the cache.
    """
    tokenizer = model_cache.get("text_tokenizer")
    model = model_cache.get("text_model")
    if not tokenizer or not model:
        logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.")
        raise RuntimeError("Text model not loaded or available.")

    logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}")
    input_text = prompt # Use the direct prompt as input

    # Variables to hold intermediate results for cleanup
    inputs = None
    outputs = None
    ideas = []

    try:
        # Prepare model inputs
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length

        # Perform inference without calculating gradients to save memory
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=num_ideas,
                do_sample=True,       # Enable sampling for diversity
                temperature=0.7,    # Control randomness (lower = more focused)
                top_k=50,             # Consider top k words
                top_p=0.95,           # Use nucleus sampling
                no_repeat_ngram_size=2 # Prevent short repetitive phrases
            )

        # Decode the generated token sequences into strings
        ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        logger.debug(f"Successfully generated {len(ideas)} raw idea(s).")

    finally:
        # --- Resource Cleanup ---
        # Explicitly delete large tensor variables to help GC
        del inputs
        del outputs
        # Clear CUDA cache if running on GPU
        if config.DEVICE == "cuda":
            torch.cuda.empty_cache()
        # Trigger garbage collection
        gc.collect()
        logger.debug("Cleaned up resources after idea generation.")

    return ideas


def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
    """
    Synchronous function to generate an image using the loaded diffusion pipeline.
    Designed to be run in a thread pool.

    Args:
        prompt: The text prompt describing the desired image.
        negative_prompt: Text prompt describing concepts to avoid.
        height: Desired image height in pixels.
        width: Desired image width in pixels.
        num_inference_steps: Number of diffusion steps (more steps = more detail, slower).
                             NOTE: If using LCM, this should be very low (e.g., 4-8).
        guidance_scale: How strongly the prompt guides generation (higher = stricter adherence).
                        NOTE: If using LCM, this should be low (e.g., 0.0-1.5).

    Returns:
        A base64 encoded string representing the generated PNG image.

    Raises:
        RuntimeError: If the image pipeline is not loaded in the cache.
    """
    pipeline = model_cache.get("image_pipeline")
    if not pipeline:
        logger.error("Execution failure: Image pipeline not found in cache during image generation.")
        raise RuntimeError("Image pipeline not loaded or available.")

    # Log parameters, potentially adjust for LCM if detected
    lcm_active = isinstance(pipeline.scheduler, LCMScheduler)
    logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})")
    logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}")
    if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0):
         logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).")

    image = None
    image_base64 = None

    try:
        # Perform inference without calculating gradients
        with torch.no_grad():
            result = pipeline(
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                # Optional: Add generator for reproducibility
                # generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed)
            )
        # Extract the first image from the result
        image = result.images[0]
        logger.debug("Image generation inference complete.")

        # Encode the generated PIL image to a base64 string
        image_base64 = encode_image_base64(image, format="PNG")
        logger.debug("Image encoded to base64 PNG format.")

    finally:
        # --- Resource Cleanup ---
        # The pipeline object itself remains in the cache
        del image # Delete the generated PIL image object
        # Clear CUDA cache if applicable
        if config.DEVICE == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        logger.debug("Cleaned up resources after image generation.")

    return image_base64


def generate_video_sync(
    image_base64: str,
    prompt: str | None, # Optional prompt for video models that use it
    motion_bucket_id: int,
    noise_aug_strength: float,
    num_frames: int,
    fps: int,
    num_inference_steps: int,
    guidance_scale: float
) -> Tuple[str, str]:
    """
    Synchronous function to generate a video clip from an input image using the loaded video pipeline.
    Designed to be run in a thread pool.

    Args:
        image_base64: Base64 encoded string of the input image.
        prompt: Optional text prompt (model-dependent usage).
        motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope).
        noise_aug_strength: Amount of noise added to input image (model-specific).
        num_frames: Number of frames desired in the output video.
        fps: Frames per second for the output video encoding.
        num_inference_steps: Number of diffusion steps for video generation.
        guidance_scale: Guidance scale for video generation.

    Returns:
        A tuple containing:
            - video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF).
            - actual_format (str): The actual format the video was encoded in ("MP4" or "GIF").

    Raises:
        RuntimeError: If the video pipeline is not loaded in the cache.
        ValueError: If the input `image_base64` is invalid.
    """
    pipeline = model_cache.get("video_pipeline")
    if not pipeline:
        logger.error("Execution failure: Video pipeline not found in cache during video generation.")
        raise RuntimeError("Video pipeline not loaded or available.")

    logger.debug("Decoding base64 input image for video generation.")
    try:
        # Decode the input image
        input_image = decode_base64_image(image_base64)
        logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}")
    except Exception as decode_err:
        logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True)
        # Raise a specific error that can be caught by the API route
        raise ValueError("Invalid base64 input image provided.") from decode_err

    logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}")

    # Variables for cleanup
    video_frames_pil = None
    video_frames_np = None
    video_base64 = None
    actual_format = "N/A"

    try:
        # Perform inference without gradients
        with torch.no_grad():
            # CPU offload handles device placement if enabled during load_models
            pipeline_output = pipeline(
                input_image,
                prompt=prompt,
                num_inference_steps=num_inference_steps,
                num_frames=num_frames,
                height=input_image.height, # Use input image dimensions
                width=input_image.width,
                guidance_scale=guidance_scale,
                # Model-specific parameters (like for Zeroscope)
                motion_bucket_id=motion_bucket_id,
                noise_aug_strength=noise_aug_strength
            )
            # Output format can vary; often `.frames` is a list containing one list of PIL images
            if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0:
                 video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]]
            else:
                 # Handle potential variations in output structure if needed
                 logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}")
                 raise RuntimeError("Video generation produced unexpected output format.")
        logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).")

        # Convert PIL frames to NumPy arrays for video encoding
        video_frames_np = [np.array(frame) for frame in video_frames_pil]
        logger.debug("Converted video frames to NumPy arrays.")

        # Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF)
        video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4")
        logger.debug(f"Video encoded to base64 with actual format: {actual_format}")

    finally:
        # --- Resource Cleanup ---
        del input_image # Delete decoded input image
        if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists
        if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists
        # Clear CUDA cache if applicable
        if config.DEVICE == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        logger.debug("Cleaned up resources after video generation.")

    return video_base64, actual_format