Spaces:
Paused
Paused
Add code/cube3d/model/gpt/dual_stream_roformer.py
Browse files
code/cube3d/model/gpt/dual_stream_roformer.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from cube3d.model.transformers.cache import Cache
|
| 8 |
+
from cube3d.model.transformers.dual_stream_attention import (
|
| 9 |
+
DualStreamDecoderLayerWithRotaryEmbedding,
|
| 10 |
+
)
|
| 11 |
+
from cube3d.model.transformers.norm import LayerNorm
|
| 12 |
+
from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding
|
| 13 |
+
from cube3d.model.transformers.rope import precompute_freqs_cis
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DualStreamRoformer(nn.Module):
|
| 17 |
+
@dataclass
|
| 18 |
+
class Config:
|
| 19 |
+
checkpoint_path: str = ""
|
| 20 |
+
n_layer: int = 12
|
| 21 |
+
n_single_layer: int = 0
|
| 22 |
+
rope_theta: float = 1000
|
| 23 |
+
|
| 24 |
+
n_head: int = 16
|
| 25 |
+
n_embd: int = 2048
|
| 26 |
+
bias: bool = False # bias in Linears and LayerNorms
|
| 27 |
+
eps: float = 1e-6 # Norm eps
|
| 28 |
+
|
| 29 |
+
shape_model_vocab_size: int = 4096
|
| 30 |
+
shape_model_embed_dim: int = 16
|
| 31 |
+
|
| 32 |
+
text_model_embed_dim: int = 512
|
| 33 |
+
use_pooled_text_embed: bool = False
|
| 34 |
+
|
| 35 |
+
encoder_with_cls_token: bool = True
|
| 36 |
+
|
| 37 |
+
use_bbox: bool = False
|
| 38 |
+
|
| 39 |
+
ldr_in_embed_dim: int = 2048
|
| 40 |
+
ldr_out_embed_dim: int = 2048
|
| 41 |
+
def __init__(self, cfg: Config) -> None:
|
| 42 |
+
"""
|
| 43 |
+
Initializes the DualStreamRoFormer model.
|
| 44 |
+
Args:
|
| 45 |
+
cfg (Config): Configuration object containing model parameters.
|
| 46 |
+
Attributes:
|
| 47 |
+
cfg (Config): Stores the configuration object.
|
| 48 |
+
text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension.
|
| 49 |
+
shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding
|
| 50 |
+
dimension
|
| 51 |
+
vocab_size (int): Vocabulary size for the shape model, including special tokens.
|
| 52 |
+
shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model.
|
| 53 |
+
shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model.
|
| 54 |
+
padding_id (int): Token ID for the padding token.
|
| 55 |
+
transformer (nn.ModuleDict): Dictionary containing the following components:
|
| 56 |
+
- wte (nn.Embedding): Embedding layer for the vocabulary.
|
| 57 |
+
- dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings.
|
| 58 |
+
- single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings.
|
| 59 |
+
- ln_f (LayerNorm): Layer normalization applied to the final output.
|
| 60 |
+
lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
self.cfg = cfg
|
| 66 |
+
|
| 67 |
+
self.text_proj = nn.Linear(
|
| 68 |
+
in_features=self.cfg.text_model_embed_dim,
|
| 69 |
+
out_features=self.cfg.n_embd,
|
| 70 |
+
bias=self.cfg.bias,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd)
|
| 74 |
+
|
| 75 |
+
self.ldr_proj = nn.Linear(self.cfg.ldr_in_embed_dim, self.cfg.n_embd)
|
| 76 |
+
#self.postion_proj = nn.Linear(3, 3)
|
| 77 |
+
|
| 78 |
+
self.vocab_size = self.cfg.shape_model_vocab_size
|
| 79 |
+
|
| 80 |
+
x_num = 251
|
| 81 |
+
y_num = 215
|
| 82 |
+
z_num = 525
|
| 83 |
+
rot_num = 24
|
| 84 |
+
|
| 85 |
+
self.x_num = x_num
|
| 86 |
+
self.y_num = y_num
|
| 87 |
+
self.z_num = z_num
|
| 88 |
+
self.rot_num = rot_num
|
| 89 |
+
|
| 90 |
+
self.x = x_num
|
| 91 |
+
self.xy = x_num + y_num + rot_num
|
| 92 |
+
self.xyz = x_num + y_num + z_num + rot_num
|
| 93 |
+
self.dat_num = 1217 #286 #604
|
| 94 |
+
self.dte = nn.Embedding(
|
| 95 |
+
self.dat_num+1,
|
| 96 |
+
#(self.cfg.n_embd-768),
|
| 97 |
+
self.cfg.n_embd,
|
| 98 |
+
padding_idx=self.dat_num,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.rte = nn.Embedding(
|
| 102 |
+
self.rot_num+2,
|
| 103 |
+
#(self.cfg.n_embd-768),
|
| 104 |
+
self.cfg.n_embd,
|
| 105 |
+
padding_idx=self.rot_num,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.xte = nn.Embedding(
|
| 109 |
+
self.x_num+2,
|
| 110 |
+
#(self.cfg.n_embd-768),
|
| 111 |
+
self.cfg.n_embd,
|
| 112 |
+
padding_idx=self.x_num,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.yte = nn.Embedding(
|
| 116 |
+
self.y_num+2,
|
| 117 |
+
#(self.cfg.n_embd-768),
|
| 118 |
+
self.cfg.n_embd,
|
| 119 |
+
padding_idx=self.y_num,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.zte = nn.Embedding(
|
| 123 |
+
self.z_num+2,
|
| 124 |
+
#(self.cfg.n_embd-768),
|
| 125 |
+
self.cfg.n_embd,
|
| 126 |
+
padding_idx=self.z_num,
|
| 127 |
+
)
|
| 128 |
+
self.is_compute = False
|
| 129 |
+
|
| 130 |
+
def add_special_token():
|
| 131 |
+
token_id = self.vocab_size
|
| 132 |
+
self.vocab_size += 1
|
| 133 |
+
return token_id
|
| 134 |
+
|
| 135 |
+
self.shape_bos_id = add_special_token() #16384
|
| 136 |
+
self.shape_eos_id = add_special_token() #16385
|
| 137 |
+
self.padding_id = add_special_token() #16386
|
| 138 |
+
|
| 139 |
+
self.transformer = nn.ModuleDict(
|
| 140 |
+
dict(
|
| 141 |
+
wte=nn.Embedding(
|
| 142 |
+
self.vocab_size,
|
| 143 |
+
self.cfg.n_embd,
|
| 144 |
+
padding_idx=self.padding_id,
|
| 145 |
+
),
|
| 146 |
+
dual_blocks=nn.ModuleList(
|
| 147 |
+
[
|
| 148 |
+
DualStreamDecoderLayerWithRotaryEmbedding.from_config(
|
| 149 |
+
self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1)
|
| 150 |
+
)
|
| 151 |
+
for i in range(self.cfg.n_layer)
|
| 152 |
+
]
|
| 153 |
+
),
|
| 154 |
+
single_blocks=nn.ModuleList(
|
| 155 |
+
[
|
| 156 |
+
DecoderLayerWithRotaryEmbedding.from_config(self.cfg)
|
| 157 |
+
for _ in range(self.cfg.n_single_layer)
|
| 158 |
+
]
|
| 159 |
+
),
|
| 160 |
+
ln_f=LayerNorm(
|
| 161 |
+
self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
|
| 167 |
+
self.ldr_head = nn.Linear(self.cfg.n_embd, self.cfg.ldr_out_embed_dim, bias=False)
|
| 168 |
+
|
| 169 |
+
if self.cfg.use_bbox:
|
| 170 |
+
self.bbox_proj = nn.Linear(3, self.cfg.n_embd)
|
| 171 |
+
|
| 172 |
+
def encode_embed(self, ldr_embed):
|
| 173 |
+
"""
|
| 174 |
+
Encodes the given ldr embeddings by projecting them through a linear transformation.
|
| 175 |
+
Args:
|
| 176 |
+
ldr_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
|
| 177 |
+
Returns:
|
| 178 |
+
torch.Tensor: The projected ldr embeddings after applying the linear transformation.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
return self.ldr_proj(ldr_embed)
|
| 182 |
+
|
| 183 |
+
def encode_text(self, text_embed):
|
| 184 |
+
"""
|
| 185 |
+
Encodes the given text embeddings by projecting them through a linear transformation.
|
| 186 |
+
Args:
|
| 187 |
+
text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
|
| 188 |
+
Returns:
|
| 189 |
+
torch.Tensor: The projected text embeddings after applying the linear transformation.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
return self.text_proj(text_embed)
|
| 193 |
+
|
| 194 |
+
def encode_token(self, tokens):
|
| 195 |
+
"""
|
| 196 |
+
Encodes the input tokens using the word token embedding layer of the transformer model.
|
| 197 |
+
Args:
|
| 198 |
+
tokens (torch.Tensor): A tensor containing the input tokens to be encoded.
|
| 199 |
+
Returns:
|
| 200 |
+
torch.Tensor: A tensor containing the encoded token embeddings.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
return self.transformer.wte(tokens)
|
| 204 |
+
|
| 205 |
+
def init_kv_cache(
|
| 206 |
+
self,
|
| 207 |
+
batch_size: int,
|
| 208 |
+
cond_len: int,
|
| 209 |
+
max_shape_tokens: int,
|
| 210 |
+
dtype: torch.dtype,
|
| 211 |
+
device: torch.device,
|
| 212 |
+
) -> list[Cache]:
|
| 213 |
+
"""
|
| 214 |
+
Initializes the key-value cache for the transformer model.
|
| 215 |
+
This method creates a list of `Cache` objects to store the key and value
|
| 216 |
+
states for both dual-stream and single-stream transformer blocks. The
|
| 217 |
+
cache is pre-allocated with zeros and is used to optimize the computation
|
| 218 |
+
of attention mechanisms during model inference.
|
| 219 |
+
Args:
|
| 220 |
+
batch_size (int): The batch size for the input data.
|
| 221 |
+
cond_len (int): The length of the conditioning sequence.
|
| 222 |
+
max_shape_tokens (int): The maximum number of tokens in the shape sequence.
|
| 223 |
+
dtype (torch.dtype): The data type for the tensors (e.g., torch.float32).
|
| 224 |
+
device (torch.device): The device on which the tensors will be allocated
|
| 225 |
+
(e.g., torch.device('cuda') or torch.device('cpu')).
|
| 226 |
+
Returns:
|
| 227 |
+
list[Cache]: A list of `Cache` objects containing pre-allocated key and
|
| 228 |
+
value states for each transformer block.
|
| 229 |
+
"""
|
| 230 |
+
num_heads = self.cfg.n_head
|
| 231 |
+
max_all_tokens = cond_len + max_shape_tokens
|
| 232 |
+
per_head_dim = self.cfg.n_embd // num_heads
|
| 233 |
+
|
| 234 |
+
kv_cache = [
|
| 235 |
+
Cache(
|
| 236 |
+
key_states=torch.zeros(
|
| 237 |
+
(batch_size, num_heads, max_all_tokens, per_head_dim),
|
| 238 |
+
dtype=dtype,
|
| 239 |
+
device=device,
|
| 240 |
+
),
|
| 241 |
+
value_states=torch.zeros(
|
| 242 |
+
(batch_size, num_heads, max_all_tokens, per_head_dim),
|
| 243 |
+
dtype=dtype,
|
| 244 |
+
device=device,
|
| 245 |
+
),
|
| 246 |
+
)
|
| 247 |
+
for _ in range(len(self.transformer.dual_blocks))
|
| 248 |
+
]
|
| 249 |
+
kv_cache += [
|
| 250 |
+
Cache(
|
| 251 |
+
key_states=torch.zeros(
|
| 252 |
+
(batch_size, num_heads, max_shape_tokens, per_head_dim),
|
| 253 |
+
dtype=dtype,
|
| 254 |
+
device=device,
|
| 255 |
+
),
|
| 256 |
+
value_states=torch.zeros(
|
| 257 |
+
(batch_size, num_heads, max_shape_tokens, per_head_dim),
|
| 258 |
+
dtype=dtype,
|
| 259 |
+
device=device,
|
| 260 |
+
),
|
| 261 |
+
)
|
| 262 |
+
for _ in range(len(self.transformer.single_blocks))
|
| 263 |
+
]
|
| 264 |
+
return kv_cache
|
| 265 |
+
|
| 266 |
+
def forward(
|
| 267 |
+
self,
|
| 268 |
+
embed: torch.Tensor,
|
| 269 |
+
cond: torch.Tensor,
|
| 270 |
+
kv_cache: Optional[list[Cache]] = None,
|
| 271 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
| 272 |
+
decode: bool = False,
|
| 273 |
+
**kwargs,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Forward pass for the dual-stream RoFormer model.
|
| 277 |
+
Args:
|
| 278 |
+
embed (torch.Tensor): The input embedding tensor.
|
| 279 |
+
cond (torch.Tensor): The conditioning tensor.
|
| 280 |
+
kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None.
|
| 281 |
+
curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None.
|
| 282 |
+
decode (bool): Whether the model is in decoding mode. Default is False.
|
| 283 |
+
Returns:
|
| 284 |
+
torch.Tensor: The output logits tensor.
|
| 285 |
+
"""
|
| 286 |
+
b, l = embed.shape[:2]
|
| 287 |
+
s = cond.shape[1]
|
| 288 |
+
device = embed.device
|
| 289 |
+
|
| 290 |
+
# attn_mask = torch.tril(
|
| 291 |
+
# torch.ones(s + l, s + l, dtype=torch.bool, device=device)
|
| 292 |
+
# ) #Causal Attention Mask
|
| 293 |
+
attn_mask = torch.ones(s + l, s + l, dtype=torch.bool, device=device) #Without Attention Mask
|
| 294 |
+
|
| 295 |
+
# positions = torch.arange(s + l, device=device)
|
| 296 |
+
# mask_1d = (positions > 1) & ((positions % 5 == 0) | (positions % 5 == 1) | (positions % 5 == 4))
|
| 297 |
+
# attn_mask[mask_1d, :] = False
|
| 298 |
+
# attn_mask[:, mask_1d] = False
|
| 299 |
+
|
| 300 |
+
position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t)
|
| 301 |
+
position_ids = position_ids.unsqueeze_(0).expand(b, -1)
|
| 302 |
+
#position_ids = position_ids.unsqueeze(0).expand(b, -1)
|
| 303 |
+
|
| 304 |
+
s_freqs_cis = precompute_freqs_cis(
|
| 305 |
+
dim=self.cfg.n_embd // self.cfg.n_head, # 128
|
| 306 |
+
t=position_ids,
|
| 307 |
+
theta=self.cfg.rope_theta, #10000.0
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
position_ids = torch.cat(
|
| 311 |
+
[
|
| 312 |
+
torch.zeros([b, s], dtype=torch.long, device=position_ids.device),
|
| 313 |
+
position_ids,
|
| 314 |
+
],
|
| 315 |
+
dim=1,
|
| 316 |
+
) #full position_ids
|
| 317 |
+
|
| 318 |
+
d_freqs_cis = precompute_freqs_cis(
|
| 319 |
+
dim=self.cfg.n_embd // self.cfg.n_head,
|
| 320 |
+
t=position_ids,
|
| 321 |
+
theta=self.cfg.rope_theta,
|
| 322 |
+
) #full position embedding
|
| 323 |
+
|
| 324 |
+
#import ipdb; ipdb.set_trace()
|
| 325 |
+
if kv_cache is not None and decode:
|
| 326 |
+
assert curr_pos_id is not None
|
| 327 |
+
embed = embed[:, curr_pos_id, :]
|
| 328 |
+
#print(decode)
|
| 329 |
+
|
| 330 |
+
h = embed
|
| 331 |
+
c = cond
|
| 332 |
+
|
| 333 |
+
layer_idx = 0
|
| 334 |
+
for block in self.transformer.dual_blocks:
|
| 335 |
+
h, c = block(
|
| 336 |
+
h,
|
| 337 |
+
c=c,
|
| 338 |
+
freqs_cis=d_freqs_cis,
|
| 339 |
+
attn_mask=attn_mask,
|
| 340 |
+
is_causal=True,
|
| 341 |
+
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
| 342 |
+
#kv_cache=None,
|
| 343 |
+
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
| 344 |
+
decode=decode,
|
| 345 |
+
)
|
| 346 |
+
layer_idx += 1
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
for block in self.transformer.single_blocks:
|
| 350 |
+
h = block(
|
| 351 |
+
h,
|
| 352 |
+
freqs_cis=s_freqs_cis,
|
| 353 |
+
attn_mask=None,
|
| 354 |
+
is_causal=True,
|
| 355 |
+
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
| 356 |
+
#kv_cache=None,
|
| 357 |
+
curr_pos_id=curr_pos_id,
|
| 358 |
+
decode=decode,
|
| 359 |
+
)
|
| 360 |
+
layer_idx += 1
|
| 361 |
+
|
| 362 |
+
#import ipdb; ipdb.set_trace()
|
| 363 |
+
# Normalization
|
| 364 |
+
h = self.transformer.ln_f(h)
|
| 365 |
+
logits = self.ldr_head(h)
|
| 366 |
+
|
| 367 |
+
return logits
|