diff --git "a/x_transformer_2_3_1.py" "b/x_transformer_2_3_1.py" --- "a/x_transformer_2_3_1.py" +++ "b/x_transformer_2_3_1.py" @@ -4,7 +4,7 @@ # # Partial x-transformers code With useful modifications as a stand-alone Python module # -# Version 1.0 +# Version 10.0 # # Original source code courtesy of lucidrains # https://github.com/lucidrains/x-transformers @@ -13,7 +13,7 @@ # Original version 2.3.1 / Commit 458bc12 # # Project Los Angeles -# Tegridy Code 2025 +# Tegridy Code 2026 # #=================================================================================================================== # @@ -22,6 +22,9 @@ # !pip install torch # !pip install einops # !pip install einx +# !pip install numpy +# !pip install scikit-learn +# !pip install matplotlib # #=================================================================================================================== @@ -45,6 +48,7 @@ import torch from torch.nn import Module from torch import nn, einsum, Tensor import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader from collections import namedtuple from functools import wraps @@ -3982,7 +3986,7 @@ class AutoregressiveWrapper(Module): # whether to add router z-loss self.add_attn_z_loss = add_attn_z_loss - @torch.no_grad() + @torch.inference_mode() @eval_decorator def generate( self, @@ -4146,63 +4150,3955 @@ class AutoregressiveWrapper(Module): out, = unpack(out, ps, '* n') return out - - def compute_accuracy(self, logits, labels): - - out = torch.argmax(logits, dim=-1) - out = out.flatten() - labels = labels.flatten() - mask = (labels != self.ignore_index) # can also be self.pad_value (your choice) - out = out[mask] - labels = labels[mask] + @torch.inference_mode() + @eval_decorator + def generate_masked( + self, + prompts, + seq_len, + eos_token = None, + temperature = 1., + prompt_lens: Tensor | None = None, + filter_logits_fn: str | Callable = top_k, + restrict_to_max_seq_len = True, + amateur_model: Module | Tuple[Module] | None = None, + filter_kwargs: dict = dict(), + contrastive_decode_kwargs: dict | Tuple[dict] = dict( + beta = 0.5, + alpha = 0.1 + ), + cache_kv = True, + return_prime=False, + verbose=True, + masked_token_ids: list[int] | Tensor | None = None, + **kwargs + ): + max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device - num_right = (out == labels) - num_right = torch.sum(num_right).type(torch.float32) + prompts, ps = pack([prompts], '* n') - acc = num_right / len(labels) - - return acc + b, t = prompts.shape - def forward(self, x, return_outputs = False, **kwargs): - seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss + # handle filter logits fn given as string + if isinstance(filter_logits_fn, str): + assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" + filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] - inp, target = x[:, :-1], x[:, 1:] - inp = torch.where(inp == ignore_index, self.pad_value, inp) + # prepare masked token ids tensor (if any) + if masked_token_ids is not None: + if not torch.is_tensor(masked_token_ids): + masked_token_ids = torch.tensor(masked_token_ids, dtype=torch.long, device=device) + else: + masked_token_ids = masked_token_ids.to(device=device, dtype=torch.long) + # keep unique and non-negative + masked_token_ids = torch.unique(masked_token_ids) + # remove any ids that are out of range (optional safety) + # we can't know vocab size here, so we only remove negative ids + masked_token_ids = masked_token_ids[masked_token_ids >= 0] + else: + masked_token_ids = None - if self.mask_prob > 0.: - rand = torch.randn(inp.shape, device = x.device) - rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out - num_mask = min(int(seq * self.mask_prob), seq - 1) - indices = rand.topk(num_mask, dim = -1).indices - mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() - kwargs.update(self_attn_kv_mask = mask) + # handle variable lengthed prompts (prefixes) + seq_start_pos = None + if exists(prompt_lens): + prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) + seq_start_pos = t - prompt_lens - logits, cache = self.net( - inp, - return_intermediates = True, - return_attn_z_loss = add_attn_z_loss, - **kwargs - ) - - acc = self.compute_accuracy(logits, target) + # output from which sampled tokens appended to + out = prompts - loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss + if verbose: + print("Generating sequence of max length:", seq_len) - loss = loss_fn( - rearrange(logits, 'b n c -> b c n'), - target, - ignore_index = ignore_index - ) + # kv caches + cache = None - if add_attn_z_loss: - loss = loss + cache.attn_z_loss + # if doing contrastive decoding, turn off filter automatically + if exists(amateur_model): + amateur_model = cast_tuple(amateur_model) + contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) - if not return_outputs: - return loss, acc + assert len(amateur_model) == len(contrastive_decode_kwargs) - return loss, acc, logits, cache + amateur_caches = [None] * len(amateur_model) + filter_logits_fn = identity + + for i, module in enumerate(amateur_model): + if isinstance(module, AutoregressiveWrapper): + amateur_model[i] = module.net + + module.eval() + + # sampling up to seq_len + for sl in range(seq_len): + + if restrict_to_max_seq_len: + max_len_exceeded = out.shape[-1] > max_seq_len + + assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' + + x = out[:, -max_seq_len:] + + if exists(cache): + for inter in cache.attn_intermediates: + if inter.layer_type == 'a': + inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] + + logits, new_cache = self.net( + x, + return_intermediates = True, + cache = cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + if cache_kv and self.net.can_cache_kv: + cache = new_cache + + logits = logits[:, -1] + + # handle contrastive decoding, Li et al. + # https://arxiv.org/abs/2210.15097 + if exists(amateur_model): + for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)): + amateur_logits, next_amateur_cache = amateur( + x, + return_intermediates = True, + cache = amateur_cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + amateur_logits = amateur_logits[:, -1] + + assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model' + logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) + + if cache_kv and amateur.can_cache_kv: + amateur_caches[i] = next_amateur_cache + + # --- apply masked token ids here (after contrastive decoding, before filtering/sampling) + if masked_token_ids is not None and masked_token_ids.numel() > 0: + # safety: ensure indices are within logits' vocab dimension + vocab_size = logits.shape[-1] + valid_masked = masked_token_ids[masked_token_ids < vocab_size] + if valid_masked.numel() > 0: + # set logits for masked ids to a very large negative value + neg_inf = -1e9 + # logits shape: (batch, vocab) + logits[:, valid_masked] = neg_inf + + # filter by top_k, top_p (nucleus), top_a, or custom + if greedy: + sample = logits.argmax(dim = -1, keepdim = True) + else: + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + # concat sample + out = torch.cat((out, sample), dim=-1) + + if verbose: + if sl % 32 == 0: + print(sl, '/', seq_len) + + if not exists(eos_token): + continue + + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + + if verbose: + print('Model called the end of sequence at:', sl, '/', seq_len) + + break + + if exists(eos_token): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, self.pad_value) + + if return_prime: + out = out[:, :] + + else: + out = out[:, t:] + + out, = unpack(out, ps, '* n') + + return out + @torch.inference_mode() + @eval_decorator + def generate_biased( + self, + prompts, + seq_len, + eos_token = None, + temperature = 1., + prompt_lens: Tensor | None = None, + filter_logits_fn: str | Callable = top_k, + restrict_to_max_seq_len = True, + amateur_model: Module | Tuple[Module] | None = None, + filter_kwargs: dict = dict(), + contrastive_decode_kwargs: dict | Tuple[dict] = dict( + beta = 0.5, + alpha = 0.1 + ), + cache_kv = True, + return_prime=False, + verbose=True, + logit_bias: dict | Tensor | None = None, # <-- new parameter + **kwargs + ): + """ + Autoregressive generation with optional additive logit bias. + + logit_bias: + - dict[token_id -> float] OR + - torch.Tensor of shape (vocab,) OR (batch, vocab) + """ + + max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device + + prompts, ps = pack([prompts], '* n') + + b, t = prompts.shape + + # handle filter logits fn given as string + if isinstance(filter_logits_fn, str): + assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" + filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] + + # handle variable lengthed prompts (prefixes) + seq_start_pos = None + if exists(prompt_lens): + prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) + seq_start_pos = t - prompt_lens + + # output from which sampled tokens appended to + out = prompts + + if verbose: + print("Generating sequence of max length:", seq_len) + + # kv caches + cache = None + + # if doing contrastive decoding, turn off filter automatically + if exists(amateur_model): + amateur_model = cast_tuple(amateur_model) + contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) + assert len(amateur_model) == len(contrastive_decode_kwargs) + amateur_caches = [None] * len(amateur_model) + filter_logits_fn = identity + for i, module in enumerate(amateur_model): + if isinstance(module, AutoregressiveWrapper): + amateur_model[i] = module.net + module.eval() + + # ------------------------- + # Prepare logit_bias (robust vocab-size detection) + # ------------------------- + prepared_bias = None + lazy_build_bias_from_dict = None + + if exists(logit_bias): + if isinstance(logit_bias, dict): + # try to determine vocab size from model without using logits + vocab_size = None + + # common places to find vocab size + try: + if hasattr(self.net, "config") and getattr(self.net.config, "vocab_size", None) is not None: + vocab_size = int(self.net.config.vocab_size) + elif getattr(self.net, "vocab_size", None) is not None: + vocab_size = int(self.net.vocab_size) + else: + # try to infer from embedding / output projection weights + # huggingface style: get_output_embeddings() or embed_tokens or lm_head + get_out = getattr(self.net, "get_output_embeddings", None) + if callable(get_out) and get_out() is not None: + vocab_size = int(get_out().weight.shape[0]) + elif hasattr(self.net, "embed_tokens"): + vocab_size = int(self.net.embed_tokens.weight.shape[0]) + elif hasattr(self.net, "lm_head"): + vocab_size = int(self.net.lm_head.weight.shape[0]) + except Exception: + vocab_size = None + + if vocab_size is not None: + bias_vec = torch.zeros(int(vocab_size), device=device, dtype=torch.float32) + for tok, val in logit_bias.items(): + tok_i = int(tok) + if tok_i < 0 or tok_i >= vocab_size: + raise IndexError(f"logit_bias token id {tok_i} out of range for vocab size {vocab_size}") + bias_vec[tok_i] = float(val) + prepared_bias = bias_vec + else: + # can't determine vocab size yet — build lazily after first logits are available + lazy_build_bias_from_dict = {int(k): float(v) for k, v in logit_bias.items()} + + elif isinstance(logit_bias, torch.Tensor): + prepared_bias = logit_bias.to(device=device, dtype=torch.float32) + else: + raise TypeError("logit_bias must be dict or torch.Tensor") + + # sampling up to seq_len + for sl in range(seq_len): + + if restrict_to_max_seq_len: + max_len_exceeded = out.shape[-1] > max_seq_len + assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \ + 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' + x = out[:, -max_seq_len:] + if exists(cache): + for inter in cache.attn_intermediates: + if inter.layer_type == 'a': + inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] + else: + x = out + + logits, new_cache = self.net( + x, + return_intermediates = True, + cache = cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + if cache_kv and self.net.can_cache_kv: + cache = new_cache + + logits = logits[:, -1] # shape (batch, vocab) + + # If we couldn't build the bias earlier because vocab size was unknown, + # build it now from the first logits tensor. + if lazy_build_bias_from_dict is not None: + vocab_size = logits.shape[-1] + bias_vec = torch.zeros(vocab_size, device=device, dtype=torch.float32) + for tok, val in lazy_build_bias_from_dict.items(): + if tok < 0 or tok >= vocab_size: + raise IndexError(f"logit_bias token id {tok} out of range for vocab size {vocab_size}") + bias_vec[tok] = val + prepared_bias = bias_vec + lazy_build_bias_from_dict = None # only build once + + # handle contrastive decoding, Li et al. + # https://arxiv.org/abs/2210.15097 + if exists(amateur_model): + for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)): + amateur_logits, next_amateur_cache = amateur( + x, + return_intermediates = True, + cache = amateur_cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + amateur_logits = amateur_logits[:, -1] + assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model' + logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) + if cache_kv and amateur.can_cache_kv: + amateur_caches[i] = next_amateur_cache + + # ------------------------- + # Apply logit bias if provided + # ------------------------- + if exists(prepared_bias): + # prepared_bias can be (vocab,) or (batch, vocab) + if prepared_bias.dim() == 1: + # broadcast to batch + logits = logits + prepared_bias.unsqueeze(0) + elif prepared_bias.dim() == 2: + # expect shape (batch, vocab) + if prepared_bias.shape[0] != logits.shape[0]: + raise ValueError("logit_bias tensor batch size must match logits batch size") + logits = logits + prepared_bias + else: + raise ValueError("logit_bias tensor must be 1D (vocab,) or 2D (batch, vocab)") + + # filter by top_k, top_p (nucleus), top_a, or custom + if greedy: + sample = logits.argmax(dim = -1, keepdim = True) + else: + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + # concat sample + out = torch.cat((out, sample), dim=-1) + + if verbose: + if sl % 32 == 0: + print(sl, '/', seq_len) + + if not exists(eos_token): + continue + + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + if verbose: + print('Model called the end of sequence at:', sl, '/', seq_len) + break + + if exists(eos_token): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, self.pad_value) + + if return_prime: + out = out[:, :] + else: + out = out[:, t:] + + out, = unpack(out, ps, '* n') + + return out + + @torch.inference_mode() + @eval_decorator + def generate_advanced( + self, + prompts, + seq_len, + eos_token = None, + temperature = 1., + prompt_lens: Tensor | None = None, + filter_logits_fn: str | Callable = top_k, + restrict_to_max_seq_len = True, + amateur_model: Module | Tuple[Module] | None = None, + filter_kwargs: dict = dict(), + contrastive_decode_kwargs: dict | Tuple[dict] = dict( + beta = 0.5, + alpha = 0.1 + ), + cache_kv = True, + return_prime=False, + verbose=True, + # --- new generation options --- + logits_bias: dict | None = None, # {token_id: bias_value} where bias_value is float or Tensor(batch,) + masked_tokens: list | Tensor | None = None, # list of token ids to forbid + # --- binary classifier mode --- + binary_classifier: bool = False, # if True, run classifier snippet and return preds, probs + classifier_model: Module | None = None, # model to use for binary classification + batches: list | None = None, # iterable of input batches for classifier_model + threshold: float = 0.5, # threshold for converting probs to preds + classifier_device: torch.device | None = None, + # ----------------- + **kwargs + ): + # If binary classifier mode requested, run the provided snippet and return early. + if binary_classifier: + assert classifier_model is not None, "classifier_model must be provided when binary_classifier=True" + assert batches is not None, "batches (iterable of input tensors) must be provided when binary_classifier=True" + + device = classifier_device if classifier_device is not None else (prompts.device if exists(prompts) else torch.device('cpu')) + + all_probs = [] + all_preds = [] + + classifier_model.eval() + with torch.no_grad(): + for x in batches: + x = x.to(device) + logits = classifier_model(x).squeeze() # [B] + probs = torch.sigmoid(logits) # [B] + preds = (probs >= threshold).long() + + all_probs.extend(probs.cpu().tolist()) + all_preds.extend(preds.cpu().tolist()) + + return all_preds, all_probs + + # --- normal generation path below --- + max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device + + prompts, ps = pack([prompts], '* n') + + b, t = prompts.shape + + # handle filter logits fn given as string + if isinstance(filter_logits_fn, str): + assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" + filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] + + # handle variable lengthed prompts (prefixes) + seq_start_pos = None + if exists(prompt_lens): + prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) + seq_start_pos = t - prompt_lens + + # output from which sampled tokens appended to + out = prompts + + if verbose: + print("Generating sequence of max length:", seq_len) + + # kv caches + cache = None + + # if doing contrastive decoding, turn off filter automatically + if exists(amateur_model): + amateur_model = cast_tuple(amateur_model) + contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) + + assert len(amateur_model) == len(contrastive_decode_kwargs) + + amateur_caches = [None] * len(amateur_model) + filter_logits_fn = identity + + for i, module in enumerate(amateur_model): + if isinstance(module, AutoregressiveWrapper): + amateur_model[i] = module.net + + module.eval() + + # normalize inputs for new args + if exists(logits_bias): + assert isinstance(logits_bias, dict), "logits_bias must be a dict {token_id: bias_value}" + if exists(masked_tokens): + if isinstance(masked_tokens, torch.Tensor): + masked_tokens = masked_tokens.tolist() + else: + masked_tokens = list(masked_tokens) + + # sampling up to seq_len + for sl in range(seq_len): + + if restrict_to_max_seq_len: + max_len_exceeded = out.shape[-1] > max_seq_len + + assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' + + x = out[:, -max_seq_len:] + + if exists(cache): + for inter in cache.attn_intermediates: + if inter.layer_type == 'a': + inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] + + logits, new_cache = self.net( + x, + return_intermediates = True, + cache = cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + if cache_kv and self.net.can_cache_kv: + cache = new_cache + + logits = logits[:, -1] # shape: (batch, vocab) + + # handle contrastive decoding, Li et al. + if exists(amateur_model): + for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)): + amateur_logits, next_amateur_cache = amateur( + x, + return_intermediates = True, + cache = amateur_cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + amateur_logits = amateur_logits[:, -1] + + assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model' + logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) + + if cache_kv and amateur.can_cache_kv: + amateur_caches[i] = next_amateur_cache + + # --- APPLY LOGITS BIAS AND MASKING HERE (before filtering / softmax) --- + # logits_bias: dict {token_id: bias_value} where bias_value is float or Tensor(batch,) + if exists(logits_bias): + # apply per-token bias updates directly to logits to avoid allocating full vocab bias tensor + for tok_id, bias_val in logits_bias.items(): + # support scalar or per-batch tensor + if isinstance(bias_val, torch.Tensor): + if bias_val.dim() == 1 and bias_val.shape[0] == b: + bias_to_add = bias_val.to(device) + else: + bias_to_add = bias_val.to(device).view(1).expand(b) + else: + bias_to_add = torch.tensor(float(bias_val), device=device).view(1).expand(b) + + logits[:, int(tok_id)] = logits[:, int(tok_id)] + bias_to_add + + # masked_tokens: list of token ids to forbid + if exists(masked_tokens) and len(masked_tokens) > 0: + NEG_INF = -1e9 + idx = torch.tensor(masked_tokens, device=device, dtype=torch.long) + idx = idx[(idx >= 0) & (idx < logits.shape[-1])] + if idx.numel() > 0: + logits.index_fill_(dim=-1, index=idx, value=NEG_INF) + # ------------------------------------------------------------------- + + # filter by top_k, top_p (nucleus), top_a, or custom + if greedy: + sample = logits.argmax(dim = -1, keepdim = True) + else: + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + # concat sample + out = torch.cat((out, sample), dim=-1) + + if verbose: + if sl % 32 == 0: + print(sl, '/', seq_len) + + if not exists(eos_token): + continue + + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + if verbose: + print('Model called the end of sequence at:', sl, '/', seq_len) + break + + if exists(eos_token): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, self.pad_value) + + if return_prime: + out = out[:, :] + + else: + out = out[:, t:] + + out, = unpack(out, ps, '* n') + + return out + + def compute_accuracy(self, logits, labels): + + out = torch.argmax(logits, dim=-1) + out = out.flatten() + labels = labels.flatten() + + mask = (labels != self.ignore_index) # can also be self.pad_value (your choice) + out = out[mask] + labels = labels[mask] + + num_right = (out == labels) + num_right = torch.sum(num_right).type(torch.float32) + + acc = num_right / len(labels) + + return acc + + def forward(self, x, return_outputs = False, **kwargs): + seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss + + inp, target = x[:, :-1], x[:, 1:] + inp = torch.where(inp == ignore_index, self.pad_value, inp) + + if self.mask_prob > 0.: + rand = torch.randn(inp.shape, device = x.device) + rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out + num_mask = min(int(seq * self.mask_prob), seq - 1) + indices = rand.topk(num_mask, dim = -1).indices + mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() + kwargs.update(self_attn_kv_mask = mask) + + logits, cache = self.net( + inp, + return_intermediates = True, + return_attn_z_loss = add_attn_z_loss, + **kwargs + ) + + acc = self.compute_accuracy(logits, target) + + loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss + + loss = loss_fn( + rearrange(logits, 'b n c -> b c n'), + target, + ignore_index = ignore_index + ) + + if add_attn_z_loss: + loss = loss + cache.attn_z_loss + + if not return_outputs: + return loss, acc + + return loss, acc, logits, cache + + @torch.inference_mode() + @eval_decorator + def generate_expert( + self, + prompts, + seq_len, + eos_token = None, + temperature = 1., + prompt_lens: Tensor | None = None, + filter_logits_fn: str | Callable = top_k, + restrict_to_max_seq_len = True, + amateur_model: Module | Tuple[Module] | None = None, + filter_kwargs: dict = dict(), + contrastive_decode_kwargs: dict | Tuple[dict] = dict( + beta = 0.5, + alpha = 0.1 + ), + cache_kv = True, + return_prime=False, + verbose=True, + # --- new controls --- + token_type_ids: torch.LongTensor | None = None, # [vocab] + type_temperatures: dict | None = None, # {type_id: temp} + type_biases: dict | None = None, # {type_id: bias} + repetition_window: int = 64, + repetition_penalty_per_type: dict | None = None, # {type_id: penalty_scale} + rare_types: set | None = None, # e.g. {4, 5} + rare_type_boost: float = 0.0, # small, e.g. 0.5 + entropy_threshold: float = 2.0, # when below, boost rare types + # --- masked tokens option --- + forbidden_token_ids: torch.LongTensor | torch.BoolTensor | None = None, + forbidden_value: float = -1e9, + **kwargs + ): + max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device + + prompts, ps = pack([prompts], '* n') + + b, t = prompts.shape + + # handle filter logits fn given as string + + if isinstance(filter_logits_fn, str): + assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" + filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] + + # handle variable lengthed prompts (prefixes) + + seq_start_pos = None + if exists(prompt_lens): + prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) + seq_start_pos = t - prompt_lens + + # output from which sampled tokens appended to + + out = prompts + + if verbose: + print("Generating sequence of max length:", seq_len) + + # kv caches + + cache = None + + # if doing contrastive decoding, turn off filter automatically + + if exists(amateur_model): + amateur_model = cast_tuple(amateur_model) + contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) + + assert len(amateur_model) == len(contrastive_decode_kwargs) + + amateur_caches = [None] * len(amateur_model) + filter_logits_fn = identity + + for i, module in enumerate(amateur_model): + if isinstance(module, AutoregressiveWrapper): + amateur_model[i] = module.net + + module.eval() + + # precompute some tensors for type controls + + if token_type_ids is not None: + token_type_ids = token_type_ids.to(device) + + # build per-token temperature and bias vectors if provided + per_token_temp = None + if type_temperatures is not None and len(type_temperatures) > 0: + per_token_temp = torch.ones_like(token_type_ids, dtype=torch.float32) + for type_id, temp_val in type_temperatures.items(): + per_token_temp[token_type_ids == type_id] = float(temp_val) + + per_token_bias = None + if type_biases is not None and len(type_biases) > 0: + per_token_bias = torch.zeros_like(token_type_ids, dtype=torch.float32) + for type_id, bias_val in type_biases.items(): + per_token_bias[token_type_ids == type_id] = float(bias_val) + + # repetition penalty per type + per_type_rep_penalty = repetition_penalty_per_type or {} + + # rare type mask + rare_type_mask = None + if rare_types is not None and len(rare_types) > 0: + rare_type_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) + for rt in rare_types: + rare_type_mask |= (token_type_ids == rt) + else: + per_token_temp = None + per_token_bias = None + per_type_rep_penalty = {} + rare_type_mask = None + + # prepare forbidden mask if provided + # We'll lazily convert forbidden_token_ids into a boolean mask of shape [b, vocab] + forbidden_mask_per_batch = None + if forbidden_token_ids is not None: + # If it's a LongTensor of ids (1D) + if forbidden_token_ids.dtype in (torch.int64, torch.int32): + # create a [vocab] bool mask from ids + vocab_size = self.net.config.vocab_size if hasattr(self.net, 'config') else None + # If we can't infer vocab_size, we'll infer from token_type_ids if available + if vocab_size is None and token_type_ids is not None: + vocab_size = token_type_ids.shape[0] + assert vocab_size is not None, "Cannot infer vocab size for forbidden_token_ids; provide a boolean mask instead." + mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) + ids = forbidden_token_ids.to(device) + mask[ids.clamp(0, vocab_size-1)] = True + forbidden_mask_per_batch = mask.unsqueeze(0).expand(b, -1) # [b, vocab] + elif forbidden_token_ids.dtype == torch.bool: + # could be [vocab] or [b, vocab] + if forbidden_token_ids.dim() == 1: + forbidden_mask_per_batch = forbidden_token_ids.to(device).unsqueeze(0).expand(b, -1) + elif forbidden_token_ids.dim() == 2: + assert forbidden_token_ids.shape[0] == b, "forbidden_token_ids batch dimension must match prompts batch size" + forbidden_mask_per_batch = forbidden_token_ids.to(device) + else: + raise ValueError("forbidden_token_ids boolean mask must be 1D [vocab] or 2D [b, vocab]") + else: + raise TypeError("forbidden_token_ids must be LongTensor of ids or BoolTensor mask") + + # sampling up to seq_len + + for sl in range(seq_len): + + if restrict_to_max_seq_len: + max_len_exceeded = out.shape[-1] > max_seq_len + + assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \ + 'the network cannot use cached key values when decoding outside the max sequence length. ' \ + 'most likely because you are using absolute positional embedding. ' \ + 'you can switch to rotary embeddings to resolve this issue' + + x = out[:, -max_seq_len:] + + if exists(cache): + for inter in cache.attn_intermediates: + if inter.layer_type == 'a': + inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] + + logits, new_cache = self.net( + x, + return_intermediates = True, + cache = cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + if cache_kv and self.net.can_cache_kv: + cache = new_cache + + logits = logits[:, -1] # [b, vocab] + + # handle contrastive decoding + + if exists(amateur_model): + for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate( + zip(amateur_model, amateur_caches, contrastive_decode_kwargs) + ): + amateur_logits, next_amateur_cache = amateur( + x, + return_intermediates = True, + cache = amateur_cache, + seq_start_pos = seq_start_pos, + **kwargs + ) + + amateur_logits = amateur_logits[:, -1] + + assert amateur_logits.shape == logits.shape, \ + 'logits dimension are not the same between amateur and expert model' + logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) + + if cache_kv and amateur.can_cache_kv: + amateur_caches[i] = next_amateur_cache + + # --------- STRUCTURED LOGIT SHAPING (no training) --------- + + if token_type_ids is not None: + + # 1) per-token bias (type-aware) + if per_token_bias is not None: + logits = logits + per_token_bias # broadcast [vocab] + + # 2) repetition penalty per type (context-aware) + if repetition_window > 0 and len(per_type_rep_penalty) > 0: + # look at recent tokens + recent = out[:, -repetition_window:].to(device) # [b, w] + # map to types + recent_types = token_type_ids[recent] # [b, w] + + # for each type, compute frequency and apply penalty + # we do this per batch element + for bi in range(b): + types_b = recent_types[bi] # [w] + if types_b.numel() == 0: + continue + # count occurrences per type id present in penalties + for type_id, penalty_scale in per_type_rep_penalty.items(): + # penalty_scale > 1.0 means stronger penalty + mask = (types_b == type_id) + if mask.any(): + freq = mask.float().mean().item() # 0..1 + if freq > 0.0: + # build a penalty vector for this type + type_mask = (token_type_ids == type_id) # [vocab] + # subtract a penalty proportional to freq + # (log-space penalty) + logits[bi, type_mask] /= (1.0 + freq * (penalty_scale - 1.0)) + + # 3) entropy-based rare-type boost (gentle, context-aware) + if rare_type_mask is not None and rare_type_boost > 0.0: + # compute current probs & entropy (before global temperature) + probs_raw = F.softmax(logits, dim=-1) # [b, vocab] + log_probs_raw = torch.log(probs_raw + 1e-9) + entropy = -(probs_raw * log_probs_raw).sum(dim=-1) # [b] + + # for low-entropy states, gently boost rare types + low_entropy = entropy < entropy_threshold + if low_entropy.any(): + # boost only for those batch elements + boost_vec = torch.zeros_like(logits) + boost_vec[:, rare_type_mask] = rare_type_boost + logits = torch.where( + low_entropy.unsqueeze(-1), + logits + boost_vec, + logits + ) + + # 4) per-token temperature (type-aware) + # apply before global temperature + if per_token_temp is not None: + # divide logits by per-token temperature + # (smaller temp -> sharper distribution for that type) + logits = logits / per_token_temp + + # --------- APPLY FORBIDDEN TOKEN MASK --------- + if forbidden_mask_per_batch is not None: + # ensure shapes match + assert forbidden_mask_per_batch.shape[0] == b and forbidden_mask_per_batch.shape[1] == logits.shape[-1], \ + "forbidden mask shape must be [b, vocab]" + # set logits for forbidden tokens to a large negative value + logits = logits.masked_fill(forbidden_mask_per_batch, float(forbidden_value)) + + # ---------------------------------------------------------- + + # filter by top_k, top_p (nucleus), top_a, or custom + + if greedy: + sample = logits.argmax(dim = -1, keepdim = True) + else: + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + # concat sample + + out = torch.cat((out, sample), dim=-1) + + if verbose: + if sl % 32 == 0: + print(sl, '/', seq_len) + + if not exists(eos_token): + continue + + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + + if verbose: + print('Model called the end of sequence at:', sl, '/', seq_len) + + break + + if exists(eos_token): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, self.pad_value) + + if return_prime: + out = out[:, :] + else: + out = out[:, t:] + + out, = unpack(out, ps, '* n') + + return out + +#================================================================================================================================= +# Binary classifier fuctions +# https://github.com/lucidrains/x-transformers/pull/264 +#================================================================================================================================= + +class ClsInferenceDataset(Dataset): + """ + Dataset for pairs (src_seq, label). + src_seq: list of token IDs (ints). + label: single int or float (0 or 1). + """ + def __init__(self, data_pairs): + self.data_pairs = data_pairs + + def __len__(self): + return len(self.data_pairs) + + def __getitem__(self, idx): + src_seq = self.data_pairs[idx] + x = torch.tensor(src_seq, dtype=torch.long) + return x + +def build_cls_model(num_tokens=18819, + max_seq_len=1024, + logits_dim=1, + use_cls_token=True, + squeeze_out_last_dim=True, + dim=1024, + depth=8, + heads=8, + device='cuda' + ): + + """ + Constructs the Transformer model that outputs a single logit per input. + """ + + model = TransformerWrapper( + num_tokens=num_tokens, + max_seq_len=max_seq_len, + logits_dim=logits_dim, + use_cls_token=use_cls_token, + squeeze_out_last_dim = squeeze_out_last_dim, + attn_layers=Encoder(dim=dim, + depth=depth, + heads=heads, + attn_flash=True, + rotary_pos_emb=True + ) + ) + + return model.to(device) + +def load_cls_model(checkpoint_path, device='cuda'): + + """ + Rebuilds the architecture, loads weights. + """ + + model = build_cls_model(device=device) + state = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(state) + model.to(device).eval() + + return model + +def cls_predict(model, + seqs, + batch_size=8, + threshold=0.5, + seq_len=1024, + pad_token=18818, + device='cuda' + ): + + """ + Returns two lists: + - probs: float probabilities + - preds: int 0/1 predictions + """ + + def collate_fn(batch): + # batch: list of sequences (list/1D-tensor) + tensors = [s[:seq_len].detach().clone() for s in batch] + max_len = min(seq_len, max(t.size(0) for t in tensors)) + padded = torch.full((len(tensors), max_len), pad_token, dtype=torch.long) + for i, t in enumerate(tensors): + L = t.size(0) + padded[i, :L] = t + return padded + + ds = ClsInferenceDataset(seqs) + loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) + + all_probs = [] + all_preds = [] + + model.to(device) + model.eval() + + with torch.inference_mode(): + for x in loader: + + x = x.to(device) # [B, L] (truncated & padded) + + logits = model(x).squeeze() # [B] + + probs = torch.sigmoid(logits) # [B] + + preds = (probs >= threshold).long() + + probs = probs.cpu().tolist() + preds = preds.cpu().tolist() + + if type(preds) == list: + all_probs.extend(probs) + all_preds.extend(preds) + + else: + all_probs.append(probs) + all_preds.append(preds) + + return all_preds, all_probs + +#================================================================================================================================= +# Sequences probabilities and scores functions +#================================================================================================================================= + +import inspect +import math +from typing import Callable, Optional, Dict, Any, List, Tuple +import torch +import torch.nn.functional as F + +def print_probs_scoring_guide(): + print(inspect.getdoc(probs_scoring_guide)) + +def probs_scoring_guide(): + + """ + Return dictionary structure and metric descriptions for generate_with_probs / score_sequences. + + Returns + ------- + result : dict + A dictionary containing token-level and sequence-level scoring information. + + Keys + ---- + tokens : torch.Tensor + Tensor of token ids for each batch entry. Shape (batch, seq_len). + - Meaning: Generated tokens (for generate_with_probs) or the original + input sequences (for score_sequences). + - Interpretation: Map ids to text with your tokenizer to inspect outputs. + + token_probs : List[List[float]] + Per-batch lists of probabilities assigned to each chosen token at the time + it was produced. Values in [0, 1]. + - Meaning: Softmax probability for the selected token at each step. + - Interpretation: Higher → model more confident about that token. Do not + multiply many token_probs directly (underflow risk); use log-probs. + + token_logprobs : List[List[float]] + Per-batch lists of natural log probabilities (nats) for each chosen token: + log p(x_t | x_ Tuple[float, str]: + lp64 = torch.tensor(logp, dtype=torch.float64) + try: + p64 = float(torch.exp(lp64).item()) + except Exception: + p64 = 0.0 + if p64 == 0.0: + log10_prob = float(lp64.item() / math.log(10.0)) + display = f"~10^{log10_prob:.2f}" + else: + display = f"{p64:.6e}" + return p64, display + +def _attach_metrics_to_result(result: Dict[str, Any]) -> Dict[str, Any]: + seq_logprobs: List[float] = result.get("sequence_logprobs", []) + token_logprobs: List[List[float]] = result.get("token_logprobs", []) + token_probs: List[List[float]] = result.get("token_probs", []) + metrics = {"per_sequence": []} + for i, seq_lp in enumerate(seq_logprobs): + toks_lp = token_logprobs[i] if i < len(token_logprobs) else [] + token_count = len(toks_lp) + avg_lp = float(sum(toks_lp) / token_count) if token_count > 0 else 0.0 + avg_lp_bits = avg_lp / math.log(2.0) + try: + perplexity = math.exp(-avg_lp) + except OverflowError: + perplexity = float("inf") + log10_prob = seq_lp / math.log(10.0) + seq_prob_display = result.get("sequence_prob_display", [None]*len(seq_logprobs))[i] + if seq_prob_display is None: + seq_prob_display = f"~10^{log10_prob:.2f}" + if token_count > 0: + try: + geom_mean = math.exp(avg_lp) + geom_mean_display = f"{geom_mean:.6e}" + except OverflowError: + geom_mean_display = f"exp({avg_lp:.3f})" + else: + geom_mean_display = "n/a" + metrics["per_sequence"].append({ + "sequence_index": i, + "token_count": token_count, + "sequence_logprob_nats": float(seq_lp), + "sequence_log10": float(log10_prob), + "sequence_prob_display": seq_prob_display, + "avg_logprob_per_token_nats": float(avg_lp), + "avg_logprob_per_token_bits": float(avg_lp_bits), + "geometric_mean_token_prob": geom_mean_display, + "perplexity": float(perplexity) + }) + result["metrics"] = metrics + return result + +def _decode_token(tokenizer, tok_id: int) -> str: + if tokenizer is None: + return str(tok_id) + try: + if hasattr(tokenizer, "decode"): + return tokenizer.decode([tok_id]) + if hasattr(tokenizer, "convert_ids_to_tokens"): + return tokenizer.convert_ids_to_tokens([tok_id])[0] + except Exception: + pass + return str(tok_id) + +# --------------------------- +# generate_with_probs (with diff) +# --------------------------- +@torch.inference_mode() +def generate_with_probs( + model, + prompts: torch.Tensor, + seq_len: int, + eos_token: Optional[int] = None, + temperature: float = 1.0, + prompt_lens: Optional[torch.Tensor] = None, + filter_logits_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + pad_value: Optional[int] = None, + tokenizer = None, + print_table: bool = False, + device: Optional[torch.device] = None, + verbose: bool = True, + include_top1: bool = True, + **kwargs +) -> Dict[str, Any]: + """ + Generate sequences from an autoregressive model while collecting per-token probabilities, + log-probabilities, scores and an optional "diff" view comparing sampled tokens to the + model's top-1 (greedy) tokens. + + This function runs the model in inference mode and appends sampled tokens to the provided + prompts until `seq_len` tokens have been generated (or until an `eos_token` ends all + sequences). It supports temperature sampling, optional logits filtering, and returns + detailed diagnostics useful for evaluation, debugging and analysis (per-token probs, + cumulative sequence log-probabilities, NLL, perplexity, and a diff of sampled vs top-1). + + Key behaviors + - Operates under `torch.inference_mode()` (no gradients). + - If `prompt_lens` is provided, prompts are right-aligned into a padded buffer of the + same shape as `prompts` before generation (useful when prompts are suffixes). + - If `filter_logits_fn` is provided it is applied to raw logits before softmax. + - If `temperature == 0.0` the function performs greedy decoding (argmax). + - If `include_top1` is True, the function computes the top-1 token and its probability + at each step (after optional filtering) and records whether the sampled token matched it. + - If `eos_token` is provided, generation stops early when every batch item has produced + an EOS; generated outputs after the first EOS are optionally padded with `pad_value`. + - Returned numeric log-probabilities are in natural log (nats) and converted to float64 + for sequence-level aggregation to reduce numerical error. + + Parameters + - model: A model object exposing a `net` callable with signature + `logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`. + `logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first + element is that tensor. + - prompts (torch.Tensor): Integer token tensor of shape (batch, prompt_len) containing + prompt tokens. Prompts are copied and extended in-place to produce generated sequences. + - seq_len (int): Maximum number of tokens to generate per example (not counting prompt). + - eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, generation + may stop early and outputs after the first EOS are optionally replaced with `pad_value`. + - temperature (float): Sampling temperature. `0.0` forces greedy decoding. + - prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths (int or tensor) + used to right-align prompts into the generation buffer when prompts are suffixes. + - filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax. + Signature should accept `(logits, **filter_kwargs)` and return logits of same shape. + - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`. + - pad_value (Optional[int]): Token id used to pad generated outputs after EOS (if any). + - tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and + printed tables. If absent, token ids are stringified. + - print_table (bool): If True, prints a human-readable table summarizing per-token stats. + - device (Optional[torch.device]): Device to run generation on. Defaults to `prompts.device`. + - verbose (bool): If True, prints progress messages during generation. + - include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs, + and a `diff` structure listing positions where sampled != top-1. + - **kwargs: Additional keyword arguments forwarded to `model.net`. + + Returns + A dictionary with the following keys (types shown informally): + - "tokens" (torch.Tensor): Generated tokens (batch, generated_len) as CPU tensor. + - "token_probs" (List[List[float]]): Per-batch list of per-token sampling probabilities. + - "token_logprobs" (List[List[float]]): Per-batch list of per-token log-probabilities (nats). + - "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities). + - "sequence_logprobs" (List[float]): Sum of token log-probabilities per generated sequence. + - "sequence_probs" (List[float]): Sequence probabilities (exp of sequence_logprobs) where + numerically possible; extremely small values may be represented as 0.0. + - "sequence_prob_display" (List[str]): Human-friendly display of sequence probability + (either decimal or approximate 10^x form for tiny values). + - "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob). + - "metadata" (dict): Contains "prompt_len", "generated_len", and "temperature". + - "diff" (List[List[Dict]]): Per-batch list of dictionaries for positions where the sampled + token differed from the top-1 token. Each dict contains: + - "pos": position index within the generated span (0-based) + - "token": sampled token id + - "token_str": decoded sampled token (or id string) + - "token_prob": sampled token probability + - "top1_token": top-1 token id + - "top1_token_str": decoded top-1 token + - "top1_prob": top-1 probability + - "match": boolean (always False for entries in diff) + - If `include_top1` is True, additional keys are included: + - "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches" + + After the primary result is assembled the function attaches a "metrics" entry with: + - "per_sequence": list of per-sequence metric dicts containing: + - "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10", + "sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits", + "geometric_mean_token_prob", "perplexity" + + Notes and caveats + - Numerical stability: very small probabilities are clamped before log to avoid -inf; + sequence probabilities that underflow are represented with an approximate 10^x string. + - The function assumes the model's logits correspond to the next-token distribution for + the last position of the provided input; it uses `logits[:, -1, :]` for sampling. + - The function may raise exceptions if `model.net` returns tensors of unexpected shape + or if device/dtype mismatches occur. + - This function is intended for analysis and debugging; it is not optimized for maximal + throughput in production sampling loops. + + Example (conceptual) + >>> res = generate_with_probs(model, prompts, seq_len=20, temperature=0.8, tokenizer=tok) + >>> print(res["metrics"]["per_sequence"][0]["perplexity"]) + """ + if filter_kwargs is None: + filter_kwargs = {} + if device is None: + device = prompts.device + if pad_value is None: + pad_value = getattr(model, "pad_value", None) + + model.eval() + with torch.inference_mode(): + prompts_in = prompts.to(device) + b, t = prompts_in.shape + + if prompt_lens is not None: + aligned = torch.full_like(prompts_in, pad_value) + for i in range(b): + L = int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else int(prompt_lens[i]) + if L > 0: + aligned[i, -L:] = prompts_in[i, -L:] + prompts_in = aligned + + out = prompts_in.clone() + + token_probs: List[List[float]] = [[] for _ in range(b)] + token_logprobs: List[List[float]] = [[] for _ in range(b)] + token_scores: List[List[float]] = [[] for _ in range(b)] + seq_logprob_tensors = [torch.tensor(0.0, dtype=torch.float64) for _ in range(b)] + + top1_tokens: List[List[int]] = [[] for _ in range(b)] + top1_token_probs: List[List[float]] = [[] for _ in range(b)] + top1_token_logprobs: List[List[float]] = [[] for _ in range(b)] + top1_matches: List[List[bool]] = [[] for _ in range(b)] + + greedy = (temperature == 0.0) + + if verbose: + print("Generating sequence of max length:", seq_len) + + for sl in range(seq_len): + max_seq_len = getattr(model, "max_seq_len", None) + x = out if max_seq_len is None else out[:, -max_seq_len:] + + logits_out = model.net(x, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs) + logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out + logits = logits[:, -1, :] + + # top1 (greedy) from raw logits + if include_top1: + top1_ids = logits.argmax(dim=-1, keepdim=True) # (batch,1) + filtered_for_top1 = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs) + probs_for_top1 = F.softmax(filtered_for_top1 / (temperature if temperature > 0 else 1.0), dim=-1) + top1_p = probs_for_top1.gather(1, top1_ids).squeeze(1) + top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64) + + if greedy: + filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / (temperature if temperature > 0 else 1.0), dim=-1) + sample = logits.argmax(dim=-1, keepdim=True) + else: + filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + + picked_probs = probs.gather(1, sample).squeeze(1) + picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64) + + out = torch.cat((out, sample), dim=-1) + + for i in range(b): + p = float(picked_probs[i].cpu().item()) + lp = float(picked_logprobs[i].cpu().item()) + token_probs[i].append(p) + token_logprobs[i].append(lp) + token_scores[i].append(-lp) + seq_logprob_tensors[i] = seq_logprob_tensors[i] + torch.tensor(lp, dtype=torch.float64) + + if include_top1: + tid = int(top1_ids[i].item()) + tp = float(top1_p[i].cpu().item()) + tlp = float(top1_lp[i].cpu().item()) + top1_tokens[i].append(tid) + top1_token_probs[i].append(tp) + top1_token_logprobs[i].append(tlp) + top1_matches[i].append(int(sample[i].item()) == tid) + + if verbose and (sl % 32 == 0): + print(f"{sl} / {seq_len}") + + if eos_token is not None: + last_tokens = out[:, -1] + if (last_tokens == eos_token).any(dim=-1).all(): + if verbose: + print('Model called the end of sequence at:', sl, '/', seq_len) + break + + gen = out[:, t:].cpu() + + if eos_token is not None: + for i in range(b): + seq_full = out[i].cpu() + eos_positions = (seq_full == eos_token).nonzero(as_tuple=False) + if eos_positions.numel() > 0: + first_eos_idx = int(eos_positions[0].item()) + gen_len_before_eos = max(0, first_eos_idx - t) + token_probs[i] = token_probs[i][:gen_len_before_eos] + token_logprobs[i] = token_logprobs[i][:gen_len_before_eos] + token_scores[i] = token_scores[i][:gen_len_before_eos] + seq_logprob_tensors[i] = torch.tensor(sum(token_logprobs[i]), dtype=torch.float64) + if include_top1: + top1_tokens[i] = top1_tokens[i][:gen_len_before_eos] + top1_token_probs[i] = top1_token_probs[i][:gen_len_before_eos] + top1_token_logprobs[i] = top1_token_logprobs[i][:gen_len_before_eos] + top1_matches[i] = top1_matches[i][:gen_len_before_eos] + if pad_value is not None: + start_mask = max(0, first_eos_idx - t) + if start_mask < gen.shape[1]: + gen[i, start_mask:] = pad_value + + sequence_logprobs: List[float] = [float(x.item()) for x in seq_logprob_tensors] + sequence_probs: List[float] = [] + sequence_prob_display: List[str] = [] + nll: List[float] = [] + + for lp in sequence_logprobs: + pnum, disp = _safe_exp64(lp) + sequence_probs.append(pnum) + sequence_prob_display.append(disp) + nll.append(-lp) + + result = { + "tokens": gen, + "token_probs": token_probs, + "token_logprobs": token_logprobs, + "token_scores": token_scores, + "sequence_logprobs": sequence_logprobs, + "sequence_probs": sequence_probs, + "sequence_prob_display": sequence_prob_display, + "nll": nll, + "metadata": { + "prompt_len": t, + "generated_len": gen.shape[1], + "temperature": temperature + } + } + + if include_top1: + result.update({ + "top1_tokens": top1_tokens, + "top1_token_probs": top1_token_probs, + "top1_token_logprobs": top1_token_logprobs, + "top1_matches": top1_matches + }) + + # build diff view: sampled != top1 + diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)] + if include_top1: + for i in range(b): + for pos, (sample_tok, sample_p, t1_tok, t1_p, match) in enumerate(zip( + [int(x) for x in gen[i].tolist()], + token_probs[i], + top1_tokens[i], + top1_token_probs[i], + top1_matches[i] + )): + if not match: + diff_all[i].append({ + "pos": pos, + "token": sample_tok, + "token_str": _decode_token(tokenizer, sample_tok), + "token_prob": sample_p, + "top1_token": int(t1_tok), + "top1_token_str": _decode_token(tokenizer, int(t1_tok)), + "top1_prob": t1_p, + "match": bool(match) + }) + result["diff"] = diff_all + + result = _attach_metrics_to_result(result) + + if print_table: + for i in range(b): + print("="*110) + print(f"Batch {i} (prompt_len={t})") + print("-"*110) + print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match") + print("-"*110) + cum_logp = 0.0 + for idx, (p, lp, sc) in enumerate(zip(token_probs[i], token_logprobs[i], token_scores[i])): + cum_logp += lp + tok_id = int(gen[i, idx].item()) if idx < gen.shape[1] else -1 + tok_display = _decode_token(tokenizer, tok_id) + if include_top1: + t1_id = top1_tokens[i][idx] + t1_p = top1_token_probs[i][idx] + match_mark = "*" if top1_matches[i][idx] else " " + print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}){match_mark}") + else: + print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}") + print("-"*110) + print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}") + m = result["metrics"]["per_sequence"][i] + print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}") + if result["diff"][i]: + print("DIFF (sampled != top1) positions:") + for d in result["diff"][i]: + print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}") + else: + print("No diffs: sampled tokens matched top1 at every step.") + print("="*110) + + return result + +# --------------------------- +# score_sequences (with diff) +# --------------------------- +@torch.inference_mode() +def score_sequences( + model, + sequences: torch.Tensor, + prompt_lens: Optional[torch.Tensor] = None, + eos_token: Optional[int] = None, + pad_value: Optional[int] = None, + filter_logits_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + tokenizer = None, + print_table: bool = False, + device: Optional[torch.device] = None, + verbose: bool = False, + include_top1: bool = True, + **kwargs +) -> Dict[str, Any]: + """ + Compute per-token and per-sequence likelihood statistics for given full sequences + under an autoregressive model, optionally comparing each target token to the model's + top-1 prediction and producing a diff of mismatches. + + This function scores provided sequences by computing the model's next-token distribution + for each position and extracting the probability and log-probability assigned to the + actual target token (i.e., the token that follows each input prefix). It supports + masking of padding tokens, optional EOS-based truncation, and an optional logits filter. + The function returns detailed per-token lists, aggregated sequence log-probabilities, + NLLs, human-friendly probability displays, and diagnostic "diff" entries where the + target token differs from the model's greedy top-1. + + Key behaviors + - Operates under `torch.inference_mode()` (no gradients). + - Expects `sequences` shaped (batch, seq_len). The function scores tokens at positions + 1..(seq_len-1) where each target is `sequences[:, pos]` and the corresponding input + is `sequences[:, :pos]`. + - If `filter_logits_fn` is provided it is applied to the model logits before softmax. + - If `pad_value` is provided, positions where the target equals `pad_value` are masked + out and not counted in sequence sums or per-token lists. + - If `eos_token` is provided, tokens after the first EOS in each sequence are masked out. + - If `include_top1` is True, the function computes top-1 ids and probabilities and + records whether the target matched the top-1 at each scored position. + + Parameters + - model: A model object exposing a `net` callable with signature + `logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`. + `logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first + element is that tensor. + - sequences (torch.Tensor): Integer token tensor of shape (batch, seq_len) containing + full sequences to be scored. The first token of each sequence is treated as context + and scoring begins at the second token. + - prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths; included in + returned metadata for bookkeeping (does not change scoring logic). + - eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, tokens + after the first EOS are excluded from scoring. + - pad_value (Optional[int]): Token id used to indicate padding; masked positions are + excluded from per-token lists and sequence aggregates. + - filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax. + Signature should accept `(logits, **filter_kwargs)` and return logits of same shape. + - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`. + - tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and + printed tables. If absent, token ids are stringified. + - print_table (bool): If True, prints a human-readable table summarizing per-token stats. + - device (Optional[torch.device]): Device to run scoring on. Defaults to `sequences.device`. + - verbose (bool): If True, prints progress or extra information (currently minimal). + - include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs, + and a `diff` structure listing positions where target != top-1. + - **kwargs: Additional keyword arguments forwarded to `model.net`. + + Returns + A dictionary with the following keys: + - "tokens" (torch.Tensor): The input `sequences` returned as a CPU tensor. + - "token_probs" (List[List[float]]): Per-batch lists of probabilities assigned to each + scored target token (masked positions removed). + - "token_logprobs" (List[List[float]]): Per-batch lists of log-probabilities (nats). + - "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities). + - "sequence_logprobs" (List[float]): Sum of log-probabilities over unmasked target tokens. + - "sequence_probs" (List[float]): Sequence probabilities where numerically representable. + - "sequence_prob_display" (List[str]): Human-friendly display of sequence probability. + - "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob). + - "mask" (torch.BoolTensor): Boolean mask (batch, scored_len) indicating which target + positions were included in scoring (True = scored). + - "diff" (List[List[Dict]]): Per-batch list of dicts for positions where the target + token did not match the model's top-1. Each dict contains: + - "pos": index within the scored positions (0-based) + - "token": target token id + - "token_str": decoded target token (or id string) + - "token_prob": probability assigned to the target token + - "top1_token": top-1 token id + - "top1_token_str": decoded top-1 token + - "top1_prob": top-1 probability + - "match": boolean (False for entries in diff) + - "metadata" (dict): Contains "prompt_len" (if provided), "seq_len" (original sequence + length), and "scored_len_per_batch" (number of scored tokens per batch item). + - If `include_top1` is True, additional keys are included: + - "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches" + + After assembling the primary result the function attaches a "metrics" entry with: + - "per_sequence": list of per-sequence metric dicts containing: + - "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10", + "sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits", + "geometric_mean_token_prob", "perplexity" + + Notes and caveats + - The function expects `sequences` to contain at least two tokens per batch item; if + `seq_len < 2` a minimal result with empty scored lists is returned. + - Numerical stability: probabilities are clamped before log to avoid -inf; extremely + small sequence probabilities are represented in approximate 10^x form. + - The function may raise ValueError if `model.net` returns logits of unexpected shape. + - This routine is intended for evaluation and analysis of model likelihoods rather than + high-performance batched scoring in production. + + Example (conceptual) + >>> res = score_sequences(model, sequences, pad_value=0, eos_token=2, tokenizer=tok) + >>> print(res["metrics"]["per_sequence"][0]["avg_logprob_per_token_nats"]) + """ + if filter_kwargs is None: + filter_kwargs = {} + if device is None: + device = sequences.device + + model.eval() + with torch.inference_mode(): + sequences = sequences.to(device) + b, L = sequences.shape + + if L < 2: + empty = [[] for _ in range(b)] + return { + "tokens": sequences.cpu(), + "token_probs": empty, + "token_logprobs": empty, + "token_scores": empty, + "sequence_probs": [1.0 for _ in range(b)], + "sequence_prob_display": [f"{1.0:.6e}" for _ in range(b)], + "sequence_logprobs": [0.0 for _ in range(b)], + "nll": [0.0 for _ in range(b)], + "mask": torch.zeros((b, 0), dtype=torch.bool), + "diff": [[] for _ in range(b)], + "metadata": {"prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens), + "seq_len": L, + "scored_len": 0} + } + + inputs = sequences[:, :-1] + targets = sequences[:, 1:] + + logits_out = model.net(inputs, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs) + logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out + + if logits.dim() != 3: + raise ValueError(f"Expected logits with shape (b, seq, vocab), got {logits.shape}") + + filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **(filter_kwargs or {})) + probs = F.softmax(filtered_logits, dim=-1) + targets_unsq = targets.unsqueeze(-1) + picked_probs = probs.gather(dim=-1, index=targets_unsq).squeeze(-1) + picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64) + + if include_top1: + top1_ids = probs.argmax(dim=-1) # (b, seq) + top1_p = probs.gather(-1, top1_ids.unsqueeze(-1)).squeeze(-1) + top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64) + + mask = torch.ones_like(picked_probs, dtype=torch.bool) + if pad_value is not None: + mask = mask & (targets != pad_value) + + if eos_token is not None: + for i in range(b): + seq_full = sequences[i] + eos_positions = (seq_full == eos_token).nonzero(as_tuple=False) + if eos_positions.numel() > 0: + first_eos = int(eos_positions[0].item()) + cutoff = max(0, first_eos - 1) + if cutoff + 1 < mask.shape[1]: + mask[i, cutoff+1:] = False + + token_probs: List[List[float]] = [] + token_logprobs: List[List[float]] = [] + token_scores: List[List[float]] = [] + sequence_logprobs: List[float] = [] + sequence_probs: List[float] = [] + sequence_prob_display: List[str] = [] + nll: List[float] = [] + + top1_tokens: List[List[int]] = [[] for _ in range(b)] + top1_token_probs: List[List[float]] = [[] for _ in range(b)] + top1_token_logprobs: List[List[float]] = [[] for _ in range(b)] + top1_matches: List[List[bool]] = [[] for _ in range(b)] + + diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)] + + for i in range(b): + row_mask = mask[i] + row_probs = picked_probs[i] + row_logps = picked_logprobs[i] + kept_probs = row_probs[row_mask].cpu().tolist() + kept_logps = row_logps[row_mask].cpu().tolist() + kept_scores = [-lp for lp in kept_logps] + token_probs.append([float(x) for x in kept_probs]) + token_logprobs.append([float(x) for x in kept_logps]) + token_scores.append([float(x) for x in kept_scores]) + + if include_top1: + t1_row = top1_ids[i] + t1_p_row = top1_p[i] + t1_lp_row = top1_lp[i] + kept_t1_ids = t1_row[row_mask].cpu().tolist() + kept_t1_ps = t1_p_row[row_mask].cpu().tolist() + kept_t1_lps = t1_lp_row[row_mask].cpu().tolist() + top1_tokens[i] = [int(x) for x in kept_t1_ids] + top1_token_probs[i] = [float(x) for x in kept_t1_ps] + top1_token_logprobs[i] = [float(x) for x in kept_t1_lps] + kept_targets = targets[i][row_mask].cpu().tolist() + top1_matches[i] = [int(t == top1) for t, top1 in zip(kept_targets, kept_t1_ids)] + + # build diff entries where target != top1 + for pos_idx, (tgt, tgt_p, t1, t1_p, match) in enumerate(zip(kept_targets, kept_probs, kept_t1_ids, kept_t1_ps, top1_matches[i])): + if not match: + diff_all[i].append({ + "pos": pos_idx, + "token": int(tgt), + "token_str": _decode_token(tokenizer, int(tgt)), + "token_prob": float(tgt_p), + "top1_token": int(t1), + "top1_token_str": _decode_token(tokenizer, int(t1)), + "top1_prob": float(t1_p), + "match": bool(match) + }) + + seq_lp_tensor = torch.tensor(sum(kept_logprobs := kept_logps), dtype=torch.float64) + seq_lp = float(seq_lp_tensor.item()) + pnum, disp = _safe_exp64(seq_lp) + sequence_logprobs.append(seq_lp) + sequence_probs.append(pnum) + sequence_prob_display.append(disp) + nll.append(-seq_lp) + + result = { + "tokens": sequences.cpu(), + "token_probs": token_probs, + "token_logprobs": token_logprobs, + "token_scores": token_scores, + "sequence_logprobs": sequence_logprobs, + "sequence_probs": sequence_probs, + "sequence_prob_display": sequence_prob_display, + "nll": nll, + "mask": mask.cpu(), + "diff": diff_all, + "metadata": { + "prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens), + "seq_len": L, + "scored_len_per_batch": [int(m.sum().item()) for m in mask] + } + } + + if include_top1: + result.update({ + "top1_tokens": top1_tokens, + "top1_token_probs": top1_token_probs, + "top1_token_logprobs": top1_token_logprobs, + "top1_matches": top1_matches + }) + + result = _attach_metrics_to_result(result) + + if print_table: + for i in range(b): + print("=" * 120) + header = f"Batch {i} (seq_len={L})" + if prompt_lens is not None: + header += f" prompt_len={int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else prompt_lens[i]}" + print(header) + print("-" * 120) + print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match") + print("-" * 120) + cum_logp = 0.0 + pos_idx = 0 + for pos in range(1, L): + if not mask[i, pos-1]: + continue + tok_id = int(sequences[i, pos].item()) + p = float(picked_probs[i, pos-1].cpu().item()) + lp = float(picked_logprobs[i, pos-1].cpu().item()) + cum_logp += lp + sc = -lp + if include_top1: + t1_id = top1_tokens[i][pos_idx] + t1_p = top1_token_probs[i][pos_idx] + match = top1_matches[i][pos_idx] + print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}) | {match}") + else: + print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}") + pos_idx += 1 + print("-" * 120) + print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}") + m = result["metrics"]["per_sequence"][i] + print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}") + if result["diff"][i]: + print("DIFF (target != top1) positions:") + for d in result["diff"][i]: + print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}") + else: + print("No diffs: target tokens matched top1 at every scored position.") + print("=" * 120) + + return result + +#================================================================================================================================= +# ETA functions +#================================================================================================================================= + +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +def calculate_eta( + hours_until_done: float, + *, + tz: str = "America/Los_Angeles", + now: datetime | None = None, + return_dict: bool = False, + verbose: bool = True, + ): + + """ + Compute an ETA timestamp based on the current time (or a provided time) + in a specified timezone. + + Parameters + ---------- + hours_until_done : float + Number of hours remaining until completion. + tz : str, optional + IANA timezone name (default: "America/Los_Angeles"). + now : datetime or None, optional + If provided, use this datetime as the starting point. + If None, the current time in the given timezone is used. + return_dict : bool, optional + If True, return a dictionary with ETA components. + verbose : bool, optional + If True, print a formatted ETA string. + + Returns + ------- + datetime or dict + ETA as a datetime object or a dictionary (if return_dict=True). + + Examples + -------- + + # Simple ETA 5.5 hours from now + calculate_eta(5.5) + + # ETA using a custom starting time in Tokyo + from datetime import datetime + calculate_eta( + 12, + tz="Asia/Tokyo", + now=datetime(2026, 1, 29, 8, 30), + ) + + # Get ETA as a dict without printing + info = calculate_eta(3, verbose=False, return_dict=True) + print(info["pretty"]) + """ + + # Resolve timezone + zone = ZoneInfo(tz) + + # Determine current time + current_time = now.astimezone(zone) if now else datetime.now(zone) + + # Compute ETA + eta = current_time + timedelta(hours=hours_until_done) + + # Format for printing + pretty = eta.strftime("ETA: %A, %B %d %Y @ %H:%M") + + if verbose: + print(pretty) + + if return_dict: + return { + "eta_datetime": eta, + "year": eta.year, + "month": eta.month, + "day": eta.day, + "hour": eta.hour, + "minute": eta.minute, + "second": eta.second, + "timezone": tz, + "pretty": pretty, + } + + return eta + +def calculate_training_run_eta( + num_epochs: int, + num_steps_per_epoch: int, + sec_per_iter: float, + *, + cost_per_hr: float = 0.0, + tz: str = "America/Los_Angeles", + now: datetime | None = None, + return_dict: bool = False, + verbose: bool = True, +): + """ + Compute ETA and cost for a full training run based on: + - number of epochs + - number of steps per epoch + - seconds per iteration + - optional cost per hour of compute + + Prints: + - start time + - ETA timestamp + - per-epoch runtime (h/m/s) + - total runtime (h/m/s) + - cost per epoch + - total run cost + + Returns: + datetime or dict (if return_dict=True) + + Examples: + + # 2 epochs, 7770 steps each, 15.07 sec/iter, $5.3 per/hr + calculate_training_run_eta( + num_epochs=2, + num_steps_per_epoch=7771, + cost_per_hr=5.3, + sec_per_iter=15.07, + ) + + # Get structured info without printing + info = calculate_training_run_eta( + 3, 1000, 0.5, + verbose=False, + return_dict=True + ) + print(info["eta_str"]) + """ + + zone = ZoneInfo(tz) + start_time = now.astimezone(zone) if now else datetime.now(zone) + + # Core calculations + total_iters = num_epochs * num_steps_per_epoch + total_seconds = total_iters * sec_per_iter + epoch_seconds = num_steps_per_epoch * sec_per_iter + + eta = start_time + timedelta(seconds=total_seconds) + + # Formatting helpers + def fmt(seconds: float) -> str: + seconds = int(seconds) + h = seconds // 3600 + m = (seconds % 3600) // 60 + s = seconds % 60 + return f"{h}h {m}m {s}s" + + # Cost calculations + total_hours = total_seconds / 3600 + epoch_hours = epoch_seconds / 3600 + + cost_epoch = epoch_hours * cost_per_hr + cost_total = total_hours * cost_per_hr + + # Pretty strings + start_str = start_time.strftime("%A, %B %d %Y @ %H:%M") + eta_str = eta.strftime("%A, %B %d %Y @ %H:%M") + + if verbose: + print(f"Start Time: {start_str}") + print(f"ETA: {eta_str}") + print(f"Per Epoch: {fmt(epoch_seconds)}") + print(f"Total Run: {fmt(total_seconds)}") + print(f"Cost/Epoch: ${cost_epoch:,.2f}") + print(f"Cost/Run: ${cost_total:,.2f}") + + if return_dict: + return { + "start_time": start_time, + "eta": eta, + "start_str": start_str, + "eta_str": eta_str, + "epoch_seconds": epoch_seconds, + "total_seconds": total_seconds, + "epoch_runtime_hms": fmt(epoch_seconds), + "total_runtime_hms": fmt(total_seconds), + "epoch_hours": epoch_hours, + "total_hours": total_hours, + "cost_per_hr": cost_per_hr, + "cost_epoch": cost_epoch, + "cost_total": cost_total, + "timezone": tz, + } + + return eta + +#================================================================================================================================= +# Autoregressive embeddings retrieval functions +#================================================================================================================================= + +import torch +import torch.nn.functional as F +from typing import Optional, Dict, List, Union, Set +import numpy as np +from tqdm import tqdm +from contextlib import nullcontext + +#=================================================================================================================== +# Advanced Embeddings Retrieval Function for Autoregressive X-Transformers +#=================================================================================================================== + +def get_embeddings( + model, + inputs: torch.Tensor, + pooling: str = 'mean', + mask: Optional[torch.Tensor] = None, + token_ids: Optional[List[int]] = None, + token_weights: Optional[Dict[int, float]] = None, + layer_index: int = -1, + normalize: bool = False, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.bfloat16, + pad_idx: int = 18819, + use_amp: bool = True, + verbose: bool = True, + _max_concat_tokens: Optional[int] = None, + ) -> np.ndarray: + + """ + Get embeddings for a single batch of inputs. + + Parameters + ---------- + model : AutoregressiveWrapper + Your trained transformer model + inputs : torch.Tensor + Input token sequences of shape (batch, seq_len) + pooling : str + Pooling strategy: 'mean' or 'concat' + mask : Optional[torch.Tensor] + Boolean mask, True for valid tokens. Auto-generated if None + token_ids : Optional[List[int]] + Token IDs to include. Works independently or with token_weights. + token_weights : Optional[Dict[int, float]] + Token ID to weight/priority mapping: + - 'mean': weights for weighted average + - 'concat': priority scores for selection when limiting count + - If provided WITHOUT token_ids: keys become the filter + - If provided WITH token_ids: only tokens in BOTH are used (intersection) + layer_index : int + Which layer's hidden states to use (-1 for last) + normalize : bool + L2-normalize output embeddings + device : Optional[torch.device] + Device for inference + dtype : torch.dtype + Dtype for autocast + pad_idx : int + Padding token index + use_amp : bool + Use automatic mixed precision + verbose : bool + Print warnings and info + _max_concat_tokens : Optional[int] + Internal: pre-computed max tokens for concat mode + + Returns + ------- + np.ndarray + Embeddings array: + - 'mean': (batch, dim) + - 'concat': (batch, max_tokens * dim) + """ + + model.eval() + + if device is None: + device = next(model.parameters()).device + + inputs = inputs.to(device) + + if inputs.ndim == 1: + inputs = inputs.unsqueeze(0) + + batch_size, seq_len = inputs.shape + + if mask is None: + mask = (inputs != pad_idx) + else: + mask = mask.to(device) + + if mask.dtype != torch.bool: + mask = mask.bool() + + if hasattr(model, 'net'): + net_model = model.net + else: + net_model = model + + if use_amp and device.type == 'cuda': + ctx = torch.amp.autocast(device_type='cuda', dtype=dtype) + else: + ctx = nullcontext() + + try: + with torch.no_grad(): + with ctx if use_amp else nullcontext(): + output = net_model( + inputs, + mask=mask if mask.ndim == 2 else mask.squeeze(), + return_intermediates=True, + ) + + if isinstance(output, tuple) and len(output) == 2: + _, intermediates = output + else: + intermediates = None + + hidden = _extract_hidden_states(intermediates, layer_index, verbose=verbose) + + if hidden is None: + raise ValueError("Could not extract hidden states") + + except Exception as e: + if verbose: + print(f"Warning: Could not extract hidden states, using token embeddings. Error: {e}") + hidden = _get_token_embeddings(net_model, inputs) + + seq_mask = (inputs != pad_idx) + seq_mask_expanded = seq_mask.unsqueeze(-1) + hidden = hidden * seq_mask_expanded.float() + + # Compute effective token IDs with INTUITIVE logic + effective_token_ids = _compute_effective_token_ids(token_ids, token_weights) + + if pooling == 'mean': + emb = _mean_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, verbose=verbose) + elif pooling == 'concat': + emb = _concat_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, + max_tokens=_max_concat_tokens, verbose=verbose) + else: + raise ValueError(f"Unknown pooling strategy: {pooling}. Use 'mean' or 'concat'") + + if normalize: + emb = F.normalize(emb, p=2, dim=-1) + + return emb.cpu().detach().numpy() + + +#=================================================================================================================== +# Batched Processing Function +#=================================================================================================================== + +def get_embeddings_batched( + model, + sequences: List[List[int]], + pooling: str = 'mean', + token_ids: Optional[List[int]] = None, + token_weights: Optional[Dict[int, float]] = None, + max_seq_len: int = 8192, + pad_idx: int = 18819, + batch_size: int = 8, + use_amp: bool = True, + dtype: torch.dtype = torch.bfloat16, + verbose: bool = True, + show_progress: bool = True, + normalize: bool = False, + ) -> np.ndarray: + + """ + Process multiple sequences in TRUE batches for memory efficiency. + + Parameters + ---------- + model : AutoregressiveWrapper + Your trained transformer model + sequences : List[List[int]] + List of token sequences (list of lists) + pooling : str + Pooling strategy: 'mean' or 'concat' + token_ids : Optional[List[int]] + Token IDs to include + token_weights : Optional[Dict[int, float]] + Token ID to weight/priority mapping + max_seq_len : int + Maximum sequence length + pad_idx : int + Padding token index + batch_size : int + Batch size for processing + use_amp : bool + Use automatic mixed precision + dtype : torch.dtype + Dtype for autocast + verbose : bool + Print messages + show_progress : bool + Show tqdm progress bar + normalize : bool + L2-normalize output embeddings + + Returns + ------- + np.ndarray + Embeddings array with consistent dimensions + """ + + model.eval() + + num_sequences = len(sequences) + + if verbose: + print(f"Processing {num_sequences} sequences in batches of {batch_size}...") + + # For concat mode: pre-scan to find max matching tokens across ALL sequences + max_concat_tokens = None + if pooling == 'concat': + effective_token_ids = _compute_effective_token_ids(token_ids, token_weights) + max_concat_tokens = _scan_max_matching_tokens(sequences, effective_token_ids, pad_idx, max_seq_len) + if verbose and max_concat_tokens is not None: + print(f"Auto-detected max matching tokens: {max_concat_tokens}") + elif verbose and max_concat_tokens == 0: + print("Warning: No sequences contain matching token IDs, using 1 token placeholder") + max_concat_tokens = 1 + + all_embeddings = [] + num_batches = (num_sequences + batch_size - 1) // batch_size + + batch_iterator = tqdm(range(num_batches), desc="Extracting embeddings", disable=not (show_progress and verbose)) if show_progress and verbose else range(num_batches) + + for batch_idx in batch_iterator: + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, num_sequences) + + batch_sequences = sequences[start_idx:end_idx] + max_len_in_batch = min(max_seq_len, max(len(seq) for seq in batch_sequences)) + + padded_batch = [] + for seq in batch_sequences: + if len(seq) > max_len_in_batch: + seq = seq[:max_len_in_batch] + else: + seq = seq + [pad_idx] * (max_len_in_batch - len(seq)) + padded_batch.append(seq) + + batch_inputs = torch.tensor(padded_batch, dtype=torch.long) + + batch_embeddings = get_embeddings( + model, + batch_inputs, + pooling=pooling, + token_ids=token_ids, + token_weights=token_weights, + pad_idx=pad_idx, + use_amp=use_amp, + dtype=dtype, + verbose=verbose and batch_idx == 0, + normalize=normalize, + _max_concat_tokens=max_concat_tokens, + ) + + all_embeddings.append(batch_embeddings) + + final_embeddings = np.concatenate(all_embeddings, axis=0) + + if verbose: + print(f"Final embeddings shape: {final_embeddings.shape}") + + return final_embeddings + +#=================================================================================================================== +# Helper Functions +#=================================================================================================================== + +def _compute_effective_token_ids(token_ids: Optional[List[int]], token_weights: Optional[Dict[int, float]]) -> Optional[Set[int]]: + """ + Compute effective token IDs with INTUITIVE logic: + + - token_ids=None, token_weights=None → None (all valid tokens) + - token_ids=[...], token_weights=None → token_ids + - token_ids=None, token_weights={...} → keys from token_weights + - token_ids=[...], token_weights={...} → INTERSECTION (only tokens in BOTH) + + This ensures token_weights acts as a filter when provided, not just weights. + """ + if token_ids is None and token_weights is None: + return None + + token_ids_set = set(token_ids) if token_ids is not None else None + weights_keys_set = set(token_weights.keys()) if token_weights is not None else None + + if token_ids_set is None and weights_keys_set is not None: + # Only token_weights provided: use its keys as filter + return weights_keys_set + elif token_ids_set is not None and weights_keys_set is None: + # Only token_ids provided: use token_ids as filter + return token_ids_set + elif token_ids_set is not None and weights_keys_set is not None: + # Both provided: INTERSECTION (only tokens in BOTH lists) + # This is the key fix for intuitive behavior + intersection = token_ids_set & weights_keys_set + if len(intersection) == 0: + # Warn but fall back to token_ids (more permissive) + print(f"Warning: token_ids and token_weights have no overlap. Using token_ids only.") + return token_ids_set + return intersection + else: + return None + +def _scan_max_matching_tokens(sequences: List[List[int]], + token_ids: Optional[Set[int]], + pad_idx: int, + max_seq_len: int) -> int: + """ + Scan all sequences to find maximum number of tokens matching token_ids. + """ + if token_ids is None: + return max(min(len(seq), max_seq_len) for seq in sequences) if sequences else 0 + + max_count = 0 + for seq in sequences: + truncated = seq[:max_seq_len] + count = sum(1 for tok in truncated if tok in token_ids and tok != pad_idx) + max_count = max(max_count, count) + + return max_count + +def _extract_hidden_states(intermediates, layer_index: int = -1, verbose: bool = True): + """Extract hidden states from LayerIntermediates object.""" + if intermediates is None: + if verbose: + print("Warning: intermediates is None") + return None + + if hasattr(intermediates, 'layer_hiddens') and intermediates.layer_hiddens is not None: + if len(intermediates.layer_hiddens) > 0: + return intermediates.layer_hiddens[layer_index] + + if hasattr(intermediates, 'hiddens') and intermediates.hiddens is not None: + if len(intermediates.hiddens) > 0: + return intermediates.hiddens[layer_index] + + if hasattr(intermediates, 'attn_intermediates') and intermediates.attn_intermediates is not None: + if len(intermediates.attn_intermediates) > 0: + attn_int = intermediates.attn_intermediates[layer_index] + if hasattr(attn_int, 'values') and attn_int.values is not None: + return attn_int.values + + if verbose: + print("Warning: Could not find layer_hiddens in intermediates") + + return None + + +def _get_token_embeddings(net_model, inputs: torch.Tensor): + """Get token embeddings directly from embedding layer.""" + if hasattr(net_model, 'token_emb'): + if hasattr(net_model.token_emb, 'emb'): + return net_model.token_emb.emb(inputs) + else: + return net_model.token_emb(inputs) + elif hasattr(net_model, 'emb'): + return net_model.emb(inputs) + else: + raise ValueError("Could not find embedding layer in model") + +def _mean_pooling( + hidden: torch.Tensor, + inputs: torch.Tensor, + mask: torch.Tensor, + token_ids: Optional[Set[int]], + token_weights: Optional[Dict[int, float]], + verbose: bool = True +) -> torch.Tensor: + """ + Mean pooling with token ID filtering and weighted averaging. + """ + batch_size, seq_len, dim = hidden.shape + device = hidden.device + + if mask.ndim > 2: + mask = mask.squeeze() + + effective_mask = mask.clone() + + if token_ids is not None: + token_mask = torch.zeros_like(mask, dtype=torch.bool, device=device) + for tid in token_ids: + token_mask = token_mask | (inputs == tid) + effective_mask = effective_mask & token_mask + + if verbose and effective_mask.sum() == 0: + print(f"Warning: No tokens match filter, falling back to all valid tokens") + effective_mask = mask + + if token_weights is not None: + weights = torch.zeros_like(effective_mask, dtype=torch.float32, device=device) + + for token_id, weight in token_weights.items(): + id_mask = (inputs == token_id) & effective_mask + weights = weights.masked_fill(id_mask, float(weight)) + + weights = weights.masked_fill(effective_mask & (weights == 0), 1.0) + + weighted_hidden = hidden * weights.unsqueeze(-1) + sum_weighted = weighted_hidden.sum(dim=1) + sum_weights = weights.sum(dim=1, keepdim=True).clamp(min=1e-9) + return sum_weighted / sum_weights + else: + masked_hidden = hidden * effective_mask.unsqueeze(-1).float() + sum_hidden = masked_hidden.sum(dim=1) + count = effective_mask.sum(dim=1, keepdim=True).clamp(min=1e-9) + return sum_hidden / count + +def _concat_pooling( + hidden: torch.Tensor, + inputs: torch.Tensor, + mask: torch.Tensor, + token_ids: Optional[Set[int]], + token_weights: Optional[Dict[int, float]], + max_tokens: Optional[int], + verbose: bool = True +) -> torch.Tensor: + """ + Concat pooling with token ID filtering and weight-based priority selection. + """ + batch_size, seq_len, dim = hidden.shape + device = hidden.device + + if max_tokens is None: + max_tokens = 1 + + output_dim = max_tokens * dim + + all_token_embs = [] + + for i in range(batch_size): + seq_mask = mask[i] + seq_inputs = inputs[i] + + if token_ids is not None: + matching_mask = torch.zeros(seq_len, dtype=torch.bool, device=device) + for tid in token_ids: + matching_mask = matching_mask | ((seq_inputs == tid) & seq_mask) + valid_indices = matching_mask.nonzero(as_tuple=True)[0] + else: + valid_indices = seq_mask.nonzero(as_tuple=True)[0] + + if len(valid_indices) == 0: + emb = torch.zeros(dim, device=device) + emb = F.pad(emb, (0, output_dim - dim)) + all_token_embs.append(emb) + continue + + matching_embs = hidden[i, valid_indices, :] + + if token_weights is not None and len(valid_indices) > max_tokens: + weights_list = [] + for idx in valid_indices: + tok_id = seq_inputs[idx].item() + weights_list.append(token_weights.get(tok_id, 1.0)) + + sorted_pairs = sorted(zip(range(len(valid_indices)), weights_list), + key=lambda x: x[1], reverse=True) + top_indices = [valid_indices[p[0]] for p in sorted_pairs[:max_tokens]] + matching_embs = hidden[i, torch.tensor(top_indices, device=device), :] + elif len(valid_indices) > max_tokens: + matching_embs = matching_embs[:max_tokens] + + if len(valid_indices) < max_tokens: + padding_needed = max_tokens - len(valid_indices) + padding = torch.zeros(padding_needed, dim, device=device) + matching_embs = torch.cat([matching_embs, padding], dim=0) + + emb = matching_embs.reshape(-1) + all_token_embs.append(emb) + + return torch.stack(all_token_embs, dim=0) + +#================================================================================================================================= +# Non-Autoregressive Encoder Embeddings Retrieval Functions +#================================================================================================================================= + +def get_enc_embeddings( + model, + sequences: List[List[int]], + seq_len: Optional[int] = 3072, + seq_pad_idx: int = 385, + batch_size: int = 64, + save_every_num_batches: int = -1, + save_file_path: str = "saved_embeddings.npy", + device: Optional[torch.device] = None, + normalize: bool = False, + pooling: str = "auto", # "auto" | "mean" | "weighted_mean" + token_type_weights: Optional[Tuple[float, float, float]] = None, # (onset_w, duration_w, pitch_w) + use_bfloat16: bool = True, # enable bfloat16 autocast when possible + return_dtype: str = "float32", # "float32" or "float16" for returned embeddings + return_numpy: bool = False, + verbose: bool = True, + show_progress_bar: bool = True + ) -> Union[Tensor, np.ndarray]: + + """ + Compute embeddings for a list of token sequences using a PyTorch model with optional bfloat16/autocast, + pooling, normalization, and periodic saving. + + This function batches input token id sequences, pads/truncates them to a fixed length, runs the model + in evaluation mode under `torch.no_grad()` and optional mixed-precision autocast, and returns a single + tensor (or NumPy array) containing per-sequence embeddings. The model is expected to accept a LongTensor + of token ids `x` and a boolean mask `mask` and to return either: + - a 2-D tensor `(B, D)` of already-pooled embeddings, or + - a 3-D tensor `(B, L, D)` of per-token embeddings (which will be pooled according to `pooling`). + + Key behaviors: + - Sequences are padded with `seq_pad_idx` and masked so padding does not affect pooling. + - If `seq_len` is provided, sequences longer than `seq_len` are truncated; otherwise the batch max length is used. + - Mixed-precision autocast is used when `use_bfloat16` is True and supported by the device; the function + falls back to the default autocast or no autocast if unavailable. + - Supports three pooling modes for per-token embeddings: + - `"auto"` or `"mean"`: simple masked mean pooling across tokens. + - `"weighted_mean"`: weighted mean pooling by token type (onset/duration/pitch) inferred from token ids; + weights are provided via `token_type_weights` and padding tokens are ignored. + - Optionally L2-normalizes embeddings (in float32) when `normalize=True`. + - Returned embeddings can be cast to `float16` for storage/transfer via `return_dtype`. + - Embeddings are collected on CPU; intermediate results can be periodically saved to `save_file_path`. + - If `return_numpy=True`, a NumPy array is returned; otherwise a CPU `torch.Tensor` is returned. + + Args: + model (torch.nn.Module): + PyTorch model used to compute embeddings. The model will be moved to `device` (or its current + parameter device if `device` is None) and set to `eval()` for inference. The forward call must + accept `x` (LongTensor) and `mask` (BoolTensor) and return embeddings when called with + `return_embeddings=True`. + sequences (List[List[int]]): + Batch of token id sequences (each sequence is a list of ints). Can be empty; an empty result + with shape `(0, 0)` will be returned in that case. + seq_len (Optional[int], default=3072): + Target sequence length for truncation/padding. If None, the maximum sequence length in the + current batch is used. + seq_pad_idx (int, default=385): + Token id used for padding positions. + batch_size (int, default=64): + Number of sequences processed per forward pass. + save_every_num_batches (int, default=-1): + If > 0, the function will save accumulated embeddings to `save_file_path` every + `save_every_num_batches` batches. A non-positive value disables periodic saving. + save_file_path (str, default="saved_embeddings.npy"): + File path used by `np.save` when periodic saving is enabled. + device (Optional[torch.device], default=None): + Device to run the model and tensors on. If None, the device of the model parameters is used. + normalize (bool, default=False): + If True, L2-normalize each embedding vector (done in float32 for numerical stability). + pooling (str, default="auto"): + Pooling strategy applied when model returns per-token embeddings: + - "auto" or "mean": masked mean pooling. + - "weighted_mean": weighted mean pooling by token type using `token_type_weights`. + Any other value raises `ValueError`. + token_type_weights (Optional[Tuple[float, float, float]], default=None): + Per-token-type weights `(onset_w, duration_w, pitch_w)` used when `pooling="weighted_mean"`. + If None, defaults to `(1.0, 1.0, 1.0)`. Token type ranges are inferred as: + onset: token_id in [0, 127] + duration:token_id in [128, 255] + pitch: token_id in [256, 383] + use_bfloat16 (bool, default=True): + If True, attempts to use `torch.bfloat16` autocast for the device; falls back gracefully if not supported. + return_dtype (str, default="float32"): + Data type for returned embeddings: `"float32"` or `"float16"`. Internally embeddings are normalized + in float32; casting to float16 happens just before collecting results if requested. + return_numpy (bool, default=False): + If True, the final result is returned as a NumPy array; otherwise a CPU `torch.Tensor` is returned. + verbose (bool, default=True): + If True, prints progress and short diagnostic messages via `tqdm`. + show_progress_bar (bool, default=True) + If True, displays tqdm progress bar. + + Returns: + Union[torch.Tensor, numpy.ndarray]: + - If `return_numpy` is False: a CPU `torch.Tensor` of shape `(N, D)` and dtype `torch.float32` + or `torch.float16` depending on `return_dtype`. + - If `return_numpy` is True: a NumPy array of shape `(N, D)` and dtype `np.float32` or `np.float16`. + `N` is the total number of input sequences and `D` is the embedding dimensionality produced by the model. + + Raises: + AssertionError: + If `return_dtype` is not one of `"float32"` or `"float16"`. + RuntimeError: + If the model returns `None` for embeddings (indicates incorrect forward flags or model behavior). + ValueError: + If the model returns an embedding tensor with unexpected dimensionality or if `pooling` is unsupported. + + Notes: + - The function uses `pad_and_mask` to produce `x` (LongTensor) and `mask` (BoolTensor). Padding tokens + are ignored by pooling operations. + - When `pooling="weighted_mean"`, if `token_ids` are not available or the model returns a 2-D tensor, + the function falls back to masked mean pooling. + - Periodic saving concatenates all embeddings collected so far and writes them with `np.save`. Save + failures are caught and reported when `verbose=True` but do not abort processing. + - The function runs the model under `torch.no_grad()` and sets `model.eval()`; it will move the model + to `device` if provided. + - For reproducible numeric behavior across devices, ensure the model and device support the requested + autocast dtype (bfloat16) and that any randomness is controlled externally. + + Example: + >>> # simple usage + >>> embs = get_embeddings_bf16(model, sequences, seq_len=1024, batch_size=32, pooling="mean", + ... normalize=True, return_dtype="float32", return_numpy=False) + """ + + assert return_dtype in ("float32", "float16"), "return_dtype must be 'float32' or 'float16'" + + model_device = next(model.parameters()).device if device is None else device + model.to(model_device) + model.eval() + + all_embs: List[Tensor] = [] + total_batches = math.ceil(len(sequences) / batch_size) if batch_size > 0 else 0 + + if verbose: + tqdm.write( + f"[get_embeddings_bf16] sequences={len(sequences)}, batch_size={batch_size}, " + f"batches={total_batches}, device={model_device}, seq_len={seq_len}, pooling={pooling}" + ) + + # Prepare autocast context using torch.amp.autocast + autocast_ctx = None + if use_bfloat16: + try: + autocast_ctx = torch.amp.autocast(device_type=model_device.type, dtype=torch.bfloat16) + except Exception: + try: + autocast_ctx = torch.amp.autocast(device_type=model_device.type) + except Exception: + autocast_ctx = None + else: + try: + autocast_ctx = torch.amp.autocast(device_type=model_device.type) + except Exception: + autocast_ctx = None + + with torch.inference_mode(): + batch_iter = range(0, len(sequences), batch_size) + pbar = tqdm(batch_iter, disable=not show_progress_bar, total=total_batches, desc="Embedding batches") + for batch_idx, i in enumerate(pbar): + batch_seqs = sequences[i : i + batch_size] + x, mask = pad_and_mask(batch_seqs, pad_idx=seq_pad_idx, seq_len=seq_len, device=model_device, verbose=verbose) + # x: (B, L) LongTensor token ids, mask: (B, L) boolean + + # Run forward under autocast if available + if autocast_ctx is not None: + with autocast_ctx: + out = model(x, return_embeddings=True, mask=mask) + else: + out = model(x, return_embeddings=True, mask=mask) + + if out is None: + raise RuntimeError("model returned None for embeddings. Check forward flags.") + + # Handle shapes + if out.dim() == 2: + # already pooled: (B, D) + emb = out + elif out.dim() == 3: + # per-token embeddings: (B, L, D) + if pooling in ("mean", "auto"): + emb = masked_mean_pool(out, mask, dim=1, verbose=verbose) + elif pooling == "weighted_mean": + # Use token ids to compute per-token weights; fallback to mean if token ids missing + emb = masked_weighted_mean_pool(out, mask, token_ids=x, token_type_weights=token_type_weights, dim=1, verbose=verbose) + else: + raise ValueError(f"unsupported pooling: {pooling}") + else: + raise ValueError(f"unexpected embedding tensor shape: {out.shape}") + + # Ensure embeddings are float32 for stable normalization/indexing + if emb.dtype != torch.float32: + emb = emb.float() + + # L2 normalize in float32 + if normalize: + emb = F.normalize(emb, p=2, dim=-1) + + # Optionally cast to float16 for return/storage + if return_dtype == "float16": + emb = emb.half() + + all_embs.append(emb.cpu()) + + # Update progress bar postfix with shapes and dtype + if verbose: + pbar.set_postfix({"batch": batch_idx + 1, "emb_shape": f"{emb.shape}", "dtype": str(emb.dtype)}) + + # Save intermediate results periodically + if save_every_num_batches > 0: + # compute 0-based batch number + bnum = batch_idx + if (bnum + 1) % save_every_num_batches == 0: + try: + concatenated = torch.cat(all_embs, dim=0).numpy() + np.save(save_file_path, concatenated) + if verbose: + tqdm.write(f"[get_embeddings_bf16] saved {concatenated.shape[0]} embeddings to {save_file_path}") + except Exception as e: + # Do not crash the whole run for a save failure; report if verbose + if verbose: + tqdm.write(f"[get_embeddings_bf16] warning: failed to save embeddings: {e}") + + if len(all_embs) == 0: + # return empty tensor/array with shape (0, 0) + empty = torch.empty((0, 0), dtype=(torch.float16 if return_dtype == "float16" else torch.float32)) + if verbose: + tqdm.write("[get_embeddings_bf16] no embeddings were produced; returning empty tensor") + return empty.numpy() if return_numpy else empty + + result = torch.cat(all_embs, dim=0) # (N, D) on CPU + + if verbose: + tqdm.write(f"[get_embeddings_bf16] finished: total_embeddings={result.shape[0]}, dim={result.shape[1]}, dtype={result.dtype}") + + if return_numpy: + return result.numpy() + + return result + +################################################################################### + +def masked_mean_pool( + token_embeddings: Tensor, + mask: Tensor, + dim: int = 1, + eps: float = 1e-9, + verbose: bool = True, + ) -> Tensor: + + """ + Compute a masked mean pooling over a specified dimension. + + This function computes the mean of `token_embeddings` along `dim`, ignoring + positions where `mask` is False. The mask is cast to the same dtype as the + embeddings to allow safe multiplication. A small epsilon is used to avoid + division by zero for sequences that are entirely masked out. + + Args: + token_embeddings: Tensor of shape (B, L, D) or similar where `dim` indexes + the sequence length. Embeddings dtype can be float16/float32/bfloat16. + mask: Boolean tensor of shape broadcastable to the sequence dimension + (e.g., (B, L)). True indicates valid tokens; False indicates padding. + dim: Dimension along which to pool (default: 1, the sequence length). + eps: Small value to avoid division by zero when a row has zero valid tokens. + verbose: If True, prints a short summary about the pooling operation. + + Returns: + Tensor of pooled embeddings with the sequence dimension removed, typically + shape (B, D). The returned dtype matches `token_embeddings.dtype`. + """ + + mask_f = mask.to(token_embeddings.dtype) # (B, L) + summed = (token_embeddings * mask_f.unsqueeze(-1)).sum(dim=dim) # (B, D) + counts = mask_f.sum(dim=dim).clamp_min(eps).unsqueeze(-1) # (B, 1) + pooled = summed / counts # (B, D) + + if verbose: + # Use tqdm.write so it doesn't interfere with progress bars + valid_counts = counts.squeeze(-1) + tqdm.write( + f"[masked_mean_pool] pooled shape={pooled.shape}, " + f"counts min={valid_counts.min().item():.3f}, max={valid_counts.max().item():.3f}" + ) + + return pooled + +################################################################################### + +def masked_weighted_mean_pool( + token_embs: Tensor, + valid_mask: Tensor, + token_ids: Optional[Tensor] = None, + token_type_weights: Optional[Tuple[float, float, float]] = None, + dim: int = 1, + verbose: bool = False, + ) -> Tensor: + + """ + Weighted mean pooling across tokens. If token_ids is provided, token types are + inferred using the same ranges as the reference code: + - onset: token_id in [0, 127] + - duration:token_id in [128, 255] + - pitch: token_id in [256, 383] + token_type_weights: (onset_w, duration_w, pitch_w). If None, defaults to (1.0,1.0,1.0) + The function multiplies each token embedding by its scalar weight and divides + by the sum of weights for valid tokens per sequence. + """ + + B, L, D = token_embs.shape + device = token_embs.device + dtype = token_embs.dtype + + if token_ids is None: + # No token-level ids available: fallback to simple masked mean + if verbose: + tqdm.write("[masked_weighted_mean_pool] token_ids is None, falling back to masked_mean_pool") + return masked_mean_pool(token_embs, valid_mask, dim=dim, verbose=verbose) + + # Default weights + if token_type_weights is None: + onset_w, duration_w, pitch_w = 1.0, 1.0, 1.0 + else: + onset_w, duration_w, pitch_w = token_type_weights + + # Build per-type boolean masks based on token id values (same ranges as reference) + onset_mask = (token_ids >= 0) & (token_ids < 128) + duration_mask = (token_ids >= 128) & (token_ids < 256) + pitch_mask = (token_ids >= 256) & (token_ids < 384) + + # Combine with valid_mask to ignore padding positions + onset_mask = onset_mask & valid_mask + duration_mask = duration_mask & valid_mask + pitch_mask = pitch_mask & valid_mask + + # Build per-token scalar weight tensor (B, L) + w = torch.ones((B, L), device=device, dtype=dtype) + if onset_w != 1.0: + w = torch.where(onset_mask, torch.tensor(onset_w, device=device, dtype=dtype), w) + if duration_w != 1.0: + w = torch.where(duration_mask, torch.tensor(duration_w, device=device, dtype=dtype), w) + if pitch_w != 1.0: + w = torch.where(pitch_mask, torch.tensor(pitch_w, device=device, dtype=dtype), w) + + # Zero out weights for padding positions + valid_mask_f = valid_mask.to(dtype) # (B, L) + w = w * valid_mask_f # (B, L) + + # Weighted sum and normalization + denom = w.sum(dim=1, keepdim=True).clamp(min=1e-6) # (B, 1) + w_exp = w.unsqueeze(-1) # (B, L, 1) + summed = (token_embs * w_exp).sum(dim=dim) # (B, D) + pooled = summed / denom # (B, D) + + return pooled + +################################################################################### + +def pad_and_mask( + sequences: List[List[int]], + pad_idx: int = 385, + seq_len: Optional[int] = None, + device: Optional[torch.device] = None, + verbose: bool = False, + ) -> Tuple[Tensor, Tensor]: + + """ + Pad and create a boolean mask for a batch of integer token sequences. + + This utility converts a list of variable-length integer sequences into a + padded LongTensor and a corresponding boolean mask indicating valid token + positions. Sequences longer than `seq_len` are truncated. If `seq_len` is + None, the function uses the maximum sequence length in the batch. + + Args: + sequences: List of token id sequences (each a list of ints). + pad_idx: Integer token id used for padding positions (default: 385). + seq_len: Optional target sequence length. If provided, sequences are + truncated or padded to this length. If None, the maximum length in + `sequences` is used. + device: Optional torch.device where the returned tensors will be placed. + If None, tensors are created on the default device. + verbose: If True, shows a small progress bar while processing sequences + and prints a summary. + + Returns: + A tuple (x, mask): + - x: LongTensor of shape (B, T) containing padded token ids. + - mask: BoolTensor of shape (B, T) where True indicates a real token. + """ + + # Fast path for empty batch + if not sequences: + empty = torch.empty((0, 0), dtype=torch.long, device=device) + empty_mask = torch.empty((0, 0), dtype=torch.bool, device=device) + return empty, empty_mask + + # Compute lengths and the batch maximum length + lengths = [len(s) for s in sequences] + batch_max = max(lengths) + + # If seq_len is given, only use it to cap lengths; but if the batch max is smaller, + # use the smaller value to avoid extra allocation/work. + if seq_len is None: + target_len = batch_max + else: + target_len = min(seq_len, batch_max) + + b = len(sequences) + if target_len == 0: + x = torch.full((b, 0), pad_idx, dtype=torch.long, device=device) + mask = torch.zeros((b, 0), dtype=torch.bool, device=device) + return x, mask + + x = torch.full((b, target_len), pad_idx, dtype=torch.long, device=device) + mask = torch.zeros((b, target_len), dtype=torch.bool, device=device) + + # iterate with optional progress display + iterator = enumerate(sequences) + if verbose: + iterator = enumerate(tqdm(sequences, disable=not verbose, desc="Pad & mask")) + + for i, seq in iterator: + if not seq: + continue + # Only truncate if seq is longer than the chosen target_len + L = len(seq) + if L > target_len: + L = target_len + # slice once to avoid creating a larger tensor then slicing + seq_slice = seq[:L] + seq_tensor = torch.tensor(seq_slice, dtype=torch.long, device=device) + else: + seq_tensor = torch.tensor(seq, dtype=torch.long, device=device) + + x[i, :L] = seq_tensor[:L] + mask[i, :L] = True + + if verbose: + tqdm.write( + f"[pad_and_mask] batch_size={b}, target_len={target_len}, " + f"min_len={min(lengths)}, max_len={max(lengths)}" + ) + + return x, mask + +#================================================================================================================================= +# Embeddings similarity comparison functions +#================================================================================================================================= + +import torch +import torch.nn.functional as F +from tqdm import tqdm +from typing import Optional, Union, Tuple + +def topk_cosine_neighbors(embeddings: torch.Tensor, + k: int = 10, + key_embeddings: Optional[torch.Tensor] = None, + row_batch: Optional[int] = None, + col_batch: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + normalize: bool = True, + dtype: Optional[torch.dtype] = None, + show_progress: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + + """ + For each query embedding, find the indices and similarities of its top-k neighbors + from a set of key embeddings, sorted by descending similarity. + + Supports both self-similarity (single array, excludes self) and pairwise + retrieval (two arrays, no exclusion). + + Optimized for maximum speed and memory efficiency across CPU, CUDA, and MPS. + Uses a streaming batched approach to handle datasets larger than GPU memory. + + Args: + embeddings (torch.Tensor): Query embeddings, shape (N_q, D). + k (int): How many neighbors to return. + key_embeddings (torch.Tensor, optional): Database/Key embeddings, shape (N_k, D). + If None, defaults to 'embeddings' (self-search). + row_batch (int, optional): Number of query rows to process at once. Auto-tuned if None. + col_batch (int, optional): Number of key columns to process at once. Auto-tuned if None. + device (str or torch.device, optional): Target device. If None, uses embeddings.device. + normalize (bool): If True, L2-normalize embeddings. Skip if already normalized. + dtype (torch.dtype, optional): Compute dtype (e.g., torch.float16, torch.bfloat16). + If None, uses embeddings.dtype. + show_progress (bool): Show tqdm progress bar. + + Returns: + top_idx (torch.Tensor): shape (N_q, k), int32 indices of nearest neighbors (indices into key_embeddings). + top_sim (torch.Tensor): shape (N_q, k), float32 cosine similarities. + """ + + # 1. Determine Search Mode (Self vs. Pairwise) + is_self_search = (key_embeddings is None) + if is_self_search: + key_embeddings = embeddings + + # 2. Device & Dtype Setup + if device is None: + device = embeddings.device + else: + device = torch.device(device) + + # Determine compute dtype + if dtype is None: + dtype = embeddings.dtype + else: + assert dtype.is_floating_point, "dtype must be a floating point type" + + # Move and cast embeddings + # Ensure contiguous for efficient matmul + query_embeddings = embeddings.to(device=device, dtype=dtype).contiguous() + key_embeddings = key_embeddings.to(device=device, dtype=dtype).contiguous() + + N_q, D = query_embeddings.shape + N_k, D_k = key_embeddings.shape + + if D != D_k: + raise ValueError(f"Query and Key embeddings must have same dimension. Got {D} and {D_k}") + + # Validation + if k < 1: + raise ValueError(f"k must be >= 1; got {k}") + + if is_self_search: + if k >= N_q: + raise ValueError(f"For self-search, k must be < N (to exclude self). Got N={N_q}, k={k}") + else: + if k > N_k: + raise ValueError(f"For pairwise search, k must be <= N_k. Got N_k={N_k}, k={k}") + + # 3. Auto-tune batch sizes based on device and memory + # Heuristics adjusted for potentially different N_q and N_k + if row_batch is None: + if device.type == 'cuda': + row_batch = 16384 + elif device.type == 'mps': + row_batch = 8192 + else: + row_batch = 4096 # CPU + + if col_batch is None: + if device.type == 'cuda': + col_batch = 16384 + elif device.type == 'mps': + col_batch = 8192 + else: + col_batch = 4096 # CPU + + # Clamp batch sizes to actual dimensions + row_batch = min(row_batch, N_q) + col_batch = min(col_batch, N_k) + + # 4. Optional Normalization + if normalize: + # Normalize in-place if possible, or reassign + query_embeddings = F.normalize(query_embeddings, p=2, dim=1) + # Only normalize keys if they are distinct from queries to avoid redundant work + # in self-search case (already normalized above) + if not is_self_search: + key_embeddings = F.normalize(key_embeddings, p=2, dim=1) + + # 5. Initialize Result Tensors (always float32 for precision in output) + top_sim = torch.empty((N_q, k), dtype=torch.float32, device=device) + top_idx = torch.empty((N_q, k), dtype=torch.int, device=device) + + # Pre-allocate reusable buffers for inner loop (memory efficiency) + # Buffers for top-k merge (size 2k) + merge_sim_buffer = torch.empty((row_batch, 2 * k), dtype=dtype, device=device) + merge_idx_buffer = torch.empty((row_batch, 2 * k), dtype=torch.int, device=device) + + # Buffer for column batch similarities + sim_buffer = torch.empty((row_batch, col_batch), dtype=dtype, device=device) + + # Value for masking (minimum possible float for the dtype) + min_val = -torch.finfo(dtype).max + + # 6. Inference Context + with torch.no_grad(): + iterator = range(0, N_q, row_batch) + if show_progress: + desc = "Query Batches" if not is_self_search else "Row Batches" + iterator = tqdm(iterator, desc=desc, leave=True) + + for i in iterator: + i_end = min(i + row_batch, N_q) + rb = i_end - i + + rows = query_embeddings[i:i_end] # (rb, D) + + # Initialize current batch top-k + # Use a tensor that persists across column batches for the current row batch + curr_sim = torch.full((rb, k), min_val, dtype=dtype, device=device) + curr_idx = torch.full((rb, k), -1, dtype=torch.int, device=device) + + for j in range(0, N_k, col_batch): + j_end = min(j + col_batch, N_k) + cb = j_end - j + + cols = key_embeddings[j:j_end] # (cb, D) + + # Compute similarities in-place into buffer + # sim_block shape: (rb, cb) + sim_block = sim_buffer[:rb, :cb] + torch.matmul(rows, cols.T, out=sim_block) + + # Mask self-similarity ONLY if self-search + if is_self_search: + offset = i - j + r_start = max(0, -offset) + r_end = min(rb, cb - offset) + + if r_start < r_end: + # Vectorized masking of the diagonal + r_range = torch.arange(r_start, r_end, dtype=torch.long, device=device) + c_range = r_range + offset + sim_block[r_range, c_range] = min_val + + # Top-k in block + if cb >= k: + blk_s, blk_p = torch.topk(sim_block, k, dim=1, largest=True, sorted=True) + blk_i = blk_p + j + else: + # Pad block to k if remaining keys are fewer than k + pad_size = k - cb + pad_vals = torch.full((rb, pad_size), min_val, dtype=dtype, device=device) + sims_padded = torch.cat([sim_block, pad_vals], dim=1) + blk_s, blk_p = torch.topk(sims_padded, k, dim=1, largest=True, sorted=True) + blk_i = blk_p + j + # Invalidate padded indices + blk_i[blk_s == min_val] = -1 + + # Merge with current best + # Layout: [curr_sim (k), blk_s (k)] -> topk(2k) -> keep k + merge_sim_buffer[:rb, :k] = curr_sim + merge_sim_buffer[:rb, k:2*k] = blk_s + merge_idx_buffer[:rb, :k] = curr_idx + merge_idx_buffer[:rb, k:2*k] = blk_i + + curr_sim, top_p = torch.topk(merge_sim_buffer[:rb, :2*k], k, dim=1, largest=True, sorted=True) + curr_idx = torch.gather(merge_idx_buffer[:rb, :2*k], dim=1, index=top_p) + + # Write results (convert to float32 for consistency) + top_sim[i:i_end] = curr_sim.to(torch.float32) + top_idx[i:i_end] = curr_idx + + # 7. Post-processing return format + if k == 1: + return top_idx.view(-1), top_sim.view(-1) + + return top_idx, top_sim + +#================================================================================================================================= +# Embeddings visualization functions +#================================================================================================================================= + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import pairwise_distances + +def plot_emb_cosine_similarity(embeddings, + clip=2.0, + gamma=0.55, + cmap="inferno", + figsize=(20, 20), + dpi=300, + output_fname='embeddings_similarity_plot.png', + return_sims=False + ): + + """ + Produces a crisp, high-contrast cosine similarity heatmap. + - clip: percentile clipping (1–5 recommended) + - gamma: nonlinear contrast (0.4–0.8 recommended) + + ----------- + Use Example + ----------- + + tok_emb = model.net.token_emb.emb.weight.detach().cpu() + + plot_cosine_similarity(tok_emb) + """ + + # 1. Compute cosine similarity (not distance) + cos_dist = pairwise_distances(embeddings, metric="cosine") + cos_sim = 1 - cos_dist + + # 2. Gamma correction for contrast + sim = np.sign(cos_sim) * (np.abs(cos_sim) ** gamma) + + # 3. Percentile clipping to remove flat tails + vmin, vmax = np.percentile(sim, [clip, 100 - clip]) + + # 4. Plot + plt.figure(figsize=figsize, dpi=dpi) + plt.imshow(sim, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest") + plt.colorbar(fraction=0.046, pad=0.04) + plt.title("Embeddings Pairwise Cosine Similarity") + plt.xlabel("Embedding Index") + plt.ylabel("Embeddings Index") + plt.tight_layout() + plt.savefig(output_fname) + plt.show() + + if return_sims: + return sim + +#================================================================================================================================= +# Fine-tuning functions +#================================================================================================================================= + +def unfreeze_last_n_blocks_and_norms(model, + n_last=2, + verbose=True + ): + + """ + 2-3 unfrozen layers usually produce good results. Default is 2 + + Returns: configured model and optimizer + """ + + # freeze everything first + for p in model.parameters(): + p.requires_grad = False + + # unfreeze head + for p in model.net.to_logits.parameters(): + p.requires_grad = True + + # unfreeze last n blocks' params and any LayerNorms inside them that have params + layers = model.net.attn_layers.layers # ModuleList of blocks + last_blocks = list(layers)[-n_last:] + for block in last_blocks: + for name, p in block.named_parameters(): + p.requires_grad = True + + # unfreeze final norm if it has parameters + final_norm = getattr(model.net.attn_layers, "final_norm", None) + if final_norm is not None: + for p in final_norm.parameters(): + p.requires_grad = True + + # verify counts + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + + if verbose: + print(f"Trainable params {trainable:,} / {total:,}") + + # collect ids for head params + head_params = list(model.net.to_logits.parameters()) + head_param_ids = {id(p) for p in head_params} + + # group trainable params into two buckets without tensor comparisons + pretrained_params = [] + head_only = [] + + for p in model.parameters(): + if not p.requires_grad: + continue + if id(p) in head_param_ids: + head_only.append(p) + else: + pretrained_params.append(p) + + # sanity checks + trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only) + total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + + assert trainable == total_trainable, "Mismatch in grouped trainable params" + + if verbose: + print(f"Pretrained params: {sum(p.numel() for p in pretrained_params):,}") + print(f"Head params: {sum(p.numel() for p in head_only):,}") + print(f"Total trainable: {total_trainable:,}") + + optim = torch.optim.Adam([ + {"params": pretrained_params, "lr": 1e-5}, + {"params": head_params, "lr": 5e-5} + ]) + + return model, optim + +def unfreeze_last_n_blocks_and_norms_full(model, + n_last_encoder=1, + n_last_decoder=2, + verbose=True + ): + + """ + Freeze entire XTransformer, then unfreeze: + - Last `n_last_encoder` encoder blocks (including all parameters in those blocks, e.g., LayerNorms) + - Last `n_last_decoder` decoder blocks (including all parameters in those blocks) + - Final encoder/decoder LayerNorms (if present and has params) + - Decoder's output head (`to_logits`) + + """ + + from x_transformer_2_3_1 import LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm + + # 1. Freeze everything + for p in model.parameters(): + p.requires_grad = False + + # 2. Unfreeze decoder head + for p in model.decoder.net.to_logits.parameters(): + p.requires_grad = True + + # 3. Helper to detect if a module is a parameterized LayerNorm-like module + def is_parametrized_norm(module): + # Custom norms from x-transformers + norm_types = (LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm) + if isinstance(module, norm_types): + return True + # Also include PyTorch built-in LayerNorm if used + if isinstance(module, torch.nn.LayerNorm): + return True + return False + + # 4. Helper to unfreeze last N blocks + norms inside them + final norm + def unfreeze_last_blocks(transformer_wrapper, n_last): + if n_last <= 0: + return + + # The actual AttentionLayers module + attn_layers = transformer_wrapper.attn_layers + layers = attn_layers.layers # ModuleList of blocks + last_blocks = list(layers)[-n_last:] + + for block in last_blocks: + # Unfreeze all parameters in the block (includes attention, FFN, and any embedded LayerNorms) + for p in block.parameters(): + p.requires_grad = True + + # Additionally, explicitly unfreeze any LayerNorm-like submodules with params (defensive) + for submodule in block.modules(): + if is_parametrized_norm(submodule): + for p in submodule.parameters(): + p.requires_grad = True + + # Unfreeze final norm (if exists and has params) + final_norm = getattr(attn_layers, 'final_norm', None) + if final_norm is not None and list(final_norm.parameters()): + for p in final_norm.parameters(): + p.requires_grad = True + + # 5. Apply to encoder and decoder + unfreeze_last_blocks(model.encoder, n_last_encoder) + unfreeze_last_blocks(model.decoder.net, n_last_decoder) # note: .net because of AutoregressiveWrapper + + # ====================== + # Parameter grouping (same as before) + # ====================== + head_params = list(model.decoder.net.to_logits.parameters()) + head_param_ids = {id(p) for p in head_params} + + pretrained_params = [] + head_only = [] + + for p in model.parameters(): + if not p.requires_grad: + continue + if id(p) in head_param_ids: + head_only.append(p) + else: + pretrained_params.append(p) + + # Sanity check + trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only) + total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert trainable == total_trainable, "Mismatch in grouped trainable params" + + if verbose: + print(f"Trainable params {trainable:,} / {total_trainable:,}") + print(f"Pretrained (enc/dec): {sum(p.numel() for p in pretrained_params):,}") + print(f"Head: {sum(p.numel() for p in head_only):,}") + print(f"Total trainable: {total_trainable:,}") + + # Optimizer + optim = torch.optim.Adam([ + {"params": pretrained_params, "lr": 1e-5}, + {"params": head_only, "lr": 5e-5} + ]) + + return model, optim + +#================================================================================================================================= +# Merging functions +#================================================================================================================================= + +def merge_encoder_and_decoder(model, + encoder_ckpt, + decoder_ckpt, + print_keys=False, + verbose=True + ): + + if verbose: + print('=' * 70) + print('Merging...') + print('=' * 70) + + if print_keys: + print('=' * 70) + print('Merged model keys:', model.state_dict().keys()) + print('=' * 70) + + if verbose: + print('=' * 70) + print('Loading encoder model...') + print('=' * 70) + + enc_ckpt = torch.load(decoder_ckpt, map_location='cpu') + enc_pre_sd = enc_ckpt.get('state_dict', enc_ckpt) + + if print_keys: + print('=' * 70) + print('Encoder model keys:', enc_pre_sd.keys()) + print('=' * 70) + + if verbose: + print('=' * 70) + print('Loading decoder model...') + print('=' * 70) + + dec_ckpt = torch.load(encoder_ckpt, map_location='cpu') + dec_pre_sd = dec_ckpt.get('state_dict', dec_ckpt) + + if print_keys: + print('=' * 70) + print('Decoder model keys', dec_pre_sd.keys()) + print('=' * 70) + + if verbose: + print('=' * 70) + print('Prepping merged model...') + print('=' * 70) + + model_new_sd = model.state_dict() + + for old_key, tensor in enc_pre_sd.items(): + + new_key = 'encoder.' + old_key + if new_key in model_new_sd: + model_new_sd[new_key] = tensor + + for old_key, tensor in dec_pre_sd.items(): + + new_key = old_key.replace('net.', 'decoder.net.') + if new_key in model_new_sd: + model_new_sd[new_key] = tensor + + if verbose: + print('=' * 70) + print('Final integrity check...') + print('=' * 70) + + # new_sd is your merged/updated state_dict + incompat = model.load_state_dict(model_new_sd, strict=False) + + if verbose: + # incompat is an IncompatibleKeys(namedtuple) + print("Missing keys: ", incompat.missing_keys) + print("Unexpected keys: ", incompat.unexpected_keys) + + try: + if verbose: + print('=' * 70) + print('Loading merged model...') + + model.load_state_dict(model_new_sd, strict=True) + + if verbose: + print('Done!') + print('=' * 70) + + return model + + except: + if verbose: + print('Failed to create merged model!') + print('=' * 70) + + return incompat + +#================================================================================================================================= +# Boundary Classifier functions +#================================================================================================================================= + +import os +import time +from collections import deque +from math import ceil +import numpy as np +import torch +from torch.utils.data import Dataset +from torch import nn +from typing import List, Tuple, Optional, Callable, Sequence + +class BoundaryDataset(Dataset): + def __init__(self, inputs_list, labels_list): + self.inputs = inputs_list + self.labels = labels_list + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return self.inputs[idx], self.labels[idx], None + +class BoundaryClassifier(nn.Module): + def __init__(self, num_tokens: int, max_seq_len: int, dim: int = 512, + depth: int = 12, heads: int = 16, num_labels: int = 2, + pad_token_id: int = 384, dropout: float = 0.1): + super().__init__() + self.pad_token_id = pad_token_id + + self.backbone = TransformerWrapper( + num_tokens=num_tokens, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=dim, depth=depth, heads=heads, + rotary_pos_emb=True, attn_flash=True) + ) + + self.classifier = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, num_labels) + ) + + def forward(self, input_ids, attn_mask=None): + hidden = self.backbone(input_ids, mask=attn_mask, return_embeddings=True) + logits = self.classifier(hidden) + return logits + +class FocalLoss(nn.Module): + def __init__(self, gamma=2.0, alpha=None, ignore_index=384): + super().__init__() + self.gamma = gamma + self.alpha = alpha # Tensor of shape [num_classes] + self.ignore_index = ignore_index + + def forward(self, logits, targets, mask=None): + B, N, C = logits.shape + logits_flat = logits.view(B * N, C) + targets_flat = targets.view(B * N) + + # 1. Create valid mask (ignore PAD and any explicit mask) + if mask is not None: + valid_mask = (targets_flat != self.ignore_index) & mask.view(-1) + else: + valid_mask = (targets_flat != self.ignore_index) + + if valid_mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + # 2. Numerically stable log_softmax + log_probs = torch.log_softmax(logits_flat, dim=-1) + + # 3. CRITICAL FIX: Clamp targets to valid range [0, C-1] before indexing + # This prevents CUDA assert when targets contain PAD_IDX (e.g., 384) + targets_clamped = targets_flat.clamp(0, C - 1) + + # 4. Gather probabilities safely + p_t = log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped] + p_t = torch.exp(p_t) # Convert back to probability for focal factor + + # 5. Calculate NLL Loss + nll_loss = -log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped] + + # 6. Focal Factor + focal_factor = (1.0 - p_t) ** self.gamma + loss = focal_factor * nll_loss + + # 7. Apply Class Weights (Alpha) + if self.alpha is not None: + # Alpha must also be gathered using clamped indices + alpha_t = self.alpha[targets_clamped] + loss = alpha_t * loss + + # 8. Apply Mask (Zero out loss for PAD tokens) + loss = loss * valid_mask.float() + + return loss.sum() / valid_mask.sum().clamp(min=1.0) + +Logger = Optional[Callable[[str], None]] + +def filter_balanced_sequences( + sequences: List[List[int]], + token_types: List[List[int]], + tol: float = 0.1, + min_len: int = 1, + max_len: Optional[int] = None, + balance_target: float = 0.5, + return_indices: bool = False, + verbose: int = 2, + logger: Logger = None, + progress_chunk: int = 50000 + ) -> Tuple[List[List[int]], List[List[int]], Optional[List[int]]]: + + """ + Filter sequence/token-type pairs to those whose token-type distribution is near-balanced, + with verbosity and lightweight progress reporting. + + Parameters + ---------- + sequences : List[List[int]] + Token sequences (not used for balance computation). + token_types : List[List[int]] + Binary token-type lists (0/1) corresponding to sequences. + tol : float + Allowed absolute deviation from balance_target. + min_len : int + Minimum token_types length to consider. + max_len : Optional[int] + Maximum token_types length to consider. + balance_target : float + Target proportion of 1s (0..1). + return_indices : bool + If True, also return kept indices. + verbose : int + 0 = silent, 1 = concise, 2 = detailed chunk diagnostics. + logger : callable or None + If provided, called with status strings instead of/in addition to printing. + progress_chunk : int + Emit chunk updates every `progress_chunk` items when verbose >= 1. + + Returns + ------- + filtered_sequences, filtered_token_types, indices_or_none + """ + + def _log(msg: str): + if logger: + try: + logger(msg) + except Exception: + pass + if verbose >= 1: + print(msg) + + if len(sequences) != len(token_types): + raise ValueError("`sequences` and `token_types` must have the same length.") + + n = len(token_types) + if n == 0: + _log("Input empty: nothing to do.") + return [], [], ([] if return_indices else None) + + start_all = time.perf_counter() + _log(f"Starting filter: {n} pairs; tol={tol}; target={balance_target}; min_len={min_len}; max_len={max_len}") + + # Compute lengths and counts using Python builtins (fast for lists of ints) + # We iterate once and optionally emit chunk diagnostics to avoid storing huge intermediate lists. + lengths = np.empty(n, dtype=np.int32) + counts = np.empty(n, dtype=np.int32) + + t0 = time.perf_counter() + for i, tlist in enumerate(token_types): + lengths[i] = len(tlist) + # sum on list of ints is C-optimized and fast + counts[i] = sum(tlist) + # chunked progress logging to avoid I/O overhead + if verbose >= 2 and (i + 1) % progress_chunk == 0: + elapsed = time.perf_counter() - t0 + _log(f" scanned {i+1}/{n} token_types (elapsed {elapsed:.2f}s)") + + scan_time = time.perf_counter() - t0 + _log(f"Scanned counts and lengths in {scan_time:.2f}s") + + # Mask length constraints and nonzero lengths + nonzero_mask = lengths > 0 + mask = nonzero_mask.copy() + if min_len > 1: + mask &= (lengths >= min_len) + if max_len is not None: + mask &= (lengths <= max_len) + + candidates = int(mask.sum()) + _log(f"Candidates after length filtering: {candidates}/{n}") + + if candidates == 0: + _log("No candidates after length filtering. Exiting.") + return [], [], ([] if return_indices else None) + + # Compute proportions for candidates only + lengths_f = lengths.astype(np.float32) + proportions = np.empty_like(lengths_f) + # safe division only for masked entries + proportions[mask] = counts[mask].astype(np.float32) / lengths_f[mask] + proportions[~mask] = -1.0 # sentinel + + # Balanced criterion + balance_mask = np.abs(proportions - float(balance_target)) <= float(tol) + final_mask = mask & balance_mask + kept = int(final_mask.sum()) + elapsed_total = time.perf_counter() - start_all + _log(f"Kept {kept}/{n} sequences (elapsed total {elapsed_total:.2f}s)") + + if kept == 0: + _log("No sequences met the balance criterion. Exiting.") + return [], [], ([] if return_indices else None) + + # Get indices to keep + keep_idx = np.nonzero(final_mask)[0].tolist() + + # Build filtered lists (list comprehension over kept indices) + # This is the only place we materialize the filtered lists. + t_build = time.perf_counter() + filtered_sequences = [sequences[i] for i in keep_idx] + filtered_token_types = [token_types[i] for i in keep_idx] + build_time = time.perf_counter() - t_build + + _log(f"Built filtered lists: {kept} items (build time {build_time:.2f}s)") + + # Optional detailed stats + if verbose >= 1: + # compute some quick stats on proportions of kept items + kept_props = proportions[final_mask] + mean_prop = float(np.mean(kept_props)) + std_prop = float(np.std(kept_props)) + min_prop = float(np.min(kept_props)) + max_prop = float(np.max(kept_props)) + _log(f"Kept proportions stats: mean={mean_prop:.4f}, std={std_prop:.4f}, min={min_prop:.4f}, max={max_prop:.4f}") + + if return_indices: + return filtered_sequences, filtered_token_types, keep_idx + else: + return filtered_sequences, filtered_token_types, None + +def compute_class_counts_from_list(labels_list: Sequence[Sequence[int]], + num_labels: int = 2, + pad_idx: int = 384) -> torch.LongTensor: + if len(labels_list) == 0: + return torch.zeros(num_labels, dtype=torch.long) + + counts = [0] * num_labels + for lbl in range(num_labels): + counts[lbl] = sum(seq.count(lbl) for seq in labels_list) + + if 0 <= pad_idx < num_labels: + pad_total = sum(seq.count(pad_idx) for seq in labels_list) + counts[pad_idx] = max(0, counts[pad_idx] - pad_total) + + return torch.tensor(counts, dtype=torch.long) + +def compute_class_weights( + labels_list, + num_labels=2, + pad_idx=384, + smoothing=0.0, + power=1.0, + max_ratio=50.0 + ): + + """ + Stable, imbalance-preserving class weights. + - No renormalization that destroys imbalance + - Optional smoothing (default 0) + - Optional exponent scaling (power) + - Optional cap on extreme ratios + """ + + # Count tokens + counts = compute_class_counts_from_list(labels_list, num_labels, pad_idx).float() + counts = torch.clamp(counts, min=1.0) + + # Inverse frequency + inv = 1.0 / counts + + # Optional smoothing + if smoothing > 0: + inv = inv + smoothing + + # Optional exponent scaling + if power != 1.0: + inv = inv ** power + + # Normalize so smallest class = 1.0 + inv = inv / inv.min() + + # Cap extreme ratios + inv = torch.clamp(inv, max=max_ratio) + + return inv + +def collate_fn_from_lists(batch, pad_token_id=384): + input_seqs = [list(x[0]) for x in batch] + label_seqs = [list(x[1]) for x in batch] + + # Get actual lengths + lengths = [len(s) for s in input_seqs] + max_len = max(lengths) if lengths else 0 + B = len(input_seqs) + + input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long) + labels = torch.full((B, max_len), pad_token_id, dtype=torch.long) + attn_mask = torch.zeros((B, max_len), dtype=torch.bool) + + for i, (xseq, yseq, length) in enumerate(zip(input_seqs, label_seqs, lengths)): + if length > 0: + input_ids[i, :length] = torch.LongTensor(xseq) + labels[i, :length] = torch.LongTensor(yseq[:length]) + # IMPROVEMENT: Mask based on length, not token ID content + attn_mask[i, :length] = True + + return input_ids, labels, attn_mask + #================================================================================================================================= # This is the end of x_transformer_2_3_1 Python module #================================================================================================================================= \ No newline at end of file