codelion commited on
Commit
671befc
·
verified ·
1 Parent(s): 1c747f7

Add diffusion-based generate() method for proper text generation

Browse files
Files changed (1) hide show
  1. modeling_dhara.py +152 -0
modeling_dhara.py CHANGED
@@ -755,6 +755,158 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
755
  })
756
  return model_inputs
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  def save_pretrained(self, save_directory, **kwargs):
759
  kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
760
  return super().save_pretrained(save_directory, **kwargs)
 
755
  })
756
  return model_inputs
757
 
758
+ @torch.no_grad()
759
+ def generate(
760
+ self,
761
+ input_ids: Optional[torch.LongTensor] = None,
762
+ max_length: Optional[int] = None,
763
+ max_new_tokens: Optional[int] = None,
764
+ num_diffusion_steps: int = 10,
765
+ temperature: float = 1.0,
766
+ top_p: float = 0.9,
767
+ do_sample: bool = True,
768
+ pad_token_id: Optional[int] = None,
769
+ eos_token_id: Optional[int] = None,
770
+ **kwargs
771
+ ) -> torch.LongTensor:
772
+ """
773
+ Generate text using masked diffusion sampling.
774
+
775
+ This method performs iterative denoising: starting from fully masked tokens,
776
+ it progressively unmasks positions based on model confidence.
777
+
778
+ Args:
779
+ input_ids: Input prompt token IDs [batch_size, prompt_len]
780
+ max_length: Maximum total sequence length (prompt + generation)
781
+ max_new_tokens: Number of new tokens to generate (alternative to max_length)
782
+ num_diffusion_steps: Number of denoising iterations (more = higher quality, slower)
783
+ temperature: Sampling temperature (higher = more random)
784
+ top_p: Nucleus sampling threshold
785
+ do_sample: Whether to sample or take argmax
786
+ pad_token_id: Token ID for padding
787
+ eos_token_id: Token ID for end of sequence
788
+
789
+ Returns:
790
+ Generated token IDs including the prompt
791
+ """
792
+ # Handle device and dtype
793
+ device = input_ids.device if input_ids is not None else next(self.parameters()).device
794
+
795
+ # Determine generation length
796
+ if input_ids is not None:
797
+ batch_size, prompt_len = input_ids.shape
798
+ else:
799
+ batch_size, prompt_len = 1, 0
800
+ input_ids = torch.empty(batch_size, 0, dtype=torch.long, device=device)
801
+
802
+ if max_new_tokens is not None:
803
+ gen_len = max_new_tokens
804
+ elif max_length is not None:
805
+ gen_len = max_length - prompt_len
806
+ else:
807
+ gen_len = 50 # Default generation length
808
+
809
+ if gen_len <= 0:
810
+ return input_ids
811
+
812
+ # Get special token IDs
813
+ mask_token_id = self.config.mask_token_id
814
+ if pad_token_id is None:
815
+ pad_token_id = self.config.pad_token_id if hasattr(self.config, 'pad_token_id') else 0
816
+ if eos_token_id is None:
817
+ eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
818
+
819
+ # Initialize: prompt + masked tokens for generation
820
+ total_len = prompt_len + gen_len
821
+ tokens = torch.full((batch_size, total_len), mask_token_id, dtype=torch.long, device=device)
822
+ tokens[:, :prompt_len] = input_ids
823
+
824
+ # Track which positions are masked (need generation)
825
+ is_masked = torch.ones(batch_size, total_len, dtype=torch.bool, device=device)
826
+ is_masked[:, :prompt_len] = False # Prompt is not masked
827
+
828
+ # Number of tokens to unmask per step
829
+ tokens_per_step = max(1, gen_len // num_diffusion_steps)
830
+
831
+ # Iterative denoising
832
+ for step in range(num_diffusion_steps):
833
+ # Forward pass to get logits
834
+ outputs = self(input_ids=tokens)
835
+ logits = outputs.logits # [batch, seq_len, vocab]
836
+
837
+ # Only consider masked positions
838
+ masked_positions = is_masked.clone()
839
+
840
+ if not masked_positions.any():
841
+ break # All tokens have been generated
842
+
843
+ # Apply temperature
844
+ if temperature != 1.0:
845
+ logits = logits / temperature
846
+
847
+ # Get probabilities
848
+ probs = F.softmax(logits, dim=-1)
849
+
850
+ # Calculate confidence (max prob) for each position
851
+ confidence, _ = probs.max(dim=-1) # [batch, seq_len]
852
+
853
+ # Mask out already-generated positions from confidence calculation
854
+ confidence = confidence.masked_fill(~masked_positions, -float('inf'))
855
+
856
+ # Determine how many tokens to unmask this step
857
+ remaining_masked = masked_positions.sum(dim=1) # [batch]
858
+
859
+ # For the last step, unmask everything remaining
860
+ if step == num_diffusion_steps - 1:
861
+ num_to_unmask = remaining_masked
862
+ else:
863
+ num_to_unmask = torch.minimum(
864
+ torch.tensor(tokens_per_step, device=device).expand(batch_size),
865
+ remaining_masked
866
+ )
867
+
868
+ # For each batch item, unmask the highest confidence positions
869
+ for b in range(batch_size):
870
+ if num_to_unmask[b] == 0:
871
+ continue
872
+
873
+ # Get confidence scores for this batch item
874
+ conf_b = confidence[b] # [seq_len]
875
+
876
+ # Get top-k positions with highest confidence
877
+ k = int(num_to_unmask[b].item())
878
+ _, top_indices = conf_b.topk(k)
879
+
880
+ # Sample or argmax for these positions
881
+ for idx in top_indices:
882
+ pos_logits = logits[b, idx] # [vocab]
883
+
884
+ if do_sample and temperature > 0:
885
+ # Top-p (nucleus) sampling
886
+ sorted_logits, sorted_indices = torch.sort(pos_logits, descending=True)
887
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
888
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
889
+
890
+ # Remove tokens with cumulative probability above top_p
891
+ sorted_indices_to_remove = cumsum_probs > top_p
892
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
893
+ sorted_indices_to_remove[0] = False
894
+
895
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
896
+ probs_filtered = F.softmax(sorted_logits, dim=-1)
897
+
898
+ # Sample
899
+ sampled_idx = torch.multinomial(probs_filtered, 1)
900
+ token_id = sorted_indices[sampled_idx]
901
+ else:
902
+ # Greedy (argmax)
903
+ token_id = pos_logits.argmax()
904
+
905
+ tokens[b, idx] = token_id
906
+ is_masked[b, idx] = False
907
+
908
+ return tokens
909
+
910
  def save_pretrained(self, save_directory, **kwargs):
911
  kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
912
  return super().save_pretrained(save_directory, **kwargs)