mazesmazes commited on
Commit
8c76fb5
·
verified ·
1 Parent(s): 3a57705

Training in progress - step 1000

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +2 -2
  3. projectors.py +9 -0
config.json CHANGED
@@ -179,7 +179,7 @@
179
  "projector_hidden_dim": null,
180
  "projector_init_std": 0.02,
181
  "projector_num_layers": 2,
182
- "projector_pool_stride": 2,
183
  "projector_type": "mosa",
184
  "qformer_hidden_size": null,
185
  "qformer_intermediate_size": null,
 
179
  "projector_hidden_dim": null,
180
  "projector_init_std": 0.02,
181
  "projector_num_layers": 2,
182
+ "projector_pool_stride": 6,
183
  "projector_type": "mosa",
184
  "qformer_hidden_size": null,
185
  "qformer_intermediate_size": null,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c92dfb13651aa8309de3cde22e20237407f7593163b9c3d668c5be22a99be7e
3
- size 152361992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2427b0eb3d43c1582d4e52125d39843fa62119c65b9d4b79f410f0de1c4386da
3
+ size 320134160
projectors.py CHANGED
@@ -104,6 +104,11 @@ class MOSAProjector(nn.Module):
104
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
105
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
106
 
 
 
 
 
 
107
  # Frame stacking: concat k adjacent frames then project
108
  in_dim = self.encoder_dim * self.k
109
 
@@ -126,6 +131,10 @@ class MOSAProjector(nn.Module):
126
  # x: (B, S, encoder_dim)
127
  batch_size, seq_len, dim = x.shape
128
 
 
 
 
 
129
  # --- 1. Router Branch ---
130
  # Mean pool encoder outputs for routing decisions
131
  x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
 
104
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
105
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
106
 
107
+ # Optional pre-norm before projection
108
+ self.use_pre_norm = getattr(config, "projector_pre_norm", False)
109
+ if self.use_pre_norm:
110
+ self.pre_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
111
+
112
  # Frame stacking: concat k adjacent frames then project
113
  in_dim = self.encoder_dim * self.k
114
 
 
131
  # x: (B, S, encoder_dim)
132
  batch_size, seq_len, dim = x.shape
133
 
134
+ # Apply pre-norm if enabled
135
+ if self.use_pre_norm:
136
+ x = self.pre_norm(x)
137
+
138
  # --- 1. Router Branch ---
139
  # Mean pool encoder outputs for routing decisions
140
  x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)