klemenk commited on
Commit
d65cc40
·
verified ·
1 Parent(s): 8838b95

Rename modeling_gslm_ulm.py to modeling.py

Browse files
Files changed (2) hide show
  1. modeling.py +130 -0
  2. modeling_gslm_ulm.py +0 -457
modeling.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSLM Model Configuration
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Optional
8
+
9
+
10
+ class GSLMConfig:
11
+ """
12
+ Configuration class for GSLM (Generative Spoken Language Model).
13
+
14
+ This configuration class stores all parameters needed to initialize a GSLMModel.
15
+ """
16
+
17
+ model_type = "gslm"
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size: int = 204,
22
+ d_model: int = 1024,
23
+ nhead: int = 16,
24
+ num_layers: int = 12,
25
+ dim_feedforward: int = 4096,
26
+ dropout: float = 0.1,
27
+ attention_dropout: float = 0.1,
28
+ max_seq_length: int = 3072,
29
+ pad_idx: int = 0,
30
+ share_input_output_embed: bool = True,
31
+ activation: str = "relu",
32
+ architecture: str = "transformer_lm_big",
33
+ **kwargs
34
+ ):
35
+ """
36
+ Initialize GSLM configuration.
37
+
38
+ Args:
39
+ vocab_size: Size of the vocabulary
40
+ d_model: Dimensionality of the embeddings and hidden states
41
+ nhead: Number of attention heads
42
+ num_layers: Number of transformer layers
43
+ dim_feedforward: Dimensionality of the feedforward network
44
+ dropout: Dropout probability
45
+ attention_dropout: Dropout probability for attention weights
46
+ max_seq_length: Maximum sequence length
47
+ pad_idx: Padding token index
48
+ share_input_output_embed: Whether to share input and output embeddings
49
+ activation: Activation function ("relu" or "gelu")
50
+ architecture: Model architecture name
51
+ """
52
+ self.vocab_size = vocab_size
53
+ self.d_model = d_model
54
+ self.nhead = nhead
55
+ self.num_layers = num_layers
56
+ self.dim_feedforward = dim_feedforward
57
+ self.dropout = dropout
58
+ self.attention_dropout = attention_dropout
59
+ self.max_seq_length = max_seq_length
60
+ self.pad_idx = pad_idx
61
+ self.share_input_output_embed = share_input_output_embed
62
+ self.activation = activation
63
+ self.architecture = architecture
64
+
65
+ # Handle any extra kwargs
66
+ for key, value in kwargs.items():
67
+ setattr(self, key, value)
68
+
69
+ def to_dict(self):
70
+ """Convert configuration to dictionary."""
71
+ output = {}
72
+ for key, value in self.__dict__.items():
73
+ if not key.startswith('_'):
74
+ output[key] = value
75
+ output['model_type'] = self.model_type
76
+ return output
77
+
78
+ def to_json_string(self):
79
+ """Convert configuration to JSON string."""
80
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True)
81
+
82
+ def save_pretrained(self, save_directory: str):
83
+ """Save configuration to directory."""
84
+ if not os.path.exists(save_directory):
85
+ os.makedirs(save_directory)
86
+
87
+ config_file = os.path.join(save_directory, "config.json")
88
+ with open(config_file, 'w') as f:
89
+ f.write(self.to_json_string())
90
+
91
+ @classmethod
92
+ def from_dict(cls, config_dict: dict):
93
+ """Create configuration from dictionary."""
94
+ return cls(**config_dict)
95
+
96
+ @classmethod
97
+ def from_json_file(cls, json_file: str):
98
+ """Create configuration from JSON file."""
99
+ with open(json_file, 'r') as f:
100
+ config_dict = json.load(f)
101
+ return cls.from_dict(config_dict)
102
+
103
+ @classmethod
104
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
105
+ """
106
+ Load configuration from pretrained model.
107
+
108
+ Args:
109
+ pretrained_model_name_or_path: Path to pretrained model or model identifier
110
+ **kwargs: Additional configuration parameters to override
111
+
112
+ Returns:
113
+ GSLMConfig instance
114
+ """
115
+ if os.path.isdir(pretrained_model_name_or_path):
116
+ config_file = os.path.join(pretrained_model_name_or_path, "config.json")
117
+ else:
118
+ config_file = pretrained_model_name_or_path
119
+
120
+ # Load config from file
121
+ config = cls.from_json_file(config_file)
122
+
123
+ # Override with any provided kwargs
124
+ for key, value in kwargs.items():
125
+ setattr(config, key, value)
126
+
127
+ return config
128
+
129
+ def __repr__(self):
130
+ return f"{self.__class__.__name__} {self.to_json_string()}"
modeling_gslm_ulm.py DELETED
@@ -1,457 +0,0 @@
1
- """
2
- GSLM ULM model definition for HuggingFace.
3
- """
4
-
5
- import math
6
- import torch
7
- import torch.nn as nn
8
- from torch.nn import functional as F
9
- from typing import Optional, Tuple, Dict, List, Union
10
- from transformers import PreTrainedModel
11
- from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
12
- from .configuration_gslm_ulm import GSLMULMConfig
13
-
14
-
15
- class GSLMULM(PreTrainedModel):
16
- """
17
- GSLM Unit Language Model - Transformer Language Model for discrete speech units.
18
- """
19
- config_class = GSLMULMConfig
20
- base_model_prefix = "transformer"
21
- supports_gradient_checkpointing = True
22
-
23
- def __init__(self, config):
24
- super().__init__(config)
25
- self.config = config
26
-
27
- self.transformer = nn.ModuleDict(dict(
28
- wte = nn.Embedding(config.vocab_size, config.n_embd),
29
- drop = nn.Dropout(config.dropout),
30
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
31
- ln_f = RMSNorm(config.n_embd, bias=config.bias),
32
- ))
33
-
34
- # Sinusoidal positional encoding
35
- if getattr(config, 'use_sinusoidal_embeddings', True):
36
- self.pos_encoder = SinusoidalPositionalEncoding(
37
- config.n_embd,
38
- config.max_position_embeddings
39
- )
40
- else:
41
- self.transformer.wpe = nn.Embedding(config.max_position_embeddings, config.n_embd)
42
- self.pos_encoder = None
43
-
44
- # Language modeling head
45
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
46
-
47
- # Weight tying
48
- self.lm_head.weight = self.transformer.wte.weight
49
-
50
- # Initialize weights
51
- self.post_init()
52
-
53
- def get_input_embeddings(self):
54
- return self.transformer.wte
55
-
56
- def set_input_embeddings(self, new_embeddings):
57
- self.transformer.wte = new_embeddings
58
-
59
- def get_output_embeddings(self):
60
- return self.lm_head
61
-
62
- def set_output_embeddings(self, new_embeddings):
63
- self.lm_head = new_embeddings
64
-
65
- def _init_weights(self, module):
66
- if isinstance(module, nn.Linear):
67
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
68
- if module.bias is not None:
69
- torch.nn.init.zeros_(module.bias)
70
- elif isinstance(module, nn.Embedding):
71
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
72
- if hasattr(module, 'padding_idx') and module.padding_idx is not None:
73
- module.weight.data[module.padding_idx].zero_()
74
- # Special scaled init for residual projections
75
- if hasattr(module, "c_proj") and isinstance(module.c_proj, nn.Linear):
76
- torch.nn.init.normal_(module.c_proj.weight, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer))
77
-
78
- def forward(
79
- self,
80
- input_ids: Optional[torch.LongTensor] = None,
81
- attention_mask: Optional[torch.FloatTensor] = None,
82
- labels: Optional[torch.LongTensor] = None,
83
- output_attentions: Optional[bool] = None,
84
- output_hidden_states: Optional[bool] = None,
85
- return_dict: Optional[bool] = None,
86
- **kwargs
87
- ) -> Union[Tuple, CausalLMOutput]:
88
- """
89
- Forward pass of the model.
90
-
91
- Args:
92
- input_ids: Input token IDs of shape (batch_size, sequence_length)
93
- attention_mask: Attention mask (not used, kept for compatibility)
94
- labels: Labels for language modeling loss
95
- output_attentions: Whether to return attention weights
96
- output_hidden_states: Whether to return hidden states
97
- return_dict: Whether to return a dictionary
98
-
99
- Returns:
100
- CausalLMOutput or tuple
101
- """
102
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
103
-
104
- # Get embeddings
105
- tok_emb = self.transformer.wte(input_ids)
106
- tok_emb = tok_emb * math.sqrt(self.config.n_embd) # Scale embeddings
107
-
108
- # Add positional encoding
109
- if self.pos_encoder is not None:
110
- x = self.pos_encoder(tok_emb)
111
- else:
112
- pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device)
113
- pos_emb = self.transformer.wpe(pos)
114
- x = tok_emb + pos_emb
115
-
116
- x = self.transformer.drop(x)
117
-
118
- # Pass through transformer blocks
119
- all_hidden_states = () if output_hidden_states else None
120
- all_attentions = () if output_attentions else None
121
-
122
- for i, block in enumerate(self.transformer.h):
123
- if output_hidden_states:
124
- all_hidden_states = all_hidden_states + (x,)
125
-
126
- outputs = block(x, output_attentions=output_attentions)
127
- x = outputs[0]
128
-
129
- if output_attentions:
130
- all_attentions = all_attentions + (outputs[1],)
131
-
132
- # Final layer norm
133
- x = self.transformer.ln_f(x)
134
-
135
- if output_hidden_states:
136
- all_hidden_states = all_hidden_states + (x,)
137
-
138
- # Language modeling head
139
- logits = self.lm_head(x)
140
-
141
- loss = None
142
- if labels is not None:
143
- # Shift so that tokens < n predict n
144
- shift_logits = logits[..., :-1, :].contiguous()
145
- shift_labels = labels[..., 1:].contiguous()
146
- # Flatten the tokens
147
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
148
- loss = loss_fct(
149
- shift_logits.view(-1, shift_logits.size(-1)),
150
- shift_labels.view(-1)
151
- )
152
-
153
- if not return_dict:
154
- output = (logits,) + (all_hidden_states,) if output_hidden_states else ()
155
- output = output + (all_attentions,) if output_attentions else ()
156
- return ((loss,) + output) if loss is not None else output
157
-
158
- return CausalLMOutput(
159
- loss=loss,
160
- logits=logits,
161
- hidden_states=all_hidden_states,
162
- attentions=all_attentions,
163
- )
164
-
165
- def sample_logits(
166
- self,
167
- logits: torch.FloatTensor,
168
- temperature: float = 1.0,
169
- top_k: Optional[int] = None,
170
- top_p: Optional[float] = None
171
- ) -> torch.LongTensor:
172
- """
173
- Sample from logits with temperature, top-k, and top-p (nucleus) sampling.
174
- """
175
- # If temperature is 0.0, use argmax
176
- if temperature == 0.0:
177
- return torch.argmax(logits, dim=-1)
178
-
179
- # Apply temperature
180
- logits = logits / temperature
181
-
182
- # Apply top-k filtering if specified
183
- if top_k is not None:
184
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
185
- logits[logits < v[..., [-1]]] = -float('Inf')
186
-
187
- # Apply top-p (nucleus) filtering if specified
188
- if top_p is not None:
189
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
190
- sorted_probs = F.softmax(sorted_logits, dim=-1)
191
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
192
-
193
- # Remove tokens with cumulative probability above the threshold
194
- sorted_indices_to_remove = cumulative_probs > top_p
195
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
196
- sorted_indices_to_remove[..., 0] = 0
197
-
198
- indices_to_remove = sorted_indices_to_remove.scatter(
199
- dim=-1, index=sorted_indices, src=sorted_indices_to_remove
200
- )
201
- logits[indices_to_remove] = -float('Inf')
202
-
203
- # Compute softmax probabilities
204
- probs = F.softmax(logits, dim=-1)
205
-
206
- # Sample from the distribution
207
- flat_probs = probs.view(-1, probs.size(-1))
208
- sampled = torch.multinomial(flat_probs, num_samples=1)
209
- sampled = sampled.view(*logits.shape[:-1])
210
-
211
- return sampled
212
-
213
- @torch.no_grad()
214
- def generate(
215
- self,
216
- input_ids: torch.LongTensor,
217
- max_new_tokens: int = 100,
218
- temperature: float = 1.0,
219
- top_k: Optional[int] = None,
220
- top_p: Optional[float] = None,
221
- eos_token_id: Optional[int] = None,
222
- pad_token_id: Optional[int] = None,
223
- attention_mask: Optional[torch.FloatTensor] = None,
224
- **kwargs
225
- ) -> torch.LongTensor:
226
- """
227
- Generate sequences autoregressively.
228
-
229
- Args:
230
- input_ids: Input token IDs of shape (batch_size, sequence_length)
231
- max_new_tokens: Maximum number of tokens to generate
232
- temperature: Sampling temperature
233
- top_k: Top-k filtering parameter
234
- top_p: Top-p (nucleus) filtering parameter
235
- eos_token_id: End-of-sequence token ID
236
- pad_token_id: Padding token ID
237
- attention_mask: Attention mask (not used, kept for compatibility)
238
-
239
- Returns:
240
- Generated token IDs
241
- """
242
- if eos_token_id is None:
243
- eos_token_id = self.config.eos_token_id
244
- if pad_token_id is None:
245
- pad_token_id = self.config.pad_token_id
246
-
247
- batch_size = input_ids.shape[0]
248
- device = input_ids.device
249
-
250
- # Keep track of which sequences are done
251
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
252
-
253
- # Cache for key-value pairs (more efficient generation)
254
- past_key_values = None
255
-
256
- for _ in range(max_new_tokens):
257
- # Forward pass
258
- outputs = self.forward(input_ids)
259
- next_token_logits = outputs.logits[:, -1, :]
260
-
261
- # Sample next tokens
262
- next_tokens = self.sample_logits(
263
- next_token_logits,
264
- temperature=temperature,
265
- top_k=top_k,
266
- top_p=top_p
267
- )
268
-
269
- # Update sequences
270
- if eos_token_id is not None:
271
- tokens_to_add = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
272
- unfinished_sequences = unfinished_sequences * (next_tokens != eos_token_id).long()
273
- else:
274
- tokens_to_add = next_tokens
275
-
276
- # Concatenate tokens
277
- input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
278
-
279
- # Stop if all sequences are finished
280
- if eos_token_id is not None and unfinished_sequences.sum() == 0:
281
- break
282
-
283
- return input_ids
284
-
285
- def prepare_inputs_for_generation(
286
- self,
287
- input_ids,
288
- past_key_values=None,
289
- attention_mask=None,
290
- **kwargs
291
- ):
292
- """Prepare inputs for generation."""
293
- # If past_key_values is used, only use the last token
294
- if past_key_values:
295
- input_ids = input_ids[:, -1:]
296
-
297
- return {
298
- "input_ids": input_ids,
299
- "past_key_values": past_key_values,
300
- "attention_mask": attention_mask,
301
- }
302
-
303
-
304
- class SinusoidalPositionalEncoding(nn.Module):
305
- """Sinusoidal positional encoding."""
306
-
307
- def __init__(self, d_model: int, max_len: int = 5000):
308
- super().__init__()
309
- pe = torch.zeros(max_len, d_model)
310
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
311
- div_term = torch.exp(torch.arange(0, d_model, 2).float() *
312
- (-math.log(10000.0) / d_model))
313
- pe[:, 0::2] = torch.sin(position * div_term)
314
- pe[:, 1::2] = torch.cos(position * div_term)
315
- self.register_buffer('pe', pe.unsqueeze(0))
316
-
317
- def forward(self, x: torch.Tensor) -> torch.Tensor:
318
- """Add positional encoding to input tensor."""
319
- return x + self.pe[:, :x.size(1)]
320
-
321
-
322
- class Block(nn.Module):
323
- """Transformer block with pre-normalization."""
324
-
325
- def __init__(self, config):
326
- super().__init__()
327
- self.attn = CausalSelfAttention(config)
328
- self.mlp = MLP(config)
329
- self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
330
- self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
331
- self.config = config
332
-
333
- def forward(self, x, output_attentions=False):
334
- # Pre-norm attention block
335
- attn_output = self.attn(self.norm1(x), output_attentions=output_attentions)
336
- if output_attentions:
337
- attn_output, attn_weights = attn_output
338
- x = x + attn_output
339
- x = x + self.mlp(self.norm2(x))
340
- return x, attn_weights
341
- else:
342
- x = x + attn_output
343
- x = x + self.mlp(self.norm2(x))
344
- return (x,)
345
-
346
-
347
- class CausalSelfAttention(nn.Module):
348
- """Multi-head causal self-attention."""
349
-
350
- def __init__(self, config):
351
- super().__init__()
352
- assert config.n_embd % config.n_head == 0
353
-
354
- self.n_head = config.n_head
355
- self.n_embd = config.n_embd
356
- self.head_dim = config.n_embd // config.n_head
357
- self.scaling = self.head_dim ** -0.5
358
-
359
- # Key, query, value projections for all heads
360
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
361
- # Output projection
362
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
363
- # Attention dropout
364
- self.attn_dropout = nn.Dropout(config.attention_dropout)
365
- self.resid_dropout = nn.Dropout(config.dropout)
366
-
367
- def forward(self, x, output_attentions=False):
368
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality
369
-
370
- # Calculate query, key, values for all heads
371
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
372
- q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
373
- k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
374
- v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
375
-
376
- # Scale query
377
- q = q * self.scaling
378
-
379
- # Causal self-attention
380
- if not output_attentions:
381
- # Use flash attention when available
382
- y = F.scaled_dot_product_attention(
383
- q, k, v,
384
- attn_mask=None,
385
- dropout_p=self.attn_dropout.p if self.training else 0.0,
386
- is_causal=True
387
- )
388
- else:
389
- # Manual implementation for attention weights
390
- att = torch.matmul(q, k.transpose(-2, -1)) # (B, nh, T, T)
391
-
392
- # Causal mask
393
- causal_mask = torch.triu(
394
- torch.ones(T, T, dtype=torch.bool, device=x.device),
395
- diagonal=1
396
- )
397
- att = att.masked_fill(causal_mask, float('-inf'))
398
-
399
- # Softmax
400
- att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
401
- att = self.attn_dropout(att)
402
-
403
- # Apply attention to values
404
- y = torch.matmul(att, v) # (B, nh, T, hs)
405
-
406
- # Re-assemble all head outputs
407
- y = y.transpose(1, 2).contiguous().view(B, T, C)
408
-
409
- # Output projection
410
- y = self.resid_dropout(self.c_proj(y))
411
-
412
- if output_attentions:
413
- return y, att
414
- else:
415
- return y
416
-
417
-
418
- class MLP(nn.Module):
419
- """Position-wise feed-forward network."""
420
-
421
- def __init__(self, config):
422
- super().__init__()
423
- self.c_fc = nn.Linear(config.n_embd, config.n_inner, bias=config.bias)
424
- self.c_proj = nn.Linear(config.n_inner, config.n_embd, bias=config.bias)
425
- self.act = nn.GELU()
426
- self.dropout = nn.Dropout(config.dropout)
427
-
428
- def forward(self, x):
429
- x = self.c_fc(x)
430
- x = self.act(x)
431
- x = self.c_proj(x)
432
- x = self.dropout(x)
433
- return x
434
-
435
-
436
- class RMSNorm(nn.Module):
437
- """Root Mean Square Layer Normalization."""
438
-
439
- def __init__(self, dim: int, eps: float = 1e-6, bias: bool = False):
440
- super().__init__()
441
- self.eps = eps
442
- self.weight = nn.Parameter(torch.ones(dim))
443
- self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
444
-
445
- def forward(self, x):
446
- norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
447
- if self.bias is not None:
448
- return self.weight * norm_x + self.bias
449
- return self.weight * norm_x
450
-
451
-
452
- # Register the model
453
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
454
-
455
- AutoConfig.register("gslm_ulm", GSLMULMConfig)
456
- AutoModel.register(GSLMULMConfig, GSLMULM)
457
- AutoModelForCausalLM.register(GSLMULMConfig, GSLMULM)