QingShuai commited on
Commit
a6c3d88
Β·
1 Parent(s): 4cadce0

update gradio

Browse files
Files changed (1) hide show
  1. gradio_app.py +20 -6
gradio_app.py CHANGED
@@ -11,6 +11,21 @@ from typing import List, Optional, Tuple, Union
11
 
12
  import gradio as gr
13
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Import spaces for Hugging Face Zero GPU support
16
  try:
@@ -60,7 +75,7 @@ def _init_runtime_if_needed():
60
  if "USE_HF_MODELS" not in os.environ:
61
  os.environ["USE_HF_MODELS"] = "1"
62
 
63
- skip_text = True
64
  _global_runtime = T2MRuntime(
65
  config_path=cfg,
66
  ckpt_name=ckpt,
@@ -801,12 +816,10 @@ class T2MGradioUI:
801
  )
802
 
803
 
804
- def create_demo():
805
  """Create the Gradio demo with Zero GPU support."""
806
  global _global_runtime, _global_args
807
 
808
- final_model_path = './configs/base'
809
-
810
  class Args:
811
  model_path = final_model_path
812
  output_dir = "output/gradio"
@@ -857,7 +870,7 @@ def create_demo():
857
  if "USE_HF_MODELS" not in os.environ:
858
  os.environ["USE_HF_MODELS"] = "1"
859
 
860
- skip_text = True
861
  runtime = T2MRuntime(
862
  config_path=cfg,
863
  ckpt_name=ckpt,
@@ -874,7 +887,8 @@ def create_demo():
874
 
875
 
876
  # Create demo at module level for Hugging Face Spaces
877
- demo = create_demo()
 
878
 
879
  if __name__ == "__main__":
880
  demo.launch()
 
11
 
12
  import gradio as gr
13
  import torch
14
+ from huggingface_hub import snapshot_download
15
+
16
+ def try_to_download_model():
17
+ repo_id = "tencent/HY-Motion-1.0"
18
+ target_folder = "HY-Motion-1.0-Lite"
19
+ print(f">>> start download ", repo_id, target_folder)
20
+ local_dir = snapshot_download(
21
+ repo_id=repo_id,
22
+ allow_patterns=f"{target_folder}/*",
23
+ local_dir="./downloaded_models"
24
+ )
25
+ final_model_path = os.path.join(local_dir, target_folder)
26
+ print(f">>> Final model path: {final_model_path}")
27
+ return final_model_path
28
+
29
 
30
  # Import spaces for Hugging Face Zero GPU support
31
  try:
 
75
  if "USE_HF_MODELS" not in os.environ:
76
  os.environ["USE_HF_MODELS"] = "1"
77
 
78
+ skip_text = False
79
  _global_runtime = T2MRuntime(
80
  config_path=cfg,
81
  ckpt_name=ckpt,
 
816
  )
817
 
818
 
819
+ def create_demo(final_model_path):
820
  """Create the Gradio demo with Zero GPU support."""
821
  global _global_runtime, _global_args
822
 
 
 
823
  class Args:
824
  model_path = final_model_path
825
  output_dir = "output/gradio"
 
870
  if "USE_HF_MODELS" not in os.environ:
871
  os.environ["USE_HF_MODELS"] = "1"
872
 
873
+ skip_text = False
874
  runtime = T2MRuntime(
875
  config_path=cfg,
876
  ckpt_name=ckpt,
 
887
 
888
 
889
  # Create demo at module level for Hugging Face Spaces
890
+ final_model_path = try_to_download_model()
891
+ demo = create_demo(final_model_path)
892
 
893
  if __name__ == "__main__":
894
  demo.launch()