|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Helper functions for HuggingFace integration and model initialization. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
def load_hf_token(): |
|
|
"""Load HuggingFace access token from local file""" |
|
|
token_file_paths = [ |
|
|
"/home/aknapitsch/hf_token.txt", |
|
|
] |
|
|
|
|
|
for token_path in token_file_paths: |
|
|
if os.path.exists(token_path): |
|
|
try: |
|
|
with open(token_path, "r") as f: |
|
|
token = f.read().strip() |
|
|
print(f"Loaded HuggingFace token from: {token_path}") |
|
|
return token |
|
|
except Exception as e: |
|
|
print(f"Error reading token from {token_path}: {e}") |
|
|
continue |
|
|
else: |
|
|
print(token_path, "token_path doesnt exist") |
|
|
|
|
|
|
|
|
|
|
|
token = ( |
|
|
os.getenv("HF_TOKEN") |
|
|
or os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
or os.getenv("HUGGING_FACE_MODEL_TOKEN") |
|
|
) |
|
|
if token: |
|
|
print("Loaded HuggingFace token from environment variable") |
|
|
return token |
|
|
|
|
|
print( |
|
|
"Warning: No HuggingFace token found. Model loading may fail for private repositories." |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
def init_hydra_config(config_path, overrides=None): |
|
|
"""Initialize Hydra config""" |
|
|
import hydra |
|
|
|
|
|
config_dir = os.path.dirname(config_path) |
|
|
config_name = os.path.basename(config_path).split(".")[0] |
|
|
relative_path = os.path.relpath(config_dir, os.path.dirname(__file__)) |
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear() |
|
|
hydra.initialize(version_base=None, config_path=relative_path) |
|
|
if overrides is not None: |
|
|
cfg = hydra.compose(config_name=config_name, overrides=overrides) |
|
|
else: |
|
|
cfg = hydra.compose(config_name=config_name) |
|
|
return cfg |
|
|
|
|
|
|
|
|
def initialize_mapanything_model(high_level_config, device): |
|
|
""" |
|
|
Initialize MapAnything model with three-tier fallback approach: |
|
|
1. Try HuggingFace from_pretrained() |
|
|
2. Download HF config + use local model factory + load HF weights |
|
|
3. Pure local configuration fallback |
|
|
|
|
|
Args: |
|
|
high_level_config (dict): Configuration dictionary containing model settings |
|
|
device (torch.device): Device to load the model on |
|
|
|
|
|
Returns: |
|
|
torch.nn.Module: Initialized MapAnything model |
|
|
""" |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from mapanything.models import init_model, MapAnything |
|
|
|
|
|
print("Initializing MapAnything model...") |
|
|
|
|
|
|
|
|
cfg = init_hydra_config( |
|
|
high_level_config["path"], overrides=high_level_config["config_overrides"] |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
print("Loading MapAnything model from_pretrained...") |
|
|
model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to( |
|
|
device |
|
|
) |
|
|
print("Loading MapAnything model from_pretrained succeeded...") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"from_pretrained failed: {e}") |
|
|
print("Falling back to local configuration approach using hf_hub_download...") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print("Downloading model configuration from HuggingFace Hub...") |
|
|
config_path = hf_hub_download( |
|
|
repo_id=high_level_config["hf_model_name"], |
|
|
filename=high_level_config["config_name"], |
|
|
token=load_hf_token(), |
|
|
) |
|
|
|
|
|
|
|
|
with open(config_path, "r") as f: |
|
|
downloaded_config = json.load(f) |
|
|
|
|
|
print("Using downloaded configuration for model initialization") |
|
|
model = init_model( |
|
|
model_str=downloaded_config.get( |
|
|
"model_str", high_level_config["model_str"] |
|
|
), |
|
|
model_config=downloaded_config.get( |
|
|
"model_config", cfg.model.model_config |
|
|
), |
|
|
torch_hub_force_reload=high_level_config.get( |
|
|
"torch_hub_force_reload", False |
|
|
), |
|
|
) |
|
|
except Exception as config_e: |
|
|
print(f"Failed to download/use HuggingFace config: {config_e}") |
|
|
print("Falling back to local configuration...") |
|
|
|
|
|
model = init_model( |
|
|
model_str=cfg.model.model_str, |
|
|
model_config=cfg.model.model_config, |
|
|
torch_hub_force_reload=high_level_config.get( |
|
|
"torch_hub_force_reload", False |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
checkpoint_filename = high_level_config["checkpoint_name"] |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=high_level_config["hf_model_name"], |
|
|
filename=checkpoint_filename, |
|
|
token=load_hf_token(), |
|
|
) |
|
|
|
|
|
|
|
|
print("start loading checkpoint") |
|
|
if checkpoint_filename.endswith(".safetensors"): |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
checkpoint = load_file(checkpoint_path) |
|
|
else: |
|
|
checkpoint = torch.load( |
|
|
checkpoint_path, map_location="cpu", weights_only=False |
|
|
) |
|
|
|
|
|
print("start loading state_dict") |
|
|
if "model" in checkpoint: |
|
|
model.load_state_dict(checkpoint["model"], strict=False) |
|
|
elif "state_dict" in checkpoint: |
|
|
model.load_state_dict(checkpoint["state_dict"], strict=False) |
|
|
else: |
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
|
|
|
print( |
|
|
f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})" |
|
|
) |
|
|
|
|
|
except Exception as inner_e: |
|
|
print(f"Error listing repository files or loading weights: {inner_e}") |
|
|
raise inner_e |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load pretrained weights: {e}") |
|
|
print("Proceeding with randomly initialized model...") |
|
|
|
|
|
model = model.to(device) |
|
|
return model |
|
|
|