GeminiFan207 commited on
Commit
ab79917
·
verified ·
1 Parent(s): 072b27b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +91 -3
model.py CHANGED
@@ -1,6 +1,7 @@
1
- # model.py - Complete TinyState Model Implementation
2
  import torch
3
  import torch.nn as nn
 
4
  from transformers import PreTrainedModel
5
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
6
  from typing import Optional, Tuple, Union, List
@@ -28,6 +29,7 @@ class TinyStateConfig:
28
  attention_dropout=0.0,
29
  num_experts=8,
30
  num_experts_per_tok=2,
 
31
  **kwargs,
32
  ):
33
  self.vocab_size = vocab_size
@@ -45,6 +47,7 @@ class TinyStateConfig:
45
  self.attention_dropout = attention_dropout
46
  self.num_experts = num_experts
47
  self.num_experts_per_tok = num_experts_per_tok
 
48
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
49
 
50
  class RMSNorm(nn.Module):
@@ -66,15 +69,44 @@ def rotate_half(x):
66
  return torch.cat((-x2, x1), dim=-1)
67
 
68
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
 
 
69
  cos = cos[position_ids].unsqueeze(1)
70
  sin = sin[position_ids].unsqueeze(1)
71
  q_embed = (q * cos) + (rotate_half(q) * sin)
72
  k_embed = (k * cos) + (rotate_half(k) * sin)
73
  return q_embed, k_embed
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  class TinyStateMLP(nn.Module):
76
  def __init__(self, config):
77
  super().__init__()
 
78
  self.hidden_size = config.hidden_size
79
  self.intermediate_size = config.intermediate_size
80
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -86,9 +118,40 @@ class TinyStateMLP(nn.Module):
86
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
87
  return down_proj
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  class TinyStateAttention(nn.Module):
90
  def __init__(self, config):
91
  super().__init__()
 
92
  self.hidden_size = config.hidden_size
93
  self.num_heads = config.num_attention_heads
94
  self.head_dim = self.hidden_size // self.num_heads
@@ -99,6 +162,12 @@ class TinyStateAttention(nn.Module):
99
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
100
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
101
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
 
102
 
103
  def forward(
104
  self,
@@ -124,7 +193,18 @@ class TinyStateAttention(nn.Module):
124
  if past_key_value is not None:
125
  kv_seq_len += past_key_value[0].shape[-2]
126
 
127
- # Simplified attention computation
 
 
 
 
 
 
 
 
 
 
 
128
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
129
 
130
  if attention_mask is not None:
@@ -132,6 +212,7 @@ class TinyStateAttention(nn.Module):
132
 
133
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
134
  attn_output = torch.matmul(attn_weights, value_states)
 
135
  attn_output = attn_output.transpose(1, 2).contiguous()
136
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
137
  attn_output = self.o_proj(attn_output)
@@ -141,12 +222,19 @@ class TinyStateAttention(nn.Module):
141
 
142
  return attn_output, attn_weights, past_key_value
143
 
 
 
 
 
 
 
 
144
  class TinyStateDecoderLayer(nn.Module):
145
  def __init__(self, config):
146
  super().__init__()
147
  self.hidden_size = config.hidden_size
148
  self.self_attn = TinyStateAttention(config)
149
- self.mlp = TinyStateMLP(config)
150
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
151
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
152
 
 
1
+ # model.py - COMPLETE IMPROVED TinyState Model with MoE
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
  from transformers import PreTrainedModel
6
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
7
  from typing import Optional, Tuple, Union, List
 
29
  attention_dropout=0.0,
30
  num_experts=8,
31
  num_experts_per_tok=2,
32
+ moe_active=True,
33
  **kwargs,
34
  ):
35
  self.vocab_size = vocab_size
 
47
  self.attention_dropout = attention_dropout
48
  self.num_experts = num_experts
49
  self.num_experts_per_tok = num_experts_per_tok
50
+ self.moe_active = moe_active
51
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
52
 
53
  class RMSNorm(nn.Module):
 
69
  return torch.cat((-x2, x1), dim=-1)
70
 
71
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
72
+ cos = cos.squeeze(1).squeeze(0)
73
+ sin = sin.squeeze(1).squeeze(0)
74
  cos = cos[position_ids].unsqueeze(1)
75
  sin = sin[position_ids].unsqueeze(1)
76
  q_embed = (q * cos) + (rotate_half(q) * sin)
77
  k_embed = (k * cos) + (rotate_half(k) * sin)
78
  return q_embed, k_embed
79
 
80
+ class TinyStateRotaryEmbedding(nn.Module):
81
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.base = base
86
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
87
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
88
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
89
+
90
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
91
+ self.max_seq_len_cached = seq_len
92
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
93
+ freqs = torch.outer(t, self.inv_freq)
94
+ emb = torch.cat((freqs, freqs), dim=-1)
95
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
96
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
97
+
98
+ def forward(self, x, seq_len=None):
99
+ if seq_len > self.max_seq_len_cached:
100
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
101
+ return (
102
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
103
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
104
+ )
105
+
106
  class TinyStateMLP(nn.Module):
107
  def __init__(self, config):
108
  super().__init__()
109
+ self.config = config
110
  self.hidden_size = config.hidden_size
111
  self.intermediate_size = config.intermediate_size
112
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 
118
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
119
  return down_proj
120
 
121
+ class TinyStateMoE(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.num_experts = config.num_experts
125
+ self.num_experts_per_tok = config.num_experts_per_tok
126
+ self.hidden_size = config.hidden_size
127
+ self.intermediate_size = config.intermediate_size
128
+ self.experts = nn.ModuleList([
129
+ TinyStateMLP(config) for _ in range(self.num_experts)
130
+ ])
131
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
132
+
133
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
134
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
135
+ hidden_states = hidden_states.view(-1, hidden_dim)
136
+ router_logits = self.gate(hidden_states)
137
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
138
+ routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
139
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
140
+ final_hidden_states = torch.zeros_like(hidden_states)
141
+
142
+ for expert_idx in range(self.num_experts):
143
+ expert_mask = (selected_experts == expert_idx)
144
+ expert_weights = routing_weights * expert_mask
145
+ expert_weights = expert_weights.sum(dim=-1, keepdim=True)
146
+ expert_output = self.experts[expert_idx](hidden_states)
147
+ final_hidden_states += expert_output * expert_weights
148
+
149
+ return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
150
+
151
  class TinyStateAttention(nn.Module):
152
  def __init__(self, config):
153
  super().__init__()
154
+ self.config = config
155
  self.hidden_size = config.hidden_size
156
  self.num_heads = config.num_attention_heads
157
  self.head_dim = self.hidden_size // self.num_heads
 
162
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
163
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
164
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
165
+
166
+ self.rotary_emb = TinyStateRotaryEmbedding(
167
+ self.head_dim,
168
+ max_position_embeddings=config.max_position_embeddings,
169
+ base=config.rope_theta,
170
+ )
171
 
172
  def forward(
173
  self,
 
193
  if past_key_value is not None:
194
  kv_seq_len += past_key_value[0].shape[-2]
195
 
196
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
197
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
198
+
199
+ if past_key_value is not None:
200
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
201
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
202
+
203
+ past_key_value = (key_states, value_states) if use_cache else None
204
+
205
+ key_states = self._repeat_kv(key_states)
206
+ value_states = self._repeat_kv(value_states)
207
+
208
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
209
 
210
  if attention_mask is not None:
 
212
 
213
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
214
  attn_output = torch.matmul(attn_weights, value_states)
215
+
216
  attn_output = attn_output.transpose(1, 2).contiguous()
217
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
218
  attn_output = self.o_proj(attn_output)
 
222
 
223
  return attn_output, attn_weights, past_key_value
224
 
225
+ def _repeat_kv(self, hidden_states: torch.Tensor) -> torch.Tensor:
226
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
227
+ if num_key_value_heads == self.num_heads:
228
+ return hidden_states
229
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, self.num_key_value_groups, slen, head_dim)
230
+ return hidden_states.reshape(batch, self.num_heads, slen, head_dim)
231
+
232
  class TinyStateDecoderLayer(nn.Module):
233
  def __init__(self, config):
234
  super().__init__()
235
  self.hidden_size = config.hidden_size
236
  self.self_attn = TinyStateAttention(config)
237
+ self.mlp = TinyStateMoE(config) if config.moe_active else TinyStateMLP(config)
238
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240