Yuchan commited on
Commit
e33316b
verified
1 Parent(s): 7d7e323

Update Mo_jax.py

Browse files
Files changed (1) hide show
  1. Mo_jax.py +58 -48
Mo_jax.py CHANGED
@@ -89,43 +89,49 @@ def create_batch_iter(inputs, targets, batch_size, rng):
89
 
90
  def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1])
91
 
92
- # ------------------
93
- # Model
94
- # ------------------
95
  class SwiGLU(nn.Module):
96
  d_model: int
97
- dtype: Any = DTYPE
98
  @nn.compact
99
- def __call__(self,x):
100
- proj = nn.Dense(self.d_model*2,dtype=self.dtype)(x)
101
- x_val, x_gate = jnp.split(proj,2,-1)
 
102
  out = x_val * nn.silu(x_gate)
103
- return nn.Dense(self.d_model,dtype=self.dtype)(out)
 
104
 
105
  class LoU(nn.Module):
106
- d_model:int
107
- dtype:Any=DTYPE
 
108
  @nn.compact
109
- def __call__(self,x):
110
- residual = x
111
- x_norm = nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(x)
112
- Q=nn.Dense(self.d_model,dtype=self.dtype)
113
- K=nn.Dense(self.d_model,dtype=self.dtype)
114
- V=nn.Dense(self.d_model,dtype=self.dtype)
115
- q,k,v = Q(x_norm),K(x_norm),V(x_norm)
116
- g_q = (jnp.tanh(q)+1)/2; g_k=(jnp.tanh(k)+1)/2
117
- score = g_q*g_k
118
- alpha_dynamic = nn.Dense(1,dtype=self.dtype)(x_norm)
 
 
119
  # EMA scan along seq axis
120
  score_t = jnp.transpose(score,(1,0,2))
121
  alpha_t = jnp.transpose(alpha_dynamic,(1,0,2))
122
- def step(prev,cur): s,a=cur; new=a*s+(1-a)*prev; return new,new
123
- init = score_t[0]; _,ema_seq=jax.lax.scan(step,init,(score_t[1:],alpha_t[1:]))
124
- ema_full=jnp.concatenate([init[None,...],ema_seq],0)
 
 
 
 
125
  ema = jnp.transpose(ema_full,(1,0,2))
126
- out = v*ema + residual
127
- out = nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(out)
128
- return SwiGLU(self.d_model,self.dtype)(out)
 
129
 
130
  class Lo(nn.Module):
131
  d_model:int
@@ -161,28 +167,26 @@ class ReLM(nn.Module):
161
  logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding)
162
  return logits
163
 
164
- # ------------------
165
- # Loss & metrics
166
- # ------------------
167
- def smoothed_ce(logits,targets,pad_id,eps=0.1):
168
- vocab=logits.shape[-1]
169
- logits=logits.reshape(-1,vocab)
170
- targets=targets.reshape(-1)
171
- mask=(targets!=pad_id).astype(jnp.float32)
172
- one_hot=jax.nn.one_hot(targets,vocab)
173
- smooth=(1-eps)*one_hot+eps/vocab
174
- log_probs=jax.nn.log_softmax(logits)
175
- loss=-jnp.sum(smooth*log_probs,axis=-1)*mask
176
- return jnp.sum(loss)/(jnp.sum(mask)+1e-8)
177
 
178
- def masked_ppl(logits,targets,pad_id,eps=0.1):
179
- vocab=logits.shape[-1]
180
- logits=logits.reshape(-1,vocab)
181
- targets=targets.reshape(-1)
182
- mask=(targets!=pad_id).astype(jnp.float32)
183
- one_hot=jax.nn.one_hot(targets,vocab)
184
- smooth=(1-eps)*one_hot+eps/vocab
185
- loss=-jnp.sum(smooth*jax.nn.log_softmax(logits),axis=-1)*mask
 
186
  return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8))
187
 
188
  # ------------------
@@ -264,7 +268,13 @@ for epoch in range(EPOCHS):
264
  # ------------------
265
  save_dir="./checkpoints"
266
  os.makedirs(save_dir,exist_ok=True)
267
- checkpoints.save_checkpoint(save_dir,jax.tree_map(lambda x:np.array(x),state),step=global_step,keep=3)
 
 
 
 
 
 
268
  print("Saved checkpoint to",save_dir)
269
 
270
  # ------------------
 
89
 
90
  def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1])
91
 
 
 
 
92
  class SwiGLU(nn.Module):
93
  d_model: int
 
94
  @nn.compact
95
+ def __call__(self, x):
96
+ x_f32 = x.astype(jnp.float32)
97
+ proj = nn.Dense(self.d_model*2, dtype=jnp.float32)(x_f32)
98
+ x_val, x_gate = jnp.split(proj, 2, axis=-1)
99
  out = x_val * nn.silu(x_gate)
100
+ out = nn.Dense(self.d_model, dtype=jnp.float32)(out)
101
+ return out.astype(x.dtype)
102
 
103
  class LoU(nn.Module):
104
+ d_model: int
105
+ clip_value: float = 5.0
106
+ eps: float = 1e-6
107
  @nn.compact
108
+ def __call__(self, x):
109
+ x_f32 = x.astype(jnp.float32)
110
+ residual = x_f32
111
+ x_norm = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(x_f32)
112
+ Q = nn.Dense(self.d_model, dtype=jnp.float32)
113
+ K = nn.Dense(self.d_model, dtype=jnp.float32)
114
+ V = nn.Dense(self.d_model, dtype=jnp.float32)
115
+ q,k,v = Q(x_norm), K(x_norm), V(x_norm)
116
+ g_q = (jnp.tanh(q)+1)/2
117
+ g_k = (jnp.tanh(k)+1)/2
118
+ score = g_q * g_k
119
+ alpha_dynamic = nn.Dense(1, dtype=jnp.float32)(x_norm)
120
  # EMA scan along seq axis
121
  score_t = jnp.transpose(score,(1,0,2))
122
  alpha_t = jnp.transpose(alpha_dynamic,(1,0,2))
123
+ def step(prev, cur):
124
+ s, a = cur
125
+ new = a*s + (1-a)*prev
126
+ return new,new
127
+ init = score_t[0]
128
+ _, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:]))
129
+ ema_full = jnp.concatenate([init[None,...], ema_seq], 0)
130
  ema = jnp.transpose(ema_full,(1,0,2))
131
+ out = v * ema + residual
132
+ out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out)
133
+ return SwiGLU(self.d_model)(out).astype(x.dtype)
134
+
135
 
136
  class Lo(nn.Module):
137
  d_model:int
 
167
  logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding)
168
  return logits
169
 
170
+ def smoothed_ce(logits, targets, pad_id, eps=0.1):
171
+ logits = logits.astype(jnp.float32)
172
+ targets = targets.astype(jnp.int32)
173
+ vocab = logits.shape[-1]
174
+ mask = (targets != pad_id).astype(jnp.float32)
175
+ one_hot = jax.nn.one_hot(targets, vocab)
176
+ smooth = (1-eps)*one_hot + eps/vocab
177
+ log_probs = jax.nn.log_softmax(logits, axis=-1)
178
+ loss = -jnp.sum(smooth * log_probs, axis=-1) * mask
179
+ return jnp.sum(loss) / (jnp.sum(mask)+1e-8)
 
 
 
180
 
181
+ def masked_ppl(logits, targets, pad_id, eps=0.1):
182
+ logits = logits.astype(jnp.float32)
183
+ targets = targets.astype(jnp.int32)
184
+ vocab = logits.shape[-1]
185
+ mask = (targets != pad_id).astype(jnp.float32)
186
+ one_hot = jax.nn.one_hot(targets, vocab)
187
+ smooth = (1-eps)*one_hot + eps/vocab
188
+ log_probs = jax.nn.log_softmax(logits, axis=-1)
189
+ loss = -jnp.sum(smooth*log_probs, axis=-1) * mask
190
  return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8))
191
 
192
  # ------------------
 
268
  # ------------------
269
  save_dir="./checkpoints"
270
  os.makedirs(save_dir,exist_ok=True)
271
+ # 旮办〈
272
+ # checkpoints.save_checkpoint(save_dir,jax.tree_map(lambda x:np.array(x),state),step=global_step,keep=3)
273
+
274
+ # 靾橃爼
275
+ import jax.tree_util
276
+ checkpoints.save_checkpoint(save_dir, jax.tree_util.tree_map(lambda x: np.array(x), state), step=global_step, keep=3)
277
+
278
  print("Saved checkpoint to",save_dir)
279
 
280
  # ------------------