Yuchan commited on
Commit
88215f4
·
verified ·
1 Parent(s): e1bb994

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +2 -2
Model.py CHANGED
@@ -235,7 +235,7 @@ class LoSoU(layers.Layer):
235
  return tf.cast(out, x.dtype)
236
 
237
  class Block(layers.Layer):
238
- def __init__(self, d_model, r, hyper_n, num_heads, num_groups):
239
  super().__init__()
240
  self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
241
 
@@ -249,7 +249,7 @@ class ReLaM(tf.keras.Model):
249
  super().__init__()
250
  self.token_embedding = layers.Embedding(vocab_size, d_model)
251
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
252
- self.blocks = [Block(d_model, r=204, hyper_n=3, num_heads=8, num_groups=2) for _ in range(n_layers)]
253
 
254
  # LayerNormalization은 float32로 해서 정밀도 문제 방지
255
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
 
235
  return tf.cast(out, x.dtype)
236
 
237
  class Block(layers.Layer):
238
+ def __init__(self, d_model, hyper_n):
239
  super().__init__()
240
  self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
241
 
 
249
  super().__init__()
250
  self.token_embedding = layers.Embedding(vocab_size, d_model)
251
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
252
+ self.blocks = [Block(d_model, hyper_n=3) for _ in range(n_layers)]
253
 
254
  # LayerNormalization은 float32로 해서 정밀도 문제 방지
255
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")