han-xudong
commited on
Commit
·
4d63463
1
Parent(s):
0be4aac
modified: config.json
Browse files- config.json +4 -4
- modeling.py +4 -4
config.json
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "asRobotics/
|
| 3 |
-
"architectures": ["
|
| 4 |
-
"model_type": "
|
| 5 |
"x_dim": [6],
|
| 6 |
-
"y_dim": [6,
|
| 7 |
"h1_dim": [100, 1000],
|
| 8 |
"h2_dim": [100, 1000],
|
| 9 |
"torch_dtype": "float32",
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "asRobotics/ballnet",
|
| 3 |
+
"architectures": ["BallNet"],
|
| 4 |
+
"model_type": "ballnet",
|
| 5 |
"x_dim": [6],
|
| 6 |
+
"y_dim": [6, 2931],
|
| 7 |
"h1_dim": [100, 1000],
|
| 8 |
"h2_dim": [100, 1000],
|
| 9 |
"torch_dtype": "float32",
|
modeling.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 4 |
|
| 5 |
-
class
|
| 6 |
-
model_type = "
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
|
@@ -20,8 +20,8 @@ class FingerNetConfig(PretrainedConfig):
|
|
| 20 |
self.h2_dim = h2_dim
|
| 21 |
|
| 22 |
|
| 23 |
-
class
|
| 24 |
-
config_class =
|
| 25 |
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__(config)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 4 |
|
| 5 |
+
class BallNetConfig(PretrainedConfig):
|
| 6 |
+
model_type = "ballnet"
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
|
|
|
| 20 |
self.h2_dim = h2_dim
|
| 21 |
|
| 22 |
|
| 23 |
+
class BallNet(PreTrainedModel):
|
| 24 |
+
config_class = BallNetConfig
|
| 25 |
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__(config)
|