| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| class GAD2Config(PretrainedConfig): |
| model_type = "gad_v2" |
| def __init__(self, vocab_size=50257, d_model=768, nhead=12, num_layers=12, |
| num_intents=3, memory_size=512, seq_len=1024, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.nhead = nhead |
| self.num_layers = num_layers |
| self.num_intents = num_intents |
| self.memory_size = memory_size |
| self.seq_len = seq_len |
| |
| |
| self.num_hidden_layers = num_layers |
| self.hidden_size = d_model |
| self.num_attention_heads = nhead |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| norm_x = x.pow(2).mean(-1, keepdim=True) |
| return (x * torch.rsqrt(norm_x + self.eps)) * self.weight |
|
|
| class MultiIntentEvolver(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.grus = nn.ModuleList([nn.GRU(config.d_model, config.d_model, batch_first=True) for _ in range(config.num_intents)]) |
| self.fusion = nn.Linear(config.d_model * config.num_intents, config.d_model) |
| def forward(self, x): |
| intents = [gru(x)[0] for gru in self.grus] |
| return self.fusion(torch.cat(intents, dim=-1)) |
|
|
| class AdaptiveMemory(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.memory = nn.Parameter(torch.randn(config.memory_size, config.d_model)) |
| self.q_proj = nn.Linear(config.d_model, config.d_model) |
| self.kv_proj = nn.Linear(config.d_model, config.d_model*2) |
| self.out_proj = nn.Linear(config.d_model, config.d_model) |
| def forward(self, x): |
| b, t, d = x.size() |
| q = self.q_proj(x) |
| kv = self.kv_proj(self.memory).view(-1, 2, d) |
| k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1) |
| attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)) |
| return self.out_proj(attn_out.squeeze(1)) |
|
|
| class GAD2_Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.nhead, self.head_dim = config.nhead, config.d_model // config.nhead |
| self.norm1, self.norm2 = RMSNorm(config.d_model), RMSNorm(config.d_model) |
| self.qkv, self.proj = nn.Linear(config.d_model, config.d_model*3, bias=False), nn.Linear(config.d_model, config.d_model) |
| inner_dim = int(config.d_model * 4 * (2/3)) |
| self.w1, self.w2, self.w3 = nn.Linear(config.d_model, inner_dim, bias=False), nn.Linear(config.d_model, inner_dim, bias=False), nn.Linear(inner_dim, config.d_model, bias=False) |
| def forward(self, x): |
| h = self.norm1(x) |
| b, t, d = h.size() |
| qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4) |
| attn_out = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2], is_causal=True) |
| x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d)) |
| h = self.norm2(x) |
| return x + self.w3(F.silu(self.w1(h)) * self.w2(h)) |
|
|
| class GAD2ForAgenticModeling(PreTrainedModel): |
| config_class = GAD2Config |
| def __init__(self, config): |
| super().__init__(config) |
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
| self.intent_evolver, self.memory_module = MultiIntentEvolver(config), AdaptiveMemory(config) |
| self.blocks = nn.ModuleList([GAD2_Block(config) for _ in range(config.num_layers)]) |
| self.final_norm, self.output_head = RMSNorm(config.d_model), nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None, attention_mask=None, **kwargs): |
| x = self.token_emb(input_ids) |
| x = x + self.intent_evolver(x) + self.memory_module(x) |
| for block in self.blocks: |
| x = block(x) |
| logits = self.output_head(self.final_norm(x)) |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| |
| return {"input_ids": input_ids, **kwargs} |