xiaoyewuz-Ruster commited on
Commit
cb9b291
·
verified ·
1 Parent(s): 10c27bb

Upload TextGenerationPipeline

Browse files
Files changed (1) hide show
  1. zzjrabbit3.py +299 -0
zzjrabbit3.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from tokenizers import Tokenizer, decoders, pre_tokenizers
6
+ from tokenizers.models import BPE
7
+ from transformers import (
8
+ GenerationMixin,
9
+ PreTrainedConfig,
10
+ PreTrainedModel,
11
+ TokenizersBackend,
12
+ )
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
14
+
15
+
16
+ class ZZJRabbit3Config(PreTrainedConfig):
17
+ model_type = "zzjrabbit3"
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size: int = 100000,
22
+ hidden_size: int = 1024,
23
+ num_hidden_layers: int = 12,
24
+ num_attention_heads: int = 8,
25
+ attention_dropout: float | int = 0.0,
26
+ pad_token_id: int | None = None,
27
+ eos_token_id: int | list[int] | None = None,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.hidden_size = hidden_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.attention_dropout = attention_dropout
35
+ self.pad_token_id = pad_token_id
36
+ self.eos_token_id = eos_token_id
37
+ super().__init__(**kwargs)
38
+
39
+
40
+ class ZZJRabbit3RotaryEmbedding(nn.Module):
41
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
42
+ """
43
+ Rotary Embedding 模块
44
+
45
+ Args:
46
+ dim: 每个 token embedding 的维度
47
+ max_position_embeddings: 最大位置数
48
+ base: rotary embedding 的频率基底
49
+ """
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.base = base
53
+
54
+ # 生成频率向量
55
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
56
+ self.register_buffer("inv_freq", inv_freq)
57
+
58
+ # 可选:预先计算 cos/sin
59
+ t = torch.arange(max_position_embeddings, dtype=torch.float32)
60
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
61
+ self.register_buffer("cos_cached", freqs.cos())
62
+ self.register_buffer("sin_cached", freqs.sin())
63
+
64
+ def forward(self, position_ids):
65
+ """
66
+ position_ids: (batch_size, seq_len)
67
+ 返回:
68
+ cos: (batch_size, seq_len, dim)
69
+ sin: (batch_size, seq_len, dim)
70
+ """
71
+ # 从缓存中选取对应位置
72
+ cos = self.cos_cached[position_ids] # shape (batch, seq_len, dim/2)
73
+ sin = self.sin_cached[position_ids]
74
+
75
+ # 将维度对齐为 (dim)
76
+ # cos/sin 当前 shape 为 (..., dim/2),重复到 dim
77
+ cos = torch.stack([cos, cos], dim=-1).flatten(-2)
78
+ sin = torch.stack([sin, sin], dim=-1).flatten(-2)
79
+ return cos, sin
80
+
81
+
82
+ def rotate_half(x):
83
+ """[-x2, x1]"""
84
+ x1 = x[..., : x.shape[-1] // 2]
85
+ x2 = x[..., x.shape[-1] // 2 :]
86
+ return torch.cat((-x2, x1), dim=-1)
87
+
88
+
89
+ def apply_rotary_pos_emb(q, k, sin, cos):
90
+ cos = cos.unsqueeze(1)
91
+ sin = sin.unsqueeze(1)
92
+
93
+ q_embed = (q * cos) + (rotate_half(q) * sin)
94
+ k_embed = (k * cos) + (rotate_half(k) * sin)
95
+
96
+ return q_embed, k_embed
97
+
98
+
99
+ class ZZJRabbit3Attention(nn.Module):
100
+ def __init__(self, config: ZZJRabbit3Config):
101
+ super().__init__()
102
+ self.config = config
103
+ self.head_dim = config.hidden_size // config.num_attention_heads
104
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
105
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
106
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
107
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
108
+ self.dropout = nn.Dropout(0.1)
109
+
110
+ def forward(
111
+ self,
112
+ x: torch.Tensor,
113
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
114
+ key_padding_mask: Optional[torch.BoolTensor] = None,
115
+ attn_mask: Optional[torch.BoolTensor] = None,
116
+ ) -> torch.Tensor:
117
+ batch_size = x.size(0)
118
+ Q = (
119
+ self.q_proj(x)
120
+ .view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
121
+ .transpose(1, 2)
122
+ )
123
+ K = (
124
+ self.k_proj(x)
125
+ .view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
126
+ .transpose(1, 2)
127
+ )
128
+ V = (
129
+ self.v_proj(x)
130
+ .view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
131
+ .transpose(1, 2)
132
+ )
133
+ cos, sin = position_embeddings
134
+ Q, K = apply_rotary_pos_emb(Q, K, sin.to(Q.dtype), cos.to(Q.dtype))
135
+ scores = torch.matmul(Q, K.transpose(-2, -1)) * (self.head_dim**-0.5)
136
+ if key_padding_mask is not None:
137
+ scores = scores.masked_fill(
138
+ key_padding_mask.view(batch_size, 1, 1, -1), float("-inf")
139
+ )
140
+ if attn_mask is not None:
141
+ scores = scores.masked_fill(attn_mask, float("-inf"))
142
+ attn_weights = nn.functional.softmax(scores, dim=-1)
143
+ attn_weights = self.dropout(attn_weights)
144
+ context = torch.matmul(attn_weights, V)
145
+ context = context.transpose(1, 2).contiguous()
146
+ context = context.view(batch_size, -1, self.config.hidden_size)
147
+ return self.out_proj(context)
148
+
149
+
150
+ class ZZJRabbit3Layer(nn.Module):
151
+ def __init__(self, config: ZZJRabbit3Config):
152
+ super().__init__()
153
+ self.attn = ZZJRabbit3Attention(config)
154
+ self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
155
+ self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
156
+ self.activate = nn.ReLU()
157
+ self.norm = nn.RMSNorm(config.hidden_size)
158
+
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ postition_embeddings: tuple[torch.Tensor, torch.Tensor],
163
+ attention_mask: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+ key_padding_mask = None
166
+ attn_mask = torch.gt(
167
+ torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0
168
+ )
169
+ if attention_mask is not None:
170
+ key_padding_mask = torch.lt(attention_mask, 1)
171
+ attn = self.attn(
172
+ x,
173
+ postition_embeddings,
174
+ key_padding_mask=key_padding_mask,
175
+ attn_mask=attn_mask,
176
+ )
177
+ x = self.norm(x + attn)
178
+ o = self.l1(x)
179
+ o = self.activate(o)
180
+ o = self.l2(o)
181
+ return self.norm(x + o)
182
+
183
+
184
+ class ZZJRabbit3Model(PreTrainedModel):
185
+ config_class = ZZJRabbit3Config
186
+
187
+ def __init__(self, config: ZZJRabbit3Config, **kwargs):
188
+ super().__init__(config, **kwargs)
189
+ self.config = config
190
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
191
+ self.rotary_emb = ZZJRabbit3RotaryEmbedding(
192
+ config.hidden_size // config.num_attention_heads
193
+ )
194
+ self.layers = nn.ModuleList(
195
+ [ZZJRabbit3Layer(config) for _ in range(config.num_hidden_layers)]
196
+ )
197
+ self.post_init()
198
+
199
+ def forward(
200
+ self,
201
+ input_ids: torch.Tensor,
202
+ return_dict: Optional[bool] = None,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ **kwargs,
205
+ ) -> tuple | BaseModelOutput:
206
+ res = self.embedding(input_ids)
207
+ batch_size, seq_len = input_ids.shape
208
+ position_ids = (
209
+ torch.arange(seq_len, device=input_ids.device)
210
+ .unsqueeze(0)
211
+ .expand(batch_size, -1)
212
+ )
213
+ position_embeddings = self.rotary_emb(position_ids)
214
+ for layer in self.layers:
215
+ res = layer(res, position_embeddings, attention_mask)
216
+ if not return_dict:
217
+ return (res,)
218
+ else:
219
+ return BaseModelOutput(res)
220
+
221
+
222
+ class ZZJRabbit3ForCausalLM(PreTrainedModel, GenerationMixin):
223
+ config_class = ZZJRabbit3Config
224
+
225
+ def __init__(self, config: ZZJRabbit3Config, **kwargs):
226
+ super().__init__(config, **kwargs)
227
+ self.model = ZZJRabbit3Model(config, **kwargs)
228
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
229
+ self.post_init()
230
+
231
+ def forward(
232
+ self,
233
+ input_ids: torch.Tensor,
234
+ return_dict: Optional[bool] = None,
235
+ labels: Optional[torch.Tensor] = None,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ logits_to_keep: Union[int, torch.Tensor] = 0,
238
+ **kwargs,
239
+ ) -> tuple | CausalLMOutput:
240
+ hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
241
+ logits = self.lm_head(
242
+ hidden[
243
+ :,
244
+ slice(-logits_to_keep, None)
245
+ if isinstance(logits_to_keep, int)
246
+ else logits_to_keep,
247
+ :,
248
+ ]
249
+ )
250
+ if labels is not None:
251
+ loss = self.loss_function(
252
+ logits=logits,
253
+ labels=labels,
254
+ vocab_size=self.config.vocab_size,
255
+ **kwargs,
256
+ )
257
+ if not return_dict:
258
+ return (loss, logits) if labels is not None else (logits,)
259
+ else:
260
+ return (
261
+ CausalLMOutput(logits=logits, loss=loss)
262
+ if labels is not None
263
+ else CausalLMOutput(logits=logits)
264
+ )
265
+
266
+ @classmethod
267
+ def can_generate(cls):
268
+ return True
269
+
270
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
271
+ return {"input_ids": input_ids}
272
+
273
+
274
+ class ZZJRabbit3Tokenizer(TokenizersBackend):
275
+ model = BPE
276
+
277
+ def __init__(
278
+ self,
279
+ vocab=None,
280
+ merges=None,
281
+ unk_token="<eos>",
282
+ eos_token="<eos>",
283
+ pad_token="<eos>",
284
+ **kwargs,
285
+ ):
286
+ self._vocab = vocab or {
287
+ "<eos>": 0,
288
+ }
289
+ self._merges = merges or []
290
+
291
+ self._tokenizer = Tokenizer(
292
+ BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
293
+ )
294
+ self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
295
+ self._tokenizer.decoder = decoders.ByteLevel()
296
+
297
+ super().__init__(
298
+ unk_token=unk_token, eos_token=eos_token, pad_token=pad_token, **kwargs
299
+ )