ishanjmukherjee commited on
Commit
bfda7e6
·
1 Parent(s): 014c0fa

Python files for modeling

Browse files
Files changed (3) hide show
  1. configuration_glm2.py +40 -0
  2. glm_tokenizer.py +53 -0
  3. modeling_glm2.py +466 -0
configuration_glm2.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM2 model configuration
2
+
3
+ Copied straight from https://huggingface.co/tattabio/gLM2_650M/blob/main/configuration_glm2.py
4
+ """
5
+
6
+ from typing import Optional
7
+ from transformers import PretrainedConfig
8
+ from transformers.utils import logging
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class gLM2Config(PretrainedConfig):
14
+ model_type = "gLM2"
15
+
16
+ def __init__(
17
+ self,
18
+ dim: int = 640,
19
+ depth: int = 30,
20
+ heads: int = 10,
21
+ vocab_size: int = 37,
22
+ swiglu_multiple_of: int = 256,
23
+ ffn_dim_multiplier: Optional[float] = None,
24
+ norm_eps: float = 1e-5,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.dim = dim
29
+ self.depth = depth
30
+ self.heads = heads
31
+ self.vocab_size = vocab_size
32
+ self.swiglu_multiple_of = swiglu_multiple_of
33
+ self.ffn_dim_multiplier = ffn_dim_multiplier
34
+ self.norm_eps = norm_eps
35
+
36
+ self.auto_map = {
37
+ "AutoConfig": "configuration_glm2.gLM2Config",
38
+ "AutoModel": "modeling_glm2.gLM2Model",
39
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
40
+ }
glm_tokenizer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM tokenizer
2
+
3
+ Copied straight from https://huggingface.co/tattabio/gLM2_650M/blob/main/glm_tokenizer.py
4
+ """
5
+
6
+ from tokenizers import Tokenizer
7
+ from tokenizers.models import BPE
8
+ from transformers import PreTrainedTokenizerFast
9
+
10
+
11
+ class gLM2Tokenizer(PreTrainedTokenizerFast):
12
+
13
+ VOCAB = [
14
+ "<cls>", "<pad>", "<eos>", "<unk>",
15
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
16
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
17
+ "O", "a", "t", "c", "g", "<+>", "<->", "<mask>", "<sep>",
18
+ ]
19
+
20
+ def __init__(
21
+ self,
22
+ unk_token="<unk>",
23
+ cls_token="<cls>",
24
+ pad_token="<pad>",
25
+ mask_token="<mask>",
26
+ eos_token="<eos>",
27
+ sep_token="<sep>",
28
+ pos_token="<+>",
29
+ neg_token="<->",
30
+ **kwargs,
31
+ ):
32
+ all_tokens = self.VOCAB
33
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
34
+
35
+ bpe = BPE(token_to_id, merges=[], unk_token=str(unk_token))
36
+ tokenizer = Tokenizer(bpe)
37
+ special_tokens = [cls_token, pad_token,
38
+ mask_token, eos_token, sep_token, pos_token, neg_token]
39
+
40
+ tokenizer.add_special_tokens(
41
+ special_tokens,
42
+ )
43
+
44
+ super().__init__(
45
+ tokenizer_object=tokenizer,
46
+ unk_token=unk_token,
47
+ cls_token=cls_token,
48
+ pad_token=pad_token,
49
+ mask_token=mask_token,
50
+ eos_token=eos_token,
51
+ sep_token=sep_token,
52
+ **kwargs,
53
+ )
modeling_glm2.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch gLM2 model.
2
+
3
+ Copied straight from https://huggingface.co/tattabio/gLM2_650M/blob/main/modeling_glm2.py
4
+ """
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from typing import Optional, Tuple, Union
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutput,
13
+ MaskedLMOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+ from .configuration_glm2 import gLM2Config
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ def rotate_half(x, interleaved=False):
23
+ if not interleaved:
24
+ x1, x2 = x.chunk(2, dim=-1)
25
+ return torch.cat((-x2, x1), dim=-1)
26
+ else:
27
+ x1, x2 = x[..., ::2], x[..., 1::2]
28
+ return rearrange(
29
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
30
+ )
31
+
32
+
33
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
34
+ """
35
+ x: (batch_size, seqlen, nheads, headdim)
36
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
37
+ """
38
+ ro_dim = cos.shape[-1] * 2
39
+ assert ro_dim <= x.shape[-1]
40
+ seqlen = x.shape[1]
41
+ cos, sin = cos[:seqlen], sin[:seqlen]
42
+ cos = repeat(
43
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
44
+ )
45
+ sin = repeat(
46
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
+ )
48
+ return torch.cat(
49
+ [
50
+ x[..., :ro_dim] * cos +
51
+ rotate_half(x[..., :ro_dim], interleaved) * sin,
52
+ x[..., ro_dim:],
53
+ ],
54
+ dim=-1,
55
+ )
56
+
57
+
58
+ class RotaryEmbedding(torch.nn.Module):
59
+ """
60
+ Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
61
+ Changed to use the torch version of apply_rotary_emb_func.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ dim: int,
67
+ base=10000.0,
68
+ interleaved=False,
69
+ scale_base=None,
70
+ pos_idx_in_fp32=True,
71
+ device=None,
72
+ ):
73
+ super().__init__()
74
+ self.dim = dim
75
+ self.base = float(base)
76
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
77
+ # Generate and save the inverse frequency buffer (non trainable)
78
+ inv_freq = self._compute_inv_freq(device)
79
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
80
+ self.interleaved = interleaved
81
+ self.scale_base = scale_base
82
+ scale = (
83
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
84
+ / (1.4 * dim)
85
+ if scale_base is not None
86
+ else None
87
+ )
88
+ self.register_buffer("scale", scale, persistent=False)
89
+
90
+ self._seq_len_cached = 0
91
+ self._cos_cached = None
92
+ self._sin_cached = None
93
+ self._cos_k_cached = None
94
+ self._sin_k_cached = None
95
+
96
+ def _compute_inv_freq(self, device=None):
97
+ return 1.0 / (
98
+ self.base
99
+ ** (
100
+ torch.arange(0, self.dim, 2, device=device,
101
+ dtype=torch.float32)
102
+ / self.dim
103
+ )
104
+ )
105
+
106
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
107
+ # Reset the tables if the sequence length has changed,
108
+ # if we're on a new device (possibly due to tracing for instance),
109
+ # or if we're switching from inference mode to training
110
+ if (
111
+ seqlen > self._seq_len_cached
112
+ or self._cos_cached is None
113
+ or self._cos_cached.device != device
114
+ or self._cos_cached.dtype != dtype
115
+ or (self.training and self._cos_cached.is_inference())
116
+ ):
117
+ self._seq_len_cached = seqlen
118
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
119
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
120
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
121
+ if self.pos_idx_in_fp32:
122
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
123
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
124
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
125
+ # cos & sin output to change significantly.
126
+ # We want to recompute self.inv_freq if it was not loaded in fp32
127
+ if self.inv_freq.dtype != torch.float32:
128
+ inv_freq = self._compute_inv_freq(device=device)
129
+ else:
130
+ inv_freq = self.inv_freq
131
+ else:
132
+ t = torch.arange(seqlen, device=device,
133
+ dtype=self.inv_freq.dtype)
134
+ inv_freq = self.inv_freq
135
+ # Don't do einsum, it converts fp32 to fp16 under AMP
136
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
+ freqs = torch.outer(t, inv_freq)
138
+ if self.scale is None:
139
+ self._cos_cached = torch.cos(freqs).to(dtype)
140
+ self._sin_cached = torch.sin(freqs).to(dtype)
141
+ else:
142
+ power = (
143
+ torch.arange(
144
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
145
+ )
146
+ - seqlen // 2
147
+ ) / self.scale_base
148
+ scale = self.scale.to(device=power.device) ** rearrange(
149
+ power, "s -> s 1"
150
+ )
151
+ # We want the multiplication by scale to happen in fp32
152
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
153
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
154
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
155
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
156
+
157
+ def forward(
158
+ self,
159
+ qkv: torch.Tensor,
160
+ max_seqlen: Optional[int] = None,
161
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
162
+ """
163
+ qkv: (batch, seqlen, 3, nheads, headdim)
164
+ """
165
+ seqlen = qkv.shape[1]
166
+ if seqlen > self._seq_len_cached:
167
+ self._update_cos_sin_cache(
168
+ seqlen, device=qkv.device, dtype=qkv.dtype)
169
+ elif max_seqlen is not None:
170
+ self._update_cos_sin_cache(
171
+ max_seqlen, device=qkv.device, dtype=qkv.dtype)
172
+ q_rot = apply_rotary_emb_torch(
173
+ qkv[:, :, 0], self._cos_cached, self._sin_cached, self.interleaved
174
+ )
175
+ k_rot = apply_rotary_emb_torch(
176
+ qkv[:, :, 1], self._cos_cached, self._sin_cached, self.interleaved
177
+ )
178
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
179
+
180
+
181
+ # @torch.jit.script
182
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
183
+ """Apply the root mean square normalization."""
184
+ input_dtype = hidden_states.dtype
185
+ hidden_states = hidden_states.to(torch.float32)
186
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
187
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
188
+ return (weight * hidden_states).to(input_dtype)
189
+
190
+
191
+ class RMSNorm(nn.Module):
192
+ """Root mean square normalization."""
193
+
194
+ def __init__(self, dim, eps=1e-6):
195
+ super().__init__()
196
+ self.weight = nn.Parameter(torch.ones(dim))
197
+ self.register_buffer(
198
+ "variance_epsilon",
199
+ torch.tensor(eps),
200
+ persistent=False,
201
+ )
202
+
203
+ def forward(self, hidden_states):
204
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
205
+
206
+
207
+ class Attention(nn.Module):
208
+ """Multi-head attention module."""
209
+
210
+ def __init__(self, config: gLM2Config):
211
+ super().__init__()
212
+ self.n_heads = config.heads
213
+ self.head_dim = config.dim // config.heads
214
+
215
+ self.wqkv = nn.Linear(config.dim, self.n_heads *
216
+ self.head_dim * 3, bias=False)
217
+ self.wo = nn.Linear(config.heads * self.head_dim,
218
+ config.dim, bias=False)
219
+
220
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
221
+
222
+ def forward(
223
+ self,
224
+ x: torch.Tensor,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ ) -> torch.Tensor:
227
+ bsz, seqlen, h_size = x.shape
228
+ qkv = self.wqkv(x)
229
+
230
+ qkv = qkv.view(bsz, seqlen, 3, self.n_heads, self.head_dim)
231
+ qkv = self.rotary_emb(qkv)
232
+
233
+ # (batch, nheads, 3, seqlen, headdim)
234
+ qkv = torch.transpose(qkv, 3, 1)
235
+ q = qkv[:, :, 0]
236
+ k = qkv[:, :, 1]
237
+ v = qkv[:, :, 2]
238
+ if attention_mask is not None:
239
+ attention_mask = attention_mask[:, None, None, :]
240
+ attention_mask = attention_mask.expand(
241
+ bsz, self.n_heads, seqlen, seqlen
242
+ ).bool()
243
+ # [B, heads, seq, D]
244
+ output = torch.nn.functional.scaled_dot_product_attention(
245
+ q, k, v, attn_mask=attention_mask
246
+ )
247
+ output = output.permute(0, 2, 1, 3).contiguous()
248
+
249
+ output = output.view(bsz, seqlen, h_size)
250
+ return self.wo(output)
251
+
252
+
253
+ class FeedForward(nn.Module):
254
+ def __init__(
255
+ self,
256
+ dim: int,
257
+ hidden_dim: int,
258
+ multiple_of: int,
259
+ ffn_dim_multiplier: Optional[float],
260
+ ):
261
+ """
262
+ SwiGLU FeedForward module.
263
+
264
+ Args:
265
+ dim (int): Input dimension.
266
+ hidden_dim (int): Hidden dimension of the feedforward layer.
267
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
268
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
269
+ """
270
+ super().__init__()
271
+ hidden_dim = int(2 * hidden_dim / 3)
272
+ # custom dim factor multiplier
273
+ if ffn_dim_multiplier is not None:
274
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
275
+ hidden_dim = multiple_of * \
276
+ ((hidden_dim + multiple_of - 1) // multiple_of)
277
+
278
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
279
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
280
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
281
+
282
+ def forward(self, x):
283
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
284
+
285
+
286
+ class TransformerBlock(nn.Module):
287
+ def __init__(self, config: gLM2Config):
288
+ super().__init__()
289
+ self.n_heads = config.heads
290
+ self.dim = config.dim
291
+ self.head_dim = config.dim // config.heads
292
+ self.attention = Attention(config)
293
+ self.feed_forward = FeedForward(
294
+ dim=config.dim,
295
+ hidden_dim=4 * config.dim,
296
+ multiple_of=config.swiglu_multiple_of,
297
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
298
+ )
299
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
300
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
301
+
302
+ def forward(
303
+ self,
304
+ x: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ ) -> torch.Tensor:
307
+ r = self.attention(self.attention_norm(
308
+ x), attention_mask=attention_mask)
309
+ h = x + r
310
+ r = self.feed_forward(self.ffn_norm(h))
311
+ out = h + r
312
+ return out
313
+
314
+
315
+ class TransformerLayers(nn.Module):
316
+ def __init__(self, config: gLM2Config):
317
+ super().__init__()
318
+ self.config = config
319
+ self.layers = torch.nn.ModuleList(
320
+ [TransformerBlock(config=config) for _ in range(config.depth)]
321
+ )
322
+
323
+ def forward(
324
+ self,
325
+ x: torch.FloatTensor,
326
+ attention_mask: Optional[torch.BoolTensor] = None,
327
+ return_all_hiddens: bool = False,
328
+ ):
329
+ if x.shape[-1] != self.config.dim:
330
+ raise ValueError(
331
+ f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
332
+ )
333
+ hiddens = []
334
+ for layer in self.layers:
335
+ x = layer(x, attention_mask=attention_mask)
336
+ if return_all_hiddens:
337
+ hiddens.append(x)
338
+
339
+ if return_all_hiddens:
340
+ return x, hiddens
341
+ return x
342
+
343
+
344
+ class gLM2PreTrainedModel(PreTrainedModel):
345
+ """
346
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
347
+ models.
348
+ """
349
+ config_class = gLM2Config
350
+ base_model_prefix = "glm2"
351
+ supports_gradient_checkpointing = False
352
+
353
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
354
+ def _init_weights(module, initializer_range=0.02):
355
+ if isinstance(module, nn.Linear):
356
+ nn.init.normal_(module.weight, std=initializer_range)
357
+ if module.bias is not None:
358
+ nn.init.zeros_(module.bias)
359
+ elif isinstance(module, nn.Embedding):
360
+ nn.init.normal_(module.weight, std=initializer_range)
361
+ if module.padding_idx is not None:
362
+ nn.init.zeros_(module.weight[module.padding_idx])
363
+
364
+
365
+ class gLM2Model(gLM2PreTrainedModel):
366
+ """gLM2 Model."""
367
+
368
+ def __init__(self, config: gLM2Config):
369
+ super().__init__(config)
370
+ self.config = config
371
+
372
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
373
+ self.encoder = TransformerLayers(config)
374
+ # Initialize weights and apply final processing
375
+ self.post_init()
376
+
377
+ def forward(
378
+ self,
379
+ input_ids: torch.Tensor,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
384
+ output_hidden_states = (
385
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
386
+ )
387
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
388
+
389
+ h = self.tok_embeddings(input_ids)
390
+ if output_hidden_states:
391
+ sequence_output, all_hidden_states = self.encoder(
392
+ h, attention_mask, return_all_hiddens=True)
393
+ else:
394
+ sequence_output = self.encoder(h, attention_mask)
395
+ all_hidden_states = None
396
+
397
+ if not return_dict:
398
+ return (sequence_output, all_hidden_states)
399
+
400
+ return BaseModelOutput(
401
+ last_hidden_state=sequence_output,
402
+ hidden_states=all_hidden_states,
403
+
404
+ )
405
+
406
+
407
+ class gLM2ForMaskedLM(gLM2PreTrainedModel):
408
+
409
+ def __init__(self, config: gLM2Config):
410
+ super().__init__(config)
411
+
412
+ self.glm2 = gLM2Model(config)
413
+ self.lm_head = gLM2LMHead(config)
414
+ self.init_weights()
415
+
416
+ def forward(
417
+ self,
418
+ input_ids: torch.Tensor,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ labels: Optional[torch.LongTensor] = None,
421
+ output_hidden_states: Optional[bool] = None,
422
+ return_dict: Optional[bool] = None,
423
+ ) -> Union[Tuple, MaskedLMOutput]:
424
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
425
+
426
+ outputs = self.glm2(
427
+ input_ids,
428
+ attention_mask=attention_mask,
429
+ output_hidden_states=output_hidden_states,
430
+ return_dict=return_dict,
431
+ )
432
+ sequence_output = outputs[0]
433
+ prediction_scores = self.lm_head(sequence_output)
434
+
435
+ masked_lm_loss = None
436
+ if labels is not None:
437
+ loss_fct = CrossEntropyLoss()
438
+
439
+ labels = labels.to(prediction_scores.device)
440
+ masked_lm_loss = loss_fct(
441
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
442
+
443
+ if not return_dict:
444
+ output = (prediction_scores,) + outputs[2:]
445
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
446
+
447
+ return MaskedLMOutput(
448
+ loss=masked_lm_loss,
449
+ logits=prediction_scores,
450
+ hidden_states=outputs.hidden_states,
451
+ attentions=outputs.attentions,
452
+ )
453
+
454
+
455
+ class gLM2LMHead(nn.Module):
456
+ """gLM2 head for masked language modeling."""
457
+
458
+ def __init__(self, config):
459
+ super().__init__()
460
+
461
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
462
+ self.proj_output = nn.Linear(
463
+ config.dim, config.vocab_size, bias=False)
464
+
465
+ def forward(self, features):
466
+ return self.proj_output(self.norm(features))