| |
| |
| |
| |
| |
| |
|
|
| import os |
| import subprocess |
| import sys |
| from typing import Dict |
|
|
| sys.path.append("./") |
|
|
| from core.utils.model_card import HuggingFace_MODEL_CARD, HuggingFace_Prior_MODEL_CARD |
|
|
| |
| package_name = "huggingface_hub" |
| hf_snapshot = None |
| try: |
| from huggingface_hub import snapshot_download as hf_snapshot_import |
|
|
| hf_snapshot = hf_snapshot_import |
| print(f"{package_name} imported successfully.") |
| except ImportError: |
| print(f"{package_name} is not installed. Attempting to install...") |
| try: |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) |
| print(f"{package_name} has been installed.") |
| from huggingface_hub import snapshot_download as hf_snapshot_import |
|
|
| hf_snapshot = hf_snapshot_import |
| except Exception as e: |
| print(f"Failed to install or import {package_name}: {e}") |
| except Exception as e: |
| print(f"An unexpected error occurred during {package_name} import: {e}") |
|
|
|
|
| def _is_valid_model_dir(path: str) -> bool: |
| """Check if path contains LHM++ or LHM model files (config + weights).""" |
| if not path or not os.path.isdir(path): |
| return False |
| config_files = ("config.json", "configuration.json") |
| weight_files = ("pytorch_model.bin", "model.safetensors") |
| has_config = any(os.path.exists(os.path.join(path, f)) for f in config_files) |
| has_weights = any(os.path.exists(os.path.join(path, f)) for f in weight_files) |
| |
| if not has_weights: |
| has_weights = any( |
| f.endswith(".safetensors") |
| for f in os.listdir(path) |
| if os.path.isfile(os.path.join(path, f)) |
| ) |
| return has_config or has_weights |
|
|
|
|
| def _get_max_step_folder(current_path: str): |
| """ |
| Find the best model checkpoint path. |
| - For LHM: look for step_* folders and return the one with max step number. |
| - For LHM++: flat structure (config.json + model.safetensors) -> return current_path. |
| """ |
| if not os.path.isdir(current_path): |
| return None |
|
|
| step_folders = [ |
| f |
| for f in os.listdir(current_path) |
| if f.startswith("step_") and os.path.isdir(os.path.join(current_path, f)) |
| ] |
|
|
| if not step_folders: |
| if _is_valid_model_dir(current_path): |
| return current_path |
| return None |
|
|
| def _step_num(name): |
| try: |
| return int(name.split("_")[1]) |
| except (IndexError, ValueError): |
| return 0 |
|
|
| max_folder = max(step_folders, key=_step_num) |
| return os.path.join(current_path, max_folder) |
|
|
|
|
| class AutoModelQuery: |
| """ |
| LHM++ auto-download: query model path from local cache or download from HuggingFace. |
| """ |
|
|
| def __init__(self, save_dir: str = "./pretrained_models", hf_kwargs=None): |
| save_dir = os.path.abspath(save_dir) |
| |
| if os.path.lexists(save_dir) and not os.path.exists(save_dir): |
| try: |
| os.unlink(save_dir) |
| print(f"Removed broken symlink, will create real dir: {save_dir}") |
| except OSError: |
| save_dir = os.path.join( |
| os.path.dirname(save_dir), "pretrained_models_local" |
| ) |
| self.base_save_dir = save_dir |
| self.hf_save_dir = os.path.join(self.base_save_dir, "huggingface") |
|
|
| os.makedirs(self.base_save_dir, exist_ok=True) |
| os.makedirs(self.hf_save_dir, exist_ok=True) |
|
|
| self._logger = lambda x: "\033[31m{}\033[0m".format(x) |
|
|
| def _ensure_trailing_slash(self, path: str) -> str: |
| return path + "/" if path and path[-1] != "/" else path |
|
|
| def query_huggingface_model(self, model_name: str, local_only: bool = False) -> str: |
| """Query HuggingFace model, optionally local cache only.""" |
| if hf_snapshot is None: |
| print(self._logger("Hugging Face Hub library not available.")) |
| raise ImportError("huggingface_hub not imported") |
|
|
| if model_name not in HuggingFace_MODEL_CARD: |
| raise ValueError( |
| f"Model '{model_name}' not found in HuggingFace_MODEL_CARD." |
| ) |
|
|
| model_repo_id = HuggingFace_MODEL_CARD[model_name] |
| action = "Checking cache for" if local_only else "Querying/Downloading" |
| print(f"{action} Hugging Face model: {model_repo_id}") |
|
|
| try: |
| model_path = hf_snapshot( |
| repo_id=model_repo_id, |
| cache_dir=self.hf_save_dir, |
| local_files_only=local_only, |
| ) |
| print( |
| f"Hugging Face model path {'found locally' if local_only else 'obtained'}: {model_path}" |
| ) |
| return model_path |
| except FileNotFoundError: |
| if local_only: |
| print( |
| f"Hugging Face model {model_repo_id} not found in local cache {self.hf_save_dir}." |
| ) |
| else: |
| print( |
| self._logger(f"Cannot download {model_repo_id} from Hugging Face") |
| ) |
| raise |
| except Exception as exc: |
| log_prefix = "Local check for" if local_only else "Download attempt for" |
| print( |
| self._logger( |
| f"{log_prefix} Hugging Face model {model_repo_id} failed: {exc}" |
| ) |
| ) |
| raise exc |
|
|
| def query_prior_huggingface_model( |
| self, model_name: str, local_only: bool = False |
| ) -> str: |
| """Query HuggingFace prior model.""" |
| if hf_snapshot is None: |
| raise ImportError("huggingface_hub not imported") |
| if model_name not in HuggingFace_Prior_MODEL_CARD: |
| raise ValueError( |
| f"Model '{model_name}' not in HuggingFace_Prior_MODEL_CARD" |
| ) |
| repo_id = HuggingFace_Prior_MODEL_CARD[model_name] |
| print(f"{'Checking' if local_only else 'Downloading'} HF prior: {repo_id}") |
| model_path = hf_snapshot( |
| repo_id=repo_id, |
| cache_dir=self.hf_save_dir, |
| local_files_only=local_only, |
| ) |
| return model_path |
|
|
| def _link_prior_bundle_to_base(self, prior_path: str) -> None: |
| """ |
| Create symlinks in base_save_dir for each top-level item in the prior bundle. |
| Skips items that already exist and have valid targets. |
| Removes broken symlinks and re-creates them. Enables flat paths like |
| pretrained_models/human_model_files, pretrained_models/voxel_grid, etc. |
| Uses absolute paths for symlink targets (Linux/HuggingFace Spaces). |
| """ |
| prior_path = os.path.abspath(os.path.normpath(prior_path.rstrip("/"))) |
| if not os.path.isdir(prior_path): |
| print(self._logger(f"Prior path is not a directory: {prior_path}")) |
| return |
| base = os.path.abspath(self.base_save_dir) |
| os.makedirs(base, exist_ok=True) |
| linked = [] |
| for name in os.listdir(prior_path): |
| src = os.path.join(prior_path, name) |
| dst = os.path.join(base, name) |
| if not os.path.exists(src): |
| continue |
| if os.path.lexists(dst): |
| if os.path.exists(dst): |
| linked.append(name) |
| continue |
| try: |
| os.unlink(dst) |
| print(f"Removed broken symlink: {name}") |
| except OSError as e: |
| print(self._logger(f"Failed to remove broken symlink {name}: {e}")) |
| continue |
| try: |
| src_abs = os.path.abspath(src) |
| os.symlink(src_abs, dst) |
| print(f"Linked {name} -> {src_abs}") |
| linked.append(name) |
| except OSError as e: |
| print(self._logger(f"Failed to link {name}: {e}")) |
| if linked: |
| print(f"Prior symlinks in pretrained_models: {linked}") |
|
|
| def query_prior(self, model_name: str) -> str: |
| """ |
| Prior model query: skip if local exists, else download from HuggingFace. |
| """ |
| if model_name not in HuggingFace_Prior_MODEL_CARD: |
| raise ValueError( |
| f"Prior model '{model_name}' not in HuggingFace_Prior_MODEL_CARD" |
| ) |
|
|
| |
| try: |
| path = self.query_prior_huggingface_model(model_name, local_only=True) |
| if path and os.path.isdir(path): |
| print(f"Prior model exists locally (HF): {path}") |
| self._link_prior_bundle_to_base(path) |
| return self._ensure_trailing_slash(path) |
| except (FileNotFoundError, ImportError, Exception) as e: |
| print(f"Local HF cache check: {e}") |
|
|
| |
| path = self.query_prior_huggingface_model(model_name, local_only=False) |
| print(f"Prior model downloaded from HuggingFace: {path}") |
| self._link_prior_bundle_to_base(path) |
| return self._ensure_trailing_slash(path) |
|
|
| def download_all_prior_models(self) -> Dict[str, str]: |
| """ |
| Download all prior models to pretrained_models; skip if local exists. |
| Returns {model_name: model_path} mapping. |
| """ |
| names = set(HuggingFace_Prior_MODEL_CARD) |
| result = {} |
| for name in names: |
| try: |
| path = self.query_prior(name) |
| result[name] = path |
| except Exception as e: |
| print(self._logger(f"Prior model {name} failed: {e}")) |
| return result |
|
|
| def query(self, model_name: str) -> str: |
| """ |
| Query model path: check local HuggingFace cache, else download from HuggingFace. |
| """ |
| print(f"\n--- Querying model: {model_name} ---") |
|
|
| if model_name not in HuggingFace_MODEL_CARD: |
| raise ValueError( |
| f"Model name '{model_name}' not found in HuggingFace_MODEL_CARD." |
| ) |
|
|
| |
| try: |
| print("Step 1: Checking local Hugging Face cache...") |
| model_path = self.query_huggingface_model(model_name, local_only=True) |
| if model_path: |
| print(f"Success: Found in local Hugging Face cache: {model_path}") |
| return self._ensure_trailing_slash(model_path) |
| except FileNotFoundError: |
| print("Info: Not found in local Hugging Face cache.") |
| except ImportError: |
| print( |
| self._logger( |
| "Warning: Hugging Face library not available for local check." |
| ) |
| ) |
| except Exception as e: |
| print( |
| self._logger(f"Warning: Error checking local Hugging Face cache: {e}") |
| ) |
|
|
| |
| print("Info: Model not found in local cache. Attempting download...") |
| try: |
| print("Step 2: Attempting download from Hugging Face...") |
| model_path = self.query_huggingface_model(model_name, local_only=False) |
| if model_path: |
| print(f"Success: Downloaded from Hugging Face: {model_path}") |
| return self._ensure_trailing_slash(model_path) |
| except ImportError: |
| print( |
| self._logger( |
| "Warning: Hugging Face library not available, cannot download." |
| ) |
| ) |
| raise |
| except Exception as e: |
| error_msg = f"Failed to find or download model '{model_name}' from Hugging Face: {e}" |
| print(self._logger(error_msg)) |
| raise FileNotFoundError(error_msg) from e |
|
|
| raise FileNotFoundError(f"Failed to find or download model '{model_name}'") |
|
|
|
|
| if __name__ == "__main__": |
| test_save_dir = "./pretrained_models" |
| print( |
| f"Initializing AutoModelQuery with save_dir: {os.path.abspath(test_save_dir)}" |
| ) |
| automodel = AutoModelQuery(save_dir=test_save_dir) |
|
|
| |
| if "--prior" in sys.argv: |
| prior_names = set(HuggingFace_Prior_MODEL_CARD) |
| print(f"\n--- Downloading prior models: {prior_names} ---") |
| result = automodel.download_all_prior_models() |
| for name, path in result.items(): |
| print(f" {name}: {path}") |
| print("Prior models download complete.") |
| sys.exit(0) |
|
|
| test_models = ["LHMPP-700M"] |
| for model_to_test in test_models: |
| print(f"\n--- Testing {model_to_test} ---") |
| try: |
| model_path_test = automodel.query(model_to_test) |
| print(f"===> Final path for {model_to_test}: {model_path_test}") |
| print(f" Does path exist? {os.path.exists(model_path_test)}") |
| except Exception as e: |
| print(f"Error: {e}") |
|
|