kilianhaefeli commited on
Commit ·
389e1f4
1
Parent(s): e5351ca
initial attempt at EBM sampler
Browse files- modeling.py +36 -9
modeling.py
CHANGED
|
@@ -723,6 +723,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 723 |
max_new_tokens=20, # Added default value for safety
|
| 724 |
mask_id=151665,
|
| 725 |
threshold=1,
|
|
|
|
| 726 |
small_block_size=8,
|
| 727 |
block_size=32,
|
| 728 |
stop_token=151645,
|
|
@@ -737,6 +738,10 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 737 |
if use_block_cache:
|
| 738 |
raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
|
| 739 |
assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
# pad the initial input_ids and attention_mask to be multiple of block_size
|
| 742 |
if False: # input_ids.shape[1] % block_size != 0:
|
|
@@ -862,15 +867,37 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 862 |
logits = logits[:, start:end]
|
| 863 |
|
| 864 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
|
| 875 |
# Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
|
| 876 |
# aka if not finished and unmask id has some True value
|
|
@@ -965,4 +992,4 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 965 |
samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
|
| 966 |
x_1 = samples.view(*p_1t.shape[:-1])
|
| 967 |
|
| 968 |
-
return x_1, p_1t
|
|
|
|
| 723 |
max_new_tokens=20, # Added default value for safety
|
| 724 |
mask_id=151665,
|
| 725 |
threshold=1,
|
| 726 |
+
sampler="confidence",
|
| 727 |
small_block_size=8,
|
| 728 |
block_size=32,
|
| 729 |
stop_token=151645,
|
|
|
|
| 738 |
if use_block_cache:
|
| 739 |
raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
|
| 740 |
assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
|
| 741 |
+
if sampler not in {"confidence", "ebm"}:
|
| 742 |
+
raise ValueError(f"Unsupported sampler: {sampler}. Use 'confidence' or 'ebm'.")
|
| 743 |
+
if sampler == "ebm" and small_block_size != block_size:
|
| 744 |
+
raise ValueError("sampler='ebm' currently requires small_block_size == block_size.")
|
| 745 |
|
| 746 |
# pad the initial input_ids and attention_mask to be multiple of block_size
|
| 747 |
if False: # input_ids.shape[1] % block_size != 0:
|
|
|
|
| 867 |
logits = logits[:, start:end]
|
| 868 |
|
| 869 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 870 |
+
mask_slice = mask_idx[:, start:end]
|
| 871 |
+
if sampler == "ebm":
|
| 872 |
+
# EBM: unmask lowest-entropy positions with cumulative entropy <= threshold.
|
| 873 |
+
log_p = torch.log(p_1t.clamp_min(1e-12))
|
| 874 |
+
entropy = -(p_1t * log_p).sum(dim=-1)
|
| 875 |
+
entropy = entropy.masked_fill(~mask_slice, float("inf"))
|
| 876 |
+
|
| 877 |
+
sorted_entropy, sorted_idx = torch.sort(entropy, dim=-1, descending=False)
|
| 878 |
+
cumulative_entropy = torch.cumsum(sorted_entropy, dim=-1)
|
| 879 |
+
n_to_unmask = (cumulative_entropy <= threshold).sum(dim=-1)
|
| 880 |
+
n_to_unmask = torch.clamp(n_to_unmask, min=1)
|
| 881 |
+
max_n = mask_slice.sum(dim=-1)
|
| 882 |
+
n_to_unmask = torch.minimum(n_to_unmask, max_n)
|
| 883 |
+
|
| 884 |
+
positions = torch.arange(sorted_idx.size(1), device=sorted_idx.device)
|
| 885 |
+
positions = positions.unsqueeze(0).expand_as(sorted_idx)
|
| 886 |
+
select = positions < n_to_unmask.unsqueeze(1)
|
| 887 |
+
|
| 888 |
+
unmask_idx = torch.zeros_like(mask_slice)
|
| 889 |
+
unmask_idx.scatter_(dim=1, index=sorted_idx, src=select)
|
| 890 |
+
unmask_idx = unmask_idx & mask_slice
|
| 891 |
+
else:
|
| 892 |
+
# Select tokens with probability greater than threshold from p_1t
|
| 893 |
+
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 894 |
+
x1_p = torch.where(mask_slice, x1_p, -torch.inf)
|
| 895 |
|
| 896 |
+
unmask_idx = (x1_p > threshold)
|
| 897 |
+
# Ensure at least one token is unmasked in the current small block
|
| 898 |
+
max_prob_idx = x1_p.argmax(dim=-1)
|
| 899 |
+
unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
|
| 900 |
+
unmask_idx = unmask_idx & mask_slice
|
| 901 |
|
| 902 |
# Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
|
| 903 |
# aka if not finished and unmask id has some True value
|
|
|
|
| 992 |
samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
|
| 993 |
x_1 = samples.view(*p_1t.shape[:-1])
|
| 994 |
|
| 995 |
+
return x_1, p_1t
|