LHMPP / core /utils /model_download_utils.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-03-20 14:38:28
# @Function : LHM++ auto-download: HuggingFace Hub and models
import os
import subprocess
import sys
from typing import Dict
sys.path.append("./")
from core.utils.model_card import HuggingFace_MODEL_CARD, HuggingFace_Prior_MODEL_CARD
# --- Hugging Face Hub Import (auto-install if missing) ---
package_name = "huggingface_hub"
hf_snapshot = None
try:
from huggingface_hub import snapshot_download as hf_snapshot_import
hf_snapshot = hf_snapshot_import
print(f"{package_name} imported successfully.")
except ImportError:
print(f"{package_name} is not installed. Attempting to install...")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
print(f"{package_name} has been installed.")
from huggingface_hub import snapshot_download as hf_snapshot_import
hf_snapshot = hf_snapshot_import
except Exception as e:
print(f"Failed to install or import {package_name}: {e}")
except Exception as e:
print(f"An unexpected error occurred during {package_name} import: {e}")
def _is_valid_model_dir(path: str) -> bool:
"""Check if path contains LHM++ or LHM model files (config + weights)."""
if not path or not os.path.isdir(path):
return False
config_files = ("config.json", "configuration.json")
weight_files = ("pytorch_model.bin", "model.safetensors")
has_config = any(os.path.exists(os.path.join(path, f)) for f in config_files)
has_weights = any(os.path.exists(os.path.join(path, f)) for f in weight_files)
# LHM++ may use sharded safetensors
if not has_weights:
has_weights = any(
f.endswith(".safetensors")
for f in os.listdir(path)
if os.path.isfile(os.path.join(path, f))
)
return has_config or has_weights
def _get_max_step_folder(current_path: str):
"""
Find the best model checkpoint path.
- For LHM: look for step_* folders and return the one with max step number.
- For LHM++: flat structure (config.json + model.safetensors) -> return current_path.
"""
if not os.path.isdir(current_path):
return None
step_folders = [
f
for f in os.listdir(current_path)
if f.startswith("step_") and os.path.isdir(os.path.join(current_path, f))
]
if not step_folders:
if _is_valid_model_dir(current_path):
return current_path
return None
def _step_num(name):
try:
return int(name.split("_")[1])
except (IndexError, ValueError):
return 0
max_folder = max(step_folders, key=_step_num)
return os.path.join(current_path, max_folder)
class AutoModelQuery:
"""
LHM++ auto-download: query model path from local cache or download from HuggingFace.
"""
def __init__(self, save_dir: str = "./pretrained_models", hf_kwargs=None):
save_dir = os.path.abspath(save_dir)
# If broken symlink (target missing), remove it and create real dir
if os.path.lexists(save_dir) and not os.path.exists(save_dir):
try:
os.unlink(save_dir)
print(f"Removed broken symlink, will create real dir: {save_dir}")
except OSError:
save_dir = os.path.join(
os.path.dirname(save_dir), "pretrained_models_local"
)
self.base_save_dir = save_dir
self.hf_save_dir = os.path.join(self.base_save_dir, "huggingface")
os.makedirs(self.base_save_dir, exist_ok=True)
os.makedirs(self.hf_save_dir, exist_ok=True)
self._logger = lambda x: "\033[31m{}\033[0m".format(x)
def _ensure_trailing_slash(self, path: str) -> str:
return path + "/" if path and path[-1] != "/" else path
def query_huggingface_model(self, model_name: str, local_only: bool = False) -> str:
"""Query HuggingFace model, optionally local cache only."""
if hf_snapshot is None:
print(self._logger("Hugging Face Hub library not available."))
raise ImportError("huggingface_hub not imported")
if model_name not in HuggingFace_MODEL_CARD:
raise ValueError(
f"Model '{model_name}' not found in HuggingFace_MODEL_CARD."
)
model_repo_id = HuggingFace_MODEL_CARD[model_name]
action = "Checking cache for" if local_only else "Querying/Downloading"
print(f"{action} Hugging Face model: {model_repo_id}")
try:
model_path = hf_snapshot(
repo_id=model_repo_id,
cache_dir=self.hf_save_dir,
local_files_only=local_only,
)
print(
f"Hugging Face model path {'found locally' if local_only else 'obtained'}: {model_path}"
)
return model_path
except FileNotFoundError:
if local_only:
print(
f"Hugging Face model {model_repo_id} not found in local cache {self.hf_save_dir}."
)
else:
print(
self._logger(f"Cannot download {model_repo_id} from Hugging Face")
)
raise
except Exception as exc:
log_prefix = "Local check for" if local_only else "Download attempt for"
print(
self._logger(
f"{log_prefix} Hugging Face model {model_repo_id} failed: {exc}"
)
)
raise exc
def query_prior_huggingface_model(
self, model_name: str, local_only: bool = False
) -> str:
"""Query HuggingFace prior model."""
if hf_snapshot is None:
raise ImportError("huggingface_hub not imported")
if model_name not in HuggingFace_Prior_MODEL_CARD:
raise ValueError(
f"Model '{model_name}' not in HuggingFace_Prior_MODEL_CARD"
)
repo_id = HuggingFace_Prior_MODEL_CARD[model_name]
print(f"{'Checking' if local_only else 'Downloading'} HF prior: {repo_id}")
model_path = hf_snapshot(
repo_id=repo_id,
cache_dir=self.hf_save_dir,
local_files_only=local_only,
)
return model_path
def _link_prior_bundle_to_base(self, prior_path: str) -> None:
"""
Create symlinks in base_save_dir for each top-level item in the prior bundle.
Skips items that already exist and have valid targets.
Removes broken symlinks and re-creates them. Enables flat paths like
pretrained_models/human_model_files, pretrained_models/voxel_grid, etc.
Uses absolute paths for symlink targets (Linux/HuggingFace Spaces).
"""
prior_path = os.path.abspath(os.path.normpath(prior_path.rstrip("/")))
if not os.path.isdir(prior_path):
print(self._logger(f"Prior path is not a directory: {prior_path}"))
return
base = os.path.abspath(self.base_save_dir)
os.makedirs(base, exist_ok=True)
linked = []
for name in os.listdir(prior_path):
src = os.path.join(prior_path, name)
dst = os.path.join(base, name)
if not os.path.exists(src):
continue
if os.path.lexists(dst):
if os.path.exists(dst):
linked.append(name)
continue
try:
os.unlink(dst)
print(f"Removed broken symlink: {name}")
except OSError as e:
print(self._logger(f"Failed to remove broken symlink {name}: {e}"))
continue
try:
src_abs = os.path.abspath(src)
os.symlink(src_abs, dst)
print(f"Linked {name} -> {src_abs}")
linked.append(name)
except OSError as e:
print(self._logger(f"Failed to link {name}: {e}"))
if linked:
print(f"Prior symlinks in pretrained_models: {linked}")
def query_prior(self, model_name: str) -> str:
"""
Prior model query: skip if local exists, else download from HuggingFace.
"""
if model_name not in HuggingFace_Prior_MODEL_CARD:
raise ValueError(
f"Prior model '{model_name}' not in HuggingFace_Prior_MODEL_CARD"
)
# 1. Check local HuggingFace cache
try:
path = self.query_prior_huggingface_model(model_name, local_only=True)
if path and os.path.isdir(path):
print(f"Prior model exists locally (HF): {path}")
self._link_prior_bundle_to_base(path)
return self._ensure_trailing_slash(path)
except (FileNotFoundError, ImportError, Exception) as e:
print(f"Local HF cache check: {e}")
# 2. Download from HuggingFace
path = self.query_prior_huggingface_model(model_name, local_only=False)
print(f"Prior model downloaded from HuggingFace: {path}")
self._link_prior_bundle_to_base(path)
return self._ensure_trailing_slash(path)
def download_all_prior_models(self) -> Dict[str, str]:
"""
Download all prior models to pretrained_models; skip if local exists.
Returns {model_name: model_path} mapping.
"""
names = set(HuggingFace_Prior_MODEL_CARD)
result = {}
for name in names:
try:
path = self.query_prior(name)
result[name] = path
except Exception as e:
print(self._logger(f"Prior model {name} failed: {e}"))
return result
def query(self, model_name: str) -> str:
"""
Query model path: check local HuggingFace cache, else download from HuggingFace.
"""
print(f"\n--- Querying model: {model_name} ---")
if model_name not in HuggingFace_MODEL_CARD:
raise ValueError(
f"Model name '{model_name}' not found in HuggingFace_MODEL_CARD."
)
# 1. Local HuggingFace cache
try:
print("Step 1: Checking local Hugging Face cache...")
model_path = self.query_huggingface_model(model_name, local_only=True)
if model_path:
print(f"Success: Found in local Hugging Face cache: {model_path}")
return self._ensure_trailing_slash(model_path)
except FileNotFoundError:
print("Info: Not found in local Hugging Face cache.")
except ImportError:
print(
self._logger(
"Warning: Hugging Face library not available for local check."
)
)
except Exception as e:
print(
self._logger(f"Warning: Error checking local Hugging Face cache: {e}")
)
# 2. Download from HuggingFace
print("Info: Model not found in local cache. Attempting download...")
try:
print("Step 2: Attempting download from Hugging Face...")
model_path = self.query_huggingface_model(model_name, local_only=False)
if model_path:
print(f"Success: Downloaded from Hugging Face: {model_path}")
return self._ensure_trailing_slash(model_path)
except ImportError:
print(
self._logger(
"Warning: Hugging Face library not available, cannot download."
)
)
raise
except Exception as e:
error_msg = f"Failed to find or download model '{model_name}' from Hugging Face: {e}"
print(self._logger(error_msg))
raise FileNotFoundError(error_msg) from e
raise FileNotFoundError(f"Failed to find or download model '{model_name}'")
if __name__ == "__main__":
test_save_dir = "./pretrained_models"
print(
f"Initializing AutoModelQuery with save_dir: {os.path.abspath(test_save_dir)}"
)
automodel = AutoModelQuery(save_dir=test_save_dir)
# Download all prior models: python -m core.utils.model_download_utils --prior
if "--prior" in sys.argv:
prior_names = set(HuggingFace_Prior_MODEL_CARD)
print(f"\n--- Downloading prior models: {prior_names} ---")
result = automodel.download_all_prior_models()
for name, path in result.items():
print(f" {name}: {path}")
print("Prior models download complete.")
sys.exit(0)
test_models = ["LHMPP-700M"]
for model_to_test in test_models:
print(f"\n--- Testing {model_to_test} ---")
try:
model_path_test = automodel.query(model_to_test)
print(f"===> Final path for {model_to_test}: {model_path_test}")
print(f" Does path exist? {os.path.exists(model_path_test)}")
except Exception as e:
print(f"Error: {e}")