Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import argparse | |
| import sys | |
| import time | |
| import os | |
| import random | |
| from skyreelsinfer.offload import Offload, OffloadConfig | |
| from skyreelsinfer.pipelines import SkyreelsVideoPipeline | |
| from skyreelsinfer import TaskType | |
| #from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer | |
| from diffusers import HunyuanVideoTransformer3DModel | |
| from diffusers.utils import export_to_video | |
| from diffusers.utils import load_image | |
| from PIL import Image | |
| import numpy as np | |
| from torchao.quantization import float8_weight_only | |
| from torchao.quantization import quantize_ | |
| from transformers import LlamaModel | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cuda.preferred_blas_library="cublas" | |
| torch.backends.cuda.preferred_linalg_library="cusolver" | |
| torch.set_float32_matmul_precision("highest") | |
| torch.backends.cuda.enable_cudnn_sdp(False) # Still a good idea to keep it. | |
| os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1") | |
| os.environ["SAFETENSORS_FAST_GPU"] = "1" | |
| os.putenv("TOKENIZERS_PARALLELISM","False") | |
| model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" | |
| base_model_id = "hunyuanvideo-community/HunyuanVideo" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| offload_config=OffloadConfig( | |
| high_cpu_memory=True, | |
| parameters_level=True, | |
| compiler_transformer=False, | |
| ) | |
| def init_predictor(): | |
| global pipe | |
| text_encoder = LlamaModel.from_pretrained( | |
| base_model_id, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| ).to("cpu") | |
| transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
| model_id, | |
| # subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| #device="cpu", | |
| ).to("cuda").eval() | |
| #quantize_(text_encoder, float8_weight_only(), device="cpu") | |
| #text_encoder.to("cpu") | |
| #torch.cuda.empty_cache() | |
| #quantize_(transformer, float8_weight_only(), device="cpu") | |
| #transformer.to("cuda") | |
| #torch.cuda.empty_cache() | |
| pipe = SkyreelsVideoPipeline.from_pretrained( | |
| base_model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch.bfloat16, | |
| ) #.to("cpu") | |
| pipe.vae.to('cpu') | |
| pipe.vae.enable_tiling() | |
| torch.cuda.empty_cache() | |
| negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" | |
| def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ): | |
| if segment==1: | |
| random.seed(time.time()) | |
| seed = int(random.randrange(4294967294)) | |
| #Offload.offload( | |
| # pipeline=pipe, | |
| # config=offload_config, | |
| #) | |
| pipe.text_encoder.to("cuda") | |
| pipe.text_encoder_2.to("cuda") | |
| with torch.no_grad(): | |
| prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt( | |
| prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device | |
| ) | |
| pipe.text_encoder.to("cpu") | |
| pipe.text_encoder_2.to("cpu") | |
| #pipe.trasformer.to('cuda') | |
| torch.cuda.empty_cache() | |
| generator = torch.Generator(device='cuda').manual_seed(seed) | |
| transformer_dtype = pipe.transformer.dtype | |
| prompt_embeds = prompt_embeds.to(transformer_dtype) | |
| prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) | |
| pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) | |
| negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) | |
| negative_attention_mask = negative_attention_mask.to(transformer_dtype) | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask]) | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) | |
| pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| all_timesteps_cpu = timesteps.cpu() | |
| timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) | |
| segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda") | |
| num_channels_latents = pipe.transformer.config.in_channels | |
| num_channels_latents = int(num_channels_latents / 2) | |
| image = Image.open(image).convert('RGB') | |
| image.resize((size,size), Image.LANCZOS) | |
| pipe.vae.to("cuda") | |
| with torch.no_grad(): | |
| image = pipe.video_processor.preprocess(image, height=size, width=size).to( | |
| device, dtype=prompt_embeds.dtype | |
| ) | |
| num_latent_frames = (frames - 1) // pipe.vae_scale_factor_temporal + 1 | |
| latents = pipe.prepare_latents( | |
| batch_size=1, num_channels_latents=num_channels_latents, height=size, width=size, num_frames=frames, | |
| dtype=torch.float32, device=device, generator=generator, latents=None, | |
| ) | |
| image_latents = pipe.image_latents( | |
| image, 1, size, size, device, torch.float32, num_channels_latents, num_latent_frames | |
| ) | |
| image_latents = image_latents.to("cuda", pipe.transformer.dtype) | |
| pipe.vae.to("cpu") | |
| torch.cuda.empty_cache() | |
| guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
| else: | |
| pipe.vae.to("cpu") | |
| torch.cuda.empty_cache() | |
| transformer_dtype = pipe.transformer.dtype | |
| state_file = f"SkyReel_{segment-1}_{seed}.pt" | |
| state = torch.load(state_file, weights_only=False) | |
| generator = torch.Generator(device='cuda').manual_seed(seed) | |
| latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16) | |
| guidance_scale = state["guidance_scale"] | |
| all_timesteps_cpu = state["all_timesteps"] | |
| size = state["height"] | |
| size = state["width"] | |
| pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device) | |
| timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) | |
| prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16) | |
| pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) | |
| prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16) | |
| image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16) | |
| if segment==9: | |
| pipe.transformer.to('cpu') | |
| torch.cuda.empty_cache() | |
| pipe.vae.to("cuda") | |
| latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor | |
| #with torch.no_grad(): | |
| video = pipe.vae.decode(latents, return_dict=False)[0] | |
| video = pipe.video_processor.postprocess_video(video) | |
| # return HunyuanVideoPipelineOutput(frames=video) | |
| save_dir = f"./" | |
| video_out_file = f"{save_dir}/{seed}.mp4" | |
| print(f"generate video, local path: {video_out_file}") | |
| export_to_video(output, video_out_file, fps=24) | |
| return video_out_file, seed | |
| else: | |
| segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda") | |
| guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
| for i, t in enumerate(pipe.progress_bar(segment_timesteps)): | |
| latents = latents.to(transformer_dtype) | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_image_input = (torch.cat([image_latents] * 2)) | |
| latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) | |
| timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32) | |
| with torch.no_grad(): | |
| noise_pred = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_attention_mask=prompt_attention_mask, | |
| pooled_projections=pooled_prompt_embeds, | |
| guidance=guidance, | |
| # attention_kwargs=attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| intermediate_latents_cpu = latents.detach().cpu() | |
| original_prompt_embeds_cpu = prompt_embeds.cpu() | |
| original_image_latents_cpu = image_latents.cpu() | |
| original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu() | |
| original_prompt_attention_mask_cpu = prompt_attention_mask.cpu() | |
| timesteps = pipe.scheduler.timesteps | |
| all_timesteps_cpu = timesteps.cpu() # Move to CPU | |
| state = { | |
| "intermediate_latents": intermediate_latents_cpu, | |
| "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler | |
| "prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds | |
| "image_latents": original_image_latents_cpu, | |
| "pooled_prompt_embeds": original_pooled_prompt_embeds_cpu, | |
| "prompt_attention_mask": original_prompt_attention_mask_cpu, | |
| "guidance_scale": guidance_scale, | |
| "seed": seed, | |
| "prompt": prompt, # Save originals for reference/verification | |
| "negative_prompt": negative_prompt, | |
| "height": size, # Save dimensions used | |
| "width": size | |
| } | |
| state_file = f"SkyReel_{segment}_{seed}.pt" | |
| torch.save(state, state_file) | |
| return None, seed | |
| def update_ranges(total_steps): | |
| """Calculates and updates the ranges for the 8 slave sliders.""" | |
| step_size = total_steps // 8 # Calculate the size of each segment | |
| ranges = [] | |
| for i in range(8): | |
| lower_bound = i * step_size | |
| ranges.append([lower_bound]) # Add the range to the list | |
| return ranges | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| image = gr.Image(label="Upload Image", type="filepath") | |
| prompt = gr.Textbox(label="Input Prompt") | |
| run_button_1 = gr.Button("Run Segment 1", scale=0) | |
| run_button_2 = gr.Button("Run Segment 2", scale=0) | |
| run_button_3 = gr.Button("Run Segment 3", scale=0) | |
| run_button_4 = gr.Button("Run Segment 4", scale=0) | |
| run_button_5 = gr.Button("Run Segment 5", scale=0) | |
| run_button_6 = gr.Button("Run Segment 6", scale=0) | |
| run_button_7 = gr.Button("Run Segment 7", scale=0) | |
| run_button_8 = gr.Button("Run Segment 8", scale=0) | |
| run_button_9 = gr.Button("Run Decode Video", scale=0) | |
| result = gr.Gallery(label="Result", columns=1, show_label=False) | |
| seed = gr.Number(value=1, label="Seed") | |
| size = gr.Slider( | |
| label="Size", | |
| minimum=256, | |
| maximum=1024, | |
| step=16, | |
| value=368, | |
| ) | |
| frames = gr.Slider( | |
| label="Number of Frames", | |
| minimum=16, | |
| maximum=256, | |
| step=8, | |
| value=64, | |
| ) | |
| steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=1, | |
| maximum=96, | |
| step=1, | |
| value=25, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=16.0, | |
| step=.1, | |
| value=6.0, | |
| ) | |
| submit_button = gr.Button("Generate Video") | |
| output_video = gr.Video(label="Generated Video") | |
| range_sliders = [] | |
| for i in range(8): | |
| slider = gr.Slider( | |
| minimum=1, | |
| maximum=250, | |
| value=[i * (steps.value // 8)], | |
| step=1, | |
| label=f"Range {i + 1}", | |
| ) | |
| range_sliders.append(slider) | |
| steps.change( | |
| update_ranges, | |
| inputs=steps, | |
| outputs=range_sliders, | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_1.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=1), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_2.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=2), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_3.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=3), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_4.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=4), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_5.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=5), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_6.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=6), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_7.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=7), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_8.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=8), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| run_button_9.click, | |
| ], | |
| fn=generate, | |
| inputs=[ | |
| gr.Number(value=9), | |
| image, | |
| prompt, | |
| size, | |
| guidance_scale, | |
| steps, | |
| frames, | |
| seed, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| if __name__ == "__main__": | |
| init_predictor() | |
| demo.launch() |