gslm-ulm / modeling_gslm.py
klemenk's picture
Update modeling_gslm.py
2c0d1ed verified
# custom_code/modeling_unit_lm.py
import math
import os
import random
from typing import List, Dict, Optional, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_unit_lm import UnitLMConfig
from .units_dictionary import UnitDictionary
# -----------------------
# Positional embeddings
# -----------------------
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int):
super().__init__()
pe = torch.zeros(max_len, dim)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe, persistent=False)
def forward(self, positions: torch.LongTensor):
return self.pe.index_select(0, positions)
# -----------------------
# Blocks
# -----------------------
def build_norm(norm_type: str, dim: int, bias: bool):
if norm_type == "layernorm":
return nn.LayerNorm(dim, eps=1e-5)
raise ValueError(f"Unsupported norm_type={norm_type}")
class MLP(nn.Module):
def __init__(self, dim: int, dropout: float, bias: bool):
super().__init__()
self.fc1 = nn.Linear(dim, 4 * dim, bias=bias)
self.act = nn.GELU()
self.fc2 = nn.Linear(4 * dim, dim, bias=bias)
self.drop = nn.Dropout(dropout)
def forward(self, x):
return self.drop(self.fc2(self.act(self.fc1(x))))
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, bias: bool, impl: str):
super().__init__()
self.n_head = n_head
self.n_embd = n_embd
self.head_dim = n_embd // n_head
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.impl = impl
if impl == "separate_qkv":
self.q_proj = nn.Linear(n_embd, n_embd, bias=bias)
self.k_proj = nn.Linear(n_embd, n_embd, bias=bias)
self.v_proj = nn.Linear(n_embd, n_embd, bias=bias)
else:
# supported but fairseq uses separate projections
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
self.out_proj = nn.Linear(n_embd, n_embd, bias=bias)
def forward(self, x):
B, T, C = x.shape
if self.impl == "separate_qkv":
q = self.q_proj(x); k = self.k_proj(x); v = self.v_proj(x)
else:
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # B,nh,T,hd
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
att = att.masked_fill(mask[None, None, :, :], float("-inf"))
att = F.softmax(att, dim=-1)
y = torch.matmul(att, v) # B,nh,T,hd
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y)
class Block(nn.Module):
def __init__(self, cfg: UnitLMConfig):
super().__init__()
self.ln1 = build_norm(cfg.norm_type, cfg.n_embd, cfg.bias)
self.attn = CausalSelfAttention(cfg.n_embd, cfg.n_head, cfg.bias, cfg.attn_impl)
self.ln2 = build_norm(cfg.norm_type, cfg.n_embd, cfg.bias)
self.mlp = MLP(cfg.n_embd, cfg.dropout, cfg.bias)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
# -----------------------
# Unit Language Model
# -----------------------
class UnitLanguageModel(PreTrainedModel):
"""
Decoder-only Transformer LM for unit tokens, fairseq-compatible topology.
Provides:
- forward(input_ids[, tgt]) -> logits, optional CE loss
- encode(unit_str) / decode
- sample(), sample_top_hypotheses(), rollout()
"""
config_class = UnitLMConfig
def __init__(self, config: UnitLMConfig):
super().__init__(config)
self.cfg = config
# dictionary (loaded later in from_pretrained override)
self.dictionary: Optional[UnitDictionary] = None
# Embedding + positional
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
if config.pos_embed == "sinusoidal":
self.wpe = SinusoidalPositionalEmbedding(config.n_embd, config.max_position_embeddings)
elif config.pos_embed == "learned":
self.wpe = nn.Embedding(config.max_position_embeddings, config.n_embd)
else:
self.wpe = None # not typical for fairseq
# Transformer blocks
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.ln_f = build_norm(config.norm_type, config.n_embd, config.bias)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.wte.weight
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
# ------------- helpers for positions -------------
def _pos_emb(self, B: int, T: int, device) -> torch.Tensor:
pos = torch.arange(0, T, dtype=torch.long, device=device)
if isinstance(self.wpe, SinusoidalPositionalEmbedding):
return self.wpe(pos) # (T, C)
elif isinstance(self.wpe, nn.Embedding):
return self.wpe(pos)
return None
# ------------- forward -------------
def forward(
self,
input_ids: torch.LongTensor, # (B, T)
tgt: Optional[torch.LongTensor] = None,
return_dict: bool = True,
) -> CausalLMOutput:
B, T = input_ids.shape
tok = self.wte(input_ids) # (B, T, C)
if self.wpe is not None:
pos = self._pos_emb(B, T, input_ids.device)
x = self.dropout(tok + pos.unsqueeze(0))
else:
x = self.dropout(tok)
for blk in self.h:
x = blk(x)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, V)
loss = None
if tgt is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1),
ignore_index=self.cfg.pad_token_id)
if return_dict:
return CausalLMOutput(loss=loss, logits=logits)
return logits, loss
# ------------- dictionary I/O (parity with sampler) -------------
def encode(self, unit_str: str, append_eos: bool = False) -> torch.LongTensor:
if self.dictionary is None:
raise RuntimeError("Dictionary not loaded. Use from_pretrained(..., trust_remote_code=True).")
return self.dictionary.encode_line(unit_str, add_if_not_exist=False, append_eos=append_eos)
def _strip_pad(self, x: torch.LongTensor) -> torch.LongTensor:
if self.cfg.pad_token_id is None:
return x
return x[x != self.cfg.pad_token_id]
def _post_process_prediction(self, tokens: torch.LongTensor) -> str:
# mimic the fairseq util: strip pad, cut at eos, detokenize via dict
toks = self._strip_pad(tokens)
return self.dictionary.string(toks)
# ------------- sampling core -------------
@staticmethod
def _sample_token(logits: torch.Tensor, temperature: float,
top_k: int = 0, top_p: float = 0.0) -> torch.LongTensor:
# logits: (B, V)
if temperature <= 0:
return torch.argmax(logits, dim=-1)
logits = logits / max(1e-6, temperature)
if top_k and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
if top_p and top_p > 0.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
probs = F.softmax(sorted_logits, dim=-1)
cum = torch.cumsum(probs, dim=-1)
mask = cum > top_p
mask[:, 1:] = mask[:, :-1].clone()
mask[:, 0] = False
filtered = torch.full_like(sorted_logits, -float("inf"))
filtered[~mask] = sorted_logits[~mask]
logits = filtered.scatter(1, sorted_idx, filtered)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
def _target_len(self, src_len: int, max_len_a: float, max_len_b: int) -> int:
# fairseq-like: max_len = a*|src| + b
return int(max_len_a * src_len + max_len_b)
# ------------- rollout over token IDs -------------
@torch.no_grad()
def rollout(
self,
src_tokens: torch.LongTensor, # (B, T_prompt)
temperature: Optional[float] = None,
sampling: Optional[bool] = None,
beam: Optional[int] = None,
prefix_size: Optional[int] = None,
max_len_a: Optional[float] = None,
max_len_b: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
seed: Optional[int] = None,
stop_on_eos: bool = True,
return_full_sequence: bool = False,
) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
Autoregressively continue token IDs (IDs-in, IDs-out).
Returns (continuations, logits_per_step) where:
- continuations: (B, T_new)
- logits_per_step: (B, T_new, V)
"""
if seed is not None:
random.seed(seed); torch.manual_seed(seed)
cfg = self.cfg
temperature = cfg.generation_temperature if temperature is None else temperature
sampling = cfg.generation_sampling if sampling is None else sampling
beam = cfg.generation_beam if beam is None else beam
prefix_size = cfg.generation_prefix_size if prefix_size is None else prefix_size
max_len_a = cfg.generation_max_len_a if max_len_a is None else max_len_a
max_len_b = cfg.generation_max_len_b if max_len_b is None else max_len_b
top_k = cfg.generation_top_k if top_k is None else top_k
top_p = cfg.generation_top_p if top_p is None else top_p
if beam and beam > 1:
# Minimal beam search (no length penalty). For units, beam=1 is typical.
return self._beam_generate(
src_tokens, beam, temperature, prefix_size, max_len_a, max_len_b, stop_on_eos
)
device = next(self.parameters()).device
src_tokens = src_tokens.to(device)
B, T0 = src_tokens.shape
# how much to generate
tgt_len = self._target_len(T0, max_len_a, max_len_b)
# start state
seq = src_tokens.clone()
logits_steps = []
# generate step-by-step
for step in range(tgt_len):
out = self.forward(seq)
logits = out.logits[:, -1, :] # (B, V)
if sampling:
next_id = self._sample_token(logits, temperature, top_k=top_k, top_p=top_p).unsqueeze(1)
else:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
logits_steps.append(logits.unsqueeze(1))
seq = torch.cat([seq, next_id], dim=1)
if stop_on_eos and (next_id == self.cfg.eos_token_id).all():
break
continuations = seq[:, T0:] # (B, T_new)
all_logits = torch.cat(logits_steps, dim=1) if logits_steps else torch.empty(B, 0, self.cfg.vocab_size, device=device)
if return_full_sequence:
return seq, all_logits
return continuations, all_logits
# ------------- minimal beam search (optional) -------------
@torch.no_grad()
def _beam_generate(self, src_tokens, beam, temperature, prefix_size, max_len_a, max_len_b, stop_on_eos):
# Basic beam search; temperature is ignored (equivalent to greedy on beam branches)
device = next(self.parameters()).device
src_tokens = src_tokens.to(device)
B, T0 = src_tokens.shape
tgt_len = self._target_len(T0, max_len_a, max_len_b)
sequences = [[(0.0, src_tokens[b:b+1])] for b in range(B)] # list per batch of (score, seq)
finished = [[] for _ in range(B)]
for _ in range(tgt_len):
new_sequences = [[] for _ in range(B)]
for b in range(B):
cand = sequences[b]
all_exp = []
for score, seq in cand:
out = self.forward(seq)
logprobs = F.log_softmax(out.logits[:, -1, :], dim=-1) # (1, V)
top_scores, top_ids = torch.topk(logprobs, k=min(beam, logprobs.size(-1)), dim=-1)
for s, i in zip(top_scores[0].tolist(), top_ids[0].tolist()):
new_seq = torch.cat([seq, torch.tensor([[i]], device=device, dtype=torch.long)], dim=1)
all_exp.append((score + s, new_seq))
# keep top beam
all_exp.sort(key=lambda x: x[0], reverse=True)
sequences[b] = all_exp[:beam]
# move EOS to finished
remain = []
for sc, sq in sequences[b]:
if stop_on_eos and sq[0, -1].item() == self.cfg.eos_token_id:
finished[b].append((sc, sq))
else:
remain.append((sc, sq))
if remain:
sequences[b] = remain
else:
# if all finished, keep top finished and continue to next batch item
sequences[b] = finished[b][:beam] if finished[b] else sequences[b]
# choose best
outs = []
for b in range(B):
pool = finished[b] if finished[b] else sequences[b]
pool.sort(key=lambda x: x[0], reverse=True)
best = pool[0][1]
outs.append(best[:, T0:]) # continuation
maxlen = max(x.size(1) for x in outs)
outs = [F.pad(x, (0, 0, 0, maxlen - x.size(1)), value=self.cfg.pad_token_id) for x in outs]
return torch.cat(outs, dim=0), torch.empty(0) # logits per step omitted for beam
# ------------- String API (parity with your sampler) -------------
@torch.no_grad()
def sample(
self, sentences: List[str] | str, beam: int = 1, verbose: bool = False, **kwargs
):
hypos = self.sample_top_hypotheses(sentences, beam=beam, verbose=verbose, **kwargs)
if isinstance(sentences, str):
return hypos[0]
return [h[0] for h in hypos]
@torch.no_grad()
def sample_top_hypotheses(
self, sentences: List[str] | str, beam: int = 1, verbose: bool = False, **kwargs
) -> List[List[str]]:
if isinstance(sentences, str):
return self.sample_top_hypotheses([sentences], beam=beam, verbose=verbose, **kwargs)
# encode each sentence (units separated by spaces)
encoded = [self.encode(s) for s in sentences]
max_len = max(e.size(0) for e in encoded)
pad_id = self.cfg.pad_token_id
src = torch.stack([F.pad(e, (0, max_len - e.size(0)), value=pad_id) for e in encoded], dim=0).to(self.device)
# generation defaults & overrides
kwargs = dict(
temperature=kwargs.get("temperature", self.cfg.generation_temperature),
sampling=kwargs.get("sampling", self.cfg.generation_sampling),
beam=beam,
prefix_size=kwargs.get("prefix_size", self.cfg.generation_prefix_size),
max_len_a=kwargs.get("max_len_a", self.cfg.generation_max_len_a),
max_len_b=kwargs.get("max_len_b", self.cfg.generation_max_len_b),
top_k=kwargs.get("top_k", self.cfg.generation_top_k),
top_p=kwargs.get("top_p", self.cfg.generation_top_p),
seed=kwargs.get("seed", None),
stop_on_eos=True,
)
cont, _ = self.rollout(src, **kwargs) # (B, Tnew)
# post-process -> strings
outs: List[List[str]] = []
for b in range(src.size(0)):
full = torch.cat([src[b], cont[b]], dim=0)
outs.append([self._post_process_prediction(full)])
return outs
# ------------- Pretrained override to load dict.txt -------------
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# attach dictionary from file
repo_root = os.fspath(pretrained_model_name_or_path)
dict_path = os.path.join(repo_root, model.config.dict_file)
model.dictionary = UnitDictionary.from_file(dict_path)
return model