upmodel-lwx commited on
Commit
6a02c44
·
verified ·
1 Parent(s): effaa79

Upload folder using huggingface_hub

Browse files
configuration_telechat3.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Telechat configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class Telechat3Config(PretrainedConfig):
22
+ model_type = "telechat3"
23
+ keys_to_ignore_at_inference = ["past_key_values"]
24
+ base_model_tp_plan = {
25
+ "layers.*.self_attn.q_proj": "colwise",
26
+ "layers.*.self_attn.k_proj": "colwise",
27
+ "layers.*.self_attn.v_proj": "colwise",
28
+ "layers.*.self_attn.o_proj": "rowwise",
29
+ "layers.*.mlp.gate_proj": "colwise",
30
+ "layers.*.mlp.up_proj": "colwise",
31
+ "layers.*.mlp.down_proj": "rowwise",
32
+ }
33
+ base_model_pp_plan = {
34
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
35
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
36
+ "norm": (["hidden_states"], ["hidden_states"]),
37
+ }
38
+
39
+ def __init__(
40
+ self,
41
+ attention_bias=False,
42
+ attention_dropout=0.0,
43
+ bos_token_id=1,
44
+ eos_token_id=2,
45
+ head_dim=128,
46
+ hidden_act="silu",
47
+ hidden_size=6144,
48
+ initializer_range=0.0048,
49
+ intermediate_size=24576,
50
+ max_position_embeddings=2048,
51
+ mlp_bias=False,
52
+ model_type="telechat3",
53
+ num_attention_heads=48,
54
+ num_hidden_layers=64,
55
+ num_key_value_heads=None,
56
+ original_max_position_embeddings=8192,
57
+ pad_token_id=None,
58
+ pretraining_tp=1,
59
+ rms_norm_eps=1e-5,
60
+ rope_scaling=None,
61
+ rope_theta=1000000.0,
62
+ tie_word_embeddings=False,
63
+ use_cache=True,
64
+ vocab_size=131072,
65
+ **kwargs,
66
+ ):
67
+ self.attention_bias = attention_bias
68
+ self.attention_dropout = attention_dropout
69
+ self.hidden_size = hidden_size
70
+ self.hidden_act = hidden_act
71
+ self.intermediate_size = intermediate_size
72
+ self.mlp_bias = mlp_bias
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+
77
+ # for backward compatibility
78
+ if num_key_value_heads is None:
79
+ num_key_value_heads = num_attention_heads
80
+ self.num_key_value_heads = num_key_value_heads
81
+
82
+ self.initializer_range = initializer_range
83
+
84
+ self.pretraining_tp = pretraining_tp
85
+ self.rms_norm_eps = rms_norm_eps
86
+ self.rope_theta = rope_theta
87
+ self.rope_scaling = rope_scaling
88
+ self.use_cache = use_cache
89
+ self.vocab_size = vocab_size
90
+
91
+ if head_dim is not None and head_dim != self.hidden_size // self.num_attention_heads:
92
+ raise ValueError("head_dim != hidden_size//num_attention_head.Please check the config.")
93
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
94
+
95
+ # Validate the correctness of rotary position embeddings parameters
96
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
97
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
98
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
99
+
100
+ super().__init__(
101
+ pad_token_id=pad_token_id,
102
+ bos_token_id=bos_token_id,
103
+ eos_token_id=eos_token_id,
104
+ tie_word_embeddings=tie_word_embeddings,
105
+ **kwargs,
106
+ )
modeling_telechat3.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import Callable, Optional, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.configuration_utils import PretrainedConfig
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import GradientCheckpointingLayer
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutputWithPast,
40
+ TokenClassifierOutput,
41
+ )
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from transformers.processing_utils import Unpack
45
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
46
+
47
+ from .configuration_telechat3 import Telechat3Config
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ # Compute the inverse frequencies
53
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
54
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
55
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
56
+
57
+
58
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
59
+ """Find dimension range bounds based on rotations"""
60
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
61
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
62
+ return max(low, 0), min(high, dim - 1)
63
+
64
+
65
+ def linear_ramp_factor(min, max, dim):
66
+ if min == max:
67
+ max += 0.001 # Prevent singularity
68
+
69
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
70
+ ramp_func = torch.clamp(linear_func, 0, 1)
71
+ return ramp_func
72
+
73
+
74
+ def _compute_telechat_yarn_parameters(
75
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
76
+ ) -> tuple["torch.Tensor", float]:
77
+ """
78
+ Computes the inverse frequencies with NTK scaling. Please refer to the
79
+ [original paper](https://huggingface.co/papers/2309.00071)
80
+ Args:
81
+ config ([`~transformers.PretrainedConfig`]):
82
+ The model configuration.
83
+ device (`torch.device`):
84
+ The device to use for initialization of the inverse frequencies.
85
+ seq_len (`int`, *optional*):
86
+ The current sequence length. Unused for this type of RoPE.
87
+ rope_kwargs (`Dict`, *optional*):
88
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
89
+ Returns:
90
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
91
+ post-processing scaling factor applied to the computed cos/sin.
92
+ """
93
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
94
+ if len(rope_kwargs) > 0:
95
+ raise ValueError(
96
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
97
+ )
98
+
99
+ base = config.rope_theta
100
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
101
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
102
+ dim = int(head_dim * partial_rotary_factor)
103
+ factor = config.rope_scaling["factor"]
104
+ attention_factor = config.rope_scaling.get("attention_factor")
105
+ mscale = config.rope_scaling.get("mscale")
106
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
107
+
108
+ # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
109
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
110
+ # values to compute the default attention scaling factor, instead of using `factor`.
111
+ if "original_max_position_embeddings" in config.rope_scaling:
112
+ original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
113
+ factor = config.max_position_embeddings / original_max_position_embeddings
114
+ else:
115
+ original_max_position_embeddings = config.max_position_embeddings
116
+
117
+ def get_mscale(scale, mscale=1):
118
+ if scale <= 1:
119
+ return 1.0
120
+ return 0.07 * mscale * math.log(scale) + 1.0
121
+
122
+ # Sets the attention factor as suggested in the paper
123
+ if attention_factor is None:
124
+ if mscale and mscale_all_dim:
125
+ attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
126
+ else:
127
+ attention_factor = get_mscale(factor)
128
+
129
+ # Optional config options
130
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
131
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
132
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
133
+
134
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
135
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
136
+ pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
137
+ inv_freq_extrapolation = 1.0 / pos_freqs
138
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
139
+
140
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
141
+
142
+ # Get n-dimensional rotational scaling corrected for extrapolation
143
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
144
+ inv_freq = (
145
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
146
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
147
+ )
148
+ return inv_freq, attention_factor
149
+
150
+
151
+ ROPE_INIT_FUNCTIONS['telechat3-yarn'] = _compute_telechat_yarn_parameters
152
+
153
+
154
+ @use_kernel_forward_from_hub("RMSNorm")
155
+ class Telechat3RMSNorm(nn.Module):
156
+ def __init__(self, hidden_size, eps=1e-6):
157
+ """
158
+ Telechat3RMSNorm is equivalent to T5LayerNorm
159
+ """
160
+ super().__init__()
161
+ self.weight = nn.Parameter(torch.ones(hidden_size))
162
+ self.variance_epsilon = eps
163
+
164
+ def forward(self, hidden_states):
165
+ input_dtype = hidden_states.dtype
166
+ hidden_states = hidden_states.to(torch.float32)
167
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
168
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
169
+ return self.weight * hidden_states.to(input_dtype)
170
+
171
+ def extra_repr(self):
172
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
173
+
174
+
175
+ class Telechat3RotaryEmbedding(nn.Module):
176
+ def __init__(self, config: Telechat3Config, device=None):
177
+ super().__init__()
178
+ # BC: "rope_type" was originally "type"
179
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
180
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
181
+ else:
182
+ self.rope_type = "default"
183
+ self.max_seq_len_cached = config.max_position_embeddings
184
+ self.original_max_seq_len = config.max_position_embeddings
185
+
186
+ self.config = config
187
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
188
+
189
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
190
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
191
+ self.original_inv_freq = self.inv_freq
192
+
193
+ @torch.no_grad()
194
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
195
+ def forward(self, x, position_ids):
196
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
197
+ position_ids_expanded = position_ids[:, None, :].float()
198
+
199
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
200
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
201
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
202
+ emb = torch.cat((freqs, freqs), dim=-1)
203
+ cos = emb.cos() * self.attention_scaling
204
+ sin = emb.sin() * self.attention_scaling
205
+
206
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
207
+
208
+
209
+ def rotate_half(x):
210
+ """Rotates half the hidden dims of the input."""
211
+ x1 = x[..., : x.shape[-1] // 2]
212
+ x2 = x[..., x.shape[-1] // 2:]
213
+ return torch.cat((-x2, x1), dim=-1)
214
+
215
+
216
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
217
+ """Applies Rotary Position Embedding to the query and key tensors.
218
+
219
+ Args:
220
+ q (`torch.Tensor`): The query tensor.
221
+ k (`torch.Tensor`): The key tensor.
222
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
223
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
224
+ position_ids (`torch.Tensor`, *optional*):
225
+ Deprecated and unused.
226
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
227
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
228
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
229
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
230
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
231
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
232
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
233
+ Returns:
234
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
235
+ """
236
+ cos = cos.unsqueeze(unsqueeze_dim)
237
+ sin = sin.unsqueeze(unsqueeze_dim)
238
+ q_embed = (q * cos) + (rotate_half(q) * sin)
239
+ k_embed = (k * cos) + (rotate_half(k) * sin)
240
+ return q_embed, k_embed
241
+
242
+
243
+ class Telechat3MLP(nn.Module):
244
+ def __init__(self, config):
245
+ super().__init__()
246
+ self.config = config
247
+ self.hidden_size = config.hidden_size
248
+ self.intermediate_size = config.intermediate_size
249
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
250
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
251
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
252
+ self.act_fn = ACT2FN[config.hidden_act]
253
+
254
+ def forward(self, x):
255
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
256
+ return down_proj
257
+
258
+
259
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
260
+ """
261
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
262
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
263
+ """
264
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
265
+ if n_rep == 1:
266
+ return hidden_states
267
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
268
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
269
+
270
+
271
+ def eager_attention_forward(
272
+ module: nn.Module,
273
+ query: torch.Tensor,
274
+ key: torch.Tensor,
275
+ value: torch.Tensor,
276
+ attention_mask: Optional[torch.Tensor],
277
+ scaling: float,
278
+ dropout: float = 0.0,
279
+ **kwargs,
280
+ ):
281
+ key_states = repeat_kv(key, module.num_key_value_groups)
282
+ value_states = repeat_kv(value, module.num_key_value_groups)
283
+
284
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
285
+ if attention_mask is not None:
286
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
287
+ attn_weights = attn_weights + causal_mask
288
+
289
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
290
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
291
+ attn_output = torch.matmul(attn_weights, value_states)
292
+ attn_output = attn_output.transpose(1, 2).contiguous()
293
+
294
+ return attn_output, attn_weights
295
+
296
+
297
+ class Telechat3Attention(nn.Module):
298
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
299
+
300
+ def __init__(self, config: Telechat3Config, layer_idx: int):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
305
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
306
+ self.scaling = self.head_dim ** -0.5
307
+ self.attention_dropout = config.attention_dropout
308
+ self.is_causal = True
309
+
310
+ self.q_proj = nn.Linear(
311
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
312
+ )
313
+ self.k_proj = nn.Linear(
314
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
315
+ )
316
+ self.v_proj = nn.Linear(
317
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
318
+ )
319
+ self.o_proj = nn.Linear(
320
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
321
+ )
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states: torch.Tensor,
326
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
327
+ attention_mask: Optional[torch.Tensor],
328
+ past_key_value: Optional[Cache] = None,
329
+ cache_position: Optional[torch.LongTensor] = None,
330
+ **kwargs: Unpack[FlashAttentionKwargs],
331
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
332
+ input_shape = hidden_states.shape[:-1]
333
+ hidden_shape = (*input_shape, -1, self.head_dim)
334
+
335
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
336
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
337
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
338
+
339
+ cos, sin = position_embeddings
340
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
341
+
342
+ if past_key_value is not None:
343
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
344
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
345
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
346
+
347
+ attention_interface: Callable = eager_attention_forward
348
+ if self.config._attn_implementation != "eager":
349
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
350
+
351
+ attn_output, attn_weights = attention_interface(
352
+ self,
353
+ query_states,
354
+ key_states,
355
+ value_states,
356
+ attention_mask,
357
+ dropout=0.0 if not self.training else self.attention_dropout,
358
+ scaling=self.scaling,
359
+ **kwargs,
360
+ )
361
+
362
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
363
+ attn_output = self.o_proj(attn_output)
364
+ return attn_output, attn_weights
365
+
366
+
367
+ class Telechat3DecoderLayer(GradientCheckpointingLayer):
368
+ def __init__(self, config: Telechat3Config, layer_idx: int):
369
+ super().__init__()
370
+ self.hidden_size = config.hidden_size
371
+
372
+ self.self_attn = Telechat3Attention(config=config, layer_idx=layer_idx)
373
+
374
+ self.mlp = Telechat3MLP(config)
375
+ self.input_layernorm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.post_attention_layernorm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
+
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_value: Optional[Cache] = None,
384
+ output_attentions: Optional[bool] = False,
385
+ use_cache: Optional[bool] = False,
386
+ cache_position: Optional[torch.LongTensor] = None,
387
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
388
+ **kwargs: Unpack[FlashAttentionKwargs],
389
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
390
+ residual = hidden_states
391
+ hidden_states = self.input_layernorm(hidden_states)
392
+
393
+ # Self Attention
394
+ hidden_states, self_attn_weights = self.self_attn(
395
+ hidden_states=hidden_states,
396
+ attention_mask=attention_mask,
397
+ position_ids=position_ids,
398
+ past_key_value=past_key_value,
399
+ output_attentions=output_attentions,
400
+ use_cache=use_cache,
401
+ cache_position=cache_position,
402
+ position_embeddings=position_embeddings,
403
+ **kwargs,
404
+ )
405
+ hidden_states = residual + hidden_states
406
+
407
+ # Fully Connected
408
+ residual = hidden_states
409
+ hidden_states = self.post_attention_layernorm(hidden_states)
410
+ hidden_states = self.mlp(hidden_states)
411
+ hidden_states = residual + hidden_states
412
+
413
+ outputs = (hidden_states,)
414
+ if output_attentions:
415
+ outputs += (self_attn_weights,)
416
+
417
+ return outputs
418
+
419
+
420
+ @auto_docstring
421
+ class Telechat3PreTrainedModel(PreTrainedModel):
422
+ config_class = Telechat3Config
423
+ base_model_prefix = "model"
424
+ supports_gradient_checkpointing = True
425
+ _no_split_modules = ["Telechat3DecoderLayer"]
426
+ _skip_keys_device_placement = ["past_key_values"]
427
+ _supports_flash_attn_3 = True
428
+ _supports_flash_attn_2 = True
429
+ _supports_sdpa = True
430
+ _supports_flex_attn = True
431
+ _supports_cache_class = True
432
+ _supports_quantized_cache = True
433
+ _supports_static_cache = True
434
+ _supports_attention_backend = True
435
+
436
+ def _init_weights(self, module):
437
+ std = self.config.initializer_range
438
+ if isinstance(module, nn.Linear):
439
+ module.weight.data.normal_(mean=0.0, std=std)
440
+ if module.bias is not None:
441
+ module.bias.data.zero_()
442
+ elif isinstance(module, nn.Embedding):
443
+ module.weight.data.normal_(mean=0.0, std=std)
444
+ if module.padding_idx is not None:
445
+ module.weight.data[module.padding_idx].zero_()
446
+ elif isinstance(module, Telechat3RMSNorm):
447
+ module.weight.data.fill_(1.0)
448
+
449
+
450
+ @auto_docstring
451
+ class Telechat3Model(Telechat3PreTrainedModel):
452
+ def __init__(self, config: Telechat3Config):
453
+ super().__init__(config)
454
+ self.padding_idx = config.pad_token_id
455
+ self.vocab_size = config.vocab_size
456
+
457
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
458
+ self.layers = nn.ModuleList(
459
+ [Telechat3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
460
+ )
461
+ self.norm = Telechat3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+ self.rotary_emb = Telechat3RotaryEmbedding(config=config)
463
+ self.gradient_checkpointing = False
464
+
465
+ # Initialize weights and apply final processing
466
+ self.post_init()
467
+
468
+ def get_input_embeddings(self):
469
+ return self.embed_tokens
470
+
471
+ def set_input_embeddings(self, value):
472
+ self.embed_tokens = value
473
+
474
+ @can_return_tuple
475
+ @auto_docstring
476
+ def forward(
477
+ self,
478
+ input_ids: Optional[torch.LongTensor] = None,
479
+ attention_mask: Optional[torch.Tensor] = None,
480
+ position_ids: Optional[torch.LongTensor] = None,
481
+ past_key_values: Optional[Cache] = None,
482
+ inputs_embeds: Optional[torch.FloatTensor] = None,
483
+ use_cache: Optional[bool] = None,
484
+ output_attentions: Optional[bool] = None,
485
+ output_hidden_states: Optional[bool] = None,
486
+ cache_position: Optional[torch.LongTensor] = None,
487
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
488
+ ) -> BaseModelOutputWithPast:
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
494
+
495
+ if (input_ids is None) ^ (inputs_embeds is not None):
496
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
497
+
498
+ if self.gradient_checkpointing and self.training and use_cache:
499
+ logger.warning_once(
500
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
501
+ )
502
+ use_cache = False
503
+
504
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
505
+ if not isinstance(past_key_values, (type(None), Cache)):
506
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
507
+
508
+ if inputs_embeds is None:
509
+ inputs_embeds = self.embed_tokens(input_ids)
510
+
511
+ if use_cache and past_key_values is None:
512
+ past_key_values = DynamicCache()
513
+
514
+ if cache_position is None:
515
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
516
+ cache_position = torch.arange(
517
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
518
+ )
519
+
520
+ if position_ids is None:
521
+ position_ids = cache_position.unsqueeze(0)
522
+
523
+ causal_mask = create_causal_mask(
524
+ config=self.config,
525
+ input_embeds=inputs_embeds,
526
+ attention_mask=attention_mask,
527
+ cache_position=cache_position,
528
+ past_key_values=past_key_values,
529
+ position_ids=position_ids,
530
+ )
531
+
532
+ hidden_states = inputs_embeds
533
+
534
+ # create position embeddings to be shared across the decoder layers
535
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
536
+
537
+ # decoder layers
538
+ all_hidden_states = () if output_hidden_states else None
539
+ all_self_attns = () if output_attentions else None
540
+
541
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
542
+ if output_hidden_states:
543
+ all_hidden_states += (hidden_states,)
544
+
545
+ layer_outputs = decoder_layer(
546
+ hidden_states,
547
+ attention_mask=causal_mask,
548
+ position_ids=position_ids,
549
+ past_key_value=past_key_values,
550
+ output_attentions=output_attentions,
551
+ use_cache=use_cache,
552
+ cache_position=cache_position,
553
+ position_embeddings=position_embeddings,
554
+ **flash_attn_kwargs,
555
+ )
556
+
557
+ hidden_states = layer_outputs[0]
558
+
559
+ if output_attentions:
560
+ all_self_attns += (layer_outputs[1],)
561
+
562
+ hidden_states = self.norm(hidden_states)
563
+
564
+ # add hidden states from the last decoder layer
565
+ if output_hidden_states:
566
+ all_hidden_states += (hidden_states,)
567
+
568
+ return BaseModelOutputWithPast(
569
+ last_hidden_state=hidden_states,
570
+ past_key_values=past_key_values if use_cache else None,
571
+ hidden_states=all_hidden_states,
572
+ attentions=all_self_attns,
573
+ )
574
+
575
+
576
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
577
+
578
+
579
+ @auto_docstring
580
+ class Telechat3ForCausalLM(Telechat3PreTrainedModel, GenerationMixin):
581
+ _tied_weights_keys = ["lm_head.weight"]
582
+ _tp_plan = {"lm_head": "colwise_rep"}
583
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
584
+
585
+ def __init__(self, config):
586
+ super().__init__(config)
587
+ self.model = Telechat3Model(config)
588
+ self.vocab_size = config.vocab_size
589
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
+
591
+ # Initialize weights and apply final processing
592
+ self.post_init()
593
+
594
+ def get_input_embeddings(self):
595
+ return self.model.embed_tokens
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.model.embed_tokens = value
599
+
600
+ def get_output_embeddings(self):
601
+ return self.lm_head
602
+
603
+ def set_output_embeddings(self, new_embeddings):
604
+ self.lm_head = new_embeddings
605
+
606
+ def set_decoder(self, decoder):
607
+ self.model = decoder
608
+
609
+ def get_decoder(self):
610
+ return self.model
611
+
612
+ @can_return_tuple
613
+ @auto_docstring
614
+ def forward(
615
+ self,
616
+ input_ids: Optional[torch.LongTensor] = None,
617
+ attention_mask: Optional[torch.Tensor] = None,
618
+ position_ids: Optional[torch.LongTensor] = None,
619
+ past_key_values: Optional[Cache] = None,
620
+ inputs_embeds: Optional[torch.FloatTensor] = None,
621
+ labels: Optional[torch.LongTensor] = None,
622
+ use_cache: Optional[bool] = None,
623
+ output_attentions: Optional[bool] = None,
624
+ output_hidden_states: Optional[bool] = None,
625
+ cache_position: Optional[torch.LongTensor] = None,
626
+ logits_to_keep: Union[int, torch.Tensor] = 0,
627
+ **kwargs: Unpack[KwargsForCausalLM],
628
+ ) -> CausalLMOutputWithPast:
629
+
630
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
631
+ output_hidden_states = (
632
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
633
+ )
634
+
635
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
636
+ outputs: BaseModelOutputWithPast = self.model(
637
+ input_ids=input_ids,
638
+ attention_mask=attention_mask,
639
+ position_ids=position_ids,
640
+ past_key_values=past_key_values,
641
+ inputs_embeds=inputs_embeds,
642
+ use_cache=use_cache,
643
+ output_attentions=output_attentions,
644
+ output_hidden_states=output_hidden_states,
645
+ cache_position=cache_position,
646
+ **kwargs,
647
+ )
648
+
649
+ hidden_states = outputs.last_hidden_state
650
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
651
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
652
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
653
+
654
+ loss = None
655
+ if labels is not None:
656
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
657
+
658
+ return CausalLMOutputWithPast(
659
+ loss=loss,
660
+ logits=logits,
661
+ past_key_values=outputs.past_key_values,
662
+ hidden_states=outputs.hidden_states,
663
+ attentions=outputs.attentions,
664
+ )
665
+
666
+
667
+ @auto_docstring(
668
+ custom_intro="""
669
+ The Telechat3 Model transformer with a sequence classification head on top (linear layer).
670
+
671
+ [`Telechat3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
672
+ (e.g. GPT-2) do.
673
+
674
+ Since it does classification on the last token, it requires to know the position of the last token. If a
675
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
676
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
677
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
678
+ each row of the batch).
679
+ """
680
+ )
681
+ class Telechat3ForSequenceClassification(Telechat3PreTrainedModel):
682
+ def __init__(self, config):
683
+ super().__init__(config)
684
+ self.num_labels = config.num_labels
685
+ self.model = Telechat3Model(config)
686
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
687
+
688
+ # Initialize weights and apply final processing
689
+ self.post_init()
690
+
691
+ def get_input_embeddings(self):
692
+ return self.model.embed_tokens
693
+
694
+ def set_input_embeddings(self, value):
695
+ self.model.embed_tokens = value
696
+
697
+ @can_return_tuple
698
+ @auto_docstring
699
+ def forward(
700
+ self,
701
+ input_ids: Optional[torch.LongTensor] = None,
702
+ attention_mask: Optional[torch.Tensor] = None,
703
+ position_ids: Optional[torch.LongTensor] = None,
704
+ past_key_values: Optional[Cache] = None,
705
+ inputs_embeds: Optional[torch.FloatTensor] = None,
706
+ labels: Optional[torch.LongTensor] = None,
707
+ use_cache: Optional[bool] = None,
708
+ output_attentions: Optional[bool] = None,
709
+ output_hidden_states: Optional[bool] = None,
710
+ ) -> SequenceClassifierOutputWithPast:
711
+ r"""
712
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
713
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
714
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
715
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
716
+ """
717
+
718
+ transformer_outputs: BaseModelOutputWithPast = self.model(
719
+ input_ids,
720
+ attention_mask=attention_mask,
721
+ position_ids=position_ids,
722
+ past_key_values=past_key_values,
723
+ inputs_embeds=inputs_embeds,
724
+ use_cache=use_cache,
725
+ output_attentions=output_attentions,
726
+ output_hidden_states=output_hidden_states,
727
+ )
728
+ hidden_states = transformer_outputs.last_hidden_state
729
+ logits = self.score(hidden_states)
730
+
731
+ if input_ids is not None:
732
+ batch_size = input_ids.shape[0]
733
+ else:
734
+ batch_size = inputs_embeds.shape[0]
735
+
736
+ if self.config.pad_token_id is None and batch_size != 1:
737
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
738
+ if self.config.pad_token_id is None:
739
+ last_non_pad_token = -1
740
+ elif input_ids is not None:
741
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
742
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
743
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
744
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
745
+ else:
746
+ last_non_pad_token = -1
747
+ logger.warning_once(
748
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
749
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
750
+ )
751
+
752
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
753
+
754
+ loss = None
755
+ if labels is not None:
756
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
757
+
758
+ return SequenceClassifierOutputWithPast(
759
+ loss=loss,
760
+ logits=pooled_logits,
761
+ past_key_values=transformer_outputs.past_key_values,
762
+ hidden_states=transformer_outputs.hidden_states,
763
+ attentions=transformer_outputs.attentions,
764
+ )
765
+
766
+
767
+ @auto_docstring
768
+ class Telechat3ForQuestionAnswering(Telechat3PreTrainedModel):
769
+ base_model_prefix = "transformer"
770
+
771
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Telechat3
772
+ def __init__(self, config):
773
+ super().__init__(config)
774
+ self.transformer = Telechat3Model(config)
775
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
776
+
777
+ # Initialize weights and apply final processing
778
+ self.post_init()
779
+
780
+ def get_input_embeddings(self):
781
+ return self.transformer.embed_tokens
782
+
783
+ def set_input_embeddings(self, value):
784
+ self.transformer.embed_tokens = value
785
+
786
+ @can_return_tuple
787
+ @auto_docstring
788
+ def forward(
789
+ self,
790
+ input_ids: Optional[torch.LongTensor] = None,
791
+ attention_mask: Optional[torch.Tensor] = None,
792
+ position_ids: Optional[torch.LongTensor] = None,
793
+ past_key_values: Optional[Cache] = None,
794
+ inputs_embeds: Optional[torch.FloatTensor] = None,
795
+ start_positions: Optional[torch.LongTensor] = None,
796
+ end_positions: Optional[torch.LongTensor] = None,
797
+ output_attentions: Optional[bool] = None,
798
+ output_hidden_states: Optional[bool] = None,
799
+ **kwargs,
800
+ ) -> QuestionAnsweringModelOutput:
801
+ outputs: BaseModelOutputWithPast = self.transformer(
802
+ input_ids,
803
+ attention_mask=attention_mask,
804
+ position_ids=position_ids,
805
+ past_key_values=past_key_values,
806
+ inputs_embeds=inputs_embeds,
807
+ output_attentions=output_attentions,
808
+ output_hidden_states=output_hidden_states,
809
+ )
810
+
811
+ sequence_output = outputs.last_hidden_state
812
+
813
+ logits = self.qa_outputs(sequence_output)
814
+ start_logits, end_logits = logits.split(1, dim=-1)
815
+ start_logits = start_logits.squeeze(-1).contiguous()
816
+ end_logits = end_logits.squeeze(-1).contiguous()
817
+
818
+ loss = None
819
+ if start_positions is not None and end_positions is not None:
820
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
821
+
822
+ return QuestionAnsweringModelOutput(
823
+ loss=loss,
824
+ start_logits=start_logits,
825
+ end_logits=end_logits,
826
+ hidden_states=outputs.hidden_states,
827
+ attentions=outputs.attentions,
828
+ )
829
+
830
+
831
+ @auto_docstring
832
+ class Telechat3ForTokenClassification(Telechat3PreTrainedModel):
833
+ def __init__(self, config):
834
+ super().__init__(config)
835
+ self.num_labels = config.num_labels
836
+ self.model = Telechat3Model(config)
837
+ if getattr(config, "classifier_dropout", None) is not None:
838
+ classifier_dropout = config.classifier_dropout
839
+ elif getattr(config, "hidden_dropout", None) is not None:
840
+ classifier_dropout = config.hidden_dropout
841
+ else:
842
+ classifier_dropout = 0.1
843
+ self.dropout = nn.Dropout(classifier_dropout)
844
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
845
+
846
+ # Initialize weights and apply final processing
847
+ self.post_init()
848
+
849
+ def get_input_embeddings(self):
850
+ return self.model.embed_tokens
851
+
852
+ def set_input_embeddings(self, value):
853
+ self.model.embed_tokens = value
854
+
855
+ @can_return_tuple
856
+ @auto_docstring
857
+ def forward(
858
+ self,
859
+ input_ids: Optional[torch.LongTensor] = None,
860
+ attention_mask: Optional[torch.Tensor] = None,
861
+ position_ids: Optional[torch.LongTensor] = None,
862
+ past_key_values: Optional[Cache] = None,
863
+ inputs_embeds: Optional[torch.FloatTensor] = None,
864
+ labels: Optional[torch.LongTensor] = None,
865
+ use_cache: Optional[bool] = None,
866
+ output_attentions: Optional[bool] = None,
867
+ output_hidden_states: Optional[bool] = None,
868
+ ) -> TokenClassifierOutput:
869
+ r"""
870
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
871
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
872
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
873
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
874
+ """
875
+
876
+ outputs: BaseModelOutputWithPast = self.model(
877
+ input_ids,
878
+ attention_mask=attention_mask,
879
+ position_ids=position_ids,
880
+ past_key_values=past_key_values,
881
+ inputs_embeds=inputs_embeds,
882
+ use_cache=use_cache,
883
+ output_attentions=output_attentions,
884
+ output_hidden_states=output_hidden_states,
885
+ )
886
+ sequence_output = outputs.last_hidden_state
887
+ sequence_output = self.dropout(sequence_output)
888
+ logits = self.score(sequence_output)
889
+
890
+ loss = None
891
+ if labels is not None:
892
+ loss = self.loss_function(logits, labels, self.config)
893
+
894
+ return TokenClassifierOutput(
895
+ loss=loss,
896
+ logits=logits,
897
+ hidden_states=outputs.hidden_states,
898
+ attentions=outputs.attentions,
899
+ )
tokenization_telechat3.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ import sentencepiece as spm
5
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
11
+
12
+ # TODO: when we get download url from huggingface, refresh the map
13
+ PRETRAINED_VOCAB_FILES_MAP = {
14
+ "vocab_file": {},
15
+ "tokenizer_file": {},
16
+ }
17
+
18
+
19
+ class Telechat3Tokenizer(PreTrainedTokenizer):
20
+ vocab_files_names = VOCAB_FILES_NAMES
21
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
22
+ model_input_names = ["input_ids", "attention_mask"]
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_file,
27
+ unk_token="<unk>",
28
+ bos_token="<_start>",
29
+ eos_token="<_end>",
30
+ pad_token="<_pad>",
31
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
32
+ add_bos_token=True,
33
+ add_eos_token=False,
34
+ clean_up_tokenization_spaces=False,
35
+ **kwargs,
36
+ ):
37
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
38
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
39
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
40
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
41
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
42
+ self.sp_model.Load(vocab_file)
43
+ super().__init__(
44
+ bos_token=bos_token,
45
+ eos_token=eos_token,
46
+ pad_token=pad_token,
47
+ add_bos_token=add_bos_token,
48
+ add_eos_token=add_eos_token,
49
+ sp_model_kwargs=self.sp_model_kwargs,
50
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
51
+ **kwargs,
52
+ )
53
+ self.vocab_file = vocab_file
54
+ self.add_bos_token = add_bos_token
55
+ self.add_eos_token = add_eos_token
56
+
57
+ def __getstate__(self):
58
+ state = self.__dict__.copy()
59
+ state["sp_model"] = None
60
+ return state
61
+
62
+ def __setstate__(self, d):
63
+ self.__dict__ = d
64
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
65
+ self.sp_model.Load(self.vocab_file)
66
+
67
+ @property
68
+ def vocab_size(self):
69
+ """Returns vocab size"""
70
+ return self.sp_model.get_piece_size()
71
+
72
+ def get_vocab(self):
73
+ """Returns vocab as a dict"""
74
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
75
+ vocab.update(self.added_tokens_encoder)
76
+ return vocab
77
+
78
+ @property
79
+ def vocab(self):
80
+ return self.get_vocab()
81
+
82
+ def _tokenize(self, text):
83
+ """Returns a tokenized string."""
84
+ return self.sp_model.encode(text, out_type=str)
85
+
86
+ def _convert_token_to_id(self, token):
87
+ """Converts a token (str) in an id using the vocab."""
88
+ return self.sp_model.piece_to_id(token)
89
+
90
+ def _convert_id_to_token(self, index):
91
+ """Converts an index (integer) in a token (str) using the vocab."""
92
+ token = self.sp_model.IdToPiece(index)
93
+ return token
94
+
95
+ def convert_tokens_to_string(self, tokens):
96
+ """Converts a sequence of tokens (string) in a single string."""
97
+ current_sub_tokens = []
98
+ out_string = ""
99
+ # prev_is_special = False
100
+ for i, token in enumerate(tokens):
101
+ # make sure that special tokens are not decoded using sentencepiece model
102
+ if token in self.all_special_tokens:
103
+ # if not prev_is_special and i != 0:
104
+ # out_string += " "
105
+ out_string += self.sp_model.decode(current_sub_tokens) + token
106
+ # prev_is_special = True
107
+ current_sub_tokens = []
108
+ else:
109
+ current_sub_tokens.append(token)
110
+ # prev_is_special = False
111
+ out_string += self.sp_model.decode(current_sub_tokens)
112
+ return out_string
113
+
114
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
115
+ """
116
+ Save the vocabulary and special tokens file to a directory.
117
+
118
+ Args:
119
+ save_directory (`str`):
120
+ The directory in which to save the vocabulary.
121
+
122
+ Returns:
123
+ `Tuple(str)`: Paths to the files saved.
124
+ """
125
+ if not os.path.isdir(save_directory):
126
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
127
+ return
128
+ out_vocab_file = os.path.join(
129
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
130
+ )
131
+
132
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
133
+ copyfile(self.vocab_file, out_vocab_file)
134
+ elif not os.path.isfile(self.vocab_file):
135
+ with open(out_vocab_file, "wb") as fi:
136
+ content_spiece_model = self.sp_model.serialized_model_proto()
137
+ fi.write(content_spiece_model)
138
+
139
+ return (out_vocab_file,)
140
+
141
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
142
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
143
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
144
+
145
+ output = bos_token_id + token_ids_0 + eos_token_id
146
+
147
+ if token_ids_1 is not None:
148
+ output = output + bos_token_id + token_ids_1 + eos_token_id
149
+
150
+ return output
151
+
152
+ def get_special_tokens_mask(
153
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
154
+ already_has_special_tokens: bool = False
155
+ ) -> List[int]:
156
+ """
157
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
158
+ special tokens using the tokenizer `prepare_for_model` method.
159
+
160
+ Args:
161
+ token_ids_0 (`List[int]`):
162
+ List of IDs.
163
+ token_ids_1 (`List[int]`, *optional*):
164
+ Optional second list of IDs for sequence pairs.
165
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
166
+ Whether or not the token list is already formatted with special tokens for the model.
167
+
168
+ Returns:
169
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
170
+ """
171
+ if already_has_special_tokens:
172
+ return super().get_special_tokens_mask(
173
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
174
+ )
175
+
176
+ bos_token_id = [1] if self.add_bos_token else []
177
+ eos_token_id = [1] if self.add_eos_token else []
178
+
179
+ if token_ids_1 is None:
180
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
181
+ return (
182
+ bos_token_id
183
+ + ([0] * len(token_ids_0))
184
+ + eos_token_id
185
+ + bos_token_id
186
+ + ([0] * len(token_ids_1))
187
+ + eos_token_id
188
+ )
189
+
190
+ def create_token_type_ids_from_sequences(
191
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
192
+ ) -> List[int]:
193
+ """
194
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
195
+ sequence pair mask has the following format:
196
+
197
+ ```
198
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
199
+ | first sequence | second sequence |
200
+ ```
201
+
202
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
203
+
204
+ Args:
205
+ token_ids_0 (`List[int]`):
206
+ List of ids.
207
+ token_ids_1 (`List[int]`, *optional*):
208
+ Optional second list of IDs for sequence pairs.
209
+
210
+ Returns:
211
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
212
+ """
213
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
214
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
215
+
216
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
217
+
218
+ if token_ids_1 is not None:
219
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
220
+
221
+ return output