t1101675 commited on
Commit
cf9bc10
·
verified ·
1 Parent(s): f62ec09

Upload folder using huggingface_hub

Browse files
__pycache__/configuration_jet_nemotron.cpython-310.pyc ADDED
Binary file (8.29 kB). View file
 
__pycache__/dconv_fwd_cache.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
__pycache__/dconv_fwdbwd.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
__pycache__/dconv_step.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
__pycache__/dynamic_conv.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
__pycache__/jet_block.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
__pycache__/kv_cache.cpython-310.pyc ADDED
Binary file (6.23 kB). View file
 
jet_block.py CHANGED
@@ -46,7 +46,7 @@ class JetBlockConfig():
46
  head_dim: int = 256
47
  norm_eps: float = 1e-5
48
  conv_size: int = 4
49
- dconv_generator_reduction: int = None
50
  dconv_implementation: str = 'triton'
51
 
52
 
@@ -180,24 +180,21 @@ class JetBlock(nn.Module):
180
  if attention_mask is not None and q_len > 1:
181
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
182
 
183
- conv_state = None
184
-
185
  conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
 
187
  q = F.silu(self.q_proj(hidden_states))
188
  k = F.silu(self.k_proj(hidden_states))
189
 
190
- conv_state_v = None
191
  if last_state is not None:
192
- conv_state_v = last_state['conv_state'][-1]
193
- v, conv_state_v = self.dynamic_conv1d(
194
  x=self.v_proj(hidden_states),
195
  generator_input=hidden_states,
196
  mask=conv_mask,
197
- cache=conv_state_v,
198
  output_final_state=use_cache,
199
  )
200
- conv_state = conv_state + (conv_state_v,) if conv_state is not None else (conv_state_v,)
201
 
202
  if attention_mask is not None and q_len > 1:
203
  q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices).unsqueeze(0)
 
46
  head_dim: int = 256
47
  norm_eps: float = 1e-5
48
  conv_size: int = 4
49
+ dconv_generator_reduction: int = 8
50
  dconv_implementation: str = 'triton'
51
 
52
 
 
180
  if attention_mask is not None and q_len > 1:
181
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
182
 
 
 
183
  conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
184
 
185
  q = F.silu(self.q_proj(hidden_states))
186
  k = F.silu(self.k_proj(hidden_states))
187
 
188
+ conv_state = None
189
  if last_state is not None:
190
+ conv_state = last_state['conv_state']
191
+ v, conv_state = self.dynamic_conv1d(
192
  x=self.v_proj(hidden_states),
193
  generator_input=hidden_states,
194
  mask=conv_mask,
195
+ cache=conv_state,
196
  output_final_state=use_cache,
197
  )
 
198
 
199
  if attention_mask is not None and q_len > 1:
200
  q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices).unsqueeze(0)