GSLM-HuBERT200 / modeling.py
klemenk's picture
Update modeling.py
5963fe7 verified
"""
GSLM Unit Language Model - HuggingFace Compatible Implementation
Based on fairseq's transformer_lm_big architecture
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import json
from typing import Optional, Tuple, Dict, Union, List
from dataclasses import dataclass
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithCrossAttentions
from transformers import AutoConfig, AutoModelForCausalLM
# Import config - handle both local and remote imports
try:
from .config import GSLMConfig
except ImportError:
# Fallback for when file is accessed directly
from config import GSLMConfig
# For backward compatibility with the API
@dataclass
class CausalLMOutput:
loss: Optional[torch.FloatTensor] = None
logits: Union[torch.FloatTensor, List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for transformer models."""
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add positional encoding to input tensor."""
return x + self.pe[:, :x.size(1)]
class MultiheadAttention(nn.Module):
"""Multi-head attention mechanism."""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query: [batch_size, tgt_len, embed_dim]
key: [batch_size, src_len, embed_dim]
value: [batch_size, src_len, embed_dim]
attn_mask: [tgt_len, src_len] or [batch_size * num_heads, tgt_len, src_len]
key_padding_mask: [batch_size, src_len]
"""
if key is None:
key = query
if value is None:
value = query
batch_size, tgt_len, embed_dim = query.size()
src_len = key.size(1)
# Project and reshape
q = self.q_proj(query) * self.scaling
k = self.k_proj(key)
v = self.v_proj(value)
q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
attn_weights = torch.matmul(q, k.transpose(-2, -1))
# Apply masks
if attn_mask is not None:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
attn_weights = attn_weights + attn_mask
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf')
)
# Softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_weights = self.attn_dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, tgt_len, embed_dim
)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class TransformerDecoderLayer(nn.Module):
"""Transformer decoder layer."""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation: str = "relu"
):
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout)
# Feedforward network
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout modules
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Activation
self.activation = F.relu if activation == "relu" else F.gelu
def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: [batch_size, seq_len, d_model]
self_attn_mask: [seq_len, seq_len]
self_attn_padding_mask: [batch_size, seq_len]
"""
# Self-attention block
residual = x
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, self_attn_mask, self_attn_padding_mask)
x = self.dropout1(x)
x = residual + x
# Feedforward block
residual = x
x = self.norm2(x)
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.dropout2(x)
x = residual + x
return x
class GSLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = GSLMConfig
base_model_prefix = "gslm"
supports_gradient_checkpointing = True
_no_split_modules = ["TransformerDecoderLayer"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GSLMForCausalLM(GSLMPreTrainedModel):
"""
GSLM Unit Language Model - Transformer LM Big Architecture
HuggingFace compatible version with modified forward API
"""
def __init__(self, config: GSLMConfig):
super().__init__(config)
self.config = config
self.d_model = config.d_model
self.vocab_size = config.vocab_size
self.pad_idx = config.pad_idx
self.max_seq_length = config.max_seq_length
# Create transformer module container for compatibility
self.transformer = nn.Module()
# Token embeddings (wte for compatibility)
self.transformer.wte = nn.Embedding(config.vocab_size, config.d_model, padding_idx=self.pad_idx)
self.embed_scale = math.sqrt(config.d_model)
# Positional encoding
self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length)
# Transformer decoder layers (h for compatibility)
self.transformer.h = nn.ModuleList([
TransformerDecoderLayer(
config.d_model,
config.nhead,
config.dim_feedforward,
config.dropout,
config.attention_dropout
) for _ in range(config.num_layers)
])
# Final layer norm (ln_f for compatibility)
self.transformer.ln_f = nn.LayerNorm(config.d_model)
# Output projection (coch_head for compatibility)
if config.share_input_output_embed:
self.coch_head = lambda x: F.linear(x, self.transformer.wte.weight)
else:
self.coch_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Dropout
self.transformer.drop = nn.Dropout(config.dropout)
# Future heads not supported in GSLM
self.future_heads = None
# Initialize weights
self.post_init()
def get_input_embeddings(self):
return self.transformer.wte
def set_input_embeddings(self, value):
self.transformer.wte = value
def get_output_embeddings(self):
if self.config.share_input_output_embed:
return self.transformer.wte
else:
return self.coch_head
def _create_causal_mask(self, seq_len: int, device) -> torch.Tensor:
"""Create causal attention mask."""
mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), device=device),
diagonal=1
)
return mask
def forward(
self,
seq=None,
input_ids=None,
tgt=None,
labels=None,
output_logits=False,
output_hidden_states=False,
return_dict=False,
up_until_layer=None,
**kwargs
):
"""
Compatible forward method with the specified API.
Args:
seq: torch.Tensor of shape (b, t) - input token IDs (legacy)
input_ids: torch.Tensor of shape (b, t) - input token IDs (HF standard)
tgt: torch.Tensor of shape (b, t) or None - target token IDs (legacy)
labels: torch.Tensor of shape (b, t) or None - target token IDs (HF standard)
output_logits: bool - whether to output logits
output_hidden_states: bool - whether to output all hidden states
return_dict: bool - whether to return dictionary output
up_until_layer: int or None - stop at specific layer
Returns:
Depending on return_dict and other flags
"""
# Handle both seq and input_ids for compatibility
if seq is None and input_ids is not None:
seq = input_ids
elif seq is None and input_ids is None:
raise ValueError("Either 'seq' or 'input_ids' must be provided")
# Handle both tgt and labels for compatibility
if tgt is None and labels is not None:
tgt = labels
batch_size, seq_len = seq.shape
device = seq.device
# Create causal mask
causal_mask = self._create_causal_mask(seq_len, device)
# Create padding mask
padding_mask = seq.eq(self.pad_idx)
# Token embeddings
tok_emb = self.transformer.wte(seq) * self.embed_scale
# Add positional encoding (sinusoidal, not learned)
x = self.pos_encoder(tok_emb)
x = self.transformer.drop(x)
all_hidden_states = []
# Pass through transformer layers
for block_idx, block in enumerate(self.transformer.h):
# Save hidden state before block
if output_hidden_states:
all_hidden_states.append(x)
# Check if we should stop early
if up_until_layer is not None and block_idx == up_until_layer:
break
# Forward the block
x = block(x, causal_mask, padding_mask)
# Append the last hidden state if we didn't exit early
if output_hidden_states and (up_until_layer is None or block_idx == len(self.transformer.h) - 1):
all_hidden_states.append(x)
# If only hidden states requested
if output_hidden_states and not output_logits and tgt is None:
model_output = BaseModelOutput(
last_hidden_state=x,
hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
)
return model_output
# Final layer norm
x = self.transformer.ln_f(x)
# Compute logits
logits = self.coch_head(x)
# Compute loss if targets provided
if tgt is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = tgt[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.reshape(-1, self.config.vocab_size),
shift_labels.reshape(-1),
ignore_index=self.pad_idx
)
if return_dict:
if output_logits:
# For compatibility, wrap single logits in list
all_logits = [logits]
if output_hidden_states:
model_output = CausalLMOutput(
loss=loss,
logits=all_logits if output_logits else logits,
hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
)
else:
model_output = CausalLMOutput(
loss=loss,
logits=all_logits if output_logits else logits,
)
return model_output
return logits, loss
# No targets provided
if return_dict:
return CausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
)
return logits, None
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor = None,
seq: torch.Tensor = None,
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**kwargs
) -> torch.Tensor:
"""Generate sequences using the language model."""
# Handle both input_ids and seq
if input_ids is None and seq is not None:
input_ids = seq
elif input_ids is None:
raise ValueError("Either 'input_ids' or 'seq' must be provided")
if pad_token_id is None:
pad_token_id = self.pad_idx
batch_size = input_ids.shape[0]
device = input_ids.device
# Keep track of which sequences are done
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
while input_ids.shape[1] < max_length:
# Forward pass
logits, _ = self.forward(input_ids)
next_token_logits = logits[:, -1, :]
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Apply top-k sampling
if top_k is not None:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = -float('inf')
# Apply top-p (nucleus) sampling
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
next_token_logits[indices_to_remove] = -float('inf')
# Sample from the distribution
probs = F.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
# Update unfinished sequences
if eos_token_id is not None:
tokens_to_add = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
unfinished_sequences = unfinished_sequences * (next_tokens != eos_token_id).long()
else:
tokens_to_add = next_tokens
# Concatenate tokens
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
# Stop if all sequences are finished
if eos_token_id is not None and unfinished_sequences.sum() == 0:
break
return input_ids
# Register the model with AutoModel
AutoConfig.register("gslm", GSLMConfig)
AutoModelForCausalLM.register(GSLMConfig, GSLMForCausalLM)