rajux75 commited on
Commit
bf7d351
·
verified ·
1 Parent(s): 5b6c27d

Update services/generation.py

Browse files
Files changed (1) hide show
  1. services/generation.py +241 -126
services/generation.py CHANGED
@@ -2,167 +2,213 @@
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from PIL import Image
5
- import numpy as np # Added import for numpy array conversion later
6
- import config
7
- from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
8
  import logging
9
  import gc # Garbage collector
10
- from typing import List, Tuple # Added Tuple for generate_video_sync return type hint
11
  from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
12
- # from peft import PeftConfig # Usually not needed directly if using load_lora_weights
13
 
14
- logger = logging.getLogger(__name__)
15
 
16
  # --- Global Model Cache ---
17
- # Use a dictionary to hold loaded models and tokenizers
18
- # This allows loading them only once when the app starts.
19
  model_cache = {}
20
 
 
 
 
 
21
  def load_models():
22
- """Loads all models into the cache. Called at application startup."""
23
- logger.info("Loading models...")
 
 
 
 
 
24
  try: # <<<--- Start of the MAIN try block for all models ---<<<
25
 
26
- # --- Text Generation Model ---
27
  logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
 
28
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
29
- model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
30
- logger.info("Text model loaded.")
31
-
32
- # --- Image Generation Model (Base) ---
33
- logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
 
 
 
 
 
 
34
  image_pipeline = StableDiffusionPipeline.from_pretrained(
35
  config.IMAGE_MODEL_NAME,
36
- torch_dtype=config.DTYPE
37
  )
38
- # Default scheduler (will be potentially overridden by LCM)
39
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
40
- logger.info("Image base pipeline loaded. Default scheduler: DPMSolverMultistepScheduler.")
41
 
42
- # --- Attempt to Load LCM LoRA (Optional Speedup) ---
43
- # Check if IMAGE_LCM_LORA_NAME is defined and not empty in config
44
- lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None) # Safely get LORA name
45
  if lcm_lora_name:
 
46
  try:
47
- logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name}")
48
- # Load LoRA weights directly into the pipeline
49
  image_pipeline.load_lora_weights(lcm_lora_name)
50
- # Fuse LoRA for potential speedup (optional, test impact)
51
  # image_pipeline.fuse_lora()
52
- logger.info("LCM LoRA loaded successfully.")
53
 
54
- # IMPORTANT: Set LCM Scheduler *only if* LoRA loaded successfully
55
  image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
56
- logger.info("Switched scheduler to LCMScheduler.")
57
 
 
 
58
  except Exception as e:
 
59
  logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
60
- # Scheduler already set to DPMSolverMultistepScheduler above, so no action needed here
61
  else:
62
- logger.info("No IMAGE_LCM_LORA_NAME configured in environment/config. Using default scheduler.")
63
-
64
- # --- Image Pipeline Device Placement and Optimizations ---
65
- image_pipeline = image_pipeline.to(config.DEVICE)
66
- logger.info(f"Image pipeline moved to device: {config.DEVICE}")
67
-
68
- if config.DEVICE == "cuda":
69
- # Optional: Enable memory efficient attention mechanisms if GPU available and libs installed
70
- try:
71
- # Requires: pip install xformers
72
- # image_pipeline.enable_xformers_memory_efficient_attention()
73
- # logger.info("Enabled xformers memory efficient attention.")
74
- pass # Keep commented out if xformers not installed/intended
75
- except ImportError:
76
- logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
77
- # Fallback option if xformers is not available
78
- # try:
79
- # image_pipeline.enable_attention_slicing()
80
- # logger.info("Enabled attention slicing.")
81
- # except Exception as attn_slice_e:
82
- # logger.warning(f"Could not enable attention slicing: {attn_slice_e}")
83
-
84
- # --- Store Image Pipeline in Cache ---
85
  model_cache["image_pipeline"] = image_pipeline
86
- logger.info("Image model setup complete and cached.")
87
-
88
 
89
- # --- Video Generation Model ---
90
- logger.info(f"Loading video model: {config.VIDEO_MODEL_NAME}")
 
91
  video_pipeline = DiffusionPipeline.from_pretrained(
92
- config.VIDEO_MODEL_NAME,
93
  torch_dtype=config.DTYPE,
94
- variant="fp16" if config.DTYPE == torch.float16 else None # Zeroscope often has fp16 variants
95
  )
 
96
  video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
97
- logger.info("Video pipeline loaded. Scheduler: DPMSolverMultistepScheduler.")
98
 
99
- # Enable CPU offloading *before* potentially moving parts to GPU if not offloading everything
100
- # This is crucial for fitting larger models in limited VRAM/RAM.
 
101
  try:
102
  video_pipeline.enable_model_cpu_offload()
103
- logger.info("Enabled model CPU offload for video pipeline.")
104
  except AttributeError:
105
- logger.warning("Video pipeline class may not support enable_model_cpu_offload(). Attempting to move entire model to device.")
106
- # Fallback if offload method isn't available on this specific pipeline class
107
  try:
108
  video_pipeline = video_pipeline.to(config.DEVICE)
109
- logger.info(f"Video pipeline moved to device: {config.DEVICE}")
110
  except Exception as move_err:
111
  logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
112
- # Decide if you want to raise here or let it fail later
113
- # raise
114
 
115
- # Store video pipeline in cache
116
  model_cache["video_pipeline"] = video_pipeline
117
- logger.info("Video model setup complete and cached.")
118
 
119
- # --- Success Message ---
120
- logger.info("All configured models loaded successfully.") # Runs only if all steps above succeed
121
 
122
  except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
123
- logger.error(f"FATAL: Error loading one or more models during startup: {e}", exc_info=True)
124
- # Re-raise the exception to prevent the application from starting
125
- # in a state where essential models are missing.
126
  raise
127
 
128
 
 
 
 
 
129
  def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
130
- """Synchronous function for text generation (run in thread pool)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  tokenizer = model_cache.get("text_tokenizer")
132
  model = model_cache.get("text_model")
133
  if not tokenizer or not model:
134
- # This should ideally not happen if load_models raises on failure
135
- logger.error("Attempted to generate ideas but text model/tokenizer not found in cache.")
136
  raise RuntimeError("Text model not loaded or available.")
137
 
138
- logger.debug(f"Generating ideas for prompt: '{prompt}'")
139
- input_text = prompt # Keep original prompt based on request model
 
 
 
 
 
140
 
141
  try:
142
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Max input length for model
 
143
 
144
- # Generation parameters
145
- with torch.no_grad(): # Ensure no gradients are computed
146
  outputs = model.generate(
147
  **inputs,
148
  max_length=max_length,
149
  num_return_sequences=num_ideas,
150
- do_sample=True, # Use sampling for more diverse ideas
151
- temperature=0.8,
152
- top_k=50,
153
- top_p=0.95,
154
- no_repeat_ngram_size=2 # Avoid repetitive phrases
155
  )
156
 
 
157
  ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
158
- logger.debug(f"Generated {len(ideas)} ideas.")
159
 
160
  finally:
161
- # Clean up GPU memory if applicable
 
162
  del inputs
163
  del outputs
 
164
  if config.DEVICE == "cuda":
165
  torch.cuda.empty_cache()
 
166
  gc.collect()
167
  logger.debug("Cleaned up resources after idea generation.")
168
 
@@ -170,19 +216,44 @@ def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[st
170
 
171
 
172
  def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
173
- """Synchronous function for image generation (run in thread pool)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  pipeline = model_cache.get("image_pipeline")
175
  if not pipeline:
176
- logger.error("Attempted to generate image but image pipeline not found in cache.")
177
  raise RuntimeError("Image pipeline not loaded or available.")
178
 
179
- logger.debug(f"Generating image for prompt: '{prompt}'")
180
- # Note: If using LCM, optimal steps are much lower (e.g., 4-8) and guidance might be 0 or 1.
181
- # Consider adding logic here or in the API route to adjust params if LCM is active.
182
- # For now, it uses the user-provided parameters.
 
 
 
 
 
183
 
184
  try:
185
- with torch.no_grad(): # Conserve memory during inference
 
186
  result = pipeline(
187
  prompt=prompt,
188
  negative_prompt=negative_prompt,
@@ -190,18 +261,22 @@ def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, w
190
  width=width,
191
  num_inference_steps=num_inference_steps,
192
  guidance_scale=guidance_scale,
193
- # generator=torch.Generator(device=config.DEVICE).manual_seed(seed) # Optional: for reproducibility
 
194
  )
195
- image: Image.Image = result.images[0]
196
- logger.debug("Image generation complete.")
 
197
 
198
- # Encode image to base64
199
  image_base64 = encode_image_base64(image, format="PNG")
200
- logger.debug("Image encoded to base64.")
201
 
202
  finally:
203
- # Clean up GPU memory if applicable
204
- # pipeline object itself is persistent in cache, don't delete it
 
 
205
  if config.DEVICE == "cuda":
206
  torch.cuda.empty_cache()
207
  gc.collect()
@@ -212,61 +287,101 @@ def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, w
212
 
213
  def generate_video_sync(
214
  image_base64: str,
215
- prompt: str | None,
216
  motion_bucket_id: int,
217
  noise_aug_strength: float,
218
  num_frames: int,
219
  fps: int,
220
  num_inference_steps: int,
221
  guidance_scale: float
222
- ) -> Tuple[str, str]: # Corrected return type hint
223
- """Synchronous function for video generation (run in thread pool)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  pipeline = model_cache.get("video_pipeline")
225
  if not pipeline:
226
- logger.error("Attempted to generate video but video pipeline not found in cache.")
227
  raise RuntimeError("Video pipeline not loaded or available.")
228
 
229
  logger.debug("Decoding base64 input image for video generation.")
230
  try:
 
231
  input_image = decode_base64_image(image_base64)
 
232
  except Exception as decode_err:
233
- logger.error(f"Failed to decode base64 image: {decode_err}", exc_info=True)
234
- raise ValueError("Invalid base64 input image.") from decode_err
 
235
 
236
- logger.debug(f"Generating video from image, frames={num_frames}, fps={fps}")
 
 
 
 
 
 
237
 
238
  try:
 
239
  with torch.no_grad():
240
  # CPU offload handles device placement if enabled during load_models
241
- video_frames_pil = pipeline(
242
  input_image,
243
- prompt=prompt, # Zeroscope uses prompt less directly, more for style maybe
244
  num_inference_steps=num_inference_steps,
245
  num_frames=num_frames,
246
- height=input_image.height, # Match input image size usually
247
  width=input_image.width,
248
  guidance_scale=guidance_scale,
 
249
  motion_bucket_id=motion_bucket_id,
250
  noise_aug_strength=noise_aug_strength
251
- ).frames[0] # Output is often nested [[frame1, frame2...]]
252
- logger.debug("Video frame generation complete.")
253
-
254
- # video_frames is usually List[PIL.Image], convert to numpy for encoding
 
 
 
 
 
 
 
255
  video_frames_np = [np.array(frame) for frame in video_frames_pil]
256
  logger.debug("Converted video frames to NumPy arrays.")
257
 
258
- # Encode video to base64
259
- video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") # Request MP4, helper handles fallback
260
- logger.debug(f"Video encoded to base64 with format: {actual_format}")
261
 
262
  finally:
263
- # Clean up GPU/CPU memory
264
- # Offloading handles VRAM well, but ensure general RAM is freed
265
- del input_image
266
- if 'video_frames_pil' in locals(): del video_frames_pil
267
- if 'video_frames_np' in locals(): del video_frames_np
268
  if config.DEVICE == "cuda":
269
- torch.cuda.empty_cache() # Still good practice
270
  gc.collect()
271
  logger.debug("Cleaned up resources after video generation.")
272
 
 
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from PIL import Image
5
+ import numpy as np
6
+ import config # Your configuration file (config.py)
7
+ from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64 # Your helper functions
8
  import logging
9
  import gc # Garbage collector
10
+ from typing import List, Tuple
11
  from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
12
+ # Note: peft is required for load_lora_weights, ensure it's in requirements.txt
13
 
14
+ logger = logging.getLogger(__name__) # Get logger instance
15
 
16
  # --- Global Model Cache ---
17
+ # Using a dictionary to store loaded models and pipelines allows loading them
18
+ # only once when the application starts, saving time and resources on subsequent requests.
19
  model_cache = {}
20
 
21
+ # ==============================================================================
22
+ # Model Loading Function (Called during Application Startup)
23
+ # ==============================================================================
24
+
25
  def load_models():
26
+ """
27
+ Loads all configured machine learning models into the global `model_cache`.
28
+ This function is called once during the application's startup lifespan event.
29
+ If any essential model fails to load, it raises an exception to prevent
30
+ the application from starting in a faulty state.
31
+ """
32
+ logger.info("Initiating model loading sequence...")
33
  try: # <<<--- Start of the MAIN try block for all models ---<<<
34
 
35
+ # --- 1. Text Generation Model ---
36
  logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
37
+ # Load tokenizer associated with the text model
38
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
39
+ # Load the sequence-to-sequence language model
40
+ # IMPORTANT: Add from_tf=True if the primary weights are TensorFlow format (like google/flan-t5-base)
41
+ model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
42
+ config.TEXT_MODEL_NAME,
43
+ from_tf=True # Required for google/flan-t5-base which has tf_model.h5
44
+ ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
45
+ logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (from TF weights if applicable) onto {config.DEVICE}.")
46
+
47
+ # --- 2. Image Generation Model (Base Pipeline) ---
48
+ logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
49
+ # Load the Stable Diffusion pipeline
50
  image_pipeline = StableDiffusionPipeline.from_pretrained(
51
  config.IMAGE_MODEL_NAME,
52
+ torch_dtype=config.DTYPE # Use configured dtype (float16 on CUDA, float32 on CPU)
53
  )
54
+ # Set a default fast scheduler (can be overridden by LCM later)
55
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
56
+ logger.info(f"Image base pipeline '{config.IMAGE_MODEL_NAME}' loaded. Default scheduler: DPMSolverMultistepScheduler.")
57
 
58
+ # --- 3. Attempt to Load LCM LoRA (Optional Speedup) ---
59
+ # Safely check if an LCM LoRA is configured in config.py or environment variables
60
+ lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None)
61
  if lcm_lora_name:
62
+ logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name} (Requires 'peft' library)")
63
  try:
64
+ # Load the LoRA weights into the existing pipeline. Requires 'peft'.
 
65
  image_pipeline.load_lora_weights(lcm_lora_name)
66
+ # Optional: Fuse LoRA weights for potential minor speedup. Test impact.
67
  # image_pipeline.fuse_lora()
68
+ logger.info(f"LCM LoRA '{lcm_lora_name}' loaded successfully.")
69
 
70
+ # IMPORTANT: Switch to the LCM Scheduler *only if* LoRA loaded successfully
71
  image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
72
+ logger.info("Switched image pipeline scheduler to LCMScheduler for optimized LCM inference.")
73
 
74
+ except ImportError as peft_import_error:
75
+ logger.error(f"Failed to load LCM LoRA '{lcm_lora_name}': 'peft' library not installed. Please add 'peft' to requirements.txt. Falling back to default scheduler. Error: {peft_import_error}")
76
  except Exception as e:
77
+ # Catch other potential errors during LoRA loading (e.g., network issues, invalid LoRA)
78
  logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
 
79
  else:
80
+ logger.info("No IMAGE_LCM_LORA_NAME configured. Using default image scheduler.")
81
+
82
+ # --- 4. Image Pipeline Device Placement and Final Setup ---
83
+ image_pipeline = image_pipeline.to(config.DEVICE) # Move the potentially modified pipeline to the device
84
+ logger.info(f"Image pipeline finalized and moved to device: {config.DEVICE}")
85
+
86
+ # Optional GPU Optimizations (Commented out as current config targets CPU)
87
+ # if config.DEVICE == "cuda":
88
+ # try:
89
+ # # Requires: pip install xformers
90
+ # # image_pipeline.enable_xformers_memory_efficient_attention()
91
+ # # logger.info("Enabled xformers memory efficient attention for image pipeline.")
92
+ # pass
93
+ # except ImportError:
94
+ # logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
95
+ # # Fallback option: Attention slicing (less memory saving than xformers)
96
+ # # try:
97
+ # # image_pipeline.enable_attention_slicing()
98
+ # # logger.info("Enabled attention slicing for image pipeline.")
99
+ # # except Exception as attn_slice_e:
100
+ # # logger.warning(f"Could not enable attention slicing: {attn_slice_e}")
101
+
102
+ # Store the finalized image pipeline in the cache
103
  model_cache["image_pipeline"] = image_pipeline
104
+ logger.info("Image generation model setup complete and cached.")
 
105
 
106
+ # --- 5. Video Generation Model ---
107
+ logger.info(f"Loading video generation model: {config.VIDEO_MODEL_NAME}")
108
+ # Load the video diffusion pipeline (e.g., Zeroscope)
109
  video_pipeline = DiffusionPipeline.from_pretrained(
110
+ config.VIDEO_MODEL_NAME, # Make sure this includes user/org (e.g., "cerspense/zeroscope_v2_576w")
111
  torch_dtype=config.DTYPE,
112
+ variant="fp16" if config.DTYPE == torch.float16 else None # Use fp16 variant if on CUDA and available
113
  )
114
+ # Set a standard scheduler for the video pipeline
115
  video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
116
+ logger.info(f"Video pipeline '{config.VIDEO_MODEL_NAME}' loaded. Scheduler: DPMSolverMultistepScheduler.")
117
 
118
+ # --- 6. Video Pipeline Memory Optimization (CPU Offload) ---
119
+ # Enable CPU offloading to save VRAM (if on GPU) or manage RAM usage (on CPU)
120
+ # This keeps parts of the model on the CPU until needed.
121
  try:
122
  video_pipeline.enable_model_cpu_offload()
123
+ logger.info("Enabled model CPU offload for video pipeline (good for memory saving).")
124
  except AttributeError:
125
+ # Fallback if the specific pipeline class doesn't support this method
126
+ logger.warning(f"Video pipeline class {type(video_pipeline).__name__} may not support enable_model_cpu_offload(). Attempting to move entire model to device {config.DEVICE}.")
127
  try:
128
  video_pipeline = video_pipeline.to(config.DEVICE)
129
+ logger.info(f"Video pipeline moved entirely to device: {config.DEVICE}")
130
  except Exception as move_err:
131
  logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
132
+ raise # Re-raise if moving the whole model also fails
 
133
 
134
+ # Store the video pipeline in the cache
135
  model_cache["video_pipeline"] = video_pipeline
136
+ logger.info("Video generation model setup complete and cached.")
137
 
138
+ # --- Success ---
139
+ logger.info("All configured models loaded successfully.") # Only logs if all steps above succeed
140
 
141
  except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
142
+ logger.error(f"FATAL: Error occurred during model loading sequence: {e}", exc_info=True)
143
+ # Re-raise the exception. This will be caught by the application's
144
+ # lifespan manager, which should prevent the server from starting properly.
145
  raise
146
 
147
 
148
+ # ==============================================================================
149
+ # Synchronous Generation Functions (Run in Thread Pool)
150
+ # ==============================================================================
151
+
152
  def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
153
+ """
154
+ Synchronous function to generate text ideas using the loaded language model.
155
+ Designed to be run in a thread pool to avoid blocking the main async event loop.
156
+
157
+ Args:
158
+ prompt: The input prompt or instruction for idea generation.
159
+ max_length: Maximum number of tokens for the generated output.
160
+ num_ideas: The desired number of distinct idea sequences to generate.
161
+
162
+ Returns:
163
+ A list of generated idea strings.
164
+
165
+ Raises:
166
+ RuntimeError: If the text model or tokenizer is not loaded in the cache.
167
+ """
168
  tokenizer = model_cache.get("text_tokenizer")
169
  model = model_cache.get("text_model")
170
  if not tokenizer or not model:
171
+ logger.error("Execution failure: Text model or tokenizer not found in cache during idea generation.")
 
172
  raise RuntimeError("Text model not loaded or available.")
173
 
174
+ logger.debug(f"Generating {num_ideas} ideas for prompt: '{prompt}' with max_length={max_length}")
175
+ input_text = prompt # Use the direct prompt as input
176
+
177
+ # Variables to hold intermediate results for cleanup
178
+ inputs = None
179
+ outputs = None
180
+ ideas = []
181
 
182
  try:
183
+ # Prepare model inputs
184
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Limit input length
185
 
186
+ # Perform inference without calculating gradients to save memory
187
+ with torch.no_grad():
188
  outputs = model.generate(
189
  **inputs,
190
  max_length=max_length,
191
  num_return_sequences=num_ideas,
192
+ do_sample=True, # Enable sampling for diversity
193
+ temperature=0.7, # Control randomness (lower = more focused)
194
+ top_k=50, # Consider top k words
195
+ top_p=0.95, # Use nucleus sampling
196
+ no_repeat_ngram_size=2 # Prevent short repetitive phrases
197
  )
198
 
199
+ # Decode the generated token sequences into strings
200
  ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
201
+ logger.debug(f"Successfully generated {len(ideas)} raw idea(s).")
202
 
203
  finally:
204
+ # --- Resource Cleanup ---
205
+ # Explicitly delete large tensor variables to help GC
206
  del inputs
207
  del outputs
208
+ # Clear CUDA cache if running on GPU
209
  if config.DEVICE == "cuda":
210
  torch.cuda.empty_cache()
211
+ # Trigger garbage collection
212
  gc.collect()
213
  logger.debug("Cleaned up resources after idea generation.")
214
 
 
216
 
217
 
218
  def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
219
+ """
220
+ Synchronous function to generate an image using the loaded diffusion pipeline.
221
+ Designed to be run in a thread pool.
222
+
223
+ Args:
224
+ prompt: The text prompt describing the desired image.
225
+ negative_prompt: Text prompt describing concepts to avoid.
226
+ height: Desired image height in pixels.
227
+ width: Desired image width in pixels.
228
+ num_inference_steps: Number of diffusion steps (more steps = more detail, slower).
229
+ NOTE: If using LCM, this should be very low (e.g., 4-8).
230
+ guidance_scale: How strongly the prompt guides generation (higher = stricter adherence).
231
+ NOTE: If using LCM, this should be low (e.g., 0.0-1.5).
232
+
233
+ Returns:
234
+ A base64 encoded string representing the generated PNG image.
235
+
236
+ Raises:
237
+ RuntimeError: If the image pipeline is not loaded in the cache.
238
+ """
239
  pipeline = model_cache.get("image_pipeline")
240
  if not pipeline:
241
+ logger.error("Execution failure: Image pipeline not found in cache during image generation.")
242
  raise RuntimeError("Image pipeline not loaded or available.")
243
 
244
+ # Log parameters, potentially adjust for LCM if detected
245
+ lcm_active = isinstance(pipeline.scheduler, LCMScheduler)
246
+ logger.debug(f"Generating image for prompt: '{prompt}' (LCM Active: {lcm_active})")
247
+ logger.debug(f"Params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}")
248
+ if lcm_active and (num_inference_steps > 10 or guidance_scale > 2.0):
249
+ logger.warning(f"LCM scheduler is active, but parameters (steps={num_inference_steps}, guidance={guidance_scale}) seem high. Optimal LCM uses low steps (4-8) and low guidance (0-1.5).")
250
+
251
+ image = None
252
+ image_base64 = None
253
 
254
  try:
255
+ # Perform inference without calculating gradients
256
+ with torch.no_grad():
257
  result = pipeline(
258
  prompt=prompt,
259
  negative_prompt=negative_prompt,
 
261
  width=width,
262
  num_inference_steps=num_inference_steps,
263
  guidance_scale=guidance_scale,
264
+ # Optional: Add generator for reproducibility
265
+ # generator=torch.Generator(device=config.DEVICE).manual_seed(some_seed)
266
  )
267
+ # Extract the first image from the result
268
+ image = result.images[0]
269
+ logger.debug("Image generation inference complete.")
270
 
271
+ # Encode the generated PIL image to a base64 string
272
  image_base64 = encode_image_base64(image, format="PNG")
273
+ logger.debug("Image encoded to base64 PNG format.")
274
 
275
  finally:
276
+ # --- Resource Cleanup ---
277
+ # The pipeline object itself remains in the cache
278
+ del image # Delete the generated PIL image object
279
+ # Clear CUDA cache if applicable
280
  if config.DEVICE == "cuda":
281
  torch.cuda.empty_cache()
282
  gc.collect()
 
287
 
288
  def generate_video_sync(
289
  image_base64: str,
290
+ prompt: str | None, # Optional prompt for video models that use it
291
  motion_bucket_id: int,
292
  noise_aug_strength: float,
293
  num_frames: int,
294
  fps: int,
295
  num_inference_steps: int,
296
  guidance_scale: float
297
+ ) -> Tuple[str, str]:
298
+ """
299
+ Synchronous function to generate a video clip from an input image using the loaded video pipeline.
300
+ Designed to be run in a thread pool.
301
+
302
+ Args:
303
+ image_base64: Base64 encoded string of the input image.
304
+ prompt: Optional text prompt (model-dependent usage).
305
+ motion_bucket_id: Control parameter for motion amount (model-specific, e.g., Zeroscope).
306
+ noise_aug_strength: Amount of noise added to input image (model-specific).
307
+ num_frames: Number of frames desired in the output video.
308
+ fps: Frames per second for the output video encoding.
309
+ num_inference_steps: Number of diffusion steps for video generation.
310
+ guidance_scale: Guidance scale for video generation.
311
+
312
+ Returns:
313
+ A tuple containing:
314
+ - video_base64 (str): Base64 encoded string of the generated video (MP4 or GIF).
315
+ - actual_format (str): The actual format the video was encoded in ("MP4" or "GIF").
316
+
317
+ Raises:
318
+ RuntimeError: If the video pipeline is not loaded in the cache.
319
+ ValueError: If the input `image_base64` is invalid.
320
+ """
321
  pipeline = model_cache.get("video_pipeline")
322
  if not pipeline:
323
+ logger.error("Execution failure: Video pipeline not found in cache during video generation.")
324
  raise RuntimeError("Video pipeline not loaded or available.")
325
 
326
  logger.debug("Decoding base64 input image for video generation.")
327
  try:
328
+ # Decode the input image
329
  input_image = decode_base64_image(image_base64)
330
+ logger.debug(f"Decoded input image: size={input_image.size}, mode={input_image.mode}")
331
  except Exception as decode_err:
332
+ logger.error(f"Failed to decode base64 input image: {decode_err}", exc_info=True)
333
+ # Raise a specific error that can be caught by the API route
334
+ raise ValueError("Invalid base64 input image provided.") from decode_err
335
 
336
+ logger.debug(f"Generating video: frames={num_frames}, fps={fps}, steps={num_inference_steps}")
337
+
338
+ # Variables for cleanup
339
+ video_frames_pil = None
340
+ video_frames_np = None
341
+ video_base64 = None
342
+ actual_format = "N/A"
343
 
344
  try:
345
+ # Perform inference without gradients
346
  with torch.no_grad():
347
  # CPU offload handles device placement if enabled during load_models
348
+ pipeline_output = pipeline(
349
  input_image,
350
+ prompt=prompt,
351
  num_inference_steps=num_inference_steps,
352
  num_frames=num_frames,
353
+ height=input_image.height, # Use input image dimensions
354
  width=input_image.width,
355
  guidance_scale=guidance_scale,
356
+ # Model-specific parameters (like for Zeroscope)
357
  motion_bucket_id=motion_bucket_id,
358
  noise_aug_strength=noise_aug_strength
359
+ )
360
+ # Output format can vary; often `.frames` is a list containing one list of PIL images
361
+ if hasattr(pipeline_output, 'frames') and isinstance(pipeline_output.frames, list) and len(pipeline_output.frames) > 0:
362
+ video_frames_pil = pipeline_output.frames[0] # Assuming the structure [[frame1, frame2...]]
363
+ else:
364
+ # Handle potential variations in output structure if needed
365
+ logger.error(f"Unexpected video pipeline output structure: {type(pipeline_output)}")
366
+ raise RuntimeError("Video generation produced unexpected output format.")
367
+ logger.debug(f"Video frame generation complete ({len(video_frames_pil)} frames).")
368
+
369
+ # Convert PIL frames to NumPy arrays for video encoding
370
  video_frames_np = [np.array(frame) for frame in video_frames_pil]
371
  logger.debug("Converted video frames to NumPy arrays.")
372
 
373
+ # Encode the NumPy frames into a base64 video string (tries MP4, falls back to GIF)
374
+ video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4")
375
+ logger.debug(f"Video encoded to base64 with actual format: {actual_format}")
376
 
377
  finally:
378
+ # --- Resource Cleanup ---
379
+ del input_image # Delete decoded input image
380
+ del video_frames_pil # Delete list of PIL frames
381
+ del video_frames_np # Delete list of numpy frames
382
+ # Clear CUDA cache if applicable
383
  if config.DEVICE == "cuda":
384
+ torch.cuda.empty_cache()
385
  gc.collect()
386
  logger.debug("Cleaned up resources after video generation.")
387