klemenk commited on
Commit
29b3bed
·
verified ·
1 Parent(s): 13e5b3f

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +453 -101
modeling.py CHANGED
@@ -1,130 +1,482 @@
1
  """
2
- GSLM Model Configuration
 
3
  """
4
 
5
- import json
 
 
 
6
  import os
7
- from typing import Optional
 
 
8
 
 
 
 
 
 
 
9
 
10
- class GSLMConfig:
11
- """
12
- Configuration class for GSLM (Generative Spoken Language Model).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- This configuration class stores all parameters needed to initialize a GSLMModel.
15
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- model_type = "gslm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __init__(
20
- self,
21
- vocab_size: int = 204,
22
- d_model: int = 1024,
23
- nhead: int = 16,
24
- num_layers: int = 12,
25
- dim_feedforward: int = 4096,
26
  dropout: float = 0.1,
27
  attention_dropout: float = 0.1,
28
- max_seq_length: int = 3072,
29
- pad_idx: int = 0,
30
- share_input_output_embed: bool = True,
31
- activation: str = "relu",
32
- architecture: str = "transformer_lm_big",
33
- **kwargs
34
  ):
35
- """
36
- Initialize GSLM configuration.
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  Args:
39
- vocab_size: Size of the vocabulary
40
- d_model: Dimensionality of the embeddings and hidden states
41
- nhead: Number of attention heads
42
- num_layers: Number of transformer layers
43
- dim_feedforward: Dimensionality of the feedforward network
44
- dropout: Dropout probability
45
- attention_dropout: Dropout probability for attention weights
46
- max_seq_length: Maximum sequence length
47
- pad_idx: Padding token index
48
- share_input_output_embed: Whether to share input and output embeddings
49
- activation: Activation function ("relu" or "gelu")
50
- architecture: Model architecture name
51
  """
52
- self.vocab_size = vocab_size
53
- self.d_model = d_model
54
- self.nhead = nhead
55
- self.num_layers = num_layers
56
- self.dim_feedforward = dim_feedforward
57
- self.dropout = dropout
58
- self.attention_dropout = attention_dropout
59
- self.max_seq_length = max_seq_length
60
- self.pad_idx = pad_idx
61
- self.share_input_output_embed = share_input_output_embed
62
- self.activation = activation
63
- self.architecture = architecture
64
-
65
- # Handle any extra kwargs
66
- for key, value in kwargs.items():
67
- setattr(self, key, value)
68
-
69
- def to_dict(self):
70
- """Convert configuration to dictionary."""
71
- output = {}
72
- for key, value in self.__dict__.items():
73
- if not key.startswith('_'):
74
- output[key] = value
75
- output['model_type'] = self.model_type
76
- return output
77
-
78
- def to_json_string(self):
79
- """Convert configuration to JSON string."""
80
- return json.dumps(self.to_dict(), indent=2, sort_keys=True)
81
 
82
- def save_pretrained(self, save_directory: str):
83
- """Save configuration to directory."""
84
- if not os.path.exists(save_directory):
85
- os.makedirs(save_directory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- config_file = os.path.join(save_directory, "config.json")
88
- with open(config_file, 'w') as f:
89
- f.write(self.to_json_string())
90
-
91
- @classmethod
92
- def from_dict(cls, config_dict: dict):
93
- """Create configuration from dictionary."""
94
- return cls(**config_dict)
95
-
96
- @classmethod
97
- def from_json_file(cls, json_file: str):
98
- """Create configuration from JSON file."""
99
- with open(json_file, 'r') as f:
100
- config_dict = json.load(f)
101
- return cls.from_dict(config_dict)
102
-
103
- @classmethod
104
- def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """
106
- Load configuration from pretrained model.
107
 
108
  Args:
109
- pretrained_model_name_or_path: Path to pretrained model or model identifier
110
- **kwargs: Additional configuration parameters to override
 
 
 
 
111
 
112
  Returns:
113
- GSLMConfig instance
114
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if os.path.isdir(pretrained_model_name_or_path):
116
- config_file = os.path.join(pretrained_model_name_or_path, "config.json")
117
  else:
118
- config_file = pretrained_model_name_or_path
 
 
 
 
119
 
120
- # Load config from file
121
- config = cls.from_json_file(config_file)
122
-
123
- # Override with any provided kwargs
124
- for key, value in kwargs.items():
125
- setattr(config, key, value)
126
 
127
- return config
 
 
128
 
129
- def __repr__(self):
130
- return f"{self.__class__.__name__} {self.to_json_string()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ GSLM Unit Language Model - HuggingFace Compatible Implementation
3
+ Based on fairseq's transformer_lm_big architecture
4
  """
5
 
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import math
10
  import os
11
+ import json
12
+ from typing import Optional, Tuple, Dict, Union, List
13
+ from dataclasses import dataclass
14
 
15
+ # Import config - handle both local and remote imports
16
+ try:
17
+ from .config import GSLMConfig
18
+ except ImportError:
19
+ # Fallback for when file is accessed directly
20
+ from config import GSLMConfig
21
 
22
+ # Import or define the output classes
23
+ @dataclass
24
+ class BaseModelOutput:
25
+ last_hidden_state: torch.FloatTensor
26
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
28
+
29
+ @dataclass
30
+ class CausalLMOutput:
31
+ loss: Optional[torch.FloatTensor] = None
32
+ logits: Union[torch.FloatTensor, List[torch.FloatTensor]] = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
35
+
36
+
37
+ class PositionalEncoding(nn.Module):
38
+ """Sinusoidal positional encoding for transformer models."""
39
 
40
+ def __init__(self, d_model: int, max_len: int = 5000):
41
+ super().__init__()
42
+ pe = torch.zeros(max_len, d_model)
43
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
44
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
45
+ (-math.log(10000.0) / d_model))
46
+ pe[:, 0::2] = torch.sin(position * div_term)
47
+ pe[:, 1::2] = torch.cos(position * div_term)
48
+ self.register_buffer('pe', pe.unsqueeze(0))
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Add positional encoding to input tensor."""
52
+ return x + self.pe[:, :x.size(1)]
53
+
54
+
55
+ class MultiheadAttention(nn.Module):
56
+ """Multi-head attention mechanism."""
57
 
58
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
59
+ super().__init__()
60
+ assert embed_dim % num_heads == 0
61
+ self.embed_dim = embed_dim
62
+ self.num_heads = num_heads
63
+ self.head_dim = embed_dim // num_heads
64
+ self.scaling = self.head_dim ** -0.5
65
+
66
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.attn_dropout = nn.Dropout(dropout)
71
+
72
+ def forward(
73
+ self,
74
+ query: torch.Tensor,
75
+ key: Optional[torch.Tensor] = None,
76
+ value: Optional[torch.Tensor] = None,
77
+ attn_mask: Optional[torch.Tensor] = None,
78
+ key_padding_mask: Optional[torch.Tensor] = None
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """
81
+ Args:
82
+ query: [batch_size, tgt_len, embed_dim]
83
+ key: [batch_size, src_len, embed_dim]
84
+ value: [batch_size, src_len, embed_dim]
85
+ attn_mask: [tgt_len, src_len] or [batch_size * num_heads, tgt_len, src_len]
86
+ key_padding_mask: [batch_size, src_len]
87
+ """
88
+ if key is None:
89
+ key = query
90
+ if value is None:
91
+ value = query
92
+
93
+ batch_size, tgt_len, embed_dim = query.size()
94
+ src_len = key.size(1)
95
+
96
+ # Project and reshape
97
+ q = self.q_proj(query) * self.scaling
98
+ k = self.k_proj(key)
99
+ v = self.v_proj(value)
100
+
101
+ q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+ k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2)
103
+ v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2)
104
+
105
+ # Compute attention scores
106
+ attn_weights = torch.matmul(q, k.transpose(-2, -1))
107
+
108
+ # Apply masks
109
+ if attn_mask is not None:
110
+ if attn_mask.dim() == 2:
111
+ attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
112
+ attn_weights = attn_weights + attn_mask
113
+
114
+ if key_padding_mask is not None:
115
+ attn_weights = attn_weights.masked_fill(
116
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
117
+ float('-inf')
118
+ )
119
+
120
+ # Softmax
121
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
122
+ attn_weights = self.attn_dropout(attn_weights)
123
+
124
+ # Apply attention to values
125
+ attn_output = torch.matmul(attn_weights, v)
126
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
127
+ batch_size, tgt_len, embed_dim
128
+ )
129
+ attn_output = self.out_proj(attn_output)
130
+
131
+ return attn_output, attn_weights
132
+
133
+
134
+ class TransformerDecoderLayer(nn.Module):
135
+ """Transformer decoder layer."""
136
 
137
  def __init__(
138
+ self,
139
+ d_model: int,
140
+ nhead: int,
141
+ dim_feedforward: int = 2048,
 
 
142
  dropout: float = 0.1,
143
  attention_dropout: float = 0.1,
144
+ activation: str = "relu"
 
 
 
 
 
145
  ):
146
+ super().__init__()
147
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout)
148
+
149
+ # Feedforward network
150
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
151
+ self.dropout = nn.Dropout(dropout)
152
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
153
 
154
+ # Layer normalization
155
+ self.norm1 = nn.LayerNorm(d_model)
156
+ self.norm2 = nn.LayerNorm(d_model)
157
+
158
+ # Dropout modules
159
+ self.dropout1 = nn.Dropout(dropout)
160
+ self.dropout2 = nn.Dropout(dropout)
161
+
162
+ # Activation
163
+ self.activation = F.relu if activation == "relu" else F.gelu
164
+
165
+ def forward(
166
+ self,
167
+ x: torch.Tensor,
168
+ self_attn_mask: Optional[torch.Tensor] = None,
169
+ self_attn_padding_mask: Optional[torch.Tensor] = None
170
+ ) -> torch.Tensor:
171
+ """
172
  Args:
173
+ x: [batch_size, seq_len, d_model]
174
+ self_attn_mask: [seq_len, seq_len]
175
+ self_attn_padding_mask: [batch_size, seq_len]
 
 
 
 
 
 
 
 
 
176
  """
177
+ # Self-attention block
178
+ residual = x
179
+ x = self.norm1(x)
180
+ x, _ = self.self_attn(x, x, x, self_attn_mask, self_attn_padding_mask)
181
+ x = self.dropout1(x)
182
+ x = residual + x
183
+
184
+ # Feedforward block
185
+ residual = x
186
+ x = self.norm2(x)
187
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
188
+ x = self.dropout2(x)
189
+ x = residual + x
190
+
191
+ return x
192
+
193
+
194
+ class GSLMForCausalLM(nn.Module):
195
+ """
196
+ GSLM Unit Language Model - Transformer LM Big Architecture
197
+ HuggingFace compatible version with modified forward API
198
+ """
 
 
 
 
 
 
 
199
 
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ self.config = config
203
+
204
+ self.d_model = config.d_model
205
+ self.vocab_size = config.vocab_size
206
+ self.pad_idx = getattr(config, 'pad_idx', 0)
207
+ self.max_seq_length = config.max_seq_length
208
+
209
+ # Create transformer module container for compatibility
210
+ self.transformer = nn.Module()
211
+
212
+ # Token embeddings (wte for compatibility)
213
+ self.transformer.wte = nn.Embedding(config.vocab_size, config.d_model, padding_idx=self.pad_idx)
214
+ self.embed_scale = math.sqrt(config.d_model)
215
+
216
+ # Positional encoding
217
+ self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length)
218
+
219
+ # Transformer decoder layers (h for compatibility)
220
+ self.transformer.h = nn.ModuleList([
221
+ TransformerDecoderLayer(
222
+ config.d_model,
223
+ config.nhead,
224
+ config.dim_feedforward,
225
+ config.dropout,
226
+ config.attention_dropout
227
+ ) for _ in range(config.num_layers)
228
+ ])
229
+
230
+ # Final layer norm (ln_f for compatibility)
231
+ self.transformer.ln_f = nn.LayerNorm(config.d_model)
232
+
233
+ # Output projection (coch_head for compatibility)
234
+ if config.share_input_output_embed:
235
+ self.coch_head = lambda x: F.linear(x, self.transformer.wte.weight)
236
+ else:
237
+ self.coch_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
238
 
239
+ # Dropout
240
+ self.transformer.drop = nn.Dropout(config.dropout)
241
+
242
+ # Future heads not supported in GSLM
243
+ self.future_heads = None
244
+
245
+ # Initialize parameters
246
+ self.init_weights()
247
+
248
+ def init_weights(self):
249
+ """Initialize model parameters."""
250
+ # Initialize embeddings
251
+ nn.init.normal_(self.transformer.wte.weight, mean=0, std=self.d_model ** -0.5)
252
+ nn.init.constant_(self.transformer.wte.weight[self.pad_idx], 0)
253
+
254
+ # Initialize output projection if not shared
255
+ if not self.config.share_input_output_embed:
256
+ nn.init.normal_(self.coch_head.weight, mean=0, std=self.d_model ** -0.5)
257
+
258
+ def _create_causal_mask(self, seq_len: int, device) -> torch.Tensor:
259
+ """Create causal attention mask."""
260
+ mask = torch.triu(
261
+ torch.full((seq_len, seq_len), float('-inf'), device=device),
262
+ diagonal=1
263
+ )
264
+ return mask
265
+
266
+ def forward(
267
+ self,
268
+ seq,
269
+ tgt=None,
270
+ output_logits=False,
271
+ output_hidden_states=False,
272
+ return_dict=False,
273
+ up_until_layer=None
274
+ ):
275
  """
276
+ Compatible forward method with the specified API.
277
 
278
  Args:
279
+ seq: torch.Tensor of shape (b, t) - input token IDs
280
+ tgt: torch.Tensor of shape (b, t) or None - target token IDs
281
+ output_logits: bool - whether to output logits
282
+ output_hidden_states: bool - whether to output all hidden states
283
+ return_dict: bool - whether to return dictionary output
284
+ up_until_layer: int or None - stop at specific layer
285
 
286
  Returns:
287
+ Depending on return_dict and other flags
288
  """
289
+ batch_size, seq_len = seq.shape
290
+ device = seq.device
291
+
292
+ # Create causal mask
293
+ causal_mask = self._create_causal_mask(seq_len, device)
294
+
295
+ # Create padding mask
296
+ padding_mask = seq.eq(self.pad_idx)
297
+
298
+ # Token embeddings
299
+ tok_emb = self.transformer.wte(seq) * self.embed_scale
300
+
301
+ # Add positional encoding (sinusoidal, not learned)
302
+ x = self.pos_encoder(tok_emb)
303
+ x = self.transformer.drop(x)
304
+
305
+ all_hidden_states = []
306
+
307
+ # Pass through transformer layers
308
+ for block_idx, block in enumerate(self.transformer.h):
309
+ # Save hidden state before block
310
+ if output_hidden_states:
311
+ all_hidden_states.append(x)
312
+
313
+ # Check if we should stop early
314
+ if up_until_layer is not None and block_idx == up_until_layer:
315
+ break
316
+
317
+ # Forward the block
318
+ x = block(x, causal_mask, padding_mask)
319
+
320
+ # Append the last hidden state if we didn't exit early
321
+ if output_hidden_states and (up_until_layer is None or block_idx == len(self.transformer.h) - 1):
322
+ all_hidden_states.append(x)
323
+
324
+ # If only hidden states requested
325
+ if output_hidden_states and not output_logits and tgt is None:
326
+ model_output = BaseModelOutput(
327
+ last_hidden_state=x,
328
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
329
+ )
330
+ return model_output
331
+
332
+ # Final layer norm
333
+ x = self.transformer.ln_f(x)
334
+
335
+ # Compute logits
336
+ logits = self.coch_head(x)
337
+
338
+ # Compute loss if targets provided
339
+ if tgt is not None:
340
+ # Shift so that tokens < n predict n
341
+ shift_logits = logits[..., :-1, :].contiguous()
342
+ shift_labels = tgt[..., 1:].contiguous()
343
+
344
+ loss = F.cross_entropy(
345
+ shift_logits.reshape(-1, self.config.vocab_size),
346
+ shift_labels.reshape(-1),
347
+ ignore_index=self.pad_idx
348
+ )
349
+
350
+ if return_dict:
351
+ if output_logits:
352
+ # For compatibility, wrap single logits in list
353
+ all_logits = [logits]
354
+
355
+ if output_hidden_states:
356
+ model_output = CausalLMOutput(
357
+ loss=loss,
358
+ logits=all_logits if output_logits else logits,
359
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
360
+ )
361
+ else:
362
+ model_output = CausalLMOutput(
363
+ loss=loss,
364
+ logits=all_logits if output_logits else logits,
365
+ )
366
+ return model_output
367
+
368
+ return logits, loss
369
+
370
+ return logits, None
371
+
372
+ @classmethod
373
+ def from_pretrained(cls, pretrained_model_name_or_path, config=None, **kwargs):
374
+ """Load model from pretrained weights."""
375
+ import os
376
+ from huggingface_hub import hf_hub_download
377
+
378
+ # Load config if not provided
379
+ if config is None:
380
+ if os.path.isdir(pretrained_model_name_or_path):
381
+ config_path = os.path.join(pretrained_model_name_or_path, "config.json")
382
+ config = GSLMConfig.from_pretrained(config_path)
383
+ else:
384
+ # Download config from hub
385
+ config_path = hf_hub_download(
386
+ repo_id=pretrained_model_name_or_path,
387
+ filename="config.json"
388
+ )
389
+ config = GSLMConfig.from_pretrained(config_path)
390
+
391
+ # Create model
392
+ model = cls(config)
393
+
394
+ # Load weights
395
  if os.path.isdir(pretrained_model_name_or_path):
396
+ weights_file = os.path.join(pretrained_model_name_or_path, "model.safetensors")
397
  else:
398
+ # Download weights from hub
399
+ weights_file = hf_hub_download(
400
+ repo_id=pretrained_model_name_or_path,
401
+ filename="model.safetensors"
402
+ )
403
 
404
+ if weights_file.endswith('.safetensors'):
405
+ from safetensors.torch import load_file
406
+ state_dict = load_file(weights_file)
407
+ else:
408
+ state_dict = torch.load(weights_file, map_location='cpu')
 
409
 
410
+ model.load_state_dict(state_dict)
411
+
412
+ return model
413
 
414
+ @torch.no_grad()
415
+ def generate(
416
+ self,
417
+ input_ids: torch.Tensor,
418
+ max_length: int = 100,
419
+ temperature: float = 1.0,
420
+ top_k: Optional[int] = None,
421
+ top_p: Optional[float] = None,
422
+ pad_token_id: Optional[int] = None,
423
+ eos_token_id: Optional[int] = None
424
+ ) -> torch.Tensor:
425
+ """Generate sequences using the language model."""
426
+ if pad_token_id is None:
427
+ pad_token_id = self.pad_idx
428
+
429
+ batch_size = input_ids.shape[0]
430
+ device = input_ids.device
431
+
432
+ # Keep track of which sequences are done
433
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
434
+
435
+ while input_ids.shape[1] < max_length:
436
+ # Forward pass
437
+ logits, _ = self.forward(input_ids)
438
+ next_token_logits = logits[:, -1, :]
439
+
440
+ # Apply temperature
441
+ if temperature != 1.0:
442
+ next_token_logits = next_token_logits / temperature
443
+
444
+ # Apply top-k sampling
445
+ if top_k is not None:
446
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
447
+ next_token_logits[indices_to_remove] = -float('inf')
448
+
449
+ # Apply top-p (nucleus) sampling
450
+ if top_p is not None:
451
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
452
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
453
+
454
+ # Remove tokens with cumulative probability above the threshold
455
+ sorted_indices_to_remove = cumulative_probs > top_p
456
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
457
+ sorted_indices_to_remove[..., 0] = 0
458
+
459
+ indices_to_remove = sorted_indices_to_remove.scatter(
460
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
461
+ )
462
+ next_token_logits[indices_to_remove] = -float('inf')
463
+
464
+ # Sample from the distribution
465
+ probs = F.softmax(next_token_logits, dim=-1)
466
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
467
+
468
+ # Update unfinished sequences
469
+ if eos_token_id is not None:
470
+ tokens_to_add = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
471
+ unfinished_sequences = unfinished_sequences * (next_tokens != eos_token_id).long()
472
+ else:
473
+ tokens_to_add = next_tokens
474
+
475
+ # Concatenate tokens
476
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
477
+
478
+ # Stop if all sequences are finished
479
+ if eos_token_id is not None and unfinished_sequences.sum() == 0:
480
+ break
481
+
482
+ return input_ids