Feature Extraction
Transformers
PyTorch
e2d2
custom_code
yairschiff commited on
Commit
6ba80c2
·
verified ·
1 Parent(s): 30e8556

Add model and code

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .hf_cache
3
+ .idea
4
+ .ipynb_checkpoints/
5
+ .pytest_cache/
6
+ .ruff_cache/
7
+ .DS_Store
8
+ outputs/
9
+ watch_folder
__init__.py ADDED
File without changes
backbone_automodel.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModel,
8
+ AutoModelForCausalLM,
9
+ AutoModelForMaskedLM,
10
+ DynamicCache,
11
+ )
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutputWithPast,
14
+ CausalLMOutputWithPast,
15
+ )
16
+
17
+ from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
18
+
19
+ try:
20
+ from torch.nn.attention.flex_attention import BlockMask
21
+ except ImportError:
22
+ BlockMask = None
23
+
24
+ AUTO_MODEL_CLS = {
25
+ "AutoModel": AutoModel,
26
+ "AutoModelForCausalLM": AutoModelForCausalLM,
27
+ "AutoModelForMaskedLM": AutoModelForMaskedLM,
28
+ }
29
+
30
+
31
+ class AutoModelFromPreTrained(nn.Module):
32
+ """Simple wrapper class that enables using AutoModel from pre-trained."""
33
+
34
+ def __init__(
35
+ self,
36
+ automodel_cls: Literal[
37
+ "AutoModel",
38
+ "AutoModelForCausalLM",
39
+ "AutoModelForMaskedLM",
40
+ ],
41
+ pretrained_model_name_or_path: str,
42
+ trust_remote_code: bool = True,
43
+ num_layers: int = -1,
44
+ keep_top_layers: bool = False,
45
+ reinit_model: bool = False,
46
+ use_causal_mask: bool = False,
47
+ **automodel_init_kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.use_causal_mask = use_causal_mask
51
+ if reinit_model:
52
+ auto_config = AutoConfig.from_pretrained(
53
+ pretrained_model_name_or_path,
54
+ num_hidden_layers=num_layers,
55
+ trust_remote_code=trust_remote_code,
56
+ **automodel_init_kwargs,
57
+ )
58
+ self.model = CustomQwen3ForCausalLM(auto_config)
59
+ # self.model = AUTO_MODEL_CLS[automodel_cls].from_config(auto_config)
60
+ else:
61
+ self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained(
62
+ pretrained_model_name_or_path,
63
+ trust_remote_code=trust_remote_code,
64
+ **automodel_init_kwargs,
65
+ )
66
+ num_layers = (
67
+ len(self.model.model.layers) if num_layers == -1 else num_layers
68
+ )
69
+ if keep_top_layers:
70
+ self.model.model.layers = self.model.model.layers[-num_layers:]
71
+ else:
72
+ self.model.model.layers = self.model.model.layers[:num_layers]
73
+
74
+ def forward(
75
+ self,
76
+ input_ids: torch.LongTensor,
77
+ attention_mask: torch.FloatTensor | BlockMask | None = None,
78
+ position_ids: torch.LongTensor | None = None,
79
+ cache_position: torch.LongTensor | None = None,
80
+ past_key_values: DynamicCache | None = None,
81
+ fix_cache_length: bool = False, # False for AR, True for diffusion models
82
+ return_updated_cache=False,
83
+ **kwargs,
84
+ ) -> CausalLMOutputWithPast | BaseModelOutputWithPast:
85
+ prev_cache_len = None
86
+ if past_key_values is not None and fix_cache_length:
87
+ prev_cache_len = [
88
+ past_key_values[i][0].shape[-2] # type: ignore
89
+ for i in range(len(past_key_values))
90
+ ]
91
+ if self.use_causal_mask:
92
+ attention_mask = None # None --> enforces use of causal mask
93
+ model_output = self.model(
94
+ input_ids,
95
+ attention_mask=attention_mask,
96
+ position_ids=position_ids,
97
+ cache_position=cache_position,
98
+ past_key_values=past_key_values,
99
+ **kwargs,
100
+ )
101
+ if return_updated_cache:
102
+ return BaseModelOutputWithPast(past_key_values=model_output.past_key_values)
103
+ if (
104
+ prev_cache_len is not None
105
+ and model_output.get("past_key_values", None) is not None
106
+ ):
107
+ # DynamicCache extends along sequence dimension by default;
108
+ # truncate back to original cache len
109
+ for i, cache_len in enumerate(prev_cache_len):
110
+ model_output.past_key_values.key_cache[i] = (
111
+ model_output.past_key_values.key_cache[i][..., :cache_len, :]
112
+ )
113
+ model_output.past_key_values.value_cache[i] = (
114
+ model_output.past_key_values.value_cache[i][..., :cache_len, :]
115
+ )
116
+ return model_output
backbone_custom_modeling_qwen3.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.models.qwen3.modeling_qwen3 import (
6
+ ALL_ATTENTION_FUNCTIONS,
7
+ Cache,
8
+ FlashAttentionKwargs,
9
+ Qwen3Attention,
10
+ Qwen3Config,
11
+ Qwen3DecoderLayer,
12
+ Qwen3ForCausalLM,
13
+ Qwen3Model,
14
+ eager_attention_forward,
15
+ rotate_half,
16
+ )
17
+ from transformers.processing_utils import Unpack
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ def custom_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, q_start_idx=0):
24
+ """Applies Rotary Position Embedding to the query and key tensors."""
25
+ cos = cos.unsqueeze(unsqueeze_dim)
26
+ sin = sin.unsqueeze(unsqueeze_dim)
27
+ q_embed = (q * cos[..., q_start_idx:, :]) + (
28
+ rotate_half(q) * sin[..., q_start_idx:, :]
29
+ )
30
+ k_embed = (k * cos) + (rotate_half(k) * sin)
31
+ return q_embed, k_embed
32
+
33
+
34
+ class CustomQwen3Attention(Qwen3Attention):
35
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
36
+
37
+ def __init__(self, config: Qwen3Config, layer_idx: int):
38
+ super().__init__(config, layer_idx=layer_idx)
39
+
40
+ def forward(
41
+ self,
42
+ hidden_states: torch.Tensor,
43
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
44
+ attention_mask: Optional[torch.Tensor],
45
+ past_key_value: Optional[Cache] = None,
46
+ cache_position: Optional[torch.LongTensor] = None,
47
+ q_start_idx: int = 0, # > 0: decoder pass w/encoder inputs in hidden_states
48
+ **kwargs: Unpack[FlashAttentionKwargs],
49
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
50
+ input_shape = hidden_states.shape[:-1]
51
+ hidden_shape = (*input_shape, -1, self.head_dim)
52
+ sa_hidden_sates = hidden_states[:, q_start_idx:, :]
53
+ query_input_shape = sa_hidden_sates.shape[:-1]
54
+ query_hidden_shape = (*query_input_shape, -1, self.head_dim)
55
+
56
+ query_states = self.q_norm(
57
+ self.q_proj(sa_hidden_sates).reshape(query_hidden_shape)
58
+ ).transpose(1, 2)
59
+ key_states = self.k_norm(
60
+ self.k_proj(hidden_states).view(hidden_shape)
61
+ ).transpose(1, 2)
62
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
63
+
64
+ cos, sin = position_embeddings
65
+ query_states, key_states = custom_apply_rotary_pos_emb(
66
+ query_states, key_states, cos, sin, q_start_idx=q_start_idx
67
+ )
68
+
69
+ if past_key_value is not None:
70
+ # sin and cos are specific to RoPE models
71
+ # cache_position needed for the static cache
72
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
73
+ key_states, value_states = past_key_value.update(
74
+ key_states, value_states, self.layer_idx, cache_kwargs
75
+ )
76
+
77
+ # NOTE: downcast for flex-attention compatibility
78
+ query_states, key_states = (
79
+ query_states.to(value_states.dtype),
80
+ key_states.to(value_states.dtype),
81
+ )
82
+
83
+ attention_interface: Callable = eager_attention_forward
84
+ if self.config._attn_implementation != "eager":
85
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
86
+ self.config._attn_implementation
87
+ ]
88
+
89
+ attn_output, attn_weights = attention_interface(
90
+ self,
91
+ query_states,
92
+ key_states,
93
+ value_states,
94
+ attention_mask,
95
+ dropout=0.0 if not self.training else self.attention_dropout,
96
+ scaling=self.scaling,
97
+ sliding_window=self.sliding_window, # diff with Llama
98
+ **kwargs,
99
+ )
100
+
101
+ attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
102
+ attn_output = self.o_proj(attn_output)
103
+ return attn_output, attn_weights
104
+
105
+
106
+ class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
107
+ def __init__(self, config: Qwen3Config, layer_idx: int):
108
+ super().__init__(config, layer_idx=layer_idx)
109
+ self.self_attn = CustomQwen3Attention(config=config, layer_idx=layer_idx)
110
+
111
+ def forward(
112
+ self,
113
+ hidden_states: torch.Tensor,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ position_ids: Optional[torch.LongTensor] = None,
116
+ past_key_value: Optional[Cache] = None,
117
+ output_attentions: Optional[bool] = False,
118
+ use_cache: Optional[bool] = False,
119
+ cache_position: Optional[torch.LongTensor] = None,
120
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
121
+ q_start_idx: int = 0,
122
+ **kwargs: Unpack[FlashAttentionKwargs],
123
+ ) -> Tuple[
124
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
125
+ ]:
126
+ residual = hidden_states[:, q_start_idx:, ...]
127
+
128
+ hidden_states = self.input_layernorm(hidden_states)
129
+
130
+ # Self Attention
131
+ hidden_states, self_attn_weights = self.self_attn(
132
+ hidden_states=hidden_states,
133
+ attention_mask=attention_mask,
134
+ position_ids=position_ids,
135
+ past_key_value=past_key_value,
136
+ output_attentions=output_attentions,
137
+ use_cache=use_cache,
138
+ cache_position=cache_position,
139
+ position_embeddings=position_embeddings,
140
+ q_start_idx=q_start_idx,
141
+ **kwargs,
142
+ )
143
+ hidden_states = residual + hidden_states
144
+ # return hidden_states
145
+
146
+ # Fully Connected
147
+ residual = hidden_states
148
+ hidden_states = self.post_attention_layernorm(hidden_states)
149
+ hidden_states = self.mlp(hidden_states)
150
+ hidden_states = residual + hidden_states
151
+
152
+ outputs = (hidden_states,)
153
+ if output_attentions:
154
+ outputs += (self_attn_weights,)
155
+
156
+ return outputs
157
+
158
+
159
+ class CustomQwen3Model(Qwen3Model):
160
+ def __init__(self, config: Qwen3Config):
161
+ super().__init__(config)
162
+ self.layers = nn.ModuleList(
163
+ [
164
+ CustomQwen3DecoderLayer(config, layer_idx)
165
+ for layer_idx in range(config.num_hidden_layers)
166
+ ]
167
+ )
168
+ # Initialize weights and apply final processing
169
+ self.post_init()
170
+
171
+
172
+ class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
173
+ def __init__(self, config: Qwen3Config):
174
+ super().__init__(config)
175
+ # Initialize a new model with custom layers
176
+ self.model = CustomQwen3Model(config)
177
+
178
+ # Initialize weights and apply final processing
179
+ self.post_init()
backbone_encoder_decoder.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import AutoConfig, AutoModelForCausalLM
8
+ from transformers.cache_utils import DynamicCache
9
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ ModelOutput,
14
+ )
15
+ from transformers.processing_utils import Unpack
16
+ from transformers.utils import logging
17
+
18
+ from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
19
+
20
+ try:
21
+ from torch.nn.attention.flex_attention import BlockMask
22
+ except ImportError:
23
+ BlockMask = None
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class EncoderBaseModelOutputWithPast(ModelOutput):
31
+ """Custom (encoder) model output.
32
+ Stores previous decoder and updated encoder cache and encoder last hidden state.
33
+ """
34
+
35
+ past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
36
+ None
37
+ )
38
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
39
+ encoder_past_key_values: Optional[
40
+ Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
41
+ ] = None
42
+
43
+
44
+ @dataclass
45
+ class DecoderCausalLMOutputWithPast(ModelOutput):
46
+ """Custom (decoder) model output.
47
+ Stores previous encoder and updated decoder cache and decoder logits.
48
+ """
49
+
50
+ logits: Optional[torch.FloatTensor] = None
51
+ past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
52
+ None
53
+ )
54
+ encoder_past_key_values: Optional[
55
+ Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
56
+ ] = None
57
+
58
+
59
+ class LLMasEncoderDecoder(nn.Module):
60
+ def __init__(
61
+ self,
62
+ pretrained_model_name_or_path: str,
63
+ max_length: int,
64
+ attn_backend: str = "sdpa",
65
+ freeze_encoder: bool = False,
66
+ reinit_encoder: bool = False,
67
+ reinit_decoder: bool = False,
68
+ tie_encoder_decoder_weights: bool = False,
69
+ use_encoder_causal_mask: bool = False,
70
+ num_encoder_layers: int = -1,
71
+ num_decoder_layers: int = -1,
72
+ keep_top_encoder_layers: bool = False,
73
+ keep_top_decoder_layers: bool = False,
74
+ use_gradient_checkpointing: bool = False,
75
+ **llm_init_kwargs,
76
+ ):
77
+ assert not (tie_encoder_decoder_weights and reinit_decoder), (
78
+ "Cannot tie encoder-decoder weights and reinitialize decoder."
79
+ )
80
+ assert not (tie_encoder_decoder_weights and freeze_encoder), (
81
+ "Cannot freeze encoder weights when tying encoder-decoder weights."
82
+ )
83
+ super().__init__()
84
+ self.use_encoder_causal_mask = use_encoder_causal_mask
85
+ self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
86
+
87
+ if reinit_encoder:
88
+ assert num_encoder_layers > 0
89
+ encoder_config = AutoConfig.from_pretrained(
90
+ pretrained_model_name_or_path,
91
+ trust_remote_code=True,
92
+ num_hidden_layers=num_encoder_layers,
93
+ attn_implementation=attn_backend,
94
+ **llm_init_kwargs,
95
+ )
96
+ self.encoder = CustomQwen3ForCausalLM(encoder_config)
97
+ else:
98
+ self.encoder = CustomQwen3ForCausalLM.from_pretrained(
99
+ pretrained_model_name_or_path,
100
+ trust_remote_code=True,
101
+ attn_implementation=attn_backend,
102
+ **llm_init_kwargs,
103
+ )
104
+ assert num_encoder_layers <= len(self.encoder.model.layers), (
105
+ f"Cannot keep {num_encoder_layers} layers. "
106
+ f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
107
+ )
108
+ num_encoder_layers = (
109
+ len(self.encoder.model.layers)
110
+ if num_encoder_layers == -1
111
+ else num_encoder_layers
112
+ )
113
+ if keep_top_encoder_layers:
114
+ self.encoder.model.layers = self.encoder.model.layers[
115
+ -num_encoder_layers:
116
+ ]
117
+ else:
118
+ self.encoder.model.layers = self.encoder.model.layers[
119
+ :num_encoder_layers
120
+ ]
121
+
122
+ if freeze_encoder:
123
+ for name, param in self.encoder.named_parameters():
124
+ if "embed_tokens" not in name:
125
+ param.requires_grad = False
126
+ if use_gradient_checkpointing:
127
+ self.encoder.gradient_checkpointing_enable()
128
+
129
+ if tie_encoder_decoder_weights:
130
+ self.decoder = self.encoder
131
+ num_decoder_layers = (
132
+ len(self.decoder.model.layers)
133
+ if num_decoder_layers == -1
134
+ else num_decoder_layers
135
+ )
136
+ assert num_decoder_layers <= len(self.decoder.model.layers), (
137
+ f"Cannot keep {num_decoder_layers} layers. "
138
+ f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
139
+ )
140
+ # Keep **top** layers when tying weights
141
+ self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
142
+ -num_decoder_layers:
143
+ ]
144
+
145
+ else:
146
+ if reinit_decoder:
147
+ assert num_decoder_layers > 0
148
+ decoder_config = AutoConfig.from_pretrained(
149
+ pretrained_model_name_or_path,
150
+ trust_remote_code=True,
151
+ num_hidden_layers=num_decoder_layers,
152
+ attn_implementation=attn_backend,
153
+ **llm_init_kwargs,
154
+ )
155
+ self.decoder = CustomQwen3ForCausalLM(decoder_config)
156
+ else:
157
+ self.decoder = CustomQwen3ForCausalLM.from_pretrained(
158
+ pretrained_model_name_or_path,
159
+ trust_remote_code=True,
160
+ attn_implementation=attn_backend,
161
+ **llm_init_kwargs,
162
+ )
163
+ assert num_decoder_layers <= len(self.decoder.model.layers), (
164
+ f"Cannot keep {num_decoder_layers} layers. "
165
+ f"Pre-trained model only has {len(self.decoder.layers)} layers."
166
+ )
167
+ if keep_top_decoder_layers:
168
+ self.decoder.model.layers = self.decoder.model.layers[
169
+ -num_decoder_layers:
170
+ ]
171
+ else:
172
+ self.decoder.model.layers = self.decoder.model.layers[
173
+ :num_decoder_layers
174
+ ]
175
+ del self.decoder.model.embed_tokens
176
+ # if in the original LM, the lm_head is weight-tied to embedding,
177
+ # point decoder lm_head to encoder's (instead of initializing separately)
178
+ if (
179
+ self.encoder.lm_head.weight.data_ptr()
180
+ == self.encoder.model.embed_tokens.weight.data_ptr()
181
+ ):
182
+ self.decoder.lm_head = self.encoder.lm_head
183
+ else:
184
+ del self.encoder.lm_head
185
+ if use_gradient_checkpointing:
186
+ self.decoder.gradient_checkpointing_enable()
187
+ self.max_length = max_length
188
+
189
+ def freeze_encoder(self):
190
+ for p in self.encoder.model.parameters():
191
+ p.requires_grad = False
192
+
193
+ def unfreeze_encoder(self):
194
+ for p in self.encoder.model.parameters():
195
+ p.requires_grad = True
196
+
197
+ # noinspection PyUnusedLocal
198
+ def forward(
199
+ self,
200
+ # Decoder inputs
201
+ input_ids: torch.LongTensor,
202
+ attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
203
+ position_ids: Optional[torch.LongTensor] = None,
204
+ cache_position: Optional[torch.LongTensor] = None,
205
+ past_key_values: Optional[DynamicCache] = None,
206
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None,
207
+ # Encoder inputs
208
+ encoder_input_ids: Optional[torch.LongTensor] = None,
209
+ encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
210
+ encoder_position_ids: Optional[torch.LongTensor] = None,
211
+ encoder_cache_position: Optional[torch.LongTensor] = None,
212
+ encoder_past_key_values: Optional[DynamicCache] = None,
213
+ # Additional args
214
+ fix_cache_length: bool = True, # Not used; compatibility with other backbones
215
+ return_updated_cache: bool = False,
216
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
217
+ ) -> Union[DecoderCausalLMOutputWithPast, EncoderBaseModelOutputWithPast]:
218
+ # During training/eval encoder_last_hidden_state is None.
219
+ # During generation encoder_last_hidden_state can be not None.
220
+ new_seen_tokens = (
221
+ 0
222
+ if encoder_last_hidden_state is None
223
+ else encoder_last_hidden_state.shape[1]
224
+ )
225
+ # Encode clean tokens
226
+ if encoder_input_ids is not None:
227
+ if self.use_encoder_causal_mask:
228
+ encoder_attention_mask = None # None --> enforces use of causal mask
229
+ if encoder_cache_position is None and encoder_position_ids is not None:
230
+ encoder_cache_position = encoder_position_ids[0]
231
+ encoder_output = self.encoder.model(
232
+ input_ids=encoder_input_ids,
233
+ attention_mask=encoder_attention_mask,
234
+ position_ids=encoder_position_ids,
235
+ use_cache=True,
236
+ past_key_values=encoder_past_key_values,
237
+ cache_position=encoder_cache_position,
238
+ )
239
+ if return_updated_cache:
240
+ # encoder_output.past_key_values now contains latest encoder input
241
+ return EncoderBaseModelOutputWithPast(
242
+ encoder_last_hidden_state=encoder_output.last_hidden_state,
243
+ encoder_past_key_values=encoder_output.past_key_values,
244
+ past_key_values=past_key_values,
245
+ )
246
+ encoder_last_hidden_state = encoder_output.last_hidden_state
247
+
248
+ # Run decoder with xattn to clean token hidden states
249
+ if encoder_last_hidden_state is None: # No new encoder tokens
250
+ q_start_idx = 0
251
+ decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
252
+ if cache_position is None:
253
+ if position_ids is not None:
254
+ cache_position = position_ids[0]
255
+ else:
256
+ past_seen_tokens = (
257
+ past_key_values.get_seq_length()
258
+ if past_key_values is not None
259
+ else 0
260
+ )
261
+ cache_position = torch.arange(
262
+ past_seen_tokens,
263
+ past_seen_tokens + decoder_hidden_states.shape[1],
264
+ device=decoder_hidden_states.device,
265
+ )
266
+ if position_ids is None:
267
+ position_ids = cache_position.unsqueeze(0)
268
+ decoder_position_embeddings = self.decoder.model.rotary_emb(
269
+ decoder_hidden_states, position_ids
270
+ )
271
+ else:
272
+ q_start_idx = encoder_last_hidden_state.shape[1]
273
+ decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
274
+ decoder_hidden_states = torch.cat(
275
+ [
276
+ encoder_last_hidden_state,
277
+ decoder_hidden_states,
278
+ ],
279
+ dim=1,
280
+ )
281
+ if cache_position is None:
282
+ if position_ids is not None:
283
+ cache_position = position_ids[0]
284
+ else:
285
+ past_seen_tokens = (
286
+ past_key_values.get_seq_length()
287
+ if past_key_values is not None
288
+ else 0
289
+ )
290
+ cache_position = torch.cat(
291
+ [
292
+ torch.arange( # clean token position ids
293
+ past_seen_tokens,
294
+ past_seen_tokens + encoder_last_hidden_state.shape[1],
295
+ device=decoder_hidden_states.device,
296
+ ),
297
+ torch.arange( # noisy position ids
298
+ past_seen_tokens + new_seen_tokens,
299
+ past_seen_tokens + new_seen_tokens + input_ids.shape[1],
300
+ device=decoder_hidden_states.device,
301
+ ),
302
+ ],
303
+ dim=-1,
304
+ )
305
+ if position_ids is None:
306
+ position_ids = cache_position.unsqueeze(0)
307
+ decoder_position_embeddings = self.decoder.model.rotary_emb(
308
+ decoder_hidden_states, position_ids
309
+ )
310
+
311
+ if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
312
+ # noinspection PyProtectedMember
313
+ attention_mask = self.decoder.model._update_causal_mask(
314
+ attention_mask=attention_mask,
315
+ input_tensor=decoder_hidden_states,
316
+ cache_position=cache_position,
317
+ past_key_values=past_key_values,
318
+ output_attentions=False,
319
+ )
320
+ for decoder_layer in self.decoder.model.layers:
321
+ layer_idx = decoder_layer.self_attn.layer_idx
322
+ if (
323
+ self.tie_encoder_decoder_weights
324
+ and layer_idx not in self.decoder_layer_idxs
325
+ ):
326
+ continue
327
+ # past_key_values gets updated in-place.
328
+ # Record previous length to re-truncate after each layer forward
329
+ if past_key_values is not None and len(past_key_values) > layer_idx:
330
+ prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
331
+ else:
332
+ prev_cache_len = 0
333
+ cache_len = prev_cache_len + new_seen_tokens
334
+
335
+ if self.decoder.model.gradient_checkpointing and self.training:
336
+ # noinspection PyProtectedMember
337
+ decoder_hidden_states = self.decoder._gradient_checkpointing_func(
338
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
339
+ decoder_hidden_states, # hidden_states=,
340
+ attention_mask, # attention_mask=,
341
+ position_ids, # position_ids=,
342
+ past_key_values, # past_key_value=,
343
+ False, # output_attentions=,
344
+ True, # use_cache=,
345
+ cache_position, # cache_position=,
346
+ decoder_position_embeddings, # position_embeddings=,
347
+ q_start_idx, # q_start_idx=
348
+ )[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
349
+ else:
350
+ decoder_hidden_states = decoder_layer(
351
+ hidden_states=decoder_hidden_states,
352
+ attention_mask=attention_mask,
353
+ position_ids=position_ids,
354
+ past_key_value=past_key_values,
355
+ output_attentions=False,
356
+ use_cache=True,
357
+ cache_position=cache_position,
358
+ position_embeddings=decoder_position_embeddings,
359
+ q_start_idx=q_start_idx, # Indicates where to slice output
360
+ **flash_attn_kwargs,
361
+ )[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
362
+ # Update decoder_hidden_states
363
+ if q_start_idx > 0:
364
+ decoder_hidden_states = torch.cat(
365
+ [
366
+ encoder_last_hidden_state,
367
+ decoder_hidden_states,
368
+ ],
369
+ dim=1,
370
+ )
371
+
372
+ if past_key_values is not None:
373
+ # DynamicCache extends along sequence dimension by default;
374
+ # truncate back to original cache len + encoder output length
375
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
376
+ layer_idx
377
+ ][..., :cache_len, :]
378
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
379
+ layer_idx
380
+ ][..., :cache_len, :]
381
+ decoder_hidden_states = self.decoder.model.norm(
382
+ decoder_hidden_states[:, q_start_idx:, :]
383
+ )
384
+ logits = self.decoder.lm_head(decoder_hidden_states)
385
+ return DecoderCausalLMOutputWithPast(
386
+ logits=logits,
387
+ past_key_values=past_key_values,
388
+ encoder_past_key_values=encoder_past_key_values,
389
+ # Do not need to store encoder_last_hidden_state.
390
+ # If it was passed in, then it has become part of the past_key_values cache.
391
+ )
392
+
393
+
394
+ class LLMasEncoderDecoderShareKV(nn.Module):
395
+ def __init__(
396
+ self,
397
+ pretrained_model_name_or_path: str,
398
+ max_length: int,
399
+ attn_backend: str = "sdpa",
400
+ freeze_encoder: bool = False,
401
+ reinit_encoder: bool = False,
402
+ reinit_decoder: bool = False,
403
+ tie_encoder_decoder_weights: bool = False,
404
+ use_encoder_causal_mask: bool = False,
405
+ num_encoder_layers: int = -1,
406
+ num_decoder_layers: int = -1,
407
+ keep_top_encoder_layers: bool = False,
408
+ keep_top_decoder_layers: bool = False,
409
+ use_gradient_checkpointing: bool = False,
410
+ **llm_init_kwargs,
411
+ ):
412
+ assert not (tie_encoder_decoder_weights and reinit_decoder), (
413
+ "Cannot tie encoder-decoder weights and reinitialize decoder."
414
+ )
415
+ assert not (tie_encoder_decoder_weights and freeze_encoder), (
416
+ "Cannot freeze encoder weights when tying encoder-decoder weights."
417
+ )
418
+ super().__init__()
419
+ self.use_encoder_causal_mask = use_encoder_causal_mask
420
+ self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
421
+
422
+ if reinit_encoder:
423
+ assert num_encoder_layers > 0
424
+ encoder_config = AutoConfig.from_pretrained(
425
+ pretrained_model_name_or_path,
426
+ trust_remote_code=True,
427
+ num_hidden_layers=num_encoder_layers,
428
+ attn_implementation=attn_backend,
429
+ **llm_init_kwargs,
430
+ )
431
+ self.encoder = AutoModelForCausalLM.from_config(encoder_config)
432
+ else:
433
+ self.encoder = AutoModelForCausalLM.from_pretrained(
434
+ pretrained_model_name_or_path,
435
+ trust_remote_code=True,
436
+ attn_implementation=attn_backend,
437
+ **llm_init_kwargs,
438
+ )
439
+ assert num_encoder_layers <= len(self.encoder.model.layers), (
440
+ f"Cannot keep {num_encoder_layers} layers. "
441
+ f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
442
+ )
443
+ num_encoder_layers = (
444
+ len(self.encoder.model.layers)
445
+ if num_encoder_layers == -1
446
+ else num_encoder_layers
447
+ )
448
+ if keep_top_encoder_layers:
449
+ self.encoder.model.layers = self.encoder.model.layers[
450
+ -num_encoder_layers:
451
+ ]
452
+ else:
453
+ self.encoder.model.layers = self.encoder.model.layers[
454
+ :num_encoder_layers
455
+ ]
456
+
457
+ if freeze_encoder:
458
+ for name, param in self.encoder.named_parameters():
459
+ if "embed_tokens" not in name:
460
+ param.requires_grad = False
461
+ if use_gradient_checkpointing:
462
+ self.encoder.gradient_checkpointing_enable()
463
+
464
+ if tie_encoder_decoder_weights:
465
+ self.decoder = self.encoder
466
+ num_decoder_layers = (
467
+ len(self.decoder.model.layers)
468
+ if num_decoder_layers == -1
469
+ else num_decoder_layers
470
+ )
471
+ assert num_decoder_layers <= len(self.decoder.model.layers), (
472
+ f"Cannot keep {num_decoder_layers} layers. "
473
+ f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
474
+ )
475
+ # Keep **top** layers when tying weights
476
+ self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
477
+ -num_decoder_layers:
478
+ ]
479
+
480
+ else:
481
+ if reinit_decoder:
482
+ assert num_decoder_layers > 0
483
+ decoder_config = AutoConfig.from_pretrained(
484
+ pretrained_model_name_or_path,
485
+ trust_remote_code=True,
486
+ num_hidden_layers=num_decoder_layers,
487
+ attn_implementation=attn_backend,
488
+ **llm_init_kwargs,
489
+ )
490
+ self.decoder = AutoModelForCausalLM(decoder_config)
491
+ else:
492
+ self.decoder = AutoModelForCausalLM.from_pretrained(
493
+ pretrained_model_name_or_path,
494
+ trust_remote_code=True,
495
+ attn_implementation=attn_backend,
496
+ **llm_init_kwargs,
497
+ )
498
+ assert num_decoder_layers <= len(self.decoder.model.layers), (
499
+ f"Cannot keep {num_decoder_layers} layers. "
500
+ f"Pre-trained model only has {len(self.decoder.layers)} layers."
501
+ )
502
+ if keep_top_decoder_layers:
503
+ self.decoder.model.layers = self.decoder.model.layers[
504
+ -num_decoder_layers:
505
+ ]
506
+ else:
507
+ self.decoder.model.layers = self.decoder.model.layers[
508
+ :num_decoder_layers
509
+ ]
510
+ del self.decoder.model.embed_tokens
511
+ # Even for frozen encoder, ensure embedding tokens are trainable
512
+ self.encoder.model.embed_tokens.requires_grad_(True)
513
+ unused_self_attn_params = ["o_proj", "q_norm", "q_proj"]
514
+ unused_layernorm_params = ["input_layernorm", "post_attention_layernorm"]
515
+ for unused_param in unused_self_attn_params:
516
+ if hasattr(self.encoder.model.layers[-1].self_attn, unused_param):
517
+ getattr(
518
+ self.encoder.model.layers[-1].self_attn, unused_param
519
+ ).requires_grad_(False)
520
+ self.encoder.model.layers[-1].mlp.requires_grad_(False)
521
+ self.encoder.model.norm.requires_grad_(False)
522
+ for unused_param in unused_layernorm_params:
523
+ if hasattr(self.encoder.model.layers[-1], unused_param):
524
+ getattr(self.encoder.model.layers[-1], unused_param).requires_grad_(
525
+ False
526
+ )
527
+ # if in the original LM, the lm_head is weight-tied to embedding,
528
+ # point decoder lm_head to encoder's (instead of initializing separately)
529
+ if (
530
+ self.encoder.lm_head.weight.data_ptr()
531
+ == self.encoder.model.embed_tokens.weight.data_ptr()
532
+ ):
533
+ self.decoder.lm_head = self.encoder.lm_head
534
+ else:
535
+ del self.encoder.lm_head
536
+ if use_gradient_checkpointing:
537
+ self.decoder.gradient_checkpointing_enable()
538
+ self.max_length = max_length
539
+
540
+ def freeze_encoder(self):
541
+ for p in self.encoder.model.parameters():
542
+ p.requires_grad = False
543
+
544
+ def unfreeze_encoder(self):
545
+ for p in self.encoder.model.parameters():
546
+ p.requires_grad = True
547
+
548
+ # noinspection PyUnusedLocal
549
+ def forward(
550
+ self,
551
+ # Decoder inputs
552
+ input_ids: torch.LongTensor,
553
+ attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
554
+ position_ids: Optional[torch.LongTensor] = None,
555
+ cache_position: Optional[torch.LongTensor] = None,
556
+ past_key_values: Optional[DynamicCache] = None,
557
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None, # Not used
558
+ # Encoder inputs
559
+ encoder_input_ids: Optional[torch.LongTensor] = None,
560
+ encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
561
+ encoder_position_ids: Optional[torch.LongTensor] = None,
562
+ encoder_cache_position: Optional[torch.LongTensor] = None,
563
+ encoder_past_key_values: Optional[DynamicCache] = None, # Not used
564
+ # Additional args
565
+ fix_cache_length: bool = True, # Not used; compatibility with other backbones
566
+ return_updated_cache: bool = False,
567
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
568
+ ) -> Union[CausalLMOutputWithPast, BaseModelOutputWithPast]:
569
+ # Encode clean tokens
570
+ if encoder_input_ids is not None:
571
+ if self.use_encoder_causal_mask:
572
+ encoder_attention_mask = None # None --> enforces use of causal mask
573
+ if encoder_cache_position is None and encoder_position_ids is not None:
574
+ encoder_cache_position = encoder_position_ids[0]
575
+ past_key_values = self.encoder.model(
576
+ input_ids=encoder_input_ids,
577
+ attention_mask=encoder_attention_mask,
578
+ position_ids=encoder_position_ids,
579
+ use_cache=True,
580
+ past_key_values=past_key_values,
581
+ cache_position=encoder_cache_position,
582
+ ).past_key_values
583
+ if return_updated_cache:
584
+ # encoder_output.past_key_values now contains latest encoder input
585
+ return BaseModelOutputWithPast(
586
+ past_key_values=past_key_values,
587
+ )
588
+
589
+ # Run decoder with xattn to clean token hidden states
590
+ decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
591
+ if cache_position is None:
592
+ if position_ids is not None:
593
+ cache_position = position_ids[0]
594
+ else: # During training / validation position_ids are not provided
595
+ cache_position = torch.arange(
596
+ decoder_hidden_states.shape[1],
597
+ device=decoder_hidden_states.device,
598
+ )
599
+ if position_ids is None:
600
+ position_ids = cache_position.unsqueeze(0)
601
+ decoder_position_embeddings = self.decoder.model.rotary_emb(
602
+ decoder_hidden_states, position_ids
603
+ )
604
+
605
+ if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
606
+ # noinspection PyProtectedMember
607
+ attention_mask = self.decoder.model._update_causal_mask(
608
+ attention_mask=attention_mask,
609
+ input_tensor=decoder_hidden_states,
610
+ cache_position=cache_position,
611
+ past_key_values=past_key_values,
612
+ output_attentions=False,
613
+ )
614
+ for decoder_layer in self.decoder.model.layers:
615
+ layer_idx = decoder_layer.self_attn.layer_idx
616
+ if (
617
+ self.tie_encoder_decoder_weights
618
+ and layer_idx not in self.decoder_layer_idxs
619
+ ):
620
+ continue
621
+ # past_key_values gets updated in-place.
622
+ # Record previous length to truncate after each layer forward
623
+ if past_key_values is not None and len(past_key_values) > layer_idx:
624
+ prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
625
+ else:
626
+ prev_cache_len = 0
627
+
628
+ decoder_hidden_states = decoder_layer(
629
+ hidden_states=decoder_hidden_states,
630
+ attention_mask=attention_mask,
631
+ position_ids=position_ids,
632
+ past_key_value=past_key_values,
633
+ output_attentions=False,
634
+ use_cache=True,
635
+ cache_position=position_ids[0],
636
+ position_embeddings=decoder_position_embeddings,
637
+ **flash_attn_kwargs,
638
+ )[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
639
+
640
+ if past_key_values is not None:
641
+ # DynamicCache extends along sequence dimension by default;
642
+ # truncate back to original cache len + encoder output length
643
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
644
+ layer_idx
645
+ ][..., :prev_cache_len, :]
646
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
647
+ layer_idx
648
+ ][..., :prev_cache_len, :]
649
+ decoder_hidden_states = self.decoder.model.norm(decoder_hidden_states)
650
+ logits = self.decoder.lm_head(decoder_hidden_states)
651
+ return CausalLMOutputWithPast(
652
+ logits=logits,
653
+ past_key_values=past_key_values,
654
+ )
denoiser_base.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+ import sys
4
+ from abc import ABC, abstractmethod
5
+ from collections import OrderedDict
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+
9
+ import hydra.utils
10
+ import torch
11
+ from hydra.errors import InstantiationException
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ DynamicCache,
15
+ GenerationConfig,
16
+ LogitsProcessorList,
17
+ PretrainedConfig,
18
+ PreTrainedModel,
19
+ StoppingCriteriaList,
20
+ )
21
+ from transformers.cache_utils import Cache
22
+ from transformers.generation.utils import GenerateOutput
23
+ from transformers.modeling_outputs import ModelOutput
24
+
25
+ # Local imports not used, but added here so that HF push_to_hub adds them to model repo
26
+ # noinspection PyUnresolvedReferences
27
+ from .backbone_automodel import AutoModelFromPreTrained # noqa: F401
28
+ from .backbone_encoder_decoder import ( # noqa: F401
29
+ LLMasEncoderDecoder,
30
+ LLMasEncoderDecoderShareKV,
31
+ )
32
+ from .noise_schedule_noise_schedules import ( # noqa: F401
33
+ CosineNoise,
34
+ ExponentialNoise,
35
+ LinearNoise,
36
+ LogarithmicNoise,
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class DenoiserInput(OrderedDict):
42
+ """Input to the denoiser model."""
43
+
44
+ xt: torch.LongTensor # (B, L) token_ids
45
+ x0: Optional[torch.LongTensor] = None # (B, L) token_ids (not used in gen.)
46
+ attention_mask: Optional[torch.FloatTensor] = None
47
+ past_key_values: Optional[Union[torch.FloatTensor, Cache]] = None
48
+ context_mask: Optional[torch.FloatTensor] = None
49
+ tokens_mask: Optional[torch.FloatTensor] = None # (B, L)
50
+ t: Optional[torch.FloatTensor] = None # (B,) | # (B, L)
51
+ alpha_t: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1)
52
+ alpha_t_prime: Optional[torch.FloatTensor] = None # (B,) | (B, 1|L) | (B, 1|L, 1)
53
+ backbone_kwargs: dict[str, Any] = field(default_factory=dict)
54
+
55
+
56
+ @dataclass
57
+ class LossAndNllOutput(OrderedDict):
58
+ """Loss output for denoiser models."""
59
+
60
+ loss: torch.FloatTensor
61
+ nlls: torch.FloatTensor
62
+ other_loss_terms: dict = field(default_factory=dict)
63
+
64
+
65
+ @dataclass
66
+ class DenoiserOutput(ModelOutput):
67
+ """Output of the denoiser model."""
68
+
69
+ denoiser_output: Optional[torch.FloatTensor] = None
70
+ logits: Optional[torch.FloatTensor] = None
71
+ tokens_mask: Optional[torch.FloatTensor] = None # Which tokens contribute to loss
72
+ past_key_values: Optional[Cache] = None
73
+ loss: Optional[torch.FloatTensor] = None
74
+ nlls: Optional[torch.FloatTensor] = None
75
+ other_loss_terms: Optional[dict[str, Any]] = None
76
+
77
+
78
+ class DenoiserConfig(PretrainedConfig):
79
+ """Configuration class for Denoiser models.
80
+
81
+ This class is used to initialize the model and contains all the necessary
82
+ parameters for the model's architecture.
83
+ """
84
+
85
+ model_type = "denoiser"
86
+
87
+ def __init__(
88
+ self,
89
+ length: Optional[int] = None,
90
+ backbone_config: Optional[Dict[str, Any]] = None,
91
+ noise_config: Optional[Dict[str, Any]] = None,
92
+ tokenization_config: Optional[Dict[str, Any]] = None,
93
+ time_conditioned_backbone: Optional[bool] = None,
94
+ attn_backend: str = "sdpa", # "sdpa", "flash_attention_2", "flex_attention"
95
+ train_on_context: bool = False,
96
+ **kwargs,
97
+ ):
98
+ super().__init__(**kwargs)
99
+ for v in [
100
+ "vocab_size",
101
+ "mask_token_id",
102
+ "pad_token_id",
103
+ "bos_token_id",
104
+ "eos_token_id",
105
+ "pad_vocab_size_multiple",
106
+ ]:
107
+ if tokenization_config is not None and (
108
+ getattr(self, v, None) is None or v in tokenization_config
109
+ ):
110
+ setattr(self, v, tokenization_config.get(v, None))
111
+ else:
112
+ setattr(self, v, None)
113
+ self.backbone_config = backbone_config
114
+ self.noise_config = noise_config
115
+ self.tokenization_config = tokenization_config
116
+ self.length = length
117
+ self.time_conditioned_backbone = time_conditioned_backbone
118
+ self.attn_backend = attn_backend
119
+ self.train_on_context = train_on_context
120
+
121
+
122
+ class Denoiser(ABC, PreTrainedModel):
123
+ """Abstract base class for denoising models.
124
+
125
+ This class defines the interface for AR, Diffusion, and Flow-based parametrizations.
126
+ """
127
+
128
+ config_class = DenoiserConfig
129
+
130
+ def __init__(
131
+ self,
132
+ config: DenoiserConfig,
133
+ **kwargs,
134
+ ):
135
+ """
136
+ Initialize the Denoiser with a configuration and optional dataset type.
137
+
138
+ Parameters:
139
+ config (Any): Configuration object for the model.
140
+ """
141
+ super().__init__(config)
142
+ self.config = config
143
+ self.vocab_size = config.vocab_size
144
+ self.mask_token_id = config.mask_token_id
145
+ self.pad_token_id = config.pad_token_id
146
+ self.bos_token_id = config.bos_token_id
147
+ self.eos_token_id = config.eos_token_id
148
+ try:
149
+ self.backbone = hydra.utils.instantiate(config.backbone_config)
150
+ except InstantiationException:
151
+ # When using HF and `from_pretrained`, the modules specified in `_target_`
152
+ # fields in our configs are already being imported under a name with the
153
+ # following format: transformers_modules.<repo_id>.<commit_id>.
154
+ # When hydra attempts to instantiate and calls importlib under the hood, the
155
+ # desired module is not found.
156
+ # The snippet below aliases the desired module, enabling seamless use of
157
+ # `hydra.utils.instantiate`.
158
+ sys_modules = copy.deepcopy(list(sys.modules.keys()))
159
+ repo_root_module = ".".join(__name__.split(".")[:-1])
160
+ for name in sys_modules:
161
+ if name.startswith(repo_root_module):
162
+ short = name.split(".")[-1]
163
+ if short not in sys.modules:
164
+ sys.modules[short] = sys.modules[name]
165
+ del sys_modules
166
+ self.backbone = hydra.utils.instantiate(config.backbone_config)
167
+ self.tokenizer = AutoTokenizer.from_pretrained(
168
+ config.tokenizer_name,
169
+ trust_remote_code=True,
170
+ )
171
+ self.noise_schedule = (
172
+ hydra.utils.instantiate(config.noise_config)
173
+ if config.noise_config is not None
174
+ else None
175
+ )
176
+ self.time_conditioned_backbone = (
177
+ config.time_conditioned_backbone
178
+ if config.time_conditioned_backbone is not None
179
+ else "noise" in inspect.getfullargspec(self.backbone.forward).args
180
+ )
181
+ # List that can contain any parameters that should not be pushed to HF,
182
+ # e.g., registered buffers for static attention masks
183
+ self.skip_params_for_push = []
184
+
185
+ @abstractmethod
186
+ def _prepare_inputs(
187
+ self,
188
+ input_ids: torch.LongTensor,
189
+ attention_mask: Optional[torch.FloatTensor] = None,
190
+ context_mask: Optional[torch.FloatTensor] = None,
191
+ t: Optional[torch.FloatTensor] = None,
192
+ past_key_values: Optional[Cache] = None,
193
+ ) -> DenoiserInput:
194
+ """
195
+ Prepare inputs for the model.
196
+
197
+ Parameters:
198
+ input_ids (LongTensor): Input tensor to the model.
199
+ attention_mask (Optional[FloatTensor]): Attention mask for the model.
200
+ t (Optional[FloatTensor]): Time step for the model.
201
+ past_key_values (Optional[Cache]): Past key values for the model.
202
+ Returns:
203
+ Denoiser inputs.
204
+ """
205
+ raise NotImplementedError("Denoiser subclasses must implement _prepare_inputs")
206
+
207
+ def _prepare_inputs_inference(
208
+ self,
209
+ input_ids: Optional[torch.LongTensor] = None,
210
+ attention_mask: Optional[torch.FloatTensor] = None,
211
+ context: Optional[torch.LongTensor] = None,
212
+ context_mask: Optional[torch.FloatTensor] = None,
213
+ cache: Optional[Dict[str, Any]] = None,
214
+ **backbone_kwargs: Any,
215
+ ) -> Tuple[DenoiserInput, Dict[str, Any]]:
216
+ raise NotImplementedError(
217
+ "Denoiser subclasses must implement _prepare_inputs_inference"
218
+ )
219
+ # assert input_ids is not None or context is not None, (
220
+ # "Must provide either input_ids or context."
221
+ # )
222
+ # cache = cache if cache is not None else {}
223
+ # past_key_values = cache.pop("past_key_values", DynamicCache())
224
+ # if context is not None:
225
+ # if input_ids is not None:
226
+ # if context_mask is None:
227
+ # context_mask = torch.cat(
228
+ # [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1
229
+ # )
230
+ # input_ids = torch.cat([context, input_ids], dim=-1)
231
+ # else:
232
+ # input_ids = context
233
+ # context_mask = torch.ones_like(input_ids)
234
+ # if attention_mask is None:
235
+ # cache_length = self._get_past_key_values_seq_length(past_key_values)
236
+ # full_seq_length = cache_length + input_ids.shape[-1]
237
+ # attention_mask = torch.ones(
238
+ # (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length),
239
+ # device=input_ids.device,
240
+ # ) # Make attention mask 4D
241
+ # attention_mask = self._preprocess_attention_mask(
242
+ # attention_mask, dtype=torch.float
243
+ # )
244
+ # return DenoiserInput(
245
+ # xt=input_ids,
246
+ # attention_mask=attention_mask,
247
+ # past_key_values=past_key_values,
248
+ # context_mask=context_mask,
249
+ # backbone_kwargs=backbone_kwargs,
250
+ # ), cache
251
+
252
+ @abstractmethod
253
+ def _compute_loss(
254
+ self,
255
+ model_output: torch.FloatTensor,
256
+ denoiser_inputs: DenoiserInput,
257
+ **kwargs: Any,
258
+ ) -> LossAndNllOutput:
259
+ """
260
+ Compute the loss for the denoising model.
261
+
262
+ Parameters:
263
+ model_output (FloatTensor): Output tensor from self.forward.
264
+ denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
265
+
266
+ Returns:
267
+ LossAndNllOutput: loss (FloatTensor) and nlls (FloatTensor).
268
+ """
269
+ raise NotImplementedError("Denoiser subclasses must implement _compute_loss")
270
+
271
+ def _forward(
272
+ self,
273
+ backbone_output: torch.FloatTensor,
274
+ denoiser_inputs: DenoiserInput,
275
+ **kwargs: Any,
276
+ ) -> torch.FloatTensor:
277
+ """
278
+ Forward pass for the denoiser model returns probabilities over denoised
279
+ sequence.
280
+
281
+ Some classes may need to override this method.
282
+
283
+ Parameters:
284
+ backbone_output (FloatTensor): Output tensor from the backbone model.
285
+ denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
286
+
287
+ Returns:
288
+ Model outputs (FloatTensor).
289
+ """
290
+ return torch.log_softmax(backbone_output, dim=-1) # type: ignore
291
+
292
+ def _backbone_forward(
293
+ self,
294
+ denoiser_inputs: DenoiserInput,
295
+ **backbone_kwargs: Any,
296
+ ) -> ModelOutput:
297
+ """Forward pass for the backbone model (should return logits).
298
+
299
+ Some classes may need to override this method.
300
+
301
+ Parameters:
302
+ denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model.
303
+ return_updated_cache (bool): If True, return past_key_values instead of
304
+ logits.
305
+
306
+ Returns:
307
+ Backbone output (ModelOutput instance).
308
+ """
309
+ if self.time_conditioned_backbone:
310
+ return self.backbone(
311
+ denoiser_inputs.xt,
312
+ attention_mask=denoiser_inputs.attention_mask,
313
+ past_key_values=denoiser_inputs.past_key_values,
314
+ noise=denoiser_inputs.alpha_t,
315
+ **denoiser_inputs.backbone_kwargs,
316
+ **backbone_kwargs,
317
+ )
318
+ return self.backbone(
319
+ denoiser_inputs.xt,
320
+ attention_mask=denoiser_inputs.attention_mask,
321
+ past_key_values=denoiser_inputs.past_key_values,
322
+ **denoiser_inputs.backbone_kwargs,
323
+ **backbone_kwargs,
324
+ )
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.LongTensor,
329
+ attention_mask: Optional[torch.FloatTensor] = None,
330
+ context_mask: Optional[torch.FloatTensor] = None,
331
+ t: Optional[torch.FloatTensor] = None,
332
+ past_key_values: Optional[Cache] = None,
333
+ compute_loss: Optional[bool] = True,
334
+ **kwargs,
335
+ ) -> DenoiserOutput:
336
+ """
337
+ Perform a forward pass through the denoising model and
338
+ (optionally) compute the loss.
339
+
340
+ Parameters:
341
+ input_ids (LongTensor): Input tensor to the model.
342
+ attention_mask (Optional[FloatTensor]): Attention mask for the model.
343
+ context_mask (Optional[FloatTensor]): Indicator for context tokens.
344
+ t (Optional[FloatTensor]): Denoising time step for the model.
345
+ past_key_values (Optional[Cache]): KV cache.
346
+ compute_loss (Optional[bool]): Flag to compute loss.
347
+
348
+ Returns:
349
+ DenoiserOutput
350
+ """
351
+ denoiser_inputs = self._prepare_inputs(
352
+ input_ids=input_ids,
353
+ attention_mask=attention_mask,
354
+ context_mask=context_mask,
355
+ past_key_values=past_key_values,
356
+ t=t,
357
+ )
358
+
359
+ backbone_output = self._backbone_forward(denoiser_inputs, **kwargs)
360
+ new_past_key_values = getattr(backbone_output, "past_key_values", None)
361
+ backbone_output = getattr(backbone_output, "logits", backbone_output[0])
362
+ denoiser_output = self._forward(
363
+ backbone_output,
364
+ denoiser_inputs,
365
+ **kwargs,
366
+ )
367
+
368
+ if compute_loss:
369
+ loss_and_nll = self._compute_loss(
370
+ model_output=denoiser_output, denoiser_inputs=denoiser_inputs, **kwargs
371
+ )
372
+ loss = loss_and_nll.loss
373
+ nlls = loss_and_nll.nlls
374
+ other_loss_terms = loss_and_nll.other_loss_terms
375
+ else:
376
+ loss, nlls = None, None
377
+ other_loss_terms = {}
378
+
379
+ return DenoiserOutput(
380
+ denoiser_output=denoiser_output,
381
+ logits=backbone_output,
382
+ past_key_values=new_past_key_values,
383
+ tokens_mask=denoiser_inputs.tokens_mask,
384
+ loss=loss,
385
+ nlls=nlls,
386
+ other_loss_terms=other_loss_terms,
387
+ )
388
+
389
+ @staticmethod
390
+ def _sample_categorical(categorical_probs, do_sample=True):
391
+ """Helper function to sample from a categorical distribution."""
392
+ categorical_probs = categorical_probs.to(torch.float64)
393
+ if not do_sample:
394
+ return categorical_probs.argmax(dim=-1)
395
+ gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to(
396
+ categorical_probs.dtype
397
+ )
398
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
399
+
400
+ @staticmethod
401
+ def _preprocess_attention_mask(attention_mask, dtype):
402
+ min_dtype = torch.finfo(dtype).min
403
+ attention_mask = torch.where(
404
+ (attention_mask == 0.0).bool(), # type: ignore
405
+ min_dtype,
406
+ 0.0,
407
+ ).to(dtype)
408
+ return attention_mask
409
+
410
+ @staticmethod
411
+ def _get_past_key_values_seq_length(past_key_values: DynamicCache):
412
+ seq_length = 0
413
+ for i in range(len(past_key_values)):
414
+ if past_key_values[i][0].shape[0] > 0: # type: ignore
415
+ seq_length = max(
416
+ past_key_values[i][0].shape[-2], # type: ignore
417
+ seq_length,
418
+ )
419
+ return seq_length
420
+
421
+ def update_cache(
422
+ self,
423
+ inputs: torch.LongTensor,
424
+ cache: Optional[Dict[str, Any]] = None,
425
+ **backbone_kwargs: Any,
426
+ ) -> Dict[str, Any]:
427
+ """
428
+ Cache the key-value pairs for the context.
429
+ Args:
430
+ inputs (torch.LongTensor): The context tensor.
431
+ cache (Dict[str, Any | None): Cache objects, e.g., past_key_values.
432
+ Returns:
433
+ Dict: Updated cache objects, e.g., past_key_values.
434
+ """
435
+ context_input, cache = self._prepare_inputs_inference(
436
+ input_ids=inputs, cache=cache, return_updated_cache=True, **backbone_kwargs
437
+ )
438
+ backbone_output = self._backbone_forward(
439
+ context_input,
440
+ return_updated_cache=True, # Will get absorbed in backbone_kwargs
441
+ **cache,
442
+ )
443
+ backbone_output = {k: v for k, v in backbone_output.items()}
444
+ backbone_output.pop("logits", None) # Do not store logits in cache
445
+ cache = cache | backbone_output
446
+ return cache
447
+
448
+ @torch.no_grad()
449
+ def generate(
450
+ self,
451
+ inputs: Optional[torch.LongTensor] = None,
452
+ generation_config: Optional[GenerationConfig] = None,
453
+ logits_processor: Optional[LogitsProcessorList] = None,
454
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
455
+ max_length: Optional[int] = None,
456
+ max_new_tokens: Optional[int] = None,
457
+ batch_size: Optional[int] = None,
458
+ device: Optional[str] = None,
459
+ **kwargs: Any,
460
+ ) -> Union[GenerateOutput, torch.LongTensor]:
461
+ """Generates sample from denoising model.
462
+ Follows signature of transformers.GenerationMixin.
463
+ """
464
+ raise NotImplementedError("Denoiser subclasses must implement generate")
diffusion.py CHANGED
@@ -21,7 +21,7 @@ except ImportError:
21
  BlockMask, and_masks, create_block_mask = None, None, None
22
 
23
 
24
- from src.denoiser.base import (
25
  Denoiser,
26
  DenoiserConfig,
27
  DenoiserInput,
 
21
  BlockMask, and_masks, create_block_mask = None, None, None
22
 
23
 
24
+ from .denoiser_base import (
25
  Denoiser,
26
  DenoiserConfig,
27
  DenoiserInput,
noise_schedule_noise_schedules.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+
5
+
6
+ class Noise(ABC):
7
+ """
8
+ Baseline forward method to get noise parameters at a timestep
9
+ """
10
+
11
+ def __call__(
12
+ self, t: torch.Tensor | float
13
+ ) -> tuple[torch.Tensor | float, torch.Tensor | float]:
14
+ # Assume time goes from 0 to 1
15
+ pass
16
+
17
+ @abstractmethod
18
+ def inverse(self, alpha_t: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ Inverse function to compute the timestep t from the noise schedule param.
21
+ """
22
+ raise NotImplementedError("Inverse function not implemented")
23
+
24
+
25
+ class CosineNoise(Noise):
26
+ def __init__(self, eps=1e-3):
27
+ super().__init__()
28
+ self.eps = eps
29
+ self.name = "cosine"
30
+
31
+ def __call__(self, t):
32
+ t = t.to(torch.float32)
33
+ cos = -(1 - self.eps) * torch.cos(t * torch.pi / 2)
34
+ sin = -(1 - self.eps) * torch.sin(t * torch.pi / 2)
35
+ move_chance = cos + 1
36
+ alpha_t_prime = sin * torch.pi / 2
37
+ return 1 - move_chance, alpha_t_prime
38
+
39
+
40
+ class ExponentialNoise(Noise):
41
+ def __init__(self, exp=2, eps=1e-3):
42
+ super().__init__()
43
+ self.eps = eps
44
+ self.exp = exp
45
+ self.name = f"exp_{exp}"
46
+
47
+ def __call__(self, t):
48
+ t = t.to(torch.float32)
49
+ move_chance = torch.pow(t, self.exp)
50
+ move_chance = torch.clamp(move_chance, min=self.eps)
51
+ alpha_t_prime = -self.exp * torch.pow(t, self.exp - 1)
52
+ return alpha_t_prime, 1 - move_chance
53
+
54
+
55
+ class LogarithmicNoise(Noise):
56
+ def __init__(self, eps=1e-3):
57
+ super().__init__()
58
+ self.eps = eps
59
+ self.name = "logarithmic"
60
+
61
+ def __call__(self, t):
62
+ t = t.to(torch.float32)
63
+ move_chance = torch.log1p(t) / torch.log(torch.tensor(2.0))
64
+ alpha_t_prime = -1 / (torch.log(torch.tensor(2.0)) * (1 + t))
65
+ return 1 - move_chance, alpha_t_prime
66
+
67
+
68
+ class LinearNoise(Noise):
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.name = "linear"
72
+
73
+ def inverse(self, alpha_t):
74
+ return 1 - alpha_t
75
+
76
+ def __call__(self, t):
77
+ t = t.to(torch.float32)
78
+ alpha_t_prime = -torch.ones_like(t)
79
+ move_chance = t
80
+ return 1 - move_chance, alpha_t_prime