GAD-2 / modeling_gad2.py
Raziel1234's picture
Update modeling_gad2.py
d563af6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class GAD2Config(PretrainedConfig):
model_type = "gad_v2"
def __init__(self, vocab_size=50257, d_model=768, nhead=12, num_layers=12,
num_intents=3, memory_size=512, seq_len=1024, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.num_layers = num_layers
self.num_intents = num_intents
self.memory_size = memory_size
self.seq_len = seq_len
# --- 转讬拽讜谉 讛砖讙讬讗讛: 讛讜住驻转 砖诪讜转 住讟谞讚专讟讬讬诐 砖讛住驻专讬讬讛 诪讞驻砖转 ---
self.num_hidden_layers = num_layers
self.hidden_size = d_model
self.num_attention_heads = nhead
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm_x = x.pow(2).mean(-1, keepdim=True)
return (x * torch.rsqrt(norm_x + self.eps)) * self.weight
class MultiIntentEvolver(nn.Module):
def __init__(self, config):
super().__init__()
self.grus = nn.ModuleList([nn.GRU(config.d_model, config.d_model, batch_first=True) for _ in range(config.num_intents)])
self.fusion = nn.Linear(config.d_model * config.num_intents, config.d_model)
def forward(self, x):
intents = [gru(x)[0] for gru in self.grus]
return self.fusion(torch.cat(intents, dim=-1))
class AdaptiveMemory(nn.Module):
def __init__(self, config):
super().__init__()
self.memory = nn.Parameter(torch.randn(config.memory_size, config.d_model))
self.q_proj = nn.Linear(config.d_model, config.d_model)
self.kv_proj = nn.Linear(config.d_model, config.d_model*2)
self.out_proj = nn.Linear(config.d_model, config.d_model)
def forward(self, x):
b, t, d = x.size()
q = self.q_proj(x)
kv = self.kv_proj(self.memory).view(-1, 2, d)
k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1)
attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
return self.out_proj(attn_out.squeeze(1))
class GAD2_Block(nn.Module):
def __init__(self, config):
super().__init__()
self.nhead, self.head_dim = config.nhead, config.d_model // config.nhead
self.norm1, self.norm2 = RMSNorm(config.d_model), RMSNorm(config.d_model)
self.qkv, self.proj = nn.Linear(config.d_model, config.d_model*3, bias=False), nn.Linear(config.d_model, config.d_model)
inner_dim = int(config.d_model * 4 * (2/3))
self.w1, self.w2, self.w3 = nn.Linear(config.d_model, inner_dim, bias=False), nn.Linear(config.d_model, inner_dim, bias=False), nn.Linear(inner_dim, config.d_model, bias=False)
def forward(self, x):
h = self.norm1(x)
b, t, d = h.size()
qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4)
attn_out = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2], is_causal=True)
x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d))
h = self.norm2(x)
return x + self.w3(F.silu(self.w1(h)) * self.w2(h))
class GAD2ForAgenticModeling(PreTrainedModel):
config_class = GAD2Config
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.intent_evolver, self.memory_module = MultiIntentEvolver(config), AdaptiveMemory(config)
self.blocks = nn.ModuleList([GAD2_Block(config) for _ in range(config.num_layers)])
self.final_norm, self.output_head = RMSNorm(config.d_model), nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
x = self.token_emb(input_ids)
x = x + self.intent_evolver(x) + self.memory_module(x)
for block in self.blocks:
x = block(x)
logits = self.output_head(self.final_norm(x))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# 讞砖讜讘 诪讗讜讚: 讛讜住驻转 转诪讬讻讛 讘-kwargs 讻讚讬 砖-generate 诇讗 讬转讘诇讘诇
return {"input_ids": input_ids, **kwargs}