rajux75 commited on
Commit
45852d0
·
verified ·
1 Parent(s): 5f7ce0f

Update services/generation.py

Browse files
Files changed (1) hide show
  1. services/generation.py +153 -87
services/generation.py CHANGED
@@ -2,17 +2,17 @@
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from PIL import Image
 
5
  import config
6
  from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
7
  import logging
8
  import gc # Garbage collector
9
- from typing import List
10
- from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler # Import LCMScheduler
11
- from peft import PeftConfig # Import PeftConfig (if needed, usually handled by load_lora_weights)
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
-
16
  # --- Global Model Cache ---
17
  # Use a dictionary to hold loaded models and tokenizers
18
  # This allows loading them only once when the app starts.
@@ -21,53 +21,72 @@ model_cache = {}
21
  def load_models():
22
  """Loads all models into the cache. Called at application startup."""
23
  logger.info("Loading models...")
24
- try:
25
- # Text Generation Model
 
26
  logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
27
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
28
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
29
  logger.info("Text model loaded.")
30
 
31
- # --- Image Generation Model ---
32
- logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
33
- image_pipeline = StableDiffusionPipeline.from_pretrained(
34
- config.IMAGE_MODEL_NAME,
35
- torch_dtype=config.DTYPE
36
- )
37
-
38
- # --- Load LCM LoRA ---
39
- try:
40
- logger.info(f"Loading LCM LoRA: {config.IMAGE_LCM_LORA_NAME}")
41
- # Load LoRA weights directly into the pipeline
42
- image_pipeline.load_lora_weights(config.IMAGE_LCM_LORA_NAME)
43
- # Fuse LoRA for potential speedup (optional, test impact)
44
- # image_pipeline.fuse_lora()
45
- logger.info("LCM LoRA loaded successfully.")
46
-
47
- # --- IMPORTANT: Set LCM Scheduler ---
48
- image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
49
- logger.info("Switched scheduler to LCMScheduler.")
50
-
51
- except Exception as e:
52
- logger.warning(f"Could not load or apply LCM LoRA '{config.IMAGE_LCM_LORA_NAME}'. Falling back to base model scheduler. Error: {e}", exc_info=True)
53
- # Fallback to a standard fast scheduler if LCM fails
54
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
-
57
- image_pipeline = image_pipeline.to(config.DEVICE)
58
- if config.DEVICE == "cuda":
59
- try:
60
- # image_pipeline.enable_xformers_memory_efficient_attention()
61
- pass
62
- except ImportError:
63
- logger.warning("xformers not installed...")
64
- # image_pipeline.enable_attention_slicing()
65
-
66
- model_cache["image_pipeline"] = image_pipeline
67
- logger.info("Image model setup complete.")
68
-
69
-
70
- # Video Generation Model
 
 
 
 
 
 
 
71
  logger.info(f"Loading video model: {config.VIDEO_MODEL_NAME}")
72
  video_pipeline = DiffusionPipeline.from_pretrained(
73
  config.VIDEO_MODEL_NAME,
@@ -75,19 +94,36 @@ def load_models():
75
  variant="fp16" if config.DTYPE == torch.float16 else None # Zeroscope often has fp16 variants
76
  )
77
  video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
78
- video_pipeline.enable_model_cpu_offload() # Crucial for low VRAM environments like Spaces CPU/T4
79
- # video_pipeline = video_pipeline.to(config.DEVICE) # CPU offload handles device placement
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  model_cache["video_pipeline"] = video_pipeline
82
- logger.info("Video model loaded.")
83
 
84
- except Exception as e:
85
- logger.error(f"Error loading models: {e}", exc_info=True)
86
- # Depending on policy, you might want to raise the exception
87
- # or allow the app to start with missing models (endpoints will fail)
88
- raise # Reraise to prevent app start if essential models fail
89
 
90
- logger.info("All models loaded successfully.")
 
 
 
 
91
 
92
 
93
  def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
@@ -95,33 +131,41 @@ def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[st
95
  tokenizer = model_cache.get("text_tokenizer")
96
  model = model_cache.get("text_model")
97
  if not tokenizer or not model:
98
- raise RuntimeError("Text model not loaded.")
 
 
99
 
100
- # Adjust prompt slightly for better instruction following if needed (e.g., for Flan-T5)
101
- # input_text = f"Generate {num_ideas} content ideas about: {prompt}"
102
  input_text = prompt # Keep original prompt based on request model
103
 
104
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Max input length for model
105
-
106
- # Generation parameters
107
- outputs = model.generate(
108
- **inputs,
109
- max_length=max_length,
110
- num_return_sequences=num_ideas,
111
- do_sample=True, # Use sampling for more diverse ideas
112
- temperature=0.8,
113
- top_k=50,
114
- top_p=0.95,
115
- no_repeat_ngram_size=2 # Avoid repetitive phrases
116
- )
117
-
118
- ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
119
- # Clean up GPU memory if applicable
120
- del inputs
121
- del outputs
122
- if config.DEVICE == "cuda":
123
- torch.cuda.empty_cache()
124
- gc.collect()
 
 
 
 
 
 
 
125
  return ideas
126
 
127
 
@@ -129,7 +173,13 @@ def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, w
129
  """Synchronous function for image generation (run in thread pool)."""
130
  pipeline = model_cache.get("image_pipeline")
131
  if not pipeline:
132
- raise RuntimeError("Image pipeline not loaded.")
 
 
 
 
 
 
133
 
134
  try:
135
  with torch.no_grad(): # Conserve memory during inference
@@ -143,15 +193,19 @@ def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, w
143
  # generator=torch.Generator(device=config.DEVICE).manual_seed(seed) # Optional: for reproducibility
144
  )
145
  image: Image.Image = result.images[0]
 
146
 
147
  # Encode image to base64
148
  image_base64 = encode_image_base64(image, format="PNG")
 
149
 
150
  finally:
151
  # Clean up GPU memory if applicable
 
152
  if config.DEVICE == "cuda":
153
  torch.cuda.empty_cache()
154
  gc.collect()
 
155
 
156
  return image_base64
157
 
@@ -165,18 +219,26 @@ def generate_video_sync(
165
  fps: int,
166
  num_inference_steps: int,
167
  guidance_scale: float
168
- ) -> tuple[str, str]:
169
  """Synchronous function for video generation (run in thread pool)."""
170
  pipeline = model_cache.get("video_pipeline")
171
  if not pipeline:
172
- raise RuntimeError("Video pipeline not loaded.")
 
 
 
 
 
 
 
 
173
 
174
- input_image = decode_base64_image(image_base64)
175
 
176
  try:
177
  with torch.no_grad():
178
- # CPU offload handles device placement, no need for explicit .to(config.DEVICE)
179
- video_frames = pipeline(
180
  input_image,
181
  prompt=prompt, # Zeroscope uses prompt less directly, more for style maybe
182
  num_inference_steps=num_inference_steps,
@@ -187,21 +249,25 @@ def generate_video_sync(
187
  motion_bucket_id=motion_bucket_id,
188
  noise_aug_strength=noise_aug_strength
189
  ).frames[0] # Output is often nested [[frame1, frame2...]]
 
190
 
191
  # video_frames is usually List[PIL.Image], convert to numpy for encoding
192
- video_frames_np = [np.array(frame) for frame in video_frames]
 
193
 
194
  # Encode video to base64
195
  video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") # Request MP4, helper handles fallback
 
196
 
197
  finally:
198
  # Clean up GPU/CPU memory
199
  # Offloading handles VRAM well, but ensure general RAM is freed
200
  del input_image
201
- del video_frames
202
- del video_frames_np
203
  if config.DEVICE == "cuda":
204
  torch.cuda.empty_cache() # Still good practice
205
  gc.collect()
 
206
 
207
  return video_base64, actual_format
 
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from PIL import Image
5
+ import numpy as np # Added import for numpy array conversion later
6
  import config
7
  from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
8
  import logging
9
  import gc # Garbage collector
10
+ from typing import List, Tuple # Added Tuple for generate_video_sync return type hint
11
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
12
+ # from peft import PeftConfig # Usually not needed directly if using load_lora_weights
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
16
  # --- Global Model Cache ---
17
  # Use a dictionary to hold loaded models and tokenizers
18
  # This allows loading them only once when the app starts.
 
21
  def load_models():
22
  """Loads all models into the cache. Called at application startup."""
23
  logger.info("Loading models...")
24
+ try: # <<<--- Start of the MAIN try block for all models ---<<<
25
+
26
+ # --- Text Generation Model ---
27
  logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
28
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
29
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
30
  logger.info("Text model loaded.")
31
 
32
+ # --- Image Generation Model (Base) ---
33
+ logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
34
+ image_pipeline = StableDiffusionPipeline.from_pretrained(
35
+ config.IMAGE_MODEL_NAME,
36
+ torch_dtype=config.DTYPE
37
+ )
38
+ # Default scheduler (will be potentially overridden by LCM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
40
+ logger.info("Image base pipeline loaded. Default scheduler: DPMSolverMultistepScheduler.")
41
+
42
+ # --- Attempt to Load LCM LoRA (Optional Speedup) ---
43
+ # Check if IMAGE_LCM_LORA_NAME is defined and not empty in config
44
+ lcm_lora_name = getattr(config, 'IMAGE_LCM_LORA_NAME', None) # Safely get LORA name
45
+ if lcm_lora_name:
46
+ try:
47
+ logger.info(f"Attempting to load LCM LoRA: {lcm_lora_name}")
48
+ # Load LoRA weights directly into the pipeline
49
+ image_pipeline.load_lora_weights(lcm_lora_name)
50
+ # Fuse LoRA for potential speedup (optional, test impact)
51
+ # image_pipeline.fuse_lora()
52
+ logger.info("LCM LoRA loaded successfully.")
53
+
54
+ # IMPORTANT: Set LCM Scheduler *only if* LoRA loaded successfully
55
+ image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
56
+ logger.info("Switched scheduler to LCMScheduler.")
57
+
58
+ except Exception as e:
59
+ logger.warning(f"Could not load or apply LCM LoRA '{lcm_lora_name}'. Using default scheduler. Error: {e}", exc_info=True)
60
+ # Scheduler already set to DPMSolverMultistepScheduler above, so no action needed here
61
+ else:
62
+ logger.info("No IMAGE_LCM_LORA_NAME configured in environment/config. Using default scheduler.")
63
+
64
+ # --- Image Pipeline Device Placement and Optimizations ---
65
+ image_pipeline = image_pipeline.to(config.DEVICE)
66
+ logger.info(f"Image pipeline moved to device: {config.DEVICE}")
67
 
68
+ if config.DEVICE == "cuda":
69
+ # Optional: Enable memory efficient attention mechanisms if GPU available and libs installed
70
+ try:
71
+ # Requires: pip install xformers
72
+ # image_pipeline.enable_xformers_memory_efficient_attention()
73
+ # logger.info("Enabled xformers memory efficient attention.")
74
+ pass # Keep commented out if xformers not installed/intended
75
+ except ImportError:
76
+ logger.warning("xformers not installed or enabled. Consider installing for potential memory savings on GPU.")
77
+ # Fallback option if xformers is not available
78
+ # try:
79
+ # image_pipeline.enable_attention_slicing()
80
+ # logger.info("Enabled attention slicing.")
81
+ # except Exception as attn_slice_e:
82
+ # logger.warning(f"Could not enable attention slicing: {attn_slice_e}")
83
+
84
+ # --- Store Image Pipeline in Cache ---
85
+ model_cache["image_pipeline"] = image_pipeline
86
+ logger.info("Image model setup complete and cached.")
87
+
88
+
89
+ # --- Video Generation Model ---
90
  logger.info(f"Loading video model: {config.VIDEO_MODEL_NAME}")
91
  video_pipeline = DiffusionPipeline.from_pretrained(
92
  config.VIDEO_MODEL_NAME,
 
94
  variant="fp16" if config.DTYPE == torch.float16 else None # Zeroscope often has fp16 variants
95
  )
96
  video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
97
+ logger.info("Video pipeline loaded. Scheduler: DPMSolverMultistepScheduler.")
98
+
99
+ # Enable CPU offloading *before* potentially moving parts to GPU if not offloading everything
100
+ # This is crucial for fitting larger models in limited VRAM/RAM.
101
+ try:
102
+ video_pipeline.enable_model_cpu_offload()
103
+ logger.info("Enabled model CPU offload for video pipeline.")
104
+ except AttributeError:
105
+ logger.warning("Video pipeline class may not support enable_model_cpu_offload(). Attempting to move entire model to device.")
106
+ # Fallback if offload method isn't available on this specific pipeline class
107
+ try:
108
+ video_pipeline = video_pipeline.to(config.DEVICE)
109
+ logger.info(f"Video pipeline moved to device: {config.DEVICE}")
110
+ except Exception as move_err:
111
+ logger.error(f"Failed to move video pipeline to device {config.DEVICE}: {move_err}", exc_info=True)
112
+ # Decide if you want to raise here or let it fail later
113
+ # raise
114
+
115
+ # Store video pipeline in cache
116
  model_cache["video_pipeline"] = video_pipeline
117
+ logger.info("Video model setup complete and cached.")
118
 
119
+ # --- Success Message ---
120
+ logger.info("All configured models loaded successfully.") # Runs only if all steps above succeed
 
 
 
121
 
122
+ except Exception as e: # <<<--- Catches errors from ANY model loading step ---<<<
123
+ logger.error(f"FATAL: Error loading one or more models during startup: {e}", exc_info=True)
124
+ # Re-raise the exception to prevent the application from starting
125
+ # in a state where essential models are missing.
126
+ raise
127
 
128
 
129
  def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
 
131
  tokenizer = model_cache.get("text_tokenizer")
132
  model = model_cache.get("text_model")
133
  if not tokenizer or not model:
134
+ # This should ideally not happen if load_models raises on failure
135
+ logger.error("Attempted to generate ideas but text model/tokenizer not found in cache.")
136
+ raise RuntimeError("Text model not loaded or available.")
137
 
138
+ logger.debug(f"Generating ideas for prompt: '{prompt}'")
 
139
  input_text = prompt # Keep original prompt based on request model
140
 
141
+ try:
142
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Max input length for model
143
+
144
+ # Generation parameters
145
+ with torch.no_grad(): # Ensure no gradients are computed
146
+ outputs = model.generate(
147
+ **inputs,
148
+ max_length=max_length,
149
+ num_return_sequences=num_ideas,
150
+ do_sample=True, # Use sampling for more diverse ideas
151
+ temperature=0.8,
152
+ top_k=50,
153
+ top_p=0.95,
154
+ no_repeat_ngram_size=2 # Avoid repetitive phrases
155
+ )
156
+
157
+ ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
158
+ logger.debug(f"Generated {len(ideas)} ideas.")
159
+
160
+ finally:
161
+ # Clean up GPU memory if applicable
162
+ del inputs
163
+ del outputs
164
+ if config.DEVICE == "cuda":
165
+ torch.cuda.empty_cache()
166
+ gc.collect()
167
+ logger.debug("Cleaned up resources after idea generation.")
168
+
169
  return ideas
170
 
171
 
 
173
  """Synchronous function for image generation (run in thread pool)."""
174
  pipeline = model_cache.get("image_pipeline")
175
  if not pipeline:
176
+ logger.error("Attempted to generate image but image pipeline not found in cache.")
177
+ raise RuntimeError("Image pipeline not loaded or available.")
178
+
179
+ logger.debug(f"Generating image for prompt: '{prompt}'")
180
+ # Note: If using LCM, optimal steps are much lower (e.g., 4-8) and guidance might be 0 or 1.
181
+ # Consider adding logic here or in the API route to adjust params if LCM is active.
182
+ # For now, it uses the user-provided parameters.
183
 
184
  try:
185
  with torch.no_grad(): # Conserve memory during inference
 
193
  # generator=torch.Generator(device=config.DEVICE).manual_seed(seed) # Optional: for reproducibility
194
  )
195
  image: Image.Image = result.images[0]
196
+ logger.debug("Image generation complete.")
197
 
198
  # Encode image to base64
199
  image_base64 = encode_image_base64(image, format="PNG")
200
+ logger.debug("Image encoded to base64.")
201
 
202
  finally:
203
  # Clean up GPU memory if applicable
204
+ # pipeline object itself is persistent in cache, don't delete it
205
  if config.DEVICE == "cuda":
206
  torch.cuda.empty_cache()
207
  gc.collect()
208
+ logger.debug("Cleaned up resources after image generation.")
209
 
210
  return image_base64
211
 
 
219
  fps: int,
220
  num_inference_steps: int,
221
  guidance_scale: float
222
+ ) -> Tuple[str, str]: # Corrected return type hint
223
  """Synchronous function for video generation (run in thread pool)."""
224
  pipeline = model_cache.get("video_pipeline")
225
  if not pipeline:
226
+ logger.error("Attempted to generate video but video pipeline not found in cache.")
227
+ raise RuntimeError("Video pipeline not loaded or available.")
228
+
229
+ logger.debug("Decoding base64 input image for video generation.")
230
+ try:
231
+ input_image = decode_base64_image(image_base64)
232
+ except Exception as decode_err:
233
+ logger.error(f"Failed to decode base64 image: {decode_err}", exc_info=True)
234
+ raise ValueError("Invalid base64 input image.") from decode_err
235
 
236
+ logger.debug(f"Generating video from image, frames={num_frames}, fps={fps}")
237
 
238
  try:
239
  with torch.no_grad():
240
+ # CPU offload handles device placement if enabled during load_models
241
+ video_frames_pil = pipeline(
242
  input_image,
243
  prompt=prompt, # Zeroscope uses prompt less directly, more for style maybe
244
  num_inference_steps=num_inference_steps,
 
249
  motion_bucket_id=motion_bucket_id,
250
  noise_aug_strength=noise_aug_strength
251
  ).frames[0] # Output is often nested [[frame1, frame2...]]
252
+ logger.debug("Video frame generation complete.")
253
 
254
  # video_frames is usually List[PIL.Image], convert to numpy for encoding
255
+ video_frames_np = [np.array(frame) for frame in video_frames_pil]
256
+ logger.debug("Converted video frames to NumPy arrays.")
257
 
258
  # Encode video to base64
259
  video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") # Request MP4, helper handles fallback
260
+ logger.debug(f"Video encoded to base64 with format: {actual_format}")
261
 
262
  finally:
263
  # Clean up GPU/CPU memory
264
  # Offloading handles VRAM well, but ensure general RAM is freed
265
  del input_image
266
+ if 'video_frames_pil' in locals(): del video_frames_pil
267
+ if 'video_frames_np' in locals(): del video_frames_np
268
  if config.DEVICE == "cuda":
269
  torch.cuda.empty_cache() # Still good practice
270
  gc.collect()
271
+ logger.debug("Cleaned up resources after video generation.")
272
 
273
  return video_base64, actual_format