Raziel1234 commited on
Commit
d563af6
verified
1 Parent(s): 1281c27

Update modeling_gad2.py

Browse files
Files changed (1) hide show
  1. modeling_gad2.py +12 -32
modeling_gad2.py CHANGED
@@ -16,6 +16,11 @@ class GAD2Config(PretrainedConfig):
16
  self.num_intents = num_intents
17
  self.memory_size = memory_size
18
  self.seq_len = seq_len
 
 
 
 
 
19
 
20
  class RMSNorm(nn.Module):
21
  def __init__(self, dim, eps=1e-6):
@@ -47,7 +52,6 @@ class AdaptiveMemory(nn.Module):
47
  q = self.q_proj(x)
48
  kv = self.kv_proj(self.memory).view(-1, 2, d)
49
  k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1)
50
- # Scaled Dot Product Attention
51
  attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
52
  return self.out_proj(attn_out.squeeze(1))
53
 
@@ -63,18 +67,13 @@ class GAD2_Block(nn.Module):
63
  h = self.norm1(x)
64
  b, t, d = h.size()
65
  qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4)
66
- q, k, v = qkv[0], qkv[1], qkv[2]
67
-
68
- attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
69
  x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d))
70
-
71
  h = self.norm2(x)
72
- x = x + self.w3(F.silu(self.w1(h)) * self.w2(h))
73
- return x
74
 
75
  class GAD2ForAgenticModeling(PreTrainedModel):
76
  config_class = GAD2Config
77
-
78
  def __init__(self, config):
79
  super().__init__(config)
80
  self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
@@ -84,33 +83,14 @@ class GAD2ForAgenticModeling(PreTrainedModel):
84
  self.post_init()
85
 
86
  def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
87
- # input_ids 注砖讜讬 诇讛讙讬注 讻讞诇拽 诪-kwargs 讘-generate
88
- if input_ids is None and "inputs_embeds" in kwargs:
89
- return None # 讟讬驻讜诇 讘诪拽专讬诐 讞专讬讙讬诐 砖诇 inputs_embeds
90
-
91
  x = self.token_emb(input_ids)
92
-
93
- # Agentic Integration
94
  x = x + self.intent_evolver(x) + self.memory_module(x)
95
-
96
- # Transformer Blocks
97
- for block in self.blocks:
98
  x = block(x)
99
-
100
  logits = self.output_head(self.final_norm(x))
101
-
102
- loss = None
103
- if labels is not None:
104
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
105
-
106
- return CausalLMOutputWithPast(
107
- loss=loss,
108
- logits=logits,
109
- past_key_values=None,
110
- hidden_states=None,
111
- attentions=None,
112
- )
113
 
114
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
115
- # 谞拽爪讬讛 讞讬谞讬讜专 .generate()
116
- return {"input_ids": input_ids}
 
16
  self.num_intents = num_intents
17
  self.memory_size = memory_size
18
  self.seq_len = seq_len
19
+
20
+ # --- 转讬拽讜谉 讛砖讙讬讗讛: 讛讜住驻转 砖诪讜转 住讟谞讚专讟讬讬诐 砖讛住驻专讬讬讛 诪讞驻砖转 ---
21
+ self.num_hidden_layers = num_layers
22
+ self.hidden_size = d_model
23
+ self.num_attention_heads = nhead
24
 
25
  class RMSNorm(nn.Module):
26
  def __init__(self, dim, eps=1e-6):
 
52
  q = self.q_proj(x)
53
  kv = self.kv_proj(self.memory).view(-1, 2, d)
54
  k, v = kv[:, 0].unsqueeze(0).expand(b, -1, -1), kv[:, 1].unsqueeze(0).expand(b, -1, -1)
 
55
  attn_out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
56
  return self.out_proj(attn_out.squeeze(1))
57
 
 
67
  h = self.norm1(x)
68
  b, t, d = h.size()
69
  qkv = self.qkv(h).view(b, t, 3, self.nhead, self.head_dim).permute(2, 0, 3, 1, 4)
70
+ attn_out = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2], is_causal=True)
 
 
71
  x = x + self.proj(attn_out.transpose(1, 2).reshape(b, t, d))
 
72
  h = self.norm2(x)
73
+ return x + self.w3(F.silu(self.w1(h)) * self.w2(h))
 
74
 
75
  class GAD2ForAgenticModeling(PreTrainedModel):
76
  config_class = GAD2Config
 
77
  def __init__(self, config):
78
  super().__init__(config)
79
  self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
 
83
  self.post_init()
84
 
85
  def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
 
 
 
 
86
  x = self.token_emb(input_ids)
 
 
87
  x = x + self.intent_evolver(x) + self.memory_module(x)
88
+ for block in self.blocks:
 
 
89
  x = block(x)
 
90
  logits = self.output_head(self.final_norm(x))
91
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None
92
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
 
 
 
 
 
 
 
 
 
 
93
 
94
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
95
+ # 讞砖 诪讗讚: 讛讜住驻转诪讬讻讛 -kwargs 讻讚讬 砖-generate 诇讗 讬转讘诇讘诇
96
+ return {"input_ids": input_ids, **kwargs}