AhmadMustafa commited on
Commit
3b51f73
·
1 Parent(s): c4279ee
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -17,13 +17,12 @@ pipe = None
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
 
20
- @spaces.GPU(duration=1000)
21
- def load_pipeline(
22
  pretrained_model_path="THUDM/CogVideoX-5b",
23
  ef_net_path="weights/EF_Net.pth",
24
  dtype_str="bfloat16",
25
  ):
26
- """Load the Sci-Fi pipeline"""
27
  global pipe
28
 
29
  dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16
@@ -77,6 +76,16 @@ def load_pipeline(
77
  return "Pipeline loaded successfully!"
78
 
79
 
 
 
 
 
 
 
 
 
 
 
80
  @spaces.GPU(duration=1000)
81
  def generate_inbetweening(
82
  first_image: Image.Image,
@@ -289,7 +298,7 @@ if __name__ == "__main__":
289
  # Automatically load pipeline on startup
290
  print("Loading pipeline automatically on startup...")
291
  try:
292
- load_pipeline()
293
  print("Pipeline loaded successfully!")
294
  except Exception as e:
295
  print(f"Failed to load pipeline on startup: {e}")
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
 
20
+ def _load_pipeline_internal(
 
21
  pretrained_model_path="THUDM/CogVideoX-5b",
22
  ef_net_path="weights/EF_Net.pth",
23
  dtype_str="bfloat16",
24
  ):
25
+ """Internal function to load the Sci-Fi pipeline"""
26
  global pipe
27
 
28
  dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16
 
76
  return "Pipeline loaded successfully!"
77
 
78
 
79
+ @spaces.GPU(duration=1000)
80
+ def load_pipeline(
81
+ pretrained_model_path="THUDM/CogVideoX-5b",
82
+ ef_net_path="weights/EF_Net.pth",
83
+ dtype_str="bfloat16",
84
+ ):
85
+ """Load the Sci-Fi pipeline (Gradio wrapper)"""
86
+ return _load_pipeline_internal(pretrained_model_path, ef_net_path, dtype_str)
87
+
88
+
89
  @spaces.GPU(duration=1000)
90
  def generate_inbetweening(
91
  first_image: Image.Image,
 
298
  # Automatically load pipeline on startup
299
  print("Loading pipeline automatically on startup...")
300
  try:
301
+ _load_pipeline_internal()
302
  print("Pipeline loaded successfully!")
303
  except Exception as e:
304
  print(f"Failed to load pipeline on startup: {e}")