import time import gradio as gr import spaces import torch from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler from diffusers.utils import export_to_video from PIL import Image from transformers import T5EncoderModel, T5Tokenizer from cogvideo_transformer import CustomCogVideoXTransformer3DModel from EF_Net import EF_Net from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline # Global variables for the pipeline pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" @spaces.GPU def load_pipeline( pretrained_model_path="THUDM/CogVideoX-5b", ef_net_path="weights/EF_Net.pth", dtype_str="bfloat16", ): """Load the Sci-Fi pipeline""" global pipe dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16 # Load models tokenizer = T5Tokenizer.from_pretrained( pretrained_model_path, subfolder="tokenizer" ) text_encoder = T5EncoderModel.from_pretrained( pretrained_model_path, subfolder="text_encoder" ) transformer = CustomCogVideoXTransformer3DModel.from_pretrained( pretrained_model_path, subfolder="transformer" ) vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae") scheduler = CogVideoXDDIMScheduler.from_pretrained( pretrained_model_path, subfolder="scheduler" ) # Load EF-Net EF_Net_model = ( EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48) .requires_grad_(False) .eval() ) ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False) EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()} m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False) print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}") # Create pipeline pipe = CogVideoXEFNetInbetweeningPipeline( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, EF_Net_model=EF_Net_model, scheduler=scheduler, ) pipe.scheduler = CogVideoXDDIMScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) pipe.to(device) pipe = pipe.to(dtype=dtype) pipe.vae.enable_slicing() pipe.vae.enable_tiling() return "Pipeline loaded successfully!" @spaces.GPU def generate_inbetweening( first_image: Image.Image, last_image: Image.Image, prompt: str, num_frames: int = 49, guidance_scale: float = 6.0, ef_net_weights: float = 1.0, ef_net_guidance_start: float = 0.0, ef_net_guidance_end: float = 1.0, seed: int = 42, progress=gr.Progress(), ): """Generate frame inbetweening video""" global pipe if pipe is None: return None, "Please load the pipeline first!" if first_image is None or last_image is None: return None, "Please upload both start and end frames!" if not prompt.strip(): return None, "Please provide a text prompt!" try: progress(0, desc="Starting generation...") start_time = time.time() # Generate video progress(0.2, desc="Processing frames...") video_frames = pipe( first_image=first_image, last_image=last_image, prompt=prompt, num_frames=num_frames, use_dynamic_cfg=False, guidance_scale=guidance_scale, generator=torch.Generator(device=device).manual_seed(seed), EF_Net_weights=ef_net_weights, EF_Net_guidance_start=ef_net_guidance_start, EF_Net_guidance_end=ef_net_guidance_end, ).frames[0] progress(0.9, desc="Exporting video...") # Export video output_path = f"output_{int(time.time())}.mp4" export_to_video(video_frames, output_path, fps=7) elapsed_time = time.time() - start_time status_msg = f"Video generated successfully in {elapsed_time:.2f}s" progress(1.0, desc="Done!") return output_path, status_msg except Exception as e: return None, f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo: gr.Markdown( """ # Sci-Fi: Symmetric Constraint for Frame Inbetweening Upload start and end frames to generate smooth inbetweening video. **Note:** Make sure to load the pipeline first before generating videos. """ ) with gr.Tab("Generate"): with gr.Row(): with gr.Column(): first_image = gr.Image(label="Start Frame", type="pil") last_image = gr.Image(label="End Frame", type="pil") with gr.Column(): prompt = gr.Textbox( label="Prompt", placeholder="Describe the motion or content...", lines=3, ) with gr.Accordion("Advanced Settings", open=False): num_frames = gr.Slider( minimum=13, maximum=49, value=49, step=12, label="Number of Frames", ) guidance_scale = gr.Slider( minimum=1.0, maximum=15.0, value=6.0, step=0.5, label="Guidance Scale", ) ef_net_weights = gr.Slider( minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="EF-Net Weights", ) ef_net_guidance_start = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="EF-Net Guidance Start", ) ef_net_guidance_end = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="EF-Net Guidance End", ) seed = gr.Number(label="Seed", value=42, precision=0) generate_btn = gr.Button("Generate Video", variant="primary", size="lg") with gr.Row(): output_video = gr.Video(label="Generated Video") status_text = gr.Textbox(label="Status", lines=2) generate_btn.click( fn=generate_inbetweening, inputs=[ first_image, last_image, prompt, num_frames, guidance_scale, ef_net_weights, ef_net_guidance_start, ef_net_guidance_end, seed, ], outputs=[output_video, status_text], ) with gr.Tab("Setup"): gr.Markdown( """ ## Load Pipeline Configure and load the model before generating videos. **Default paths:** - Model: `THUDM/CogVideoX-5b` (or your downloaded path) - EF-Net: `weights/EF_Net.pth` """ ) with gr.Row(): model_path = gr.Textbox( label="Pretrained Model Path", value="THUDM/CogVideoX-5b", placeholder="Path to CogVideoX model", ) ef_net_path = gr.Textbox( label="EF-Net Checkpoint Path", value="weights/EF_Net.pth", placeholder="Path to EF-Net weights", ) dtype_choice = gr.Radio( choices=["bfloat16", "float16"], value="bfloat16", label="Data Type" ) load_btn = gr.Button("Load Pipeline", variant="primary") load_status = gr.Textbox(label="Load Status", interactive=False) load_btn.click( fn=load_pipeline, inputs=[model_path, ef_net_path, dtype_choice], outputs=load_status, ) with gr.Tab("Examples"): gr.Markdown( """ ## Example Inputs Try these example frame pairs from the `example_input_pairs/` folder. """ ) gr.Examples( examples=[ [ "example_input_pairs/input_pair1/start.jpg", "example_input_pairs/input_pair1/end.jpg", "A smooth transition between frames", ], [ "example_input_pairs/input_pair2/start.jpg", "example_input_pairs/input_pair2/end.jpg", "Natural motion interpolation", ], ], inputs=[first_image, last_image, prompt], ) if __name__ == "__main__": demo.launch()