Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| """ | |
| Helper functions for HuggingFace integration and model initialization. | |
| """ | |
| import json | |
| import os | |
| def load_hf_token(): | |
| """Load HuggingFace access token from local file""" | |
| # Also try environment variable | |
| # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options | |
| token = ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| or os.getenv("HUGGING_FACE_MODEL_TOKEN") | |
| ) | |
| if token: | |
| print("Loaded HuggingFace token from environment variable") | |
| return token | |
| print( | |
| "Warning: No HuggingFace token found. Model loading may fail for private repositories." | |
| ) | |
| return None | |
| def init_hydra_config(config_path, overrides=None): | |
| """Initialize Hydra config""" | |
| import hydra | |
| config_dir = os.path.dirname(config_path) | |
| config_name = os.path.basename(config_path).split(".")[0] | |
| relative_path = os.path.relpath(config_dir, os.path.dirname(__file__)) | |
| hydra.core.global_hydra.GlobalHydra.instance().clear() | |
| hydra.initialize(version_base=None, config_path=relative_path) | |
| if overrides is not None: | |
| cfg = hydra.compose(config_name=config_name, overrides=overrides) | |
| else: | |
| cfg = hydra.compose(config_name=config_name) | |
| return cfg | |
| def initialize_mapanything_model(high_level_config, device): | |
| """ | |
| Initialize MapAnything model with three-tier fallback approach: | |
| 1. Try HuggingFace from_pretrained() | |
| 2. Download HF config + use local model factory + load HF weights | |
| 3. Pure local configuration fallback | |
| Args: | |
| high_level_config (dict): Configuration dictionary containing model settings | |
| device (torch.device): Device to load the model on | |
| Returns: | |
| torch.nn.Module: Initialized MapAnything model | |
| """ | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from mapanything.models import init_model, MapAnything | |
| print("Initializing MapAnything model...") | |
| # Initialize Hydra config and create model from configuration | |
| cfg = init_hydra_config( | |
| high_level_config["path"], overrides=high_level_config["config_overrides"] | |
| ) | |
| # Try using from_pretrained first | |
| try: | |
| print("Loading MapAnything model from_pretrained...") | |
| model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to( | |
| device | |
| ) | |
| print("Loading MapAnything model from_pretrained succeeded...") | |
| return model | |
| except Exception as e: | |
| print(f"from_pretrained failed: {e}") | |
| print("Falling back to local configuration approach using hf_hub_download...") | |
| # Create model from local configuration instead of using from_pretrained | |
| # Try to download and use the config from HuggingFace Hub | |
| try: | |
| print("Downloading model configuration from HuggingFace Hub...") | |
| config_path = hf_hub_download( | |
| repo_id=high_level_config["hf_model_name"], | |
| filename=high_level_config["config_name"], | |
| token=load_hf_token(), | |
| ) | |
| # Load the config from the downloaded file | |
| with open(config_path, "r") as f: | |
| downloaded_config = json.load(f) | |
| print("Using downloaded configuration for model initialization") | |
| model = init_model( | |
| model_str=downloaded_config.get( | |
| "model_str", high_level_config["model_str"] | |
| ), | |
| model_config=downloaded_config.get( | |
| "model_config", cfg.model.model_config | |
| ), | |
| torch_hub_force_reload=high_level_config.get( | |
| "torch_hub_force_reload", False | |
| ), | |
| ) | |
| except Exception as config_e: | |
| print(f"Failed to download/use HuggingFace config: {config_e}") | |
| print("Falling back to local configuration...") | |
| # Fall back to local configuration as before | |
| model = init_model( | |
| model_str=cfg.model.model_str, | |
| model_config=cfg.model.model_config, | |
| torch_hub_force_reload=high_level_config.get( | |
| "torch_hub_force_reload", False | |
| ), | |
| ) | |
| # Load the pretrained weights from HuggingFace Hub | |
| try: | |
| # First, let's see what files are available in the repository | |
| try: | |
| checkpoint_filename = high_level_config["checkpoint_name"] | |
| # Download the model weights | |
| checkpoint_path = hf_hub_download( | |
| repo_id=high_level_config["hf_model_name"], | |
| filename=checkpoint_filename, | |
| token=load_hf_token(), | |
| ) | |
| # Load the weights | |
| print("start loading checkpoint") | |
| if checkpoint_filename.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| checkpoint = load_file(checkpoint_path) | |
| else: | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location="cpu", weights_only=False | |
| ) | |
| print("start loading state_dict") | |
| if "model" in checkpoint: | |
| model.load_state_dict(checkpoint["model"], strict=False) | |
| elif "state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["state_dict"], strict=False) | |
| else: | |
| model.load_state_dict(checkpoint, strict=False) | |
| print( | |
| f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})" | |
| ) | |
| except Exception as inner_e: | |
| print(f"Error listing repository files or loading weights: {inner_e}") | |
| raise inner_e | |
| except Exception as e: | |
| print(f"Warning: Could not load pretrained weights: {e}") | |
| print("Proceeding with randomly initialized model...") | |
| model = model.to(device) | |
| return model | |
| def initialize_mapanything_local(local_config, device): | |
| """Initialize a MapAnything model entirely from local resources. | |
| Args: | |
| local_config (dict): | |
| - path (str): Path to the Hydra config (for example ``configs/train.yaml``). | |
| - checkpoint_path (str): Local path to the pretrained checkpoint. | |
| - config_overrides (list[str], optional): Hydra override strings. | |
| - config_json_path (str, optional): JSON file containing ``model_str``/``model_config`` overrides. | |
| - model_str (str, optional): Model alias if not provided by the JSON/config (defaults to Hydra config value). | |
| - torch_hub_force_reload (bool, optional): Forwarded to ``init_model``. | |
| - strict (bool, optional): ``load_state_dict`` strict flag, defaults to False so older checkpoints remain compatible. | |
| device (torch.device | str): Target device that will host the model. | |
| Returns: | |
| torch.nn.Module: MapAnything model moved to ``device`` and switched to ``eval()``. | |
| Raises: | |
| FileNotFoundError: Raised when the JSON config or checkpoint cannot be found. | |
| """ | |
| if "path" not in local_config or "checkpoint_path" not in local_config: | |
| raise ValueError("local_config must provide both 'path' and 'checkpoint_path'") | |
| import torch | |
| from mapanything.models import init_model | |
| config_overrides = local_config.get("config_overrides") | |
| cfg = init_hydra_config(local_config["path"], overrides=config_overrides) | |
| model_config_json = None | |
| config_json_path = local_config.get("config_json_path") | |
| if config_json_path: | |
| if not os.path.exists(config_json_path): | |
| raise FileNotFoundError(f"Config JSON not found: {config_json_path}") | |
| with open(config_json_path, "r") as f: | |
| model_config_json = json.load(f) | |
| model_str = None | |
| model_config = None | |
| if model_config_json: | |
| model_str = model_config_json.get("model_str") | |
| model_config = model_config_json.get("model_config") | |
| if model_str is None: | |
| model_str = local_config.get("model_str", cfg.model.model_str) | |
| if model_config is None: | |
| model_config = local_config.get("model_config", cfg.model.model_config) | |
| torch_hub_force_reload = local_config.get("torch_hub_force_reload", False) | |
| model = init_model( | |
| model_str=model_str, | |
| model_config=model_config, | |
| torch_hub_force_reload=torch_hub_force_reload, | |
| ) | |
| checkpoint_path = local_config["checkpoint_path"] | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| if checkpoint_path.endswith(".safetensors"): | |
| from safetensors.torch import load_file as load_safetensors | |
| checkpoint = load_safetensors(checkpoint_path) | |
| else: | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| strict = local_config.get("strict", False) | |
| if isinstance(checkpoint, dict): | |
| if "model" in checkpoint: | |
| state_dict = checkpoint["model"] | |
| elif "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict, strict=strict) | |
| model = model.to(device).eval() | |
| return model | |