""" GSLM Unit Language Model - HuggingFace Compatible Implementation Based on fairseq's transformer_lm_big architecture """ import torch import torch.nn as nn import torch.nn.functional as F import math import os import json from typing import Optional, Tuple, Dict, Union, List from dataclasses import dataclass from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithCrossAttentions from transformers import AutoConfig, AutoModelForCausalLM # Import config - handle both local and remote imports try: from .config import GSLMConfig except ImportError: # Fallback for when file is accessed directly from config import GSLMConfig # For backward compatibility with the API @dataclass class CausalLMOutput: loss: Optional[torch.FloatTensor] = None logits: Union[torch.FloatTensor, List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for transformer models.""" def __init__(self, d_model: int, max_len: int = 5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding to input tensor.""" return x + self.pe[:, :x.size(1)] class MultiheadAttention(nn.Module): """Multi-head attention mechanism.""" def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): super().__init__() assert embed_dim % num_heads == 0 self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.attn_dropout = nn.Dropout(dropout) def forward( self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: query: [batch_size, tgt_len, embed_dim] key: [batch_size, src_len, embed_dim] value: [batch_size, src_len, embed_dim] attn_mask: [tgt_len, src_len] or [batch_size * num_heads, tgt_len, src_len] key_padding_mask: [batch_size, src_len] """ if key is None: key = query if value is None: value = query batch_size, tgt_len, embed_dim = query.size() src_len = key.size(1) # Project and reshape q = self.q_proj(query) * self.scaling k = self.k_proj(key) v = self.v_proj(value) q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores attn_weights = torch.matmul(q, k.transpose(-2, -1)) # Apply masks if attn_mask is not None: if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) attn_weights = attn_weights + attn_mask if key_padding_mask is not None: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') ) # Softmax attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) attn_weights = self.attn_dropout(attn_weights) # Apply attention to values attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, tgt_len, embed_dim ) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class TransformerDecoderLayer(nn.Module): """Transformer decoder layer.""" def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, attention_dropout: float = 0.1, activation: str = "relu" ): super().__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout) # Feedforward network self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) # Layer normalization self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) # Dropout modules self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Activation self.activation = F.relu if activation == "relu" else F.gelu def forward( self, x: torch.Tensor, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x: [batch_size, seq_len, d_model] self_attn_mask: [seq_len, seq_len] self_attn_padding_mask: [batch_size, seq_len] """ # Self-attention block residual = x x = self.norm1(x) x, _ = self.self_attn(x, x, x, self_attn_mask, self_attn_padding_mask) x = self.dropout1(x) x = residual + x # Feedforward block residual = x x = self.norm2(x) x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = self.dropout2(x) x = residual + x return x class GSLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = GSLMConfig base_model_prefix = "gslm" supports_gradient_checkpointing = True _no_split_modules = ["TransformerDecoderLayer"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class GSLMForCausalLM(GSLMPreTrainedModel): """ GSLM Unit Language Model - Transformer LM Big Architecture HuggingFace compatible version with modified forward API """ def __init__(self, config: GSLMConfig): super().__init__(config) self.config = config self.d_model = config.d_model self.vocab_size = config.vocab_size self.pad_idx = config.pad_idx self.max_seq_length = config.max_seq_length # Create transformer module container for compatibility self.transformer = nn.Module() # Token embeddings (wte for compatibility) self.transformer.wte = nn.Embedding(config.vocab_size, config.d_model, padding_idx=self.pad_idx) self.embed_scale = math.sqrt(config.d_model) # Positional encoding self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length) # Transformer decoder layers (h for compatibility) self.transformer.h = nn.ModuleList([ TransformerDecoderLayer( config.d_model, config.nhead, config.dim_feedforward, config.dropout, config.attention_dropout ) for _ in range(config.num_layers) ]) # Final layer norm (ln_f for compatibility) self.transformer.ln_f = nn.LayerNorm(config.d_model) # Output projection (coch_head for compatibility) if config.share_input_output_embed: self.coch_head = lambda x: F.linear(x, self.transformer.wte.weight) else: self.coch_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Dropout self.transformer.drop = nn.Dropout(config.dropout) # Future heads not supported in GSLM self.future_heads = None # Initialize weights self.post_init() def get_input_embeddings(self): return self.transformer.wte def set_input_embeddings(self, value): self.transformer.wte = value def get_output_embeddings(self): if self.config.share_input_output_embed: return self.transformer.wte else: return self.coch_head def _create_causal_mask(self, seq_len: int, device) -> torch.Tensor: """Create causal attention mask.""" mask = torch.triu( torch.full((seq_len, seq_len), float('-inf'), device=device), diagonal=1 ) return mask def forward( self, seq=None, input_ids=None, tgt=None, labels=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None, **kwargs ): """ Compatible forward method with the specified API. Args: seq: torch.Tensor of shape (b, t) - input token IDs (legacy) input_ids: torch.Tensor of shape (b, t) - input token IDs (HF standard) tgt: torch.Tensor of shape (b, t) or None - target token IDs (legacy) labels: torch.Tensor of shape (b, t) or None - target token IDs (HF standard) output_logits: bool - whether to output logits output_hidden_states: bool - whether to output all hidden states return_dict: bool - whether to return dictionary output up_until_layer: int or None - stop at specific layer Returns: Depending on return_dict and other flags """ # Handle both seq and input_ids for compatibility if seq is None and input_ids is not None: seq = input_ids elif seq is None and input_ids is None: raise ValueError("Either 'seq' or 'input_ids' must be provided") # Handle both tgt and labels for compatibility if tgt is None and labels is not None: tgt = labels batch_size, seq_len = seq.shape device = seq.device # Create causal mask causal_mask = self._create_causal_mask(seq_len, device) # Create padding mask padding_mask = seq.eq(self.pad_idx) # Token embeddings tok_emb = self.transformer.wte(seq) * self.embed_scale # Add positional encoding (sinusoidal, not learned) x = self.pos_encoder(tok_emb) x = self.transformer.drop(x) all_hidden_states = [] # Pass through transformer layers for block_idx, block in enumerate(self.transformer.h): # Save hidden state before block if output_hidden_states: all_hidden_states.append(x) # Check if we should stop early if up_until_layer is not None and block_idx == up_until_layer: break # Forward the block x = block(x, causal_mask, padding_mask) # Append the last hidden state if we didn't exit early if output_hidden_states and (up_until_layer is None or block_idx == len(self.transformer.h) - 1): all_hidden_states.append(x) # If only hidden states requested if output_hidden_states and not output_logits and tgt is None: model_output = BaseModelOutput( last_hidden_state=x, hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) return model_output # Final layer norm x = self.transformer.ln_f(x) # Compute logits logits = self.coch_head(x) # Compute loss if targets provided if tgt is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = tgt[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.reshape(-1, self.config.vocab_size), shift_labels.reshape(-1), ignore_index=self.pad_idx ) if return_dict: if output_logits: # For compatibility, wrap single logits in list all_logits = [logits] if output_hidden_states: model_output = CausalLMOutput( loss=loss, logits=all_logits if output_logits else logits, hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) else: model_output = CausalLMOutput( loss=loss, logits=all_logits if output_logits else logits, ) return model_output return logits, loss # No targets provided if return_dict: return CausalLMOutputWithCrossAttentions( logits=logits, hidden_states=tuple(all_hidden_states) if output_hidden_states else None, ) return logits, None @torch.no_grad() def generate( self, input_ids: torch.Tensor = None, seq: torch.Tensor = None, max_length: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, **kwargs ) -> torch.Tensor: """Generate sequences using the language model.""" # Handle both input_ids and seq if input_ids is None and seq is not None: input_ids = seq elif input_ids is None: raise ValueError("Either 'input_ids' or 'seq' must be provided") if pad_token_id is None: pad_token_id = self.pad_idx batch_size = input_ids.shape[0] device = input_ids.device # Keep track of which sequences are done unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) while input_ids.shape[1] < max_length: # Forward pass logits, _ = self.forward(input_ids) next_token_logits = logits[:, -1, :] # Apply temperature if temperature != 1.0: next_token_logits = next_token_logits / temperature # Apply top-k sampling if top_k is not None: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = -float('inf') # Apply top-p (nucleus) sampling if top_p is not None: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) next_token_logits[indices_to_remove] = -float('inf') # Sample from the distribution probs = F.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) # Update unfinished sequences if eos_token_id is not None: tokens_to_add = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) unfinished_sequences = unfinished_sequences * (next_tokens != eos_token_id).long() else: tokens_to_add = next_tokens # Concatenate tokens input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) # Stop if all sequences are finished if eos_token_id is not None and unfinished_sequences.sum() == 0: break return input_ids # Register the model with AutoModel AutoConfig.register("gslm", GSLMConfig) AutoModelForCausalLM.register(GSLMConfig, GSLMForCausalLM)