multimodalart HF Staff commited on
Commit
6e7e289
·
verified ·
1 Parent(s): 007bd3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -11
app.py CHANGED
@@ -21,23 +21,29 @@ MODEL_DIR = os.path.abspath("ckpts")
21
  # Repositories
22
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
23
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
24
- HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev"
 
25
 
26
  # Configuration
27
  TRANSFORMER_VERSION = "480p_i2v_distilled"
28
  DTYPE = torch.bfloat16
 
29
  ENABLE_OFFLOADING = False
30
 
31
  def setup_environment():
32
  print("=" * 50)
33
  print("Checking Environment & Dependencies...")
34
 
 
35
  if not os.path.exists(REPO_DIR):
 
36
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
37
 
 
38
  if REPO_DIR not in sys.path:
39
  sys.path.insert(0, REPO_DIR)
40
 
 
41
  os.makedirs(MODEL_DIR, exist_ok=True)
42
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
43
 
@@ -49,30 +55,52 @@ def setup_environment():
49
  f"transformer/{TRANSFORMER_VERSION}/*",
50
  "vae/*",
51
  "scheduler/*",
52
- "tokenizer/*",
53
- "text_encoder/*" # Download LLM here too to simplify
54
  ]
55
- snapshot_download(repo_id=HF_MAIN_REPO, local_dir=MODEL_DIR, allow_patterns=allow_patterns, local_dir_use_symlinks=False)
 
 
 
 
 
56
  except Exception as e:
57
  print(f"Error downloading main weights: {e}")
58
  sys.exit(1)
59
 
60
- # Vision Encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
62
  if not os.path.exists(vision_target) or not os.listdir(vision_target):
63
  print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
64
  try:
65
  from huggingface_hub import snapshot_download
66
- snapshot_download(repo_id=HF_VISION_REPO, local_dir=vision_target, local_dir_use_symlinks=False)
 
 
 
 
67
  except Exception as e:
68
  print(f"Error downloading Vision Encoder: {e}")
69
 
70
- # Glyph Weights
71
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
72
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
73
 
74
  if not os.path.exists(glyph_ckpt_target):
75
- print(f"Downloading Glyph Weights from {HF_GLYPH_REPO}...")
76
  try:
77
  from huggingface_hub import snapshot_download
78
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
@@ -81,21 +109,25 @@ def setup_environment():
81
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
82
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
83
 
 
84
  src_assets = os.path.join(glyph_temp, "assets")
85
  if os.path.exists(src_assets):
86
  for f in os.listdir(src_assets):
87
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
88
 
 
89
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
90
  if os.path.exists(src_bin):
91
  shutil.move(src_bin, glyph_ckpt_target)
92
  else:
93
  src_safe = os.path.join(glyph_temp, "model.safetensors")
94
- if os.path.exists(src_safe): shutil.move(src_safe, glyph_ckpt_target)
 
95
 
96
  shutil.rmtree(glyph_temp, ignore_errors=True)
97
- except Exception:
98
- pass
 
99
 
100
  print("Environment Ready.")
101
  print("=" * 50)
 
21
  # Repositories
22
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
23
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
24
+ HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
25
+ HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" # User specified
26
 
27
  # Configuration
28
  TRANSFORMER_VERSION = "480p_i2v_distilled"
29
  DTYPE = torch.bfloat16
30
+ # ZeroGPU: Set False so we control offloading manually (CPU -> GPU -> CPU)
31
  ENABLE_OFFLOADING = False
32
 
33
  def setup_environment():
34
  print("=" * 50)
35
  print("Checking Environment & Dependencies...")
36
 
37
+ # 1. Clone Code Repository
38
  if not os.path.exists(REPO_DIR):
39
+ print(f"Cloning repository to {REPO_DIR}...")
40
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
41
 
42
+ # 2. Add Repo to Python Path
43
  if REPO_DIR not in sys.path:
44
  sys.path.insert(0, REPO_DIR)
45
 
46
+ # 3. Download Main Weights (Transformer, VAE, Scheduler)
47
  os.makedirs(MODEL_DIR, exist_ok=True)
48
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
49
 
 
55
  f"transformer/{TRANSFORMER_VERSION}/*",
56
  "vae/*",
57
  "scheduler/*",
58
+ "tokenizer/*"
 
59
  ]
60
+ snapshot_download(
61
+ repo_id=HF_MAIN_REPO,
62
+ local_dir=MODEL_DIR,
63
+ allow_patterns=allow_patterns,
64
+ local_dir_use_symlinks=False
65
+ )
66
  except Exception as e:
67
  print(f"Error downloading main weights: {e}")
68
  sys.exit(1)
69
 
70
+ # 4. Download LLM Text Encoder (Qwen)
71
+ llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
72
+ if not os.path.exists(llm_target) or not os.listdir(llm_target):
73
+ print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
74
+ try:
75
+ from huggingface_hub import snapshot_download
76
+ snapshot_download(
77
+ repo_id=HF_LLM_REPO,
78
+ local_dir=llm_target,
79
+ local_dir_use_symlinks=False
80
+ )
81
+ except Exception as e:
82
+ print(f"Error downloading LLM: {e}")
83
+
84
+ # 5. Download Vision Encoder (SigLIP)
85
  vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
86
  if not os.path.exists(vision_target) or not os.listdir(vision_target):
87
  print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
88
  try:
89
  from huggingface_hub import snapshot_download
90
+ snapshot_download(
91
+ repo_id=HF_VISION_REPO,
92
+ local_dir=vision_target,
93
+ local_dir_use_symlinks=False
94
+ )
95
  except Exception as e:
96
  print(f"Error downloading Vision Encoder: {e}")
97
 
98
+ # 6. Download & Restructure Glyph Weights
99
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
100
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
101
 
102
  if not os.path.exists(glyph_ckpt_target):
103
+ print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
104
  try:
105
  from huggingface_hub import snapshot_download
106
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
 
109
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
110
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
111
 
112
+ # Move Assets
113
  src_assets = os.path.join(glyph_temp, "assets")
114
  if os.path.exists(src_assets):
115
  for f in os.listdir(src_assets):
116
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
117
 
118
+ # Move Model
119
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
120
  if os.path.exists(src_bin):
121
  shutil.move(src_bin, glyph_ckpt_target)
122
  else:
123
  src_safe = os.path.join(glyph_temp, "model.safetensors")
124
+ if os.path.exists(src_safe):
125
+ shutil.move(src_safe, glyph_ckpt_target)
126
 
127
  shutil.rmtree(glyph_temp, ignore_errors=True)
128
+
129
+ except Exception as e:
130
+ print(f"Error setting up Glyph weights: {e}")
131
 
132
  print("Environment Ready.")
133
  print("=" * 50)