Lekr0's picture
Add files using upload-large-folder tool
212a146 verified
raw
history blame
5.21 kB
import json
import os
from typing import Optional, Union
import torch
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase
from transformers import (
GptOssConfig,
Llama4Config,
Llama4TextConfig,
LlamaConfig,
Phi3Config,
PretrainedConfig,
Qwen2Config,
Qwen3Config,
Qwen3MoeConfig,
modeling_utils,
)
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.custom_backend import (
GptOssForCausalLM,
Llama4ForCausalLM,
LlamaForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Qwen3MoeForCausalLM,
)
class AutoEagle3DraftModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
LlamaConfig: LlamaForCausalLMEagle3,
}
@classmethod
def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs):
"""
This class method takes a configuration object and create its model based on the
_model_mapping class variable.
Args:
config (PretrainedConfig): A configuration object.
Returns:
A model instance.
"""
# get the model class from the
_model_cls = cls._model_mapping[type(config)]
model = _model_cls(config, **config_kwargs)
# Convert model to specified dtype if provided
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
return model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
*model_args,
**kwargs,
):
original_warn = modeling_utils.logger.warning
def filtered_warning(msg):
if "embed_tokens.weight" in str(msg) and "initialized" in str(msg):
return
original_warn(msg)
modeling_utils.logger.warning = filtered_warning
try:
model = super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
finally:
modeling_utils.logger.warning = original_warn
return model
class AutoDistributedTargetModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
Llama4TextConfig: [Llama4ForCausalLM],
Qwen3MoeConfig: [Qwen3MoeForCausalLM],
Qwen2Config: [Qwen2ForCausalLM],
LlamaConfig: [LlamaForCausalLM],
Qwen3Config: [Qwen3ForCausalLM],
Phi3Config: [Phi3ForCausalLM],
GptOssConfig: [GptOssForCausalLM],
}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**config_kwargs,
):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
)
if isinstance(config, Llama4Config):
config = config.text_config
assert (
type(config) in cls._model_mapping
), f"Unsupported config type: {type(config)}"
model_cls = cls._model_mapping[type(config)][0]
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
**config_kwargs,
)
if device is not None:
model = model.to(device)
else:
model = model.cuda()
return model
class AutoDraftModelConfig:
_config_mapping = {
"LlamaForCausalLMEagle3": LlamaConfig,
}
@classmethod
def from_file(cls, config_path: str):
"""
This class method takes a configuration file path and create its configuration object based on the
_config_mapping class variable.
Args:
config_path (str): A path to a configuration file.
Returns:
A configuration object.
"""
with open(config_path, "r") as f:
config = json.load(f)
if "tie_word_embeddings" in config:
print("Set draft model tie_word_embeddings to False")
config["tie_word_embeddings"] = False
# check for architectures
architectures = config.get("architectures", None)
if architectures is None:
raise ValueError("No architectures found in the config file")
if len(architectures) != 1:
raise ValueError("Only one architecture is supported")
architecture = architectures[0]
if architecture not in cls._config_mapping:
raise ValueError(f"Architecture {architecture} not supported")
# If draft_vocab_size is not in config or is None, set draft_vocab_size to vocab_size
if "draft_vocab_size" not in config or config["draft_vocab_size"] is None:
config["draft_vocab_size"] = config.get("vocab_size", None)
return cls._config_mapping[architecture].from_dict(config)