contentapi / services /generation.py
rajux75's picture
Revert from_tf=True for Flan-T5 loading
12eb42e verified
# services/generation.py
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from PIL import Image
import numpy as np
import config # Your configuration file (config.py)
from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions
import logging
import gc # Garbage collector
from typing import List, Tuple
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
# Note: peft is required for load_lora_weights, ensure it's in requirements.txt
logger = logging.getLogger(__name__) # Get logger instance
# --- Global Model Cache ---
# Using a dictionary to store loaded models and pipelines allows loading them
# only once when the application starts, saving time and resources on subsequent requests.
model_cache = {}
# ==============================================================================
# Model Loading Function (Called during Application Startup)
# ==============================================================================
def load_models():
"""
Loads all configured machine learning models into the global `model_cache`.
This function is called once during the application's startup lifespan event.
If any essential model fails to load, it raises an exception to prevent
the application from starting in a faulty state.
"""
logger.info("Initiating model loading sequence...")
try: # <<<--- Start of the MAIN try block for all models ---<<<
# --- 1. Text Generation Model ---
logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
# Load tokenizer associated with the text model
model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
# Load the sequence-to-sequence language model
# Assuming PyTorch weights (.bin or .safetensors) are available for the model.
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
config.TEXT_MODEL_NAME
# REMOVED: from_tf=True - Attempt to load PyTorch weights directly.
).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.")
# --- 2. Image Generation Model (Base Pipeline) ---
logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
# Load the Stable Diffusion pipeline
image_pipeline = StableDiffusionPipeline.from_pretrained(
config.IMAGE_MODEL_NAME,
torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU)
)
# Set a default fast scheduler (can be overridden by LCM later)
image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.")
# --- 3. Attempt to Load LCM LoRA (Optional Speedup) ---
# Safely check if an LCM LoRA is configured in config.py or environment variables
lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None)
if lcm_lora_name:
logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)")
try:
# Load the LoRA weights into the existing pipeline. Requires 'peft'.
image_pipeline.load_lora_weights(lcm_lora_name)
# Optional: Fuse LoRA weights for potential minor speedup. Test impact.
# image_pipeline.fuse_lora()
logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.")
# IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully
image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.")
except ImportError as peft_import_error:
logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}")
except Exception as e:
# Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA)
logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
else:
logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.")
# --- 4. Image Pipeline Device Placement and Final Setup ---
image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device
logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}")
# Optional GPU Optimizations (Commented out as current config targets CPU)
# if config.DEVICE == "cuda":
# try:
# # Requires: pip install xformers
# # image_pipeline.enable_xformers_memory_efficient_attention()
# # logger.info("Enabled xformers memory efficient attention for image pipeline.")
# pass
# except ImportError:
# logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
# # Fallback option: Attention slicing (less memory saving than xformers)
# # try:
# # image_pipeline.enable_attention_slicing()
# # logger.info("Enabled attention slicing for image pipeline.")
# # except Exception as attn_slice_e:
# # logger.warning(f"Could not enable attention slicing: {attn_slice_e}")
# Store the finalized image pipeline in the cache
model_cache["image_pipeline"] = image_pipeline
logger.info("Image generation model setup complete and cached.")
# --- 5. Video Generation Model ---
logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}")
# Load the video diffusion pipeline (e.g., Zeroscope)
video_pipeline = DiffusionPipeline.from_pretrained(
config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w")
torch_dtype=config.DTYPE,
variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available
)
# Set a standard scheduler for the video pipeline
video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.")
# --- 6. Video Pipeline Memory Optimization (CPU Offload) ---
# Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU)
# This keeps parts of the model on the CPU until needed.
try:
video_pipeline.enable_model_cpu_offload()
logger.info("Enabled model CPU offload for video pipeline (good for memory saving).")
except AttributeError:
# Fallback if the specific pipeline class doesn't support this method
logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.")
try:
video_pipeline = video_pipeline.to(config.DEVICE)
logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}")
except Exception as move_err:
logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
raise # Re-raise if moving the whole model also fails
# Store the video pipeline in the cache
model_cache["video_pipeline"] = video_pipeline
logger.info("Video generation model setup complete and cached.")
# --- Success ---
logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed
except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True)
# Re-raise the exception. This will be caught by the application's
# lifespan manager, which should prevent the server from starting properly.
raise
# ==============================================================================
# Synchronous Generation Functions (Run in Thread Pool)
# ==============================================================================
def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
"""
Synchronous function to generate text ideas using the loaded language model.
Designed to be run in a thread pool to avoid blocking the main async event loop.
Args:
prompt: The input prompt or instruction for idea generation.
max_length: Maximum number of tokens for the generated output.
num_ideas: The desired number of distinct idea sequences to generate.
Returns:
A list of generated idea strings.
Raises:
RuntimeError: If the text model or tokenizer is not loaded in the cache.
"""
tokenizer = model_cache.get("text_tokenizer")
model = model_cache.get("text_model")
if not tokenizer or not model:
logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.")
raise RuntimeError("Text model not loaded or available.")
logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}")
input_text = prompt # Use the direct prompt as input
# Variables to hold intermediate results for cleanup
inputs = None
outputs = None
ideas = []
try:
# Prepare model inputs
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length
# Perform inference without calculating gradients to save memory
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=num_ideas,
do_sample=True, # Enable sampling for diversity
temperature=0.7, # Control randomness (lower = more focused)
top_k=50, # Consider top k words
top_p=0.95, # Use nucleus sampling
no_repeat_ngram_size=2 # Prevent short repetitive phrases
)
# Decode the generated token sequences into strings
ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
logger.debug(f"Successfully generated {len(ideas)} raw idea(s).")
finally:
# --- Resource Cleanup ---
# Explicitly delete large tensor variables to help GC
del inputs
del outputs
# Clear CUDA cache if running on GPU
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
# Trigger garbage collection
gc.collect()
logger.debug("Cleaned up resources after idea generation.")
return ideas
def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
"""
Synchronous function to generate an image using the loaded diffusion pipeline.
Designed to be run in a thread pool.
Args:
prompt: The text prompt describing the desired image.
negative_prompt: Text prompt describing concepts to avoid.
height: Desired image height in pixels.
width: Desired image width in pixels.
num_inference_steps: Number of diffusion steps (more steps = more detail, slower).
NOTE: If using LCM, this should be very low (e.g., 4-8).
guidance_scale: How strongly the prompt guides generation (higher = stricter adherence).
NOTE: If using LCM, this should be low (e.g., 0.0-1.5).
Returns:
A base64 encoded string representing the generated PNG image.
Raises:
RuntimeError: If the image pipeline is not loaded in the cache.
"""
pipeline = model_cache.get("image_pipeline")
if not pipeline:
logger.error("Execution failure: Image pipeline not found in cache during image generation.")
raise RuntimeError("Image pipeline not loaded or available.")
# Log parameters, potentially adjust for LCM if detected
lcm_active = isinstance(pipeline.scheduler, LCMScheduler)
logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})")
logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}")
if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0):
logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).")
image = None
image_base64 = None
try:
# Perform inference without calculating gradients
with torch.no_grad():
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# Optional: Add generator for reproducibility
# generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed)
)
# Extract the first image from the result
image = result.images[0]
logger.debug("Image generation inference complete.")
# Encode the generated PIL image to a base64 string
image_base64 = encode_image_base64(image, format="PNG")
logger.debug("Image encoded to base64 PNG format.")
finally:
# --- Resource Cleanup ---
# The pipeline object itself remains in the cache
del image # Delete the generated PIL image object
# Clear CUDA cache if applicable
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
logger.debug("Cleaned up resources after image generation.")
return image_base64
def generate_video_sync(
image_base64: str,
prompt: str | None, # Optional prompt for video models that use it
motion_bucket_id: int,
noise_aug_strength: float,
num_frames: int,
fps: int,
num_inference_steps: int,
guidance_scale: float
) -> Tuple[str, str]:
"""
Synchronous function to generate a video clip from an input image using the loaded video pipeline.
Designed to be run in a thread pool.
Args:
image_base64: Base64 encoded string of the input image.
prompt: Optional text prompt (model-dependent usage).
motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope).
noise_aug_strength: Amount of noise added to input image (model-specific).
num_frames: Number of frames desired in the output video.
fps: Frames per second for the output video encoding.
num_inference_steps: Number of diffusion steps for video generation.
guidance_scale: Guidance scale for video generation.
Returns:
A tuple containing:
- video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF).
- actual_format (str): The actual format the video was encoded in ("MP4" or "GIF").
Raises:
RuntimeError: If the video pipeline is not loaded in the cache.
ValueError: If the input `image_base64` is invalid.
"""
pipeline = model_cache.get("video_pipeline")
if not pipeline:
logger.error("Execution failure: Video pipeline not found in cache during video generation.")
raise RuntimeError("Video pipeline not loaded or available.")
logger.debug("Decoding base64 input image for video generation.")
try:
# Decode the input image
input_image = decode_base64_image(image_base64)
logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}")
except Exception as decode_err:
logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True)
# Raise a specific error that can be caught by the API route
raise ValueError("Invalid base64 input image provided.") from decode_err
logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}")
# Variables for cleanup
video_frames_pil = None
video_frames_np = None
video_base64 = None
actual_format = "N/A"
try:
# Perform inference without gradients
with torch.no_grad():
# CPU offload handles device placement if enabled during load_models
pipeline_output = pipeline(
input_image,
prompt=prompt,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
height=input_image.height, # Use input image dimensions
width=input_image.width,
guidance_scale=guidance_scale,
# Model-specific parameters (like for Zeroscope)
motion_bucket_id=motion_bucket_id,
noise_aug_strength=noise_aug_strength
)
# Output format can vary; often `.frames` is a list containing one list of PIL images
if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0:
video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]]
else:
# Handle potential variations in output structure if needed
logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}")
raise RuntimeError("Video generation produced unexpected output format.")
logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).")
# Convert PIL frames to NumPy arrays for video encoding
video_frames_np = [np.array(frame) for frame in video_frames_pil]
logger.debug("Converted video frames to NumPy arrays.")
# Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF)
video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4")
logger.debug(f"Video encoded to base64 with actual format: {actual_format}")
finally:
# --- Resource Cleanup ---
del input_image # Delete decoded input image
if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists
if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists
# Clear CUDA cache if applicable
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
logger.debug("Cleaned up resources after video generation.")
return video_base64, actual_format