projectlosangeles commited on
Commit
ec18889
·
verified ·
1 Parent(s): b4122db

Upload x_transformer_2_3_1.py

Browse files
Files changed (1) hide show
  1. x_transformer_2_3_1.py +657 -4
x_transformer_2_3_1.py CHANGED
@@ -4,7 +4,7 @@
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
- # Version 3.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
@@ -45,6 +45,7 @@ import torch
45
  from torch.nn import Module
46
  from torch import nn, einsum, Tensor
47
  import torch.nn.functional as F
 
48
 
49
  from collections import namedtuple
50
  from functools import wraps
@@ -3982,7 +3983,7 @@ class AutoregressiveWrapper(Module):
3982
  # whether to add router z-loss
3983
  self.add_attn_z_loss = add_attn_z_loss
3984
 
3985
- @torch.no_grad()
3986
  @eval_decorator
3987
  def generate(
3988
  self,
@@ -4147,7 +4148,7 @@ class AutoregressiveWrapper(Module):
4147
 
4148
  return out
4149
 
4150
- @torch.no_grad()
4151
  @eval_decorator
4152
  def generate_masked(
4153
  self,
@@ -4328,7 +4329,7 @@ class AutoregressiveWrapper(Module):
4328
 
4329
  return out
4330
 
4331
- @torch.no_grad()
4332
  @eval_decorator
4333
  def generate_biased(
4334
  self,
@@ -4556,6 +4557,226 @@ class AutoregressiveWrapper(Module):
4556
  out, = unpack(out, ps, '* n')
4557
 
4558
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4559
 
4560
  def compute_accuracy(self, logits, labels):
4561
 
@@ -4613,6 +4834,438 @@ class AutoregressiveWrapper(Module):
4613
 
4614
  return loss, acc, logits, cache
4615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4616
  #=================================================================================================================================
4617
  # This is the end of x_transformer_2_3_1 Python module
4618
  #=================================================================================================================================
 
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
+ # Version 6.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
 
45
  from torch.nn import Module
46
  from torch import nn, einsum, Tensor
47
  import torch.nn.functional as F
48
+ from torch.utils.data import Dataset, DataLoader
49
 
50
  from collections import namedtuple
51
  from functools import wraps
 
3983
  # whether to add router z-loss
3984
  self.add_attn_z_loss = add_attn_z_loss
3985
 
3986
+ @torch.inference_mode()
3987
  @eval_decorator
3988
  def generate(
3989
  self,
 
4148
 
4149
  return out
4150
 
4151
+ @torch.inference_mode()
4152
  @eval_decorator
4153
  def generate_masked(
4154
  self,
 
4329
 
4330
  return out
4331
 
4332
+ @torch.inference_mode()
4333
  @eval_decorator
4334
  def generate_biased(
4335
  self,
 
4557
  out, = unpack(out, ps, '* n')
4558
 
4559
  return out
4560
+
4561
+ @torch.inference_mode()
4562
+ @eval_decorator
4563
+ def generate_advanced(
4564
+ self,
4565
+ prompts,
4566
+ seq_len,
4567
+ eos_token = None,
4568
+ temperature = 1.,
4569
+ prompt_lens: Tensor | None = None,
4570
+ filter_logits_fn: str | Callable = top_k,
4571
+ restrict_to_max_seq_len = True,
4572
+ amateur_model: Module | Tuple[Module] | None = None,
4573
+ filter_kwargs: dict = dict(),
4574
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
4575
+ beta = 0.5,
4576
+ alpha = 0.1
4577
+ ),
4578
+ cache_kv = True,
4579
+ return_prime=False,
4580
+ verbose=True,
4581
+ # --- new generation options ---
4582
+ logits_bias: dict | None = None, # {token_id: bias_value} where bias_value is float or Tensor(batch,)
4583
+ masked_tokens: list | Tensor | None = None, # list of token ids to forbid
4584
+ # --- binary classifier mode ---
4585
+ binary_classifier: bool = False, # if True, run classifier snippet and return preds, probs
4586
+ classifier_model: Module | None = None, # model to use for binary classification
4587
+ batches: list | None = None, # iterable of input batches for classifier_model
4588
+ threshold: float = 0.5, # threshold for converting probs to preds
4589
+ classifier_device: torch.device | None = None,
4590
+ # -----------------
4591
+ **kwargs
4592
+ ):
4593
+ # If binary classifier mode requested, run the provided snippet and return early.
4594
+ if binary_classifier:
4595
+ assert classifier_model is not None, "classifier_model must be provided when binary_classifier=True"
4596
+ assert batches is not None, "batches (iterable of input tensors) must be provided when binary_classifier=True"
4597
+
4598
+ device = classifier_device if classifier_device is not None else (prompts.device if exists(prompts) else torch.device('cpu'))
4599
+
4600
+ all_probs = []
4601
+ all_preds = []
4602
+
4603
+ classifier_model.eval()
4604
+ with torch.no_grad():
4605
+ for x in batches:
4606
+ x = x.to(device)
4607
+ logits = classifier_model(x).squeeze() # [B]
4608
+ probs = torch.sigmoid(logits) # [B]
4609
+ preds = (probs >= threshold).long()
4610
+
4611
+ all_probs.extend(probs.cpu().tolist())
4612
+ all_preds.extend(preds.cpu().tolist())
4613
+
4614
+ return all_preds, all_probs
4615
+
4616
+ # --- normal generation path below ---
4617
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
4618
+
4619
+ prompts, ps = pack([prompts], '* n')
4620
+
4621
+ b, t = prompts.shape
4622
+
4623
+ # handle filter logits fn given as string
4624
+ if isinstance(filter_logits_fn, str):
4625
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
4626
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
4627
+
4628
+ # handle variable lengthed prompts (prefixes)
4629
+ seq_start_pos = None
4630
+ if exists(prompt_lens):
4631
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
4632
+ seq_start_pos = t - prompt_lens
4633
+
4634
+ # output from which sampled tokens appended to
4635
+ out = prompts
4636
+
4637
+ if verbose:
4638
+ print("Generating sequence of max length:", seq_len)
4639
+
4640
+ # kv caches
4641
+ cache = None
4642
+
4643
+ # if doing contrastive decoding, turn off filter automatically
4644
+ if exists(amateur_model):
4645
+ amateur_model = cast_tuple(amateur_model)
4646
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
4647
+
4648
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
4649
+
4650
+ amateur_caches = [None] * len(amateur_model)
4651
+ filter_logits_fn = identity
4652
+
4653
+ for i, module in enumerate(amateur_model):
4654
+ if isinstance(module, AutoregressiveWrapper):
4655
+ amateur_model[i] = module.net
4656
+
4657
+ module.eval()
4658
+
4659
+ # normalize inputs for new args
4660
+ if exists(logits_bias):
4661
+ assert isinstance(logits_bias, dict), "logits_bias must be a dict {token_id: bias_value}"
4662
+ if exists(masked_tokens):
4663
+ if isinstance(masked_tokens, torch.Tensor):
4664
+ masked_tokens = masked_tokens.tolist()
4665
+ else:
4666
+ masked_tokens = list(masked_tokens)
4667
+
4668
+ # sampling up to seq_len
4669
+ for sl in range(seq_len):
4670
+
4671
+ if restrict_to_max_seq_len:
4672
+ max_len_exceeded = out.shape[-1] > max_seq_len
4673
+
4674
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
4675
+
4676
+ x = out[:, -max_seq_len:]
4677
+
4678
+ if exists(cache):
4679
+ for inter in cache.attn_intermediates:
4680
+ if inter.layer_type == 'a':
4681
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
4682
+
4683
+ logits, new_cache = self.net(
4684
+ x,
4685
+ return_intermediates = True,
4686
+ cache = cache,
4687
+ seq_start_pos = seq_start_pos,
4688
+ **kwargs
4689
+ )
4690
+
4691
+ if cache_kv and self.net.can_cache_kv:
4692
+ cache = new_cache
4693
+
4694
+ logits = logits[:, -1] # shape: (batch, vocab)
4695
+
4696
+ # handle contrastive decoding, Li et al.
4697
+ if exists(amateur_model):
4698
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
4699
+ amateur_logits, next_amateur_cache = amateur(
4700
+ x,
4701
+ return_intermediates = True,
4702
+ cache = amateur_cache,
4703
+ seq_start_pos = seq_start_pos,
4704
+ **kwargs
4705
+ )
4706
+
4707
+ amateur_logits = amateur_logits[:, -1]
4708
+
4709
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
4710
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
4711
+
4712
+ if cache_kv and amateur.can_cache_kv:
4713
+ amateur_caches[i] = next_amateur_cache
4714
+
4715
+ # --- APPLY LOGITS BIAS AND MASKING HERE (before filtering / softmax) ---
4716
+ # logits_bias: dict {token_id: bias_value} where bias_value is float or Tensor(batch,)
4717
+ if exists(logits_bias):
4718
+ # apply per-token bias updates directly to logits to avoid allocating full vocab bias tensor
4719
+ for tok_id, bias_val in logits_bias.items():
4720
+ # support scalar or per-batch tensor
4721
+ if isinstance(bias_val, torch.Tensor):
4722
+ if bias_val.dim() == 1 and bias_val.shape[0] == b:
4723
+ bias_to_add = bias_val.to(device)
4724
+ else:
4725
+ bias_to_add = bias_val.to(device).view(1).expand(b)
4726
+ else:
4727
+ bias_to_add = torch.tensor(float(bias_val), device=device).view(1).expand(b)
4728
+
4729
+ logits[:, int(tok_id)] = logits[:, int(tok_id)] + bias_to_add
4730
+
4731
+ # masked_tokens: list of token ids to forbid
4732
+ if exists(masked_tokens) and len(masked_tokens) > 0:
4733
+ NEG_INF = -1e9
4734
+ idx = torch.tensor(masked_tokens, device=device, dtype=torch.long)
4735
+ idx = idx[(idx >= 0) & (idx < logits.shape[-1])]
4736
+ if idx.numel() > 0:
4737
+ logits.index_fill_(dim=-1, index=idx, value=NEG_INF)
4738
+ # -------------------------------------------------------------------
4739
+
4740
+ # filter by top_k, top_p (nucleus), top_a, or custom
4741
+ if greedy:
4742
+ sample = logits.argmax(dim = -1, keepdim = True)
4743
+ else:
4744
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
4745
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
4746
+ sample = torch.multinomial(probs, 1)
4747
+
4748
+ # concat sample
4749
+ out = torch.cat((out, sample), dim=-1)
4750
+
4751
+ if verbose:
4752
+ if sl % 32 == 0:
4753
+ print(sl, '/', seq_len)
4754
+
4755
+ if not exists(eos_token):
4756
+ continue
4757
+
4758
+ is_eos_tokens = (out == eos_token)
4759
+
4760
+ if is_eos_tokens.any(dim = -1).all():
4761
+ if verbose:
4762
+ print('Model called the end of sequence at:', sl, '/', seq_len)
4763
+ break
4764
+
4765
+ if exists(eos_token):
4766
+ # mask out everything after the eos tokens
4767
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
4768
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
4769
+ out = out.masked_fill(mask, self.pad_value)
4770
+
4771
+ if return_prime:
4772
+ out = out[:, :]
4773
+
4774
+ else:
4775
+ out = out[:, t:]
4776
+
4777
+ out, = unpack(out, ps, '* n')
4778
+
4779
+ return out
4780
 
4781
  def compute_accuracy(self, logits, labels):
4782
 
 
4834
 
4835
  return loss, acc, logits, cache
4836
 
4837
+ @torch.inference_mode()
4838
+ @eval_decorator
4839
+ def generate_expert(
4840
+ self,
4841
+ prompts,
4842
+ seq_len,
4843
+ eos_token = None,
4844
+ temperature = 1.,
4845
+ prompt_lens: Tensor | None = None,
4846
+ filter_logits_fn: str | Callable = top_k,
4847
+ restrict_to_max_seq_len = True,
4848
+ amateur_model: Module | Tuple[Module] | None = None,
4849
+ filter_kwargs: dict = dict(),
4850
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
4851
+ beta = 0.5,
4852
+ alpha = 0.1
4853
+ ),
4854
+ cache_kv = True,
4855
+ return_prime=False,
4856
+ verbose=True,
4857
+ # --- new controls ---
4858
+ token_type_ids: torch.LongTensor | None = None, # [vocab]
4859
+ type_temperatures: dict | None = None, # {type_id: temp}
4860
+ type_biases: dict | None = None, # {type_id: bias}
4861
+ repetition_window: int = 64,
4862
+ repetition_penalty_per_type: dict | None = None, # {type_id: penalty_scale}
4863
+ rare_types: set | None = None, # e.g. {4, 5}
4864
+ rare_type_boost: float = 0.0, # small, e.g. 0.5
4865
+ entropy_threshold: float = 2.0, # when below, boost rare types
4866
+ # --- masked tokens option ---
4867
+ forbidden_token_ids: torch.LongTensor | torch.BoolTensor | None = None,
4868
+ forbidden_value: float = -1e9,
4869
+ **kwargs
4870
+ ):
4871
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
4872
+
4873
+ prompts, ps = pack([prompts], '* n')
4874
+
4875
+ b, t = prompts.shape
4876
+
4877
+ # handle filter logits fn given as string
4878
+
4879
+ if isinstance(filter_logits_fn, str):
4880
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
4881
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
4882
+
4883
+ # handle variable lengthed prompts (prefixes)
4884
+
4885
+ seq_start_pos = None
4886
+ if exists(prompt_lens):
4887
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
4888
+ seq_start_pos = t - prompt_lens
4889
+
4890
+ # output from which sampled tokens appended to
4891
+
4892
+ out = prompts
4893
+
4894
+ if verbose:
4895
+ print("Generating sequence of max length:", seq_len)
4896
+
4897
+ # kv caches
4898
+
4899
+ cache = None
4900
+
4901
+ # if doing contrastive decoding, turn off filter automatically
4902
+
4903
+ if exists(amateur_model):
4904
+ amateur_model = cast_tuple(amateur_model)
4905
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
4906
+
4907
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
4908
+
4909
+ amateur_caches = [None] * len(amateur_model)
4910
+ filter_logits_fn = identity
4911
+
4912
+ for i, module in enumerate(amateur_model):
4913
+ if isinstance(module, AutoregressiveWrapper):
4914
+ amateur_model[i] = module.net
4915
+
4916
+ module.eval()
4917
+
4918
+ # precompute some tensors for type controls
4919
+
4920
+ if token_type_ids is not None:
4921
+ token_type_ids = token_type_ids.to(device)
4922
+
4923
+ # build per-token temperature and bias vectors if provided
4924
+ per_token_temp = None
4925
+ if type_temperatures is not None and len(type_temperatures) > 0:
4926
+ per_token_temp = torch.ones_like(token_type_ids, dtype=torch.float32)
4927
+ for type_id, temp_val in type_temperatures.items():
4928
+ per_token_temp[token_type_ids == type_id] = float(temp_val)
4929
+
4930
+ per_token_bias = None
4931
+ if type_biases is not None and len(type_biases) > 0:
4932
+ per_token_bias = torch.zeros_like(token_type_ids, dtype=torch.float32)
4933
+ for type_id, bias_val in type_biases.items():
4934
+ per_token_bias[token_type_ids == type_id] = float(bias_val)
4935
+
4936
+ # repetition penalty per type
4937
+ per_type_rep_penalty = repetition_penalty_per_type or {}
4938
+
4939
+ # rare type mask
4940
+ rare_type_mask = None
4941
+ if rare_types is not None and len(rare_types) > 0:
4942
+ rare_type_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
4943
+ for rt in rare_types:
4944
+ rare_type_mask |= (token_type_ids == rt)
4945
+ else:
4946
+ per_token_temp = None
4947
+ per_token_bias = None
4948
+ per_type_rep_penalty = {}
4949
+ rare_type_mask = None
4950
+
4951
+ # prepare forbidden mask if provided
4952
+ # We'll lazily convert forbidden_token_ids into a boolean mask of shape [b, vocab]
4953
+ forbidden_mask_per_batch = None
4954
+ if forbidden_token_ids is not None:
4955
+ # If it's a LongTensor of ids (1D)
4956
+ if forbidden_token_ids.dtype in (torch.int64, torch.int32):
4957
+ # create a [vocab] bool mask from ids
4958
+ vocab_size = self.net.config.vocab_size if hasattr(self.net, 'config') else None
4959
+ # If we can't infer vocab_size, we'll infer from token_type_ids if available
4960
+ if vocab_size is None and token_type_ids is not None:
4961
+ vocab_size = token_type_ids.shape[0]
4962
+ assert vocab_size is not None, "Cannot infer vocab size for forbidden_token_ids; provide a boolean mask instead."
4963
+ mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
4964
+ ids = forbidden_token_ids.to(device)
4965
+ mask[ids.clamp(0, vocab_size-1)] = True
4966
+ forbidden_mask_per_batch = mask.unsqueeze(0).expand(b, -1) # [b, vocab]
4967
+ elif forbidden_token_ids.dtype == torch.bool:
4968
+ # could be [vocab] or [b, vocab]
4969
+ if forbidden_token_ids.dim() == 1:
4970
+ forbidden_mask_per_batch = forbidden_token_ids.to(device).unsqueeze(0).expand(b, -1)
4971
+ elif forbidden_token_ids.dim() == 2:
4972
+ assert forbidden_token_ids.shape[0] == b, "forbidden_token_ids batch dimension must match prompts batch size"
4973
+ forbidden_mask_per_batch = forbidden_token_ids.to(device)
4974
+ else:
4975
+ raise ValueError("forbidden_token_ids boolean mask must be 1D [vocab] or 2D [b, vocab]")
4976
+ else:
4977
+ raise TypeError("forbidden_token_ids must be LongTensor of ids or BoolTensor mask")
4978
+
4979
+ # sampling up to seq_len
4980
+
4981
+ for sl in range(seq_len):
4982
+
4983
+ if restrict_to_max_seq_len:
4984
+ max_len_exceeded = out.shape[-1] > max_seq_len
4985
+
4986
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
4987
+ 'the network cannot use cached key values when decoding outside the max sequence length. ' \
4988
+ 'most likely because you are using absolute positional embedding. ' \
4989
+ 'you can switch to rotary embeddings to resolve this issue'
4990
+
4991
+ x = out[:, -max_seq_len:]
4992
+
4993
+ if exists(cache):
4994
+ for inter in cache.attn_intermediates:
4995
+ if inter.layer_type == 'a':
4996
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
4997
+
4998
+ logits, new_cache = self.net(
4999
+ x,
5000
+ return_intermediates = True,
5001
+ cache = cache,
5002
+ seq_start_pos = seq_start_pos,
5003
+ **kwargs
5004
+ )
5005
+
5006
+ if cache_kv and self.net.can_cache_kv:
5007
+ cache = new_cache
5008
+
5009
+ logits = logits[:, -1] # [b, vocab]
5010
+
5011
+ # handle contrastive decoding
5012
+
5013
+ if exists(amateur_model):
5014
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(
5015
+ zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
5016
+ ):
5017
+ amateur_logits, next_amateur_cache = amateur(
5018
+ x,
5019
+ return_intermediates = True,
5020
+ cache = amateur_cache,
5021
+ seq_start_pos = seq_start_pos,
5022
+ **kwargs
5023
+ )
5024
+
5025
+ amateur_logits = amateur_logits[:, -1]
5026
+
5027
+ assert amateur_logits.shape == logits.shape, \
5028
+ 'logits dimension are not the same between amateur and expert model'
5029
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
5030
+
5031
+ if cache_kv and amateur.can_cache_kv:
5032
+ amateur_caches[i] = next_amateur_cache
5033
+
5034
+ # --------- STRUCTURED LOGIT SHAPING (no training) ---------
5035
+
5036
+ if token_type_ids is not None:
5037
+
5038
+ # 1) per-token bias (type-aware)
5039
+ if per_token_bias is not None:
5040
+ logits = logits + per_token_bias # broadcast [vocab]
5041
+
5042
+ # 2) repetition penalty per type (context-aware)
5043
+ if repetition_window > 0 and len(per_type_rep_penalty) > 0:
5044
+ # look at recent tokens
5045
+ recent = out[:, -repetition_window:].to(device) # [b, w]
5046
+ # map to types
5047
+ recent_types = token_type_ids[recent] # [b, w]
5048
+
5049
+ # for each type, compute frequency and apply penalty
5050
+ # we do this per batch element
5051
+ for bi in range(b):
5052
+ types_b = recent_types[bi] # [w]
5053
+ if types_b.numel() == 0:
5054
+ continue
5055
+ # count occurrences per type id present in penalties
5056
+ for type_id, penalty_scale in per_type_rep_penalty.items():
5057
+ # penalty_scale > 1.0 means stronger penalty
5058
+ mask = (types_b == type_id)
5059
+ if mask.any():
5060
+ freq = mask.float().mean().item() # 0..1
5061
+ if freq > 0.0:
5062
+ # build a penalty vector for this type
5063
+ type_mask = (token_type_ids == type_id) # [vocab]
5064
+ # subtract a penalty proportional to freq
5065
+ # (log-space penalty)
5066
+ logits[bi, type_mask] /= (1.0 + freq * (penalty_scale - 1.0))
5067
+
5068
+ # 3) entropy-based rare-type boost (gentle, context-aware)
5069
+ if rare_type_mask is not None and rare_type_boost > 0.0:
5070
+ # compute current probs & entropy (before global temperature)
5071
+ probs_raw = F.softmax(logits, dim=-1) # [b, vocab]
5072
+ log_probs_raw = torch.log(probs_raw + 1e-9)
5073
+ entropy = -(probs_raw * log_probs_raw).sum(dim=-1) # [b]
5074
+
5075
+ # for low-entropy states, gently boost rare types
5076
+ low_entropy = entropy < entropy_threshold
5077
+ if low_entropy.any():
5078
+ # boost only for those batch elements
5079
+ boost_vec = torch.zeros_like(logits)
5080
+ boost_vec[:, rare_type_mask] = rare_type_boost
5081
+ logits = torch.where(
5082
+ low_entropy.unsqueeze(-1),
5083
+ logits + boost_vec,
5084
+ logits
5085
+ )
5086
+
5087
+ # 4) per-token temperature (type-aware)
5088
+ # apply before global temperature
5089
+ if per_token_temp is not None:
5090
+ # divide logits by per-token temperature
5091
+ # (smaller temp -> sharper distribution for that type)
5092
+ logits = logits / per_token_temp
5093
+
5094
+ # --------- APPLY FORBIDDEN TOKEN MASK ---------
5095
+ if forbidden_mask_per_batch is not None:
5096
+ # ensure shapes match
5097
+ assert forbidden_mask_per_batch.shape[0] == b and forbidden_mask_per_batch.shape[1] == logits.shape[-1], \
5098
+ "forbidden mask shape must be [b, vocab]"
5099
+ # set logits for forbidden tokens to a large negative value
5100
+ logits = logits.masked_fill(forbidden_mask_per_batch, float(forbidden_value))
5101
+
5102
+ # ----------------------------------------------------------
5103
+
5104
+ # filter by top_k, top_p (nucleus), top_a, or custom
5105
+
5106
+ if greedy:
5107
+ sample = logits.argmax(dim = -1, keepdim = True)
5108
+ else:
5109
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
5110
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
5111
+ sample = torch.multinomial(probs, 1)
5112
+
5113
+ # concat sample
5114
+
5115
+ out = torch.cat((out, sample), dim=-1)
5116
+
5117
+ if verbose:
5118
+ if sl % 32 == 0:
5119
+ print(sl, '/', seq_len)
5120
+
5121
+ if not exists(eos_token):
5122
+ continue
5123
+
5124
+ is_eos_tokens = (out == eos_token)
5125
+
5126
+ if is_eos_tokens.any(dim = -1).all():
5127
+
5128
+ if verbose:
5129
+ print('Model called the end of sequence at:', sl, '/', seq_len)
5130
+
5131
+ break
5132
+
5133
+ if exists(eos_token):
5134
+ # mask out everything after the eos tokens
5135
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
5136
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
5137
+ out = out.masked_fill(mask, self.pad_value)
5138
+
5139
+ if return_prime:
5140
+ out = out[:, :]
5141
+ else:
5142
+ out = out[:, t:]
5143
+
5144
+ out, = unpack(out, ps, '* n')
5145
+
5146
+ return out
5147
+
5148
+ #=========================================================================================
5149
+
5150
+ # Binary classifier fuctions
5151
+
5152
+ class ClsInferenceDataset(Dataset):
5153
+ """
5154
+ Dataset for pairs (src_seq, label).
5155
+ src_seq: list of token IDs (ints).
5156
+ label: single int or float (0 or 1).
5157
+ """
5158
+ def __init__(self, data_pairs):
5159
+ self.data_pairs = data_pairs
5160
+
5161
+ def __len__(self):
5162
+ return len(self.data_pairs)
5163
+
5164
+ def __getitem__(self, idx):
5165
+ src_seq = self.data_pairs[idx]
5166
+ x = torch.tensor(src_seq, dtype=torch.long)
5167
+ return x
5168
+
5169
+ def build_cls_model(num_tokens=18819,
5170
+ max_seq_len=1024,
5171
+ logits_dim=1,
5172
+ use_cls_token=True,
5173
+ squeeze_out_last_dim=True,
5174
+ dim=1024,
5175
+ depth=8,
5176
+ heads=8,
5177
+ device='cuda'
5178
+ ):
5179
+
5180
+ """
5181
+ Constructs the Transformer model that outputs a single logit per input.
5182
+ """
5183
+
5184
+ model = TransformerWrapper(
5185
+ num_tokens=num_tokens,
5186
+ max_seq_len=max_seq_len,
5187
+ logits_dim=logits_dim,
5188
+ use_cls_token=use_cls_token,
5189
+ squeeze_out_last_dim = squeeze_out_last_dim,
5190
+ attn_layers=Encoder(dim=dim,
5191
+ depth=depth,
5192
+ heads=heads
5193
+ )
5194
+ )
5195
+
5196
+ return model.to(device)
5197
+
5198
+ def load_cls_model(checkpoint_path, device='cuda'):
5199
+
5200
+ """
5201
+ Rebuilds the architecture, loads weights.
5202
+ """
5203
+
5204
+ model = build_cls_model(device=device)
5205
+ state = torch.load(checkpoint_path, map_location=device)
5206
+ model.load_state_dict(state)
5207
+ model.to(device).eval()
5208
+
5209
+ return model
5210
+
5211
+ def cls_predict(model,
5212
+ seqs,
5213
+ batch_size=8,
5214
+ threshold=0.5,
5215
+ seq_len=1024,
5216
+ pad_token=18818,
5217
+ device='cuda'
5218
+ ):
5219
+
5220
+ """
5221
+ Returns two lists:
5222
+ - probs: float probabilities
5223
+ - preds: int 0/1 predictions
5224
+ """
5225
+
5226
+ def collate_fn(batch):
5227
+ # batch: list of sequences (list/1D-tensor)
5228
+ tensors = [s[:seq_len].detach().clone() for s in batch]
5229
+ max_len = min(seq_len, max(t.size(0) for t in tensors))
5230
+ padded = torch.full((len(tensors), max_len), pad_token, dtype=torch.long)
5231
+ for i, t in enumerate(tensors):
5232
+ L = t.size(0)
5233
+ padded[i, :L] = t
5234
+ return padded
5235
+
5236
+ ds = ClsInferenceDataset(seqs)
5237
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
5238
+
5239
+ all_probs = []
5240
+ all_preds = []
5241
+
5242
+ model.to(device)
5243
+ model.eval()
5244
+
5245
+ with torch.inference_mode():
5246
+ for x in loader:
5247
+
5248
+ x = x.to(device) # [B, L] (truncated & padded)
5249
+
5250
+ logits = model(x).squeeze() # [B]
5251
+
5252
+ probs = torch.sigmoid(logits) # [B]
5253
+
5254
+ preds = (probs >= threshold).long()
5255
+
5256
+ probs = probs.cpu().tolist()
5257
+ preds = preds.cpu().tolist()
5258
+
5259
+ if type(preds) == list:
5260
+ all_probs.extend(probs)
5261
+ all_preds.extend(preds)
5262
+
5263
+ else:
5264
+ all_probs.append(probs)
5265
+ all_preds.append(preds)
5266
+
5267
+ return all_preds, all_probs
5268
+
5269
  #=================================================================================================================================
5270
  # This is the end of x_transformer_2_3_1 Python module
5271
  #=================================================================================================================================