Upload modeling_nsa.py with huggingface_hub
Browse files- modeling_nsa.py +2 -23
modeling_nsa.py
CHANGED
|
@@ -9,12 +9,7 @@ from transformers.generation.utils import GenerationMixin
|
|
| 9 |
from transformers.modeling_outputs import CausalLMOutput
|
| 10 |
|
| 11 |
from .configuration_nsa import NSAConfig
|
| 12 |
-
_HAS_NSA = False
|
| 13 |
-
try:
|
| 14 |
-
from .nsa.model.llama_block_nsa import LlamaBlockNSA as _VendorNSABlock
|
| 15 |
-
_HAS_NSA = True
|
| 16 |
-
except Exception:
|
| 17 |
-
_VendorNSABlock = None # type: ignore
|
| 18 |
|
| 19 |
|
| 20 |
class RMSNorm(nn.Module):
|
|
@@ -262,23 +257,7 @@ class NSATinyLM(nn.Module):
|
|
| 262 |
import os as _os
|
| 263 |
# Allow forcing simple fallback via env for integration tests
|
| 264 |
_force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
|
| 265 |
-
if not _force_simple
|
| 266 |
-
# Prefer vendored NSA block to match training semantics and map gate weights
|
| 267 |
-
self.blocks = nn.ModuleList([
|
| 268 |
-
_VendorNSABlock(
|
| 269 |
-
dim=self.hidden_size,
|
| 270 |
-
n_heads=self.num_attention_heads,
|
| 271 |
-
n_kv_groups=self.n_kv_groups,
|
| 272 |
-
d_k=self.d_k,
|
| 273 |
-
d_v=self.d_v,
|
| 274 |
-
l=self.l,
|
| 275 |
-
d=self.d,
|
| 276 |
-
l_sel=self.l_sel,
|
| 277 |
-
n_sel=self.n_sel,
|
| 278 |
-
w=self.w,
|
| 279 |
-
) for _ in range(self.num_hidden_layers)
|
| 280 |
-
])
|
| 281 |
-
elif not _force_simple:
|
| 282 |
# Fallback to embedded minimal NSA if vendor import failed
|
| 283 |
self.blocks = nn.ModuleList([
|
| 284 |
NSABlockRemote(
|
|
|
|
| 9 |
from transformers.modeling_outputs import CausalLMOutput
|
| 10 |
|
| 11 |
from .configuration_nsa import NSAConfig
|
| 12 |
+
_HAS_NSA = False # avoid nested vendor imports in HF dynamic loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class RMSNorm(nn.Module):
|
|
|
|
| 257 |
import os as _os
|
| 258 |
# Allow forcing simple fallback via env for integration tests
|
| 259 |
_force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
|
| 260 |
+
if not _force_simple:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
# Fallback to embedded minimal NSA if vendor import failed
|
| 262 |
self.blocks = nn.ModuleList([
|
| 263 |
NSABlockRemote(
|