|
|
import time |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler |
|
|
from diffusers.utils import export_to_video |
|
|
from PIL import Image |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
from cogvideo_transformer import CustomCogVideoXTransformer3DModel |
|
|
from EF_Net import EF_Net |
|
|
from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline |
|
|
|
|
|
|
|
|
pipe = None |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=1000) |
|
|
def load_pipeline( |
|
|
pretrained_model_path="THUDM/CogVideoX-5b", |
|
|
ef_net_path="weights/EF_Net.pth", |
|
|
dtype_str="bfloat16", |
|
|
): |
|
|
"""Load the Sci-Fi pipeline""" |
|
|
global pipe |
|
|
|
|
|
dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16 |
|
|
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained( |
|
|
pretrained_model_path, subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
pretrained_model_path, subfolder="text_encoder" |
|
|
) |
|
|
transformer = CustomCogVideoXTransformer3DModel.from_pretrained( |
|
|
pretrained_model_path, subfolder="transformer" |
|
|
) |
|
|
vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae") |
|
|
scheduler = CogVideoXDDIMScheduler.from_pretrained( |
|
|
pretrained_model_path, subfolder="scheduler" |
|
|
) |
|
|
|
|
|
|
|
|
EF_Net_model = ( |
|
|
EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48) |
|
|
.requires_grad_(False) |
|
|
.eval() |
|
|
) |
|
|
|
|
|
ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False) |
|
|
EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()} |
|
|
m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False) |
|
|
print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}") |
|
|
|
|
|
|
|
|
pipe = CogVideoXEFNetInbetweeningPipeline( |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
vae=vae, |
|
|
EF_Net_model=EF_Net_model, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
pipe.scheduler = CogVideoXDDIMScheduler.from_config( |
|
|
pipe.scheduler.config, timestep_spacing="trailing" |
|
|
) |
|
|
|
|
|
pipe.to(device) |
|
|
pipe = pipe.to(dtype=dtype) |
|
|
|
|
|
pipe.vae.enable_slicing() |
|
|
pipe.vae.enable_tiling() |
|
|
|
|
|
return "Pipeline loaded successfully!" |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=1000) |
|
|
def generate_inbetweening( |
|
|
first_image: Image.Image, |
|
|
last_image: Image.Image, |
|
|
prompt: str, |
|
|
num_frames: int = 49, |
|
|
guidance_scale: float = 6.0, |
|
|
ef_net_weights: float = 1.0, |
|
|
ef_net_guidance_start: float = 0.0, |
|
|
ef_net_guidance_end: float = 1.0, |
|
|
seed: int = 42, |
|
|
progress=gr.Progress(), |
|
|
): |
|
|
"""Generate frame inbetweening video""" |
|
|
global pipe |
|
|
|
|
|
if pipe is None: |
|
|
return None, "Please load the pipeline first!" |
|
|
|
|
|
if first_image is None or last_image is None: |
|
|
return None, "Please upload both start and end frames!" |
|
|
|
|
|
if not prompt.strip(): |
|
|
return None, "Please provide a text prompt!" |
|
|
|
|
|
try: |
|
|
progress(0, desc="Starting generation...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
progress(0.2, desc="Processing frames...") |
|
|
video_frames = pipe( |
|
|
first_image=first_image, |
|
|
last_image=last_image, |
|
|
prompt=prompt, |
|
|
num_frames=num_frames, |
|
|
use_dynamic_cfg=False, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
|
EF_Net_weights=ef_net_weights, |
|
|
EF_Net_guidance_start=ef_net_guidance_start, |
|
|
EF_Net_guidance_end=ef_net_guidance_end, |
|
|
).frames[0] |
|
|
|
|
|
progress(0.9, desc="Exporting video...") |
|
|
|
|
|
|
|
|
output_path = f"output_{int(time.time())}.mp4" |
|
|
export_to_video(video_frames, output_path, fps=7) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
status_msg = f"Video generated successfully in {elapsed_time:.2f}s" |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return output_path, status_msg |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Sci-Fi: Symmetric Constraint for Frame Inbetweening |
|
|
|
|
|
Upload start and end frames to generate smooth inbetweening video. |
|
|
|
|
|
**Note:** Make sure to load the pipeline first before generating videos. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tab("Generate"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
first_image = gr.Image(label="Start Frame", type="pil") |
|
|
last_image = gr.Image(label="End Frame", type="pil") |
|
|
|
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Describe the motion or content...", |
|
|
lines=3, |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
num_frames = gr.Slider( |
|
|
minimum=13, |
|
|
maximum=49, |
|
|
value=49, |
|
|
step=12, |
|
|
label="Number of Frames", |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=15.0, |
|
|
value=6.0, |
|
|
step=0.5, |
|
|
label="Guidance Scale", |
|
|
) |
|
|
ef_net_weights = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="EF-Net Weights", |
|
|
) |
|
|
ef_net_guidance_start = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.0, |
|
|
step=0.1, |
|
|
label="EF-Net Guidance Start", |
|
|
) |
|
|
ef_net_guidance_end = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="EF-Net Guidance End", |
|
|
) |
|
|
seed = gr.Number(label="Seed", value=42, precision=0) |
|
|
|
|
|
generate_btn = gr.Button("Generate Video", variant="primary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
output_video = gr.Video(label="Generated Video") |
|
|
status_text = gr.Textbox(label="Status", lines=2) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_inbetweening, |
|
|
inputs=[ |
|
|
first_image, |
|
|
last_image, |
|
|
prompt, |
|
|
num_frames, |
|
|
guidance_scale, |
|
|
ef_net_weights, |
|
|
ef_net_guidance_start, |
|
|
ef_net_guidance_end, |
|
|
seed, |
|
|
], |
|
|
outputs=[output_video, status_text], |
|
|
) |
|
|
|
|
|
with gr.Tab("Setup"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## Load Pipeline |
|
|
|
|
|
Configure and load the model before generating videos. |
|
|
|
|
|
**Default paths:** |
|
|
- Model: `THUDM/CogVideoX-5b` (or your downloaded path) |
|
|
- EF-Net: `weights/EF_Net.pth` |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_path = gr.Textbox( |
|
|
label="Pretrained Model Path", |
|
|
value="THUDM/CogVideoX-5b", |
|
|
placeholder="Path to CogVideoX model", |
|
|
) |
|
|
ef_net_path = gr.Textbox( |
|
|
label="EF-Net Checkpoint Path", |
|
|
value="weights/EF_Net.pth", |
|
|
placeholder="Path to EF-Net weights", |
|
|
) |
|
|
|
|
|
dtype_choice = gr.Radio( |
|
|
choices=["bfloat16", "float16"], value="bfloat16", label="Data Type" |
|
|
) |
|
|
|
|
|
load_btn = gr.Button("Load Pipeline", variant="primary") |
|
|
load_status = gr.Textbox(label="Load Status", interactive=False) |
|
|
|
|
|
load_btn.click( |
|
|
fn=load_pipeline, |
|
|
inputs=[model_path, ef_net_path, dtype_choice], |
|
|
outputs=load_status, |
|
|
) |
|
|
|
|
|
with gr.Tab("Examples"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## Example Inputs |
|
|
|
|
|
Try these example frame pairs from the `example_input_pairs/` folder. |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"example_input_pairs/input_pair1/start.jpg", |
|
|
"example_input_pairs/input_pair1/end.jpg", |
|
|
"A smooth transition between frames", |
|
|
], |
|
|
[ |
|
|
"example_input_pairs/input_pair2/start.jpg", |
|
|
"example_input_pairs/input_pair2/end.jpg", |
|
|
"Natural motion interpolation", |
|
|
], |
|
|
], |
|
|
inputs=[first_image, last_image, prompt], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Loading pipeline automatically on startup...") |
|
|
try: |
|
|
load_pipeline() |
|
|
print("Pipeline loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Failed to load pipeline on startup: {e}") |
|
|
print("You can manually load it from the Setup tab.") |
|
|
|
|
|
demo.launch() |
|
|
|