File size: 2,925 Bytes
8faef29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

import torch
import torch.nn as nn
from mamba_ssm import Mamba

class MambaHypernetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        vocab_size = config["vocab_size"]
        hidden_dim = config["hidden_dim"]
        state_dim = config["state_dim"]
        expand = config["expand"]
        num_llm_layers = config["num_llm_layers"]
        lora_rank = config["lora_rank"]
        q_proj_dim = config["q_proj_dim"]
        v_proj_dim = config["v_proj_dim"]
        
        self.hidden_dim = hidden_dim
        self.num_llm_layers = num_llm_layers
        self.lora_rank = lora_rank
        self.q_proj_dim = q_proj_dim
        self.v_proj_dim = v_proj_dim
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand)
        
        self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim)
        self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim)
        self.combine = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        self.delta_heads = nn.ModuleList([
            nn.ModuleDict({
                "q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank),
                "q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
                "v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim),
                "v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank),
            })
            for _ in range(num_llm_layers)
        ])
    
    def encode_text(self, input_ids, attention_mask):
        emb = self.embedding(input_ids)
        mamba_out = self.mamba(emb)
        mask_expanded = attention_mask.unsqueeze(-1).float()
        masked_out = mamba_out * mask_expanded
        sum_out = masked_out.sum(dim=1)
        count = mask_expanded.sum(dim=1)
        mean_pooled = sum_out / (count + 1e-8)
        masked_out_for_max = masked_out.clone()
        masked_out_for_max[attention_mask == 0] = float('-inf')
        max_pooled = masked_out_for_max.max(dim=1).values
        pooled = torch.cat([mean_pooled, max_pooled], dim=-1)
        return pooled
    
    def forward(self, persona_ids, persona_mask, history_ids, history_mask):
        persona_feat = self.encode_text(persona_ids, persona_mask)
        persona_feat = self.persona_proj(persona_feat)
        history_feat = self.encode_text(history_ids, history_mask)
        history_feat = self.history_proj(history_feat)
        combined = torch.cat([persona_feat, history_feat], dim=-1)
        combined = self.combine(combined)
        all_deltas = []
        for head in self.delta_heads:
            layer_deltas = {name: head[name](combined) for name in head}
            all_deltas.append(layer_deltas)
        return all_deltas