fix modeling_plamo.py
#1
by
ssuzuki65
- opened
- modeling_plamo.py +17 -27
modeling_plamo.py
CHANGED
|
@@ -240,6 +240,8 @@ class PlamoCache(torch.nn.Module):
|
|
| 240 |
|
| 241 |
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 242 |
c = self.cache[layer_idx]
|
|
|
|
|
|
|
| 243 |
assert isinstance(c, PlamoAttentionCache)
|
| 244 |
|
| 245 |
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
|
@@ -257,11 +259,17 @@ class PlamoCache(torch.nn.Module):
|
|
| 257 |
def update_attention(
|
| 258 |
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
| 259 |
) -> PlamoAttentionCache:
|
|
|
|
|
|
|
|
|
|
| 260 |
if self.cache[layer_idx] is None:
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
else:
|
| 263 |
-
full_attn = layer_idx in self.config.full_attention_idx
|
| 264 |
-
window_size = self.config.attention_window_size
|
| 265 |
c = self.cache[layer_idx]
|
| 266 |
assert isinstance(c, PlamoAttentionCache)
|
| 267 |
k, v = self.append_kv(key_states, value_states, layer_idx)
|
|
@@ -968,15 +976,6 @@ class Attention(torch.nn.Module):
|
|
| 968 |
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
| 969 |
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
| 970 |
|
| 971 |
-
if past_states is not None and past_states[self.layer_idx] is None:
|
| 972 |
-
bsz, nhead_k, _, c_k = key_states.shape
|
| 973 |
-
_, nhead_v, _, c_v = value_states.shape
|
| 974 |
-
past_states.update_attention(
|
| 975 |
-
torch.zeros((bsz, nhead_k, 0, c_k), dtype=key_states.dtype, device=key_states.device),
|
| 976 |
-
torch.zeros((bsz, nhead_v, 0, c_v), dtype=value_states.dtype, device=value_states.device),
|
| 977 |
-
self.layer_idx,
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
if past_states is not None:
|
| 981 |
# reuse k, v, self_attention
|
| 982 |
key_states_new = key_states
|
|
@@ -1154,6 +1153,7 @@ class PlamoDecoder(torch.nn.Module):
|
|
| 1154 |
for i in range(config.num_hidden_layers)
|
| 1155 |
]
|
| 1156 |
)
|
|
|
|
| 1157 |
|
| 1158 |
def forward(self, x: DecoderInput) -> DecoderOutput:
|
| 1159 |
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
|
@@ -1166,19 +1166,12 @@ class PlamoDecoder(torch.nn.Module):
|
|
| 1166 |
all_hidden_states += (hidden_states,)
|
| 1167 |
|
| 1168 |
if self.training and x.gradient_checkpointing:
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
def custom_forward(*inputs): # type: ignore
|
| 1172 |
-
# None for past_key_value
|
| 1173 |
-
return module(*inputs, x.output_attentions, None)
|
| 1174 |
-
|
| 1175 |
-
return custom_forward
|
| 1176 |
-
|
| 1177 |
-
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 1178 |
-
create_custom_forward(decoder_layer), # type: ignore
|
| 1179 |
hidden_states,
|
| 1180 |
x.attention_mask,
|
| 1181 |
-
|
|
|
|
| 1182 |
)
|
| 1183 |
else:
|
| 1184 |
layer_outputs = decoder_layer(
|
|
@@ -1217,9 +1210,6 @@ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
|
|
| 1217 |
if module.padding_idx is not None:
|
| 1218 |
module.weight.data[module.padding_idx].zero_()
|
| 1219 |
|
| 1220 |
-
def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
|
| 1221 |
-
module.gradient_checkpointing = value # type: ignore
|
| 1222 |
-
|
| 1223 |
|
| 1224 |
class PlamoModel(PlamoPreTrainedModel):
|
| 1225 |
def __init__(self, config: PlamoConfig):
|
|
@@ -1613,4 +1603,4 @@ class Bias(nn.Module):
|
|
| 1613 |
self,
|
| 1614 |
x: torch.Tensor,
|
| 1615 |
) -> torch.Tensor:
|
| 1616 |
-
return x + self._bias
|
|
|
|
| 240 |
|
| 241 |
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 242 |
c = self.cache[layer_idx]
|
| 243 |
+
if c is None:
|
| 244 |
+
return key, value
|
| 245 |
assert isinstance(c, PlamoAttentionCache)
|
| 246 |
|
| 247 |
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
|
|
|
| 259 |
def update_attention(
|
| 260 |
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
| 261 |
) -> PlamoAttentionCache:
|
| 262 |
+
full_attn = layer_idx in self.config.full_attention_idx
|
| 263 |
+
window_size = self.config.attention_window_size
|
| 264 |
+
|
| 265 |
if self.cache[layer_idx] is None:
|
| 266 |
+
if full_attn:
|
| 267 |
+
self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
|
| 268 |
+
else:
|
| 269 |
+
self.cache[layer_idx] = PlamoAttentionCache(
|
| 270 |
+
key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
|
| 271 |
+
)
|
| 272 |
else:
|
|
|
|
|
|
|
| 273 |
c = self.cache[layer_idx]
|
| 274 |
assert isinstance(c, PlamoAttentionCache)
|
| 275 |
k, v = self.append_kv(key_states, value_states, layer_idx)
|
|
|
|
| 976 |
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
| 977 |
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
| 978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
if past_states is not None:
|
| 980 |
# reuse k, v, self_attention
|
| 981 |
key_states_new = key_states
|
|
|
|
| 1153 |
for i in range(config.num_hidden_layers)
|
| 1154 |
]
|
| 1155 |
)
|
| 1156 |
+
self.gradient_checkpointing = False
|
| 1157 |
|
| 1158 |
def forward(self, x: DecoderInput) -> DecoderOutput:
|
| 1159 |
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
|
|
|
| 1166 |
all_hidden_states += (hidden_states,)
|
| 1167 |
|
| 1168 |
if self.training and x.gradient_checkpointing:
|
| 1169 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1170 |
+
decoder_layer.__call__,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
hidden_states,
|
| 1172 |
x.attention_mask,
|
| 1173 |
+
x.past_states,
|
| 1174 |
+
x.output_attentions,
|
| 1175 |
)
|
| 1176 |
else:
|
| 1177 |
layer_outputs = decoder_layer(
|
|
|
|
| 1210 |
if module.padding_idx is not None:
|
| 1211 |
module.weight.data[module.padding_idx].zero_()
|
| 1212 |
|
|
|
|
|
|
|
|
|
|
| 1213 |
|
| 1214 |
class PlamoModel(PlamoPreTrainedModel):
|
| 1215 |
def __init__(self, config: PlamoConfig):
|
|
|
|
| 1603 |
self,
|
| 1604 |
x: torch.Tensor,
|
| 1605 |
) -> torch.Tensor:
|
| 1606 |
+
return x + self._bias
|