LilPhat23 commited on
Commit
7a941ae
·
verified ·
1 Parent(s): 7cd5956

Upload configuration_phogpt.py

Browse files
Files changed (1) hide show
  1. configuration_phogpt.py +142 -0
configuration_phogpt.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration for PhoGPT."""
2
+ import warnings
3
+ from typing import Any, Dict, Optional, Union
4
+ from transformers import PretrainedConfig
5
+ from .attention import check_alibi_support, is_flash_v1_installed, is_flash_v2_installed
6
+ from .blocks import attn_config_defaults as phogpt_attn_defaults
7
+ from .fc import FC_CLASS_REGISTRY
8
+ from .norm import LPLayerNorm
9
+ from .ffn import FFN_CLASS_REGISTRY
10
+ from .warnings import VersionedDeprecationWarning
11
+
12
+ ffn_config_defaults: Dict = {'ffn_type': 'phogpt_mlp'}
13
+ init_config_defaults: Dict = {
14
+ 'name': 'kaiming_normal_',
15
+ 'fan_mode': 'fan_in',
16
+ 'init_nonlinearity': 'relu',
17
+ 'init_div_is_residual': True,
18
+ 'emb_init_std': None,
19
+ 'emb_init_uniform_lim': None,
20
+ 'init_std': None,
21
+ 'init_gain': 0.0
22
+ }
23
+
24
+ class PhoGPTConfig(PretrainedConfig):
25
+ model_type = 'phogpt'
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 4096,
30
+ num_attention_heads: int = 32,
31
+ num_hidden_layers: int = 32,
32
+ expansion_ratio: Union[int, float] = 4,
33
+ max_seq_len: int = 4096,
34
+ vocab_size: int = 51200,
35
+ resid_pdrop: float = 0.0,
36
+ emb_pdrop: float = 0.0,
37
+ learned_pos_emb: bool = True,
38
+ attn_config: Dict = phogpt_attn_defaults,
39
+ ffn_config: Dict = ffn_config_defaults,
40
+ init_device: str = 'cpu',
41
+ logit_scale: Optional[Union[float, str]] = None,
42
+ no_bias: bool = False,
43
+ embedding_fraction: float = 1.0,
44
+ norm_type: str = 'low_precision_layernorm',
45
+ use_cache: bool = False,
46
+ init_config: Dict = init_config_defaults,
47
+ fc_type: str = 'torch',
48
+ tie_word_embeddings: bool = True,
49
+ use_pad_tok_in_ffn: bool = True,
50
+ **kwargs: Any
51
+ ):
52
+ """PhoGPT configuration class.
53
+
54
+ Args:
55
+ hidden_size (int): Model hidden size (embedding dimension)
56
+ num_attention_heads (int): Number of attention heads
57
+ num_hidden_layers (int): Number of transformer layers
58
+ expansion_ratio (int | float): FFN expansion ratio
59
+ max_seq_len (int): Max sequence length
60
+ vocab_size (int): Vocabulary size
61
+ resid_pdrop (float): Dropout on residuals
62
+ emb_pdrop (float): Dropout on embeddings
63
+ learned_pos_emb (bool): Use learned positional embeddings
64
+ attn_config (dict): Attention configuration dictionary
65
+ ffn_config (dict): Feedforward network config dictionary
66
+ init_device (str): Device for initialization
67
+ logit_scale (float | str): Logit scaling
68
+ no_bias (bool): Disable biases
69
+ embedding_fraction (float): Scale embedding gradients
70
+ norm_type (str): LayerNorm type
71
+ use_cache (bool): Return past key/value
72
+ init_config (dict): Weight initialization config
73
+ fc_type (str): Fully connected layer type ('torch' or 'te')
74
+ tie_word_embeddings (bool): Tie input/output embeddings
75
+ use_pad_tok_in_ffn (bool): Forward pad tokens through FFN
76
+ """
77
+ self.hidden_size = hidden_size
78
+ self.num_attention_heads = num_attention_heads
79
+ self.num_hidden_layers = num_hidden_layers
80
+ self.expansion_ratio = expansion_ratio
81
+ self.max_seq_len = max_seq_len
82
+ self.vocab_size = vocab_size
83
+ self.resid_pdrop = resid_pdrop
84
+ self.emb_pdrop = emb_pdrop
85
+ self.learned_pos_emb = learned_pos_emb
86
+ self.attn_config = attn_config
87
+ self.ffn_config = ffn_config
88
+ self.init_device = init_device
89
+ self.logit_scale = logit_scale
90
+ self.no_bias = no_bias
91
+ self.embedding_fraction = embedding_fraction
92
+ self.norm_type = norm_type
93
+ self.use_cache = use_cache
94
+ self.init_config = init_config
95
+ self.fc_type = fc_type
96
+ self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
97
+
98
+ if 'name' in kwargs:
99
+ del kwargs['name']
100
+ if 'loss_fn' in kwargs:
101
+ del kwargs['loss_fn']
102
+
103
+ if self.attn_config.get('alibi', False) or self.attn_config.get('rope', False):
104
+ self.learned_pos_emb = False
105
+ warnings.warn("alibi or rope is enabled, setting learned_pos_emb to False.")
106
+
107
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
108
+ self._validate_config()
109
+
110
+ def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
111
+ for k, v in config_defaults.items():
112
+ if k not in config:
113
+ config[k] = v
114
+ elif isinstance(v, dict):
115
+ config[k] = self._set_config_defaults(config.get(k, {}), v)
116
+ return config
117
+
118
+ def _validate_config(self) -> None:
119
+ self.attn_config = self._set_config_defaults(self.attn_config, phogpt_attn_defaults)
120
+ self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
121
+ self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
122
+
123
+ if self.hidden_size % self.num_attention_heads != 0:
124
+ raise ValueError("hidden_size must be divisible by num_attention_heads")
125
+
126
+ for prob in [self.attn_config.get('attn_pdrop', 0.0), self.resid_pdrop, self.emb_pdrop]:
127
+ if not 0.0 <= prob <= 1.0:
128
+ raise ValueError("Dropout probabilities must be between 0 and 1")
129
+
130
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
131
+ raise ValueError("embedding_fraction must be in (0, 1]")
132
+
133
+ if not (self.learned_pos_emb or self.attn_config.get('alibi', False) or self.attn_config.get('rope', False)):
134
+ warnings.warn("No positional encoding used: learned_pos_emb, alibi, or rope should be enabled.")
135
+
136
+ if self.fc_type == 'te' or self.ffn_config.get('ffn_type') == 'te_ln_mlp':
137
+ try:
138
+ import transformer_engine.pytorch as te
139
+ del te
140
+ except ImportError:
141
+ raise ImportError("fc_type='te' requires TransformerEngine installed.")
142
+