phammminhhieu commited on
Commit
8faef29
·
verified ·
1 Parent(s): f0dfd5e

Add model code

Browse files
Files changed (1) hide show
  1. modeling_mamba_hypernetwork.py +72 -0
modeling_mamba_hypernetwork.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from mamba_ssm import Mamba
5
+
6
+ class MambaHypernetwork(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+
10
+ vocab_size = config["vocab_size"]
11
+ hidden_dim = config["hidden_dim"]
12
+ state_dim = config["state_dim"]
13
+ expand = config["expand"]
14
+ num_llm_layers = config["num_llm_layers"]
15
+ lora_rank = config["lora_rank"]
16
+ q_proj_dim = config["q_proj_dim"]
17
+ v_proj_dim = config["v_proj_dim"]
18
+
19
+ self.hidden_dim = hidden_dim
20
+ self.num_llm_layers = num_llm_layers
21
+ self.lora_rank = lora_rank
22
+ self.q_proj_dim = q_proj_dim
23
+ self.v_proj_dim = v_proj_dim
24
+
25
+ self.embedding = nn.Embedding(vocab_size, hidden_dim)
26
+ self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand)
27
+
28
+ self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim)
29
+ self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim)
30
+ self.combine = nn.Sequential(
31
+ nn.Linear(2 * hidden_dim, hidden_dim),
32
+ nn.LayerNorm(hidden_dim),
33
+ nn.ReLU(),
34
+ nn.Linear(hidden_dim, hidden_dim),
35
+ )
36
+
37
+ self.delta_heads = nn.ModuleList([
38
+ nn.ModuleDict({
39
+ "q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank),
40
+ "q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
41
+ "v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
42
+ "v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank),
43
+ })
44
+ for _ in range(num_llm_layers)
45
+ ])
46
+
47
+ def encode_text(self, input_ids, attention_mask):
48
+ emb = self.embedding(input_ids)
49
+ mamba_out = self.mamba(emb)
50
+ mask_expanded = attention_mask.unsqueeze(-1).float()
51
+ masked_out = mamba_out * mask_expanded
52
+ sum_out = masked_out.sum(dim=1)
53
+ count = mask_expanded.sum(dim=1)
54
+ mean_pooled = sum_out / (count + 1e-8)
55
+ masked_out_for_max = masked_out.clone()
56
+ masked_out_for_max[attention_mask == 0] = float('-inf')
57
+ max_pooled = masked_out_for_max.max(dim=1).values
58
+ pooled = torch.cat([mean_pooled, max_pooled], dim=-1)
59
+ return pooled
60
+
61
+ def forward(self, persona_ids, persona_mask, history_ids, history_mask):
62
+ persona_feat = self.encode_text(persona_ids, persona_mask)
63
+ persona_feat = self.persona_proj(persona_feat)
64
+ history_feat = self.encode_text(history_ids, history_mask)
65
+ history_feat = self.history_proj(history_feat)
66
+ combined = torch.cat([persona_feat, history_feat], dim=-1)
67
+ combined = self.combine(combined)
68
+ all_deltas = []
69
+ for head in self.delta_heads:
70
+ layer_deltas = {name: head[name](combined) for name in head}
71
+ all_deltas.append(layer_deltas)
72
+ return all_deltas