File size: 2,640 Bytes
71153bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import PretrainedConfig
from typing import Literal, Optional

class DiffusionLlamaConfig(PretrainedConfig):
    model_type = "diff_llama"

    def __init__(
        self,
        block_size: int = 4096,
        vocab_size: int = 50254,
        padding_multiple: int = 512,
        padded_vocab_size: Optional[int] = None,
        n_layer: int = 16,
        n_head: int = 32,
        n_embd: int = 4096,
        rotary_percentage: float = 0.25,
        parallel_residual: bool = True,
        bias: bool = True,
        n_query_groups: Optional[int] = None,
        shared_attention_norm: bool = False,
        norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] = "LayerNorm",
        norm_eps: float = 1e-5,
        mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP",
        intermediate_size: Optional[int] = None,
        condense_ratio: int = 1,
        initializer_range: float = 0.02,
        **kwargs,
    ):
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.padding_multiple = padding_multiple
        
        # Logic from original Config.__post_init__
        # 1. Calculate padded vocab size
        if padded_vocab_size is None:
            self.padded_vocab_size = self._find_multiple(vocab_size, padding_multiple)
        else:
            self.padded_vocab_size = padded_vocab_size
            
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.rotary_percentage = rotary_percentage
        self.parallel_residual = parallel_residual
        self.bias = bias
        
        # 2. Calculate query groups
        if n_query_groups is not None:
            self.n_query_groups = n_query_groups
        else:
            self.n_query_groups = n_head
            
        self.shared_attention_norm = shared_attention_norm
        self.norm_class = norm_class
        self.norm_eps = norm_eps
        self.mlp_class = mlp_class
        
        # 3. Calculate intermediate size
        if intermediate_size is None:
            # Default to 4x if not specified, though LLaMA usually specifies it explicitly
            self.intermediate_size = 4 * n_embd
        else:
            self.intermediate_size = intermediate_size
            
        self.condense_ratio = condense_ratio
        self.initializer_range = initializer_range
        
        super().__init__(**kwargs)

    @property
    def head_size(self) -> int:
        return self.n_embd // self.n_head

    def _find_multiple(self, n: int, k: int) -> int:
        if k > 0 and n % k == 0:
            return n
        return n + k - (n % k)