multimodalart HF Staff commited on
Commit
3483f07
·
verified ·
1 Parent(s): 4809c06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -25
app.py CHANGED
@@ -18,6 +18,8 @@ MODEL_DIR = os.path.abspath("ckpts")
18
  # Repositories
19
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
20
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
 
 
21
 
22
  # Configuration
23
  TRANSFORMER_VERSION = "480p_i2v_distilled"
@@ -38,7 +40,7 @@ def setup_environment():
38
  if REPO_DIR not in sys.path:
39
  sys.path.insert(0, REPO_DIR)
40
 
41
- # 3. Download Main Weights
42
  os.makedirs(MODEL_DIR, exist_ok=True)
43
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
44
 
@@ -49,8 +51,6 @@ def setup_environment():
49
  allow_patterns = [
50
  f"transformer/{TRANSFORMER_VERSION}/*",
51
  "vae/*",
52
- "text_encoder/*",
53
- "vision_encoder/*",
54
  "scheduler/*",
55
  "tokenizer/*"
56
  ]
@@ -64,8 +64,35 @@ def setup_environment():
64
  print(f"Error downloading main weights: {e}")
65
  sys.exit(1)
66
 
67
- # 4. Download & Restructure Glyph Weights
68
- # The pipeline expects: ckpts/text_encoder/Glyph-SDXL-v2/checkpoints/byt5_model.pt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
70
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
71
 
@@ -73,48 +100,31 @@ def setup_environment():
73
  print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
74
  try:
75
  from huggingface_hub import snapshot_download
76
- # Download to a temp folder first
77
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
78
- snapshot_download(
79
- repo_id=HF_GLYPH_REPO,
80
- local_dir=glyph_temp,
81
- local_dir_use_symlinks=False
82
- )
83
 
84
- # Create target structure
85
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
86
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
87
 
88
- # Move Assets (color_idx.json, etc.)
89
  src_assets = os.path.join(glyph_temp, "assets")
90
  if os.path.exists(src_assets):
91
  for f in os.listdir(src_assets):
92
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
93
 
94
- # Move & Rename Model (pytorch_model.bin -> byt5_model.pt)
95
- # Try bin first, then safetensors (code usually loads via torch.load, so bin/pt is safer)
96
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
97
  if os.path.exists(src_bin):
98
- print(" moving pytorch_model.bin -> byt5_model.pt")
99
  shutil.move(src_bin, glyph_ckpt_target)
100
  else:
101
- # Fallback if repo changes structure
102
- print("Warning: pytorch_model.bin not found, looking for safetensors...")
103
  src_safe = os.path.join(glyph_temp, "model.safetensors")
104
  if os.path.exists(src_safe):
105
- # Note: Standard torch.load might fail on safetensors if code expects pickle,
106
- # but let's try.
107
  shutil.move(src_safe, glyph_ckpt_target)
108
 
109
- # Clean up temp
110
  shutil.rmtree(glyph_temp, ignore_errors=True)
111
- print("Glyph setup complete.")
112
 
113
  except Exception as e:
114
  print(f"Error setting up Glyph weights: {e}")
115
- # Don't exit, maybe the model can run without it if config tweaked,
116
- # but likely it will fail later.
117
- pass
118
 
119
  print("Environment Ready.")
120
  print("=" * 50)
 
18
  # Repositories
19
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
20
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
21
+ HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
22
+ HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" # User specified
23
 
24
  # Configuration
25
  TRANSFORMER_VERSION = "480p_i2v_distilled"
 
40
  if REPO_DIR not in sys.path:
41
  sys.path.insert(0, REPO_DIR)
42
 
43
+ # 3. Download Main Weights (Transformer, VAE, Scheduler)
44
  os.makedirs(MODEL_DIR, exist_ok=True)
45
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
46
 
 
51
  allow_patterns = [
52
  f"transformer/{TRANSFORMER_VERSION}/*",
53
  "vae/*",
 
 
54
  "scheduler/*",
55
  "tokenizer/*"
56
  ]
 
64
  print(f"Error downloading main weights: {e}")
65
  sys.exit(1)
66
 
67
+ # 4. Download LLM Text Encoder (Qwen)
68
+ llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
69
+ if not os.path.exists(llm_target) or not os.listdir(llm_target):
70
+ print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
71
+ try:
72
+ from huggingface_hub import snapshot_download
73
+ snapshot_download(
74
+ repo_id=HF_LLM_REPO,
75
+ local_dir=llm_target,
76
+ local_dir_use_symlinks=False
77
+ )
78
+ except Exception as e:
79
+ print(f"Error downloading LLM: {e}")
80
+
81
+ # 5. Download Vision Encoder (SigLIP)
82
+ vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
83
+ if not os.path.exists(vision_target) or not os.listdir(vision_target):
84
+ print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
85
+ try:
86
+ from huggingface_hub import snapshot_download
87
+ snapshot_download(
88
+ repo_id=HF_VISION_REPO,
89
+ local_dir=vision_target,
90
+ local_dir_use_symlinks=False
91
+ )
92
+ except Exception as e:
93
+ print(f"Error downloading Vision Encoder: {e}")
94
+
95
+ # 6. Download & Restructure Glyph Weights
96
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
97
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
98
 
 
100
  print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
101
  try:
102
  from huggingface_hub import snapshot_download
 
103
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
104
+ snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False)
 
 
 
 
105
 
 
106
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
107
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
108
 
109
+ # Move Assets
110
  src_assets = os.path.join(glyph_temp, "assets")
111
  if os.path.exists(src_assets):
112
  for f in os.listdir(src_assets):
113
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
114
 
115
+ # Move Model
 
116
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
117
  if os.path.exists(src_bin):
 
118
  shutil.move(src_bin, glyph_ckpt_target)
119
  else:
 
 
120
  src_safe = os.path.join(glyph_temp, "model.safetensors")
121
  if os.path.exists(src_safe):
 
 
122
  shutil.move(src_safe, glyph_ckpt_target)
123
 
 
124
  shutil.rmtree(glyph_temp, ignore_errors=True)
 
125
 
126
  except Exception as e:
127
  print(f"Error setting up Glyph weights: {e}")
 
 
 
128
 
129
  print("Environment Ready.")
130
  print("=" * 50)