smoorsmith commited on
Commit
a9a06d4
verified
1 Parent(s): d49164e

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. generation_utils.py +38 -3
  2. modeling_dream.py +14 -0
  3. softmasking_utils.py +122 -0
generation_utils.py CHANGED
@@ -30,6 +30,7 @@ from transformers.utils import (
30
  is_torchdynamo_compiling,
31
  logging,
32
  )
 
33
 
34
  logger = logging.get_logger(__name__)
35
 
@@ -351,12 +352,25 @@ class DreamGenerationMixin:
351
  attention_mask=attention_mask
352
  )
353
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  result = self._sample(
355
  input_ids,
356
  attention_mask=attention_mask,
357
  generation_config=generation_config,
358
  generation_tokens_hook_func=generation_tokens_hook_func,
359
- generation_logits_hook_func=generation_logits_hook_func
 
360
  )
361
  return result
362
 
@@ -366,7 +380,8 @@ class DreamGenerationMixin:
366
  attention_mask: Optional[torch.LongTensor],
367
  generation_config: DreamGenerationConfig,
368
  generation_tokens_hook_func,
369
- generation_logits_hook_func
 
370
  ) -> Union[DreamModelOutput, torch.LongTensor]:
371
  # init values
372
  output_history = generation_config.output_history
@@ -407,9 +422,18 @@ class DreamGenerationMixin:
407
 
408
  # this allows user-defined token control of the intermediate steps
409
  x = generation_tokens_hook_func(None, x, None)
 
 
 
 
 
 
410
  for i in range(steps):
411
  mask_index = (x == mask_token_id)
412
- logits = self(x, attention_mask, tok_idx).logits
 
 
 
413
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
414
 
415
  # this allows user-defined logits control of the intermediate steps
@@ -454,6 +478,17 @@ class DreamGenerationMixin:
454
  # this allows user-defined token control of the intermediate steps
455
  x = generation_tokens_hook_func(i, x, logits)
456
 
 
 
 
 
 
 
 
 
 
 
 
457
  if histories is not None:
458
  histories.append(x.clone())
459
 
 
30
  is_torchdynamo_compiling,
31
  logging,
32
  )
33
+ from .softmasking_utils import SMArgs, get_mixing_factors_for_softmasking
34
 
35
  logger = logging.get_logger(__name__)
36
 
 
352
  attention_mask=attention_mask
353
  )
354
 
355
+ ###### LOAD IN Softmasking PARAMETERS ######
356
+ sm_args = SMArgs(
357
+ sm_alg=kwargs.pop("transparency_alg", "none"),
358
+ sm_schedule=kwargs.pop("transparency_scheduling", "none"),
359
+ scale=kwargs.pop("transparency_scale", 0.0),
360
+ steepness=kwargs.pop("transparency_steepness", 0.0),
361
+ offset=kwargs.pop("transparency_centre", 0.0),
362
+ mixinputs_k=kwargs.pop("mixinputs_k", 1),
363
+ mixinputs_temp=kwargs.pop("mixture_temp", 1.0),
364
+ )
365
+ #############################################
366
+
367
  result = self._sample(
368
  input_ids,
369
  attention_mask=attention_mask,
370
  generation_config=generation_config,
371
  generation_tokens_hook_func=generation_tokens_hook_func,
372
+ generation_logits_hook_func=generation_logits_hook_func,
373
+ sm_args=sm_args
374
  )
375
  return result
376
 
 
380
  attention_mask: Optional[torch.LongTensor],
381
  generation_config: DreamGenerationConfig,
382
  generation_tokens_hook_func,
383
+ generation_logits_hook_func,
384
+ sm_args: SMArgs,
385
  ) -> Union[DreamModelOutput, torch.LongTensor]:
386
  # init values
387
  output_history = generation_config.output_history
 
422
 
423
  # this allows user-defined token control of the intermediate steps
424
  x = generation_tokens_hook_func(None, x, None)
425
+
426
+ # Initialize necessary SM variables
427
+ inputs_embeds = None
428
+ embed_weights = self.get_input_embeddings().weight # (V,D)
429
+ max_gen_length = (x == mask_token_id).sum().item()
430
+
431
  for i in range(steps):
432
  mask_index = (x == mask_token_id)
433
+ if inputs_embeds is None:
434
+ logits = self(x, attention_mask, tok_idx).logits
435
+ else:
436
+ logits = self(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=tok_idx).logits
437
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
438
 
439
  # this allows user-defined logits control of the intermediate steps
 
478
  # this allows user-defined token control of the intermediate steps
479
  x = generation_tokens_hook_func(i, x, logits)
480
 
481
+ # DO SOFTMASKING MIXING
482
+ if sm_args.sm_alg != "none":
483
+ p_sm = get_mixing_factors_for_softmasking(
484
+ x,
485
+ logits,
486
+ mask_token_id,
487
+ max_gen_length,
488
+ sm_args
489
+ )
490
+ inputs_embeds = torch.matmul(p_sm, embed_weights) # (B,T,D)
491
+
492
  if histories is not None:
493
  histories.append(x.clone())
494
 
modeling_dream.py CHANGED
@@ -743,6 +743,20 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
743
 
744
  # Initialize weights and apply final processing
745
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
  def reset_rope_parameters(self):
748
  self.model.rotary_emb.reset_parameters()
 
743
 
744
  # Initialize weights and apply final processing
745
  self.post_init()
746
+
747
+ # Apparently needed for LM Eval backend
748
+ # Adapted from LLaDa-Instruct
749
+ def prepare_inputs_for_generation(
750
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
751
+ ):
752
+ if past_key_values:
753
+ # This is because we want the model to only process the last generated token.
754
+ input_ids = input_ids[:, -1:]
755
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
756
+
757
+ model_inputs.update(kwargs)
758
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
759
+ return model_inputs
760
 
761
  def reset_rope_parameters(self):
762
  self.model.rotary_emb.reset_parameters()
softmasking_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class SMArgs:
7
+ """Arguments for Softmasking"""
8
+
9
+ # sm algorithm
10
+ sm_alg: str = "none" # "mixinputs_with_topk" or "mixinputs_with_temp"
11
+ sm_schedule: str = "none" # "none", "linear", or "stepwise"
12
+
13
+ # lambda(路) parameters
14
+ scale: float = 0.0 # overall strength of mixing (0 disables mixing)
15
+ steepness: float = 0.0 # sigmoid steepness for entropy->lambda map
16
+ offset: float = 0.0 # sigmoid offset entropy->lambda map
17
+
18
+ # used only when sm_alg == "mixinputs_with_topk"
19
+ mixinputs_k: int = 3
20
+ # used only when sm_alg == "mixinputs_with_temp"
21
+ mixinputs_temp: float = 1.0
22
+
23
+ def get_mixing_factors_for_softmasking(input_ids, logits_prelim, mask_token_id, max_gen_length, sm_args):
24
+ """Compute mixing factors and output probabilities for Softmasking."""
25
+
26
+ # Create a one-hot distribution for the original input `xt`.
27
+ xt_one_hot = F.one_hot(input_ids, num_classes=logits_prelim.shape[-1]).to(logits_prelim.dtype)
28
+
29
+ # First get the negative entropy to calculate lambda
30
+ temperature = sm_args.mixinputs_temp if sm_args.sm_alg == "mixinputs_with_temp" else 1.0
31
+ neg_entropy, p = get_neg_entropy_and_probabilities(logits_prelim, temperature=temperature)
32
+
33
+ # Update scale with schedule if needed
34
+ if sm_args.schedule != "none":
35
+ num_mask_token = (input_ids == mask_token_id).sum().item()
36
+ scale = get_time_dependence(
37
+ max_gen_length=max_gen_length,
38
+ num_mask_token=num_mask_token,
39
+ scale=sm_args.scale,
40
+ schedule=sm_args.sm_schedule
41
+ )
42
+ else:
43
+ scale = sm_args.scale
44
+
45
+ # Calculate lambda tensor
46
+ mask_positions = (input_ids == mask_token_id)
47
+ lambda_tensor = calculate_lambda_tensor(neg_entropy, mask_positions,
48
+ scale, sm_args.steepness, sm_args.offset)
49
+
50
+ if sm_args.sm_alg == "mixinputs_with_topk":
51
+ # Only fill probabilities for top-k tokens
52
+ p = get_only_topk_probs(logits_prelim, sm_args.mixinputs_k)
53
+
54
+ # Create convex combination for output probabilities
55
+ p_out = (1 - lambda_tensor) * xt_one_hot \
56
+ + lambda_tensor * p
57
+
58
+ return p_out
59
+
60
+ def get_neg_entropy_and_probabilities(logits, temperature=1.0):
61
+ """Get negative entropy and probabilities from logits"""
62
+
63
+ epsilon = 1e-10
64
+ p = torch.softmax(logits / temperature, dim=-1) # (B,T,V)
65
+ logp = torch.log(p + epsilon)
66
+ neg_entropy = torch.sum(p * logp, dim=-1)
67
+ return neg_entropy, p
68
+
69
+ def calculate_lambda_tensor(neg_entropy, mask_positions, scale, steepness, offset):
70
+ """Calculate lambda tensor from negative entropy"""
71
+
72
+ if neg_entropy is None or scale == 0.0:
73
+ return torch.zeros_like(neg_entropy)
74
+
75
+ # scale negative entropy to [0,1] using sigmoid
76
+ lambda_tensor = neg_entropy
77
+ lambda_tensor = scale * torch.sigmoid(steepness * (lambda_tensor - offset))
78
+
79
+ # apply only on mask positions
80
+ lambda_tensor = torch.where(mask_positions, lambda_tensor, torch.zeros_like(lambda_tensor))
81
+ return lambda_tensor.unsqueeze(-1) # (B,T,1)
82
+
83
+ def get_only_topk_probs(logits, mixinputs_k=3):
84
+ """Compute a full-vocabulary probability tensor where only the top-k tokens per position
85
+ receive softmax probabilities and all other entries are zero."""
86
+
87
+ topk_logits, topk_indices = torch.topk(logits, k=mixinputs_k, dim=-1) # (batch_size, seq_len, k)
88
+
89
+ topk_probs = torch.softmax(topk_logits, dim=-1) # (batch_size, seq_len, k)
90
+ topk_sum = topk_probs.sum(dim=-1) # (batch_size, seq_len)
91
+ assert torch.allclose(topk_sum, torch.ones_like(topk_sum), atol=1e-1), \
92
+ f"Top-k softmax probabilities do not sum to 1: max deviation = {(topk_sum - 1).abs().max().item()}"
93
+
94
+ probs_full = torch.zeros_like(logits) # (B, L, V)
95
+ probs_full.scatter_(-1, topk_indices, topk_probs) # fill top-k
96
+ assert torch.sum(probs_full > 0).item() == mixinputs_k * logits.shape[0] * logits.shape[1], \
97
+ f"Number of non-zero entries in probs_full is incorrect: got {torch.sum(probs_full > 0).item()}, expected {mixinputs_k * logits.shape[0] * logits.shape[1]}"
98
+
99
+ return probs_full
100
+
101
+ def get_time_dependence(
102
+ max_gen_length: int,
103
+ num_mask_token: int,
104
+ scale: float,
105
+ schedule: str,
106
+ sm_to_hm: bool = True,
107
+ threshold: float = 0.5,
108
+ ) -> float:
109
+ """Return scale factor depending on decoding progress."""
110
+ t = num_mask_token / max_gen_length if max_gen_length else 1.0
111
+
112
+ if schedule == "none":
113
+ return scale
114
+
115
+ if schedule == "linear":
116
+ return scale * (t if sm_to_hm else 1 - t)
117
+
118
+ if schedule == "stepwise":
119
+ cond = t > threshold if sm_to_hm else t < threshold
120
+ return scale if cond else 0
121
+
122
+ raise ValueError(f"Unknown schedule: {schedule}")