jbilcke-hf commited on
Commit
4e8d40c
·
verified ·
1 Parent(s): 80253b4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +172 -78
handler.py CHANGED
@@ -1,12 +1,14 @@
1
  from dataclasses import dataclass
2
  from typing import Dict, Any, Optional
3
  import base64
 
4
  import logging
5
  import random
6
  import traceback
7
  import torch
8
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
9
  from varnish import Varnish
 
10
 
11
  from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
12
  from teacache import enable_teacache, disable_teacache
@@ -15,6 +17,9 @@ from teacache import enable_teacache, disable_teacache
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
18
  @dataclass
19
  class GenerationConfig:
20
  """Configuration for video generation"""
@@ -51,7 +56,12 @@ class GenerationConfig:
51
 
52
  # Enhance-A-Video settings
53
  enable_enhance_a_video: bool = True
54
- enhance_a_video_weight: float = 4.0
 
 
 
 
 
55
 
56
  def validate_and_adjust(self) -> 'GenerationConfig':
57
  """Validate and adjust parameters"""
@@ -83,25 +93,37 @@ class EndpointHandler:
83
  subfolder="transformer",
84
  torch_dtype=torch.bfloat16
85
  )
86
- inject_enhance_for_hunyuanvideo(transformer)
87
-
88
- # Initialize HunyuanVideo pipeline with the enhanced transformer
89
- self.pipeline = HunyuanVideoPipeline.from_pretrained(
90
- path,
91
- transformer=transformer,
92
- torch_dtype=torch.float16,
93
- ).to(self.device)
94
-
95
 
96
- # Initialize text encoders in float16
97
- self.pipeline.text_encoder = self.pipeline.text_encoder.half()
98
- self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
99
-
100
- # Initialize transformer in bfloat16
101
- self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
102
-
103
- # Initialize VAE in float16
104
- self.pipeline.vae = self.pipeline.vae.half()
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Initialize Varnish for post-processing
107
  self.varnish = Varnish(
@@ -109,6 +131,56 @@ class EndpointHandler:
109
  model_base_dir="/repository/varnish"
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
  """Process video generation requests
114
 
@@ -156,7 +228,11 @@ class EndpointHandler:
156
  teacache_threshold=params.get("teacache_threshold", 0.15),
157
 
158
  enable_enhance_a_video=params.get("enable_enhance_a_video", True),
159
- enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0)
 
 
 
 
160
  ).validate_and_adjust()
161
 
162
  try:
@@ -178,77 +254,95 @@ class EndpointHandler:
178
  #else:
179
  # disable_teacache(self.pipeline.transformer)
180
 
181
- # Configure Enhance-A-Video weight if enabled
182
- if config.enable_enhance_a_video:
183
- set_enhance_weight(config.enhance_a_video_weight)
184
- enable_enhance()
185
- else:
186
- # Reset enhance weight to 0 to effectively disable it
187
- set_enhance_weight(0)
188
-
189
- # Generate video frames
190
  with torch.inference_mode():
191
- output = self.pipeline(
192
- prompt=config.prompt,
 
 
 
 
 
193
 
 
 
 
 
 
194
  # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
195
- #negative_prompt=config.negative_prompt,
196
 
197
- num_frames=config.num_frames,
198
- height=config.height,
199
- width=config.width,
200
- num_inference_steps=config.num_inference_steps,
201
- guidance_scale=config.guidance_scale,
202
- generator=generator,
203
- output_type="pt",
204
- ).frames
205
-
206
- # Process with Varnish
207
- import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  try:
209
  loop = asyncio.get_event_loop()
210
  except RuntimeError:
211
  loop = asyncio.new_event_loop()
212
  asyncio.set_event_loop(loop)
213
-
214
- result = loop.run_until_complete(
215
- self.varnish(
216
- input_data=output,
217
- fps=config.fps,
218
- double_num_frames=config.double_num_frames,
219
- super_resolution=config.super_resolution,
220
- grain_amount=config.grain_amount,
221
- enable_audio=config.enable_audio,
222
- audio_prompt=config.audio_prompt,
223
- audio_negative_prompt=config.audio_negative_prompt,
224
- )
225
- )
226
-
227
- # Get video data URI
228
- video_uri = loop.run_until_complete(
229
- result.write(
230
- type="data-uri",
231
- quality=config.quality
232
- )
233
- )
234
-
235
  return {
236
  "video": video_uri,
237
  "content-type": "video/mp4",
238
- "metadata": {
239
- "width": result.metadata.width,
240
- "height": result.metadata.height,
241
- "num_frames": result.metadata.frame_count,
242
- "fps": result.metadata.fps,
243
- "duration": result.metadata.duration,
244
- "seed": config.seed,
245
- "enable_teacache": config.enable_teacache,
246
- "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
247
- "enable_enhance_a_video": config.enable_enhance_a_video,
248
- "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
249
- }
250
  }
251
-
252
  except Exception as e:
253
  message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
254
  logger.error(message)
 
1
  from dataclasses import dataclass
2
  from typing import Dict, Any, Optional
3
  import base64
4
+ import asyncio
5
  import logging
6
  import random
7
  import traceback
8
  import torch
9
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
10
  from varnish import Varnish
11
+ from varnish.utils import is_truthy, process_input_image
12
 
13
  from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
14
  from teacache import enable_teacache, disable_teacache
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Check environment variable for pipeline support
21
+ support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
22
+
23
  @dataclass
24
  class GenerationConfig:
25
  """Configuration for video generation"""
 
56
 
57
  # Enhance-A-Video settings
58
  enable_enhance_a_video: bool = True
59
+ enhance_a_video_weight: float = 5.0
60
+
61
+ # LoRA settings
62
+ lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
63
+ lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
64
+ lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
65
 
66
  def validate_and_adjust(self) -> 'GenerationConfig':
67
  """Validate and adjust parameters"""
 
93
  subfolder="transformer",
94
  torch_dtype=torch.bfloat16
95
  )
96
+
97
+ if support_image_prompt:
98
+ # Initialize image-to-video pipeline
99
+ self.image_to_video = HunyuanImageToVideoPipeline.from_pretrained(
100
+ path,
101
+ transformer=transformer,
102
+ torch_dtype=torch.float16,
103
+ ).to(self.device)
 
104
 
105
+ # Initialize components in appropriate precision
106
+ self.image_to_video.text_encoder = self.image_to_video.text_encoder.half()
107
+ self.image_to_video.text_encoder_2 = self.image_to_video.text_encoder_2.half()
108
+ self.image_to_video.transformer = self.image_to_video.transformer.to(torch.bfloat16)
109
+ self.image_to_video.vae = self.image_to_video.vae.half()
110
+ else:
111
+ # Initialize text-to-video pipeline
112
+ self.text_to_video = HunyuanVideoPipeline.from_pretrained(
113
+ path,
114
+ transformer=transformer,
115
+ torch_dtype=torch.float16,
116
+ ).to(self.device)
117
+
118
+ # Initialize components in appropriate precision
119
+ self.text_to_video.text_encoder = self.text_to_video.text_encoder.half()
120
+ self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half()
121
+ self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
122
+ self.text_to_video.vae = self.text_to_video.vae.half()
123
+
124
+
125
+ # Initialize LoRA tracking
126
+ self._current_lora_model = None
127
 
128
  # Initialize Varnish for post-processing
129
  self.varnish = Varnish(
 
131
  model_base_dir="/repository/varnish"
132
  )
133
 
134
+ async def process_frames(
135
+ self,
136
+ frames: torch.Tensor,
137
+ config: GenerationConfig
138
+ ) -> tuple[str, dict]:
139
+ """Post-process generated frames using Varnish
140
+
141
+ Args:
142
+ frames: Generated video frames tensor
143
+ config: Generation configuration
144
+
145
+ Returns:
146
+ Tuple of (video data URI, metadata dictionary)
147
+ """
148
+ try:
149
+ # Process video with Varnish
150
+ result = await self.varnish(
151
+ input_data=frames,
152
+ fps=config.fps,
153
+ double_num_frames=config.double_num_frames,
154
+ super_resolution=config.super_resolution,
155
+ grain_amount=config.grain_amount,
156
+ enable_audio=config.enable_audio,
157
+ audio_prompt=config.audio_prompt,
158
+ audio_negative_prompt=config.audio_negative_prompt
159
+ )
160
+
161
+ # Convert to data URI
162
+ video_uri = await result.write(type="data-uri", quality=config.quality)
163
+
164
+ # Collect metadata
165
+ metadata = {
166
+ "width": result.metadata.width,
167
+ "height": result.metadata.height,
168
+ "num_frames": result.metadata.frame_count,
169
+ "fps": result.metadata.fps,
170
+ "duration": result.metadata.duration,
171
+ "seed": config.seed,
172
+ "enable_teacache": config.enable_teacache,
173
+ "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
174
+ "enable_enhance_a_video": config.enable_enhance_a_video,
175
+ "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
176
+ }
177
+
178
+ return video_uri, metadata
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error in process_frames: {str(e)}")
182
+ raise RuntimeError(f"Failed to process frames: {str(e)}")
183
+
184
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
185
  """Process video generation requests
186
 
 
228
  teacache_threshold=params.get("teacache_threshold", 0.15),
229
 
230
  enable_enhance_a_video=params.get("enable_enhance_a_video", True),
231
+ enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
232
+
233
+ lora_model_name=params.get("lora_model_name", ""),
234
+ lora_model_weight_file=params.get("lora_model_weight_file", ""),
235
+ lora_model_trigger=params.get("lora_model_trigger", ""),
236
  ).validate_and_adjust()
237
 
238
  try:
 
254
  #else:
255
  # disable_teacache(self.pipeline.transformer)
256
 
 
 
 
 
 
 
 
 
 
257
  with torch.inference_mode():
258
+ # Configure Enhance-A-Video weight if enabled
259
+ if config.enable_enhance_a_video:
260
+ set_enhance_weight(config.enhance_a_video_weight)
261
+ enable_enhance()
262
+ else:
263
+ # Reset enhance weight to 0 to effectively disable it
264
+ set_enhance_weight(0)
265
 
266
+ # Prepare generation parameters
267
+ generation_kwargs = {
268
+ "prompt": config.prompt,
269
+
270
+
271
  # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
272
+ #"negative_prompt": config.negative_prompt,
273
 
274
+ "num_frames": config.num_frames,
275
+ "height": config.height,
276
+ "width": config.width,
277
+ "num_inference_steps": config.num_inference_steps,
278
+ "guidance_scale": config.guidance_scale,
279
+ "generator": generator,
280
+ "output_type": "pt",
281
+ }
282
+
283
+ # Handle LoRA loading/unloading
284
+ if hasattr(self, '_current_lora_model'):
285
+ if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
286
+ # Unload previous LoRA if it exists and is different
287
+ if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
288
+ self.image_to_video.unload_lora_weights()
289
+ else:
290
+ if hasattr(self.text_to_video, 'unload_lora_weights'):
291
+ self.text_to_video.unload_lora_weights()
292
+
293
+ if config.lora_model_name:
294
+ # Load new LoRA
295
+ if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
296
+ self.image_to_video.load_lora_weights(
297
+ config.lora_model_name,
298
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
299
+ token=hf_token,
300
+ )
301
+ else:
302
+ if hasattr(self.text_to_video, 'load_lora_weights'):
303
+ self.text_to_video.load_lora_weights(
304
+ config.lora_model_name,
305
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
306
+ token=hf_token,
307
+ )
308
+ self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
309
+
310
+ # Modify prompt if trigger word is provided
311
+ if config.lora_model_trigger:
312
+ generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
313
+
314
+
315
+
316
+ # Check if image-to-video generation is requested
317
+ if support_image_prompt and input_image:
318
+ self._configure_teacache(self.image_to_video, config)
319
+ processed_image = process_input_image(
320
+ input_image,
321
+ config.width,
322
+ config.height,
323
+ config.input_image_quality,
324
+ )
325
+ generation_kwargs["image"] = processed_image
326
+ frames = self.image_to_video(**generation_kwargs).frames
327
+ else:
328
+ self._configure_teacache(self.text_to_video, config)
329
+ frames = self.text_to_video(**generation_kwargs).frames
330
+
331
+
332
  try:
333
  loop = asyncio.get_event_loop()
334
  except RuntimeError:
335
  loop = asyncio.new_event_loop()
336
  asyncio.set_event_loop(loop)
337
+
338
+ video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
339
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  return {
341
  "video": video_uri,
342
  "content-type": "video/mp4",
343
+ "metadata": metadata
 
 
 
 
 
 
 
 
 
 
 
344
  }
345
+
346
  except Exception as e:
347
  message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
348
  logger.error(message)