import torch import torch.nn as nn from mamba_ssm import Mamba class MambaHypernetwork(nn.Module): def __init__(self, config): super().__init__() vocab_size = config["vocab_size"] hidden_dim = config["hidden_dim"] state_dim = config["state_dim"] expand = config["expand"] num_llm_layers = config["num_llm_layers"] lora_rank = config["lora_rank"] q_proj_dim = config["q_proj_dim"] v_proj_dim = config["v_proj_dim"] self.hidden_dim = hidden_dim self.num_llm_layers = num_llm_layers self.lora_rank = lora_rank self.q_proj_dim = q_proj_dim self.v_proj_dim = v_proj_dim self.embedding = nn.Embedding(vocab_size, hidden_dim) self.mamba = Mamba(d_model=hidden_dim, d_state=state_dim, d_conv=4, expand=expand) self.persona_proj = nn.Linear(2 * hidden_dim, hidden_dim) self.history_proj = nn.Linear(2 * hidden_dim, hidden_dim) self.combine = nn.Sequential( nn.Linear(2 * hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), ) self.delta_heads = nn.ModuleList([ nn.ModuleDict({ "q_proj_A": nn.Linear(hidden_dim, q_proj_dim * lora_rank), "q_proj_B": nn.Linear(hidden_dim, lora_rank * q_proj_dim), "v_proj_A": nn.Linear(hidden_dim, lora_rank * q_proj_dim), "v_proj_B": nn.Linear(hidden_dim, v_proj_dim * lora_rank), }) for _ in range(num_llm_layers) ]) def encode_text(self, input_ids, attention_mask): emb = self.embedding(input_ids) mamba_out = self.mamba(emb) mask_expanded = attention_mask.unsqueeze(-1).float() masked_out = mamba_out * mask_expanded sum_out = masked_out.sum(dim=1) count = mask_expanded.sum(dim=1) mean_pooled = sum_out / (count + 1e-8) masked_out_for_max = masked_out.clone() masked_out_for_max[attention_mask == 0] = float('-inf') max_pooled = masked_out_for_max.max(dim=1).values pooled = torch.cat([mean_pooled, max_pooled], dim=-1) return pooled def forward(self, persona_ids, persona_mask, history_ids, history_mask): persona_feat = self.encode_text(persona_ids, persona_mask) persona_feat = self.persona_proj(persona_feat) history_feat = self.encode_text(history_ids, history_mask) history_feat = self.history_proj(history_feat) combined = torch.cat([persona_feat, history_feat], dim=-1) combined = self.combine(combined) all_deltas = [] for head in self.delta_heads: layer_deltas = {name: head[name](combined) for name in head} all_deltas.append(layer_deltas) return all_deltas