seconds-0 commited on
Commit
883d17e
·
verified ·
1 Parent(s): 622dbbd

Upload modeling_nsa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_nsa.py +2 -23
modeling_nsa.py CHANGED
@@ -9,12 +9,7 @@ from transformers.generation.utils import GenerationMixin
9
  from transformers.modeling_outputs import CausalLMOutput
10
 
11
  from .configuration_nsa import NSAConfig
12
- _HAS_NSA = False
13
- try:
14
- from .nsa.model.llama_block_nsa import LlamaBlockNSA as _VendorNSABlock
15
- _HAS_NSA = True
16
- except Exception:
17
- _VendorNSABlock = None # type: ignore
18
 
19
 
20
  class RMSNorm(nn.Module):
@@ -262,23 +257,7 @@ class NSATinyLM(nn.Module):
262
  import os as _os
263
  # Allow forcing simple fallback via env for integration tests
264
  _force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
265
- if not _force_simple and _HAS_NSA and _VendorNSABlock is not None:
266
- # Prefer vendored NSA block to match training semantics and map gate weights
267
- self.blocks = nn.ModuleList([
268
- _VendorNSABlock(
269
- dim=self.hidden_size,
270
- n_heads=self.num_attention_heads,
271
- n_kv_groups=self.n_kv_groups,
272
- d_k=self.d_k,
273
- d_v=self.d_v,
274
- l=self.l,
275
- d=self.d,
276
- l_sel=self.l_sel,
277
- n_sel=self.n_sel,
278
- w=self.w,
279
- ) for _ in range(self.num_hidden_layers)
280
- ])
281
- elif not _force_simple:
282
  # Fallback to embedded minimal NSA if vendor import failed
283
  self.blocks = nn.ModuleList([
284
  NSABlockRemote(
 
9
  from transformers.modeling_outputs import CausalLMOutput
10
 
11
  from .configuration_nsa import NSAConfig
12
+ _HAS_NSA = False # avoid nested vendor imports in HF dynamic loader
 
 
 
 
 
13
 
14
 
15
  class RMSNorm(nn.Module):
 
257
  import os as _os
258
  # Allow forcing simple fallback via env for integration tests
259
  _force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
260
+ if not _force_simple:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  # Fallback to embedded minimal NSA if vendor import failed
262
  self.blocks = nn.ModuleList([
263
  NSABlockRemote(