Update handler.py
Browse files- handler.py +19 -9
handler.py
CHANGED
|
@@ -8,13 +8,11 @@ import traceback
|
|
| 8 |
import torch
|
| 9 |
|
| 10 |
# note: there is no HunyuanImageToVideoPipeline yet in Diffusers
|
| 11 |
-
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
| 12 |
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
|
| 13 |
from varnish import Varnish
|
| 14 |
from varnish.utils import is_truthy, process_input_image
|
| 15 |
|
| 16 |
-
from teacache import enable_teacache, disable_teacache
|
| 17 |
-
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
| 20 |
logger = logging.getLogger(__name__)
|
|
@@ -52,12 +50,12 @@ class GenerationConfig:
|
|
| 52 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
| 53 |
|
| 54 |
# TeaCache settings
|
| 55 |
-
enable_teacache: bool =
|
| 56 |
teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
|
| 57 |
|
| 58 |
|
| 59 |
# Enhance-A-Video settings
|
| 60 |
-
enable_enhance_a_video: bool =
|
| 61 |
enhance_a_video_weight: float = 5.0
|
| 62 |
|
| 63 |
# LoRA settings
|
|
@@ -95,7 +93,7 @@ class EndpointHandler:
|
|
| 95 |
subfolder="transformer",
|
| 96 |
torch_dtype=torch.bfloat16
|
| 97 |
)
|
| 98 |
-
|
| 99 |
if support_image_prompt:
|
| 100 |
raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline")
|
| 101 |
# # Initialize image-to-video pipeline
|
|
@@ -124,6 +122,21 @@ class EndpointHandler:
|
|
| 124 |
self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
|
| 125 |
self.text_to_video.vae = self.text_to_video.vae.half()
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# Initialize LoRA tracking
|
| 129 |
self._current_lora_model = None
|
|
@@ -309,7 +322,6 @@ class EndpointHandler:
|
|
| 309 |
|
| 310 |
# Check if image-to-video generation is requested
|
| 311 |
if support_image_prompt and input_image:
|
| 312 |
-
self._configure_teacache(self.image_to_video, config)
|
| 313 |
processed_image = process_input_image(
|
| 314 |
input_image,
|
| 315 |
config.width,
|
|
@@ -326,8 +338,6 @@ class EndpointHandler:
|
|
| 326 |
|
| 327 |
frames = self.image_to_video(**generation_kwargs).frames
|
| 328 |
else:
|
| 329 |
-
self._configure_teacache(self.text_to_video, config)
|
| 330 |
-
|
| 331 |
apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig(
|
| 332 |
weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
|
| 333 |
num_frames_callback=lambda: (config.num_frames - 1),
|
|
|
|
| 8 |
import torch
|
| 9 |
|
| 10 |
# note: there is no HunyuanImageToVideoPipeline yet in Diffusers
|
| 11 |
+
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FasterCacheConfig
|
| 12 |
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
|
| 13 |
from varnish import Varnish
|
| 14 |
from varnish.utils import is_truthy, process_input_image
|
| 15 |
|
|
|
|
|
|
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
|
|
|
| 50 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
| 51 |
|
| 52 |
# TeaCache settings
|
| 53 |
+
enable_teacache: bool = False
|
| 54 |
teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
|
| 55 |
|
| 56 |
|
| 57 |
# Enhance-A-Video settings
|
| 58 |
+
enable_enhance_a_video: bool = False
|
| 59 |
enhance_a_video_weight: float = 5.0
|
| 60 |
|
| 61 |
# LoRA settings
|
|
|
|
| 93 |
subfolder="transformer",
|
| 94 |
torch_dtype=torch.bfloat16
|
| 95 |
)
|
| 96 |
+
|
| 97 |
if support_image_prompt:
|
| 98 |
raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline")
|
| 99 |
# # Initialize image-to-video pipeline
|
|
|
|
| 122 |
self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
|
| 123 |
self.text_to_video.vae = self.text_to_video.vae.half()
|
| 124 |
|
| 125 |
+
# enable FasterCache
|
| 126 |
+
|
| 127 |
+
# those values are coming from here:
|
| 128 |
+
# https://github.com/huggingface/diffusers/pull/10163/files#diff-777f4ee62cb325371233a450e0f6cc0ba357a3fade2ec2dea912260b4f8d08ceR67-R74
|
| 129 |
+
|
| 130 |
+
faster_cache_config = FasterCacheConfig(
|
| 131 |
+
spatial_attention_block_skip_range=2,
|
| 132 |
+
spatial_attention_timestep_skip_range=(-1, 901),
|
| 133 |
+
unconditional_batch_skip_range=2,
|
| 134 |
+
attention_weight_callback=lambda _: 0.5,
|
| 135 |
+
is_guidance_distilled=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.text_to_video.transformer.enable_cache(config)
|
| 139 |
+
|
| 140 |
|
| 141 |
# Initialize LoRA tracking
|
| 142 |
self._current_lora_model = None
|
|
|
|
| 322 |
|
| 323 |
# Check if image-to-video generation is requested
|
| 324 |
if support_image_prompt and input_image:
|
|
|
|
| 325 |
processed_image = process_input_image(
|
| 326 |
input_image,
|
| 327 |
config.width,
|
|
|
|
| 338 |
|
| 339 |
frames = self.image_to_video(**generation_kwargs).frames
|
| 340 |
else:
|
|
|
|
|
|
|
| 341 |
apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig(
|
| 342 |
weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
|
| 343 |
num_frames_callback=lambda: (config.num_frames - 1),
|