| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from contextlib import nullcontext |
| | from typing import TYPE_CHECKING, Optional |
| |
|
| | import torch |
| | from transformers.integrations import is_deepspeed_zero3_enabled |
| |
|
| | from ...extras import logging |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from transformers import PreTrainedModel, PreTrainedTokenizer |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: |
| | """Initialize new token embeddings with mean + Gaussian noise. |
| | |
| | This is the default initialization method used by LlamaFactory. |
| | |
| | Args: |
| | embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) |
| | num_new_tokens: Number of new tokens added at the end of the embedding matrix |
| | """ |
| | embedding_dim = embed_weight.size(1) |
| | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) |
| | noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) |
| | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) |
| | embed_weight[-num_new_tokens:] = avg_weight + noise_weight |
| |
|
| |
|
| | def _description_based_initialization( |
| | embed_weight: "torch.Tensor", |
| | num_new_tokens: int, |
| | descriptions: dict[str, str], |
| | tokenizer: "PreTrainedTokenizer", |
| | model: "PreTrainedModel", |
| | add_noise: bool = False, |
| | ) -> None: |
| | """Initialize new token embeddings based on textual descriptions. |
| | |
| | For each new token, this function: |
| | 1. Tokenizes its description text |
| | 2. Gets embeddings of the description tokens |
| | 3. Averages them to initialize the new token's embedding |
| | 4. Optionally adds Gaussian noise |
| | |
| | Args: |
| | embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) |
| | num_new_tokens: Number of new tokens added |
| | descriptions: Dict mapping token string to its description text |
| | e.g., {"<think>": "A token representing reasoning process"} |
| | tokenizer: The tokenizer instance |
| | model: The model instance (used to get input embeddings) |
| | add_noise: Whether to add Gaussian noise to the initialization |
| | |
| | Example: |
| | descriptions = { |
| | "<|START_OF_SVG|>": "Marks the beginning of an SVG document", |
| | "<|END_OF_SVG|>": "Marks the end of an SVG document" |
| | } |
| | """ |
| | embedding_dim = embed_weight.size(1) |
| |
|
| | for i, desc in enumerate(descriptions.values()): |
| | |
| | tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False) |
| |
|
| | with torch.no_grad(): |
| | token_ids = tokens["input_ids"][0] |
| | |
| | device = embed_weight.device |
| | token_ids = token_ids.to(device) |
| |
|
| | |
| | valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)] |
| |
|
| | if len(valid_token_ids) == 0: |
| | |
| | logger.warning_rank0( |
| | f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. " |
| | "Using mean of existing embeddings." |
| | ) |
| | base_embedding = embed_weight[:-num_new_tokens].mean(dim=0) |
| | else: |
| | |
| | token_embeds = model.get_input_embeddings()(valid_token_ids) |
| | base_embedding = token_embeds.mean(dim=0) |
| |
|
| | |
| | if add_noise: |
| | noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim)) |
| | embed_weight[-num_new_tokens + i] = base_embedding + noise |
| | else: |
| | embed_weight[-num_new_tokens + i] = base_embedding |
| |
|
| |
|
| | def _initialize_embeddings( |
| | embed_weight: "torch.Tensor", |
| | num_new_tokens: int, |
| | init_method: str, |
| | new_special_tokens_config: Optional[dict], |
| | tokenizer: "PreTrainedTokenizer", |
| | model: "PreTrainedModel", |
| | ) -> None: |
| | """Single source of truth for embedding initialization. |
| | |
| | This function selects the appropriate initialization method and applies it. |
| | |
| | Args: |
| | embed_weight: The embedding weight matrix to initialize |
| | num_new_tokens: Number of new tokens added |
| | init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise') |
| | new_special_tokens_config: Config dict with token descriptions (required for desc_init methods) |
| | tokenizer: The tokenizer instance |
| | model: The model instance |
| | """ |
| | if init_method == "desc_init" and new_special_tokens_config: |
| | logger.info_rank0("Using semantic initialization (desc_init) for new special tokens") |
| | _description_based_initialization( |
| | embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False |
| | ) |
| | elif init_method == "desc_init_w_noise" and new_special_tokens_config: |
| | logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens") |
| | _description_based_initialization( |
| | embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True |
| | ) |
| | else: |
| | if init_method != "noise_init": |
| | logger.warning_rank0( |
| | f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'" |
| | ) |
| | logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens") |
| | _noisy_mean_initialization(embed_weight, num_new_tokens) |
| |
|
| |
|
| | def resize_embedding_layer( |
| | model: "PreTrainedModel", |
| | tokenizer: "PreTrainedTokenizer", |
| | new_special_tokens_config: Optional[dict] = None, |
| | init_special_tokens: str = "noise_init", |
| | ) -> None: |
| | r"""Resize token embeddings and initialize new tokens. |
| | |
| | Args: |
| | model: The model to resize |
| | tokenizer: The tokenizer (used to get target vocab size) |
| | new_special_tokens_config: Optional dict with token descriptions for semantic initialization |
| | init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise') |
| | """ |
| | if is_deepspeed_zero3_enabled(): |
| | import deepspeed |
| |
|
| | params = [model.get_input_embeddings().weight] |
| | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: |
| | params.append(model.get_output_embeddings().weight) |
| |
|
| | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) |
| | else: |
| | context_maybe_zero3 = nullcontext() |
| |
|
| | with context_maybe_zero3: |
| | current_embedding_size = model.get_input_embeddings().weight.size(0) |
| |
|
| | if len(tokenizer) > current_embedding_size: |
| | if getattr(model, "quantization_method", None): |
| | raise ValueError("Cannot resize embedding layers of a quantized model.") |
| |
|
| | if not isinstance(model.get_output_embeddings(), torch.nn.Linear): |
| | raise ValueError("Current model does not support resizing embedding layers.") |
| |
|
| | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) |
| | with context_maybe_zero3: |
| | new_embedding_size = model.get_input_embeddings().weight.size(0) |
| | num_new_tokens = new_embedding_size - current_embedding_size |
| | logger.info_rank0( |
| | f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)" |
| | ) |
| |
|
| | |
| | _initialize_embeddings( |
| | model.get_input_embeddings().weight.data, |
| | num_new_tokens, |
| | init_special_tokens, |
| | new_special_tokens_config, |
| | tokenizer, |
| | model, |
| | ) |
| |
|
| | |
| | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: |
| | _initialize_embeddings( |
| | model.get_output_embeddings().weight.data, |
| | num_new_tokens, |
| | init_special_tokens, |
| | new_special_tokens_config, |
| | tokenizer, |
| | model, |
| | ) |
| |
|
| | model.config.vocab_size = new_embedding_size |
| | logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") |
| |
|