Add diffusion-based generate() method for proper text generation
Browse files- modeling_dhara.py +152 -0
modeling_dhara.py
CHANGED
|
@@ -755,6 +755,158 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 755 |
})
|
| 756 |
return model_inputs
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
def save_pretrained(self, save_directory, **kwargs):
|
| 759 |
kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
|
| 760 |
return super().save_pretrained(save_directory, **kwargs)
|
|
|
|
| 755 |
})
|
| 756 |
return model_inputs
|
| 757 |
|
| 758 |
+
@torch.no_grad()
|
| 759 |
+
def generate(
|
| 760 |
+
self,
|
| 761 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 762 |
+
max_length: Optional[int] = None,
|
| 763 |
+
max_new_tokens: Optional[int] = None,
|
| 764 |
+
num_diffusion_steps: int = 10,
|
| 765 |
+
temperature: float = 1.0,
|
| 766 |
+
top_p: float = 0.9,
|
| 767 |
+
do_sample: bool = True,
|
| 768 |
+
pad_token_id: Optional[int] = None,
|
| 769 |
+
eos_token_id: Optional[int] = None,
|
| 770 |
+
**kwargs
|
| 771 |
+
) -> torch.LongTensor:
|
| 772 |
+
"""
|
| 773 |
+
Generate text using masked diffusion sampling.
|
| 774 |
+
|
| 775 |
+
This method performs iterative denoising: starting from fully masked tokens,
|
| 776 |
+
it progressively unmasks positions based on model confidence.
|
| 777 |
+
|
| 778 |
+
Args:
|
| 779 |
+
input_ids: Input prompt token IDs [batch_size, prompt_len]
|
| 780 |
+
max_length: Maximum total sequence length (prompt + generation)
|
| 781 |
+
max_new_tokens: Number of new tokens to generate (alternative to max_length)
|
| 782 |
+
num_diffusion_steps: Number of denoising iterations (more = higher quality, slower)
|
| 783 |
+
temperature: Sampling temperature (higher = more random)
|
| 784 |
+
top_p: Nucleus sampling threshold
|
| 785 |
+
do_sample: Whether to sample or take argmax
|
| 786 |
+
pad_token_id: Token ID for padding
|
| 787 |
+
eos_token_id: Token ID for end of sequence
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
Generated token IDs including the prompt
|
| 791 |
+
"""
|
| 792 |
+
# Handle device and dtype
|
| 793 |
+
device = input_ids.device if input_ids is not None else next(self.parameters()).device
|
| 794 |
+
|
| 795 |
+
# Determine generation length
|
| 796 |
+
if input_ids is not None:
|
| 797 |
+
batch_size, prompt_len = input_ids.shape
|
| 798 |
+
else:
|
| 799 |
+
batch_size, prompt_len = 1, 0
|
| 800 |
+
input_ids = torch.empty(batch_size, 0, dtype=torch.long, device=device)
|
| 801 |
+
|
| 802 |
+
if max_new_tokens is not None:
|
| 803 |
+
gen_len = max_new_tokens
|
| 804 |
+
elif max_length is not None:
|
| 805 |
+
gen_len = max_length - prompt_len
|
| 806 |
+
else:
|
| 807 |
+
gen_len = 50 # Default generation length
|
| 808 |
+
|
| 809 |
+
if gen_len <= 0:
|
| 810 |
+
return input_ids
|
| 811 |
+
|
| 812 |
+
# Get special token IDs
|
| 813 |
+
mask_token_id = self.config.mask_token_id
|
| 814 |
+
if pad_token_id is None:
|
| 815 |
+
pad_token_id = self.config.pad_token_id if hasattr(self.config, 'pad_token_id') else 0
|
| 816 |
+
if eos_token_id is None:
|
| 817 |
+
eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
|
| 818 |
+
|
| 819 |
+
# Initialize: prompt + masked tokens for generation
|
| 820 |
+
total_len = prompt_len + gen_len
|
| 821 |
+
tokens = torch.full((batch_size, total_len), mask_token_id, dtype=torch.long, device=device)
|
| 822 |
+
tokens[:, :prompt_len] = input_ids
|
| 823 |
+
|
| 824 |
+
# Track which positions are masked (need generation)
|
| 825 |
+
is_masked = torch.ones(batch_size, total_len, dtype=torch.bool, device=device)
|
| 826 |
+
is_masked[:, :prompt_len] = False # Prompt is not masked
|
| 827 |
+
|
| 828 |
+
# Number of tokens to unmask per step
|
| 829 |
+
tokens_per_step = max(1, gen_len // num_diffusion_steps)
|
| 830 |
+
|
| 831 |
+
# Iterative denoising
|
| 832 |
+
for step in range(num_diffusion_steps):
|
| 833 |
+
# Forward pass to get logits
|
| 834 |
+
outputs = self(input_ids=tokens)
|
| 835 |
+
logits = outputs.logits # [batch, seq_len, vocab]
|
| 836 |
+
|
| 837 |
+
# Only consider masked positions
|
| 838 |
+
masked_positions = is_masked.clone()
|
| 839 |
+
|
| 840 |
+
if not masked_positions.any():
|
| 841 |
+
break # All tokens have been generated
|
| 842 |
+
|
| 843 |
+
# Apply temperature
|
| 844 |
+
if temperature != 1.0:
|
| 845 |
+
logits = logits / temperature
|
| 846 |
+
|
| 847 |
+
# Get probabilities
|
| 848 |
+
probs = F.softmax(logits, dim=-1)
|
| 849 |
+
|
| 850 |
+
# Calculate confidence (max prob) for each position
|
| 851 |
+
confidence, _ = probs.max(dim=-1) # [batch, seq_len]
|
| 852 |
+
|
| 853 |
+
# Mask out already-generated positions from confidence calculation
|
| 854 |
+
confidence = confidence.masked_fill(~masked_positions, -float('inf'))
|
| 855 |
+
|
| 856 |
+
# Determine how many tokens to unmask this step
|
| 857 |
+
remaining_masked = masked_positions.sum(dim=1) # [batch]
|
| 858 |
+
|
| 859 |
+
# For the last step, unmask everything remaining
|
| 860 |
+
if step == num_diffusion_steps - 1:
|
| 861 |
+
num_to_unmask = remaining_masked
|
| 862 |
+
else:
|
| 863 |
+
num_to_unmask = torch.minimum(
|
| 864 |
+
torch.tensor(tokens_per_step, device=device).expand(batch_size),
|
| 865 |
+
remaining_masked
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# For each batch item, unmask the highest confidence positions
|
| 869 |
+
for b in range(batch_size):
|
| 870 |
+
if num_to_unmask[b] == 0:
|
| 871 |
+
continue
|
| 872 |
+
|
| 873 |
+
# Get confidence scores for this batch item
|
| 874 |
+
conf_b = confidence[b] # [seq_len]
|
| 875 |
+
|
| 876 |
+
# Get top-k positions with highest confidence
|
| 877 |
+
k = int(num_to_unmask[b].item())
|
| 878 |
+
_, top_indices = conf_b.topk(k)
|
| 879 |
+
|
| 880 |
+
# Sample or argmax for these positions
|
| 881 |
+
for idx in top_indices:
|
| 882 |
+
pos_logits = logits[b, idx] # [vocab]
|
| 883 |
+
|
| 884 |
+
if do_sample and temperature > 0:
|
| 885 |
+
# Top-p (nucleus) sampling
|
| 886 |
+
sorted_logits, sorted_indices = torch.sort(pos_logits, descending=True)
|
| 887 |
+
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
| 888 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 889 |
+
|
| 890 |
+
# Remove tokens with cumulative probability above top_p
|
| 891 |
+
sorted_indices_to_remove = cumsum_probs > top_p
|
| 892 |
+
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
| 893 |
+
sorted_indices_to_remove[0] = False
|
| 894 |
+
|
| 895 |
+
sorted_logits[sorted_indices_to_remove] = float('-inf')
|
| 896 |
+
probs_filtered = F.softmax(sorted_logits, dim=-1)
|
| 897 |
+
|
| 898 |
+
# Sample
|
| 899 |
+
sampled_idx = torch.multinomial(probs_filtered, 1)
|
| 900 |
+
token_id = sorted_indices[sampled_idx]
|
| 901 |
+
else:
|
| 902 |
+
# Greedy (argmax)
|
| 903 |
+
token_id = pos_logits.argmax()
|
| 904 |
+
|
| 905 |
+
tokens[b, idx] = token_id
|
| 906 |
+
is_masked[b, idx] = False
|
| 907 |
+
|
| 908 |
+
return tokens
|
| 909 |
+
|
| 910 |
def save_pretrained(self, save_directory, **kwargs):
|
| 911 |
kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
|
| 912 |
return super().save_pretrained(save_directory, **kwargs)
|