| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """PyTorch INFLM model.""" |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaDecoderLayer, |
| | LlamaModel, |
| | LlamaForCausalLM |
| | ) |
| | from .configuration_inflm import INFLMConfig |
| |
|
| | _CONFIG_FOR_DOC = "INFLMConfig" |
| |
|
| |
|
| | class INFLMDecoderLayer(LlamaDecoderLayer): |
| | def __init__(self, config: INFLMConfig, layer_idx: int): |
| | super().__init__(config, layer_idx) |
| | self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| | self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
|
| |
|
| | class INFLMModel(LlamaModel): |
| | config_class = INFLMConfig |
| | _no_split_modules = ["INFLMDecoderLayer"] |
| | |
| | def __init__(self, config: INFLMConfig): |
| | super().__init__(config) |
| | self.padding_idx = config.pad_token_id |
| | self.vocab_size = config.vocab_size |
| |
|
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| | self.layers = nn.ModuleList([INFLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) |
| | self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
|
| | self.gradient_checkpointing = False |
| | |
| | self.post_init() |
| |
|
| |
|
| | class INFLMForCausalLM(LlamaForCausalLM): |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: INFLMConfig): |
| | super().__init__(config) |
| | self.model = INFLMModel(config) |
| | self.vocab_size = config.vocab_size |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|