LongCat0830 commited on
Commit
f2cd2b4
·
verified ·
1 Parent(s): 968b18d

Add files using upload-large-folder tool

Browse files
config.json CHANGED
@@ -1,37 +1,37 @@
1
  {
2
- "architectures": [
3
- "LongcatCausalLM"
4
- ],
5
- "attention_bias": false,
6
- "attention_dropout": 0.0,
7
- "auto_map": {
8
- "AutoConfig": "configuration_longcat.LongcatConfig",
9
- "AutoModel": "modeling_longcat.LongcatModel",
10
- "AutoModelForCausalLM": "modeling_longcat.LongcatForCausalLM"
11
- },
12
- "vocab_size": 131072,
13
- "hidden_size": 6144,
14
- "ffn_hidden_size": 12288,
15
- "expert_ffn_hidden_size": 2048,
16
- "num_layers": 28,
17
- "num_attention_heads": 64,
18
- "kv_lora_rank": 512,
19
- "q_lora_rank": 1536,
20
- "qk_rope_head_dim": 64,
21
- "v_head_dim": 128,
22
- "qk_nope_head_dim": 128,
23
- "mla_scale_q_lora": true,
24
- "mla_scale_kv_lora": true,
25
- "routed_scaling_factor": 6.0,
26
- "n_routed_experts": 512,
27
- "max_position_embeddings": 131072,
28
- "rms_norm_eps": 1e-5,
29
- "use_cache": true,
30
- "bos_token_id": 1,
31
- "eos_token_id": 2,
32
- "rope_theta": 10000000.0,
33
- "attention_method": "MLA",
34
- "zero_expert_num": 256,
35
- "zero_expert_type": "identity",
36
- "moe_topk": 12
37
- }
 
1
  {
2
+ "architectures": [
3
+ "LongcatFlashForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_longcat_flash.LongcatFlashConfig",
9
+ "AutoModel": "modeling_longcat_flash.LongcatFlashModel",
10
+ "AutoModelForCausalLM": "modeling_longcat_flash.LongcatFlashForCausalLM"
11
+ },
12
+ "vocab_size": 131072,
13
+ "hidden_size": 6144,
14
+ "ffn_hidden_size": 12288,
15
+ "expert_ffn_hidden_size": 2048,
16
+ "num_layers": 28,
17
+ "num_attention_heads": 64,
18
+ "kv_lora_rank": 512,
19
+ "q_lora_rank": 1536,
20
+ "qk_rope_head_dim": 64,
21
+ "v_head_dim": 128,
22
+ "qk_nope_head_dim": 128,
23
+ "mla_scale_q_lora": true,
24
+ "mla_scale_kv_lora": true,
25
+ "routed_scaling_factor": 6.0,
26
+ "n_routed_experts": 512,
27
+ "max_position_embeddings": 131072,
28
+ "rms_norm_eps": 1e-5,
29
+ "use_cache": true,
30
+ "bos_token_id": 1,
31
+ "eos_token_id": 2,
32
+ "rope_theta": 10000000.0,
33
+ "attention_method": "MLA",
34
+ "zero_expert_num": 256,
35
+ "zero_expert_type": "identity",
36
+ "moe_topk": 12
37
+ }
configuration_longcat_flash.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """LongcatFlash model configuration"""
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+
7
+
8
+ LONGCAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
9
+
10
+
11
+ class LongcatFlashConfig(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`LongcatFlashModel`]. It is used to instantiate an LongcatFlash
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the LongcatFlash.
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+
19
+
20
+ Args:
21
+ vocab_size (`int`, *optional*, defaults to 131072):
22
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`LongcatFlashModel`]
24
+ hidden_size (`int`, *optional*, defaults to 7168):
25
+ Dimension of the hidden representations.
26
+ ffn_hidden_size (`int`, *optional*, defaults to 18432):
27
+ Dimension of the MLP representations.
28
+ expert_ffn_hidden_size (`int`, *optional*, defaults to 2048):
29
+ Dimension of the MoE representations.
30
+ num_layers (`int`, *optional*, defaults to 61):
31
+ Number of hidden layers in the Transformer decoder.
32
+ num_attention_heads (`int`, *optional*, defaults to 128):
33
+ Number of attention heads for each attention layer in the Transformer decoder.
34
+ num_key_value_heads (`int`, *optional*, defaults to 128):
35
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
36
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
37
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
38
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
39
+ by meanpooling all the original heads within that group. For more details checkout [this
40
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
41
+ `num_attention_heads`.
42
+ n_routed_experts (`int`, *optional*, defaults to 256):
43
+ Number of routed experts.
44
+ routed_scaling_factor (`float`, *optional*, defaults to 2.5):
45
+ Scaling factor or routed experts.
46
+ kv_lora_rank (`int`, *optional*, defaults to 512):
47
+ Rank of the LoRA matrices for key and value projections.
48
+ q_lora_rank (`int`, *optional*, defaults to 1536):
49
+ Rank of the LoRA matrices for query projections.
50
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
51
+ Dimension of the query/key heads that use rotary position embeddings.
52
+ v_head_dim (`int`, *optional*, defaults to 128):
53
+ Dimension of the value heads.
54
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
55
+ Dimension of the query/key heads that don't use rotary position embeddings.
56
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
57
+ Whether to normalize the weights of the routed experts.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
61
+ The maximum sequence length that this model might ever be used with.
62
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
63
+ The epsilon used by the rms normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ pad_token_id (`int`, *optional*):
68
+ Padding token id.
69
+ bos_token_id (`int`, *optional*, defaults to 0):
70
+ Beginning of stream token id.
71
+ eos_token_id (`int`, *optional*, defaults to 1):
72
+ End of stream token id.
73
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
74
+ Whether to tie weight embeddings
75
+ rope_theta (`float`, *optional*, defaults to 10000.0):
76
+ The base period of the RoPE embeddings.
77
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
78
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+ attention_method (`str`, *optional*, defaults to `"MLA"`):
82
+ The attention method to use.
83
+ initializer_range (`float`, *optional*, defaults to 0.006):
84
+ The initializer range for the model.
85
+ router_bias (`bool`, *optional*, defaults to `False`):
86
+ Whether to use a bias in the router.
87
+ zero_expert_num (`int`, *optional*, defaults to `None`):
88
+ The number of zero experts to use.
89
+ zero_expert_type (`str`, *optional*, defaults to `None`):
90
+ The type of zero expert to use.
91
+
92
+ ```python
93
+ >>> from transformers import LongcatFlashModel, LongcatFlashConfig
94
+
95
+ >>> # Initializing a LongcatFlash style configuration
96
+ >>> configuration = LongcatFlashConfig()
97
+
98
+ >>> # Accessing the model configuration
99
+ >>> configuration = model.config
100
+ ```"""
101
+
102
+ model_type = "longcat_flash"
103
+ keys_to_ignore_at_inference = ["past_key_values"]
104
+ base_model_tp_plan = {
105
+ "layers.*.self_attn.k_proj": "colwise",
106
+ "layers.*.self_attn.v_proj": "colwise",
107
+ "layers.*.self_attn.o_proj": "rowwise",
108
+ "layers.*.mlp.experts.*.gate_proj": "local_colwise",
109
+ "layers.*.mlp.experts.*.up_proj": "local_colwise",
110
+ "layers.*.mlp.experts.*.down_proj": "local_rowwise",
111
+ "layers.*.mlps.*.gate_proj": "local_colwise",
112
+ "layers.*.mlps.*.up_proj": "local_colwise",
113
+ "layers.*.mlps.*.down_proj": "local_rowwise",
114
+ }
115
+ base_model_pp_plan = {
116
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
117
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
118
+ "norm": (["hidden_states"], ["hidden_states"]),
119
+ }
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=131072,
124
+ hidden_size=7168,
125
+ ffn_hidden_size=18432,
126
+ expert_ffn_hidden_size=2048,
127
+ num_layers=61,
128
+ num_attention_heads=128,
129
+ num_key_value_heads=None,
130
+ n_routed_experts=256,
131
+ routed_scaling_factor=1,
132
+ kv_lora_rank=512,
133
+ q_lora_rank=1536,
134
+ qk_rope_head_dim=64,
135
+ v_head_dim=128,
136
+ qk_nope_head_dim=128,
137
+ mla_scale_q_lora=True,
138
+ mla_scale_kv_lora=True,
139
+ moe_topk=8,
140
+ norm_topk_prob=False,
141
+ hidden_act="silu",
142
+ max_position_embeddings=4096,
143
+ rms_norm_eps=1e-6,
144
+ use_cache=True,
145
+ pad_token_id=None,
146
+ bos_token_id=0,
147
+ eos_token_id=1,
148
+ tie_word_embeddings=False,
149
+ rope_theta=10000.0,
150
+ attention_bias=False,
151
+ attention_dropout=0.0,
152
+ attention_method='MLA',
153
+ initializer_range=0.006,
154
+ router_bias=False,
155
+ zero_expert_num=None,
156
+ zero_expert_type=None,
157
+ **kwargs,
158
+ ):
159
+ self.vocab_size = vocab_size
160
+ self.max_position_embeddings = max_position_embeddings
161
+ self.hidden_size = hidden_size
162
+ self.ffn_hidden_size = ffn_hidden_size
163
+ self.expert_ffn_hidden_size = expert_ffn_hidden_size
164
+ self.num_layers = num_layers
165
+ self.num_attention_heads = num_attention_heads
166
+ self.n_routed_experts = n_routed_experts
167
+ self.routed_scaling_factor = routed_scaling_factor
168
+ self.kv_lora_rank = kv_lora_rank
169
+ self.q_lora_rank = q_lora_rank
170
+ self.qk_rope_head_dim = qk_rope_head_dim
171
+ self.v_head_dim = v_head_dim
172
+ self.qk_nope_head_dim = qk_nope_head_dim
173
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
174
+ self.moe_topk = moe_topk
175
+ self.norm_topk_prob = norm_topk_prob
176
+ self.mla_scale_q_lora = mla_scale_q_lora
177
+ self.mla_scale_kv_lora = mla_scale_kv_lora
178
+ self.attention_method = attention_method
179
+ self.initializer_range = initializer_range
180
+ self.router_bias = router_bias
181
+ self.zero_expert_num = zero_expert_num
182
+ self.zero_expert_type = zero_expert_type
183
+
184
+ if self.attention_method == "MLA":
185
+ self.head_dim = qk_rope_head_dim
186
+ else:
187
+ ValueError('attention_method should be one of ["MLA"]')
188
+
189
+
190
+ if num_key_value_heads is None:
191
+ num_key_value_heads = num_attention_heads
192
+
193
+ self.num_key_value_heads = num_key_value_heads
194
+ self.hidden_act = hidden_act
195
+ self.rms_norm_eps = rms_norm_eps
196
+ self.use_cache = use_cache
197
+ self.rope_theta = rope_theta
198
+ self.attention_bias = attention_bias
199
+ self.attention_dropout = attention_dropout
200
+
201
+ rope_config_validation(self)
202
+
203
+ super().__init__(
204
+ pad_token_id=pad_token_id,
205
+ bos_token_id=bos_token_id,
206
+ eos_token_id=eos_token_id,
207
+ tie_word_embeddings=tie_word_embeddings,
208
+ **kwargs,
209
+ )
210
+
211
+ @property
212
+ def num_hidden_layers(self):
213
+ return self.num_layers
214
+
215
+
216
+ __all__ = ["LongcatFlashConfig"]
modeling_longcat_flash.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from transformers.activations import ACT2FN
8
+ from transformers.cache_utils import Cache, DynamicCache
9
+ from transformers.generation import GenerationMixin
10
+ from transformers.integrations import use_kernel_forward_from_hub
11
+ from transformers.masking_utils import create_causal_mask
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_layers import GradientCheckpointingLayer
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
16
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
17
+ from transformers.processing_utils import Unpack
18
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
19
+ from transformers.utils.generic import check_model_inputs
20
+ from .configuration_longcat_flash import LongcatFlashConfig
21
+
22
+
23
+ @use_kernel_forward_from_hub("RMSNorm")
24
+ class LongcatFlashRMSNorm(nn.Module):
25
+ def __init__(self, hidden_size, eps=1e-6):
26
+ """
27
+ LongcatFlashRMSNorm is equivalent to T5LayerNorm
28
+ """
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(hidden_size))
31
+ self.variance_epsilon = eps
32
+
33
+ def forward(self, hidden_states):
34
+ input_dtype = hidden_states.dtype
35
+ hidden_states = hidden_states.to(torch.float32)
36
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
37
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
38
+ return self.weight * hidden_states.to(input_dtype)
39
+
40
+ def extra_repr(self):
41
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
42
+
43
+
44
+ class LongcatFlashRotaryEmbedding(nn.Module):
45
+ def __init__(self, config: LongcatFlashConfig, device=None):
46
+ super().__init__()
47
+ # BC: "rope_type" was originally "type"
48
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
49
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
50
+ else:
51
+ self.rope_type = "default"
52
+ self.max_seq_len_cached = config.max_position_embeddings
53
+ self.original_max_seq_len = config.max_position_embeddings
54
+
55
+ self.config = config
56
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
57
+
58
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
59
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
60
+ self.original_inv_freq = self.inv_freq
61
+
62
+ @torch.no_grad()
63
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
64
+ def forward(self, x, position_ids):
65
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
66
+ position_ids_expanded = position_ids[:, None, :].float()
67
+
68
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
69
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
70
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
71
+ emb = torch.cat((freqs, freqs), dim=-1)
72
+ cos = emb.cos() * self.attention_scaling
73
+ sin = emb.sin() * self.attention_scaling
74
+
75
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
76
+
77
+
78
+ class LongcatFlashMLP(nn.Module):
79
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
80
+ super().__init__()
81
+ self.config = config
82
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
83
+ self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
84
+
85
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
86
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
87
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
88
+ self.act_fn = ACT2FN[config.hidden_act]
89
+
90
+ def forward(self, x):
91
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
92
+ return down_proj
93
+
94
+
95
+ class LongcatFlashTopkRouter(nn.Module):
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.config = config
99
+ self.top_k = config.moe_topk
100
+ self.n_routed_experts = (
101
+ config.n_routed_experts
102
+ if config.zero_expert_num is None
103
+ else config.n_routed_experts + config.zero_expert_num
104
+ )
105
+ self.routed_scaling_factor = config.routed_scaling_factor
106
+ self.norm_topk_prob = config.norm_topk_prob
107
+ self.router_bias = config.router_bias
108
+
109
+ self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias)
110
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
111
+
112
+ @torch.no_grad()
113
+ def get_topk_indices(self, scores):
114
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
115
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
116
+ return topk_indices
117
+
118
+ def forward(self, hidden_states):
119
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
120
+ router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32))
121
+ scores = router_logits.softmax(dim=-1)
122
+ topk_indices = self.get_topk_indices(scores)
123
+ topk_weights = scores.gather(1, topk_indices)
124
+ if self.norm_topk_prob:
125
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
126
+ topk_weights /= denominator
127
+ topk_weights = topk_weights * self.routed_scaling_factor
128
+ return topk_indices, topk_weights
129
+
130
+
131
+ class LongcatFlashMoE(nn.Module):
132
+ """
133
+ moe module.
134
+ """
135
+
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.config = config
139
+ self.experts = nn.ModuleList(
140
+ [
141
+ LongcatFlashMLP(config, intermediate_size=config.expert_ffn_hidden_size)
142
+ for _ in range(config.n_routed_experts)
143
+ ]
144
+ )
145
+ self.router = LongcatFlashTopkRouter(config)
146
+ self.zero_expert_num = config.zero_expert_num
147
+ self.zero_expert_type = config.zero_expert_type
148
+
149
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
150
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
151
+ total_experts = len(self.experts) if self.zero_expert_num is None else len(self.experts) + self.zero_expert_num
152
+
153
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=total_experts)
154
+ expert_mask = expert_mask.permute(2, 0, 1)
155
+
156
+ for expert_idx in range(total_experts):
157
+ expert = self.experts[expert_idx] if expert_idx < len(self.experts) else None
158
+ mask = expert_mask[expert_idx]
159
+ token_indices, weight_indices = torch.where(mask)
160
+
161
+ if token_indices.numel() > 0:
162
+ expert_weights = topk_weights[token_indices, weight_indices]
163
+ expert_input = hidden_states[token_indices]
164
+
165
+ if self.zero_expert_num is None or expert_idx < len(self.experts):
166
+ expert_output = expert(expert_input)
167
+ elif self.zero_expert_type == "identity":
168
+ expert_output = expert_input
169
+ else:
170
+ raise ValueError("Unknown condition")
171
+
172
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
173
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
174
+
175
+ return final_hidden_states.type(hidden_states.dtype)
176
+
177
+ def forward(self, hidden_states):
178
+ orig_shape = hidden_states.shape
179
+ topk_indices, topk_weights = self.router(hidden_states)
180
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
181
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
182
+ return hidden_states
183
+
184
+
185
+ def rotate_half(x):
186
+ """Rotates half the hidden dims of the input."""
187
+ x1 = x[..., : x.shape[-1] // 2]
188
+ x2 = x[..., x.shape[-1] // 2 :]
189
+ return torch.cat((-x2, x1), dim=-1)
190
+
191
+
192
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
193
+ """
194
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
195
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
196
+ """
197
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
198
+ if n_rep == 1:
199
+ return hidden_states
200
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
201
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
202
+
203
+
204
+ def eager_attention_forward(
205
+ module: nn.Module,
206
+ query: torch.Tensor,
207
+ key: torch.Tensor,
208
+ value: torch.Tensor,
209
+ attention_mask: Optional[torch.Tensor],
210
+ scaling: float,
211
+ dropout: float = 0.0,
212
+ **kwargs: Unpack[TransformersKwargs],
213
+ ):
214
+ key_states = repeat_kv(key, module.num_key_value_groups)
215
+ value_states = repeat_kv(value, module.num_key_value_groups)
216
+
217
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
218
+ if attention_mask is not None:
219
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
220
+ attn_weights = attn_weights + causal_mask
221
+
222
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
223
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
224
+ attn_output = torch.matmul(attn_weights, value_states)
225
+ attn_output = attn_output.transpose(1, 2).contiguous()
226
+
227
+ return attn_output, attn_weights
228
+
229
+
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, use_mla=False):
231
+ """Applies Rotary Position Embedding to the query and key tensors.
232
+
233
+ Args:
234
+ q (`torch.Tensor`): The query tensor.
235
+ k (`torch.Tensor`): The key tensor.
236
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
237
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
238
+ position_ids (`torch.Tensor`, *optional*):
239
+ Deprecated and unused.
240
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
241
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
242
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
243
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
244
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
245
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
246
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
247
+ Returns:
248
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
249
+ """
250
+ cos = cos.unsqueeze(unsqueeze_dim)
251
+ sin = sin.unsqueeze(unsqueeze_dim)
252
+
253
+ if use_mla:
254
+ b, h, s, d = q.shape
255
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
256
+
257
+ b, h, s, d = k.shape
258
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
259
+
260
+ q_embed = (q * cos) + (rotate_half(q) * sin)
261
+ k_embed = (k * cos) + (rotate_half(k) * sin)
262
+ return q_embed, k_embed
263
+
264
+
265
+ class LongcatFlashMLA(nn.Module):
266
+ """Modified from Deepseek MLA"""
267
+
268
+ def __init__(self, config: LongcatFlashConfig, layer_idx: int):
269
+ super().__init__()
270
+ self.config = config
271
+ self.layer_idx = layer_idx
272
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
273
+ self.attention_dropout = config.attention_dropout
274
+ self.num_heads = config.num_attention_heads
275
+ self.rope_theta = config.rope_theta
276
+ self.q_lora_rank = config.q_lora_rank
277
+ self.qk_rope_head_dim = config.qk_rope_head_dim
278
+ self.kv_lora_rank = config.kv_lora_rank
279
+ self.v_head_dim = config.v_head_dim
280
+ self.qk_nope_head_dim = config.qk_nope_head_dim
281
+ self.qk_head_dim = config.qk_head_dim
282
+
283
+ self.is_causal = True
284
+ if self.q_lora_rank is None:
285
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
286
+ else:
287
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
288
+ self.q_a_layernorm = LongcatFlashRMSNorm(config.q_lora_rank)
289
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
290
+
291
+ self.kv_a_proj_with_mqa = nn.Linear(
292
+ config.hidden_size,
293
+ self.kv_lora_rank + self.qk_rope_head_dim,
294
+ bias=config.attention_bias,
295
+ )
296
+ self.kv_a_layernorm = LongcatFlashRMSNorm(self.kv_lora_rank)
297
+ self.kv_b_proj = nn.Linear(
298
+ self.kv_lora_rank,
299
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
300
+ bias=False,
301
+ )
302
+
303
+ self.o_proj = nn.Linear(
304
+ self.num_heads * self.v_head_dim,
305
+ config.hidden_size,
306
+ bias=config.attention_bias,
307
+ )
308
+
309
+ if config.mla_scale_q_lora:
310
+ self.mla_scale_q_lora = (config.hidden_size / self.q_lora_rank) ** 0.5
311
+ if config.mla_scale_kv_lora:
312
+ self.mla_scale_kv_lora = (config.hidden_size / self.kv_lora_rank) ** 0.5
313
+ self.scaling = self.qk_head_dim ** (-0.5)
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
319
+ attention_mask: Optional[torch.Tensor],
320
+ past_key_value: Optional[Cache] = None,
321
+ cache_position: Optional[torch.LongTensor] = None,
322
+ **kwargs: Unpack[FlashAttentionKwargs],
323
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
324
+ batch_size, seq_length = hidden_states.shape[:-1]
325
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
326
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
327
+
328
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
329
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
330
+
331
+ # apply q_lora scaling
332
+ if self.mla_scale_q_lora is not None:
333
+ q_pass = q_pass * self.mla_scale_q_lora
334
+ q_rot = q_rot * self.mla_scale_q_lora
335
+
336
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
337
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
338
+ k_pass = self.kv_a_layernorm(k_pass)
339
+
340
+ # apply kv_lora scaling
341
+ if self.mla_scale_kv_lora is not None:
342
+ k_pass = k_pass * self.mla_scale_kv_lora
343
+
344
+ k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2)
345
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
346
+
347
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
348
+
349
+ cos, sin = position_embeddings
350
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, use_mla=True)
351
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
352
+
353
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
354
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
355
+
356
+ if past_key_value is not None:
357
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
358
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
359
+
360
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
361
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
362
+
363
+ attention_interface: Callable = eager_attention_forward
364
+ if self.config._attn_implementation != "eager":
365
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
366
+
367
+ attn_output, attn_weights = attention_interface(
368
+ self,
369
+ query_states,
370
+ key_states,
371
+ value_states,
372
+ attention_mask,
373
+ dropout=0.0 if not self.training else self.attention_dropout,
374
+ scaling=self.scaling,
375
+ **kwargs,
376
+ )
377
+
378
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
379
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
380
+
381
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
382
+ attn_output = self.o_proj(attn_output)
383
+ return attn_output, attn_weights
384
+
385
+
386
+ def create_attention_block(class_name, *args, **kwargs):
387
+ attention_mapping = {"MLA": LongcatFlashMLA}
388
+
389
+ chosen_class = attention_mapping.get(class_name)
390
+ if not chosen_class:
391
+ raise ValueError(f"No class found for name: {class_name}")
392
+
393
+ return chosen_class(*args, **kwargs)
394
+
395
+
396
+ class LongcatFlashDecoderLayer(GradientCheckpointingLayer):
397
+ def __init__(self, config: LongcatFlashConfig, layer_idx: int):
398
+ super().__init__()
399
+ self.layer_idx = layer_idx
400
+ self.hidden_size = config.hidden_size
401
+ self.mlp = LongcatFlashMoE(config)
402
+
403
+ self_attn = []
404
+ mlps = []
405
+ input_layernorm = []
406
+ post_attention_layernorm = []
407
+ for i in range(2):
408
+ self_attn.append(
409
+ create_attention_block(config.attention_method, config=config, layer_idx=layer_idx * 2 + i)
410
+ )
411
+ mlps.append(LongcatFlashMLP(config))
412
+ input_layernorm.append(LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps))
413
+ post_attention_layernorm.append(LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps))
414
+
415
+ self.self_attn = nn.ModuleList(self_attn)
416
+ self.mlps = nn.ModuleList(mlps)
417
+ self.input_layernorm = nn.ModuleList(input_layernorm)
418
+ self.post_attention_layernorm = nn.ModuleList(post_attention_layernorm)
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ position_ids: Optional[torch.LongTensor] = None,
425
+ past_key_value: Optional[Cache] = None,
426
+ use_cache: Optional[bool] = False,
427
+ cache_position: Optional[torch.LongTensor] = None,
428
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
429
+ **kwargs: Unpack[FlashAttentionKwargs],
430
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
431
+ for i in range(2):
432
+ residual = hidden_states
433
+
434
+ hidden_states = self.input_layernorm[i](hidden_states)
435
+
436
+ hidden_states, _ = self.self_attn[i](
437
+ hidden_states=hidden_states,
438
+ attention_mask=attention_mask,
439
+ position_ids=position_ids,
440
+ past_key_value=past_key_value,
441
+ use_cache=use_cache,
442
+ cache_position=cache_position,
443
+ position_embeddings=position_embeddings,
444
+ **kwargs,
445
+ )
446
+ hidden_states = residual + hidden_states
447
+
448
+ residual = hidden_states
449
+ hidden_states = self.post_attention_layernorm[i](hidden_states)
450
+
451
+ if i == 0:
452
+ shortcut_mlp_output = self.mlp(hidden_states) # shortcut output (MoE output)
453
+
454
+ hidden_states = self.mlps[i](hidden_states)
455
+ hidden_states = residual + hidden_states
456
+ if i == 1:
457
+ hidden_states = hidden_states + shortcut_mlp_output
458
+
459
+ return hidden_states
460
+
461
+
462
+ @auto_docstring
463
+ class LongcatFlashPreTrainedModel(PreTrainedModel):
464
+ config: LongcatFlashConfig
465
+ base_model_prefix = "model"
466
+ supports_gradient_checkpointing = True
467
+ _no_split_modules = ["LongcatFlashDecoderLayer"]
468
+ _skip_keys_device_placement = ["past_key_values"]
469
+ _supports_flash_attn = True
470
+ _supports_sdpa = True
471
+ _supports_flex_attn = True
472
+ _can_compile_fullgraph = True
473
+ _supports_attention_backend = True
474
+ _can_record_outputs = {
475
+ "hidden_states": LongcatFlashDecoderLayer,
476
+ "attentions": LongcatFlashMLA,
477
+ }
478
+
479
+
480
+ @auto_docstring
481
+ class LongcatFlashModel(LongcatFlashPreTrainedModel):
482
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
483
+
484
+ def __init__(self, config: LongcatFlashConfig):
485
+ super().__init__(config)
486
+ self.padding_idx = config.pad_token_id
487
+ self.vocab_size = config.vocab_size
488
+
489
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
490
+ self.layers = nn.ModuleList(
491
+ [LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
492
+ )
493
+ self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
494
+ self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
495
+ self.gradient_checkpointing = False
496
+
497
+ # Initialize weights and apply final processing
498
+ self.post_init()
499
+
500
+ @check_model_inputs
501
+ @auto_docstring
502
+ def forward(
503
+ self,
504
+ input_ids: Optional[torch.LongTensor] = None,
505
+ attention_mask: Optional[torch.Tensor] = None,
506
+ position_ids: Optional[torch.LongTensor] = None,
507
+ past_key_values: Optional[Cache] = None,
508
+ inputs_embeds: Optional[torch.FloatTensor] = None,
509
+ cache_position: Optional[torch.LongTensor] = None,
510
+ use_cache: Optional[bool] = None,
511
+ **kwargs: Unpack[TransformersKwargs],
512
+ ) -> BaseModelOutputWithPast:
513
+ if (input_ids is None) ^ (inputs_embeds is not None):
514
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
515
+
516
+ if inputs_embeds is None:
517
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
518
+
519
+ if use_cache and past_key_values is None:
520
+ past_key_values = DynamicCache()
521
+
522
+ if cache_position is None:
523
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
524
+ cache_position: torch.Tensor = torch.arange(
525
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
526
+ )
527
+
528
+ if position_ids is None:
529
+ position_ids = cache_position.unsqueeze(0)
530
+
531
+ causal_mask = create_causal_mask(
532
+ config=self.config,
533
+ input_embeds=inputs_embeds,
534
+ attention_mask=attention_mask,
535
+ cache_position=cache_position,
536
+ past_key_values=past_key_values,
537
+ position_ids=position_ids,
538
+ )
539
+
540
+ hidden_states = inputs_embeds
541
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
542
+
543
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
544
+ hidden_states = decoder_layer(
545
+ hidden_states,
546
+ attention_mask=causal_mask,
547
+ position_ids=position_ids,
548
+ past_key_value=past_key_values,
549
+ cache_position=cache_position,
550
+ position_embeddings=position_embeddings,
551
+ **kwargs,
552
+ )
553
+
554
+ hidden_states = self.norm(hidden_states)
555
+ return BaseModelOutputWithPast(
556
+ last_hidden_state=hidden_states,
557
+ past_key_values=past_key_values,
558
+ )
559
+
560
+
561
+ @auto_docstring
562
+ class LongcatFlashForCausalLM(LongcatFlashPreTrainedModel, GenerationMixin):
563
+ _tied_weights_keys = ["lm_head.weight"]
564
+ _tp_plan = {"lm_head": "colwise_rep"}
565
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
566
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
567
+
568
+ def __init__(self, config):
569
+ super().__init__(config)
570
+ self.model = LongcatFlashModel(config)
571
+ self.vocab_size = config.vocab_size
572
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
573
+
574
+ # Initialize weights and apply final processing
575
+ self.post_init()
576
+
577
+ def set_decoder(self, decoder):
578
+ self.model = decoder
579
+
580
+ def get_decoder(self):
581
+ return self.model
582
+
583
+ @can_return_tuple
584
+ @auto_docstring
585
+ def forward(
586
+ self,
587
+ input_ids: Optional[torch.LongTensor] = None,
588
+ attention_mask: Optional[torch.Tensor] = None,
589
+ position_ids: Optional[torch.LongTensor] = None,
590
+ past_key_values: Optional[Cache] = None,
591
+ inputs_embeds: Optional[torch.FloatTensor] = None,
592
+ labels: Optional[torch.LongTensor] = None,
593
+ use_cache: Optional[bool] = None,
594
+ cache_position: Optional[torch.LongTensor] = None,
595
+ logits_to_keep: Union[int, torch.Tensor] = 0,
596
+ **kwargs: Unpack[TransformersKwargs],
597
+ ) -> CausalLMOutputWithPast:
598
+ r"""
599
+ Example:
600
+
601
+ ```python
602
+ >>> from transformers import AutoTokenizer, LongcatFlashForCausalLM
603
+
604
+ >>> model = LongcatFlashForCausalLM.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
605
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
606
+
607
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
608
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
609
+
610
+ >>> # Generate
611
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
612
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
613
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
614
+ ```"""
615
+ outputs: BaseModelOutputWithPast = self.model(
616
+ input_ids=input_ids,
617
+ attention_mask=attention_mask,
618
+ position_ids=position_ids,
619
+ past_key_values=past_key_values,
620
+ inputs_embeds=inputs_embeds,
621
+ use_cache=use_cache,
622
+ cache_position=cache_position,
623
+ **kwargs,
624
+ )
625
+
626
+ hidden_states = outputs.last_hidden_state
627
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
628
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
629
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
630
+
631
+ loss = None
632
+ if labels is not None:
633
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
634
+
635
+ return CausalLMOutputWithPast(
636
+ loss=loss,
637
+ logits=logits,
638
+ past_key_values=outputs.past_key_values,
639
+ hidden_states=outputs.hidden_states,
640
+ attentions=outputs.attentions,
641
+ )
642
+
643
+
644
+ __all__ = ["LongcatFlashPreTrainedModel", "LongcatFlashModel", "LongcatFlashForCausalLM"]