Add model and code
Browse files- .gitignore +9 -0
- __init__.py +0 -0
- backbone_automodel.py +116 -0
- backbone_custom_modeling_qwen3.py +179 -0
- backbone_encoder_decoder.py +654 -0
- denoiser_base.py +464 -0
- diffusion.py +1 -1
- noise_schedule_noise_schedules.py +80 -0
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.hf_cache
|
| 3 |
+
.idea
|
| 4 |
+
.ipynb_checkpoints/
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.ruff_cache/
|
| 7 |
+
.DS_Store
|
| 8 |
+
outputs/
|
| 9 |
+
watch_folder
|
__init__.py
ADDED
|
File without changes
|
backbone_automodel.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoConfig,
|
| 7 |
+
AutoModel,
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoModelForMaskedLM,
|
| 10 |
+
DynamicCache,
|
| 11 |
+
)
|
| 12 |
+
from transformers.modeling_outputs import (
|
| 13 |
+
BaseModelOutputWithPast,
|
| 14 |
+
CausalLMOutputWithPast,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 21 |
+
except ImportError:
|
| 22 |
+
BlockMask = None
|
| 23 |
+
|
| 24 |
+
AUTO_MODEL_CLS = {
|
| 25 |
+
"AutoModel": AutoModel,
|
| 26 |
+
"AutoModelForCausalLM": AutoModelForCausalLM,
|
| 27 |
+
"AutoModelForMaskedLM": AutoModelForMaskedLM,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AutoModelFromPreTrained(nn.Module):
|
| 32 |
+
"""Simple wrapper class that enables using AutoModel from pre-trained."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
automodel_cls: Literal[
|
| 37 |
+
"AutoModel",
|
| 38 |
+
"AutoModelForCausalLM",
|
| 39 |
+
"AutoModelForMaskedLM",
|
| 40 |
+
],
|
| 41 |
+
pretrained_model_name_or_path: str,
|
| 42 |
+
trust_remote_code: bool = True,
|
| 43 |
+
num_layers: int = -1,
|
| 44 |
+
keep_top_layers: bool = False,
|
| 45 |
+
reinit_model: bool = False,
|
| 46 |
+
use_causal_mask: bool = False,
|
| 47 |
+
**automodel_init_kwargs,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.use_causal_mask = use_causal_mask
|
| 51 |
+
if reinit_model:
|
| 52 |
+
auto_config = AutoConfig.from_pretrained(
|
| 53 |
+
pretrained_model_name_or_path,
|
| 54 |
+
num_hidden_layers=num_layers,
|
| 55 |
+
trust_remote_code=trust_remote_code,
|
| 56 |
+
**automodel_init_kwargs,
|
| 57 |
+
)
|
| 58 |
+
self.model = CustomQwen3ForCausalLM(auto_config)
|
| 59 |
+
# self.model = AUTO_MODEL_CLS[automodel_cls].from_config(auto_config)
|
| 60 |
+
else:
|
| 61 |
+
self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained(
|
| 62 |
+
pretrained_model_name_or_path,
|
| 63 |
+
trust_remote_code=trust_remote_code,
|
| 64 |
+
**automodel_init_kwargs,
|
| 65 |
+
)
|
| 66 |
+
num_layers = (
|
| 67 |
+
len(self.model.model.layers) if num_layers == -1 else num_layers
|
| 68 |
+
)
|
| 69 |
+
if keep_top_layers:
|
| 70 |
+
self.model.model.layers = self.model.model.layers[-num_layers:]
|
| 71 |
+
else:
|
| 72 |
+
self.model.model.layers = self.model.model.layers[:num_layers]
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
input_ids: torch.LongTensor,
|
| 77 |
+
attention_mask: torch.FloatTensor | BlockMask | None = None,
|
| 78 |
+
position_ids: torch.LongTensor | None = None,
|
| 79 |
+
cache_position: torch.LongTensor | None = None,
|
| 80 |
+
past_key_values: DynamicCache | None = None,
|
| 81 |
+
fix_cache_length: bool = False, # False for AR, True for diffusion models
|
| 82 |
+
return_updated_cache=False,
|
| 83 |
+
**kwargs,
|
| 84 |
+
) -> CausalLMOutputWithPast | BaseModelOutputWithPast:
|
| 85 |
+
prev_cache_len = None
|
| 86 |
+
if past_key_values is not None and fix_cache_length:
|
| 87 |
+
prev_cache_len = [
|
| 88 |
+
past_key_values[i][0].shape[-2] # type: ignore
|
| 89 |
+
for i in range(len(past_key_values))
|
| 90 |
+
]
|
| 91 |
+
if self.use_causal_mask:
|
| 92 |
+
attention_mask = None # None --> enforces use of causal mask
|
| 93 |
+
model_output = self.model(
|
| 94 |
+
input_ids,
|
| 95 |
+
attention_mask=attention_mask,
|
| 96 |
+
position_ids=position_ids,
|
| 97 |
+
cache_position=cache_position,
|
| 98 |
+
past_key_values=past_key_values,
|
| 99 |
+
**kwargs,
|
| 100 |
+
)
|
| 101 |
+
if return_updated_cache:
|
| 102 |
+
return BaseModelOutputWithPast(past_key_values=model_output.past_key_values)
|
| 103 |
+
if (
|
| 104 |
+
prev_cache_len is not None
|
| 105 |
+
and model_output.get("past_key_values", None) is not None
|
| 106 |
+
):
|
| 107 |
+
# DynamicCache extends along sequence dimension by default;
|
| 108 |
+
# truncate back to original cache len
|
| 109 |
+
for i, cache_len in enumerate(prev_cache_len):
|
| 110 |
+
model_output.past_key_values.key_cache[i] = (
|
| 111 |
+
model_output.past_key_values.key_cache[i][..., :cache_len, :]
|
| 112 |
+
)
|
| 113 |
+
model_output.past_key_values.value_cache[i] = (
|
| 114 |
+
model_output.past_key_values.value_cache[i][..., :cache_len, :]
|
| 115 |
+
)
|
| 116 |
+
return model_output
|
backbone_custom_modeling_qwen3.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from transformers.models.qwen3.modeling_qwen3 import (
|
| 6 |
+
ALL_ATTENTION_FUNCTIONS,
|
| 7 |
+
Cache,
|
| 8 |
+
FlashAttentionKwargs,
|
| 9 |
+
Qwen3Attention,
|
| 10 |
+
Qwen3Config,
|
| 11 |
+
Qwen3DecoderLayer,
|
| 12 |
+
Qwen3ForCausalLM,
|
| 13 |
+
Qwen3Model,
|
| 14 |
+
eager_attention_forward,
|
| 15 |
+
rotate_half,
|
| 16 |
+
)
|
| 17 |
+
from transformers.processing_utils import Unpack
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def custom_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, q_start_idx=0):
|
| 24 |
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 25 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 26 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 27 |
+
q_embed = (q * cos[..., q_start_idx:, :]) + (
|
| 28 |
+
rotate_half(q) * sin[..., q_start_idx:, :]
|
| 29 |
+
)
|
| 30 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 31 |
+
return q_embed, k_embed
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CustomQwen3Attention(Qwen3Attention):
|
| 35 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: Qwen3Config, layer_idx: int):
|
| 38 |
+
super().__init__(config, layer_idx=layer_idx)
|
| 39 |
+
|
| 40 |
+
def forward(
|
| 41 |
+
self,
|
| 42 |
+
hidden_states: torch.Tensor,
|
| 43 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 44 |
+
attention_mask: Optional[torch.Tensor],
|
| 45 |
+
past_key_value: Optional[Cache] = None,
|
| 46 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 47 |
+
q_start_idx: int = 0, # > 0: decoder pass w/encoder inputs in hidden_states
|
| 48 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 49 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 50 |
+
input_shape = hidden_states.shape[:-1]
|
| 51 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 52 |
+
sa_hidden_sates = hidden_states[:, q_start_idx:, :]
|
| 53 |
+
query_input_shape = sa_hidden_sates.shape[:-1]
|
| 54 |
+
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
|
| 55 |
+
|
| 56 |
+
query_states = self.q_norm(
|
| 57 |
+
self.q_proj(sa_hidden_sates).reshape(query_hidden_shape)
|
| 58 |
+
).transpose(1, 2)
|
| 59 |
+
key_states = self.k_norm(
|
| 60 |
+
self.k_proj(hidden_states).view(hidden_shape)
|
| 61 |
+
).transpose(1, 2)
|
| 62 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 63 |
+
|
| 64 |
+
cos, sin = position_embeddings
|
| 65 |
+
query_states, key_states = custom_apply_rotary_pos_emb(
|
| 66 |
+
query_states, key_states, cos, sin, q_start_idx=q_start_idx
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if past_key_value is not None:
|
| 70 |
+
# sin and cos are specific to RoPE models
|
| 71 |
+
# cache_position needed for the static cache
|
| 72 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 73 |
+
key_states, value_states = past_key_value.update(
|
| 74 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# NOTE: downcast for flex-attention compatibility
|
| 78 |
+
query_states, key_states = (
|
| 79 |
+
query_states.to(value_states.dtype),
|
| 80 |
+
key_states.to(value_states.dtype),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
attention_interface: Callable = eager_attention_forward
|
| 84 |
+
if self.config._attn_implementation != "eager":
|
| 85 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 86 |
+
self.config._attn_implementation
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
attn_output, attn_weights = attention_interface(
|
| 90 |
+
self,
|
| 91 |
+
query_states,
|
| 92 |
+
key_states,
|
| 93 |
+
value_states,
|
| 94 |
+
attention_mask,
|
| 95 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 96 |
+
scaling=self.scaling,
|
| 97 |
+
sliding_window=self.sliding_window, # diff with Llama
|
| 98 |
+
**kwargs,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
|
| 102 |
+
attn_output = self.o_proj(attn_output)
|
| 103 |
+
return attn_output, attn_weights
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
| 107 |
+
def __init__(self, config: Qwen3Config, layer_idx: int):
|
| 108 |
+
super().__init__(config, layer_idx=layer_idx)
|
| 109 |
+
self.self_attn = CustomQwen3Attention(config=config, layer_idx=layer_idx)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
hidden_states: torch.Tensor,
|
| 114 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 115 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 116 |
+
past_key_value: Optional[Cache] = None,
|
| 117 |
+
output_attentions: Optional[bool] = False,
|
| 118 |
+
use_cache: Optional[bool] = False,
|
| 119 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 120 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 121 |
+
q_start_idx: int = 0,
|
| 122 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 123 |
+
) -> Tuple[
|
| 124 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
| 125 |
+
]:
|
| 126 |
+
residual = hidden_states[:, q_start_idx:, ...]
|
| 127 |
+
|
| 128 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 129 |
+
|
| 130 |
+
# Self Attention
|
| 131 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 132 |
+
hidden_states=hidden_states,
|
| 133 |
+
attention_mask=attention_mask,
|
| 134 |
+
position_ids=position_ids,
|
| 135 |
+
past_key_value=past_key_value,
|
| 136 |
+
output_attentions=output_attentions,
|
| 137 |
+
use_cache=use_cache,
|
| 138 |
+
cache_position=cache_position,
|
| 139 |
+
position_embeddings=position_embeddings,
|
| 140 |
+
q_start_idx=q_start_idx,
|
| 141 |
+
**kwargs,
|
| 142 |
+
)
|
| 143 |
+
hidden_states = residual + hidden_states
|
| 144 |
+
# return hidden_states
|
| 145 |
+
|
| 146 |
+
# Fully Connected
|
| 147 |
+
residual = hidden_states
|
| 148 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 149 |
+
hidden_states = self.mlp(hidden_states)
|
| 150 |
+
hidden_states = residual + hidden_states
|
| 151 |
+
|
| 152 |
+
outputs = (hidden_states,)
|
| 153 |
+
if output_attentions:
|
| 154 |
+
outputs += (self_attn_weights,)
|
| 155 |
+
|
| 156 |
+
return outputs
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CustomQwen3Model(Qwen3Model):
|
| 160 |
+
def __init__(self, config: Qwen3Config):
|
| 161 |
+
super().__init__(config)
|
| 162 |
+
self.layers = nn.ModuleList(
|
| 163 |
+
[
|
| 164 |
+
CustomQwen3DecoderLayer(config, layer_idx)
|
| 165 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
# Initialize weights and apply final processing
|
| 169 |
+
self.post_init()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
|
| 173 |
+
def __init__(self, config: Qwen3Config):
|
| 174 |
+
super().__init__(config)
|
| 175 |
+
# Initialize a new model with custom layers
|
| 176 |
+
self.model = CustomQwen3Model(config)
|
| 177 |
+
|
| 178 |
+
# Initialize weights and apply final processing
|
| 179 |
+
self.post_init()
|
backbone_encoder_decoder.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 8 |
+
from transformers.cache_utils import DynamicCache
|
| 9 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 10 |
+
from transformers.modeling_outputs import (
|
| 11 |
+
BaseModelOutputWithPast,
|
| 12 |
+
CausalLMOutputWithPast,
|
| 13 |
+
ModelOutput,
|
| 14 |
+
)
|
| 15 |
+
from transformers.processing_utils import Unpack
|
| 16 |
+
from transformers.utils import logging
|
| 17 |
+
|
| 18 |
+
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 22 |
+
except ImportError:
|
| 23 |
+
BlockMask = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class EncoderBaseModelOutputWithPast(ModelOutput):
|
| 31 |
+
"""Custom (encoder) model output.
|
| 32 |
+
Stores previous decoder and updated encoder cache and encoder last hidden state.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
|
| 36 |
+
None
|
| 37 |
+
)
|
| 38 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 39 |
+
encoder_past_key_values: Optional[
|
| 40 |
+
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
|
| 41 |
+
] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class DecoderCausalLMOutputWithPast(ModelOutput):
|
| 46 |
+
"""Custom (decoder) model output.
|
| 47 |
+
Stores previous encoder and updated decoder cache and decoder logits.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
logits: Optional[torch.FloatTensor] = None
|
| 51 |
+
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
|
| 52 |
+
None
|
| 53 |
+
)
|
| 54 |
+
encoder_past_key_values: Optional[
|
| 55 |
+
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
|
| 56 |
+
] = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LLMasEncoderDecoder(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
pretrained_model_name_or_path: str,
|
| 63 |
+
max_length: int,
|
| 64 |
+
attn_backend: str = "sdpa",
|
| 65 |
+
freeze_encoder: bool = False,
|
| 66 |
+
reinit_encoder: bool = False,
|
| 67 |
+
reinit_decoder: bool = False,
|
| 68 |
+
tie_encoder_decoder_weights: bool = False,
|
| 69 |
+
use_encoder_causal_mask: bool = False,
|
| 70 |
+
num_encoder_layers: int = -1,
|
| 71 |
+
num_decoder_layers: int = -1,
|
| 72 |
+
keep_top_encoder_layers: bool = False,
|
| 73 |
+
keep_top_decoder_layers: bool = False,
|
| 74 |
+
use_gradient_checkpointing: bool = False,
|
| 75 |
+
**llm_init_kwargs,
|
| 76 |
+
):
|
| 77 |
+
assert not (tie_encoder_decoder_weights and reinit_decoder), (
|
| 78 |
+
"Cannot tie encoder-decoder weights and reinitialize decoder."
|
| 79 |
+
)
|
| 80 |
+
assert not (tie_encoder_decoder_weights and freeze_encoder), (
|
| 81 |
+
"Cannot freeze encoder weights when tying encoder-decoder weights."
|
| 82 |
+
)
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.use_encoder_causal_mask = use_encoder_causal_mask
|
| 85 |
+
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
|
| 86 |
+
|
| 87 |
+
if reinit_encoder:
|
| 88 |
+
assert num_encoder_layers > 0
|
| 89 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 90 |
+
pretrained_model_name_or_path,
|
| 91 |
+
trust_remote_code=True,
|
| 92 |
+
num_hidden_layers=num_encoder_layers,
|
| 93 |
+
attn_implementation=attn_backend,
|
| 94 |
+
**llm_init_kwargs,
|
| 95 |
+
)
|
| 96 |
+
self.encoder = CustomQwen3ForCausalLM(encoder_config)
|
| 97 |
+
else:
|
| 98 |
+
self.encoder = CustomQwen3ForCausalLM.from_pretrained(
|
| 99 |
+
pretrained_model_name_or_path,
|
| 100 |
+
trust_remote_code=True,
|
| 101 |
+
attn_implementation=attn_backend,
|
| 102 |
+
**llm_init_kwargs,
|
| 103 |
+
)
|
| 104 |
+
assert num_encoder_layers <= len(self.encoder.model.layers), (
|
| 105 |
+
f"Cannot keep {num_encoder_layers} layers. "
|
| 106 |
+
f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
|
| 107 |
+
)
|
| 108 |
+
num_encoder_layers = (
|
| 109 |
+
len(self.encoder.model.layers)
|
| 110 |
+
if num_encoder_layers == -1
|
| 111 |
+
else num_encoder_layers
|
| 112 |
+
)
|
| 113 |
+
if keep_top_encoder_layers:
|
| 114 |
+
self.encoder.model.layers = self.encoder.model.layers[
|
| 115 |
+
-num_encoder_layers:
|
| 116 |
+
]
|
| 117 |
+
else:
|
| 118 |
+
self.encoder.model.layers = self.encoder.model.layers[
|
| 119 |
+
:num_encoder_layers
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
if freeze_encoder:
|
| 123 |
+
for name, param in self.encoder.named_parameters():
|
| 124 |
+
if "embed_tokens" not in name:
|
| 125 |
+
param.requires_grad = False
|
| 126 |
+
if use_gradient_checkpointing:
|
| 127 |
+
self.encoder.gradient_checkpointing_enable()
|
| 128 |
+
|
| 129 |
+
if tie_encoder_decoder_weights:
|
| 130 |
+
self.decoder = self.encoder
|
| 131 |
+
num_decoder_layers = (
|
| 132 |
+
len(self.decoder.model.layers)
|
| 133 |
+
if num_decoder_layers == -1
|
| 134 |
+
else num_decoder_layers
|
| 135 |
+
)
|
| 136 |
+
assert num_decoder_layers <= len(self.decoder.model.layers), (
|
| 137 |
+
f"Cannot keep {num_decoder_layers} layers. "
|
| 138 |
+
f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
|
| 139 |
+
)
|
| 140 |
+
# Keep **top** layers when tying weights
|
| 141 |
+
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
|
| 142 |
+
-num_decoder_layers:
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
if reinit_decoder:
|
| 147 |
+
assert num_decoder_layers > 0
|
| 148 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 149 |
+
pretrained_model_name_or_path,
|
| 150 |
+
trust_remote_code=True,
|
| 151 |
+
num_hidden_layers=num_decoder_layers,
|
| 152 |
+
attn_implementation=attn_backend,
|
| 153 |
+
**llm_init_kwargs,
|
| 154 |
+
)
|
| 155 |
+
self.decoder = CustomQwen3ForCausalLM(decoder_config)
|
| 156 |
+
else:
|
| 157 |
+
self.decoder = CustomQwen3ForCausalLM.from_pretrained(
|
| 158 |
+
pretrained_model_name_or_path,
|
| 159 |
+
trust_remote_code=True,
|
| 160 |
+
attn_implementation=attn_backend,
|
| 161 |
+
**llm_init_kwargs,
|
| 162 |
+
)
|
| 163 |
+
assert num_decoder_layers <= len(self.decoder.model.layers), (
|
| 164 |
+
f"Cannot keep {num_decoder_layers} layers. "
|
| 165 |
+
f"Pre-trained model only has {len(self.decoder.layers)} layers."
|
| 166 |
+
)
|
| 167 |
+
if keep_top_decoder_layers:
|
| 168 |
+
self.decoder.model.layers = self.decoder.model.layers[
|
| 169 |
+
-num_decoder_layers:
|
| 170 |
+
]
|
| 171 |
+
else:
|
| 172 |
+
self.decoder.model.layers = self.decoder.model.layers[
|
| 173 |
+
:num_decoder_layers
|
| 174 |
+
]
|
| 175 |
+
del self.decoder.model.embed_tokens
|
| 176 |
+
# if in the original LM, the lm_head is weight-tied to embedding,
|
| 177 |
+
# point decoder lm_head to encoder's (instead of initializing separately)
|
| 178 |
+
if (
|
| 179 |
+
self.encoder.lm_head.weight.data_ptr()
|
| 180 |
+
== self.encoder.model.embed_tokens.weight.data_ptr()
|
| 181 |
+
):
|
| 182 |
+
self.decoder.lm_head = self.encoder.lm_head
|
| 183 |
+
else:
|
| 184 |
+
del self.encoder.lm_head
|
| 185 |
+
if use_gradient_checkpointing:
|
| 186 |
+
self.decoder.gradient_checkpointing_enable()
|
| 187 |
+
self.max_length = max_length
|
| 188 |
+
|
| 189 |
+
def freeze_encoder(self):
|
| 190 |
+
for p in self.encoder.model.parameters():
|
| 191 |
+
p.requires_grad = False
|
| 192 |
+
|
| 193 |
+
def unfreeze_encoder(self):
|
| 194 |
+
for p in self.encoder.model.parameters():
|
| 195 |
+
p.requires_grad = True
|
| 196 |
+
|
| 197 |
+
# noinspection PyUnusedLocal
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
# Decoder inputs
|
| 201 |
+
input_ids: torch.LongTensor,
|
| 202 |
+
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
|
| 203 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 204 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 205 |
+
past_key_values: Optional[DynamicCache] = None,
|
| 206 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None,
|
| 207 |
+
# Encoder inputs
|
| 208 |
+
encoder_input_ids: Optional[torch.LongTensor] = None,
|
| 209 |
+
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
|
| 210 |
+
encoder_position_ids: Optional[torch.LongTensor] = None,
|
| 211 |
+
encoder_cache_position: Optional[torch.LongTensor] = None,
|
| 212 |
+
encoder_past_key_values: Optional[DynamicCache] = None,
|
| 213 |
+
# Additional args
|
| 214 |
+
fix_cache_length: bool = True, # Not used; compatibility with other backbones
|
| 215 |
+
return_updated_cache: bool = False,
|
| 216 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 217 |
+
) -> Union[DecoderCausalLMOutputWithPast, EncoderBaseModelOutputWithPast]:
|
| 218 |
+
# During training/eval encoder_last_hidden_state is None.
|
| 219 |
+
# During generation encoder_last_hidden_state can be not None.
|
| 220 |
+
new_seen_tokens = (
|
| 221 |
+
0
|
| 222 |
+
if encoder_last_hidden_state is None
|
| 223 |
+
else encoder_last_hidden_state.shape[1]
|
| 224 |
+
)
|
| 225 |
+
# Encode clean tokens
|
| 226 |
+
if encoder_input_ids is not None:
|
| 227 |
+
if self.use_encoder_causal_mask:
|
| 228 |
+
encoder_attention_mask = None # None --> enforces use of causal mask
|
| 229 |
+
if encoder_cache_position is None and encoder_position_ids is not None:
|
| 230 |
+
encoder_cache_position = encoder_position_ids[0]
|
| 231 |
+
encoder_output = self.encoder.model(
|
| 232 |
+
input_ids=encoder_input_ids,
|
| 233 |
+
attention_mask=encoder_attention_mask,
|
| 234 |
+
position_ids=encoder_position_ids,
|
| 235 |
+
use_cache=True,
|
| 236 |
+
past_key_values=encoder_past_key_values,
|
| 237 |
+
cache_position=encoder_cache_position,
|
| 238 |
+
)
|
| 239 |
+
if return_updated_cache:
|
| 240 |
+
# encoder_output.past_key_values now contains latest encoder input
|
| 241 |
+
return EncoderBaseModelOutputWithPast(
|
| 242 |
+
encoder_last_hidden_state=encoder_output.last_hidden_state,
|
| 243 |
+
encoder_past_key_values=encoder_output.past_key_values,
|
| 244 |
+
past_key_values=past_key_values,
|
| 245 |
+
)
|
| 246 |
+
encoder_last_hidden_state = encoder_output.last_hidden_state
|
| 247 |
+
|
| 248 |
+
# Run decoder with xattn to clean token hidden states
|
| 249 |
+
if encoder_last_hidden_state is None: # No new encoder tokens
|
| 250 |
+
q_start_idx = 0
|
| 251 |
+
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
|
| 252 |
+
if cache_position is None:
|
| 253 |
+
if position_ids is not None:
|
| 254 |
+
cache_position = position_ids[0]
|
| 255 |
+
else:
|
| 256 |
+
past_seen_tokens = (
|
| 257 |
+
past_key_values.get_seq_length()
|
| 258 |
+
if past_key_values is not None
|
| 259 |
+
else 0
|
| 260 |
+
)
|
| 261 |
+
cache_position = torch.arange(
|
| 262 |
+
past_seen_tokens,
|
| 263 |
+
past_seen_tokens + decoder_hidden_states.shape[1],
|
| 264 |
+
device=decoder_hidden_states.device,
|
| 265 |
+
)
|
| 266 |
+
if position_ids is None:
|
| 267 |
+
position_ids = cache_position.unsqueeze(0)
|
| 268 |
+
decoder_position_embeddings = self.decoder.model.rotary_emb(
|
| 269 |
+
decoder_hidden_states, position_ids
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
q_start_idx = encoder_last_hidden_state.shape[1]
|
| 273 |
+
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
|
| 274 |
+
decoder_hidden_states = torch.cat(
|
| 275 |
+
[
|
| 276 |
+
encoder_last_hidden_state,
|
| 277 |
+
decoder_hidden_states,
|
| 278 |
+
],
|
| 279 |
+
dim=1,
|
| 280 |
+
)
|
| 281 |
+
if cache_position is None:
|
| 282 |
+
if position_ids is not None:
|
| 283 |
+
cache_position = position_ids[0]
|
| 284 |
+
else:
|
| 285 |
+
past_seen_tokens = (
|
| 286 |
+
past_key_values.get_seq_length()
|
| 287 |
+
if past_key_values is not None
|
| 288 |
+
else 0
|
| 289 |
+
)
|
| 290 |
+
cache_position = torch.cat(
|
| 291 |
+
[
|
| 292 |
+
torch.arange( # clean token position ids
|
| 293 |
+
past_seen_tokens,
|
| 294 |
+
past_seen_tokens + encoder_last_hidden_state.shape[1],
|
| 295 |
+
device=decoder_hidden_states.device,
|
| 296 |
+
),
|
| 297 |
+
torch.arange( # noisy position ids
|
| 298 |
+
past_seen_tokens + new_seen_tokens,
|
| 299 |
+
past_seen_tokens + new_seen_tokens + input_ids.shape[1],
|
| 300 |
+
device=decoder_hidden_states.device,
|
| 301 |
+
),
|
| 302 |
+
],
|
| 303 |
+
dim=-1,
|
| 304 |
+
)
|
| 305 |
+
if position_ids is None:
|
| 306 |
+
position_ids = cache_position.unsqueeze(0)
|
| 307 |
+
decoder_position_embeddings = self.decoder.model.rotary_emb(
|
| 308 |
+
decoder_hidden_states, position_ids
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
|
| 312 |
+
# noinspection PyProtectedMember
|
| 313 |
+
attention_mask = self.decoder.model._update_causal_mask(
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
input_tensor=decoder_hidden_states,
|
| 316 |
+
cache_position=cache_position,
|
| 317 |
+
past_key_values=past_key_values,
|
| 318 |
+
output_attentions=False,
|
| 319 |
+
)
|
| 320 |
+
for decoder_layer in self.decoder.model.layers:
|
| 321 |
+
layer_idx = decoder_layer.self_attn.layer_idx
|
| 322 |
+
if (
|
| 323 |
+
self.tie_encoder_decoder_weights
|
| 324 |
+
and layer_idx not in self.decoder_layer_idxs
|
| 325 |
+
):
|
| 326 |
+
continue
|
| 327 |
+
# past_key_values gets updated in-place.
|
| 328 |
+
# Record previous length to re-truncate after each layer forward
|
| 329 |
+
if past_key_values is not None and len(past_key_values) > layer_idx:
|
| 330 |
+
prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
|
| 331 |
+
else:
|
| 332 |
+
prev_cache_len = 0
|
| 333 |
+
cache_len = prev_cache_len + new_seen_tokens
|
| 334 |
+
|
| 335 |
+
if self.decoder.model.gradient_checkpointing and self.training:
|
| 336 |
+
# noinspection PyProtectedMember
|
| 337 |
+
decoder_hidden_states = self.decoder._gradient_checkpointing_func(
|
| 338 |
+
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
| 339 |
+
decoder_hidden_states, # hidden_states=,
|
| 340 |
+
attention_mask, # attention_mask=,
|
| 341 |
+
position_ids, # position_ids=,
|
| 342 |
+
past_key_values, # past_key_value=,
|
| 343 |
+
False, # output_attentions=,
|
| 344 |
+
True, # use_cache=,
|
| 345 |
+
cache_position, # cache_position=,
|
| 346 |
+
decoder_position_embeddings, # position_embeddings=,
|
| 347 |
+
q_start_idx, # q_start_idx=
|
| 348 |
+
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
|
| 349 |
+
else:
|
| 350 |
+
decoder_hidden_states = decoder_layer(
|
| 351 |
+
hidden_states=decoder_hidden_states,
|
| 352 |
+
attention_mask=attention_mask,
|
| 353 |
+
position_ids=position_ids,
|
| 354 |
+
past_key_value=past_key_values,
|
| 355 |
+
output_attentions=False,
|
| 356 |
+
use_cache=True,
|
| 357 |
+
cache_position=cache_position,
|
| 358 |
+
position_embeddings=decoder_position_embeddings,
|
| 359 |
+
q_start_idx=q_start_idx, # Indicates where to slice output
|
| 360 |
+
**flash_attn_kwargs,
|
| 361 |
+
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
|
| 362 |
+
# Update decoder_hidden_states
|
| 363 |
+
if q_start_idx > 0:
|
| 364 |
+
decoder_hidden_states = torch.cat(
|
| 365 |
+
[
|
| 366 |
+
encoder_last_hidden_state,
|
| 367 |
+
decoder_hidden_states,
|
| 368 |
+
],
|
| 369 |
+
dim=1,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
if past_key_values is not None:
|
| 373 |
+
# DynamicCache extends along sequence dimension by default;
|
| 374 |
+
# truncate back to original cache len + encoder output length
|
| 375 |
+
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
|
| 376 |
+
layer_idx
|
| 377 |
+
][..., :cache_len, :]
|
| 378 |
+
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
|
| 379 |
+
layer_idx
|
| 380 |
+
][..., :cache_len, :]
|
| 381 |
+
decoder_hidden_states = self.decoder.model.norm(
|
| 382 |
+
decoder_hidden_states[:, q_start_idx:, :]
|
| 383 |
+
)
|
| 384 |
+
logits = self.decoder.lm_head(decoder_hidden_states)
|
| 385 |
+
return DecoderCausalLMOutputWithPast(
|
| 386 |
+
logits=logits,
|
| 387 |
+
past_key_values=past_key_values,
|
| 388 |
+
encoder_past_key_values=encoder_past_key_values,
|
| 389 |
+
# Do not need to store encoder_last_hidden_state.
|
| 390 |
+
# If it was passed in, then it has become part of the past_key_values cache.
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class LLMasEncoderDecoderShareKV(nn.Module):
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
pretrained_model_name_or_path: str,
|
| 398 |
+
max_length: int,
|
| 399 |
+
attn_backend: str = "sdpa",
|
| 400 |
+
freeze_encoder: bool = False,
|
| 401 |
+
reinit_encoder: bool = False,
|
| 402 |
+
reinit_decoder: bool = False,
|
| 403 |
+
tie_encoder_decoder_weights: bool = False,
|
| 404 |
+
use_encoder_causal_mask: bool = False,
|
| 405 |
+
num_encoder_layers: int = -1,
|
| 406 |
+
num_decoder_layers: int = -1,
|
| 407 |
+
keep_top_encoder_layers: bool = False,
|
| 408 |
+
keep_top_decoder_layers: bool = False,
|
| 409 |
+
use_gradient_checkpointing: bool = False,
|
| 410 |
+
**llm_init_kwargs,
|
| 411 |
+
):
|
| 412 |
+
assert not (tie_encoder_decoder_weights and reinit_decoder), (
|
| 413 |
+
"Cannot tie encoder-decoder weights and reinitialize decoder."
|
| 414 |
+
)
|
| 415 |
+
assert not (tie_encoder_decoder_weights and freeze_encoder), (
|
| 416 |
+
"Cannot freeze encoder weights when tying encoder-decoder weights."
|
| 417 |
+
)
|
| 418 |
+
super().__init__()
|
| 419 |
+
self.use_encoder_causal_mask = use_encoder_causal_mask
|
| 420 |
+
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
|
| 421 |
+
|
| 422 |
+
if reinit_encoder:
|
| 423 |
+
assert num_encoder_layers > 0
|
| 424 |
+
encoder_config = AutoConfig.from_pretrained(
|
| 425 |
+
pretrained_model_name_or_path,
|
| 426 |
+
trust_remote_code=True,
|
| 427 |
+
num_hidden_layers=num_encoder_layers,
|
| 428 |
+
attn_implementation=attn_backend,
|
| 429 |
+
**llm_init_kwargs,
|
| 430 |
+
)
|
| 431 |
+
self.encoder = AutoModelForCausalLM.from_config(encoder_config)
|
| 432 |
+
else:
|
| 433 |
+
self.encoder = AutoModelForCausalLM.from_pretrained(
|
| 434 |
+
pretrained_model_name_or_path,
|
| 435 |
+
trust_remote_code=True,
|
| 436 |
+
attn_implementation=attn_backend,
|
| 437 |
+
**llm_init_kwargs,
|
| 438 |
+
)
|
| 439 |
+
assert num_encoder_layers <= len(self.encoder.model.layers), (
|
| 440 |
+
f"Cannot keep {num_encoder_layers} layers. "
|
| 441 |
+
f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
|
| 442 |
+
)
|
| 443 |
+
num_encoder_layers = (
|
| 444 |
+
len(self.encoder.model.layers)
|
| 445 |
+
if num_encoder_layers == -1
|
| 446 |
+
else num_encoder_layers
|
| 447 |
+
)
|
| 448 |
+
if keep_top_encoder_layers:
|
| 449 |
+
self.encoder.model.layers = self.encoder.model.layers[
|
| 450 |
+
-num_encoder_layers:
|
| 451 |
+
]
|
| 452 |
+
else:
|
| 453 |
+
self.encoder.model.layers = self.encoder.model.layers[
|
| 454 |
+
:num_encoder_layers
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
if freeze_encoder:
|
| 458 |
+
for name, param in self.encoder.named_parameters():
|
| 459 |
+
if "embed_tokens" not in name:
|
| 460 |
+
param.requires_grad = False
|
| 461 |
+
if use_gradient_checkpointing:
|
| 462 |
+
self.encoder.gradient_checkpointing_enable()
|
| 463 |
+
|
| 464 |
+
if tie_encoder_decoder_weights:
|
| 465 |
+
self.decoder = self.encoder
|
| 466 |
+
num_decoder_layers = (
|
| 467 |
+
len(self.decoder.model.layers)
|
| 468 |
+
if num_decoder_layers == -1
|
| 469 |
+
else num_decoder_layers
|
| 470 |
+
)
|
| 471 |
+
assert num_decoder_layers <= len(self.decoder.model.layers), (
|
| 472 |
+
f"Cannot keep {num_decoder_layers} layers. "
|
| 473 |
+
f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
|
| 474 |
+
)
|
| 475 |
+
# Keep **top** layers when tying weights
|
| 476 |
+
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
|
| 477 |
+
-num_decoder_layers:
|
| 478 |
+
]
|
| 479 |
+
|
| 480 |
+
else:
|
| 481 |
+
if reinit_decoder:
|
| 482 |
+
assert num_decoder_layers > 0
|
| 483 |
+
decoder_config = AutoConfig.from_pretrained(
|
| 484 |
+
pretrained_model_name_or_path,
|
| 485 |
+
trust_remote_code=True,
|
| 486 |
+
num_hidden_layers=num_decoder_layers,
|
| 487 |
+
attn_implementation=attn_backend,
|
| 488 |
+
**llm_init_kwargs,
|
| 489 |
+
)
|
| 490 |
+
self.decoder = AutoModelForCausalLM(decoder_config)
|
| 491 |
+
else:
|
| 492 |
+
self.decoder = AutoModelForCausalLM.from_pretrained(
|
| 493 |
+
pretrained_model_name_or_path,
|
| 494 |
+
trust_remote_code=True,
|
| 495 |
+
attn_implementation=attn_backend,
|
| 496 |
+
**llm_init_kwargs,
|
| 497 |
+
)
|
| 498 |
+
assert num_decoder_layers <= len(self.decoder.model.layers), (
|
| 499 |
+
f"Cannot keep {num_decoder_layers} layers. "
|
| 500 |
+
f"Pre-trained model only has {len(self.decoder.layers)} layers."
|
| 501 |
+
)
|
| 502 |
+
if keep_top_decoder_layers:
|
| 503 |
+
self.decoder.model.layers = self.decoder.model.layers[
|
| 504 |
+
-num_decoder_layers:
|
| 505 |
+
]
|
| 506 |
+
else:
|
| 507 |
+
self.decoder.model.layers = self.decoder.model.layers[
|
| 508 |
+
:num_decoder_layers
|
| 509 |
+
]
|
| 510 |
+
del self.decoder.model.embed_tokens
|
| 511 |
+
# Even for frozen encoder, ensure embedding tokens are trainable
|
| 512 |
+
self.encoder.model.embed_tokens.requires_grad_(True)
|
| 513 |
+
unused_self_attn_params = ["o_proj", "q_norm", "q_proj"]
|
| 514 |
+
unused_layernorm_params = ["input_layernorm", "post_attention_layernorm"]
|
| 515 |
+
for unused_param in unused_self_attn_params:
|
| 516 |
+
if hasattr(self.encoder.model.layers[-1].self_attn, unused_param):
|
| 517 |
+
getattr(
|
| 518 |
+
self.encoder.model.layers[-1].self_attn, unused_param
|
| 519 |
+
).requires_grad_(False)
|
| 520 |
+
self.encoder.model.layers[-1].mlp.requires_grad_(False)
|
| 521 |
+
self.encoder.model.norm.requires_grad_(False)
|
| 522 |
+
for unused_param in unused_layernorm_params:
|
| 523 |
+
if hasattr(self.encoder.model.layers[-1], unused_param):
|
| 524 |
+
getattr(self.encoder.model.layers[-1], unused_param).requires_grad_(
|
| 525 |
+
False
|
| 526 |
+
)
|
| 527 |
+
# if in the original LM, the lm_head is weight-tied to embedding,
|
| 528 |
+
# point decoder lm_head to encoder's (instead of initializing separately)
|
| 529 |
+
if (
|
| 530 |
+
self.encoder.lm_head.weight.data_ptr()
|
| 531 |
+
== self.encoder.model.embed_tokens.weight.data_ptr()
|
| 532 |
+
):
|
| 533 |
+
self.decoder.lm_head = self.encoder.lm_head
|
| 534 |
+
else:
|
| 535 |
+
del self.encoder.lm_head
|
| 536 |
+
if use_gradient_checkpointing:
|
| 537 |
+
self.decoder.gradient_checkpointing_enable()
|
| 538 |
+
self.max_length = max_length
|
| 539 |
+
|
| 540 |
+
def freeze_encoder(self):
|
| 541 |
+
for p in self.encoder.model.parameters():
|
| 542 |
+
p.requires_grad = False
|
| 543 |
+
|
| 544 |
+
def unfreeze_encoder(self):
|
| 545 |
+
for p in self.encoder.model.parameters():
|
| 546 |
+
p.requires_grad = True
|
| 547 |
+
|
| 548 |
+
# noinspection PyUnusedLocal
|
| 549 |
+
def forward(
|
| 550 |
+
self,
|
| 551 |
+
# Decoder inputs
|
| 552 |
+
input_ids: torch.LongTensor,
|
| 553 |
+
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
|
| 554 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 555 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 556 |
+
past_key_values: Optional[DynamicCache] = None,
|
| 557 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None, # Not used
|
| 558 |
+
# Encoder inputs
|
| 559 |
+
encoder_input_ids: Optional[torch.LongTensor] = None,
|
| 560 |
+
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
|
| 561 |
+
encoder_position_ids: Optional[torch.LongTensor] = None,
|
| 562 |
+
encoder_cache_position: Optional[torch.LongTensor] = None,
|
| 563 |
+
encoder_past_key_values: Optional[DynamicCache] = None, # Not used
|
| 564 |
+
# Additional args
|
| 565 |
+
fix_cache_length: bool = True, # Not used; compatibility with other backbones
|
| 566 |
+
return_updated_cache: bool = False,
|
| 567 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 568 |
+
) -> Union[CausalLMOutputWithPast, BaseModelOutputWithPast]:
|
| 569 |
+
# Encode clean tokens
|
| 570 |
+
if encoder_input_ids is not None:
|
| 571 |
+
if self.use_encoder_causal_mask:
|
| 572 |
+
encoder_attention_mask = None # None --> enforces use of causal mask
|
| 573 |
+
if encoder_cache_position is None and encoder_position_ids is not None:
|
| 574 |
+
encoder_cache_position = encoder_position_ids[0]
|
| 575 |
+
past_key_values = self.encoder.model(
|
| 576 |
+
input_ids=encoder_input_ids,
|
| 577 |
+
attention_mask=encoder_attention_mask,
|
| 578 |
+
position_ids=encoder_position_ids,
|
| 579 |
+
use_cache=True,
|
| 580 |
+
past_key_values=past_key_values,
|
| 581 |
+
cache_position=encoder_cache_position,
|
| 582 |
+
).past_key_values
|
| 583 |
+
if return_updated_cache:
|
| 584 |
+
# encoder_output.past_key_values now contains latest encoder input
|
| 585 |
+
return BaseModelOutputWithPast(
|
| 586 |
+
past_key_values=past_key_values,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Run decoder with xattn to clean token hidden states
|
| 590 |
+
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
|
| 591 |
+
if cache_position is None:
|
| 592 |
+
if position_ids is not None:
|
| 593 |
+
cache_position = position_ids[0]
|
| 594 |
+
else: # During training / validation position_ids are not provided
|
| 595 |
+
cache_position = torch.arange(
|
| 596 |
+
decoder_hidden_states.shape[1],
|
| 597 |
+
device=decoder_hidden_states.device,
|
| 598 |
+
)
|
| 599 |
+
if position_ids is None:
|
| 600 |
+
position_ids = cache_position.unsqueeze(0)
|
| 601 |
+
decoder_position_embeddings = self.decoder.model.rotary_emb(
|
| 602 |
+
decoder_hidden_states, position_ids
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
|
| 606 |
+
# noinspection PyProtectedMember
|
| 607 |
+
attention_mask = self.decoder.model._update_causal_mask(
|
| 608 |
+
attention_mask=attention_mask,
|
| 609 |
+
input_tensor=decoder_hidden_states,
|
| 610 |
+
cache_position=cache_position,
|
| 611 |
+
past_key_values=past_key_values,
|
| 612 |
+
output_attentions=False,
|
| 613 |
+
)
|
| 614 |
+
for decoder_layer in self.decoder.model.layers:
|
| 615 |
+
layer_idx = decoder_layer.self_attn.layer_idx
|
| 616 |
+
if (
|
| 617 |
+
self.tie_encoder_decoder_weights
|
| 618 |
+
and layer_idx not in self.decoder_layer_idxs
|
| 619 |
+
):
|
| 620 |
+
continue
|
| 621 |
+
# past_key_values gets updated in-place.
|
| 622 |
+
# Record previous length to truncate after each layer forward
|
| 623 |
+
if past_key_values is not None and len(past_key_values) > layer_idx:
|
| 624 |
+
prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
|
| 625 |
+
else:
|
| 626 |
+
prev_cache_len = 0
|
| 627 |
+
|
| 628 |
+
decoder_hidden_states = decoder_layer(
|
| 629 |
+
hidden_states=decoder_hidden_states,
|
| 630 |
+
attention_mask=attention_mask,
|
| 631 |
+
position_ids=position_ids,
|
| 632 |
+
past_key_value=past_key_values,
|
| 633 |
+
output_attentions=False,
|
| 634 |
+
use_cache=True,
|
| 635 |
+
cache_position=position_ids[0],
|
| 636 |
+
position_embeddings=decoder_position_embeddings,
|
| 637 |
+
**flash_attn_kwargs,
|
| 638 |
+
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
|
| 639 |
+
|
| 640 |
+
if past_key_values is not None:
|
| 641 |
+
# DynamicCache extends along sequence dimension by default;
|
| 642 |
+
# truncate back to original cache len + encoder output length
|
| 643 |
+
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
|
| 644 |
+
layer_idx
|
| 645 |
+
][..., :prev_cache_len, :]
|
| 646 |
+
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
|
| 647 |
+
layer_idx
|
| 648 |
+
][..., :prev_cache_len, :]
|
| 649 |
+
decoder_hidden_states = self.decoder.model.norm(decoder_hidden_states)
|
| 650 |
+
logits = self.decoder.lm_head(decoder_hidden_states)
|
| 651 |
+
return CausalLMOutputWithPast(
|
| 652 |
+
logits=logits,
|
| 653 |
+
past_key_values=past_key_values,
|
| 654 |
+
)
|
denoiser_base.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import inspect
|
| 3 |
+
import sys
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import hydra.utils
|
| 10 |
+
import torch
|
| 11 |
+
from hydra.errors import InstantiationException
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
DynamicCache,
|
| 15 |
+
GenerationConfig,
|
| 16 |
+
LogitsProcessorList,
|
| 17 |
+
PretrainedConfig,
|
| 18 |
+
PreTrainedModel,
|
| 19 |
+
StoppingCriteriaList,
|
| 20 |
+
)
|
| 21 |
+
from transformers.cache_utils import Cache
|
| 22 |
+
from transformers.generation.utils import GenerateOutput
|
| 23 |
+
from transformers.modeling_outputs import ModelOutput
|
| 24 |
+
|
| 25 |
+
# Local imports not used, but added here so that HF push_to_hub adds them to model repo
|
| 26 |
+
# noinspection PyUnresolvedReferences
|
| 27 |
+
from .backbone_automodel import AutoModelFromPreTrained # noqa: F401
|
| 28 |
+
from .backbone_encoder_decoder import ( # noqa: F401
|
| 29 |
+
LLMasEncoderDecoder,
|
| 30 |
+
LLMasEncoderDecoderShareKV,
|
| 31 |
+
)
|
| 32 |
+
from .noise_schedule_noise_schedules import ( # noqa: F401
|
| 33 |
+
CosineNoise,
|
| 34 |
+
ExponentialNoise,
|
| 35 |
+
LinearNoise,
|
| 36 |
+
LogarithmicNoise,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class DenoiserInput(OrderedDict):
|
| 42 |
+
"""Input to the denoiser model."""
|
| 43 |
+
|
| 44 |
+
xt: torch.LongTensor # (B, L) token_ids
|
| 45 |
+
x0: Optional[torch.LongTensor] = None # (B, L) token_ids (not used in gen.)
|
| 46 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
| 47 |
+
past_key_values: Optional[Union[torch.FloatTensor, Cache]] = None
|
| 48 |
+
context_mask: Optional[torch.FloatTensor] = None
|
| 49 |
+
tokens_mask: Optional[torch.FloatTensor] = None # (B, L)
|
| 50 |
+
t: Optional[torch.FloatTensor] = None # (B,) | # (B, L)
|
| 51 |
+
alpha_t: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1)
|
| 52 |
+
alpha_t_prime: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1)
|
| 53 |
+
backbone_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class LossAndNllOutput(OrderedDict):
|
| 58 |
+
"""Loss output for denoiser models."""
|
| 59 |
+
|
| 60 |
+
loss: torch.FloatTensor
|
| 61 |
+
nlls: torch.FloatTensor
|
| 62 |
+
other_loss_terms: dict = field(default_factory=dict)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class DenoiserOutput(ModelOutput):
|
| 67 |
+
"""Output of the denoiser model."""
|
| 68 |
+
|
| 69 |
+
denoiser_output: Optional[torch.FloatTensor] = None
|
| 70 |
+
logits: Optional[torch.FloatTensor] = None
|
| 71 |
+
tokens_mask: Optional[torch.FloatTensor] = None # Which tokens contribute to loss
|
| 72 |
+
past_key_values: Optional[Cache] = None
|
| 73 |
+
loss: Optional[torch.FloatTensor] = None
|
| 74 |
+
nlls: Optional[torch.FloatTensor] = None
|
| 75 |
+
other_loss_terms: Optional[dict[str, Any]] = None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DenoiserConfig(PretrainedConfig):
|
| 79 |
+
"""Configuration class for Denoiser models.
|
| 80 |
+
|
| 81 |
+
This class is used to initialize the model and contains all the necessary
|
| 82 |
+
parameters for the model's architecture.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_type = "denoiser"
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
length: Optional[int] = None,
|
| 90 |
+
backbone_config: Optional[Dict[str, Any]] = None,
|
| 91 |
+
noise_config: Optional[Dict[str, Any]] = None,
|
| 92 |
+
tokenization_config: Optional[Dict[str, Any]] = None,
|
| 93 |
+
time_conditioned_backbone: Optional[bool] = None,
|
| 94 |
+
attn_backend: str = "sdpa", # "sdpa", "flash_attention_2", "flex_attention"
|
| 95 |
+
train_on_context: bool = False,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
super().__init__(**kwargs)
|
| 99 |
+
for v in [
|
| 100 |
+
"vocab_size",
|
| 101 |
+
"mask_token_id",
|
| 102 |
+
"pad_token_id",
|
| 103 |
+
"bos_token_id",
|
| 104 |
+
"eos_token_id",
|
| 105 |
+
"pad_vocab_size_multiple",
|
| 106 |
+
]:
|
| 107 |
+
if tokenization_config is not None and (
|
| 108 |
+
getattr(self, v, None) is None or v in tokenization_config
|
| 109 |
+
):
|
| 110 |
+
setattr(self, v, tokenization_config.get(v, None))
|
| 111 |
+
else:
|
| 112 |
+
setattr(self, v, None)
|
| 113 |
+
self.backbone_config = backbone_config
|
| 114 |
+
self.noise_config = noise_config
|
| 115 |
+
self.tokenization_config = tokenization_config
|
| 116 |
+
self.length = length
|
| 117 |
+
self.time_conditioned_backbone = time_conditioned_backbone
|
| 118 |
+
self.attn_backend = attn_backend
|
| 119 |
+
self.train_on_context = train_on_context
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Denoiser(ABC, PreTrainedModel):
|
| 123 |
+
"""Abstract base class for denoising models.
|
| 124 |
+
|
| 125 |
+
This class defines the interface for AR, Diffusion, and Flow-based parametrizations.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
config_class = DenoiserConfig
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
config: DenoiserConfig,
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
"""
|
| 136 |
+
Initialize the Denoiser with a configuration and optional dataset type.
|
| 137 |
+
|
| 138 |
+
Parameters:
|
| 139 |
+
config (Any): Configuration object for the model.
|
| 140 |
+
"""
|
| 141 |
+
super().__init__(config)
|
| 142 |
+
self.config = config
|
| 143 |
+
self.vocab_size = config.vocab_size
|
| 144 |
+
self.mask_token_id = config.mask_token_id
|
| 145 |
+
self.pad_token_id = config.pad_token_id
|
| 146 |
+
self.bos_token_id = config.bos_token_id
|
| 147 |
+
self.eos_token_id = config.eos_token_id
|
| 148 |
+
try:
|
| 149 |
+
self.backbone = hydra.utils.instantiate(config.backbone_config)
|
| 150 |
+
except InstantiationException:
|
| 151 |
+
# When using HF and `from_pretrained`, the modules specified in `_target_`
|
| 152 |
+
# fields in our configs are already being imported under a name with the
|
| 153 |
+
# following format: transformers_modules.<repo_id>.<commit_id>.
|
| 154 |
+
# When hydra attempts to instantiate and calls importlib under the hood, the
|
| 155 |
+
# desired module is not found.
|
| 156 |
+
# The snippet below aliases the desired module, enabling seamless use of
|
| 157 |
+
# `hydra.utils.instantiate`.
|
| 158 |
+
sys_modules = copy.deepcopy(list(sys.modules.keys()))
|
| 159 |
+
repo_root_module = ".".join(__name__.split(".")[:-1])
|
| 160 |
+
for name in sys_modules:
|
| 161 |
+
if name.startswith(repo_root_module):
|
| 162 |
+
short = name.split(".")[-1]
|
| 163 |
+
if short not in sys.modules:
|
| 164 |
+
sys.modules[short] = sys.modules[name]
|
| 165 |
+
del sys_modules
|
| 166 |
+
self.backbone = hydra.utils.instantiate(config.backbone_config)
|
| 167 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 168 |
+
config.tokenizer_name,
|
| 169 |
+
trust_remote_code=True,
|
| 170 |
+
)
|
| 171 |
+
self.noise_schedule = (
|
| 172 |
+
hydra.utils.instantiate(config.noise_config)
|
| 173 |
+
if config.noise_config is not None
|
| 174 |
+
else None
|
| 175 |
+
)
|
| 176 |
+
self.time_conditioned_backbone = (
|
| 177 |
+
config.time_conditioned_backbone
|
| 178 |
+
if config.time_conditioned_backbone is not None
|
| 179 |
+
else "noise" in inspect.getfullargspec(self.backbone.forward).args
|
| 180 |
+
)
|
| 181 |
+
# List that can contain any parameters that should not be pushed to HF,
|
| 182 |
+
# e.g., registered buffers for static attention masks
|
| 183 |
+
self.skip_params_for_push = []
|
| 184 |
+
|
| 185 |
+
@abstractmethod
|
| 186 |
+
def _prepare_inputs(
|
| 187 |
+
self,
|
| 188 |
+
input_ids: torch.LongTensor,
|
| 189 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 190 |
+
context_mask: Optional[torch.FloatTensor] = None,
|
| 191 |
+
t: Optional[torch.FloatTensor] = None,
|
| 192 |
+
past_key_values: Optional[Cache] = None,
|
| 193 |
+
) -> DenoiserInput:
|
| 194 |
+
"""
|
| 195 |
+
Prepare inputs for the model.
|
| 196 |
+
|
| 197 |
+
Parameters:
|
| 198 |
+
input_ids (LongTensor): Input tensor to the model.
|
| 199 |
+
attention_mask (Optional[FloatTensor]): Attention mask for the model.
|
| 200 |
+
t (Optional[FloatTensor]): Time step for the model.
|
| 201 |
+
past_key_values (Optional[Cache]): Past key values for the model.
|
| 202 |
+
Returns:
|
| 203 |
+
Denoiser inputs.
|
| 204 |
+
"""
|
| 205 |
+
raise NotImplementedError("Denoiser subclasses must implement _prepare_inputs")
|
| 206 |
+
|
| 207 |
+
def _prepare_inputs_inference(
|
| 208 |
+
self,
|
| 209 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 210 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 211 |
+
context: Optional[torch.LongTensor] = None,
|
| 212 |
+
context_mask: Optional[torch.FloatTensor] = None,
|
| 213 |
+
cache: Optional[Dict[str, Any]] = None,
|
| 214 |
+
**backbone_kwargs: Any,
|
| 215 |
+
) -> Tuple[DenoiserInput, Dict[str, Any]]:
|
| 216 |
+
raise NotImplementedError(
|
| 217 |
+
"Denoiser subclasses must implement _prepare_inputs_inference"
|
| 218 |
+
)
|
| 219 |
+
# assert input_ids is not None or context is not None, (
|
| 220 |
+
# "Must provide either input_ids or context."
|
| 221 |
+
# )
|
| 222 |
+
# cache = cache if cache is not None else {}
|
| 223 |
+
# past_key_values = cache.pop("past_key_values", DynamicCache())
|
| 224 |
+
# if context is not None:
|
| 225 |
+
# if input_ids is not None:
|
| 226 |
+
# if context_mask is None:
|
| 227 |
+
# context_mask = torch.cat(
|
| 228 |
+
# [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1
|
| 229 |
+
# )
|
| 230 |
+
# input_ids = torch.cat([context, input_ids], dim=-1)
|
| 231 |
+
# else:
|
| 232 |
+
# input_ids = context
|
| 233 |
+
# context_mask = torch.ones_like(input_ids)
|
| 234 |
+
# if attention_mask is None:
|
| 235 |
+
# cache_length = self._get_past_key_values_seq_length(past_key_values)
|
| 236 |
+
# full_seq_length = cache_length + input_ids.shape[-1]
|
| 237 |
+
# attention_mask = torch.ones(
|
| 238 |
+
# (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length),
|
| 239 |
+
# device=input_ids.device,
|
| 240 |
+
# ) # Make attention mask 4D
|
| 241 |
+
# attention_mask = self._preprocess_attention_mask(
|
| 242 |
+
# attention_mask, dtype=torch.float
|
| 243 |
+
# )
|
| 244 |
+
# return DenoiserInput(
|
| 245 |
+
# xt=input_ids,
|
| 246 |
+
# attention_mask=attention_mask,
|
| 247 |
+
# past_key_values=past_key_values,
|
| 248 |
+
# context_mask=context_mask,
|
| 249 |
+
# backbone_kwargs=backbone_kwargs,
|
| 250 |
+
# ), cache
|
| 251 |
+
|
| 252 |
+
@abstractmethod
|
| 253 |
+
def _compute_loss(
|
| 254 |
+
self,
|
| 255 |
+
model_output: torch.FloatTensor,
|
| 256 |
+
denoiser_inputs: DenoiserInput,
|
| 257 |
+
**kwargs: Any,
|
| 258 |
+
) -> LossAndNllOutput:
|
| 259 |
+
"""
|
| 260 |
+
Compute the loss for the denoising model.
|
| 261 |
+
|
| 262 |
+
Parameters:
|
| 263 |
+
model_output (FloatTensor): Output tensor from self.forward.
|
| 264 |
+
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
LossAndNllOutput: loss (FloatTensor) and nlls (FloatTensor).
|
| 268 |
+
"""
|
| 269 |
+
raise NotImplementedError("Denoiser subclasses must implement _compute_loss")
|
| 270 |
+
|
| 271 |
+
def _forward(
|
| 272 |
+
self,
|
| 273 |
+
backbone_output: torch.FloatTensor,
|
| 274 |
+
denoiser_inputs: DenoiserInput,
|
| 275 |
+
**kwargs: Any,
|
| 276 |
+
) -> torch.FloatTensor:
|
| 277 |
+
"""
|
| 278 |
+
Forward pass for the denoiser model returns probabilities over denoised
|
| 279 |
+
sequence.
|
| 280 |
+
|
| 281 |
+
Some classes may need to override this method.
|
| 282 |
+
|
| 283 |
+
Parameters:
|
| 284 |
+
backbone_output (FloatTensor): Output tensor from the backbone model.
|
| 285 |
+
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Model outputs (FloatTensor).
|
| 289 |
+
"""
|
| 290 |
+
return torch.log_softmax(backbone_output, dim=-1) # type: ignore
|
| 291 |
+
|
| 292 |
+
def _backbone_forward(
|
| 293 |
+
self,
|
| 294 |
+
denoiser_inputs: DenoiserInput,
|
| 295 |
+
**backbone_kwargs: Any,
|
| 296 |
+
) -> ModelOutput:
|
| 297 |
+
"""Forward pass for the backbone model (should return logits).
|
| 298 |
+
|
| 299 |
+
Some classes may need to override this method.
|
| 300 |
+
|
| 301 |
+
Parameters:
|
| 302 |
+
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
|
| 303 |
+
return_updated_cache (bool): If True, return past_key_values instead of
|
| 304 |
+
logits.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
Backbone output (ModelOutput instance).
|
| 308 |
+
"""
|
| 309 |
+
if self.time_conditioned_backbone:
|
| 310 |
+
return self.backbone(
|
| 311 |
+
denoiser_inputs.xt,
|
| 312 |
+
attention_mask=denoiser_inputs.attention_mask,
|
| 313 |
+
past_key_values=denoiser_inputs.past_key_values,
|
| 314 |
+
noise=denoiser_inputs.alpha_t,
|
| 315 |
+
**denoiser_inputs.backbone_kwargs,
|
| 316 |
+
**backbone_kwargs,
|
| 317 |
+
)
|
| 318 |
+
return self.backbone(
|
| 319 |
+
denoiser_inputs.xt,
|
| 320 |
+
attention_mask=denoiser_inputs.attention_mask,
|
| 321 |
+
past_key_values=denoiser_inputs.past_key_values,
|
| 322 |
+
**denoiser_inputs.backbone_kwargs,
|
| 323 |
+
**backbone_kwargs,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
input_ids: torch.LongTensor,
|
| 329 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 330 |
+
context_mask: Optional[torch.FloatTensor] = None,
|
| 331 |
+
t: Optional[torch.FloatTensor] = None,
|
| 332 |
+
past_key_values: Optional[Cache] = None,
|
| 333 |
+
compute_loss: Optional[bool] = True,
|
| 334 |
+
**kwargs,
|
| 335 |
+
) -> DenoiserOutput:
|
| 336 |
+
"""
|
| 337 |
+
Perform a forward pass through the denoising model and
|
| 338 |
+
(optionally) compute the loss.
|
| 339 |
+
|
| 340 |
+
Parameters:
|
| 341 |
+
input_ids (LongTensor): Input tensor to the model.
|
| 342 |
+
attention_mask (Optional[FloatTensor]): Attention mask for the model.
|
| 343 |
+
context_mask (Optional[FloatTensor]): Indicator for context tokens.
|
| 344 |
+
t (Optional[FloatTensor]): Denoising time step for the model.
|
| 345 |
+
past_key_values (Optional[Cache]): KV cache.
|
| 346 |
+
compute_loss (Optional[bool]): Flag to compute loss.
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
DenoiserOutput
|
| 350 |
+
"""
|
| 351 |
+
denoiser_inputs = self._prepare_inputs(
|
| 352 |
+
input_ids=input_ids,
|
| 353 |
+
attention_mask=attention_mask,
|
| 354 |
+
context_mask=context_mask,
|
| 355 |
+
past_key_values=past_key_values,
|
| 356 |
+
t=t,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
backbone_output = self._backbone_forward(denoiser_inputs, **kwargs)
|
| 360 |
+
new_past_key_values = getattr(backbone_output, "past_key_values", None)
|
| 361 |
+
backbone_output = getattr(backbone_output, "logits", backbone_output[0])
|
| 362 |
+
denoiser_output = self._forward(
|
| 363 |
+
backbone_output,
|
| 364 |
+
denoiser_inputs,
|
| 365 |
+
**kwargs,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if compute_loss:
|
| 369 |
+
loss_and_nll = self._compute_loss(
|
| 370 |
+
model_output=denoiser_output, denoiser_inputs=denoiser_inputs, **kwargs
|
| 371 |
+
)
|
| 372 |
+
loss = loss_and_nll.loss
|
| 373 |
+
nlls = loss_and_nll.nlls
|
| 374 |
+
other_loss_terms = loss_and_nll.other_loss_terms
|
| 375 |
+
else:
|
| 376 |
+
loss, nlls = None, None
|
| 377 |
+
other_loss_terms = {}
|
| 378 |
+
|
| 379 |
+
return DenoiserOutput(
|
| 380 |
+
denoiser_output=denoiser_output,
|
| 381 |
+
logits=backbone_output,
|
| 382 |
+
past_key_values=new_past_key_values,
|
| 383 |
+
tokens_mask=denoiser_inputs.tokens_mask,
|
| 384 |
+
loss=loss,
|
| 385 |
+
nlls=nlls,
|
| 386 |
+
other_loss_terms=other_loss_terms,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
@staticmethod
|
| 390 |
+
def _sample_categorical(categorical_probs, do_sample=True):
|
| 391 |
+
"""Helper function to sample from a categorical distribution."""
|
| 392 |
+
categorical_probs = categorical_probs.to(torch.float64)
|
| 393 |
+
if not do_sample:
|
| 394 |
+
return categorical_probs.argmax(dim=-1)
|
| 395 |
+
gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to(
|
| 396 |
+
categorical_probs.dtype
|
| 397 |
+
)
|
| 398 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def _preprocess_attention_mask(attention_mask, dtype):
|
| 402 |
+
min_dtype = torch.finfo(dtype).min
|
| 403 |
+
attention_mask = torch.where(
|
| 404 |
+
(attention_mask == 0.0).bool(), # type: ignore
|
| 405 |
+
min_dtype,
|
| 406 |
+
0.0,
|
| 407 |
+
).to(dtype)
|
| 408 |
+
return attention_mask
|
| 409 |
+
|
| 410 |
+
@staticmethod
|
| 411 |
+
def _get_past_key_values_seq_length(past_key_values: DynamicCache):
|
| 412 |
+
seq_length = 0
|
| 413 |
+
for i in range(len(past_key_values)):
|
| 414 |
+
if past_key_values[i][0].shape[0] > 0: # type: ignore
|
| 415 |
+
seq_length = max(
|
| 416 |
+
past_key_values[i][0].shape[-2], # type: ignore
|
| 417 |
+
seq_length,
|
| 418 |
+
)
|
| 419 |
+
return seq_length
|
| 420 |
+
|
| 421 |
+
def update_cache(
|
| 422 |
+
self,
|
| 423 |
+
inputs: torch.LongTensor,
|
| 424 |
+
cache: Optional[Dict[str, Any]] = None,
|
| 425 |
+
**backbone_kwargs: Any,
|
| 426 |
+
) -> Dict[str, Any]:
|
| 427 |
+
"""
|
| 428 |
+
Cache the key-value pairs for the context.
|
| 429 |
+
Args:
|
| 430 |
+
inputs (torch.LongTensor): The context tensor.
|
| 431 |
+
cache (Dict[str, Any | None): Cache objects, e.g., past_key_values.
|
| 432 |
+
Returns:
|
| 433 |
+
Dict: Updated cache objects, e.g., past_key_values.
|
| 434 |
+
"""
|
| 435 |
+
context_input, cache = self._prepare_inputs_inference(
|
| 436 |
+
input_ids=inputs, cache=cache, return_updated_cache=True, **backbone_kwargs
|
| 437 |
+
)
|
| 438 |
+
backbone_output = self._backbone_forward(
|
| 439 |
+
context_input,
|
| 440 |
+
return_updated_cache=True, # Will get absorbed in backbone_kwargs
|
| 441 |
+
**cache,
|
| 442 |
+
)
|
| 443 |
+
backbone_output = {k: v for k, v in backbone_output.items()}
|
| 444 |
+
backbone_output.pop("logits", None) # Do not store logits in cache
|
| 445 |
+
cache = cache | backbone_output
|
| 446 |
+
return cache
|
| 447 |
+
|
| 448 |
+
@torch.no_grad()
|
| 449 |
+
def generate(
|
| 450 |
+
self,
|
| 451 |
+
inputs: Optional[torch.LongTensor] = None,
|
| 452 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 453 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 454 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 455 |
+
max_length: Optional[int] = None,
|
| 456 |
+
max_new_tokens: Optional[int] = None,
|
| 457 |
+
batch_size: Optional[int] = None,
|
| 458 |
+
device: Optional[str] = None,
|
| 459 |
+
**kwargs: Any,
|
| 460 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 461 |
+
"""Generates sample from denoising model.
|
| 462 |
+
Follows signature of transformers.GenerationMixin.
|
| 463 |
+
"""
|
| 464 |
+
raise NotImplementedError("Denoiser subclasses must implement generate")
|
diffusion.py
CHANGED
|
@@ -21,7 +21,7 @@ except ImportError:
|
|
| 21 |
BlockMask, and_masks, create_block_mask = None, None, None
|
| 22 |
|
| 23 |
|
| 24 |
-
from
|
| 25 |
Denoiser,
|
| 26 |
DenoiserConfig,
|
| 27 |
DenoiserInput,
|
|
|
|
| 21 |
BlockMask, and_masks, create_block_mask = None, None, None
|
| 22 |
|
| 23 |
|
| 24 |
+
from .denoiser_base import (
|
| 25 |
Denoiser,
|
| 26 |
DenoiserConfig,
|
| 27 |
DenoiserInput,
|
noise_schedule_noise_schedules.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Noise(ABC):
|
| 7 |
+
"""
|
| 8 |
+
Baseline forward method to get noise parameters at a timestep
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __call__(
|
| 12 |
+
self, t: torch.Tensor | float
|
| 13 |
+
) -> tuple[torch.Tensor | float, torch.Tensor | float]:
|
| 14 |
+
# Assume time goes from 0 to 1
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def inverse(self, alpha_t: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Inverse function to compute the timestep t from the noise schedule param.
|
| 21 |
+
"""
|
| 22 |
+
raise NotImplementedError("Inverse function not implemented")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CosineNoise(Noise):
|
| 26 |
+
def __init__(self, eps=1e-3):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.eps = eps
|
| 29 |
+
self.name = "cosine"
|
| 30 |
+
|
| 31 |
+
def __call__(self, t):
|
| 32 |
+
t = t.to(torch.float32)
|
| 33 |
+
cos = -(1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 34 |
+
sin = -(1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 35 |
+
move_chance = cos + 1
|
| 36 |
+
alpha_t_prime = sin * torch.pi / 2
|
| 37 |
+
return 1 - move_chance, alpha_t_prime
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ExponentialNoise(Noise):
|
| 41 |
+
def __init__(self, exp=2, eps=1e-3):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.eps = eps
|
| 44 |
+
self.exp = exp
|
| 45 |
+
self.name = f"exp_{exp}"
|
| 46 |
+
|
| 47 |
+
def __call__(self, t):
|
| 48 |
+
t = t.to(torch.float32)
|
| 49 |
+
move_chance = torch.pow(t, self.exp)
|
| 50 |
+
move_chance = torch.clamp(move_chance, min=self.eps)
|
| 51 |
+
alpha_t_prime = -self.exp * torch.pow(t, self.exp - 1)
|
| 52 |
+
return alpha_t_prime, 1 - move_chance
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LogarithmicNoise(Noise):
|
| 56 |
+
def __init__(self, eps=1e-3):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.eps = eps
|
| 59 |
+
self.name = "logarithmic"
|
| 60 |
+
|
| 61 |
+
def __call__(self, t):
|
| 62 |
+
t = t.to(torch.float32)
|
| 63 |
+
move_chance = torch.log1p(t) / torch.log(torch.tensor(2.0))
|
| 64 |
+
alpha_t_prime = -1 / (torch.log(torch.tensor(2.0)) * (1 + t))
|
| 65 |
+
return 1 - move_chance, alpha_t_prime
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LinearNoise(Noise):
|
| 69 |
+
def __init__(self):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.name = "linear"
|
| 72 |
+
|
| 73 |
+
def inverse(self, alpha_t):
|
| 74 |
+
return 1 - alpha_t
|
| 75 |
+
|
| 76 |
+
def __call__(self, t):
|
| 77 |
+
t = t.to(torch.float32)
|
| 78 |
+
alpha_t_prime = -torch.ones_like(t)
|
| 79 |
+
move_chance = t
|
| 80 |
+
return 1 - move_chance, alpha_t_prime
|