Update V2.py
Browse files
V2.py
CHANGED
|
@@ -128,7 +128,7 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
| 131 |
-
class
|
| 132 |
def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
|
| 133 |
super().__init__()
|
| 134 |
assert k % 2 == 1
|
|
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
| 131 |
+
class HyperConv1D(layers.Layer):
|
| 132 |
def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
|
| 133 |
super().__init__()
|
| 134 |
assert k % 2 == 1
|