Sci-Fi / app.py
AhmadMustafa's picture
rename
d230b19
raw
history blame
9.4 kB
import time
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from cogvideo_transformer import CustomCogVideoXTransformer3DModel
from EF_Net import EF_Net
from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline
# Global variables for the pipeline
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
@spaces.GPU(duration=1000)
def load_pipeline(
pretrained_model_path="THUDM/CogVideoX-5b",
ef_net_path="weights/EF_Net.pth",
dtype_str="bfloat16",
):
"""Load the Sci-Fi pipeline"""
global pipe
dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16
# Load models
tokenizer = T5Tokenizer.from_pretrained(
pretrained_model_path, subfolder="tokenizer"
)
text_encoder = T5EncoderModel.from_pretrained(
pretrained_model_path, subfolder="text_encoder"
)
transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
pretrained_model_path, subfolder="transformer"
)
vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae")
scheduler = CogVideoXDDIMScheduler.from_pretrained(
pretrained_model_path, subfolder="scheduler"
)
# Load EF-Net
EF_Net_model = (
EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48)
.requires_grad_(False)
.eval()
)
ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False)
EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()}
m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False)
print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}")
# Create pipeline
pipe = CogVideoXEFNetInbetweeningPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
EF_Net_model=EF_Net_model,
scheduler=scheduler,
)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.to(device)
pipe = pipe.to(dtype=dtype)
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
return "Pipeline loaded successfully!"
@spaces.GPU(duration=1000)
def generate_inbetweening(
first_image: Image.Image,
last_image: Image.Image,
prompt: str,
num_frames: int = 49,
guidance_scale: float = 6.0,
ef_net_weights: float = 1.0,
ef_net_guidance_start: float = 0.0,
ef_net_guidance_end: float = 1.0,
seed: int = 42,
progress=gr.Progress(),
):
"""Generate frame inbetweening video"""
global pipe
if pipe is None:
return None, "Please load the pipeline first!"
if first_image is None or last_image is None:
return None, "Please upload both start and end frames!"
if not prompt.strip():
return None, "Please provide a text prompt!"
try:
progress(0, desc="Starting generation...")
start_time = time.time()
# Generate video
progress(0.2, desc="Processing frames...")
video_frames = pipe(
first_image=first_image,
last_image=last_image,
prompt=prompt,
num_frames=num_frames,
use_dynamic_cfg=False,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed),
EF_Net_weights=ef_net_weights,
EF_Net_guidance_start=ef_net_guidance_start,
EF_Net_guidance_end=ef_net_guidance_end,
).frames[0]
progress(0.9, desc="Exporting video...")
# Export video
output_path = f"output_{int(time.time())}.mp4"
export_to_video(video_frames, output_path, fps=7)
elapsed_time = time.time() - start_time
status_msg = f"Video generated successfully in {elapsed_time:.2f}s"
progress(1.0, desc="Done!")
return output_path, status_msg
except Exception as e:
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo:
gr.Markdown(
"""
# Sci-Fi: Symmetric Constraint for Frame Inbetweening
Upload start and end frames to generate smooth inbetweening video.
**Note:** Make sure to load the pipeline first before generating videos.
"""
)
with gr.Tab("Generate"):
with gr.Row():
with gr.Column():
first_image = gr.Image(label="Start Frame", type="pil")
last_image = gr.Image(label="End Frame", type="pil")
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the motion or content...",
lines=3,
)
with gr.Accordion("Advanced Settings", open=False):
num_frames = gr.Slider(
minimum=13,
maximum=49,
value=49,
step=12,
label="Number of Frames",
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=15.0,
value=6.0,
step=0.5,
label="Guidance Scale",
)
ef_net_weights = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1,
label="EF-Net Weights",
)
ef_net_guidance_start = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
label="EF-Net Guidance Start",
)
ef_net_guidance_end = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
label="EF-Net Guidance End",
)
seed = gr.Number(label="Seed", value=42, precision=0)
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Row():
output_video = gr.Video(label="Generated Video")
status_text = gr.Textbox(label="Status", lines=2)
generate_btn.click(
fn=generate_inbetweening,
inputs=[
first_image,
last_image,
prompt,
num_frames,
guidance_scale,
ef_net_weights,
ef_net_guidance_start,
ef_net_guidance_end,
seed,
],
outputs=[output_video, status_text],
)
with gr.Tab("Setup"):
gr.Markdown(
"""
## Load Pipeline
Configure and load the model before generating videos.
**Default paths:**
- Model: `THUDM/CogVideoX-5b` (or your downloaded path)
- EF-Net: `weights/EF_Net.pth`
"""
)
with gr.Row():
model_path = gr.Textbox(
label="Pretrained Model Path",
value="THUDM/CogVideoX-5b",
placeholder="Path to CogVideoX model",
)
ef_net_path = gr.Textbox(
label="EF-Net Checkpoint Path",
value="weights/EF_Net.pth",
placeholder="Path to EF-Net weights",
)
dtype_choice = gr.Radio(
choices=["bfloat16", "float16"], value="bfloat16", label="Data Type"
)
load_btn = gr.Button("Load Pipeline", variant="primary")
load_status = gr.Textbox(label="Load Status", interactive=False)
load_btn.click(
fn=load_pipeline,
inputs=[model_path, ef_net_path, dtype_choice],
outputs=load_status,
)
with gr.Tab("Examples"):
gr.Markdown(
"""
## Example Inputs
Try these example frame pairs from the `example_input_pairs/` folder.
"""
)
gr.Examples(
examples=[
[
"example_input_pairs/input_pair1/start.jpg",
"example_input_pairs/input_pair1/end.jpg",
"A smooth transition between frames",
],
[
"example_input_pairs/input_pair2/start.jpg",
"example_input_pairs/input_pair2/end.jpg",
"Natural motion interpolation",
],
],
inputs=[first_image, last_image, prompt],
)
if __name__ == "__main__":
# Automatically load pipeline on startup
print("Loading pipeline automatically on startup...")
try:
load_pipeline()
print("Pipeline loaded successfully!")
except Exception as e:
print(f"Failed to load pipeline on startup: {e}")
print("You can manually load it from the Setup tab.")
demo.launch()