Spaces:
Sleeping
Sleeping
File size: 19,630 Bytes
2a5411f bf7d351 2a5411f bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 12eb42e bf7d351 12eb42e bf7d351 12eb42e bf7d351 45852d0 bf7d351 45852d0 bf7d351 2a5411f bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f 45852d0 bf7d351 45852d0 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 45852d0 2a5411f bf7d351 2a5411f 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 2a5411f bf7d351 2a5411f bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f 45852d0 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 45852d0 bf7d351 45852d0 bf7d351 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f 45852d0 bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 2a5411f bf7d351 45852d0 2a5411f bf7d351 2a5411f bf7d351 12eb42e bf7d351 2a5411f bf7d351 2a5411f 45852d0 2a5411f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
# services/generation.py
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from PIL import Image
import numpy as np
import config # Your configuration file (config.py)
from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions
import logging
import gc # Garbage collector
from typing import List, Tuple
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
# Note: peft is required for load_lora_weights, ensure it's in requirements.txt
logger = logging.getLogger(__name__) # Get logger instance
# --- Global Model Cache ---
# Using a dictionary to store loaded models and pipelines allows loading them
# only once when the application starts, saving time and resources on subsequent requests.
model_cache = {}
# ==============================================================================
# Model Loading Function (Called during Application Startup)
# ==============================================================================
def load_models():
"""
Loads all configured machine learning models into the global `model_cache`.
This function is called once during the application's startup lifespan event.
If any essential model fails to load, it raises an exception to prevent
the application from starting in a faulty state.
"""
logger.info("Initiating model loading sequence...")
try: # <<<--- Start of the MAIN try block for all models ---<<<
# --- 1. Text Generation Model ---
logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
# Load tokenizer associated with the text model
model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
# Load the sequence-to-sequence language model
# Assuming PyTorch weights (.bin or .safetensors) are available for the model.
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
config.TEXT_MODEL_NAME
# REMOVED: from_tf=True - Attempt to load PyTorch weights directly.
).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.")
# --- 2. Image Generation Model (Base Pipeline) ---
logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
# Load the Stable Diffusion pipeline
image_pipeline = StableDiffusionPipeline.from_pretrained(
config.IMAGE_MODEL_NAME,
torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU)
)
# Set a default fast scheduler (can be overridden by LCM later)
image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.")
# --- 3. Attempt to Load LCM LoRA (Optional Speedup) ---
# Safely check if an LCM LoRA is configured in config.py or environment variables
lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None)
if lcm_lora_name:
logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)")
try:
# Load the LoRA weights into the existing pipeline. Requires 'peft'.
image_pipeline.load_lora_weights(lcm_lora_name)
# Optional: Fuse LoRA weights for potential minor speedup. Test impact.
# image_pipeline.fuse_lora()
logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.")
# IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully
image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.")
except ImportError as peft_import_error:
logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}")
except Exception as e:
# Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA)
logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
else:
logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.")
# --- 4. Image Pipeline Device Placement and Final Setup ---
image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device
logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}")
# Optional GPU Optimizations (Commented out as current config targets CPU)
# if config.DEVICE == "cuda":
# try:
# # Requires: pip install xformers
# # image_pipeline.enable_xformers_memory_efficient_attention()
# # logger.info("Enabled xformers memory efficient attention for image pipeline.")
# pass
# except ImportError:
# logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
# # Fallback option: Attention slicing (less memory saving than xformers)
# # try:
# # image_pipeline.enable_attention_slicing()
# # logger.info("Enabled attention slicing for image pipeline.")
# # except Exception as attn_slice_e:
# # logger.warning(f"Could not enable attention slicing: {attn_slice_e}")
# Store the finalized image pipeline in the cache
model_cache["image_pipeline"] = image_pipeline
logger.info("Image generation model setup complete and cached.")
# --- 5. Video Generation Model ---
logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}")
# Load the video diffusion pipeline (e.g., Zeroscope)
video_pipeline = DiffusionPipeline.from_pretrained(
config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w")
torch_dtype=config.DTYPE,
variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available
)
# Set a standard scheduler for the video pipeline
video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.")
# --- 6. Video Pipeline Memory Optimization (CPU Offload) ---
# Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU)
# This keeps parts of the model on the CPU until needed.
try:
video_pipeline.enable_model_cpu_offload()
logger.info("Enabled model CPU offload for video pipeline (good for memory saving).")
except AttributeError:
# Fallback if the specific pipeline class doesn't support this method
logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.")
try:
video_pipeline = video_pipeline.to(config.DEVICE)
logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}")
except Exception as move_err:
logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
raise # Re-raise if moving the whole model also fails
# Store the video pipeline in the cache
model_cache["video_pipeline"] = video_pipeline
logger.info("Video generation model setup complete and cached.")
# --- Success ---
logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed
except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True)
# Re-raise the exception. This will be caught by the application's
# lifespan manager, which should prevent the server from starting properly.
raise
# ==============================================================================
# Synchronous Generation Functions (Run in Thread Pool)
# ==============================================================================
def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
"""
Synchronous function to generate text ideas using the loaded language model.
Designed to be run in a thread pool to avoid blocking the main async event loop.
Args:
prompt: The input prompt or instruction for idea generation.
max_length: Maximum number of tokens for the generated output.
num_ideas: The desired number of distinct idea sequences to generate.
Returns:
A list of generated idea strings.
Raises:
RuntimeError: If the text model or tokenizer is not loaded in the cache.
"""
tokenizer = model_cache.get("text_tokenizer")
model = model_cache.get("text_model")
if not tokenizer or not model:
logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.")
raise RuntimeError("Text model not loaded or available.")
logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}")
input_text = prompt # Use the direct prompt as input
# Variables to hold intermediate results for cleanup
inputs = None
outputs = None
ideas = []
try:
# Prepare model inputs
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length
# Perform inference without calculating gradients to save memory
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=num_ideas,
do_sample=True, # Enable sampling for diversity
temperature=0.7, # Control randomness (lower = more focused)
top_k=50, # Consider top k words
top_p=0.95, # Use nucleus sampling
no_repeat_ngram_size=2 # Prevent short repetitive phrases
)
# Decode the generated token sequences into strings
ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
logger.debug(f"Successfully generated {len(ideas)} raw idea(s).")
finally:
# --- Resource Cleanup ---
# Explicitly delete large tensor variables to help GC
del inputs
del outputs
# Clear CUDA cache if running on GPU
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
# Trigger garbage collection
gc.collect()
logger.debug("Cleaned up resources after idea generation.")
return ideas
def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
"""
Synchronous function to generate an image using the loaded diffusion pipeline.
Designed to be run in a thread pool.
Args:
prompt: The text prompt describing the desired image.
negative_prompt: Text prompt describing concepts to avoid.
height: Desired image height in pixels.
width: Desired image width in pixels.
num_inference_steps: Number of diffusion steps (more steps = more detail, slower).
NOTE: If using LCM, this should be very low (e.g., 4-8).
guidance_scale: How strongly the prompt guides generation (higher = stricter adherence).
NOTE: If using LCM, this should be low (e.g., 0.0-1.5).
Returns:
A base64 encoded string representing the generated PNG image.
Raises:
RuntimeError: If the image pipeline is not loaded in the cache.
"""
pipeline = model_cache.get("image_pipeline")
if not pipeline:
logger.error("Execution failure: Image pipeline not found in cache during image generation.")
raise RuntimeError("Image pipeline not loaded or available.")
# Log parameters, potentially adjust for LCM if detected
lcm_active = isinstance(pipeline.scheduler, LCMScheduler)
logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})")
logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}")
if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0):
logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).")
image = None
image_base64 = None
try:
# Perform inference without calculating gradients
with torch.no_grad():
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# Optional: Add generator for reproducibility
# generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed)
)
# Extract the first image from the result
image = result.images[0]
logger.debug("Image generation inference complete.")
# Encode the generated PIL image to a base64 string
image_base64 = encode_image_base64(image, format="PNG")
logger.debug("Image encoded to base64 PNG format.")
finally:
# --- Resource Cleanup ---
# The pipeline object itself remains in the cache
del image # Delete the generated PIL image object
# Clear CUDA cache if applicable
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
logger.debug("Cleaned up resources after image generation.")
return image_base64
def generate_video_sync(
image_base64: str,
prompt: str | None, # Optional prompt for video models that use it
motion_bucket_id: int,
noise_aug_strength: float,
num_frames: int,
fps: int,
num_inference_steps: int,
guidance_scale: float
) -> Tuple[str, str]:
"""
Synchronous function to generate a video clip from an input image using the loaded video pipeline.
Designed to be run in a thread pool.
Args:
image_base64: Base64 encoded string of the input image.
prompt: Optional text prompt (model-dependent usage).
motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope).
noise_aug_strength: Amount of noise added to input image (model-specific).
num_frames: Number of frames desired in the output video.
fps: Frames per second for the output video encoding.
num_inference_steps: Number of diffusion steps for video generation.
guidance_scale: Guidance scale for video generation.
Returns:
A tuple containing:
- video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF).
- actual_format (str): The actual format the video was encoded in ("MP4" or "GIF").
Raises:
RuntimeError: If the video pipeline is not loaded in the cache.
ValueError: If the input `image_base64` is invalid.
"""
pipeline = model_cache.get("video_pipeline")
if not pipeline:
logger.error("Execution failure: Video pipeline not found in cache during video generation.")
raise RuntimeError("Video pipeline not loaded or available.")
logger.debug("Decoding base64 input image for video generation.")
try:
# Decode the input image
input_image = decode_base64_image(image_base64)
logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}")
except Exception as decode_err:
logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True)
# Raise a specific error that can be caught by the API route
raise ValueError("Invalid base64 input image provided.") from decode_err
logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}")
# Variables for cleanup
video_frames_pil = None
video_frames_np = None
video_base64 = None
actual_format = "N/A"
try:
# Perform inference without gradients
with torch.no_grad():
# CPU offload handles device placement if enabled during load_models
pipeline_output = pipeline(
input_image,
prompt=prompt,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
height=input_image.height, # Use input image dimensions
width=input_image.width,
guidance_scale=guidance_scale,
# Model-specific parameters (like for Zeroscope)
motion_bucket_id=motion_bucket_id,
noise_aug_strength=noise_aug_strength
)
# Output format can vary; often `.frames` is a list containing one list of PIL images
if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0:
video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]]
else:
# Handle potential variations in output structure if needed
logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}")
raise RuntimeError("Video generation produced unexpected output format.")
logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).")
# Convert PIL frames to NumPy arrays for video encoding
video_frames_np = [np.array(frame) for frame in video_frames_pil]
logger.debug("Converted video frames to NumPy arrays.")
# Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF)
video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4")
logger.debug(f"Video encoded to base64 with actual format: {actual_format}")
finally:
# --- Resource Cleanup ---
del input_image # Delete decoded input image
if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists
if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists
# Clear CUDA cache if applicable
if config.DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
logger.debug("Cleaned up resources after video generation.")
return video_base64, actual_format |