Update modeling_gad2.py
Browse files- modeling_gad2.py +12 -32
modeling_gad2.py
CHANGED
|
@@ -16,6 +16,11 @@ class GAD2Config(PretrainedConfig):
|
|
| 16 |
self.num_intents = num_intents
|
| 17 |
self.memory_size = memory_size
|
| 18 |
self.seq_len = seq_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class RMSNorm(nn.Module):
|
| 21 |
def __init__(self, dim, eps=1e-6):
|
|
@@ -47,7 +52,6 @@ class AdaptiveMemory(nn.Module):
|
|
| 47 |
q = self.q_proj(x)
|
| 48 |
kv = self.kv_proj(self.memory).view(-1, 2, d)
|
| 49 |
k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1)
|
| 50 |
-
# Scaled Dot Product Attention
|
| 51 |
attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
|
| 52 |
return self.out_proj(attn_out.squeeze(1))
|
| 53 |
|
|
@@ -63,18 +67,13 @@ class GAD2_Block(nn.Module):
|
|
| 63 |
h = self.norm1(x)
|
| 64 |
b, t, d = h.size()
|
| 65 |
qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 69 |
x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d))
|
| 70 |
-
|
| 71 |
h = self.norm2(x)
|
| 72 |
-
|
| 73 |
-
return x
|
| 74 |
|
| 75 |
class GAD2ForAgenticModeling(PreTrainedModel):
|
| 76 |
config_class = GAD2Config
|
| 77 |
-
|
| 78 |
def __init__(self, config):
|
| 79 |
super().__init__(config)
|
| 80 |
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
|
|
@@ -84,33 +83,14 @@ class GAD2ForAgenticModeling(PreTrainedModel):
|
|
| 84 |
self.post_init()
|
| 85 |
|
| 86 |
def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
|
| 87 |
-
# input_ids 注砖讜讬 诇讛讙讬注 讻讞诇拽 诪-kwargs 讘-generate
|
| 88 |
-
if input_ids is None and "inputs_embeds" in kwargs:
|
| 89 |
-
return None # 讟讬驻讜诇 讘诪拽专讬诐 讞专讬讙讬诐 砖诇 inputs_embeds
|
| 90 |
-
|
| 91 |
x = self.token_emb(input_ids)
|
| 92 |
-
|
| 93 |
-
# Agentic Integration
|
| 94 |
x = x + self.intent_evolver(x) + self.memory_module(x)
|
| 95 |
-
|
| 96 |
-
# Transformer Blocks
|
| 97 |
-
for block in self.blocks:
|
| 98 |
x = block(x)
|
| 99 |
-
|
| 100 |
logits = self.output_head(self.final_norm(x))
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
if labels is not None:
|
| 104 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 105 |
-
|
| 106 |
-
return CausalLMOutputWithPast(
|
| 107 |
-
loss=loss,
|
| 108 |
-
logits=logits,
|
| 109 |
-
past_key_values=None,
|
| 110 |
-
hidden_states=None,
|
| 111 |
-
attentions=None,
|
| 112 |
-
)
|
| 113 |
|
| 114 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 115 |
-
#
|
| 116 |
-
return {"input_ids": input_ids}
|
|
|
|
| 16 |
self.num_intents = num_intents
|
| 17 |
self.memory_size = memory_size
|
| 18 |
self.seq_len = seq_len
|
| 19 |
+
|
| 20 |
+
# --- 转讬拽讜谉 讛砖讙讬讗讛: 讛讜住驻转 砖诪讜转 住讟谞讚专讟讬讬诐 砖讛住驻专讬讬讛 诪讞驻砖转 ---
|
| 21 |
+
self.num_hidden_layers = num_layers
|
| 22 |
+
self.hidden_size = d_model
|
| 23 |
+
self.num_attention_heads = nhead
|
| 24 |
|
| 25 |
class RMSNorm(nn.Module):
|
| 26 |
def __init__(self, dim, eps=1e-6):
|
|
|
|
| 52 |
q = self.q_proj(x)
|
| 53 |
kv = self.kv_proj(self.memory).view(-1, 2, d)
|
| 54 |
k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1)
|
|
|
|
| 55 |
attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
|
| 56 |
return self.out_proj(attn_out.squeeze(1))
|
| 57 |
|
|
|
|
| 67 |
h = self.norm1(x)
|
| 68 |
b, t, d = h.size()
|
| 69 |
qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 70 |
+
attn_out = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2], is_causal=True)
|
|
|
|
|
|
|
| 71 |
x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d))
|
|
|
|
| 72 |
h = self.norm2(x)
|
| 73 |
+
return x + self.w3(F.silu(self.w1(h)) * self.w2(h))
|
|
|
|
| 74 |
|
| 75 |
class GAD2ForAgenticModeling(PreTrainedModel):
|
| 76 |
config_class = GAD2Config
|
|
|
|
| 77 |
def __init__(self, config):
|
| 78 |
super().__init__(config)
|
| 79 |
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
|
|
|
|
| 83 |
self.post_init()
|
| 84 |
|
| 85 |
def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
x = self.token_emb(input_ids)
|
|
|
|
|
|
|
| 87 |
x = x + self.intent_evolver(x) + self.memory_module(x)
|
| 88 |
+
for block in self.blocks:
|
|
|
|
|
|
|
| 89 |
x = block(x)
|
|
|
|
| 90 |
logits = self.output_head(self.final_norm(x))
|
| 91 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
|
| 92 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 95 |
+
# 讞砖讜讘 诪讗讜讚: 讛讜住驻转 转诪讬讻讛 讘-kwargs 讻讚讬 砖-generate 诇讗 讬转讘诇讘诇
|
| 96 |
+
return {"input_ids": input_ids, **kwargs}
|