Diff_LLaMA_336M_sudoku_simple_sft_320 / configuration_diff_llama.py
zzy1123's picture
Upload DiffusionLlamaLM
71153bb verified
from transformers import PretrainedConfig
from typing import Literal, Optional
class DiffusionLlamaConfig(PretrainedConfig):
model_type = "diff_llama"
def __init__(
self,
block_size: int = 4096,
vocab_size: int = 50254,
padding_multiple: int = 512,
padded_vocab_size: Optional[int] = None,
n_layer: int = 16,
n_head: int = 32,
n_embd: int = 4096,
rotary_percentage: float = 0.25,
parallel_residual: bool = True,
bias: bool = True,
n_query_groups: Optional[int] = None,
shared_attention_norm: bool = False,
norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] = "LayerNorm",
norm_eps: float = 1e-5,
mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP",
intermediate_size: Optional[int] = None,
condense_ratio: int = 1,
initializer_range: float = 0.02,
**kwargs,
):
self.block_size = block_size
self.vocab_size = vocab_size
self.padding_multiple = padding_multiple
# Logic from original Config.__post_init__
# 1. Calculate padded vocab size
if padded_vocab_size is None:
self.padded_vocab_size = self._find_multiple(vocab_size, padding_multiple)
else:
self.padded_vocab_size = padded_vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.rotary_percentage = rotary_percentage
self.parallel_residual = parallel_residual
self.bias = bias
# 2. Calculate query groups
if n_query_groups is not None:
self.n_query_groups = n_query_groups
else:
self.n_query_groups = n_head
self.shared_attention_norm = shared_attention_norm
self.norm_class = norm_class
self.norm_eps = norm_eps
self.mlp_class = mlp_class
# 3. Calculate intermediate size
if intermediate_size is None:
# Default to 4x if not specified, though LLaMA usually specifies it explicitly
self.intermediate_size = 4 * n_embd
else:
self.intermediate_size = intermediate_size
self.condense_ratio = condense_ratio
self.initializer_range = initializer_range
super().__init__(**kwargs)
@property
def head_size(self) -> int:
return self.n_embd // self.n_head
def _find_multiple(self, n: int, k: int) -> int:
if k > 0 and n % k == 0:
return n
return n + k - (n % k)