# services/generation.py import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from PIL import Image import numpy as np import config # Your configuration file (config.py) from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions import logging import gc # Garbage collector from typing import List, Tuple from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler # Note: peft is required for load_lora_weights, ensure it's in requirements.txt logger = logging.getLogger(__name__) # Get logger instance # --- Global Model Cache --- # Using a dictionary to store loaded models and pipelines allows loading them # only once when the application starts, saving time and resources on subsequent requests. model_cache = {} # ============================================================================== # Model Loading Function (Called during Application Startup) # ============================================================================== def load_models(): """ Loads all configured machine learning models into the global `model_cache`. This function is called once during the application's startup lifespan event. If any essential model fails to load, it raises an exception to prevent the application from starting in a faulty state. """ logger.info("Initiating model loading sequence...") try: # <<<--- Start of the MAIN try block for all models ---<<< # --- 1. Text Generation Model --- logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}") # Load tokenizer associated with the text model model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME) # Load the sequence-to-sequence language model # Assuming PyTorch weights (.bin or .safetensors) are available for the model. model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained( config.TEXT_MODEL_NAME # REMOVED: from_tf=True - Attempt to load PyTorch weights directly. ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA) logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.") # --- 2. Image Generation Model (Base Pipeline) --- logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}") # Load the Stable Diffusion pipeline image_pipeline = StableDiffusionPipeline.from_pretrained( config.IMAGE_MODEL_NAME, torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU) ) # Set a default fast scheduler (can be overridden by LCM later) image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config) logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.") # --- 3. Attempt to Load LCM LoRA (Optional Speedup) --- # Safely check if an LCM LoRA is configured in config.py or environment variables lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None) if lcm_lora_name: logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)") try: # Load the LoRA weights into the existing pipeline. Requires 'peft'. image_pipeline.load_lora_weights(lcm_lora_name) # Optional: Fuse LoRA weights for potential minor speedup. Test impact. # image_pipeline.fuse_lora() logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.") # IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config) logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.") except ImportError as peft_import_error: logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}") except Exception as e: # Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA) logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True) else: logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.") # --- 4. Image Pipeline Device Placement and Final Setup --- image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}") # Optional GPU Optimizations (Commented out as current config targets CPU) # if config.DEVICE == "cuda": # try: # # Requires: pip install xformers # # image_pipeline.enable_xformers_memory_efficient_attention() # # logger.info("Enabled xformers memory efficient attention for image pipeline.") # pass # except ImportError: # logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.") # # Fallback option: Attention slicing (less memory saving than xformers) # # try: # # image_pipeline.enable_attention_slicing() # # logger.info("Enabled attention slicing for image pipeline.") # # except Exception as attn_slice_e: # # logger.warning(f"Could not enable attention slicing: {attn_slice_e}") # Store the finalized image pipeline in the cache model_cache["image_pipeline"] = image_pipeline logger.info("Image generation model setup complete and cached.") # --- 5. Video Generation Model --- logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}") # Load the video diffusion pipeline (e.g., Zeroscope) video_pipeline = DiffusionPipeline.from_pretrained( config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w") torch_dtype=config.DTYPE, variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available ) # Set a standard scheduler for the video pipeline video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config) logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.") # --- 6. Video Pipeline Memory Optimization (CPU Offload) --- # Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU) # This keeps parts of the model on the CPU until needed. try: video_pipeline.enable_model_cpu_offload() logger.info("Enabled model CPU offload for video pipeline (good for memory saving).") except AttributeError: # Fallback if the specific pipeline class doesn't support this method logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.") try: video_pipeline = video_pipeline.to(config.DEVICE) logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}") except Exception as move_err: logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True) raise # Re-raise if moving the whole model also fails # Store the video pipeline in the cache model_cache["video_pipeline"] = video_pipeline logger.info("Video generation model setup complete and cached.") # --- Success --- logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<< logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True) # Re-raise the exception. This will be caught by the application's # lifespan manager, which should prevent the server from starting properly. raise # ============================================================================== # Synchronous Generation Functions (Run in Thread Pool) # ============================================================================== def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]: """ Synchronous function to generate text ideas using the loaded language model. Designed to be run in a thread pool to avoid blocking the main async event loop. Args: prompt: The input prompt or instruction for idea generation. max_length: Maximum number of tokens for the generated output. num_ideas: The desired number of distinct idea sequences to generate. Returns: A list of generated idea strings. Raises: RuntimeError: If the text model or tokenizer is not loaded in the cache. """ tokenizer = model_cache.get("text_tokenizer") model = model_cache.get("text_model") if not tokenizer or not model: logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.") raise RuntimeError("Text model not loaded or available.") logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}") input_text = prompt # Use the direct prompt as input # Variables to hold intermediate results for cleanup inputs = None outputs = None ideas = [] try: # Prepare model inputs inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length # Perform inference without calculating gradients to save memory with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, num_return_sequences=num_ideas, do_sample=True, # Enable sampling for diversity temperature=0.7, # Control randomness (lower = more focused) top_k=50, # Consider top k words top_p=0.95, # Use nucleus sampling no_repeat_ngram_size=2 # Prevent short repetitive phrases ) # Decode the generated token sequences into strings ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] logger.debug(f"Successfully generated {len(ideas)} raw idea(s).") finally: # --- Resource Cleanup --- # Explicitly delete large tensor variables to help GC del inputs del outputs # Clear CUDA cache if running on GPU if config.DEVICE == "cuda": torch.cuda.empty_cache() # Trigger garbage collection gc.collect() logger.debug("Cleaned up resources after idea generation.") return ideas def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str: """ Synchronous function to generate an image using the loaded diffusion pipeline. Designed to be run in a thread pool. Args: prompt: The text prompt describing the desired image. negative_prompt: Text prompt describing concepts to avoid. height: Desired image height in pixels. width: Desired image width in pixels. num_inference_steps: Number of diffusion steps (more steps = more detail, slower). NOTE: If using LCM, this should be very low (e.g., 4-8). guidance_scale: How strongly the prompt guides generation (higher = stricter adherence). NOTE: If using LCM, this should be low (e.g., 0.0-1.5). Returns: A base64 encoded string representing the generated PNG image. Raises: RuntimeError: If the image pipeline is not loaded in the cache. """ pipeline = model_cache.get("image_pipeline") if not pipeline: logger.error("Execution failure: Image pipeline not found in cache during image generation.") raise RuntimeError("Image pipeline not loaded or available.") # Log parameters, potentially adjust for LCM if detected lcm_active = isinstance(pipeline.scheduler, LCMScheduler) logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})") logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}") if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0): logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).") image = None image_base64 = None try: # Perform inference without calculating gradients with torch.no_grad(): result = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, # Optional: Add generator for reproducibility # generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed) ) # Extract the first image from the result image = result.images[0] logger.debug("Image generation inference complete.") # Encode the generated PIL image to a base64 string image_base64 = encode_image_base64(image, format="PNG") logger.debug("Image encoded to base64 PNG format.") finally: # --- Resource Cleanup --- # The pipeline object itself remains in the cache del image # Delete the generated PIL image object # Clear CUDA cache if applicable if config.DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() logger.debug("Cleaned up resources after image generation.") return image_base64 def generate_video_sync( image_base64: str, prompt: str | None, # Optional prompt for video models that use it motion_bucket_id: int, noise_aug_strength: float, num_frames: int, fps: int, num_inference_steps: int, guidance_scale: float ) -> Tuple[str, str]: """ Synchronous function to generate a video clip from an input image using the loaded video pipeline. Designed to be run in a thread pool. Args: image_base64: Base64 encoded string of the input image. prompt: Optional text prompt (model-dependent usage). motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope). noise_aug_strength: Amount of noise added to input image (model-specific). num_frames: Number of frames desired in the output video. fps: Frames per second for the output video encoding. num_inference_steps: Number of diffusion steps for video generation. guidance_scale: Guidance scale for video generation. Returns: A tuple containing: - video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF). - actual_format (str): The actual format the video was encoded in ("MP4" or "GIF"). Raises: RuntimeError: If the video pipeline is not loaded in the cache. ValueError: If the input `image_base64` is invalid. """ pipeline = model_cache.get("video_pipeline") if not pipeline: logger.error("Execution failure: Video pipeline not found in cache during video generation.") raise RuntimeError("Video pipeline not loaded or available.") logger.debug("Decoding base64 input image for video generation.") try: # Decode the input image input_image = decode_base64_image(image_base64) logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}") except Exception as decode_err: logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True) # Raise a specific error that can be caught by the API route raise ValueError("Invalid base64 input image provided.") from decode_err logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}") # Variables for cleanup video_frames_pil = None video_frames_np = None video_base64 = None actual_format = "N/A" try: # Perform inference without gradients with torch.no_grad(): # CPU offload handles device placement if enabled during load_models pipeline_output = pipeline( input_image, prompt=prompt, num_inference_steps=num_inference_steps, num_frames=num_frames, height=input_image.height, # Use input image dimensions width=input_image.width, guidance_scale=guidance_scale, # Model-specific parameters (like for Zeroscope) motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength ) # Output format can vary; often `.frames` is a list containing one list of PIL images if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0: video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]] else: # Handle potential variations in output structure if needed logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}") raise RuntimeError("Video generation produced unexpected output format.") logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).") # Convert PIL frames to NumPy arrays for video encoding video_frames_np = [np.array(frame) for frame in video_frames_pil] logger.debug("Converted video frames to NumPy arrays.") # Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF) video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") logger.debug(f"Video encoded to base64 with actual format: {actual_format}") finally: # --- Resource Cleanup --- del input_image # Delete decoded input image if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists # Clear CUDA cache if applicable if config.DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() logger.debug("Cleaned up resources after video generation.") return video_base64, actual_format