from transformers import PretrainedConfig from typing import Literal, Optional class DiffusionLlamaConfig(PretrainedConfig): model_type = "diff_llama" def __init__( self, block_size: int = 4096, vocab_size: int = 50254, padding_multiple: int = 512, padded_vocab_size: Optional[int] = None, n_layer: int = 16, n_head: int = 32, n_embd: int = 4096, rotary_percentage: float = 0.25, parallel_residual: bool = True, bias: bool = True, n_query_groups: Optional[int] = None, shared_attention_norm: bool = False, norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] = "LayerNorm", norm_eps: float = 1e-5, mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP", intermediate_size: Optional[int] = None, condense_ratio: int = 1, initializer_range: float = 0.02, **kwargs, ): self.block_size = block_size self.vocab_size = vocab_size self.padding_multiple = padding_multiple # Logic from original Config.__post_init__ # 1. Calculate padded vocab size if padded_vocab_size is None: self.padded_vocab_size = self._find_multiple(vocab_size, padding_multiple) else: self.padded_vocab_size = padded_vocab_size self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.rotary_percentage = rotary_percentage self.parallel_residual = parallel_residual self.bias = bias # 2. Calculate query groups if n_query_groups is not None: self.n_query_groups = n_query_groups else: self.n_query_groups = n_head self.shared_attention_norm = shared_attention_norm self.norm_class = norm_class self.norm_eps = norm_eps self.mlp_class = mlp_class # 3. Calculate intermediate size if intermediate_size is None: # Default to 4x if not specified, though LLaMA usually specifies it explicitly self.intermediate_size = 4 * n_embd else: self.intermediate_size = intermediate_size self.condense_ratio = condense_ratio self.initializer_range = initializer_range super().__init__(**kwargs) @property def head_size(self) -> int: return self.n_embd // self.n_head def _find_multiple(self, n: int, k: int) -> int: if k > 0 and n % k == 0: return n return n + k - (n % k)