Spaces:
Sleeping
Sleeping
| # services/generation.py | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| import config # Your configuration file (config.py) | |
| from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions | |
| import logging | |
| import gc # Garbage collector | |
| from typing import List, Tuple | |
| from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler | |
| # Note: peft is required for load_lora_weights, ensure it's in requirements.txt | |
| logger = logging.getLogger(__name__) # Get logger instance | |
| # --- Global Model Cache --- | |
| # Using a dictionary to store loaded models and pipelines allows loading them | |
| # only once when the application starts, saving time and resources on subsequent requests. | |
| model_cache = {} | |
| # ============================================================================== | |
| # Model Loading Function (Called during Application Startup) | |
| # ============================================================================== | |
| def load_models(): | |
| """ | |
| Loads all configured machine learning models into the global `model_cache`. | |
| This function is called once during the application's startup lifespan event. | |
| If any essential model fails to load, it raises an exception to prevent | |
| the application from starting in a faulty state. | |
| """ | |
| logger.info("Initiating model loading sequence...") | |
| try: # <<<--- Start of the MAIN try block for all models ---<<< | |
| # --- 1. Text Generation Model --- | |
| logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}") | |
| # Load tokenizer associated with the text model | |
| model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME) | |
| # Load the sequence-to-sequence language model | |
| # Assuming PyTorch weights (.bin or .safetensors) are available for the model. | |
| model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained( | |
| config.TEXT_MODEL_NAME | |
| # REMOVED: from_tf=True - Attempt to load PyTorch weights directly. | |
| ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA) | |
| logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.") | |
| # --- 2. Image Generation Model (Base Pipeline) --- | |
| logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}") | |
| # Load the Stable Diffusion pipeline | |
| image_pipeline = StableDiffusionPipeline.from_pretrained( | |
| config.IMAGE_MODEL_NAME, | |
| torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU) | |
| ) | |
| # Set a default fast scheduler (can be overridden by LCM later) | |
| image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config) | |
| logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.") | |
| # --- 3. Attempt to Load LCM LoRA (Optional Speedup) --- | |
| # Safely check if an LCM LoRA is configured in config.py or environment variables | |
| lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None) | |
| if lcm_lora_name: | |
| logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)") | |
| try: | |
| # Load the LoRA weights into the existing pipeline. Requires 'peft'. | |
| image_pipeline.load_lora_weights(lcm_lora_name) | |
| # Optional: Fuse LoRA weights for potential minor speedup. Test impact. | |
| # image_pipeline.fuse_lora() | |
| logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.") | |
| # IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully | |
| image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config) | |
| logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.") | |
| except ImportError as peft_import_error: | |
| logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}") | |
| except Exception as e: | |
| # Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA) | |
| logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True) | |
| else: | |
| logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.") | |
| # --- 4. Image Pipeline Device Placement and Final Setup --- | |
| image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device | |
| logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}") | |
| # Optional GPU Optimizations (Commented out as current config targets CPU) | |
| # if config.DEVICE == "cuda": | |
| # try: | |
| # # Requires: pip install xformers | |
| # # image_pipeline.enable_xformers_memory_efficient_attention() | |
| # # logger.info("Enabled xformers memory efficient attention for image pipeline.") | |
| # pass | |
| # except ImportError: | |
| # logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.") | |
| # # Fallback option: Attention slicing (less memory saving than xformers) | |
| # # try: | |
| # # image_pipeline.enable_attention_slicing() | |
| # # logger.info("Enabled attention slicing for image pipeline.") | |
| # # except Exception as attn_slice_e: | |
| # # logger.warning(f"Could not enable attention slicing: {attn_slice_e}") | |
| # Store the finalized image pipeline in the cache | |
| model_cache["image_pipeline"] = image_pipeline | |
| logger.info("Image generation model setup complete and cached.") | |
| # --- 5. Video Generation Model --- | |
| logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}") | |
| # Load the video diffusion pipeline (e.g., Zeroscope) | |
| video_pipeline = DiffusionPipeline.from_pretrained( | |
| config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w") | |
| torch_dtype=config.DTYPE, | |
| variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available | |
| ) | |
| # Set a standard scheduler for the video pipeline | |
| video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config) | |
| logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.") | |
| # --- 6. Video Pipeline Memory Optimization (CPU Offload) --- | |
| # Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU) | |
| # This keeps parts of the model on the CPU until needed. | |
| try: | |
| video_pipeline.enable_model_cpu_offload() | |
| logger.info("Enabled model CPU offload for video pipeline (good for memory saving).") | |
| except AttributeError: | |
| # Fallback if the specific pipeline class doesn't support this method | |
| logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.") | |
| try: | |
| video_pipeline = video_pipeline.to(config.DEVICE) | |
| logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}") | |
| except Exception as move_err: | |
| logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True) | |
| raise # Re-raise if moving the whole model also fails | |
| # Store the video pipeline in the cache | |
| model_cache["video_pipeline"] = video_pipeline | |
| logger.info("Video generation model setup complete and cached.") | |
| # --- Success --- | |
| logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed | |
| except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<< | |
| logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True) | |
| # Re-raise the exception. This will be caught by the application's | |
| # lifespan manager, which should prevent the server from starting properly. | |
| raise | |
| # ============================================================================== | |
| # Synchronous Generation Functions (Run in Thread Pool) | |
| # ============================================================================== | |
| def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]: | |
| """ | |
| Synchronous function to generate text ideas using the loaded language model. | |
| Designed to be run in a thread pool to avoid blocking the main async event loop. | |
| Args: | |
| prompt: The input prompt or instruction for idea generation. | |
| max_length: Maximum number of tokens for the generated output. | |
| num_ideas: The desired number of distinct idea sequences to generate. | |
| Returns: | |
| A list of generated idea strings. | |
| Raises: | |
| RuntimeError: If the text model or tokenizer is not loaded in the cache. | |
| """ | |
| tokenizer = model_cache.get("text_tokenizer") | |
| model = model_cache.get("text_model") | |
| if not tokenizer or not model: | |
| logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.") | |
| raise RuntimeError("Text model not loaded or available.") | |
| logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}") | |
| input_text = prompt # Use the direct prompt as input | |
| # Variables to hold intermediate results for cleanup | |
| inputs = None | |
| outputs = None | |
| ideas = [] | |
| try: | |
| # Prepare model inputs | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length | |
| # Perform inference without calculating gradients to save memory | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=num_ideas, | |
| do_sample=True, # Enable sampling for diversity | |
| temperature=0.7, # Control randomness (lower = more focused) | |
| top_k=50, # Consider top k words | |
| top_p=0.95, # Use nucleus sampling | |
| no_repeat_ngram_size=2 # Prevent short repetitive phrases | |
| ) | |
| # Decode the generated token sequences into strings | |
| ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| logger.debug(f"Successfully generated {len(ideas)} raw idea(s).") | |
| finally: | |
| # --- Resource Cleanup --- | |
| # Explicitly delete large tensor variables to help GC | |
| del inputs | |
| del outputs | |
| # Clear CUDA cache if running on GPU | |
| if config.DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| # Trigger garbage collection | |
| gc.collect() | |
| logger.debug("Cleaned up resources after idea generation.") | |
| return ideas | |
| def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str: | |
| """ | |
| Synchronous function to generate an image using the loaded diffusion pipeline. | |
| Designed to be run in a thread pool. | |
| Args: | |
| prompt: The text prompt describing the desired image. | |
| negative_prompt: Text prompt describing concepts to avoid. | |
| height: Desired image height in pixels. | |
| width: Desired image width in pixels. | |
| num_inference_steps: Number of diffusion steps (more steps = more detail, slower). | |
| NOTE: If using LCM, this should be very low (e.g., 4-8). | |
| guidance_scale: How strongly the prompt guides generation (higher = stricter adherence). | |
| NOTE: If using LCM, this should be low (e.g., 0.0-1.5). | |
| Returns: | |
| A base64 encoded string representing the generated PNG image. | |
| Raises: | |
| RuntimeError: If the image pipeline is not loaded in the cache. | |
| """ | |
| pipeline = model_cache.get("image_pipeline") | |
| if not pipeline: | |
| logger.error("Execution failure: Image pipeline not found in cache during image generation.") | |
| raise RuntimeError("Image pipeline not loaded or available.") | |
| # Log parameters, potentially adjust for LCM if detected | |
| lcm_active = isinstance(pipeline.scheduler, LCMScheduler) | |
| logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})") | |
| logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}") | |
| if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0): | |
| logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).") | |
| image = None | |
| image_base64 = None | |
| try: | |
| # Perform inference without calculating gradients | |
| with torch.no_grad(): | |
| result = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| # Optional: Add generator for reproducibility | |
| # generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed) | |
| ) | |
| # Extract the first image from the result | |
| image = result.images[0] | |
| logger.debug("Image generation inference complete.") | |
| # Encode the generated PIL image to a base64 string | |
| image_base64 = encode_image_base64(image, format="PNG") | |
| logger.debug("Image encoded to base64 PNG format.") | |
| finally: | |
| # --- Resource Cleanup --- | |
| # The pipeline object itself remains in the cache | |
| del image # Delete the generated PIL image object | |
| # Clear CUDA cache if applicable | |
| if config.DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logger.debug("Cleaned up resources after image generation.") | |
| return image_base64 | |
| def generate_video_sync( | |
| image_base64: str, | |
| prompt: str | None, # Optional prompt for video models that use it | |
| motion_bucket_id: int, | |
| noise_aug_strength: float, | |
| num_frames: int, | |
| fps: int, | |
| num_inference_steps: int, | |
| guidance_scale: float | |
| ) -> Tuple[str, str]: | |
| """ | |
| Synchronous function to generate a video clip from an input image using the loaded video pipeline. | |
| Designed to be run in a thread pool. | |
| Args: | |
| image_base64: Base64 encoded string of the input image. | |
| prompt: Optional text prompt (model-dependent usage). | |
| motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope). | |
| noise_aug_strength: Amount of noise added to input image (model-specific). | |
| num_frames: Number of frames desired in the output video. | |
| fps: Frames per second for the output video encoding. | |
| num_inference_steps: Number of diffusion steps for video generation. | |
| guidance_scale: Guidance scale for video generation. | |
| Returns: | |
| A tuple containing: | |
| - video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF). | |
| - actual_format (str): The actual format the video was encoded in ("MP4" or "GIF"). | |
| Raises: | |
| RuntimeError: If the video pipeline is not loaded in the cache. | |
| ValueError: If the input `image_base64` is invalid. | |
| """ | |
| pipeline = model_cache.get("video_pipeline") | |
| if not pipeline: | |
| logger.error("Execution failure: Video pipeline not found in cache during video generation.") | |
| raise RuntimeError("Video pipeline not loaded or available.") | |
| logger.debug("Decoding base64 input image for video generation.") | |
| try: | |
| # Decode the input image | |
| input_image = decode_base64_image(image_base64) | |
| logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}") | |
| except Exception as decode_err: | |
| logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True) | |
| # Raise a specific error that can be caught by the API route | |
| raise ValueError("Invalid base64 input image provided.") from decode_err | |
| logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}") | |
| # Variables for cleanup | |
| video_frames_pil = None | |
| video_frames_np = None | |
| video_base64 = None | |
| actual_format = "N/A" | |
| try: | |
| # Perform inference without gradients | |
| with torch.no_grad(): | |
| # CPU offload handles device placement if enabled during load_models | |
| pipeline_output = pipeline( | |
| input_image, | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| num_frames=num_frames, | |
| height=input_image.height, # Use input image dimensions | |
| width=input_image.width, | |
| guidance_scale=guidance_scale, | |
| # Model-specific parameters (like for Zeroscope) | |
| motion_bucket_id=motion_bucket_id, | |
| noise_aug_strength=noise_aug_strength | |
| ) | |
| # Output format can vary; often `.frames` is a list containing one list of PIL images | |
| if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0: | |
| video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]] | |
| else: | |
| # Handle potential variations in output structure if needed | |
| logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}") | |
| raise RuntimeError("Video generation produced unexpected output format.") | |
| logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).") | |
| # Convert PIL frames to NumPy arrays for video encoding | |
| video_frames_np = [np.array(frame) for frame in video_frames_pil] | |
| logger.debug("Converted video frames to NumPy arrays.") | |
| # Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF) | |
| video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") | |
| logger.debug(f"Video encoded to base64 with actual format: {actual_format}") | |
| finally: | |
| # --- Resource Cleanup --- | |
| del input_image # Delete decoded input image | |
| if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists | |
| if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists | |
| # Clear CUDA cache if applicable | |
| if config.DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logger.debug("Cleaned up resources after video generation.") | |
| return video_base64, actual_format |