Upload folder using huggingface_hub
Browse files- __pycache__/configuration_jet_nemotron.cpython-310.pyc +0 -0
- __pycache__/dconv_fwd_cache.cpython-310.pyc +0 -0
- __pycache__/dconv_fwdbwd.cpython-310.pyc +0 -0
- __pycache__/dconv_step.cpython-310.pyc +0 -0
- __pycache__/dynamic_conv.cpython-310.pyc +0 -0
- __pycache__/jet_block.cpython-310.pyc +0 -0
- __pycache__/kv_cache.cpython-310.pyc +0 -0
- jet_block.py +5 -8
__pycache__/configuration_jet_nemotron.cpython-310.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
__pycache__/dconv_fwd_cache.cpython-310.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
__pycache__/dconv_fwdbwd.cpython-310.pyc
ADDED
|
Binary file (6.49 kB). View file
|
|
|
__pycache__/dconv_step.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
__pycache__/dynamic_conv.cpython-310.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
__pycache__/jet_block.cpython-310.pyc
ADDED
|
Binary file (7.07 kB). View file
|
|
|
__pycache__/kv_cache.cpython-310.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
jet_block.py
CHANGED
|
@@ -46,7 +46,7 @@ class JetBlockConfig():
|
|
| 46 |
head_dim: int = 256
|
| 47 |
norm_eps: float = 1e-5
|
| 48 |
conv_size: int = 4
|
| 49 |
-
dconv_generator_reduction: int =
|
| 50 |
dconv_implementation: str = 'triton'
|
| 51 |
|
| 52 |
|
|
@@ -180,24 +180,21 @@ class JetBlock(nn.Module):
|
|
| 180 |
if attention_mask is not None and q_len > 1:
|
| 181 |
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
|
| 182 |
|
| 183 |
-
conv_state = None
|
| 184 |
-
|
| 185 |
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
| 186 |
|
| 187 |
q = F.silu(self.q_proj(hidden_states))
|
| 188 |
k = F.silu(self.k_proj(hidden_states))
|
| 189 |
|
| 190 |
-
|
| 191 |
if last_state is not None:
|
| 192 |
-
|
| 193 |
-
v,
|
| 194 |
x=self.v_proj(hidden_states),
|
| 195 |
generator_input=hidden_states,
|
| 196 |
mask=conv_mask,
|
| 197 |
-
cache=
|
| 198 |
output_final_state=use_cache,
|
| 199 |
)
|
| 200 |
-
conv_state = conv_state + (conv_state_v,) if conv_state is not None else (conv_state_v,)
|
| 201 |
|
| 202 |
if attention_mask is not None and q_len > 1:
|
| 203 |
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices).unsqueeze(0)
|
|
|
|
| 46 |
head_dim: int = 256
|
| 47 |
norm_eps: float = 1e-5
|
| 48 |
conv_size: int = 4
|
| 49 |
+
dconv_generator_reduction: int = 8
|
| 50 |
dconv_implementation: str = 'triton'
|
| 51 |
|
| 52 |
|
|
|
|
| 180 |
if attention_mask is not None and q_len > 1:
|
| 181 |
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
|
| 182 |
|
|
|
|
|
|
|
| 183 |
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
| 184 |
|
| 185 |
q = F.silu(self.q_proj(hidden_states))
|
| 186 |
k = F.silu(self.k_proj(hidden_states))
|
| 187 |
|
| 188 |
+
conv_state = None
|
| 189 |
if last_state is not None:
|
| 190 |
+
conv_state = last_state['conv_state']
|
| 191 |
+
v, conv_state = self.dynamic_conv1d(
|
| 192 |
x=self.v_proj(hidden_states),
|
| 193 |
generator_input=hidden_states,
|
| 194 |
mask=conv_mask,
|
| 195 |
+
cache=conv_state,
|
| 196 |
output_final_state=use_cache,
|
| 197 |
)
|
|
|
|
| 198 |
|
| 199 |
if attention_mask is not None and q_len > 1:
|
| 200 |
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices).unsqueeze(0)
|