File size: 7,926 Bytes
068b511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import gradio as gr
import torch
from diffusers.utils import export_to_video
from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
from PIL import Image
import tempfile
import os

# Global variable to store the pipeline
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_path):
    """Load the CogVideoX-Interpolation model"""
    global pipe

    print(f"Loading model from {model_path}...")
    print(f"Using device: {device}")

    # Determine dtype based on model variant
    dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16

    pipe = CogVideoXInterpolationPipeline.from_pretrained(
        model_path,
        torch_dtype=dtype
    )

    # Memory optimization
    if device == "cuda":
        pipe.enable_sequential_cpu_offload()
    else:
        pipe = pipe.to(device)

    pipe.vae.enable_tiling()
    pipe.vae.enable_slicing()

    print("Model loaded successfully!")
    return "✓ Model loaded successfully!"

def generate_interpolation(
    first_image,
    last_image,
    prompt,
    num_frames=49,
    num_inference_steps=50,
    guidance_scale=6.0,
    fps=8,
    seed=42
):
    """Generate interpolated video between two keyframes"""

    if pipe is None:
        return None, "⚠️ Please load the model first!"

    if first_image is None or last_image is None:
        return None, "⚠️ Please upload both start and end frame images!"

    if not prompt.strip():
        return None, "⚠️ Please provide a text prompt describing the motion!"

    try:
        # Convert numpy arrays to PIL Images if needed
        if not isinstance(first_image, Image.Image):
            first_image = Image.fromarray(first_image)
        if not isinstance(last_image, Image.Image):
            last_image = Image.fromarray(last_image)

        print(f"Generating video with prompt: {prompt}")
        print(f"Parameters: frames={num_frames}, steps={num_inference_steps}, guidance={guidance_scale}")

        # Generate video
        generator = torch.Generator(device=device).manual_seed(seed)

        video = pipe(
            prompt=prompt,
            first_image=first_image,
            last_image=last_image,
            num_videos_per_prompt=1,
            num_inference_steps=num_inference_steps,
            num_frames=num_frames,
            guidance_scale=guidance_scale,
            generator=generator,
        )[0]

        # Export to temporary file
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        output_path = temp_file.name
        temp_file.close()

        export_to_video(video, output_path, fps=fps)

        status = f"✓ Video generated successfully! ({num_frames} frames at {fps} fps)"
        print(status)

        return output_path, status

    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        print(error_msg)
        return None, error_msg

# Create Gradio interface
with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
    gr.Markdown("""
    # 🎬 CogVideoX Keyframe Interpolation

    Generate smooth video transitions between two keyframe images using AI.

    **Instructions:**
    1. First, load the model by providing the path to your checkpoint
    2. Upload start and end frame images
    3. Describe the motion/transition in the text prompt
    4. Adjust parameters and generate!
    """)

    with gr.Row():
        with gr.Column():
            gr.Markdown("### 🔧 Model Setup")
            model_path_input = gr.Textbox(
                label="Model Path",
                placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation",
                value="feizhengcong/CogvideoX-Interpolation"
            )
            load_btn = gr.Button("Load Model", variant="primary")
            model_status = gr.Textbox(label="Status", interactive=False)

    gr.Markdown("---")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### 🖼️ Input Keyframes")
            first_image_input = gr.Image(
                label="Start Frame",
                type="pil",
                height=300
            )
            last_image_input = gr.Image(
                label="End Frame",
                type="pil",
                height=300
            )

        with gr.Column():
            gr.Markdown("### ⚙️ Generation Settings")
            prompt_input = gr.Textbox(
                label="Motion Description",
                placeholder="Describe the motion/transition between the frames...",
                lines=4
            )

            with gr.Row():
                num_frames_slider = gr.Slider(
                    label="Number of Frames",
                    minimum=13,
                    maximum=49,
                    step=4,
                    value=49,
                    info="Must be 4k+1 format (13, 17, 21, ..., 49)"
                )
                fps_slider = gr.Slider(
                    label="FPS",
                    minimum=4,
                    maximum=16,
                    step=2,
                    value=8
                )

            with gr.Row():
                num_steps_slider = gr.Slider(
                    label="Inference Steps",
                    minimum=20,
                    maximum=100,
                    step=5,
                    value=50,
                    info="More steps = better quality but slower"
                )
                guidance_slider = gr.Slider(
                    label="Guidance Scale",
                    minimum=1.0,
                    maximum=15.0,
                    step=0.5,
                    value=6.0,
                    info="Higher = stronger prompt following"
                )

            seed_input = gr.Number(
                label="Random Seed",
                value=42,
                precision=0
            )

            generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")

    gr.Markdown("---")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### 🎥 Generated Video")
            output_video = gr.Video(label="Output")
            generation_status = gr.Textbox(label="Generation Status", interactive=False)

    # Examples
    gr.Markdown("---")
    gr.Markdown("### 💡 Example Prompts")
    gr.Examples(
        examples=[
            ["A person walks forward slowly, their body moving naturally with each step."],
            ["The camera smoothly pans from left to right, revealing the scene."],
            ["A dancer gracefully transitions from one pose to another."],
            ["The sun sets gradually, changing the lighting and colors of the scene."],
            ["A car accelerates down the street, moving from standstill to motion."],
        ],
        inputs=prompt_input,
        label="Click to use example prompts"
    )

    # Event handlers
    load_btn.click(
        fn=load_model,
        inputs=[model_path_input],
        outputs=[model_status]
    )

    generate_btn.click(
        fn=generate_interpolation,
        inputs=[
            first_image_input,
            last_image_input,
            prompt_input,
            num_frames_slider,
            num_steps_slider,
            guidance_slider,
            fps_slider,
            seed_input
        ],
        outputs=[output_video, generation_status]
    )

if __name__ == "__main__":
    print("="*50)
    print("CogVideoX Keyframe Interpolation Gradio App")
    print("="*50)
    print(f"Device: {device}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("="*50)

    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )