Spaces:
Running
on
Zero
Running
on
Zero
Upload x_transformer_2_3_1.py
Browse files- x_transformer_2_3_1.py +657 -4
x_transformer_2_3_1.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
#
|
| 5 |
# Partial x-transformers code With useful modifications as a stand-alone Python module
|
| 6 |
#
|
| 7 |
-
# Version
|
| 8 |
#
|
| 9 |
# Original source code courtesy of lucidrains
|
| 10 |
# https://github.com/lucidrains/x-transformers
|
|
@@ -45,6 +45,7 @@ import torch
|
|
| 45 |
from torch.nn import Module
|
| 46 |
from torch import nn, einsum, Tensor
|
| 47 |
import torch.nn.functional as F
|
|
|
|
| 48 |
|
| 49 |
from collections import namedtuple
|
| 50 |
from functools import wraps
|
|
@@ -3982,7 +3983,7 @@ class AutoregressiveWrapper(Module):
|
|
| 3982 |
# whether to add router z-loss
|
| 3983 |
self.add_attn_z_loss = add_attn_z_loss
|
| 3984 |
|
| 3985 |
-
@torch.
|
| 3986 |
@eval_decorator
|
| 3987 |
def generate(
|
| 3988 |
self,
|
|
@@ -4147,7 +4148,7 @@ class AutoregressiveWrapper(Module):
|
|
| 4147 |
|
| 4148 |
return out
|
| 4149 |
|
| 4150 |
-
@torch.
|
| 4151 |
@eval_decorator
|
| 4152 |
def generate_masked(
|
| 4153 |
self,
|
|
@@ -4328,7 +4329,7 @@ class AutoregressiveWrapper(Module):
|
|
| 4328 |
|
| 4329 |
return out
|
| 4330 |
|
| 4331 |
-
@torch.
|
| 4332 |
@eval_decorator
|
| 4333 |
def generate_biased(
|
| 4334 |
self,
|
|
@@ -4556,6 +4557,226 @@ class AutoregressiveWrapper(Module):
|
|
| 4556 |
out, = unpack(out, ps, '* n')
|
| 4557 |
|
| 4558 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4559 |
|
| 4560 |
def compute_accuracy(self, logits, labels):
|
| 4561 |
|
|
@@ -4613,6 +4834,438 @@ class AutoregressiveWrapper(Module):
|
|
| 4613 |
|
| 4614 |
return loss, acc, logits, cache
|
| 4615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4616 |
#=================================================================================================================================
|
| 4617 |
# This is the end of x_transformer_2_3_1 Python module
|
| 4618 |
#=================================================================================================================================
|
|
|
|
| 4 |
#
|
| 5 |
# Partial x-transformers code With useful modifications as a stand-alone Python module
|
| 6 |
#
|
| 7 |
+
# Version 6.0
|
| 8 |
#
|
| 9 |
# Original source code courtesy of lucidrains
|
| 10 |
# https://github.com/lucidrains/x-transformers
|
|
|
|
| 45 |
from torch.nn import Module
|
| 46 |
from torch import nn, einsum, Tensor
|
| 47 |
import torch.nn.functional as F
|
| 48 |
+
from torch.utils.data import Dataset, DataLoader
|
| 49 |
|
| 50 |
from collections import namedtuple
|
| 51 |
from functools import wraps
|
|
|
|
| 3983 |
# whether to add router z-loss
|
| 3984 |
self.add_attn_z_loss = add_attn_z_loss
|
| 3985 |
|
| 3986 |
+
@torch.inference_mode()
|
| 3987 |
@eval_decorator
|
| 3988 |
def generate(
|
| 3989 |
self,
|
|
|
|
| 4148 |
|
| 4149 |
return out
|
| 4150 |
|
| 4151 |
+
@torch.inference_mode()
|
| 4152 |
@eval_decorator
|
| 4153 |
def generate_masked(
|
| 4154 |
self,
|
|
|
|
| 4329 |
|
| 4330 |
return out
|
| 4331 |
|
| 4332 |
+
@torch.inference_mode()
|
| 4333 |
@eval_decorator
|
| 4334 |
def generate_biased(
|
| 4335 |
self,
|
|
|
|
| 4557 |
out, = unpack(out, ps, '* n')
|
| 4558 |
|
| 4559 |
return out
|
| 4560 |
+
|
| 4561 |
+
@torch.inference_mode()
|
| 4562 |
+
@eval_decorator
|
| 4563 |
+
def generate_advanced(
|
| 4564 |
+
self,
|
| 4565 |
+
prompts,
|
| 4566 |
+
seq_len,
|
| 4567 |
+
eos_token = None,
|
| 4568 |
+
temperature = 1.,
|
| 4569 |
+
prompt_lens: Tensor | None = None,
|
| 4570 |
+
filter_logits_fn: str | Callable = top_k,
|
| 4571 |
+
restrict_to_max_seq_len = True,
|
| 4572 |
+
amateur_model: Module | Tuple[Module] | None = None,
|
| 4573 |
+
filter_kwargs: dict = dict(),
|
| 4574 |
+
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| 4575 |
+
beta = 0.5,
|
| 4576 |
+
alpha = 0.1
|
| 4577 |
+
),
|
| 4578 |
+
cache_kv = True,
|
| 4579 |
+
return_prime=False,
|
| 4580 |
+
verbose=True,
|
| 4581 |
+
# --- new generation options ---
|
| 4582 |
+
logits_bias: dict | None = None, # {token_id: bias_value} where bias_value is float or Tensor(batch,)
|
| 4583 |
+
masked_tokens: list | Tensor | None = None, # list of token ids to forbid
|
| 4584 |
+
# --- binary classifier mode ---
|
| 4585 |
+
binary_classifier: bool = False, # if True, run classifier snippet and return preds, probs
|
| 4586 |
+
classifier_model: Module | None = None, # model to use for binary classification
|
| 4587 |
+
batches: list | None = None, # iterable of input batches for classifier_model
|
| 4588 |
+
threshold: float = 0.5, # threshold for converting probs to preds
|
| 4589 |
+
classifier_device: torch.device | None = None,
|
| 4590 |
+
# -----------------
|
| 4591 |
+
**kwargs
|
| 4592 |
+
):
|
| 4593 |
+
# If binary classifier mode requested, run the provided snippet and return early.
|
| 4594 |
+
if binary_classifier:
|
| 4595 |
+
assert classifier_model is not None, "classifier_model must be provided when binary_classifier=True"
|
| 4596 |
+
assert batches is not None, "batches (iterable of input tensors) must be provided when binary_classifier=True"
|
| 4597 |
+
|
| 4598 |
+
device = classifier_device if classifier_device is not None else (prompts.device if exists(prompts) else torch.device('cpu'))
|
| 4599 |
+
|
| 4600 |
+
all_probs = []
|
| 4601 |
+
all_preds = []
|
| 4602 |
+
|
| 4603 |
+
classifier_model.eval()
|
| 4604 |
+
with torch.no_grad():
|
| 4605 |
+
for x in batches:
|
| 4606 |
+
x = x.to(device)
|
| 4607 |
+
logits = classifier_model(x).squeeze() # [B]
|
| 4608 |
+
probs = torch.sigmoid(logits) # [B]
|
| 4609 |
+
preds = (probs >= threshold).long()
|
| 4610 |
+
|
| 4611 |
+
all_probs.extend(probs.cpu().tolist())
|
| 4612 |
+
all_preds.extend(preds.cpu().tolist())
|
| 4613 |
+
|
| 4614 |
+
return all_preds, all_probs
|
| 4615 |
+
|
| 4616 |
+
# --- normal generation path below ---
|
| 4617 |
+
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
| 4618 |
+
|
| 4619 |
+
prompts, ps = pack([prompts], '* n')
|
| 4620 |
+
|
| 4621 |
+
b, t = prompts.shape
|
| 4622 |
+
|
| 4623 |
+
# handle filter logits fn given as string
|
| 4624 |
+
if isinstance(filter_logits_fn, str):
|
| 4625 |
+
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| 4626 |
+
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
| 4627 |
+
|
| 4628 |
+
# handle variable lengthed prompts (prefixes)
|
| 4629 |
+
seq_start_pos = None
|
| 4630 |
+
if exists(prompt_lens):
|
| 4631 |
+
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| 4632 |
+
seq_start_pos = t - prompt_lens
|
| 4633 |
+
|
| 4634 |
+
# output from which sampled tokens appended to
|
| 4635 |
+
out = prompts
|
| 4636 |
+
|
| 4637 |
+
if verbose:
|
| 4638 |
+
print("Generating sequence of max length:", seq_len)
|
| 4639 |
+
|
| 4640 |
+
# kv caches
|
| 4641 |
+
cache = None
|
| 4642 |
+
|
| 4643 |
+
# if doing contrastive decoding, turn off filter automatically
|
| 4644 |
+
if exists(amateur_model):
|
| 4645 |
+
amateur_model = cast_tuple(amateur_model)
|
| 4646 |
+
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
| 4647 |
+
|
| 4648 |
+
assert len(amateur_model) == len(contrastive_decode_kwargs)
|
| 4649 |
+
|
| 4650 |
+
amateur_caches = [None] * len(amateur_model)
|
| 4651 |
+
filter_logits_fn = identity
|
| 4652 |
+
|
| 4653 |
+
for i, module in enumerate(amateur_model):
|
| 4654 |
+
if isinstance(module, AutoregressiveWrapper):
|
| 4655 |
+
amateur_model[i] = module.net
|
| 4656 |
+
|
| 4657 |
+
module.eval()
|
| 4658 |
+
|
| 4659 |
+
# normalize inputs for new args
|
| 4660 |
+
if exists(logits_bias):
|
| 4661 |
+
assert isinstance(logits_bias, dict), "logits_bias must be a dict {token_id: bias_value}"
|
| 4662 |
+
if exists(masked_tokens):
|
| 4663 |
+
if isinstance(masked_tokens, torch.Tensor):
|
| 4664 |
+
masked_tokens = masked_tokens.tolist()
|
| 4665 |
+
else:
|
| 4666 |
+
masked_tokens = list(masked_tokens)
|
| 4667 |
+
|
| 4668 |
+
# sampling up to seq_len
|
| 4669 |
+
for sl in range(seq_len):
|
| 4670 |
+
|
| 4671 |
+
if restrict_to_max_seq_len:
|
| 4672 |
+
max_len_exceeded = out.shape[-1] > max_seq_len
|
| 4673 |
+
|
| 4674 |
+
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
| 4675 |
+
|
| 4676 |
+
x = out[:, -max_seq_len:]
|
| 4677 |
+
|
| 4678 |
+
if exists(cache):
|
| 4679 |
+
for inter in cache.attn_intermediates:
|
| 4680 |
+
if inter.layer_type == 'a':
|
| 4681 |
+
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
| 4682 |
+
|
| 4683 |
+
logits, new_cache = self.net(
|
| 4684 |
+
x,
|
| 4685 |
+
return_intermediates = True,
|
| 4686 |
+
cache = cache,
|
| 4687 |
+
seq_start_pos = seq_start_pos,
|
| 4688 |
+
**kwargs
|
| 4689 |
+
)
|
| 4690 |
+
|
| 4691 |
+
if cache_kv and self.net.can_cache_kv:
|
| 4692 |
+
cache = new_cache
|
| 4693 |
+
|
| 4694 |
+
logits = logits[:, -1] # shape: (batch, vocab)
|
| 4695 |
+
|
| 4696 |
+
# handle contrastive decoding, Li et al.
|
| 4697 |
+
if exists(amateur_model):
|
| 4698 |
+
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
| 4699 |
+
amateur_logits, next_amateur_cache = amateur(
|
| 4700 |
+
x,
|
| 4701 |
+
return_intermediates = True,
|
| 4702 |
+
cache = amateur_cache,
|
| 4703 |
+
seq_start_pos = seq_start_pos,
|
| 4704 |
+
**kwargs
|
| 4705 |
+
)
|
| 4706 |
+
|
| 4707 |
+
amateur_logits = amateur_logits[:, -1]
|
| 4708 |
+
|
| 4709 |
+
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
| 4710 |
+
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
| 4711 |
+
|
| 4712 |
+
if cache_kv and amateur.can_cache_kv:
|
| 4713 |
+
amateur_caches[i] = next_amateur_cache
|
| 4714 |
+
|
| 4715 |
+
# --- APPLY LOGITS BIAS AND MASKING HERE (before filtering / softmax) ---
|
| 4716 |
+
# logits_bias: dict {token_id: bias_value} where bias_value is float or Tensor(batch,)
|
| 4717 |
+
if exists(logits_bias):
|
| 4718 |
+
# apply per-token bias updates directly to logits to avoid allocating full vocab bias tensor
|
| 4719 |
+
for tok_id, bias_val in logits_bias.items():
|
| 4720 |
+
# support scalar or per-batch tensor
|
| 4721 |
+
if isinstance(bias_val, torch.Tensor):
|
| 4722 |
+
if bias_val.dim() == 1 and bias_val.shape[0] == b:
|
| 4723 |
+
bias_to_add = bias_val.to(device)
|
| 4724 |
+
else:
|
| 4725 |
+
bias_to_add = bias_val.to(device).view(1).expand(b)
|
| 4726 |
+
else:
|
| 4727 |
+
bias_to_add = torch.tensor(float(bias_val), device=device).view(1).expand(b)
|
| 4728 |
+
|
| 4729 |
+
logits[:, int(tok_id)] = logits[:, int(tok_id)] + bias_to_add
|
| 4730 |
+
|
| 4731 |
+
# masked_tokens: list of token ids to forbid
|
| 4732 |
+
if exists(masked_tokens) and len(masked_tokens) > 0:
|
| 4733 |
+
NEG_INF = -1e9
|
| 4734 |
+
idx = torch.tensor(masked_tokens, device=device, dtype=torch.long)
|
| 4735 |
+
idx = idx[(idx >= 0) & (idx < logits.shape[-1])]
|
| 4736 |
+
if idx.numel() > 0:
|
| 4737 |
+
logits.index_fill_(dim=-1, index=idx, value=NEG_INF)
|
| 4738 |
+
# -------------------------------------------------------------------
|
| 4739 |
+
|
| 4740 |
+
# filter by top_k, top_p (nucleus), top_a, or custom
|
| 4741 |
+
if greedy:
|
| 4742 |
+
sample = logits.argmax(dim = -1, keepdim = True)
|
| 4743 |
+
else:
|
| 4744 |
+
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| 4745 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 4746 |
+
sample = torch.multinomial(probs, 1)
|
| 4747 |
+
|
| 4748 |
+
# concat sample
|
| 4749 |
+
out = torch.cat((out, sample), dim=-1)
|
| 4750 |
+
|
| 4751 |
+
if verbose:
|
| 4752 |
+
if sl % 32 == 0:
|
| 4753 |
+
print(sl, '/', seq_len)
|
| 4754 |
+
|
| 4755 |
+
if not exists(eos_token):
|
| 4756 |
+
continue
|
| 4757 |
+
|
| 4758 |
+
is_eos_tokens = (out == eos_token)
|
| 4759 |
+
|
| 4760 |
+
if is_eos_tokens.any(dim = -1).all():
|
| 4761 |
+
if verbose:
|
| 4762 |
+
print('Model called the end of sequence at:', sl, '/', seq_len)
|
| 4763 |
+
break
|
| 4764 |
+
|
| 4765 |
+
if exists(eos_token):
|
| 4766 |
+
# mask out everything after the eos tokens
|
| 4767 |
+
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| 4768 |
+
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| 4769 |
+
out = out.masked_fill(mask, self.pad_value)
|
| 4770 |
+
|
| 4771 |
+
if return_prime:
|
| 4772 |
+
out = out[:, :]
|
| 4773 |
+
|
| 4774 |
+
else:
|
| 4775 |
+
out = out[:, t:]
|
| 4776 |
+
|
| 4777 |
+
out, = unpack(out, ps, '* n')
|
| 4778 |
+
|
| 4779 |
+
return out
|
| 4780 |
|
| 4781 |
def compute_accuracy(self, logits, labels):
|
| 4782 |
|
|
|
|
| 4834 |
|
| 4835 |
return loss, acc, logits, cache
|
| 4836 |
|
| 4837 |
+
@torch.inference_mode()
|
| 4838 |
+
@eval_decorator
|
| 4839 |
+
def generate_expert(
|
| 4840 |
+
self,
|
| 4841 |
+
prompts,
|
| 4842 |
+
seq_len,
|
| 4843 |
+
eos_token = None,
|
| 4844 |
+
temperature = 1.,
|
| 4845 |
+
prompt_lens: Tensor | None = None,
|
| 4846 |
+
filter_logits_fn: str | Callable = top_k,
|
| 4847 |
+
restrict_to_max_seq_len = True,
|
| 4848 |
+
amateur_model: Module | Tuple[Module] | None = None,
|
| 4849 |
+
filter_kwargs: dict = dict(),
|
| 4850 |
+
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| 4851 |
+
beta = 0.5,
|
| 4852 |
+
alpha = 0.1
|
| 4853 |
+
),
|
| 4854 |
+
cache_kv = True,
|
| 4855 |
+
return_prime=False,
|
| 4856 |
+
verbose=True,
|
| 4857 |
+
# --- new controls ---
|
| 4858 |
+
token_type_ids: torch.LongTensor | None = None, # [vocab]
|
| 4859 |
+
type_temperatures: dict | None = None, # {type_id: temp}
|
| 4860 |
+
type_biases: dict | None = None, # {type_id: bias}
|
| 4861 |
+
repetition_window: int = 64,
|
| 4862 |
+
repetition_penalty_per_type: dict | None = None, # {type_id: penalty_scale}
|
| 4863 |
+
rare_types: set | None = None, # e.g. {4, 5}
|
| 4864 |
+
rare_type_boost: float = 0.0, # small, e.g. 0.5
|
| 4865 |
+
entropy_threshold: float = 2.0, # when below, boost rare types
|
| 4866 |
+
# --- masked tokens option ---
|
| 4867 |
+
forbidden_token_ids: torch.LongTensor | torch.BoolTensor | None = None,
|
| 4868 |
+
forbidden_value: float = -1e9,
|
| 4869 |
+
**kwargs
|
| 4870 |
+
):
|
| 4871 |
+
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
| 4872 |
+
|
| 4873 |
+
prompts, ps = pack([prompts], '* n')
|
| 4874 |
+
|
| 4875 |
+
b, t = prompts.shape
|
| 4876 |
+
|
| 4877 |
+
# handle filter logits fn given as string
|
| 4878 |
+
|
| 4879 |
+
if isinstance(filter_logits_fn, str):
|
| 4880 |
+
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| 4881 |
+
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
| 4882 |
+
|
| 4883 |
+
# handle variable lengthed prompts (prefixes)
|
| 4884 |
+
|
| 4885 |
+
seq_start_pos = None
|
| 4886 |
+
if exists(prompt_lens):
|
| 4887 |
+
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| 4888 |
+
seq_start_pos = t - prompt_lens
|
| 4889 |
+
|
| 4890 |
+
# output from which sampled tokens appended to
|
| 4891 |
+
|
| 4892 |
+
out = prompts
|
| 4893 |
+
|
| 4894 |
+
if verbose:
|
| 4895 |
+
print("Generating sequence of max length:", seq_len)
|
| 4896 |
+
|
| 4897 |
+
# kv caches
|
| 4898 |
+
|
| 4899 |
+
cache = None
|
| 4900 |
+
|
| 4901 |
+
# if doing contrastive decoding, turn off filter automatically
|
| 4902 |
+
|
| 4903 |
+
if exists(amateur_model):
|
| 4904 |
+
amateur_model = cast_tuple(amateur_model)
|
| 4905 |
+
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
| 4906 |
+
|
| 4907 |
+
assert len(amateur_model) == len(contrastive_decode_kwargs)
|
| 4908 |
+
|
| 4909 |
+
amateur_caches = [None] * len(amateur_model)
|
| 4910 |
+
filter_logits_fn = identity
|
| 4911 |
+
|
| 4912 |
+
for i, module in enumerate(amateur_model):
|
| 4913 |
+
if isinstance(module, AutoregressiveWrapper):
|
| 4914 |
+
amateur_model[i] = module.net
|
| 4915 |
+
|
| 4916 |
+
module.eval()
|
| 4917 |
+
|
| 4918 |
+
# precompute some tensors for type controls
|
| 4919 |
+
|
| 4920 |
+
if token_type_ids is not None:
|
| 4921 |
+
token_type_ids = token_type_ids.to(device)
|
| 4922 |
+
|
| 4923 |
+
# build per-token temperature and bias vectors if provided
|
| 4924 |
+
per_token_temp = None
|
| 4925 |
+
if type_temperatures is not None and len(type_temperatures) > 0:
|
| 4926 |
+
per_token_temp = torch.ones_like(token_type_ids, dtype=torch.float32)
|
| 4927 |
+
for type_id, temp_val in type_temperatures.items():
|
| 4928 |
+
per_token_temp[token_type_ids == type_id] = float(temp_val)
|
| 4929 |
+
|
| 4930 |
+
per_token_bias = None
|
| 4931 |
+
if type_biases is not None and len(type_biases) > 0:
|
| 4932 |
+
per_token_bias = torch.zeros_like(token_type_ids, dtype=torch.float32)
|
| 4933 |
+
for type_id, bias_val in type_biases.items():
|
| 4934 |
+
per_token_bias[token_type_ids == type_id] = float(bias_val)
|
| 4935 |
+
|
| 4936 |
+
# repetition penalty per type
|
| 4937 |
+
per_type_rep_penalty = repetition_penalty_per_type or {}
|
| 4938 |
+
|
| 4939 |
+
# rare type mask
|
| 4940 |
+
rare_type_mask = None
|
| 4941 |
+
if rare_types is not None and len(rare_types) > 0:
|
| 4942 |
+
rare_type_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
| 4943 |
+
for rt in rare_types:
|
| 4944 |
+
rare_type_mask |= (token_type_ids == rt)
|
| 4945 |
+
else:
|
| 4946 |
+
per_token_temp = None
|
| 4947 |
+
per_token_bias = None
|
| 4948 |
+
per_type_rep_penalty = {}
|
| 4949 |
+
rare_type_mask = None
|
| 4950 |
+
|
| 4951 |
+
# prepare forbidden mask if provided
|
| 4952 |
+
# We'll lazily convert forbidden_token_ids into a boolean mask of shape [b, vocab]
|
| 4953 |
+
forbidden_mask_per_batch = None
|
| 4954 |
+
if forbidden_token_ids is not None:
|
| 4955 |
+
# If it's a LongTensor of ids (1D)
|
| 4956 |
+
if forbidden_token_ids.dtype in (torch.int64, torch.int32):
|
| 4957 |
+
# create a [vocab] bool mask from ids
|
| 4958 |
+
vocab_size = self.net.config.vocab_size if hasattr(self.net, 'config') else None
|
| 4959 |
+
# If we can't infer vocab_size, we'll infer from token_type_ids if available
|
| 4960 |
+
if vocab_size is None and token_type_ids is not None:
|
| 4961 |
+
vocab_size = token_type_ids.shape[0]
|
| 4962 |
+
assert vocab_size is not None, "Cannot infer vocab size for forbidden_token_ids; provide a boolean mask instead."
|
| 4963 |
+
mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
| 4964 |
+
ids = forbidden_token_ids.to(device)
|
| 4965 |
+
mask[ids.clamp(0, vocab_size-1)] = True
|
| 4966 |
+
forbidden_mask_per_batch = mask.unsqueeze(0).expand(b, -1) # [b, vocab]
|
| 4967 |
+
elif forbidden_token_ids.dtype == torch.bool:
|
| 4968 |
+
# could be [vocab] or [b, vocab]
|
| 4969 |
+
if forbidden_token_ids.dim() == 1:
|
| 4970 |
+
forbidden_mask_per_batch = forbidden_token_ids.to(device).unsqueeze(0).expand(b, -1)
|
| 4971 |
+
elif forbidden_token_ids.dim() == 2:
|
| 4972 |
+
assert forbidden_token_ids.shape[0] == b, "forbidden_token_ids batch dimension must match prompts batch size"
|
| 4973 |
+
forbidden_mask_per_batch = forbidden_token_ids.to(device)
|
| 4974 |
+
else:
|
| 4975 |
+
raise ValueError("forbidden_token_ids boolean mask must be 1D [vocab] or 2D [b, vocab]")
|
| 4976 |
+
else:
|
| 4977 |
+
raise TypeError("forbidden_token_ids must be LongTensor of ids or BoolTensor mask")
|
| 4978 |
+
|
| 4979 |
+
# sampling up to seq_len
|
| 4980 |
+
|
| 4981 |
+
for sl in range(seq_len):
|
| 4982 |
+
|
| 4983 |
+
if restrict_to_max_seq_len:
|
| 4984 |
+
max_len_exceeded = out.shape[-1] > max_seq_len
|
| 4985 |
+
|
| 4986 |
+
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
|
| 4987 |
+
'the network cannot use cached key values when decoding outside the max sequence length. ' \
|
| 4988 |
+
'most likely because you are using absolute positional embedding. ' \
|
| 4989 |
+
'you can switch to rotary embeddings to resolve this issue'
|
| 4990 |
+
|
| 4991 |
+
x = out[:, -max_seq_len:]
|
| 4992 |
+
|
| 4993 |
+
if exists(cache):
|
| 4994 |
+
for inter in cache.attn_intermediates:
|
| 4995 |
+
if inter.layer_type == 'a':
|
| 4996 |
+
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
| 4997 |
+
|
| 4998 |
+
logits, new_cache = self.net(
|
| 4999 |
+
x,
|
| 5000 |
+
return_intermediates = True,
|
| 5001 |
+
cache = cache,
|
| 5002 |
+
seq_start_pos = seq_start_pos,
|
| 5003 |
+
**kwargs
|
| 5004 |
+
)
|
| 5005 |
+
|
| 5006 |
+
if cache_kv and self.net.can_cache_kv:
|
| 5007 |
+
cache = new_cache
|
| 5008 |
+
|
| 5009 |
+
logits = logits[:, -1] # [b, vocab]
|
| 5010 |
+
|
| 5011 |
+
# handle contrastive decoding
|
| 5012 |
+
|
| 5013 |
+
if exists(amateur_model):
|
| 5014 |
+
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(
|
| 5015 |
+
zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
|
| 5016 |
+
):
|
| 5017 |
+
amateur_logits, next_amateur_cache = amateur(
|
| 5018 |
+
x,
|
| 5019 |
+
return_intermediates = True,
|
| 5020 |
+
cache = amateur_cache,
|
| 5021 |
+
seq_start_pos = seq_start_pos,
|
| 5022 |
+
**kwargs
|
| 5023 |
+
)
|
| 5024 |
+
|
| 5025 |
+
amateur_logits = amateur_logits[:, -1]
|
| 5026 |
+
|
| 5027 |
+
assert amateur_logits.shape == logits.shape, \
|
| 5028 |
+
'logits dimension are not the same between amateur and expert model'
|
| 5029 |
+
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
| 5030 |
+
|
| 5031 |
+
if cache_kv and amateur.can_cache_kv:
|
| 5032 |
+
amateur_caches[i] = next_amateur_cache
|
| 5033 |
+
|
| 5034 |
+
# --------- STRUCTURED LOGIT SHAPING (no training) ---------
|
| 5035 |
+
|
| 5036 |
+
if token_type_ids is not None:
|
| 5037 |
+
|
| 5038 |
+
# 1) per-token bias (type-aware)
|
| 5039 |
+
if per_token_bias is not None:
|
| 5040 |
+
logits = logits + per_token_bias # broadcast [vocab]
|
| 5041 |
+
|
| 5042 |
+
# 2) repetition penalty per type (context-aware)
|
| 5043 |
+
if repetition_window > 0 and len(per_type_rep_penalty) > 0:
|
| 5044 |
+
# look at recent tokens
|
| 5045 |
+
recent = out[:, -repetition_window:].to(device) # [b, w]
|
| 5046 |
+
# map to types
|
| 5047 |
+
recent_types = token_type_ids[recent] # [b, w]
|
| 5048 |
+
|
| 5049 |
+
# for each type, compute frequency and apply penalty
|
| 5050 |
+
# we do this per batch element
|
| 5051 |
+
for bi in range(b):
|
| 5052 |
+
types_b = recent_types[bi] # [w]
|
| 5053 |
+
if types_b.numel() == 0:
|
| 5054 |
+
continue
|
| 5055 |
+
# count occurrences per type id present in penalties
|
| 5056 |
+
for type_id, penalty_scale in per_type_rep_penalty.items():
|
| 5057 |
+
# penalty_scale > 1.0 means stronger penalty
|
| 5058 |
+
mask = (types_b == type_id)
|
| 5059 |
+
if mask.any():
|
| 5060 |
+
freq = mask.float().mean().item() # 0..1
|
| 5061 |
+
if freq > 0.0:
|
| 5062 |
+
# build a penalty vector for this type
|
| 5063 |
+
type_mask = (token_type_ids == type_id) # [vocab]
|
| 5064 |
+
# subtract a penalty proportional to freq
|
| 5065 |
+
# (log-space penalty)
|
| 5066 |
+
logits[bi, type_mask] /= (1.0 + freq * (penalty_scale - 1.0))
|
| 5067 |
+
|
| 5068 |
+
# 3) entropy-based rare-type boost (gentle, context-aware)
|
| 5069 |
+
if rare_type_mask is not None and rare_type_boost > 0.0:
|
| 5070 |
+
# compute current probs & entropy (before global temperature)
|
| 5071 |
+
probs_raw = F.softmax(logits, dim=-1) # [b, vocab]
|
| 5072 |
+
log_probs_raw = torch.log(probs_raw + 1e-9)
|
| 5073 |
+
entropy = -(probs_raw * log_probs_raw).sum(dim=-1) # [b]
|
| 5074 |
+
|
| 5075 |
+
# for low-entropy states, gently boost rare types
|
| 5076 |
+
low_entropy = entropy < entropy_threshold
|
| 5077 |
+
if low_entropy.any():
|
| 5078 |
+
# boost only for those batch elements
|
| 5079 |
+
boost_vec = torch.zeros_like(logits)
|
| 5080 |
+
boost_vec[:, rare_type_mask] = rare_type_boost
|
| 5081 |
+
logits = torch.where(
|
| 5082 |
+
low_entropy.unsqueeze(-1),
|
| 5083 |
+
logits + boost_vec,
|
| 5084 |
+
logits
|
| 5085 |
+
)
|
| 5086 |
+
|
| 5087 |
+
# 4) per-token temperature (type-aware)
|
| 5088 |
+
# apply before global temperature
|
| 5089 |
+
if per_token_temp is not None:
|
| 5090 |
+
# divide logits by per-token temperature
|
| 5091 |
+
# (smaller temp -> sharper distribution for that type)
|
| 5092 |
+
logits = logits / per_token_temp
|
| 5093 |
+
|
| 5094 |
+
# --------- APPLY FORBIDDEN TOKEN MASK ---------
|
| 5095 |
+
if forbidden_mask_per_batch is not None:
|
| 5096 |
+
# ensure shapes match
|
| 5097 |
+
assert forbidden_mask_per_batch.shape[0] == b and forbidden_mask_per_batch.shape[1] == logits.shape[-1], \
|
| 5098 |
+
"forbidden mask shape must be [b, vocab]"
|
| 5099 |
+
# set logits for forbidden tokens to a large negative value
|
| 5100 |
+
logits = logits.masked_fill(forbidden_mask_per_batch, float(forbidden_value))
|
| 5101 |
+
|
| 5102 |
+
# ----------------------------------------------------------
|
| 5103 |
+
|
| 5104 |
+
# filter by top_k, top_p (nucleus), top_a, or custom
|
| 5105 |
+
|
| 5106 |
+
if greedy:
|
| 5107 |
+
sample = logits.argmax(dim = -1, keepdim = True)
|
| 5108 |
+
else:
|
| 5109 |
+
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| 5110 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 5111 |
+
sample = torch.multinomial(probs, 1)
|
| 5112 |
+
|
| 5113 |
+
# concat sample
|
| 5114 |
+
|
| 5115 |
+
out = torch.cat((out, sample), dim=-1)
|
| 5116 |
+
|
| 5117 |
+
if verbose:
|
| 5118 |
+
if sl % 32 == 0:
|
| 5119 |
+
print(sl, '/', seq_len)
|
| 5120 |
+
|
| 5121 |
+
if not exists(eos_token):
|
| 5122 |
+
continue
|
| 5123 |
+
|
| 5124 |
+
is_eos_tokens = (out == eos_token)
|
| 5125 |
+
|
| 5126 |
+
if is_eos_tokens.any(dim = -1).all():
|
| 5127 |
+
|
| 5128 |
+
if verbose:
|
| 5129 |
+
print('Model called the end of sequence at:', sl, '/', seq_len)
|
| 5130 |
+
|
| 5131 |
+
break
|
| 5132 |
+
|
| 5133 |
+
if exists(eos_token):
|
| 5134 |
+
# mask out everything after the eos tokens
|
| 5135 |
+
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| 5136 |
+
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| 5137 |
+
out = out.masked_fill(mask, self.pad_value)
|
| 5138 |
+
|
| 5139 |
+
if return_prime:
|
| 5140 |
+
out = out[:, :]
|
| 5141 |
+
else:
|
| 5142 |
+
out = out[:, t:]
|
| 5143 |
+
|
| 5144 |
+
out, = unpack(out, ps, '* n')
|
| 5145 |
+
|
| 5146 |
+
return out
|
| 5147 |
+
|
| 5148 |
+
#=========================================================================================
|
| 5149 |
+
|
| 5150 |
+
# Binary classifier fuctions
|
| 5151 |
+
|
| 5152 |
+
class ClsInferenceDataset(Dataset):
|
| 5153 |
+
"""
|
| 5154 |
+
Dataset for pairs (src_seq, label).
|
| 5155 |
+
src_seq: list of token IDs (ints).
|
| 5156 |
+
label: single int or float (0 or 1).
|
| 5157 |
+
"""
|
| 5158 |
+
def __init__(self, data_pairs):
|
| 5159 |
+
self.data_pairs = data_pairs
|
| 5160 |
+
|
| 5161 |
+
def __len__(self):
|
| 5162 |
+
return len(self.data_pairs)
|
| 5163 |
+
|
| 5164 |
+
def __getitem__(self, idx):
|
| 5165 |
+
src_seq = self.data_pairs[idx]
|
| 5166 |
+
x = torch.tensor(src_seq, dtype=torch.long)
|
| 5167 |
+
return x
|
| 5168 |
+
|
| 5169 |
+
def build_cls_model(num_tokens=18819,
|
| 5170 |
+
max_seq_len=1024,
|
| 5171 |
+
logits_dim=1,
|
| 5172 |
+
use_cls_token=True,
|
| 5173 |
+
squeeze_out_last_dim=True,
|
| 5174 |
+
dim=1024,
|
| 5175 |
+
depth=8,
|
| 5176 |
+
heads=8,
|
| 5177 |
+
device='cuda'
|
| 5178 |
+
):
|
| 5179 |
+
|
| 5180 |
+
"""
|
| 5181 |
+
Constructs the Transformer model that outputs a single logit per input.
|
| 5182 |
+
"""
|
| 5183 |
+
|
| 5184 |
+
model = TransformerWrapper(
|
| 5185 |
+
num_tokens=num_tokens,
|
| 5186 |
+
max_seq_len=max_seq_len,
|
| 5187 |
+
logits_dim=logits_dim,
|
| 5188 |
+
use_cls_token=use_cls_token,
|
| 5189 |
+
squeeze_out_last_dim = squeeze_out_last_dim,
|
| 5190 |
+
attn_layers=Encoder(dim=dim,
|
| 5191 |
+
depth=depth,
|
| 5192 |
+
heads=heads
|
| 5193 |
+
)
|
| 5194 |
+
)
|
| 5195 |
+
|
| 5196 |
+
return model.to(device)
|
| 5197 |
+
|
| 5198 |
+
def load_cls_model(checkpoint_path, device='cuda'):
|
| 5199 |
+
|
| 5200 |
+
"""
|
| 5201 |
+
Rebuilds the architecture, loads weights.
|
| 5202 |
+
"""
|
| 5203 |
+
|
| 5204 |
+
model = build_cls_model(device=device)
|
| 5205 |
+
state = torch.load(checkpoint_path, map_location=device)
|
| 5206 |
+
model.load_state_dict(state)
|
| 5207 |
+
model.to(device).eval()
|
| 5208 |
+
|
| 5209 |
+
return model
|
| 5210 |
+
|
| 5211 |
+
def cls_predict(model,
|
| 5212 |
+
seqs,
|
| 5213 |
+
batch_size=8,
|
| 5214 |
+
threshold=0.5,
|
| 5215 |
+
seq_len=1024,
|
| 5216 |
+
pad_token=18818,
|
| 5217 |
+
device='cuda'
|
| 5218 |
+
):
|
| 5219 |
+
|
| 5220 |
+
"""
|
| 5221 |
+
Returns two lists:
|
| 5222 |
+
- probs: float probabilities
|
| 5223 |
+
- preds: int 0/1 predictions
|
| 5224 |
+
"""
|
| 5225 |
+
|
| 5226 |
+
def collate_fn(batch):
|
| 5227 |
+
# batch: list of sequences (list/1D-tensor)
|
| 5228 |
+
tensors = [s[:seq_len].detach().clone() for s in batch]
|
| 5229 |
+
max_len = min(seq_len, max(t.size(0) for t in tensors))
|
| 5230 |
+
padded = torch.full((len(tensors), max_len), pad_token, dtype=torch.long)
|
| 5231 |
+
for i, t in enumerate(tensors):
|
| 5232 |
+
L = t.size(0)
|
| 5233 |
+
padded[i, :L] = t
|
| 5234 |
+
return padded
|
| 5235 |
+
|
| 5236 |
+
ds = ClsInferenceDataset(seqs)
|
| 5237 |
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
|
| 5238 |
+
|
| 5239 |
+
all_probs = []
|
| 5240 |
+
all_preds = []
|
| 5241 |
+
|
| 5242 |
+
model.to(device)
|
| 5243 |
+
model.eval()
|
| 5244 |
+
|
| 5245 |
+
with torch.inference_mode():
|
| 5246 |
+
for x in loader:
|
| 5247 |
+
|
| 5248 |
+
x = x.to(device) # [B, L] (truncated & padded)
|
| 5249 |
+
|
| 5250 |
+
logits = model(x).squeeze() # [B]
|
| 5251 |
+
|
| 5252 |
+
probs = torch.sigmoid(logits) # [B]
|
| 5253 |
+
|
| 5254 |
+
preds = (probs >= threshold).long()
|
| 5255 |
+
|
| 5256 |
+
probs = probs.cpu().tolist()
|
| 5257 |
+
preds = preds.cpu().tolist()
|
| 5258 |
+
|
| 5259 |
+
if type(preds) == list:
|
| 5260 |
+
all_probs.extend(probs)
|
| 5261 |
+
all_preds.extend(preds)
|
| 5262 |
+
|
| 5263 |
+
else:
|
| 5264 |
+
all_probs.append(probs)
|
| 5265 |
+
all_preds.append(preds)
|
| 5266 |
+
|
| 5267 |
+
return all_preds, all_probs
|
| 5268 |
+
|
| 5269 |
#=================================================================================================================================
|
| 5270 |
# This is the end of x_transformer_2_3_1 Python module
|
| 5271 |
#=================================================================================================================================
|