Upload DisamBert
Browse files- DisamBert.py +18 -25
- config.json +1 -1
- model.safetensors +1 -1
DisamBert.py
CHANGED
|
@@ -5,7 +5,7 @@ from enum import StrEnum
|
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
-
from transformers import AutoConfig, AutoModel, AutoTokenizer,
|
| 9 |
|
| 10 |
BATCH_SIZE = 64
|
| 11 |
|
|
@@ -25,29 +25,30 @@ class LexicalExample:
|
|
| 25 |
class PaddedBatch:
|
| 26 |
input_ids: torch.Tensor
|
| 27 |
attention_mask: torch.Tensor
|
| 28 |
-
|
| 29 |
|
| 30 |
class DisamBert(PreTrainedModel):
|
| 31 |
-
def __init__(self, config:PreTrainedConfig):
|
| 32 |
super().__init__(config)
|
| 33 |
if config.init_basemodel:
|
| 34 |
-
self.BaseModel = AutoModel.from_pretrained(config.name_or_path,device_map="auto")
|
| 35 |
-
|
| 36 |
-
self.classifier_head = nn.UninitializedParameter()
|
| 37 |
self.__entities = None
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
|
| 40 |
self.post_init()
|
| 41 |
-
|
| 42 |
-
|
| 43 |
@classmethod
|
| 44 |
def from_base(cls, base_id: ModelURI):
|
| 45 |
config = AutoConfig.from_pretrained(base_id)
|
| 46 |
config.init_basemodel = True
|
| 47 |
config.tokenizer_path = base_id
|
| 48 |
return cls(config)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
def init_classifier(self, entities: Generator[LexicalExample]) -> None:
|
| 52 |
entity_ids = []
|
| 53 |
vectors = []
|
|
@@ -57,25 +58,22 @@ class DisamBert(PreTrainedModel):
|
|
| 57 |
for entity in entities:
|
| 58 |
entity_ids.append(entity.concept)
|
| 59 |
batch.append(entity.definition)
|
| 60 |
-
|
| 61 |
n += 1
|
| 62 |
if n == BATCH_SIZE:
|
| 63 |
tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
|
| 64 |
-
encoding = self.BaseModel(
|
| 65 |
-
tokens["input_ids"], tokens["attention_mask"]
|
| 66 |
-
)
|
| 67 |
vectors.append(encoding.last_hidden_state.detach()[:, 0])
|
| 68 |
n = 0
|
| 69 |
batch = []
|
| 70 |
if n > 0:
|
| 71 |
tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
|
| 72 |
-
encoding = self.BaseModel(
|
| 73 |
-
tokens["input_ids"], tokens["attention_mask"]
|
| 74 |
-
)
|
| 75 |
vectors.append(encoding.last_hidden_state.detach()[:, 0])
|
| 76 |
-
|
| 77 |
self.__entities = pd.Series(entity_ids)
|
| 78 |
self.config.entities = entity_ids
|
|
|
|
| 79 |
self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
|
| 80 |
|
| 81 |
@property
|
|
@@ -147,12 +145,7 @@ class DisamBert(PreTrainedModel):
|
|
| 147 |
]
|
| 148 |
)
|
| 149 |
attention_mask = torch.vstack(
|
| 150 |
-
[
|
| 151 |
-
torch.cat(
|
| 152 |
-
(torch.ones(length), torch.zeros(maxlen - length))
|
| 153 |
-
)
|
| 154 |
-
for length in lengths
|
| 155 |
-
]
|
| 156 |
)
|
| 157 |
return PaddedBatch(input_ids, attention_mask)
|
| 158 |
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer, ModernBertModel, PreTrainedConfig, PreTrainedModel
|
| 9 |
|
| 10 |
BATCH_SIZE = 64
|
| 11 |
|
|
|
|
| 25 |
class PaddedBatch:
|
| 26 |
input_ids: torch.Tensor
|
| 27 |
attention_mask: torch.Tensor
|
| 28 |
+
|
| 29 |
|
| 30 |
class DisamBert(PreTrainedModel):
|
| 31 |
+
def __init__(self, config: PreTrainedConfig):
|
| 32 |
super().__init__(config)
|
| 33 |
if config.init_basemodel:
|
| 34 |
+
self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
|
| 35 |
+
self.classifier_head = nn.UninitializedParameter()
|
|
|
|
| 36 |
self.__entities = None
|
| 37 |
+
else:
|
| 38 |
+
self.BaseModel = ModernBertModel(config)
|
| 39 |
+
self.classifier_head = nn.Parameter(torch.empty((config.vocab_size,config.hidden_size)))
|
| 40 |
+
self._entities
|
| 41 |
+
config.init_basemodel = False
|
| 42 |
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
|
| 43 |
self.post_init()
|
| 44 |
+
|
|
|
|
| 45 |
@classmethod
|
| 46 |
def from_base(cls, base_id: ModelURI):
|
| 47 |
config = AutoConfig.from_pretrained(base_id)
|
| 48 |
config.init_basemodel = True
|
| 49 |
config.tokenizer_path = base_id
|
| 50 |
return cls(config)
|
| 51 |
+
|
|
|
|
| 52 |
def init_classifier(self, entities: Generator[LexicalExample]) -> None:
|
| 53 |
entity_ids = []
|
| 54 |
vectors = []
|
|
|
|
| 58 |
for entity in entities:
|
| 59 |
entity_ids.append(entity.concept)
|
| 60 |
batch.append(entity.definition)
|
| 61 |
+
|
| 62 |
n += 1
|
| 63 |
if n == BATCH_SIZE:
|
| 64 |
tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
|
| 65 |
+
encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"])
|
|
|
|
|
|
|
| 66 |
vectors.append(encoding.last_hidden_state.detach()[:, 0])
|
| 67 |
n = 0
|
| 68 |
batch = []
|
| 69 |
if n > 0:
|
| 70 |
tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
|
| 71 |
+
encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"])
|
|
|
|
|
|
|
| 72 |
vectors.append(encoding.last_hidden_state.detach()[:, 0])
|
| 73 |
+
|
| 74 |
self.__entities = pd.Series(entity_ids)
|
| 75 |
self.config.entities = entity_ids
|
| 76 |
+
self.config.vocab_size = len(entity_ids)
|
| 77 |
self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
|
| 78 |
|
| 79 |
@property
|
|
|
|
| 145 |
]
|
| 146 |
)
|
| 147 |
attention_mask = torch.vstack(
|
| 148 |
+
[torch.cat((torch.ones(length), torch.zeros(maxlen - length))) for length in lengths]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
return PaddedBatch(input_ids, attention_mask)
|
| 151 |
|
config.json
CHANGED
|
@@ -117741,5 +117741,5 @@
|
|
| 117741 |
"tie_word_embeddings": true,
|
| 117742 |
"tokenizer_path": "answerdotai/ModernBERT-base",
|
| 117743 |
"transformers_version": "5.0.0",
|
| 117744 |
-
"vocab_size":
|
| 117745 |
}
|
|
|
|
| 117741 |
"tie_word_embeddings": true,
|
| 117742 |
"tokenizer_path": "answerdotai/ModernBERT-base",
|
| 117743 |
"transformers_version": "5.0.0",
|
| 117744 |
+
"vocab_size": 117660
|
| 117745 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 957523088
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79d0851573b5002b29d196af74a0b87c06e774b30889fe729bd17f323af7fc2f
|
| 3 |
size 957523088
|