Spaces:
Paused
Paused
Commit
·
2246e78
1
Parent(s):
c599f41
Refactor app for automatic model loading on startup
Browse files- Load model globally at startup (no button needed)
- Move CUDA initialization inside @spaces.GPU decorated function
- Remove model loading UI components
- Add 120s duration to GPU decorator for longer generation
- Simplify app flow for better UX
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
app.py
CHANGED
|
@@ -9,38 +9,21 @@ from PIL import Image
|
|
| 9 |
|
| 10 |
from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
@spaces.GPU
|
| 18 |
-
def load_model(model_path):
|
| 19 |
-
"""Load the CogVideoX-Interpolation model"""
|
| 20 |
-
global pipe
|
| 21 |
|
| 22 |
-
|
| 23 |
-
print(f"Using device: {device}")
|
| 24 |
-
|
| 25 |
-
# Determine dtype based on model variant
|
| 26 |
-
dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16
|
| 27 |
-
|
| 28 |
-
pipe = CogVideoXInterpolationPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
| 29 |
-
|
| 30 |
-
# Memory optimization
|
| 31 |
-
if device == "cuda":
|
| 32 |
-
pipe.enable_sequential_cpu_offload()
|
| 33 |
-
else:
|
| 34 |
-
pipe = pipe.to(device)
|
| 35 |
-
|
| 36 |
-
pipe.vae.enable_tiling()
|
| 37 |
-
pipe.vae.enable_slicing()
|
| 38 |
-
|
| 39 |
-
print("Model loaded successfully!")
|
| 40 |
-
return "✓ Model loaded successfully!"
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@spaces.GPU
|
| 44 |
def generate_interpolation(
|
| 45 |
first_image,
|
| 46 |
last_image,
|
|
@@ -53,9 +36,6 @@ def generate_interpolation(
|
|
| 53 |
):
|
| 54 |
"""Generate interpolated video between two keyframes"""
|
| 55 |
|
| 56 |
-
if pipe is None:
|
| 57 |
-
return None, "⚠️ Please load the model first!"
|
| 58 |
-
|
| 59 |
if first_image is None or last_image is None:
|
| 60 |
return None, "⚠️ Please upload both start and end frame images!"
|
| 61 |
|
|
@@ -63,6 +43,11 @@ def generate_interpolation(
|
|
| 63 |
return None, "⚠️ Please provide a text prompt describing the motion!"
|
| 64 |
|
| 65 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# Convert numpy arrays to PIL Images if needed
|
| 67 |
if not isinstance(first_image, Image.Image):
|
| 68 |
first_image = Image.fromarray(first_image)
|
|
@@ -115,26 +100,12 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
|
|
| 115 |
Generate smooth video transitions between two keyframe images using AI.
|
| 116 |
|
| 117 |
**Instructions:**
|
| 118 |
-
1.
|
| 119 |
-
2.
|
| 120 |
-
3.
|
| 121 |
-
4. Adjust parameters and generate!
|
| 122 |
"""
|
| 123 |
)
|
| 124 |
|
| 125 |
-
with gr.Row():
|
| 126 |
-
with gr.Column():
|
| 127 |
-
gr.Markdown("### 🔧 Model Setup")
|
| 128 |
-
model_path_input = gr.Textbox(
|
| 129 |
-
label="Model Path",
|
| 130 |
-
placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation",
|
| 131 |
-
value="feizhengcong/CogvideoX-Interpolation",
|
| 132 |
-
)
|
| 133 |
-
load_btn = gr.Button("Load Model", variant="primary")
|
| 134 |
-
model_status = gr.Textbox(label="Status", interactive=False)
|
| 135 |
-
|
| 136 |
-
gr.Markdown("---")
|
| 137 |
-
|
| 138 |
with gr.Row():
|
| 139 |
with gr.Column():
|
| 140 |
gr.Markdown("### 🖼️ Input Keyframes")
|
|
@@ -210,8 +181,6 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
|
|
| 210 |
)
|
| 211 |
|
| 212 |
# Event handlers
|
| 213 |
-
load_btn.click(fn=load_model, inputs=[model_path_input], outputs=[model_status])
|
| 214 |
-
|
| 215 |
generate_btn.click(
|
| 216 |
fn=generate_interpolation,
|
| 217 |
inputs=[
|
|
@@ -228,16 +197,4 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
|
|
| 228 |
)
|
| 229 |
|
| 230 |
if __name__ == "__main__":
|
| 231 |
-
print("=" * 50)
|
| 232 |
-
print("CogVideoX Keyframe Interpolation Gradio App")
|
| 233 |
-
print("=" * 50)
|
| 234 |
-
print(f"Device: {device}")
|
| 235 |
-
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 236 |
-
if torch.cuda.is_available():
|
| 237 |
-
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 238 |
-
print(
|
| 239 |
-
f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
|
| 240 |
-
)
|
| 241 |
-
print("=" * 50)
|
| 242 |
-
|
| 243 |
demo.launch()
|
|
|
|
| 9 |
|
| 10 |
from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
|
| 11 |
|
| 12 |
+
# Load model globally at startup
|
| 13 |
+
print("Loading CogVideoX-Interpolation model...")
|
| 14 |
+
MODEL_PATH = "feizhengcong/CogvideoX-Interpolation"
|
| 15 |
+
dtype = torch.float16
|
| 16 |
|
| 17 |
+
pipe = CogVideoXInterpolationPipeline.from_pretrained(
|
| 18 |
+
MODEL_PATH,
|
| 19 |
+
torch_dtype=dtype
|
| 20 |
+
)
|
| 21 |
+
pipe.vae.enable_tiling()
|
| 22 |
+
pipe.vae.enable_slicing()
|
| 23 |
+
print("Model loaded successfully!")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
@spaces.GPU(duration=120)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def generate_interpolation(
|
| 28 |
first_image,
|
| 29 |
last_image,
|
|
|
|
| 36 |
):
|
| 37 |
"""Generate interpolated video between two keyframes"""
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
if first_image is None or last_image is None:
|
| 40 |
return None, "⚠️ Please upload both start and end frame images!"
|
| 41 |
|
|
|
|
| 43 |
return None, "⚠️ Please provide a text prompt describing the motion!"
|
| 44 |
|
| 45 |
try:
|
| 46 |
+
# Move model to CUDA inside the decorated function
|
| 47 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
+
pipe.to(device)
|
| 49 |
+
pipe.enable_sequential_cpu_offload()
|
| 50 |
+
|
| 51 |
# Convert numpy arrays to PIL Images if needed
|
| 52 |
if not isinstance(first_image, Image.Image):
|
| 53 |
first_image = Image.fromarray(first_image)
|
|
|
|
| 100 |
Generate smooth video transitions between two keyframe images using AI.
|
| 101 |
|
| 102 |
**Instructions:**
|
| 103 |
+
1. Upload start and end frame images
|
| 104 |
+
2. Describe the motion/transition in the text prompt
|
| 105 |
+
3. Adjust parameters and generate!
|
|
|
|
| 106 |
"""
|
| 107 |
)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column():
|
| 111 |
gr.Markdown("### 🖼️ Input Keyframes")
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
# Event handlers
|
|
|
|
|
|
|
| 184 |
generate_btn.click(
|
| 185 |
fn=generate_interpolation,
|
| 186 |
inputs=[
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
demo.launch()
|