ColamanAI's picture
Upload 169 files
b74998d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""
Helper functions for HuggingFace integration and model initialization.
"""
import json
import os
def load_hf_token():
"""Load HuggingFace access token from local file"""
# Also try environment variable
# see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
or os.getenv("HUGGING_FACE_MODEL_TOKEN")
)
if token:
print("Loaded HuggingFace token from environment variable")
return token
print(
"Warning: No HuggingFace token found. Model loading may fail for private repositories."
)
return None
def init_hydra_config(config_path, overrides=None):
"""Initialize Hydra config"""
import hydra
config_dir = os.path.dirname(config_path)
config_name = os.path.basename(config_path).split(".")[0]
relative_path = os.path.relpath(config_dir, os.path.dirname(__file__))
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base=None, config_path=relative_path)
if overrides is not None:
cfg = hydra.compose(config_name=config_name, overrides=overrides)
else:
cfg = hydra.compose(config_name=config_name)
return cfg
def initialize_mapanything_model(high_level_config, device):
"""
Initialize MapAnything model with three-tier fallback approach:
1. Try HuggingFace from_pretrained()
2. Download HF config + use local model factory + load HF weights
3. Pure local configuration fallback
Args:
high_level_config (dict): Configuration dictionary containing model settings
device (torch.device): Device to load the model on
Returns:
torch.nn.Module: Initialized MapAnything model
"""
import torch
from huggingface_hub import hf_hub_download
from mapanything.models import init_model, MapAnything
print("Initializing MapAnything model...")
# Initialize Hydra config and create model from configuration
cfg = init_hydra_config(
high_level_config["path"], overrides=high_level_config["config_overrides"]
)
# Try using from_pretrained first
try:
print("Loading MapAnything model from_pretrained...")
model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
device
)
print("Loading MapAnything model from_pretrained succeeded...")
return model
except Exception as e:
print(f"from_pretrained failed: {e}")
print("Falling back to local configuration approach using hf_hub_download...")
# Create model from local configuration instead of using from_pretrained
# Try to download and use the config from HuggingFace Hub
try:
print("Downloading model configuration from HuggingFace Hub...")
config_path = hf_hub_download(
repo_id=high_level_config["hf_model_name"],
filename=high_level_config["config_name"],
token=load_hf_token(),
)
# Load the config from the downloaded file
with open(config_path, "r") as f:
downloaded_config = json.load(f)
print("Using downloaded configuration for model initialization")
model = init_model(
model_str=downloaded_config.get(
"model_str", high_level_config["model_str"]
),
model_config=downloaded_config.get(
"model_config", cfg.model.model_config
),
torch_hub_force_reload=high_level_config.get(
"torch_hub_force_reload", False
),
)
except Exception as config_e:
print(f"Failed to download/use HuggingFace config: {config_e}")
print("Falling back to local configuration...")
# Fall back to local configuration as before
model = init_model(
model_str=cfg.model.model_str,
model_config=cfg.model.model_config,
torch_hub_force_reload=high_level_config.get(
"torch_hub_force_reload", False
),
)
# Load the pretrained weights from HuggingFace Hub
try:
# First, let's see what files are available in the repository
try:
checkpoint_filename = high_level_config["checkpoint_name"]
# Download the model weights
checkpoint_path = hf_hub_download(
repo_id=high_level_config["hf_model_name"],
filename=checkpoint_filename,
token=load_hf_token(),
)
# Load the weights
print("start loading checkpoint")
if checkpoint_filename.endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path)
else:
checkpoint = torch.load(
checkpoint_path, map_location="cpu", weights_only=False
)
print("start loading state_dict")
if "model" in checkpoint:
model.load_state_dict(checkpoint["model"], strict=False)
elif "state_dict" in checkpoint:
model.load_state_dict(checkpoint["state_dict"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
print(
f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
)
except Exception as inner_e:
print(f"Error listing repository files or loading weights: {inner_e}")
raise inner_e
except Exception as e:
print(f"Warning: Could not load pretrained weights: {e}")
print("Proceeding with randomly initialized model...")
model = model.to(device)
return model
def initialize_mapanything_local(local_config, device):
"""Initialize a MapAnything model entirely from local resources.
Args:
local_config (dict):
- path (str): Path to the Hydra config (for example ``configs/train.yaml``).
- checkpoint_path (str): Local path to the pretrained checkpoint.
- config_overrides (list[str], optional): Hydra override strings.
- config_json_path (str, optional): JSON file containing ``model_str``/``model_config`` overrides.
- model_str (str, optional): Model alias if not provided by the JSON/config (defaults to Hydra config value).
- torch_hub_force_reload (bool, optional): Forwarded to ``init_model``.
- strict (bool, optional): ``load_state_dict`` strict flag, defaults to False so older checkpoints remain compatible.
device (torch.device | str): Target device that will host the model.
Returns:
torch.nn.Module: MapAnything model moved to ``device`` and switched to ``eval()``.
Raises:
FileNotFoundError: Raised when the JSON config or checkpoint cannot be found.
"""
if "path" not in local_config or "checkpoint_path" not in local_config:
raise ValueError("local_config must provide both 'path' and 'checkpoint_path'")
import torch
from mapanything.models import init_model
config_overrides = local_config.get("config_overrides")
cfg = init_hydra_config(local_config["path"], overrides=config_overrides)
model_config_json = None
config_json_path = local_config.get("config_json_path")
if config_json_path:
if not os.path.exists(config_json_path):
raise FileNotFoundError(f"Config JSON not found: {config_json_path}")
with open(config_json_path, "r") as f:
model_config_json = json.load(f)
model_str = None
model_config = None
if model_config_json:
model_str = model_config_json.get("model_str")
model_config = model_config_json.get("model_config")
if model_str is None:
model_str = local_config.get("model_str", cfg.model.model_str)
if model_config is None:
model_config = local_config.get("model_config", cfg.model.model_config)
torch_hub_force_reload = local_config.get("torch_hub_force_reload", False)
model = init_model(
model_str=model_str,
model_config=model_config,
torch_hub_force_reload=torch_hub_force_reload,
)
checkpoint_path = local_config["checkpoint_path"]
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file as load_safetensors
checkpoint = load_safetensors(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
strict = local_config.get("strict", False)
if isinstance(checkpoint, dict):
if "model" in checkpoint:
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=strict)
model = model.to(device).eval()
return model