bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_loader/loader.py
from abc import ABC
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from ..utils import logging
from ..utils.import_utils import is_torch_npu_available, is_vescale_available
from .module_utils import init_empty_weights, load_model_weights
from .registry import get_registry
logger = logging.get_logger(__name__)
class BaseModelLoader(ABC):
def __init__(self):
pass
def load_model(self, model_config, **kwargs):
raise NotImplementedError
class HuggingfaceLoader(BaseModelLoader):
def __init__(self):
super().__init__()
def load_model(self, init_kwargs: dict, **kwargs):
model_config = init_kwargs["config"]
architecture = _get_model_arch_from_config(model_config)
if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
elif "ForCausalLM" in architecture and type(model_config) in AutoModelForCausalLM._model_mapping.keys():
load_class = AutoModelForCausalLM
else:
load_class = AutoModel
init_device = kwargs.pop("init_device", "cuda")
weights_path = kwargs.pop("weights_path", None)
empty_init = kwargs.pop("empty_init", False)
logger.info_rank0(
f"Loading model from Huggingface modeling.\n"
f"init_device: {init_device}\n"
f"empty_init: {empty_init}\n"
f"weights_path: {weights_path}"
)
if weights_path is None: # init empty model from config
if is_torch_npu_available() and init_device == "cuda":
init_device = "npu"
if init_device == "meta":
with torch.device(init_device), no_init_weights():
logger.info_rank0("Init empty model on meta device from config without init_weights.")
model = load_class.from_config(**init_kwargs)
else:
with torch.device(init_device):
logger.info_rank0("Init empty model from config.")
model = load_class.from_config(**init_kwargs)
else:
if is_vescale_available() and init_device == "meta":
from vescale.initialize.meta_init import meta_device_init
with meta_device_init():
model = load_class.from_config(**init_kwargs)
else:
with init_empty_weights(), no_init_weights():
model = load_class.from_config(**init_kwargs)
if not empty_init:
load_model_weights(model, weights_path, init_device)
return model
class CustomizedModelingLoader(BaseModelLoader):
def __init__(self, model_cls: PreTrainedModel):
super().__init__()
self.model_cls = model_cls # model class from code_path
def load_model(self, init_kwargs: dict, **kwargs):
init_kwargs.pop("trust_remote_code", True)
init_device = kwargs.pop("init_device", "cuda")
weights_path = kwargs.pop("weights_path", None)
empty_init = kwargs.pop("empty_init", False)
vlm_repo_id = kwargs.pop("vlm_repo_id", None)
enable_expert_vision = kwargs.pop("enable_expert_vision", False)
expert_vision_path = kwargs.pop("expert_vision_path", None)
post_training = kwargs.pop("post_training", False)
adanorm_time = kwargs.pop("adanorm_time", False)
incremental_training = kwargs.pop("incremental_training", False)
depth_incremental_training = kwargs.pop("depth_incremental_training", False)
norm_qkv = kwargs.pop("norm_qkv", False)
logger.info_rank0(
f"Loading model from customized modeling.\n"
f"init_device: {init_device}\n"
f"empty_init: {empty_init}\n"
f"weights_path: {weights_path}"
)
if weights_path is None: # init empty model from config
if is_torch_npu_available() and init_device == "cuda":
init_device = "npu"
if init_device == "meta":
with torch.device(init_device), no_init_weights():
logger.info_rank0("Init empty model on meta device from config without init_weights.")
model = self.model_cls._from_config(**init_kwargs)
else:
with torch.device(init_device):
logger.info_rank0("Init empty model from config.")
model = self.model_cls._from_config(**init_kwargs)
else:
load_vlm_only = False
if is_vescale_available() and init_device == "meta":
from vescale.initialize.meta_init import meta_device_init
with meta_device_init():
model = self.model_cls._from_config(**init_kwargs)
else:
with init_empty_weights(), no_init_weights():
if (self.model_cls.__name__ == "PI0Policy" and
self.model_cls.__module__ == "lingbotvla.models.vla.pi0.modeling_pi0"):
model = self.model_cls(config=init_kwargs['config'], tokenizer_path=init_kwargs['config'].tokenizer_path).to(init_kwargs['torch_dtype'])
if vlm_repo_id is not None:
load_vlm_only = True
elif (self.model_cls.__name__ == "LingbotVlaPolicy" and
self.model_cls.__module__ == "lingbotvla.models.vla.pi0.modeling_lingbot_vla"):
model = self.model_cls(config=init_kwargs['config'], tokenizer_path=init_kwargs['config'].tokenizer_path).to(init_kwargs['torch_dtype'])
if vlm_repo_id is not None and incremental_training:
load_vlm_only = True
else:
model = self.model_cls._from_config(**init_kwargs)
if not empty_init:
load_model_weights(model, weights_path, init_device, load_vlm_only=load_vlm_only, enable_expert_vision=enable_expert_vision, expert_vision_path=expert_vision_path, post_training=post_training, incremental_training=incremental_training, depth_incremental_training=depth_incremental_training, norm_qkv=norm_qkv, adanorm_time=adanorm_time)
# we should tie embeddings after loading weights because init_empty_weights() leads to untied weights,
if getattr(model.config, "tie_word_embeddings", True):
try:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
output_embeddings._parameters["weight"] = input_embeddings._parameters["weight"]
except Exception as e:
logger.info_rank0(f"Failed to tie embeddings: {e}")
return model
def _get_model_arch_from_config(model_config):
arch_name = model_config.architectures
if isinstance(arch_name, list):
arch_name = arch_name[0]
return arch_name
def get_loader(model_config, force_use_huggingface):
if isinstance(model_config, PI0Config):
if 'qwen' not in model_config.tokenizer_path.lower():
model_arch = 'PI0Policy'
elif 'qwen2' in model_config.tokenizer_path.lower():
model_arch = 'LingbotVlaPolicy'
else:
model_arch = _get_model_arch_from_config(model_config) # Qwen2VLForConditionalGeneration
loader = HuggingfaceLoader()
if not force_use_huggingface:
model_registry = get_registry()
if model_arch in model_registry.supported_models:
model_cls = model_registry.get_model_cls_from_model_arch(model_arch)
loader = CustomizedModelingLoader(model_cls=model_cls)
return loader