yixinsong commited on
Commit
cbda492
·
1 Parent(s): 7ed7f72
Files changed (1) hide show
  1. modeling_falcon.py_bak +0 -1277
modeling_falcon.py_bak DELETED
@@ -1,1277 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch Falcon model."""
16
-
17
- import math
18
- from typing import Optional, Tuple, Union
19
-
20
- import torch
21
- import torch.utils.checkpoint
22
- from torch import nn
23
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
- from torch.nn import functional as F
25
-
26
- from transformers.modeling_outputs import (
27
- BaseModelOutputWithPastAndCrossAttentions,
28
- CausalLMOutputWithCrossAttentions,
29
- QuestionAnsweringModelOutput,
30
- SequenceClassifierOutputWithPast,
31
- TokenClassifierOutput,
32
- )
33
- from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
- from .configuration_falcon import FalconConfig
36
-
37
-
38
- logger = logging.get_logger(__name__)
39
-
40
- FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
- "tiiuae/falcon-40b",
42
- "tiiuae/falcon-40b-instruct",
43
- "tiiuae/falcon-7b",
44
- "tiiuae/falcon-7b-instruct",
45
- "tiiuae/falcon-rw-7b",
46
- "tiiuae/falcon-rw-1b",
47
- ]
48
- _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
- _CONFIG_FOR_DOC = "FalconConfig"
50
-
51
-
52
- # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
- # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
- class FalconLinear(nn.Linear):
55
- def forward(self, input: torch.Tensor) -> torch.Tensor:
56
- hidden_states = input @ self.weight.T
57
- if self.bias is None:
58
- return hidden_states
59
- return hidden_states + self.bias
60
-
61
-
62
- # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
- def rotate_half(x):
64
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
-
67
-
68
- class FalconRotaryEmbedding(nn.Module):
69
- """Implementation of RotaryEmbedding from GPT-NeoX.
70
- This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
- n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
- """
73
-
74
- def __init__(self, head_dim: int, base=10000):
75
- super().__init__()
76
- inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
- self.register_buffer("inv_freq", inv_freq, persistent=False)
78
- self.head_dim = head_dim
79
- self.seq_len_cached = -1
80
- self.cos_cached: torch.Tensor | None = None
81
- self.sin_cached: torch.Tensor | None = None
82
-
83
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
- total_length = seq_len + past_key_values_length
85
- if total_length > self.seq_len_cached:
86
- self.seq_len_cached = total_length
87
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
88
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
-
91
- if dtype in [torch.float16, torch.bfloat16]:
92
- emb = emb.float()
93
-
94
- self.cos_cached = emb.cos()[None, :, :]
95
- self.sin_cached = emb.sin()[None, :, :]
96
-
97
- self.cos_cached = self.cos_cached.type(dtype)
98
- self.sin_cached = self.sin_cached.type(dtype)
99
-
100
- return (
101
- self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
- self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
- )
104
-
105
- def forward(self, query, key, past_key_values_length=0):
106
- batch, seq_len, head_dim = query.shape
107
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
-
110
-
111
- def _make_causal_mask(
112
- input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
- ) -> torch.BoolTensor:
114
- """
115
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
- target_length, target_length+past_key_values_length]`.
118
- """
119
- batch_size, target_length = input_ids_shape
120
-
121
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
- mask = torch.cat([past_mask, mask], dim=-1)
127
- expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
- return expanded_mask
129
-
130
-
131
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
- """
133
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
- """
135
- batch_size, total_length = mask.shape
136
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
-
138
- expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
-
141
-
142
- def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
143
- batch_size, seq_length = attention_mask.shape
144
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
145
- base = torch.tensor(
146
- 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
147
- )
148
- powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
149
- slopes = torch.pow(base, powers)
150
-
151
- if closest_power_of_2 != num_heads:
152
- extra_base = torch.tensor(
153
- 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
154
- )
155
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
156
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
157
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
158
-
159
- # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
160
- # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
161
- # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
162
- # => the query_length dimension will then be broadcasted correctly
163
- # This is more or less identical to T5's relative position bias:
164
- # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
165
- arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
166
- alibi = slopes[..., None].bfloat16() * arange_tensor
167
- return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
-
169
-
170
- # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
- def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
- """
173
- Dropout add function
174
-
175
- Args:
176
- x (`torch.tensor`, *required*):
177
- input tensor
178
- residual (`torch.tensor`, *required*):
179
- residual tensor
180
- prob (`float`, *required*):
181
- dropout probability
182
- training (`bool`, *required*):
183
- training mode
184
- """
185
- out = F.dropout(x, p=prob, training=training)
186
- out = residual + out
187
- return out
188
-
189
-
190
- class FalconAttention(nn.Module):
191
- def __init__(self, config: FalconConfig):
192
- super().__init__()
193
-
194
- self.hidden_size = config.hidden_size
195
- self.num_heads = config.num_attention_heads
196
- self.head_dim = self.hidden_size // self.num_heads
197
- self.split_size = self.hidden_size
198
- self.hidden_dropout = config.hidden_dropout
199
-
200
- if self.head_dim * self.num_heads != self.hidden_size:
201
- raise ValueError(
202
- f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
203
- f" {self.num_heads})."
204
- )
205
-
206
- self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
-
208
- # Layer-wise attention scaling
209
- self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
- self.beta = self.inv_norm_factor
211
- if config.new_decoder_architecture:
212
- qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
- elif config.multi_query:
214
- qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
- else:
216
- qkv_out_dim = 3 * self.hidden_size
217
- self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
- self.new_decoder_architecture = config.new_decoder_architecture
219
- self.multi_query = config.multi_query
220
- self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
- self.attention_dropout = nn.Dropout(config.attention_dropout)
222
- self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
-
224
- def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
- """
226
- Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
227
-
228
- Args:
229
- fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
-
231
- Returns:
232
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
- value: [batch_size, seq_length, num_heads, head_dim]
234
- """
235
- if self.new_decoder_architecture:
236
- batch, seq_len, _ = fused_qkv.shape
237
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
- query = qkv[:, :, :, :-2]
239
- key = qkv[:, :, :, [-2]]
240
- value = qkv[:, :, :, [-1]]
241
- key = torch.broadcast_to(key, query.shape)
242
- value = torch.broadcast_to(value, query.shape)
243
-
244
- query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
- return query, key, value
246
- elif not self.multi_query:
247
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
250
- else:
251
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
252
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
-
255
- # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
- def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
- """
258
- Merge heads together over the last dimenstion
259
-
260
- Args:
261
- x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
-
263
- Returns:
264
- torch.tensor: [batch_size, seq_length, num_heads * head_dim]
265
- """
266
- # What we want to achieve is:
267
- # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
268
- batch_size_and_num_heads, seq_length, _ = x.shape
269
- batch_size = batch_size_and_num_heads // self.num_heads
270
-
271
- # First view to decompose the batch size
272
- # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
273
- x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
274
-
275
- # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
276
- x = x.permute(0, 2, 1, 3)
277
-
278
- # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
279
- return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
280
-
281
- def forward(
282
- self,
283
- hidden_states: torch.Tensor,
284
- alibi: Optional[torch.Tensor],
285
- attention_mask: torch.Tensor,
286
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
- head_mask: Optional[torch.Tensor] = None,
288
- use_cache: bool = False,
289
- output_attentions: bool = False,
290
- ):
291
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
- # 3 x [batch_size, seq_length, num_heads, head_dim]
294
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
-
296
- batch_size, query_length, _, _ = query_layer.shape
297
-
298
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
- key_layer = key_layer.transpose(1, 2).reshape(
300
- batch_size * num_kv_heads,
301
- query_length,
302
- self.head_dim,
303
- )
304
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
-
306
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
-
309
- if layer_past is not None:
310
- past_key, past_value = layer_past
311
- # concatenate along seq_length dimension:
312
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
- key_layer = torch.cat((past_key, key_layer), dim=1)
315
- value_layer = torch.cat((past_value, value_layer), dim=1)
316
-
317
- _, kv_length, _ = key_layer.shape
318
- if use_cache:
319
- present = (key_layer, value_layer)
320
- else:
321
- present = None
322
-
323
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
-
325
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
-
329
- if alibi is None:
330
- if output_attentions:
331
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
- # to do it by hand if we want them
333
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
- attention_scores /= math.sqrt(self.head_dim)
335
-
336
- attention_scores = F.softmax(
337
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
- )
339
- attn_output = attention_scores @ value_layer_
340
- else:
341
- attn_output = F.scaled_dot_product_attention(
342
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
- )
344
- attention_scores = None
345
-
346
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
- attn_output = attn_output.permute(0, 2, 1, 3)
348
- attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
-
350
- output_tensor = self.dense(attn_output)
351
-
352
- if output_attentions:
353
- return output_tensor, present, attention_scores
354
- else:
355
- return output_tensor, present
356
-
357
- else:
358
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
359
-
360
- # change view to [batch_size, num_heads, q_length, kv_length]
361
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
-
363
- # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
- input_dtype = attention_scores.dtype
365
- # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
- if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
- attention_scores = attention_scores.to(torch.float32)
368
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
- # and you'd like to experiment and maybe file a PR, feel free!
372
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
- attention_logits *= self.inv_norm_factor
374
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
- # [batch_size, num_heads, q_length, kv_length]
376
- attention_probs = self.attention_dropout(attention_probs)
377
-
378
- if head_mask is not None:
379
- attention_probs = attention_probs * head_mask
380
-
381
- # change view [batch_size, num_heads, q_length, kv_length]
382
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
-
384
- # matmul: [batch_size * num_heads, q_length, head_dim]
385
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
-
387
- # change view [batch_size, num_heads, q_length, head_dim]
388
- context_layer = self._merge_heads(context_layer)
389
-
390
- output_tensor = self.dense(context_layer)
391
-
392
- if output_attentions:
393
- return output_tensor, present, attention_probs
394
- else:
395
- return output_tensor, present
396
-
397
-
398
- class FalconMLP(nn.Module):
399
- def __init__(self, config: FalconConfig, i):
400
- super().__init__()
401
- hidden_size = config.hidden_size
402
-
403
- self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
- # self.act = nn.GELU()
405
- self.act = nn.ReLU()
406
- self.layer_id = i
407
- # print('using relu')
408
- self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
409
- self.hidden_dropout = config.hidden_dropout
410
- self.cnt = 0
411
- self.record = torch.zeros(config.ffn_dim, dtype=torch.int64).cuda()
412
-
413
- def forward(self, x: torch.Tensor) -> torch.Tensor:
414
- x = self.act(self.dense_h_to_4h(x))
415
- hidden_states = x
416
- self.record += torch.count_nonzero(hidden_states, dim=0).to(self.record.device)
417
- self.cnt += 1
418
- self.tot += hidden_states.shape[0]
419
- if(self.cnt > 500):
420
- tmp = self.record.cpu()
421
- torch.save(tmp, "/nvme/syx/activation_13b_count/activation_{}.pt".format(self.layer_id))
422
- if (self.layer_id == 47):
423
- import sys
424
- sys.exit()
425
- x = self.dense_4h_to_h(x)
426
- return x
427
-
428
-
429
- class FalconDecoderLayer(nn.Module):
430
- def __init__(self, config: FalconConfig):
431
- super().__init__()
432
- hidden_size = config.hidden_size
433
- self.num_heads = config.num_attention_heads
434
- self.self_attention = FalconAttention(config)
435
- self.mlp = FalconMLP(config)
436
- self.hidden_dropout = config.hidden_dropout
437
- self.config = config
438
-
439
- if config.new_decoder_architecture:
440
- # The layer norm before self-attention
441
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
442
- # The layer norm before the MLP
443
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
444
- else:
445
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
446
- if not config.parallel_attn:
447
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
448
-
449
- def forward(
450
- self,
451
- hidden_states: torch.Tensor,
452
- alibi: Optional[torch.Tensor],
453
- attention_mask: torch.Tensor,
454
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
455
- head_mask: Optional[torch.Tensor] = None,
456
- use_cache: bool = False,
457
- output_attentions: bool = False,
458
- ):
459
- residual = hidden_states
460
-
461
- if self.config.new_decoder_architecture:
462
- attention_layernorm_out = self.ln_attn(hidden_states)
463
- mlp_layernorm_out = self.ln_mlp(hidden_states)
464
- else:
465
- attention_layernorm_out = self.input_layernorm(hidden_states)
466
-
467
- # Self attention.
468
- attn_outputs = self.self_attention(
469
- attention_layernorm_out,
470
- layer_past=layer_past,
471
- attention_mask=attention_mask,
472
- alibi=alibi,
473
- head_mask=head_mask,
474
- use_cache=use_cache,
475
- output_attentions=output_attentions,
476
- )
477
-
478
- attention_output = attn_outputs[0]
479
-
480
- if not self.config.new_decoder_architecture:
481
- if self.config.parallel_attn:
482
- mlp_layernorm_out = attention_layernorm_out
483
- else:
484
- residual = dropout_add(
485
- attention_output, residual, self.config.attention_dropout, training=self.training
486
- )
487
- mlp_layernorm_out = self.post_attention_layernorm(residual)
488
-
489
- outputs = attn_outputs[1:]
490
-
491
- # MLP.
492
- mlp_output = self.mlp(mlp_layernorm_out)
493
-
494
- if self.config.new_decoder_architecture or self.config.parallel_attn:
495
- mlp_output += attention_output
496
-
497
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
498
-
499
- if use_cache:
500
- outputs = (output,) + outputs
501
- else:
502
- outputs = (output,) + outputs[1:]
503
-
504
- return outputs # hidden_states, present, attentions
505
-
506
-
507
- FALCON_START_DOCSTRING = r"""
508
-
509
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
510
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
511
-
512
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
513
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
514
- and behavior.
515
-
516
- Parameters:
517
- config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
518
- Initializing with a config file does not load the weights associated with the model, only the
519
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
520
- """
521
-
522
- FALCON_INPUTS_DOCSTRING = r"""
523
- Args:
524
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
525
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
526
- (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
527
-
528
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
529
- `input_ids`.
530
-
531
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
532
- [`PreTrainedTokenizer.__call__`] for details.
533
-
534
- [What are input IDs?](../glossary#input-ids)
535
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
536
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
537
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
538
- their past given to this model should not be passed as `input_ids` as they have already been computed.
539
-
540
- Each element of `past_key_values` is a tuple (past_key, past_value):
541
- - past_key: [batch_size * num_heads, head_dim, kv_length]
542
- - past_value: [batch_size * num_heads, kv_length, head_dim]
543
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
544
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
545
-
546
- - 1 for tokens that are **not masked**,
547
- - 0 for tokens that are **masked**.
548
-
549
- [What are attention masks?](../glossary#attention-mask)
550
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
551
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
552
-
553
- - 1 indicates the head is **not masked**,
554
- - 0 indicates the head is **masked**.
555
-
556
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
557
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
558
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
559
- model's internal embedding lookup matrix.
560
-
561
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
562
- `past_key_values`).
563
- use_cache (`bool`, *optional*):
564
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
565
- `past_key_values`).
566
- output_attentions (`bool`, *optional*):
567
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
568
- tensors for more detail.
569
- output_hidden_states (`bool`, *optional*):
570
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
571
- more detail.
572
- return_dict (`bool`, *optional*):
573
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
574
- """
575
-
576
-
577
- class FalconPreTrainedModel(PreTrainedModel):
578
- """
579
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
580
- models.
581
- """
582
-
583
- config_class = FalconConfig
584
- base_model_prefix = "transformer"
585
- supports_gradient_checkpointing = True
586
- _no_split_modules = ["FalconDecoderLayer"]
587
-
588
- def __init__(self, *inputs, **kwargs):
589
- super().__init__(*inputs, **kwargs)
590
-
591
- def _init_weights(self, module: nn.Module):
592
- """Initialize the weights."""
593
- if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
594
- # Slightly different from the TF version which uses truncated_normal for initialization
595
- # cf https://github.com/pytorch/pytorch/pull/5617
596
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
597
- if module.bias is not None:
598
- module.bias.data.zero_()
599
- elif isinstance(module, nn.Embedding):
600
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
601
- if module.padding_idx is not None:
602
- module.weight.data[module.padding_idx].zero_()
603
- elif isinstance(module, LayerNorm):
604
- module.bias.data.zero_()
605
- module.weight.data.fill_(1.0)
606
-
607
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
608
- def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
609
- if isinstance(module, FalconModel):
610
- module.gradient_checkpointing = value
611
-
612
- @staticmethod
613
- def _convert_cache_to_standard_format(
614
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
615
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
616
- """
617
- Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
618
- num_heads, ...]))
619
- """
620
- batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
621
- # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
622
- # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
623
- # on whether we use multi_query attention.
624
- num_heads = batch_size_times_num_heads // batch_size
625
- return tuple(
626
- (
627
- layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
628
- layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
629
- )
630
- for layer_past in past_key_value
631
- )
632
-
633
- @staticmethod
634
- def _convert_to_rw_cache(
635
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
636
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
637
- batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
638
- batch_size_times_num_heads = batch_size * num_heads
639
- # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
640
- return tuple(
641
- (
642
- layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
643
- layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
644
- )
645
- for layer_past in past_key_value
646
- )
647
-
648
-
649
- @add_start_docstrings(
650
- "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
651
- FALCON_START_DOCSTRING,
652
- )
653
- class FalconModel(FalconPreTrainedModel):
654
- def __init__(self, config: FalconConfig):
655
- super().__init__(config)
656
-
657
- self.embed_dim = config.hidden_size
658
- self.num_heads = config.num_attention_heads
659
- self.use_alibi = config.alibi
660
-
661
- # Embedding + LN Embedding
662
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
663
-
664
- # Transformer blocks
665
- self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
666
-
667
- # Final Layer Norm
668
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
669
-
670
- self.gradient_checkpointing = False
671
-
672
- # Initialize weights and apply final processing
673
- self.post_init()
674
-
675
- def get_input_embeddings(self):
676
- return self.word_embeddings
677
-
678
- @staticmethod
679
- def _prepare_attn_mask(
680
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
681
- ) -> torch.BoolTensor:
682
- # Create a causal mask
683
- # The attention mask we receive as input should cover the whole extended sequence, including any past
684
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
685
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
686
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
687
- raise ValueError(
688
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
689
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
690
- f" {past_key_values_length}."
691
- )
692
- combined_attention_mask = None
693
- device = attention_mask.device
694
- _, seq_length = input_shape
695
-
696
- if seq_length > 1:
697
- combined_attention_mask = _make_causal_mask(
698
- input_shape, device=device, past_key_values_length=past_key_values_length
699
- )
700
-
701
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
702
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
703
- combined_attention_mask = (
704
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
705
- )
706
-
707
- return combined_attention_mask
708
-
709
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
710
- self.word_embeddings = new_embeddings
711
-
712
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
713
- @add_code_sample_docstrings(
714
- checkpoint=_CHECKPOINT_FOR_DOC,
715
- output_type=BaseModelOutputWithPastAndCrossAttentions,
716
- config_class=_CONFIG_FOR_DOC,
717
- )
718
- def forward(
719
- self,
720
- input_ids: Optional[torch.LongTensor] = None,
721
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
722
- attention_mask: Optional[torch.Tensor] = None,
723
- head_mask: Optional[torch.LongTensor] = None,
724
- inputs_embeds: Optional[torch.LongTensor] = None,
725
- use_cache: Optional[bool] = None,
726
- output_attentions: Optional[bool] = None,
727
- output_hidden_states: Optional[bool] = None,
728
- return_dict: Optional[bool] = None,
729
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
730
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
731
- output_hidden_states = (
732
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
733
- )
734
- use_cache = use_cache if use_cache is not None else self.config.use_cache
735
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
736
-
737
- if input_ids is not None and inputs_embeds is not None:
738
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
739
- elif input_ids is not None:
740
- batch_size, seq_length = input_ids.shape
741
- elif inputs_embeds is not None:
742
- batch_size, seq_length, _ = inputs_embeds.shape
743
- else:
744
- raise ValueError("You have to specify either input_ids or inputs_embeds")
745
-
746
- if past_key_values is None:
747
- past_key_values = tuple([None] * len(self.h))
748
- else:
749
- past_key_values = self._convert_to_rw_cache(past_key_values)
750
-
751
- # Prepare head mask if needed
752
- # 1.0 in head_mask indicate we keep the head
753
- # attention_probs has shape batch_size x num_heads x N x N
754
- # head_mask has shape n_layer x batch x num_heads x N x N
755
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
756
-
757
- if inputs_embeds is None:
758
- inputs_embeds = self.word_embeddings(input_ids)
759
-
760
- hidden_states = inputs_embeds
761
-
762
- presents = () if use_cache else None
763
- all_self_attentions = () if output_attentions else None
764
- all_hidden_states = () if output_hidden_states else None
765
-
766
- # Compute alibi tensor: check build_alibi_tensor documentation
767
- past_key_values_length = 0
768
- if past_key_values[0] is not None:
769
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
770
- if attention_mask is None:
771
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
772
- else:
773
- attention_mask = attention_mask.to(hidden_states.device)
774
-
775
- if self.use_alibi:
776
- alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
777
- else:
778
- alibi = None
779
-
780
- causal_mask = self._prepare_attn_mask(
781
- attention_mask,
782
- input_shape=(batch_size, seq_length),
783
- past_key_values_length=past_key_values_length,
784
- )
785
-
786
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
787
- if output_hidden_states:
788
- all_hidden_states = all_hidden_states + (hidden_states,)
789
-
790
- if self.gradient_checkpointing and self.training:
791
- if use_cache:
792
- logger.warning(
793
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
794
- )
795
- use_cache = False
796
-
797
- def create_custom_forward(module):
798
- def custom_forward(*inputs):
799
- # None for past_key_value
800
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
801
-
802
- return custom_forward
803
-
804
- outputs = torch.utils.checkpoint.checkpoint(
805
- create_custom_forward(block),
806
- hidden_states,
807
- alibi,
808
- causal_mask,
809
- head_mask[i],
810
- )
811
- else:
812
- outputs = block(
813
- hidden_states,
814
- layer_past=layer_past,
815
- attention_mask=causal_mask,
816
- head_mask=head_mask[i],
817
- use_cache=use_cache,
818
- output_attentions=output_attentions,
819
- alibi=alibi,
820
- )
821
-
822
- hidden_states = outputs[0]
823
- if use_cache is True:
824
- presents = presents + (outputs[1],)
825
-
826
- if output_attentions:
827
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
828
-
829
- # Add last hidden state
830
- hidden_states = self.ln_f(hidden_states)
831
-
832
- if output_hidden_states:
833
- all_hidden_states = all_hidden_states + (hidden_states,)
834
-
835
- if presents is not None:
836
- presents = self._convert_cache_to_standard_format(presents, batch_size)
837
-
838
- if not return_dict:
839
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
840
-
841
- return BaseModelOutputWithPastAndCrossAttentions(
842
- last_hidden_state=hidden_states,
843
- past_key_values=presents,
844
- hidden_states=all_hidden_states,
845
- attentions=all_self_attentions,
846
- )
847
-
848
-
849
- @add_start_docstrings(
850
- "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
851
- FALCON_START_DOCSTRING,
852
- )
853
- class FalconForCausalLM(FalconPreTrainedModel):
854
- _tied_weights_keys = ["lm_head.weight"]
855
-
856
- def __init__(self, config: FalconConfig):
857
- super().__init__(config)
858
- self.transformer = FalconModel(config)
859
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
860
-
861
- # Initialize weights and apply final processing
862
- self.post_init()
863
-
864
- def get_output_embeddings(self):
865
- return self.lm_head
866
-
867
- def set_output_embeddings(self, new_embeddings: torch.Tensor):
868
- self.lm_head = new_embeddings
869
-
870
- def prepare_inputs_for_generation(
871
- self,
872
- input_ids: torch.LongTensor,
873
- past_key_values: Optional[torch.Tensor] = None,
874
- attention_mask: Optional[torch.Tensor] = None,
875
- **kwargs,
876
- ) -> dict:
877
- if past_key_values is not None:
878
- input_ids = input_ids[:, -1:]
879
-
880
- return {
881
- "input_ids": input_ids,
882
- "past_key_values": past_key_values,
883
- "use_cache": kwargs.get("use_cache"),
884
- "attention_mask": attention_mask,
885
- }
886
-
887
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
888
- @add_code_sample_docstrings(
889
- checkpoint=_CHECKPOINT_FOR_DOC,
890
- output_type=CausalLMOutputWithCrossAttentions,
891
- config_class=_CONFIG_FOR_DOC,
892
- )
893
- def forward(
894
- self,
895
- input_ids: Optional[torch.LongTensor] = None,
896
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
897
- attention_mask: Optional[torch.Tensor] = None,
898
- head_mask: Optional[torch.Tensor] = None,
899
- inputs_embeds: Optional[torch.Tensor] = None,
900
- labels: Optional[torch.Tensor] = None,
901
- use_cache: Optional[bool] = None,
902
- output_attentions: Optional[bool] = None,
903
- output_hidden_states: Optional[bool] = None,
904
- return_dict: Optional[bool] = None,
905
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
906
- r"""
907
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
908
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
909
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
910
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
911
- """
912
-
913
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
914
-
915
- transformer_outputs = self.transformer(
916
- input_ids,
917
- past_key_values=past_key_values,
918
- attention_mask=attention_mask,
919
- head_mask=head_mask,
920
- inputs_embeds=inputs_embeds,
921
- use_cache=use_cache,
922
- output_attentions=output_attentions,
923
- output_hidden_states=output_hidden_states,
924
- return_dict=return_dict,
925
- )
926
- hidden_states = transformer_outputs[0]
927
-
928
- lm_logits = self.lm_head(hidden_states)
929
-
930
- loss = None
931
- if labels is not None:
932
- # Shift so that tokens < n predict n
933
- shift_logits = lm_logits[..., :-1, :].contiguous()
934
- shift_labels = labels[..., 1:].contiguous()
935
- batch_size, seq_length, vocab_size = shift_logits.shape
936
- # Flatten the tokens
937
- loss_fct = CrossEntropyLoss()
938
- loss = loss_fct(
939
- shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
940
- )
941
-
942
- if not return_dict:
943
- output = (lm_logits,) + transformer_outputs[1:]
944
- return ((loss,) + output) if loss is not None else output
945
-
946
- return CausalLMOutputWithCrossAttentions(
947
- loss=loss,
948
- logits=lm_logits,
949
- past_key_values=transformer_outputs.past_key_values,
950
- hidden_states=transformer_outputs.hidden_states,
951
- attentions=transformer_outputs.attentions,
952
- )
953
-
954
- def _reorder_cache(
955
- self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
956
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
957
- """
958
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
959
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
960
- beam_idx at every generation step.
961
-
962
- Output shares the same memory storage as `past`.
963
- """
964
-
965
- # Get a copy of `beam_idx` on all the devices where we need those indices.
966
- device_to_beam_idx = {
967
- past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
968
- }
969
- reordered_past = tuple(
970
- (
971
- layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
972
- layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
973
- )
974
- for layer_past in past
975
- )
976
- return reordered_past
977
-
978
-
979
- @add_start_docstrings(
980
- """
981
- The Falcon Model transformer with a sequence classification head on top (linear layer).
982
-
983
- [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
984
- (e.g. GPT-1) do.
985
-
986
- Since it does classification on the last token, it requires to know the position of the last token. If a
987
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
988
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
989
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
990
- each row of the batch).
991
- """,
992
- FALCON_START_DOCSTRING,
993
- )
994
- class FalconForSequenceClassification(FalconPreTrainedModel):
995
- def __init__(self, config: FalconConfig):
996
- super().__init__(config)
997
- self.num_labels = config.num_labels
998
- self.transformer = FalconModel(config)
999
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
1000
-
1001
- # Initialize weights and apply final processing
1002
- self.post_init()
1003
-
1004
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1005
- @add_code_sample_docstrings(
1006
- checkpoint=_CHECKPOINT_FOR_DOC,
1007
- output_type=SequenceClassifierOutputWithPast,
1008
- config_class=_CONFIG_FOR_DOC,
1009
- )
1010
- def forward(
1011
- self,
1012
- input_ids: Optional[torch.LongTensor] = None,
1013
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1014
- attention_mask: Optional[torch.Tensor] = None,
1015
- head_mask: Optional[torch.Tensor] = None,
1016
- inputs_embeds: Optional[torch.Tensor] = None,
1017
- labels: Optional[torch.Tensor] = None,
1018
- use_cache: Optional[bool] = None,
1019
- output_attentions: Optional[bool] = None,
1020
- output_hidden_states: Optional[bool] = None,
1021
- return_dict: Optional[bool] = None,
1022
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1023
- r"""
1024
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1025
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1026
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1027
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1028
- """
1029
-
1030
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1031
-
1032
- transformer_outputs = self.transformer(
1033
- input_ids,
1034
- past_key_values=past_key_values,
1035
- attention_mask=attention_mask,
1036
- head_mask=head_mask,
1037
- inputs_embeds=inputs_embeds,
1038
- use_cache=use_cache,
1039
- output_attentions=output_attentions,
1040
- output_hidden_states=output_hidden_states,
1041
- return_dict=return_dict,
1042
- )
1043
-
1044
- hidden_states = transformer_outputs[0]
1045
- logits = self.score(hidden_states)
1046
-
1047
- if input_ids is not None:
1048
- batch_size = input_ids.shape[0]
1049
- else:
1050
- batch_size = inputs_embeds.shape[0]
1051
-
1052
- if self.config.pad_token_id is None and batch_size != 1:
1053
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1054
- if self.config.pad_token_id is None:
1055
- sequence_lengths = -1
1056
- else:
1057
- if input_ids is not None:
1058
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1059
- else:
1060
- sequence_lengths = -1
1061
- logger.warning(
1062
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1063
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1064
- )
1065
-
1066
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1067
-
1068
- loss = None
1069
- if labels is not None:
1070
- if self.config.problem_type is None:
1071
- if self.num_labels == 1:
1072
- self.config.problem_type = "regression"
1073
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1074
- self.config.problem_type = "single_label_classification"
1075
- else:
1076
- self.config.problem_type = "multi_label_classification"
1077
-
1078
- if self.config.problem_type == "regression":
1079
- loss_fct = MSELoss()
1080
- if self.num_labels == 1:
1081
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1082
- else:
1083
- loss = loss_fct(pooled_logits, labels)
1084
- elif self.config.problem_type == "single_label_classification":
1085
- loss_fct = CrossEntropyLoss()
1086
- loss = loss_fct(pooled_logits, labels)
1087
- elif self.config.problem_type == "multi_label_classification":
1088
- loss_fct = BCEWithLogitsLoss()
1089
- loss = loss_fct(pooled_logits, labels)
1090
- if not return_dict:
1091
- output = (pooled_logits,) + transformer_outputs[1:]
1092
- return ((loss,) + output) if loss is not None else output
1093
-
1094
- return SequenceClassifierOutputWithPast(
1095
- loss=loss,
1096
- logits=pooled_logits,
1097
- past_key_values=transformer_outputs.past_key_values,
1098
- hidden_states=transformer_outputs.hidden_states,
1099
- attentions=transformer_outputs.attentions,
1100
- )
1101
-
1102
-
1103
- @add_start_docstrings(
1104
- """
1105
- Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1106
- Named-Entity-Recognition (NER) tasks.
1107
- """,
1108
- FALCON_START_DOCSTRING,
1109
- )
1110
- class FalconForTokenClassification(FalconPreTrainedModel):
1111
- def __init__(self, config: FalconConfig):
1112
- super().__init__(config)
1113
- self.num_labels = config.num_labels
1114
-
1115
- self.transformer = FalconModel(config)
1116
- if getattr(config, "classifier_dropout", None) is not None:
1117
- classifier_dropout = config.classifier_dropout
1118
- elif getattr(config, "hidden_dropout", None) is not None:
1119
- classifier_dropout = config.hidden_dropout
1120
- else:
1121
- classifier_dropout = 0.1
1122
- self.dropout = nn.Dropout(classifier_dropout)
1123
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1124
-
1125
- # Initialize weights and apply final processing
1126
- self.post_init()
1127
-
1128
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1129
- @add_code_sample_docstrings(
1130
- checkpoint=_CHECKPOINT_FOR_DOC,
1131
- output_type=TokenClassifierOutput,
1132
- config_class=_CONFIG_FOR_DOC,
1133
- )
1134
- def forward(
1135
- self,
1136
- input_ids: Optional[torch.LongTensor] = None,
1137
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1138
- attention_mask: Optional[torch.Tensor] = None,
1139
- head_mask: Optional[torch.Tensor] = None,
1140
- inputs_embeds: Optional[torch.Tensor] = None,
1141
- labels: Optional[torch.Tensor] = None,
1142
- use_cache: Optional[bool] = None,
1143
- output_attentions: Optional[bool] = None,
1144
- output_hidden_states: Optional[bool] = None,
1145
- return_dict: Optional[bool] = None,
1146
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1147
- r"""
1148
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1149
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1150
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1151
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1152
- """
1153
-
1154
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1155
-
1156
- transformer_outputs = self.transformer(
1157
- input_ids,
1158
- past_key_values=past_key_values,
1159
- attention_mask=attention_mask,
1160
- head_mask=head_mask,
1161
- inputs_embeds=inputs_embeds,
1162
- use_cache=use_cache,
1163
- output_attentions=output_attentions,
1164
- output_hidden_states=output_hidden_states,
1165
- return_dict=return_dict,
1166
- )
1167
-
1168
- hidden_states = transformer_outputs[0]
1169
- hidden_states = self.dropout(hidden_states)
1170
- logits = self.classifier(hidden_states)
1171
-
1172
- loss = None
1173
- if labels is not None:
1174
- batch_size, seq_length = labels.shape
1175
- loss_fct = CrossEntropyLoss()
1176
- loss = loss_fct(
1177
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1178
- )
1179
-
1180
- if not return_dict:
1181
- output = (logits,) + transformer_outputs[2:]
1182
- return ((loss,) + output) if loss is not None else output
1183
-
1184
- return TokenClassifierOutput(
1185
- loss=loss,
1186
- logits=logits,
1187
- hidden_states=transformer_outputs.hidden_states,
1188
- attentions=transformer_outputs.attentions,
1189
- )
1190
-
1191
-
1192
- @add_start_docstrings(
1193
- """
1194
- The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1195
- SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1196
- """,
1197
- FALCON_START_DOCSTRING,
1198
- )
1199
- class FalconForQuestionAnswering(FalconPreTrainedModel):
1200
- def __init__(self, config):
1201
- super().__init__(config)
1202
- self.transformer = FalconModel(config)
1203
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1204
-
1205
- # Initialize weights and apply final processing
1206
- self.post_init()
1207
-
1208
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1209
- def forward(
1210
- self,
1211
- input_ids: Optional[torch.LongTensor] = None,
1212
- attention_mask: Optional[torch.FloatTensor] = None,
1213
- head_mask: Optional[torch.FloatTensor] = None,
1214
- inputs_embeds: Optional[torch.FloatTensor] = None,
1215
- start_positions: Optional[torch.LongTensor] = None,
1216
- end_positions: Optional[torch.LongTensor] = None,
1217
- output_attentions: Optional[bool] = None,
1218
- output_hidden_states: Optional[bool] = None,
1219
- return_dict: Optional[bool] = None,
1220
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1221
- r"""
1222
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1223
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1224
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1225
- are not taken into account for computing the loss.
1226
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1227
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1228
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1229
- are not taken into account for computing the loss.
1230
- """
1231
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1232
-
1233
- outputs = self.transformer(
1234
- input_ids,
1235
- attention_mask=attention_mask,
1236
- head_mask=head_mask,
1237
- inputs_embeds=inputs_embeds,
1238
- output_attentions=output_attentions,
1239
- output_hidden_states=output_hidden_states,
1240
- return_dict=return_dict,
1241
- )
1242
-
1243
- sequence_output = outputs[0]
1244
-
1245
- logits = self.qa_outputs(sequence_output)
1246
- start_logits, end_logits = logits.split(1, dim=-1)
1247
- start_logits = start_logits.squeeze(-1).contiguous()
1248
- end_logits = end_logits.squeeze(-1).contiguous()
1249
-
1250
- total_loss = None
1251
- if start_positions is not None and end_positions is not None:
1252
- # If we are on multi-GPU, split add a dimension
1253
- if len(start_positions.size()) > 1:
1254
- start_positions = start_positions.squeeze(-1)
1255
- if len(end_positions.size()) > 1:
1256
- end_positions = end_positions.squeeze(-1)
1257
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1258
- ignored_index = start_logits.size(1)
1259
- start_positions = start_positions.clamp(0, ignored_index)
1260
- end_positions = end_positions.clamp(0, ignored_index)
1261
-
1262
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1263
- start_loss = loss_fct(start_logits, start_positions)
1264
- end_loss = loss_fct(end_logits, end_positions)
1265
- total_loss = (start_loss + end_loss) / 2
1266
-
1267
- if not return_dict:
1268
- output = (start_logits, end_logits) + outputs[2:]
1269
- return ((total_loss,) + output) if total_loss is not None else output
1270
-
1271
- return QuestionAnsweringModelOutput(
1272
- loss=total_loss,
1273
- start_logits=start_logits,
1274
- end_logits=end_logits,
1275
- hidden_states=outputs.hidden_states,
1276
- attentions=outputs.attentions,
1277
- )