Upload DisamBert
Browse files- DisamBert.py +3 -1
- model.safetensors +1 -1
DisamBert.py
CHANGED
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
from enum import StrEnum
|
| 4 |
from itertools import chain
|
| 5 |
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
|
@@ -91,7 +92,8 @@ class DisamBert(PreTrainedModel):
|
|
| 91 |
self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
|
| 92 |
self.bias = nn.Parameter(
|
| 93 |
torch.nn.init.normal_(
|
| 94 |
-
torch.empty((self.config.ontology_size, 1)),
|
|
|
|
| 95 |
)
|
| 96 |
)
|
| 97 |
|
|
|
|
| 3 |
from enum import StrEnum
|
| 4 |
from itertools import chain
|
| 5 |
|
| 6 |
+
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 92 |
self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
|
| 93 |
self.bias = nn.Parameter(
|
| 94 |
torch.nn.init.normal_(
|
| 95 |
+
torch.empty((self.config.ontology_size, 1)),
|
| 96 |
+
std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size)
|
| 97 |
)
|
| 98 |
)
|
| 99 |
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 957993808
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba02b2a39f89f4b20322e713d0882548617ed5657678db44d8ef6af92839f9f0
|
| 3 |
size 957993808
|