Training in progress - step 1000
Browse files- config.json +1 -1
- model.safetensors +2 -2
- projectors.py +9 -0
config.json
CHANGED
|
@@ -179,7 +179,7 @@
|
|
| 179 |
"projector_hidden_dim": null,
|
| 180 |
"projector_init_std": 0.02,
|
| 181 |
"projector_num_layers": 2,
|
| 182 |
-
"projector_pool_stride":
|
| 183 |
"projector_type": "mosa",
|
| 184 |
"qformer_hidden_size": null,
|
| 185 |
"qformer_intermediate_size": null,
|
|
|
|
| 179 |
"projector_hidden_dim": null,
|
| 180 |
"projector_init_std": 0.02,
|
| 181 |
"projector_num_layers": 2,
|
| 182 |
+
"projector_pool_stride": 6,
|
| 183 |
"projector_type": "mosa",
|
| 184 |
"qformer_hidden_size": null,
|
| 185 |
"qformer_intermediate_size": null,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2427b0eb3d43c1582d4e52125d39843fa62119c65b9d4b79f410f0de1c4386da
|
| 3 |
+
size 320134160
|
projectors.py
CHANGED
|
@@ -104,6 +104,11 @@ class MOSAProjector(nn.Module):
|
|
| 104 |
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 105 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Frame stacking: concat k adjacent frames then project
|
| 108 |
in_dim = self.encoder_dim * self.k
|
| 109 |
|
|
@@ -126,6 +131,10 @@ class MOSAProjector(nn.Module):
|
|
| 126 |
# x: (B, S, encoder_dim)
|
| 127 |
batch_size, seq_len, dim = x.shape
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# --- 1. Router Branch ---
|
| 130 |
# Mean pool encoder outputs for routing decisions
|
| 131 |
x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
|
|
|
|
| 104 |
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 105 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 106 |
|
| 107 |
+
# Optional pre-norm before projection
|
| 108 |
+
self.use_pre_norm = getattr(config, "projector_pre_norm", False)
|
| 109 |
+
if self.use_pre_norm:
|
| 110 |
+
self.pre_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
|
| 111 |
+
|
| 112 |
# Frame stacking: concat k adjacent frames then project
|
| 113 |
in_dim = self.encoder_dim * self.k
|
| 114 |
|
|
|
|
| 131 |
# x: (B, S, encoder_dim)
|
| 132 |
batch_size, seq_len, dim = x.shape
|
| 133 |
|
| 134 |
+
# Apply pre-norm if enabled
|
| 135 |
+
if self.use_pre_norm:
|
| 136 |
+
x = self.pre_norm(x)
|
| 137 |
+
|
| 138 |
# --- 1. Router Branch ---
|
| 139 |
# Mean pool encoder outputs for routing decisions
|
| 140 |
x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
|