AhmadMustafa Claude commited on
Commit
2246e78
·
1 Parent(s): c599f41

Refactor app for automatic model loading on startup

Browse files

- Load model globally at startup (no button needed)
- Move CUDA initialization inside @spaces.GPU decorated function
- Remove model loading UI components
- Add 120s duration to GPU decorator for longer generation
- Simplify app flow for better UX

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +20 -63
app.py CHANGED
@@ -9,38 +9,21 @@ from PIL import Image
9
 
10
  from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
11
 
12
- # Global variable to store the pipeline
13
- pipe = None
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
 
 
 
 
 
 
 
16
 
17
- @spaces.GPU
18
- def load_model(model_path):
19
- """Load the CogVideoX-Interpolation model"""
20
- global pipe
21
 
22
- print(f"Loading model from {model_path}...")
23
- print(f"Using device: {device}")
24
-
25
- # Determine dtype based on model variant
26
- dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16
27
-
28
- pipe = CogVideoXInterpolationPipeline.from_pretrained(model_path, torch_dtype=dtype)
29
-
30
- # Memory optimization
31
- if device == "cuda":
32
- pipe.enable_sequential_cpu_offload()
33
- else:
34
- pipe = pipe.to(device)
35
-
36
- pipe.vae.enable_tiling()
37
- pipe.vae.enable_slicing()
38
-
39
- print("Model loaded successfully!")
40
- return "✓ Model loaded successfully!"
41
-
42
-
43
- @spaces.GPU
44
  def generate_interpolation(
45
  first_image,
46
  last_image,
@@ -53,9 +36,6 @@ def generate_interpolation(
53
  ):
54
  """Generate interpolated video between two keyframes"""
55
 
56
- if pipe is None:
57
- return None, "⚠️ Please load the model first!"
58
-
59
  if first_image is None or last_image is None:
60
  return None, "⚠️ Please upload both start and end frame images!"
61
 
@@ -63,6 +43,11 @@ def generate_interpolation(
63
  return None, "⚠️ Please provide a text prompt describing the motion!"
64
 
65
  try:
 
 
 
 
 
66
  # Convert numpy arrays to PIL Images if needed
67
  if not isinstance(first_image, Image.Image):
68
  first_image = Image.fromarray(first_image)
@@ -115,26 +100,12 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
115
  Generate smooth video transitions between two keyframe images using AI.
116
 
117
  **Instructions:**
118
- 1. First, load the model by providing the path to your checkpoint
119
- 2. Upload start and end frame images
120
- 3. Describe the motion/transition in the text prompt
121
- 4. Adjust parameters and generate!
122
  """
123
  )
124
 
125
- with gr.Row():
126
- with gr.Column():
127
- gr.Markdown("### 🔧 Model Setup")
128
- model_path_input = gr.Textbox(
129
- label="Model Path",
130
- placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation",
131
- value="feizhengcong/CogvideoX-Interpolation",
132
- )
133
- load_btn = gr.Button("Load Model", variant="primary")
134
- model_status = gr.Textbox(label="Status", interactive=False)
135
-
136
- gr.Markdown("---")
137
-
138
  with gr.Row():
139
  with gr.Column():
140
  gr.Markdown("### 🖼️ Input Keyframes")
@@ -210,8 +181,6 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
210
  )
211
 
212
  # Event handlers
213
- load_btn.click(fn=load_model, inputs=[model_path_input], outputs=[model_status])
214
-
215
  generate_btn.click(
216
  fn=generate_interpolation,
217
  inputs=[
@@ -228,16 +197,4 @@ with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
228
  )
229
 
230
  if __name__ == "__main__":
231
- print("=" * 50)
232
- print("CogVideoX Keyframe Interpolation Gradio App")
233
- print("=" * 50)
234
- print(f"Device: {device}")
235
- print(f"CUDA available: {torch.cuda.is_available()}")
236
- if torch.cuda.is_available():
237
- print(f"GPU: {torch.cuda.get_device_name(0)}")
238
- print(
239
- f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
240
- )
241
- print("=" * 50)
242
-
243
  demo.launch()
 
9
 
10
  from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
11
 
12
+ # Load model globally at startup
13
+ print("Loading CogVideoX-Interpolation model...")
14
+ MODEL_PATH = "feizhengcong/CogvideoX-Interpolation"
15
+ dtype = torch.float16
16
 
17
+ pipe = CogVideoXInterpolationPipeline.from_pretrained(
18
+ MODEL_PATH,
19
+ torch_dtype=dtype
20
+ )
21
+ pipe.vae.enable_tiling()
22
+ pipe.vae.enable_slicing()
23
+ print("Model loaded successfully!")
24
 
 
 
 
 
25
 
26
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def generate_interpolation(
28
  first_image,
29
  last_image,
 
36
  ):
37
  """Generate interpolated video between two keyframes"""
38
 
 
 
 
39
  if first_image is None or last_image is None:
40
  return None, "⚠️ Please upload both start and end frame images!"
41
 
 
43
  return None, "⚠️ Please provide a text prompt describing the motion!"
44
 
45
  try:
46
+ # Move model to CUDA inside the decorated function
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ pipe.to(device)
49
+ pipe.enable_sequential_cpu_offload()
50
+
51
  # Convert numpy arrays to PIL Images if needed
52
  if not isinstance(first_image, Image.Image):
53
  first_image = Image.fromarray(first_image)
 
100
  Generate smooth video transitions between two keyframe images using AI.
101
 
102
  **Instructions:**
103
+ 1. Upload start and end frame images
104
+ 2. Describe the motion/transition in the text prompt
105
+ 3. Adjust parameters and generate!
 
106
  """
107
  )
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  with gr.Row():
110
  with gr.Column():
111
  gr.Markdown("### 🖼️ Input Keyframes")
 
181
  )
182
 
183
  # Event handlers
 
 
184
  generate_btn.click(
185
  fn=generate_interpolation,
186
  inputs=[
 
197
  )
198
 
199
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
200
  demo.launch()