kilianhaefeli commited on
Commit
389e1f4
·
1 Parent(s): e5351ca

initial attempt at EBM sampler

Browse files
Files changed (1) hide show
  1. modeling.py +36 -9
modeling.py CHANGED
@@ -723,6 +723,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
723
  max_new_tokens=20, # Added default value for safety
724
  mask_id=151665,
725
  threshold=1,
 
726
  small_block_size=8,
727
  block_size=32,
728
  stop_token=151645,
@@ -737,6 +738,10 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
737
  if use_block_cache:
738
  raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
739
  assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
 
 
 
 
740
 
741
  # pad the initial input_ids and attention_mask to be multiple of block_size
742
  if False: # input_ids.shape[1] % block_size != 0:
@@ -862,15 +867,37 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
862
  logits = logits[:, start:end]
863
 
864
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
865
- # Select tokens with probability greater than threshold from p_1t
866
- x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
867
- x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
- unmask_idx = (x1_p > threshold)
870
- # Ensure at least one token is unmasked in the current small block
871
- max_prob_idx = x1_p.argmax(dim=-1)
872
- unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
873
- unmask_idx = unmask_idx & mask_idx[:, start:end]
874
 
875
  # Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
876
  # aka if not finished and unmask id has some True value
@@ -965,4 +992,4 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
965
  samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
966
  x_1 = samples.view(*p_1t.shape[:-1])
967
 
968
- return x_1, p_1t
 
723
  max_new_tokens=20, # Added default value for safety
724
  mask_id=151665,
725
  threshold=1,
726
+ sampler="confidence",
727
  small_block_size=8,
728
  block_size=32,
729
  stop_token=151645,
 
738
  if use_block_cache:
739
  raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
740
  assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
741
+ if sampler not in {"confidence", "ebm"}:
742
+ raise ValueError(f"Unsupported sampler: {sampler}. Use 'confidence' or 'ebm'.")
743
+ if sampler == "ebm" and small_block_size != block_size:
744
+ raise ValueError("sampler='ebm' currently requires small_block_size == block_size.")
745
 
746
  # pad the initial input_ids and attention_mask to be multiple of block_size
747
  if False: # input_ids.shape[1] % block_size != 0:
 
867
  logits = logits[:, start:end]
868
 
869
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
870
+ mask_slice = mask_idx[:, start:end]
871
+ if sampler == "ebm":
872
+ # EBM: unmask lowest-entropy positions with cumulative entropy <= threshold.
873
+ log_p = torch.log(p_1t.clamp_min(1e-12))
874
+ entropy = -(p_1t * log_p).sum(dim=-1)
875
+ entropy = entropy.masked_fill(~mask_slice, float("inf"))
876
+
877
+ sorted_entropy, sorted_idx = torch.sort(entropy, dim=-1, descending=False)
878
+ cumulative_entropy = torch.cumsum(sorted_entropy, dim=-1)
879
+ n_to_unmask = (cumulative_entropy <= threshold).sum(dim=-1)
880
+ n_to_unmask = torch.clamp(n_to_unmask, min=1)
881
+ max_n = mask_slice.sum(dim=-1)
882
+ n_to_unmask = torch.minimum(n_to_unmask, max_n)
883
+
884
+ positions = torch.arange(sorted_idx.size(1), device=sorted_idx.device)
885
+ positions = positions.unsqueeze(0).expand_as(sorted_idx)
886
+ select = positions < n_to_unmask.unsqueeze(1)
887
+
888
+ unmask_idx = torch.zeros_like(mask_slice)
889
+ unmask_idx.scatter_(dim=1, index=sorted_idx, src=select)
890
+ unmask_idx = unmask_idx & mask_slice
891
+ else:
892
+ # Select tokens with probability greater than threshold from p_1t
893
+ x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
894
+ x1_p = torch.where(mask_slice, x1_p, -torch.inf)
895
 
896
+ unmask_idx = (x1_p > threshold)
897
+ # Ensure at least one token is unmasked in the current small block
898
+ max_prob_idx = x1_p.argmax(dim=-1)
899
+ unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
900
+ unmask_idx = unmask_idx & mask_slice
901
 
902
  # Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
903
  # aka if not finished and unmask id has some True value
 
992
  samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
993
  x_1 = samples.view(*p_1t.shape[:-1])
994
 
995
+ return x_1, p_1t