File size: 8,144 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from omegaconf import MISSING
from transformers import AutoConfig
from verl.base_config import BaseConfig
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.import_utils import import_external_libs
from verl.utils.model import get_generation_config, update_model_config
__all__ = ["HFModelConfig", "MtpConfig"]
@dataclass
class MtpConfig(BaseConfig):
"""
Configuration for MTP model.
enable: Enable loading and saving of MTP parameters, but do not use them
enable_train: Whether to enable using MTP parameters during training
enable_rollout: Whether to enable using MTP parameters during rollout
Training parameters:
detach_encoder: Whether to detach encoder parameters during MTP training
mtp_loss_scaling_factor: Loss scaling factor during MTP training
vLLM rollout parameters:
method: "mtp"
num-speculative-tokens: 1
SGLang rollout parameters:
speculative-algorithm: EAGLE
speculative-num-steps: 3
speculative-eagle-topk: 1
speculative-num-draft-tokens: 4
"""
enable: bool = False
enable_train: bool = False
enable_rollout: bool = False
detach_encoder: bool = False
mtp_loss_scaling_factor: float = 0.1
speculative_algorithm: str = "EAGLE"
speculative_num_steps: int = 3
speculative_eagle_topk: int = 1
speculative_num_draft_tokens: int = 4
method: str = "mtp"
num_speculative_tokens: int = 1
@dataclass
class HFModelConfig(BaseConfig):
# note that we separate model_path, model_config_path and tokenizer_path in case they are different
_mutable_fields = {
"hf_config_path",
"tokenizer_path",
"hf_config",
"generation_config",
"tokenizer",
"processor",
"local_path",
"architectures",
"local_hf_config_path",
"local_tokenizer_path",
"mtp",
}
path: str = MISSING
local_path: Optional[str] = None
hf_config_path: Optional[str] = None
local_hf_config_path: Optional[str] = None
tokenizer_path: Optional[str] = None
local_tokenizer_path: Optional[str] = None
# whether to load tokenizer. This is useful when we only want to load model config
load_tokenizer: bool = True
hf_config: Any = None
generation_config: Any = None
tokenizer: Any = None
processor: Any = None
# whether to use shared memory
use_shm: bool = False
trust_remote_code: bool = False
# custom chat template for the model
custom_chat_template: Optional[str] = None
external_lib: Optional[str] = None
override_config: dict = field(default_factory=dict)
enable_gradient_checkpointing: bool = True
enable_activation_offload: bool = False
use_remove_padding: bool = True
# TODO: unify fsdp and megatron lora config
# fsdp lora related. We may setup a separate config later
lora_rank: int = 0
lora_alpha: int = 16
target_modules: Optional[Any] = "all-linear" # allow both "all-linear" and ["q_proj","k_proj"]
target_parameters: Optional[list[str]] = None # for lora adapter on nn.Parameter
exclude_modules: Optional[str] = None
# megatron lora config
lora: dict[str, Any] = field(default_factory=dict)
# path to pre-trained LoRA adapter to load for continued training
lora_adapter_path: Optional[str] = None
use_liger: bool = False
use_fused_kernels: bool = False
fused_kernel_options: dict = field(default_factory=dict)
# TiledMLP configuration for memory-efficient MLP computation
tiled_mlp: dict = field(default_factory=lambda: {"enabled": False, "num_shards": 4})
architectures: Optional[list[str]] = None
mtp: MtpConfig = field(default_factory=MtpConfig)
def __post_init__(self):
import_external_libs(self.external_lib)
if self.hf_config_path is None:
self.hf_config_path = self.path
if self.tokenizer_path is None:
self.tokenizer_path = self.path
self.local_path = copy_to_local(self.path, use_shm=self.use_shm)
# construct tokenizer
if self.load_tokenizer:
self.local_tokenizer_path = copy_to_local(self.tokenizer_path, use_shm=self.use_shm)
self.tokenizer = hf_tokenizer(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code)
self.processor = hf_processor(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code)
if self.custom_chat_template is not None:
if self.processor is not None:
self.processor.chat_template = self.custom_chat_template
else:
self.tokenizer.chat_template = self.custom_chat_template
self.local_hf_config_path = copy_to_local(self.hf_config_path, use_shm=self.use_shm)
self.generation_config = get_generation_config(
self.local_hf_config_path, trust_remote_code=self.trust_remote_code
)
# construct hf_config
attn_implementation = self.override_config.get("attn_implementation", "flash_attention_2")
self.hf_config = AutoConfig.from_pretrained(
self.local_hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation
)
override_config_kwargs = {}
if self.tokenizer is not None:
override_config_kwargs.update(
{
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
)
# TODO: (vermouth1992). self.config.model in megatron differs from that of fsdp in the override_config.
override_config = (
self.override_config["model_config"] if "model_config" in self.override_config else self.override_config
)
override_config_kwargs.update(override_config)
update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs)
self.share_embeddings_and_output_weights = getattr(self.hf_config, "tie_word_embeddings", False)
# get model architectures
self.architectures = getattr(self.hf_config, "architectures", None)
assert self.architectures is not None and len(self.architectures) == 1, (
"Expect only one architecture, got {}".format(self.architectures)
)
# per model patch
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
self.hf_config.text_config.topk_method = "greedy"
# Ensure target_modules is a str or list[str] (only if not None)
if self.target_modules is not None:
if not isinstance(self.target_modules, (str | list)):
raise TypeError(
"target_modules must be a string or a list of strings, "
f"but got {type(self.target_modules).__name__}"
)
if isinstance(self.target_modules, list):
for x in self.target_modules:
if not isinstance(x, str):
raise TypeError(
f"All elements in target_modules list must be strings, but found {type(x).__name__}"
)
def get_processor(self):
return self.processor if self.processor is not None else self.tokenizer
|