Upload model
Browse files- config.json +1 -1
- vit_fnqs_config.py +1 -1
- vit_fnqs_model.py +1 -1
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"L_eff":
|
| 3 |
"architectures": [
|
| 4 |
"NQSModel"
|
| 5 |
],
|
|
|
|
| 1 |
{
|
| 2 |
+
"L_eff": 8,
|
| 3 |
"architectures": [
|
| 4 |
"NQSModel"
|
| 5 |
],
|
vit_fnqs_config.py
CHANGED
|
@@ -7,7 +7,7 @@ class ViTFNQSConfig(PretrainedConfig):
|
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
-
L_eff=
|
| 11 |
num_layers = 6,
|
| 12 |
d_model = 72,
|
| 13 |
heads = 12,
|
|
|
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
+
L_eff=8,
|
| 11 |
num_layers = 6,
|
| 12 |
d_model = 72,
|
| 13 |
heads = 12,
|
vit_fnqs_model.py
CHANGED
|
@@ -9,7 +9,7 @@ class ViTFNQSModel(FlaxPreTrainedModel):
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
config: ViTFNQSConfig,
|
| 12 |
-
input_shape = (jnp.zeros((1,
|
| 13 |
seed: int = 0,
|
| 14 |
dtype: jnp.dtype = jnp.float64,
|
| 15 |
_do_init: bool = True,
|
|
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
config: ViTFNQSConfig,
|
| 12 |
+
input_shape = (jnp.zeros((1, 32)), jnp.zeros((1, 32))),
|
| 13 |
seed: int = 0,
|
| 14 |
dtype: jnp.dtype = jnp.float64,
|
| 15 |
_do_init: bool = True,
|