Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,6 +34,35 @@ video_transforms = transforms.Compose(
|
|
| 34 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 35 |
]
|
| 36 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
|
| 39 |
"""
|
|
@@ -65,37 +94,6 @@ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: T
|
|
| 65 |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
| 66 |
return image
|
| 67 |
|
| 68 |
-
def construct_video_pipeline(model_id: str, lora_path: str):
|
| 69 |
-
# Load model and LORA
|
| 70 |
-
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
| 71 |
-
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
| 72 |
-
|
| 73 |
-
# Enable memory savings
|
| 74 |
-
pipe.vae.enable_tiling()
|
| 75 |
-
pipe.enable_model_cpu_offload()
|
| 76 |
-
|
| 77 |
-
with torch.no_grad(): # enable image inputs
|
| 78 |
-
initial_input_channels = pipe.transformer.config.in_channels
|
| 79 |
-
new_img_in = HunyuanVideoPatchEmbed(
|
| 80 |
-
patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
|
| 81 |
-
in_chans=pipe.transformer.config.in_channels * 2,
|
| 82 |
-
embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
|
| 83 |
-
)
|
| 84 |
-
new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
|
| 85 |
-
new_img_in.proj.weight.zero_()
|
| 86 |
-
new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
|
| 87 |
-
if pipe.transformer.x_embedder.proj.bias is not None:
|
| 88 |
-
new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
|
| 89 |
-
pipe.transformer.x_embedder = new_img_in
|
| 90 |
-
|
| 91 |
-
lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
|
| 92 |
-
transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
|
| 93 |
-
pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
|
| 94 |
-
pipe.set_adapters(["i2v"], adapter_weights=[1.0])
|
| 95 |
-
pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
|
| 96 |
-
pipe.unload_lora_weights()
|
| 97 |
-
|
| 98 |
-
return pipe
|
| 99 |
|
| 100 |
def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
|
| 101 |
# Load and preprocess frames
|
|
@@ -317,13 +315,11 @@ def main():
|
|
| 317 |
outputs = [
|
| 318 |
gr.Video(label="Generated Video"),
|
| 319 |
]
|
| 320 |
-
|
| 321 |
-
def generate_video_wrapper(*args):
|
| 322 |
-
return generate_video(pipe, *args)
|
| 323 |
|
| 324 |
# Create the Gradio interface
|
| 325 |
iface = gr.Interface(
|
| 326 |
-
fn=
|
| 327 |
inputs=inputs,
|
| 328 |
outputs=outputs,
|
| 329 |
title="Hunyuan Video Generator",
|
|
|
|
| 34 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 35 |
]
|
| 36 |
)
|
| 37 |
+
model_id = "hunyuanvideo-community/HunyuanVideo"
|
| 38 |
+
lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft") # Replace with the actual LORA path
|
| 39 |
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
| 40 |
+
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
| 41 |
+
|
| 42 |
+
# Enable memory savings
|
| 43 |
+
pipe.vae.enable_tiling()
|
| 44 |
+
pipe.enable_model_cpu_offload()
|
| 45 |
+
|
| 46 |
+
with torch.no_grad(): # enable image inputs
|
| 47 |
+
initial_input_channels = pipe.transformer.config.in_channels
|
| 48 |
+
new_img_in = HunyuanVideoPatchEmbed(
|
| 49 |
+
patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
|
| 50 |
+
in_chans=pipe.transformer.config.in_channels * 2,
|
| 51 |
+
embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
|
| 52 |
+
)
|
| 53 |
+
new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
|
| 54 |
+
new_img_in.proj.weight.zero_()
|
| 55 |
+
new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
|
| 56 |
+
if pipe.transformer.x_embedder.proj.bias is not None:
|
| 57 |
+
new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
|
| 58 |
+
pipe.transformer.x_embedder = new_img_in
|
| 59 |
+
|
| 60 |
+
lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
|
| 61 |
+
transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
|
| 62 |
+
pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
|
| 63 |
+
pipe.set_adapters(["i2v"], adapter_weights=[1.0])
|
| 64 |
+
pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
|
| 65 |
+
pipe.unload_lora_weights()
|
| 66 |
|
| 67 |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
|
| 68 |
"""
|
|
|
|
| 94 |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
| 95 |
return image
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
|
| 99 |
# Load and preprocess frames
|
|
|
|
| 315 |
outputs = [
|
| 316 |
gr.Video(label="Generated Video"),
|
| 317 |
]
|
| 318 |
+
|
|
|
|
|
|
|
| 319 |
|
| 320 |
# Create the Gradio interface
|
| 321 |
iface = gr.Interface(
|
| 322 |
+
fn=generate_video,
|
| 323 |
inputs=inputs,
|
| 324 |
outputs=outputs,
|
| 325 |
title="Hunyuan Video Generator",
|