Upload folder using huggingface_hub
Browse files- generation_utils.py +38 -3
- modeling_dream.py +14 -0
- softmasking_utils.py +122 -0
generation_utils.py
CHANGED
|
@@ -30,6 +30,7 @@ from transformers.utils import (
|
|
| 30 |
is_torchdynamo_compiling,
|
| 31 |
logging,
|
| 32 |
)
|
|
|
|
| 33 |
|
| 34 |
logger = logging.get_logger(__name__)
|
| 35 |
|
|
@@ -351,12 +352,25 @@ class DreamGenerationMixin:
|
|
| 351 |
attention_mask=attention_mask
|
| 352 |
)
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
result = self._sample(
|
| 355 |
input_ids,
|
| 356 |
attention_mask=attention_mask,
|
| 357 |
generation_config=generation_config,
|
| 358 |
generation_tokens_hook_func=generation_tokens_hook_func,
|
| 359 |
-
generation_logits_hook_func=generation_logits_hook_func
|
|
|
|
| 360 |
)
|
| 361 |
return result
|
| 362 |
|
|
@@ -366,7 +380,8 @@ class DreamGenerationMixin:
|
|
| 366 |
attention_mask: Optional[torch.LongTensor],
|
| 367 |
generation_config: DreamGenerationConfig,
|
| 368 |
generation_tokens_hook_func,
|
| 369 |
-
generation_logits_hook_func
|
|
|
|
| 370 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 371 |
# init values
|
| 372 |
output_history = generation_config.output_history
|
|
@@ -407,9 +422,18 @@ class DreamGenerationMixin:
|
|
| 407 |
|
| 408 |
# this allows user-defined token control of the intermediate steps
|
| 409 |
x = generation_tokens_hook_func(None, x, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
for i in range(steps):
|
| 411 |
mask_index = (x == mask_token_id)
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
| 413 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 414 |
|
| 415 |
# this allows user-defined logits control of the intermediate steps
|
|
@@ -454,6 +478,17 @@ class DreamGenerationMixin:
|
|
| 454 |
# this allows user-defined token control of the intermediate steps
|
| 455 |
x = generation_tokens_hook_func(i, x, logits)
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
if histories is not None:
|
| 458 |
histories.append(x.clone())
|
| 459 |
|
|
|
|
| 30 |
is_torchdynamo_compiling,
|
| 31 |
logging,
|
| 32 |
)
|
| 33 |
+
from .softmasking_utils import SMArgs, get_mixing_factors_for_softmasking
|
| 34 |
|
| 35 |
logger = logging.get_logger(__name__)
|
| 36 |
|
|
|
|
| 352 |
attention_mask=attention_mask
|
| 353 |
)
|
| 354 |
|
| 355 |
+
###### LOAD IN Softmasking PARAMETERS ######
|
| 356 |
+
sm_args = SMArgs(
|
| 357 |
+
sm_alg=kwargs.pop("transparency_alg", "none"),
|
| 358 |
+
sm_schedule=kwargs.pop("transparency_scheduling", "none"),
|
| 359 |
+
scale=kwargs.pop("transparency_scale", 0.0),
|
| 360 |
+
steepness=kwargs.pop("transparency_steepness", 0.0),
|
| 361 |
+
offset=kwargs.pop("transparency_centre", 0.0),
|
| 362 |
+
mixinputs_k=kwargs.pop("mixinputs_k", 1),
|
| 363 |
+
mixinputs_temp=kwargs.pop("mixture_temp", 1.0),
|
| 364 |
+
)
|
| 365 |
+
#############################################
|
| 366 |
+
|
| 367 |
result = self._sample(
|
| 368 |
input_ids,
|
| 369 |
attention_mask=attention_mask,
|
| 370 |
generation_config=generation_config,
|
| 371 |
generation_tokens_hook_func=generation_tokens_hook_func,
|
| 372 |
+
generation_logits_hook_func=generation_logits_hook_func,
|
| 373 |
+
sm_args=sm_args
|
| 374 |
)
|
| 375 |
return result
|
| 376 |
|
|
|
|
| 380 |
attention_mask: Optional[torch.LongTensor],
|
| 381 |
generation_config: DreamGenerationConfig,
|
| 382 |
generation_tokens_hook_func,
|
| 383 |
+
generation_logits_hook_func,
|
| 384 |
+
sm_args: SMArgs,
|
| 385 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 386 |
# init values
|
| 387 |
output_history = generation_config.output_history
|
|
|
|
| 422 |
|
| 423 |
# this allows user-defined token control of the intermediate steps
|
| 424 |
x = generation_tokens_hook_func(None, x, None)
|
| 425 |
+
|
| 426 |
+
# Initialize necessary SM variables
|
| 427 |
+
inputs_embeds = None
|
| 428 |
+
embed_weights = self.get_input_embeddings().weight # (V,D)
|
| 429 |
+
max_gen_length = (x == mask_token_id).sum().item()
|
| 430 |
+
|
| 431 |
for i in range(steps):
|
| 432 |
mask_index = (x == mask_token_id)
|
| 433 |
+
if inputs_embeds is None:
|
| 434 |
+
logits = self(x, attention_mask, tok_idx).logits
|
| 435 |
+
else:
|
| 436 |
+
logits = self(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=tok_idx).logits
|
| 437 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 438 |
|
| 439 |
# this allows user-defined logits control of the intermediate steps
|
|
|
|
| 478 |
# this allows user-defined token control of the intermediate steps
|
| 479 |
x = generation_tokens_hook_func(i, x, logits)
|
| 480 |
|
| 481 |
+
# DO SOFTMASKING MIXING
|
| 482 |
+
if sm_args.sm_alg != "none":
|
| 483 |
+
p_sm = get_mixing_factors_for_softmasking(
|
| 484 |
+
x,
|
| 485 |
+
logits,
|
| 486 |
+
mask_token_id,
|
| 487 |
+
max_gen_length,
|
| 488 |
+
sm_args
|
| 489 |
+
)
|
| 490 |
+
inputs_embeds = torch.matmul(p_sm, embed_weights) # (B,T,D)
|
| 491 |
+
|
| 492 |
if histories is not None:
|
| 493 |
histories.append(x.clone())
|
| 494 |
|
modeling_dream.py
CHANGED
|
@@ -743,6 +743,20 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
|
|
| 743 |
|
| 744 |
# Initialize weights and apply final processing
|
| 745 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
def reset_rope_parameters(self):
|
| 748 |
self.model.rotary_emb.reset_parameters()
|
|
|
|
| 743 |
|
| 744 |
# Initialize weights and apply final processing
|
| 745 |
self.post_init()
|
| 746 |
+
|
| 747 |
+
# Apparently needed for LM Eval backend
|
| 748 |
+
# Adapted from LLaDa-Instruct
|
| 749 |
+
def prepare_inputs_for_generation(
|
| 750 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 751 |
+
):
|
| 752 |
+
if past_key_values:
|
| 753 |
+
# This is because we want the model to only process the last generated token.
|
| 754 |
+
input_ids = input_ids[:, -1:]
|
| 755 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 756 |
+
|
| 757 |
+
model_inputs.update(kwargs)
|
| 758 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 759 |
+
return model_inputs
|
| 760 |
|
| 761 |
def reset_rope_parameters(self):
|
| 762 |
self.model.rotary_emb.reset_parameters()
|
softmasking_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class SMArgs:
|
| 7 |
+
"""Arguments for Softmasking"""
|
| 8 |
+
|
| 9 |
+
# sm algorithm
|
| 10 |
+
sm_alg: str = "none" # "mixinputs_with_topk" or "mixinputs_with_temp"
|
| 11 |
+
sm_schedule: str = "none" # "none", "linear", or "stepwise"
|
| 12 |
+
|
| 13 |
+
# lambda(路) parameters
|
| 14 |
+
scale: float = 0.0 # overall strength of mixing (0 disables mixing)
|
| 15 |
+
steepness: float = 0.0 # sigmoid steepness for entropy->lambda map
|
| 16 |
+
offset: float = 0.0 # sigmoid offset entropy->lambda map
|
| 17 |
+
|
| 18 |
+
# used only when sm_alg == "mixinputs_with_topk"
|
| 19 |
+
mixinputs_k: int = 3
|
| 20 |
+
# used only when sm_alg == "mixinputs_with_temp"
|
| 21 |
+
mixinputs_temp: float = 1.0
|
| 22 |
+
|
| 23 |
+
def get_mixing_factors_for_softmasking(input_ids, logits_prelim, mask_token_id, max_gen_length, sm_args):
|
| 24 |
+
"""Compute mixing factors and output probabilities for Softmasking."""
|
| 25 |
+
|
| 26 |
+
# Create a one-hot distribution for the original input `xt`.
|
| 27 |
+
xt_one_hot = F.one_hot(input_ids, num_classes=logits_prelim.shape[-1]).to(logits_prelim.dtype)
|
| 28 |
+
|
| 29 |
+
# First get the negative entropy to calculate lambda
|
| 30 |
+
temperature = sm_args.mixinputs_temp if sm_args.sm_alg == "mixinputs_with_temp" else 1.0
|
| 31 |
+
neg_entropy, p = get_neg_entropy_and_probabilities(logits_prelim, temperature=temperature)
|
| 32 |
+
|
| 33 |
+
# Update scale with schedule if needed
|
| 34 |
+
if sm_args.schedule != "none":
|
| 35 |
+
num_mask_token = (input_ids == mask_token_id).sum().item()
|
| 36 |
+
scale = get_time_dependence(
|
| 37 |
+
max_gen_length=max_gen_length,
|
| 38 |
+
num_mask_token=num_mask_token,
|
| 39 |
+
scale=sm_args.scale,
|
| 40 |
+
schedule=sm_args.sm_schedule
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
scale = sm_args.scale
|
| 44 |
+
|
| 45 |
+
# Calculate lambda tensor
|
| 46 |
+
mask_positions = (input_ids == mask_token_id)
|
| 47 |
+
lambda_tensor = calculate_lambda_tensor(neg_entropy, mask_positions,
|
| 48 |
+
scale, sm_args.steepness, sm_args.offset)
|
| 49 |
+
|
| 50 |
+
if sm_args.sm_alg == "mixinputs_with_topk":
|
| 51 |
+
# Only fill probabilities for top-k tokens
|
| 52 |
+
p = get_only_topk_probs(logits_prelim, sm_args.mixinputs_k)
|
| 53 |
+
|
| 54 |
+
# Create convex combination for output probabilities
|
| 55 |
+
p_out = (1 - lambda_tensor) * xt_one_hot \
|
| 56 |
+
+ lambda_tensor * p
|
| 57 |
+
|
| 58 |
+
return p_out
|
| 59 |
+
|
| 60 |
+
def get_neg_entropy_and_probabilities(logits, temperature=1.0):
|
| 61 |
+
"""Get negative entropy and probabilities from logits"""
|
| 62 |
+
|
| 63 |
+
epsilon = 1e-10
|
| 64 |
+
p = torch.softmax(logits / temperature, dim=-1) # (B,T,V)
|
| 65 |
+
logp = torch.log(p + epsilon)
|
| 66 |
+
neg_entropy = torch.sum(p * logp, dim=-1)
|
| 67 |
+
return neg_entropy, p
|
| 68 |
+
|
| 69 |
+
def calculate_lambda_tensor(neg_entropy, mask_positions, scale, steepness, offset):
|
| 70 |
+
"""Calculate lambda tensor from negative entropy"""
|
| 71 |
+
|
| 72 |
+
if neg_entropy is None or scale == 0.0:
|
| 73 |
+
return torch.zeros_like(neg_entropy)
|
| 74 |
+
|
| 75 |
+
# scale negative entropy to [0,1] using sigmoid
|
| 76 |
+
lambda_tensor = neg_entropy
|
| 77 |
+
lambda_tensor = scale * torch.sigmoid(steepness * (lambda_tensor - offset))
|
| 78 |
+
|
| 79 |
+
# apply only on mask positions
|
| 80 |
+
lambda_tensor = torch.where(mask_positions, lambda_tensor, torch.zeros_like(lambda_tensor))
|
| 81 |
+
return lambda_tensor.unsqueeze(-1) # (B,T,1)
|
| 82 |
+
|
| 83 |
+
def get_only_topk_probs(logits, mixinputs_k=3):
|
| 84 |
+
"""Compute a full-vocabulary probability tensor where only the top-k tokens per position
|
| 85 |
+
receive softmax probabilities and all other entries are zero."""
|
| 86 |
+
|
| 87 |
+
topk_logits, topk_indices = torch.topk(logits, k=mixinputs_k, dim=-1) # (batch_size, seq_len, k)
|
| 88 |
+
|
| 89 |
+
topk_probs = torch.softmax(topk_logits, dim=-1) # (batch_size, seq_len, k)
|
| 90 |
+
topk_sum = topk_probs.sum(dim=-1) # (batch_size, seq_len)
|
| 91 |
+
assert torch.allclose(topk_sum, torch.ones_like(topk_sum), atol=1e-1), \
|
| 92 |
+
f"Top-k softmax probabilities do not sum to 1: max deviation = {(topk_sum - 1).abs().max().item()}"
|
| 93 |
+
|
| 94 |
+
probs_full = torch.zeros_like(logits) # (B, L, V)
|
| 95 |
+
probs_full.scatter_(-1, topk_indices, topk_probs) # fill top-k
|
| 96 |
+
assert torch.sum(probs_full > 0).item() == mixinputs_k * logits.shape[0] * logits.shape[1], \
|
| 97 |
+
f"Number of non-zero entries in probs_full is incorrect: got {torch.sum(probs_full > 0).item()}, expected {mixinputs_k * logits.shape[0] * logits.shape[1]}"
|
| 98 |
+
|
| 99 |
+
return probs_full
|
| 100 |
+
|
| 101 |
+
def get_time_dependence(
|
| 102 |
+
max_gen_length: int,
|
| 103 |
+
num_mask_token: int,
|
| 104 |
+
scale: float,
|
| 105 |
+
schedule: str,
|
| 106 |
+
sm_to_hm: bool = True,
|
| 107 |
+
threshold: float = 0.5,
|
| 108 |
+
) -> float:
|
| 109 |
+
"""Return scale factor depending on decoding progress."""
|
| 110 |
+
t = num_mask_token / max_gen_length if max_gen_length else 1.0
|
| 111 |
+
|
| 112 |
+
if schedule == "none":
|
| 113 |
+
return scale
|
| 114 |
+
|
| 115 |
+
if schedule == "linear":
|
| 116 |
+
return scale * (t if sm_to_hm else 1 - t)
|
| 117 |
+
|
| 118 |
+
if schedule == "stepwise":
|
| 119 |
+
cond = t > threshold if sm_to_hm else t < threshold
|
| 120 |
+
return scale if cond else 0
|
| 121 |
+
|
| 122 |
+
raise ValueError(f"Unknown schedule: {schedule}")
|