| |
| import math |
| import os |
| import random |
| from typing import List, Dict, Optional, Tuple, Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
|
|
| from .configuration_unit_lm import UnitLMConfig |
| from .units_dictionary import UnitDictionary |
|
|
|
|
| |
| |
| |
| class SinusoidalPositionalEmbedding(nn.Module): |
| def __init__(self, dim: int, max_len: int): |
| super().__init__() |
| pe = torch.zeros(max_len, dim) |
| pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer("pe", pe, persistent=False) |
|
|
| def forward(self, positions: torch.LongTensor): |
| return self.pe.index_select(0, positions) |
|
|
|
|
| |
| |
| |
| def build_norm(norm_type: str, dim: int, bias: bool): |
| if norm_type == "layernorm": |
| return nn.LayerNorm(dim, eps=1e-5) |
| raise ValueError(f"Unsupported norm_type={norm_type}") |
|
|
| class MLP(nn.Module): |
| def __init__(self, dim: int, dropout: float, bias: bool): |
| super().__init__() |
| self.fc1 = nn.Linear(dim, 4 * dim, bias=bias) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(4 * dim, dim, bias=bias) |
| self.drop = nn.Dropout(dropout) |
| def forward(self, x): |
| return self.drop(self.fc2(self.act(self.fc1(x)))) |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, n_embd: int, n_head: int, bias: bool, impl: str): |
| super().__init__() |
| self.n_head = n_head |
| self.n_embd = n_embd |
| self.head_dim = n_embd // n_head |
| assert n_embd % n_head == 0, "n_embd must be divisible by n_head" |
| self.impl = impl |
| if impl == "separate_qkv": |
| self.q_proj = nn.Linear(n_embd, n_embd, bias=bias) |
| self.k_proj = nn.Linear(n_embd, n_embd, bias=bias) |
| self.v_proj = nn.Linear(n_embd, n_embd, bias=bias) |
| else: |
| |
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) |
| self.out_proj = nn.Linear(n_embd, n_embd, bias=bias) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| if self.impl == "separate_qkv": |
| q = self.q_proj(x); k = self.k_proj(x); v = self.v_proj(x) |
| else: |
| qkv = self.c_attn(x) |
| q, k, v = qkv.split(self.n_embd, dim=-1) |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) |
| att = att.masked_fill(mask[None, None, :, :], float("-inf")) |
| att = F.softmax(att, dim=-1) |
| y = torch.matmul(att, v) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.out_proj(y) |
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: UnitLMConfig): |
| super().__init__() |
| self.ln1 = build_norm(cfg.norm_type, cfg.n_embd, cfg.bias) |
| self.attn = CausalSelfAttention(cfg.n_embd, cfg.n_head, cfg.bias, cfg.attn_impl) |
| self.ln2 = build_norm(cfg.norm_type, cfg.n_embd, cfg.bias) |
| self.mlp = MLP(cfg.n_embd, cfg.dropout, cfg.bias) |
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| |
| |
| |
| class UnitLanguageModel(PreTrainedModel): |
| """ |
| Decoder-only Transformer LM for unit tokens, fairseq-compatible topology. |
| Provides: |
| - forward(input_ids[, tgt]) -> logits, optional CE loss |
| - encode(unit_str) / decode |
| - sample(), sample_top_hypotheses(), rollout() |
| """ |
| config_class = UnitLMConfig |
|
|
| def __init__(self, config: UnitLMConfig): |
| super().__init__(config) |
| self.cfg = config |
|
|
| |
| self.dictionary: Optional[UnitDictionary] = None |
|
|
| |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| if config.pos_embed == "sinusoidal": |
| self.wpe = SinusoidalPositionalEmbedding(config.n_embd, config.max_position_embeddings) |
| elif config.pos_embed == "learned": |
| self.wpe = nn.Embedding(config.max_position_embeddings, config.n_embd) |
| else: |
| self.wpe = None |
|
|
| |
| self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) |
| self.ln_f = build_norm(config.norm_type, config.n_embd, config.bias) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.wte.weight |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
| |
| def _pos_emb(self, B: int, T: int, device) -> torch.Tensor: |
| pos = torch.arange(0, T, dtype=torch.long, device=device) |
| if isinstance(self.wpe, SinusoidalPositionalEmbedding): |
| return self.wpe(pos) |
| elif isinstance(self.wpe, nn.Embedding): |
| return self.wpe(pos) |
| return None |
|
|
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| tgt: Optional[torch.LongTensor] = None, |
| return_dict: bool = True, |
| ) -> CausalLMOutput: |
| B, T = input_ids.shape |
| tok = self.wte(input_ids) |
| if self.wpe is not None: |
| pos = self._pos_emb(B, T, input_ids.device) |
| x = self.dropout(tok + pos.unsqueeze(0)) |
| else: |
| x = self.dropout(tok) |
|
|
| for blk in self.h: |
| x = blk(x) |
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if tgt is not None: |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), |
| ignore_index=self.cfg.pad_token_id) |
|
|
| if return_dict: |
| return CausalLMOutput(loss=loss, logits=logits) |
| return logits, loss |
|
|
| |
| def encode(self, unit_str: str, append_eos: bool = False) -> torch.LongTensor: |
| if self.dictionary is None: |
| raise RuntimeError("Dictionary not loaded. Use from_pretrained(..., trust_remote_code=True).") |
| return self.dictionary.encode_line(unit_str, add_if_not_exist=False, append_eos=append_eos) |
|
|
| def _strip_pad(self, x: torch.LongTensor) -> torch.LongTensor: |
| if self.cfg.pad_token_id is None: |
| return x |
| return x[x != self.cfg.pad_token_id] |
|
|
| def _post_process_prediction(self, tokens: torch.LongTensor) -> str: |
| |
| toks = self._strip_pad(tokens) |
| return self.dictionary.string(toks) |
|
|
| |
| @staticmethod |
| def _sample_token(logits: torch.Tensor, temperature: float, |
| top_k: int = 0, top_p: float = 0.0) -> torch.LongTensor: |
| |
| if temperature <= 0: |
| return torch.argmax(logits, dim=-1) |
| logits = logits / max(1e-6, temperature) |
|
|
| if top_k and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
|
|
| if top_p and top_p > 0.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cum = torch.cumsum(probs, dim=-1) |
| mask = cum > top_p |
| mask[:, 1:] = mask[:, :-1].clone() |
| mask[:, 0] = False |
| filtered = torch.full_like(sorted_logits, -float("inf")) |
| filtered[~mask] = sorted_logits[~mask] |
| logits = filtered.scatter(1, sorted_idx, filtered) |
|
|
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
| def _target_len(self, src_len: int, max_len_a: float, max_len_b: int) -> int: |
| |
| return int(max_len_a * src_len + max_len_b) |
|
|
| |
| @torch.no_grad() |
| def rollout( |
| self, |
| src_tokens: torch.LongTensor, |
| temperature: Optional[float] = None, |
| sampling: Optional[bool] = None, |
| beam: Optional[int] = None, |
| prefix_size: Optional[int] = None, |
| max_len_a: Optional[float] = None, |
| max_len_b: Optional[int] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| seed: Optional[int] = None, |
| stop_on_eos: bool = True, |
| return_full_sequence: bool = False, |
| ) -> Tuple[torch.LongTensor, torch.Tensor]: |
| """ |
| Autoregressively continue token IDs (IDs-in, IDs-out). |
| Returns (continuations, logits_per_step) where: |
| - continuations: (B, T_new) |
| - logits_per_step: (B, T_new, V) |
| """ |
| if seed is not None: |
| random.seed(seed); torch.manual_seed(seed) |
|
|
| cfg = self.cfg |
| temperature = cfg.generation_temperature if temperature is None else temperature |
| sampling = cfg.generation_sampling if sampling is None else sampling |
| beam = cfg.generation_beam if beam is None else beam |
| prefix_size = cfg.generation_prefix_size if prefix_size is None else prefix_size |
| max_len_a = cfg.generation_max_len_a if max_len_a is None else max_len_a |
| max_len_b = cfg.generation_max_len_b if max_len_b is None else max_len_b |
| top_k = cfg.generation_top_k if top_k is None else top_k |
| top_p = cfg.generation_top_p if top_p is None else top_p |
|
|
| if beam and beam > 1: |
| |
| return self._beam_generate( |
| src_tokens, beam, temperature, prefix_size, max_len_a, max_len_b, stop_on_eos |
| ) |
|
|
| device = next(self.parameters()).device |
| src_tokens = src_tokens.to(device) |
| B, T0 = src_tokens.shape |
|
|
| |
| tgt_len = self._target_len(T0, max_len_a, max_len_b) |
|
|
| |
| seq = src_tokens.clone() |
| logits_steps = [] |
|
|
| |
| for step in range(tgt_len): |
| out = self.forward(seq) |
| logits = out.logits[:, -1, :] |
| if sampling: |
| next_id = self._sample_token(logits, temperature, top_k=top_k, top_p=top_p).unsqueeze(1) |
| else: |
| next_id = torch.argmax(logits, dim=-1, keepdim=True) |
| logits_steps.append(logits.unsqueeze(1)) |
| seq = torch.cat([seq, next_id], dim=1) |
|
|
| if stop_on_eos and (next_id == self.cfg.eos_token_id).all(): |
| break |
|
|
| continuations = seq[:, T0:] |
| all_logits = torch.cat(logits_steps, dim=1) if logits_steps else torch.empty(B, 0, self.cfg.vocab_size, device=device) |
|
|
| if return_full_sequence: |
| return seq, all_logits |
| return continuations, all_logits |
|
|
| |
| @torch.no_grad() |
| def _beam_generate(self, src_tokens, beam, temperature, prefix_size, max_len_a, max_len_b, stop_on_eos): |
| |
| device = next(self.parameters()).device |
| src_tokens = src_tokens.to(device) |
| B, T0 = src_tokens.shape |
| tgt_len = self._target_len(T0, max_len_a, max_len_b) |
|
|
| sequences = [[(0.0, src_tokens[b:b+1])] for b in range(B)] |
| finished = [[] for _ in range(B)] |
|
|
| for _ in range(tgt_len): |
| new_sequences = [[] for _ in range(B)] |
| for b in range(B): |
| cand = sequences[b] |
| all_exp = [] |
| for score, seq in cand: |
| out = self.forward(seq) |
| logprobs = F.log_softmax(out.logits[:, -1, :], dim=-1) |
| top_scores, top_ids = torch.topk(logprobs, k=min(beam, logprobs.size(-1)), dim=-1) |
| for s, i in zip(top_scores[0].tolist(), top_ids[0].tolist()): |
| new_seq = torch.cat([seq, torch.tensor([[i]], device=device, dtype=torch.long)], dim=1) |
| all_exp.append((score + s, new_seq)) |
| |
| all_exp.sort(key=lambda x: x[0], reverse=True) |
| sequences[b] = all_exp[:beam] |
| |
| remain = [] |
| for sc, sq in sequences[b]: |
| if stop_on_eos and sq[0, -1].item() == self.cfg.eos_token_id: |
| finished[b].append((sc, sq)) |
| else: |
| remain.append((sc, sq)) |
| if remain: |
| sequences[b] = remain |
| else: |
| |
| sequences[b] = finished[b][:beam] if finished[b] else sequences[b] |
|
|
| |
| outs = [] |
| for b in range(B): |
| pool = finished[b] if finished[b] else sequences[b] |
| pool.sort(key=lambda x: x[0], reverse=True) |
| best = pool[0][1] |
| outs.append(best[:, T0:]) |
| maxlen = max(x.size(1) for x in outs) |
| outs = [F.pad(x, (0, 0, 0, maxlen - x.size(1)), value=self.cfg.pad_token_id) for x in outs] |
| return torch.cat(outs, dim=0), torch.empty(0) |
|
|
| |
| @torch.no_grad() |
| def sample( |
| self, sentences: List[str] | str, beam: int = 1, verbose: bool = False, **kwargs |
| ): |
| hypos = self.sample_top_hypotheses(sentences, beam=beam, verbose=verbose, **kwargs) |
| if isinstance(sentences, str): |
| return hypos[0] |
| return [h[0] for h in hypos] |
|
|
| @torch.no_grad() |
| def sample_top_hypotheses( |
| self, sentences: List[str] | str, beam: int = 1, verbose: bool = False, **kwargs |
| ) -> List[List[str]]: |
| if isinstance(sentences, str): |
| return self.sample_top_hypotheses([sentences], beam=beam, verbose=verbose, **kwargs) |
|
|
| |
| encoded = [self.encode(s) for s in sentences] |
| max_len = max(e.size(0) for e in encoded) |
| pad_id = self.cfg.pad_token_id |
| src = torch.stack([F.pad(e, (0, max_len - e.size(0)), value=pad_id) for e in encoded], dim=0).to(self.device) |
|
|
| |
| kwargs = dict( |
| temperature=kwargs.get("temperature", self.cfg.generation_temperature), |
| sampling=kwargs.get("sampling", self.cfg.generation_sampling), |
| beam=beam, |
| prefix_size=kwargs.get("prefix_size", self.cfg.generation_prefix_size), |
| max_len_a=kwargs.get("max_len_a", self.cfg.generation_max_len_a), |
| max_len_b=kwargs.get("max_len_b", self.cfg.generation_max_len_b), |
| top_k=kwargs.get("top_k", self.cfg.generation_top_k), |
| top_p=kwargs.get("top_p", self.cfg.generation_top_p), |
| seed=kwargs.get("seed", None), |
| stop_on_eos=True, |
| ) |
| cont, _ = self.rollout(src, **kwargs) |
|
|
| |
| outs: List[List[str]] = [] |
| for b in range(src.size(0)): |
| full = torch.cat([src[b], cont[b]], dim=0) |
| outs.append([self._post_process_prediction(full)]) |
| return outs |
|
|
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
| model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
| |
| repo_root = os.fspath(pretrained_model_name_or_path) |
| dict_path = os.path.join(repo_root, model.config.dict_file) |
| model.dictionary = UnitDictionary.from_file(dict_path) |
| return model |
|
|