Add model code
Browse files
modeling_mamba_hypernetwork.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mamba_ssm import Mamba
|
| 5 |
+
|
| 6 |
+
class MambaHypernetwork(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
vocab_size = config["vocab_size"]
|
| 11 |
+
hidden_dim = config["hidden_dim"]
|
| 12 |
+
state_dim = config["state_dim"]
|
| 13 |
+
expand = config["expand"]
|
| 14 |
+
num_llm_layers = config["num_llm_layers"]
|
| 15 |
+
lora_rank = config["lora_rank"]
|
| 16 |
+
q_proj_dim = config["q_proj_dim"]
|
| 17 |
+
v_proj_dim = config["v_proj_dim"]
|
| 18 |
+
|
| 19 |
+
self.hidden_dim = hidden_dim
|
| 20 |
+
self.num_llm_layers = num_llm_layers
|
| 21 |
+
self.lora_rank = lora_rank
|
| 22 |
+
self.q_proj_dim = q_proj_dim
|
| 23 |
+
self.v_proj_dim = v_proj_dim
|
| 24 |
+
|
| 25 |
+
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
| 26 |
+
self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand)
|
| 27 |
+
|
| 28 |
+
self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim)
|
| 29 |
+
self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim)
|
| 30 |
+
self.combine = nn.Sequential(
|
| 31 |
+
nn.Linear(2 * hidden_dim, hidden_dim),
|
| 32 |
+
nn.LayerNorm(hidden_dim),
|
| 33 |
+
nn.ReLU(),
|
| 34 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.delta_heads = nn.ModuleList([
|
| 38 |
+
nn.ModuleDict({
|
| 39 |
+
"q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank),
|
| 40 |
+
"q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
|
| 41 |
+
"v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
|
| 42 |
+
"v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank),
|
| 43 |
+
})
|
| 44 |
+
for _ in range(num_llm_layers)
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
def encode_text(self, input_ids, attention_mask):
|
| 48 |
+
emb = self.embedding(input_ids)
|
| 49 |
+
mamba_out = self.mamba(emb)
|
| 50 |
+
mask_expanded = attention_mask.unsqueeze(-1).float()
|
| 51 |
+
masked_out = mamba_out * mask_expanded
|
| 52 |
+
sum_out = masked_out.sum(dim=1)
|
| 53 |
+
count = mask_expanded.sum(dim=1)
|
| 54 |
+
mean_pooled = sum_out / (count + 1e-8)
|
| 55 |
+
masked_out_for_max = masked_out.clone()
|
| 56 |
+
masked_out_for_max[attention_mask == 0] = float('-inf')
|
| 57 |
+
max_pooled = masked_out_for_max.max(dim=1).values
|
| 58 |
+
pooled = torch.cat([mean_pooled, max_pooled], dim=-1)
|
| 59 |
+
return pooled
|
| 60 |
+
|
| 61 |
+
def forward(self, persona_ids, persona_mask, history_ids, history_mask):
|
| 62 |
+
persona_feat = self.encode_text(persona_ids, persona_mask)
|
| 63 |
+
persona_feat = self.persona_proj(persona_feat)
|
| 64 |
+
history_feat = self.encode_text(history_ids, history_mask)
|
| 65 |
+
history_feat = self.history_proj(history_feat)
|
| 66 |
+
combined = torch.cat([persona_feat, history_feat], dim=-1)
|
| 67 |
+
combined = self.combine(combined)
|
| 68 |
+
all_deltas = []
|
| 69 |
+
for head in self.delta_heads:
|
| 70 |
+
layer_deltas = {name: head[name](combined) for name in head}
|
| 71 |
+
all_deltas.append(layer_deltas)
|
| 72 |
+
return all_deltas
|