anthonym21 commited on
Commit
59e4dc1
·
verified ·
1 Parent(s): 02d6021

Upload modeling_eve.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_eve.py +16 -16
modeling_eve.py CHANGED
@@ -72,26 +72,18 @@ class SharedMoE(nn.Module):
72
  routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
73
  return shared_out + routed_out, aux_loss
74
 
75
- class Block(nn.Module):
76
  def __init__(self, config):
77
  super().__init__()
78
- self.ln_1 = RMSNorm(config.n_embd)
79
- self.ln_2 = RMSNorm(config.n_embd)
80
-
81
- # Attention components
82
  self.n_head = config.n_head
83
  self.head_dim = config.head_dim
84
  self.n_embd = config.n_embd
85
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
86
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
87
-
88
- self.mlp = SharedMoE(config)
89
 
90
  def forward(self, x, freqs_cis):
91
- # Attention Block
92
  B, T, C = x.shape
93
- h = self.ln_1(x)
94
- qkv = self.c_attn(h)
95
  q, k, v = qkv.split(self.n_embd, dim=2)
96
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
97
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
@@ -100,17 +92,26 @@ class Block(nn.Module):
100
  k = apply_rope(k, freqs_cis)
101
  y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
102
  y = y.transpose(1, 2).contiguous().view(B, T, C)
103
- attn_out = self.c_proj(y)
 
 
 
 
 
 
 
 
 
 
 
104
  x = x + attn_out
105
-
106
- # MoE Block
107
  mlp_out, aux_loss = self.mlp(self.ln_2(x))
108
  x = x + mlp_out
109
  return x, aux_loss
110
 
111
  class DeepSeekMoE(PreTrainedModel):
112
  config_class = EveConfig
113
- _tied_weights_keys = ["lm_head.weight"] # <--- THE FIX IS HERE
114
 
115
  def __init__(self, config):
116
  super().__init__(config)
@@ -122,7 +123,7 @@ class DeepSeekMoE(PreTrainedModel):
122
  ))
123
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
124
 
125
- # Tie weights manually (HF checks this flag)
126
  self.transformer.wte.weight = self.lm_head.weight
127
 
128
  freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
@@ -148,7 +149,6 @@ class DeepSeekMoE(PreTrainedModel):
148
  x = self.transformer.wte(idx)
149
  total_aux_loss = 0.0
150
 
151
- # Ensure rope freqs are on correct device
152
  freqs_cis = self.freqs_cis.to(x.device)
153
 
154
  for block in self.transformer.h:
 
72
  routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
73
  return shared_out + routed_out, aux_loss
74
 
75
+ class CausalSelfAttention(nn.Module):
76
  def __init__(self, config):
77
  super().__init__()
 
 
 
 
78
  self.n_head = config.n_head
79
  self.head_dim = config.head_dim
80
  self.n_embd = config.n_embd
81
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
82
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
 
 
83
 
84
  def forward(self, x, freqs_cis):
 
85
  B, T, C = x.shape
86
+ qkv = self.c_attn(x)
 
87
  q, k, v = qkv.split(self.n_embd, dim=2)
88
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
89
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 
92
  k = apply_rope(k, freqs_cis)
93
  y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
94
  y = y.transpose(1, 2).contiguous().view(B, T, C)
95
+ return self.c_proj(y)
96
+
97
+ class Block(nn.Module):
98
+ def __init__(self, config):
99
+ super().__init__()
100
+ self.ln_1 = RMSNorm(config.n_embd)
101
+ self.ln_2 = RMSNorm(config.n_embd)
102
+ self.attn = CausalSelfAttention(config) # Named 'attn' to match safetensors
103
+ self.mlp = SharedMoE(config)
104
+
105
+ def forward(self, x, freqs_cis):
106
+ attn_out = self.attn(self.ln_1(x), freqs_cis)
107
  x = x + attn_out
 
 
108
  mlp_out, aux_loss = self.mlp(self.ln_2(x))
109
  x = x + mlp_out
110
  return x, aux_loss
111
 
112
  class DeepSeekMoE(PreTrainedModel):
113
  config_class = EveConfig
114
+ _tied_weights_keys = ["lm_head.weight"]
115
 
116
  def __init__(self, config):
117
  super().__init__(config)
 
123
  ))
124
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125
 
126
+ # Tie weights
127
  self.transformer.wte.weight = self.lm_head.weight
128
 
129
  freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
 
149
  x = self.transformer.wte(idx)
150
  total_aux_loss = 0.0
151
 
 
152
  freqs_cis = self.freqs_cis.to(x.device)
153
 
154
  for block in self.transformer.h: