AuriStream-base / modeling_auristream.py
gretatuckute's picture
Doc strings + comments
df1b42e verified
"""
AuriStream Model for HuggingFace Transformers.
AuriStream is a speech language model by Greta Tuckute and Klemen Kotar.
This model predicts cochlear tokens from a tokenizer such as WavCochCausalV8192.
https://huggingface.co/TuKoResearch/WavCochCausalV8192
"""
import math
from typing import Optional, List
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
from .configuration_auristream import AuriStreamConfig
# ============================================================================
# Building Blocks
# ============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Normalization."""
def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) if weight else None
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
return output * self.weight
return output
class Rotary(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, dim: int, base: float = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x):
seq_len = x.shape[1]
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
cos_cached = freqs.cos()
sin_cached = freqs.sin()
return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
def apply_rotary_emb(x, cos, sin):
"""Apply rotary embeddings to input tensor."""
assert x.ndim == 4 # multihead attention expected
d = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], dim=3)
class CausalSelfAttention(nn.Module):
"""Multi-head causal self attention with RoPE."""
def __init__(self, config: AuriStreamConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
# Key, query, value projections for all heads
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
# Output projection
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
# RoPE
rope_theta = getattr(config, 'rope_theta', 10000)
if rope_theta is None:
rope_theta = 10000
self.rotary = Rotary(self.head_dim, base=rope_theta)
def forward(self, x, return_kv=False, return_attn_maps=False):
B, T, C = x.size()
# Calculate query, key, values for all heads
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, self.head_dim)
q = q.view(B, T, self.n_head, self.head_dim)
v = v.view(B, T, self.n_head, self.head_dim)
# Apply RoPE
cos, sin = self.rotary(q)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
if not return_kv and not return_attn_maps:
y = F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
is_causal=True
)
else:
# Manual implementation of attention
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
mask = mask.view(1, 1, T, T)
masked_att = att.masked_fill(mask, float('-inf'))
masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
if return_attn_maps:
return y, F.softmax(att, dim=-1)
if return_kv:
return y, k, v
return y
def kv_cache_forward(self, x, k_cache=None, v_cache=None):
"""Forward pass with KV cache for efficient generation."""
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# Apply RoPE with correct position
cache_len = k_cache.shape[2] if k_cache is not None else 0
dummy = torch.zeros(B, cache_len + T, self.n_head, self.head_dim,
device=q.device, dtype=q.dtype)
cos, sin = self.rotary(dummy)
cos = cos[:, cache_len:cache_len+T, :, :]
sin = sin[:, cache_len:cache_len+T, :, :]
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# Concatenate with cache
if k_cache is not None:
k = torch.cat((k_cache, k), dim=2)
if v_cache is not None:
v = torch.cat((v_cache, v), dim=2)
# Attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y, k, v
class MLP(nn.Module):
"""MLP with SiLU activation."""
def __init__(self, config: AuriStreamConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.SiLU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
"""Transformer block with pre-normalization."""
def __init__(self, config: AuriStreamConfig):
super().__init__()
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
self.attn_scale = 1.0
self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
def forward(self, x, return_kv=False, k_cache=None, v_cache=None):
if k_cache is not None and v_cache is not None:
x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache)
x = x + x_attn
x = x + self.mlp(self.norm2(x))
return x, k, v
elif return_kv:
x_attn, k, v = self.attn(self.norm1(x), return_kv=True)
x = x + x_attn
x = x + self.mlp(self.norm2(x))
return x, k, v
x = x + self.attn_scale * self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# ============================================================================
# Main Model
# ============================================================================
class AuriStreamPreTrainedModel(PreTrainedModel):
"""Base class for AuriStream models."""
config_class = AuriStreamConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Block"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class AuriStreamModel(AuriStreamPreTrainedModel):
"""
AuriStream speech language model.
A GPT-like transformer model for cochlear token prediction with optional
multi-token prediction (MTP) heads for improved representation learning and
novel inference capabilities.
Developed by Greta Tuckute and Klemen Kotar.
"""
config_class = AuriStreamConfig
def __init__(self, config: AuriStreamConfig):
super().__init__(config)
self.config = config
# Transformer components (no wrapper to match weight keys)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.ln_f = RMSNorm(config.n_embd, bias=config.bias)
# Multi-token prediction heads
if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
self.future_heads = nn.ModuleList([
nn.Linear(config.n_embd, config.vocab_size, bias=False)
for _ in range(config.n_pred_steps - 1)
])
else:
self.future_heads = None
# "Standard" LM output head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
# Apply special scaled init to residual projections
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, value):
self.wte = value
def get_num_params(self, non_embedding=True):
"""Return the number of parameters in the model."""
return sum(p.numel() for p in self.parameters())
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_logits: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
up_until_layer: Optional[int] = None,
normalize_embeddings: Optional[str] = None,
# Legacy arguments for compatibility
seq: Optional[torch.LongTensor] = None,
tgt: Optional[torch.LongTensor] = None,
):
"""
Forward pass for the AuriStream model.
Args:
input_ids: Input token IDs of shape (batch_size, seq_len)
labels: Target token IDs for computing loss
output_logits: Whether to return all logits (including from future heads).
The first element corresponds to the standard next-token head (prediction of i+1);
subsequent elements correspond to future heads predicting tokens i+2, i+3, etc.
output_hidden_states: Whether to return all hidden states, including the input
embedding state and final pre-ln_f state. Matches HuggingFace GPT-style.
return_dict: Whether to return a dict or tuple. If True, return a CausalLMOutput dict,
otherwise return a tuple.
up_until_layer: If set, stop the forward pass after this transformer block
(inclusive) and return intermediate activations. Useful for saving compute.
normalize_embeddings: 'l2' or 'learned' to normalize hidden states
seq: Legacy argument (alias for input_ids for backward compatibility)
tgt: Legacy argument (alias for labels for backward compatibility)
Returns:
If return_dict is True:
CausalLMOutput with fields:
• loss (optional): Scalar training loss
• logits: Tensor or list of tensors of prediction logits
• hidden_states (optional): Tuple of hidden states
Otherwise:
Tuple of (logits or list of logits, loss).
"""
# Handle legacy arguments
if seq is not None:
input_ids = seq
if tgt is not None:
labels = tgt
# Get embeddings
tok_emb = self.wte(input_ids)
x = self.drop(tok_emb)
# Collect hidden states
all_hidden_states = []
# Forward through transformer blocks
for block_idx, block in enumerate(self.h):
all_hidden_states.append(x)
if up_until_layer is not None and block_idx == up_until_layer:
break
x = block(x)
# Append final pre-ln_f state if we didn't exit early
if up_until_layer is None or block_idx == len(self.h) - 1:
all_hidden_states.append(x)
# Normalize hidden states if requested
hs_to_return = all_hidden_states
if output_hidden_states and normalize_embeddings is not None:
if normalize_embeddings == 'l2': # Preserve direction, get rid of magnitude
hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states] # Dim -1 is the hidden state dim;
# after normalization torch.norm(h_norm, p=2, dim=-1) will be 1. I.e. for every token, the hidden state dim norm is 1.
elif normalize_embeddings == 'learned': # We use the learned RMSNorm (first one; used to prepare embeddings for attn)
# I.e. these are the representations on which the model computes.
hs_to_return = []
L = len(self.h)
for i, h in enumerate(all_hidden_states):
if i < L:
hs_to_return.append(self.h[i].norm1(h))
else:
hs_to_return.append(self.ln_f(h)) # Final layer norm (after the main blocks, before LM head(s))
# If only hidden states requested (not logits), return early
if output_hidden_states and not output_logits and labels is None:
return BaseModelOutput(
last_hidden_state=x,
hidden_states=hs_to_return,
)
# Final layer norm and output head
x = self.ln_f(x)
logits = self.lm_head(x)
# Collect all logits if requested
all_logits = [logits] if output_logits else None
# Compute future head logits
# lm_head is the first "standard" lm head which predicts token i+1 (as all GPT models have)
# self.future_heads holds all the other "MTP" future prediction heads, so self.future_heads
# corresponds to the head that predicts token i+2 - aka the "second head"
if self.future_heads is not None:
for i, head in enumerate(self.future_heads):
future_logits = head(x[:, :-(i + 1)])
if output_logits:
all_logits.append(future_logits)
# Compute loss if labels provided
loss = None
if labels is not None:
# compute loss from the first "standard" lm head
loss = F.cross_entropy(
logits.reshape(-1, self.config.vocab_size),
labels.reshape(-1),
)
# Multi-token prediction loss
if self.future_heads is not None:
for i, head in enumerate(self.future_heads):
future_logits = head(x[:, :-(i + 1)])
loss = loss + F.cross_entropy(
future_logits.reshape(-1, self.config.vocab_size),
labels[:, (i + 1):].reshape(-1),
)
if not return_dict:
if labels is not None:
return (all_logits if output_logits else logits), loss
return (all_logits if output_logits else logits), None
return CausalLMOutput(
loss=loss,
logits=all_logits if output_logits else logits,
hidden_states=hs_to_return if output_hidden_states else None,
)
def sample_logits(
self,
logits: torch.FloatTensor,
temperature: float = 0.9,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.LongTensor:
"""Sample from logits with temperature, top-k, and top-p."""
if temperature == 0.0:
return torch.argmax(logits, dim=-1)
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[..., [-1]]] = -float('Inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = -float('Inf')
probs = F.softmax(logits, dim=-1)
flat_probs = probs.view(-1, probs.size(-1))
sampled = torch.multinomial(flat_probs, num_samples=1)
sampled = sampled.view(*logits.shape[:-1])
return sampled
@torch.no_grad()
def generate(
self,
seq: torch.Tensor,
n_tokens: int = 1,
temp: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
seed: Optional[int] = None,
):
"""
Generate new tokens autoregressively.
Args:
seq: Input token IDs of shape (batch_size, seq_len)
n_tokens: Number of tokens to generate
temp: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
seed: Random seed
Returns:
Tuple of (generated_tokens, all_logits)
"""
import random
import numpy as np
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
all_logits = []
device = seq.device
b, t = seq.size()
# Encode conditioning sequence into KV cache
tok_emb = self.wte(seq)
x = self.drop(tok_emb)
k_list = []
v_list = []
for block in self.h:
x, k, v = block(x, return_kv=True)
k_list.append(k)
v_list.append(v)
k_cache = torch.stack(k_list, dim=0)
v_cache = torch.stack(v_list, dim=0)
x = self.ln_f(x)
# First prediction
logits = self.lm_head(x[:, [-1]])
predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
all_logits.append(logits)
# Generate remaining tokens
for i in range(n_tokens - 1):
tok_emb = self.wte(predictions[-1])
x = self.drop(tok_emb)
k_list = []
v_list = []
for block_idx, block in enumerate(self.h):
x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
k_list.append(k)
v_list.append(v)
x = self.ln_f(x)
k_cache = torch.stack(k_list, dim=0)
v_cache = torch.stack(v_list, dim=0)
logits = self.lm_head(x)
predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
all_logits.append(logits)
pred_coch = torch.cat(predictions, dim=1)
all_logits = torch.cat(all_logits, dim=1)
return pred_coch, all_logits
# Alias for backward compatibility
AuriStream = AuriStreamModel