| |
| |
| |
|
|
| """GPT Blocks used for the GPT Model.""" |
|
|
| from typing import Dict, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| from .attention import ATTN_CLASS_REGISTRY |
| from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY |
|
|
| class MPTMLP(nn.Module): |
|
|
| def __init__(self, |
| d_model: int, |
| expansion_ratio: int, |
| device: Optional[str] = None): |
| super().__init__() |
| self.up_proj = nn.Linear(d_model, |
| expansion_ratio * d_model, |
| device=device) |
| self.act = nn.GELU(approximate='none') |
| self.down_proj = nn.Linear(expansion_ratio * d_model, |
| d_model, |
| device=device) |
| self.down_proj._is_residual = True |
|
|
| def forward(self, x): |
| return self.down_proj(self.act(self.up_proj(x))) |
|
|
| class MPTBlock(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| expansion_ratio: int, |
| attn_config: Dict = { |
| 'attn_type': 'multihead_attention', |
| 'attn_pdrop': 0.0, |
| 'attn_impl': 'triton', |
| 'qk_ln': False, |
| 'clip_qkv': None, |
| 'softmax_scale': None, |
| 'prefix_lm': False, |
| 'attn_uses_sequence_id': False, |
| 'alibi': False, |
| 'alibi_bias_max': 8, |
| }, |
| resid_pdrop: float = 0.0, |
| norm_type: str = 'low_precision_layernorm', |
| verbose: int = 0, |
| device: Optional[str] = None, |
| **kwargs): |
| del kwargs |
| super().__init__() |
|
|
| norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] |
| attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] |
|
|
| self.norm_1 = norm_class(d_model, device=device) |
| self.attn = attn_class( |
| attn_impl=attn_config['attn_impl'], |
| clip_qkv=attn_config['clip_qkv'], |
| qk_ln=attn_config['qk_ln'], |
| softmax_scale=attn_config['softmax_scale'], |
| attn_pdrop=attn_config['attn_pdrop'], |
| d_model=d_model, |
| n_heads=n_heads, |
| verbose=verbose, |
| device=device, |
| ) |
| self.norm_2 = norm_class(d_model, device=device) |
| self.ffn = MPTMLP( |
| d_model=d_model, |
| expansion_ratio=expansion_ratio, |
| device=device, |
| ) |
| self.resid_attn_dropout = nn.Dropout(resid_pdrop) |
| self.resid_ffn_dropout = nn.Dropout(resid_pdrop) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None, |
| attn_bias: Optional[torch.Tensor] = None, |
| attn_bias_ae: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.ByteTensor] = None, |
| is_causal: bool = True, |
| topk:int=None, |
| needs_weights:bool=None, |
| faiss_indexes:Tuple=None, |
| n_layers:int=None, |
| current_layer:int=None, |
| mask_by_sim:bool=False, |
| sim_threshold:float=None |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: |
| a = self.norm_1(x) |
| b, attn_weights, past_key_value, reshaped_idx = self.attn( |
| a, |
| past_key_value=past_key_value, |
| long_range_past_key_value=long_range_past_key_value, |
| attn_bias=attn_bias, |
| attn_bias_ae=attn_bias_ae, |
| attention_mask=attention_mask, |
| is_causal=is_causal, |
| topk=topk, |
| needs_weights=needs_weights, |
| faiss_indexes=faiss_indexes, |
| n_layers=n_layers, |
| current_layer=current_layer, |
| mask_by_sim=mask_by_sim, |
| sim_threshold=sim_threshold |
| ) |
| x = x + self.resid_attn_dropout(b) |
| m = self.norm_2(x) |
| n = self.ffn(m) |
| x = x + self.resid_ffn_dropout(n) |
| return x, attn_weights, past_key_value, reshaped_idx |
|
|