|
|
""" |
|
|
GSLM Unit Language Model - HuggingFace Compatible Implementation |
|
|
Based on fairseq's transformer_lm_big architecture |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import os |
|
|
import json |
|
|
from typing import Optional, Tuple, Dict, Union, List |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithCrossAttentions |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
try: |
|
|
from .config import GSLMConfig |
|
|
except ImportError: |
|
|
|
|
|
from config import GSLMConfig |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CausalLMOutput: |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: Union[torch.FloatTensor, List[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""Sinusoidal positional encoding for transformer models.""" |
|
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000): |
|
|
super().__init__() |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
|
|
(-math.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Add positional encoding to input tensor.""" |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
"""Multi-head attention mechanism.""" |
|
|
|
|
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): |
|
|
super().__init__() |
|
|
assert embed_dim % num_heads == 0 |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: Optional[torch.Tensor] = None, |
|
|
value: Optional[torch.Tensor] = None, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
key_padding_mask: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
query: [batch_size, tgt_len, embed_dim] |
|
|
key: [batch_size, src_len, embed_dim] |
|
|
value: [batch_size, src_len, embed_dim] |
|
|
attn_mask: [tgt_len, src_len] or [batch_size * num_heads, tgt_len, src_len] |
|
|
key_padding_mask: [batch_size, src_len] |
|
|
""" |
|
|
if key is None: |
|
|
key = query |
|
|
if value is None: |
|
|
value = query |
|
|
|
|
|
batch_size, tgt_len, embed_dim = query.size() |
|
|
src_len = key.size(1) |
|
|
|
|
|
|
|
|
q = self.q_proj(query) * self.scaling |
|
|
k = self.k_proj(key) |
|
|
v = self.v_proj(value) |
|
|
|
|
|
q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) |
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
|
if attn_mask.dim() == 2: |
|
|
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) |
|
|
attn_weights = attn_weights + attn_mask |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
attn_weights = attn_weights.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2), |
|
|
float('-inf') |
|
|
) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
|
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view( |
|
|
batch_size, tgt_len, embed_dim |
|
|
) |
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
"""Transformer decoder layer.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
nhead: int, |
|
|
dim_feedforward: int = 2048, |
|
|
dropout: float = 0.1, |
|
|
attention_dropout: float = 0.1, |
|
|
activation: str = "relu" |
|
|
): |
|
|
super().__init__() |
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout) |
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.activation = F.relu if activation == "relu" else F.gelu |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
self_attn_mask: Optional[torch.Tensor] = None, |
|
|
self_attn_padding_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [batch_size, seq_len, d_model] |
|
|
self_attn_mask: [seq_len, seq_len] |
|
|
self_attn_padding_mask: [batch_size, seq_len] |
|
|
""" |
|
|
|
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
x, _ = self.self_attn(x, x, x, self_attn_mask, self_attn_padding_mask) |
|
|
x = self.dropout1(x) |
|
|
x = residual + x |
|
|
|
|
|
|
|
|
residual = x |
|
|
x = self.norm2(x) |
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
|
x = self.dropout2(x) |
|
|
x = residual + x |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class GSLMPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. |
|
|
""" |
|
|
config_class = GSLMConfig |
|
|
base_model_prefix = "gslm" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["TransformerDecoderLayer"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
class GSLMForCausalLM(GSLMPreTrainedModel): |
|
|
""" |
|
|
GSLM Unit Language Model - Transformer LM Big Architecture |
|
|
HuggingFace compatible version with modified forward API |
|
|
""" |
|
|
|
|
|
def __init__(self, config: GSLMConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.d_model = config.d_model |
|
|
self.vocab_size = config.vocab_size |
|
|
self.pad_idx = config.pad_idx |
|
|
self.max_seq_length = config.max_seq_length |
|
|
|
|
|
|
|
|
self.transformer = nn.Module() |
|
|
|
|
|
|
|
|
self.transformer.wte = nn.Embedding(config.vocab_size, config.d_model, padding_idx=self.pad_idx) |
|
|
self.embed_scale = math.sqrt(config.d_model) |
|
|
|
|
|
|
|
|
self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length) |
|
|
|
|
|
|
|
|
self.transformer.h = nn.ModuleList([ |
|
|
TransformerDecoderLayer( |
|
|
config.d_model, |
|
|
config.nhead, |
|
|
config.dim_feedforward, |
|
|
config.dropout, |
|
|
config.attention_dropout |
|
|
) for _ in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.transformer.ln_f = nn.LayerNorm(config.d_model) |
|
|
|
|
|
|
|
|
if config.share_input_output_embed: |
|
|
self.coch_head = lambda x: F.linear(x, self.transformer.wte.weight) |
|
|
else: |
|
|
self.coch_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.transformer.drop = nn.Dropout(config.dropout) |
|
|
|
|
|
|
|
|
self.future_heads = None |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.transformer.wte |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.transformer.wte = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
if self.config.share_input_output_embed: |
|
|
return self.transformer.wte |
|
|
else: |
|
|
return self.coch_head |
|
|
|
|
|
def _create_causal_mask(self, seq_len: int, device) -> torch.Tensor: |
|
|
"""Create causal attention mask.""" |
|
|
mask = torch.triu( |
|
|
torch.full((seq_len, seq_len), float('-inf'), device=device), |
|
|
diagonal=1 |
|
|
) |
|
|
return mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
seq=None, |
|
|
input_ids=None, |
|
|
tgt=None, |
|
|
labels=None, |
|
|
output_logits=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=False, |
|
|
up_until_layer=None, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Compatible forward method with the specified API. |
|
|
|
|
|
Args: |
|
|
seq: torch.Tensor of shape (b, t) - input token IDs (legacy) |
|
|
input_ids: torch.Tensor of shape (b, t) - input token IDs (HF standard) |
|
|
tgt: torch.Tensor of shape (b, t) or None - target token IDs (legacy) |
|
|
labels: torch.Tensor of shape (b, t) or None - target token IDs (HF standard) |
|
|
output_logits: bool - whether to output logits |
|
|
output_hidden_states: bool - whether to output all hidden states |
|
|
return_dict: bool - whether to return dictionary output |
|
|
up_until_layer: int or None - stop at specific layer |
|
|
|
|
|
Returns: |
|
|
Depending on return_dict and other flags |
|
|
""" |
|
|
|
|
|
if seq is None and input_ids is not None: |
|
|
seq = input_ids |
|
|
elif seq is None and input_ids is None: |
|
|
raise ValueError("Either 'seq' or 'input_ids' must be provided") |
|
|
|
|
|
|
|
|
if tgt is None and labels is not None: |
|
|
tgt = labels |
|
|
|
|
|
batch_size, seq_len = seq.shape |
|
|
device = seq.device |
|
|
|
|
|
|
|
|
causal_mask = self._create_causal_mask(seq_len, device) |
|
|
|
|
|
|
|
|
padding_mask = seq.eq(self.pad_idx) |
|
|
|
|
|
|
|
|
tok_emb = self.transformer.wte(seq) * self.embed_scale |
|
|
|
|
|
|
|
|
x = self.pos_encoder(tok_emb) |
|
|
x = self.transformer.drop(x) |
|
|
|
|
|
all_hidden_states = [] |
|
|
|
|
|
|
|
|
for block_idx, block in enumerate(self.transformer.h): |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states.append(x) |
|
|
|
|
|
|
|
|
if up_until_layer is not None and block_idx == up_until_layer: |
|
|
break |
|
|
|
|
|
|
|
|
x = block(x, causal_mask, padding_mask) |
|
|
|
|
|
|
|
|
if output_hidden_states and (up_until_layer is None or block_idx == len(self.transformer.h) - 1): |
|
|
all_hidden_states.append(x) |
|
|
|
|
|
|
|
|
if output_hidden_states and not output_logits and tgt is None: |
|
|
model_output = BaseModelOutput( |
|
|
last_hidden_state=x, |
|
|
hidden_states=tuple(all_hidden_states) if all_hidden_states else None, |
|
|
) |
|
|
return model_output |
|
|
|
|
|
|
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
|
|
|
logits = self.coch_head(x) |
|
|
|
|
|
|
|
|
if tgt is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = tgt[..., 1:].contiguous() |
|
|
|
|
|
loss = F.cross_entropy( |
|
|
shift_logits.reshape(-1, self.config.vocab_size), |
|
|
shift_labels.reshape(-1), |
|
|
ignore_index=self.pad_idx |
|
|
) |
|
|
|
|
|
if return_dict: |
|
|
if output_logits: |
|
|
|
|
|
all_logits = [logits] |
|
|
|
|
|
if output_hidden_states: |
|
|
model_output = CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=all_logits if output_logits else logits, |
|
|
hidden_states=tuple(all_hidden_states) if all_hidden_states else None, |
|
|
) |
|
|
else: |
|
|
model_output = CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=all_logits if output_logits else logits, |
|
|
) |
|
|
return model_output |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
|
|
|
if return_dict: |
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
logits=logits, |
|
|
hidden_states=tuple(all_hidden_states) if output_hidden_states else None, |
|
|
) |
|
|
|
|
|
return logits, None |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor = None, |
|
|
seq: torch.Tensor = None, |
|
|
max_length: int = 100, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
pad_token_id: Optional[int] = None, |
|
|
eos_token_id: Optional[int] = None, |
|
|
**kwargs |
|
|
) -> torch.Tensor: |
|
|
"""Generate sequences using the language model.""" |
|
|
|
|
|
if input_ids is None and seq is not None: |
|
|
input_ids = seq |
|
|
elif input_ids is None: |
|
|
raise ValueError("Either 'input_ids' or 'seq' must be provided") |
|
|
|
|
|
if pad_token_id is None: |
|
|
pad_token_id = self.pad_idx |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) |
|
|
|
|
|
while input_ids.shape[1] < max_length: |
|
|
|
|
|
logits, _ = self.forward(input_ids) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
|
next_token_logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
|
|
|
if top_p is not None: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter( |
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
|
|
) |
|
|
next_token_logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
|
|
|
if eos_token_id is not None: |
|
|
tokens_to_add = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
unfinished_sequences = unfinished_sequences * (next_tokens != eos_token_id).long() |
|
|
else: |
|
|
tokens_to_add = next_tokens |
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and unfinished_sequences.sum() == 0: |
|
|
break |
|
|
|
|
|
return input_ids |
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("gslm", GSLMConfig) |
|
|
AutoModelForCausalLM.register(GSLMConfig, GSLMForCausalLM) |
|
|
|