Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -103,7 +103,7 @@ def txt_stream(file_path, num_lines=None):
|
|
| 103 |
|
| 104 |
# Dataset ์์ฑ (์: ์ฒ์ 10,000๋ผ์ธ๋ง)
|
| 105 |
dataset = tf.data.Dataset.from_generator(
|
| 106 |
-
lambda: txt_stream(DATA_PATH, num_lines=
|
| 107 |
output_signature=(
|
| 108 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 109 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
|
@@ -129,7 +129,7 @@ class SwiGLU(layers.Layer):
|
|
| 129 |
return tf.cast(out, x.dtype)
|
| 130 |
|
| 131 |
class SparseCausalAttention(tf.keras.layers.Layer):
|
| 132 |
-
def __init__(self, num_heads, head_dim, window_size=
|
| 133 |
super().__init__(**kwargs)
|
| 134 |
self.num_heads = num_heads
|
| 135 |
self.head_dim = head_dim
|
|
|
|
| 103 |
|
| 104 |
# Dataset ์์ฑ (์: ์ฒ์ 10,000๋ผ์ธ๋ง)
|
| 105 |
dataset = tf.data.Dataset.from_generator(
|
| 106 |
+
lambda: txt_stream(DATA_PATH, num_lines=100000),
|
| 107 |
output_signature=(
|
| 108 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 109 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
|
|
|
| 129 |
return tf.cast(out, x.dtype)
|
| 130 |
|
| 131 |
class SparseCausalAttention(tf.keras.layers.Layer):
|
| 132 |
+
def __init__(self, num_heads, head_dim, window_size=8, **kwargs):
|
| 133 |
super().__init__(**kwargs)
|
| 134 |
self.num_heads = num_heads
|
| 135 |
self.head_dim = head_dim
|