PeteBleackley commited on
Commit
36f55fa
·
verified ·
1 Parent(s): 2837752

Upload DisamBert

Browse files
Files changed (2) hide show
  1. DisamBert.py +65 -64
  2. model.safetensors +1 -1
DisamBert.py CHANGED
@@ -30,9 +30,10 @@ class PaddedBatch:
30
  class DisamBert(PreTrainedModel):
31
  def __init__(self, config:PreTrainedConfig):
32
  super().__init__(config)
33
- self.BaseModel = AutoModel.from_pretrained(config.name_or_path).to("cuda")
34
  self.tokenizer = AutoTokenizer.from_pretrained(config.name_or_path)
35
- self.classifier_head = nn.UninitializedParameter(device="cuda")
 
36
  self.__entities = None
37
 
38
  @classmethod
@@ -45,29 +46,30 @@ class DisamBert(PreTrainedModel):
45
  vectors = []
46
  batch = []
47
  n = 0
48
- for entity in entities:
49
- entity_ids.append(entity.concept)
50
- batch.append(entity.definition)
51
-
52
- n += 1
53
- if n == BATCH_SIZE:
 
 
 
 
 
 
 
 
 
54
  tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
55
  encoding = self.BaseModel(
56
- tokens["input_ids"].to("cuda"), tokens["attention_mask"].to("cuda")
57
  )
58
  vectors.append(encoding.last_hidden_state.detach()[:, 0])
59
- n = 0
60
- batch = []
61
- if n > 0:
62
- tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
63
- encoding = self.BaseModel(
64
- tokens["input_ids"].to("cuda"), tokens["attention_mask"].to("cuda")
65
- )
66
- vectors.append(encoding.last_hidden_state.detach()[:, 0])
67
-
68
- self.__entities = pd.Series(entity_ids)
69
- self.config.entities = entity_ids
70
- self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
71
 
72
  @property
73
  def entities(self) -> pd.Series:
@@ -81,45 +83,45 @@ class DisamBert(PreTrainedModel):
81
  )
82
  all_indices = []
83
  all_tokens = []
84
-
85
- for sentence, span_indices in zip(sentences, indices, strict=True):
86
- indices = []
87
- tokens = []
88
- last_span = len(span_indices) - 2
89
- for i, position in enumerate(span_indices[:-1]):
90
- span = sentence[position : span_indices[i + 1]]
91
- span_tokens = self.tokenizer([span], padding=False)["input_ids"][0]
92
- if i > 0:
93
- span_tokens = span_tokens[1:]
94
- if i < last_span:
95
- span_tokens = span_tokens[:-1]
96
- indices.append(len(span_tokens))
97
- tokens.extend(span_tokens)
98
- all_indices.append(indices)
99
- all_tokens.append(tokens)
100
- sentence_lengths = [len(boundaries) for boundaries in all_indices]
101
- maxlen = max(sentence_lengths)
102
- batch = self.pad(all_tokens)
103
- token_vectors = self.BaseModel(batch.input_ids, batch.attention_mask).last_hidden_state
104
- span_vectors = torch.cat(
105
- [
106
- torch.vstack(
107
- [
108
- torch.sum(chunk, dim=0)
109
- for chunk in self.split(token_vectors[i], sentence_indices)
110
- ]
111
- )
112
- for (i, sentence_indices) in enumerate(all_indices)
113
- ]
114
- )
115
- logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head)
116
- split_logits = torch.split(logits, sentence_lengths, dim=1)
117
- return torch.stack(
118
- [
119
- self.extend_to_max_length(sentence, length, maxlen)
120
- for (sentence, length) in zip(split_logits, sentence_lengths, strict=True)
121
- ]
122
- )
123
 
124
  def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]:
125
  maxlen = vectors.shape[0]
@@ -135,13 +137,12 @@ class DisamBert(PreTrainedModel):
135
  [
136
  sentence + [self.config.pad_token_id] * (maxlen - length)
137
  for (sentence, length) in zip(tokens, lengths, strict=True)
138
- ],
139
- device="cuda",
140
  )
141
  attention_mask = torch.vstack(
142
  [
143
  torch.cat(
144
- (torch.ones(length, device="cuda"), torch.zeros(maxlen - length, device="cuda"))
145
  )
146
  for length in lengths
147
  ]
@@ -155,7 +156,7 @@ class DisamBert(PreTrainedModel):
155
  torch.cat(
156
  [
157
  sentence,
158
- torch.zeros((self.__entities.shape[0], maxlength - length), device="cuda"),
159
  ],
160
  dim=1,
161
  )
 
30
  class DisamBert(PreTrainedModel):
31
  def __init__(self, config:PreTrainedConfig):
32
  super().__init__(config)
33
+ self.BaseModel = AutoModel.from_pretrained(config.name_or_path,device_map="auto")
34
  self.tokenizer = AutoTokenizer.from_pretrained(config.name_or_path)
35
+ with self.BaseModel.device:
36
+ self.classifier_head = nn.UninitializedParameter()
37
  self.__entities = None
38
 
39
  @classmethod
 
46
  vectors = []
47
  batch = []
48
  n = 0
49
+ with self.BaseModel.device:
50
+ for entity in entities:
51
+ entity_ids.append(entity.concept)
52
+ batch.append(entity.definition)
53
+
54
+ n += 1
55
+ if n == BATCH_SIZE:
56
+ tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
57
+ encoding = self.BaseModel(
58
+ tokens["input_ids"], tokens["attention_mask"]
59
+ )
60
+ vectors.append(encoding.last_hidden_state.detach()[:, 0])
61
+ n = 0
62
+ batch = []
63
+ if n > 0:
64
  tokens = self.tokenizer(batch, padding=True, return_tensors="pt")
65
  encoding = self.BaseModel(
66
+ tokens["input_ids"], tokens["attention_mask"]
67
  )
68
  vectors.append(encoding.last_hidden_state.detach()[:, 0])
69
+
70
+ self.__entities = pd.Series(entity_ids)
71
+ self.config.entities = entity_ids
72
+ self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0))
 
 
 
 
 
 
 
 
73
 
74
  @property
75
  def entities(self) -> pd.Series:
 
83
  )
84
  all_indices = []
85
  all_tokens = []
86
+ with self.BaseModel.device:
87
+ for sentence, span_indices in zip(sentences, indices, strict=True):
88
+ indices = []
89
+ tokens = []
90
+ last_span = len(span_indices) - 2
91
+ for i, position in enumerate(span_indices[:-1]):
92
+ span = sentence[position : span_indices[i + 1]]
93
+ span_tokens = self.tokenizer([span], padding=False)["input_ids"][0]
94
+ if i > 0:
95
+ span_tokens = span_tokens[1:]
96
+ if i < last_span:
97
+ span_tokens = span_tokens[:-1]
98
+ indices.append(len(span_tokens))
99
+ tokens.extend(span_tokens)
100
+ all_indices.append(indices)
101
+ all_tokens.append(tokens)
102
+ sentence_lengths = [len(boundaries) for boundaries in all_indices]
103
+ maxlen = max(sentence_lengths)
104
+ batch = self.pad(all_tokens)
105
+ token_vectors = self.BaseModel(batch.input_ids, batch.attention_mask).last_hidden_state
106
+ span_vectors = torch.cat(
107
+ [
108
+ torch.vstack(
109
+ [
110
+ torch.sum(chunk, dim=0)
111
+ for chunk in self.split(token_vectors[i], sentence_indices)
112
+ ]
113
+ )
114
+ for (i, sentence_indices) in enumerate(all_indices)
115
+ ]
116
+ )
117
+ logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head)
118
+ split_logits = torch.split(logits, sentence_lengths, dim=1)
119
+ return torch.stack(
120
+ [
121
+ self.extend_to_max_length(sentence, length, maxlen)
122
+ for (sentence, length) in zip(split_logits, sentence_lengths, strict=True)
123
+ ]
124
+ )
125
 
126
  def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]:
127
  maxlen = vectors.shape[0]
 
137
  [
138
  sentence + [self.config.pad_token_id] * (maxlen - length)
139
  for (sentence, length) in zip(tokens, lengths, strict=True)
140
+ ]
 
141
  )
142
  attention_mask = torch.vstack(
143
  [
144
  torch.cat(
145
+ (torch.ones(length), torch.zeros(maxlen - length))
146
  )
147
  for length in lengths
148
  ]
 
156
  torch.cat(
157
  [
158
  sentence,
159
+ torch.zeros((self.__entities.shape[0], maxlength - length)),
160
  ],
161
  dim=1,
162
  )
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:765767f2ce35a2118f15cef212da9c3e5159a114e6e1aa080942d3e256b12c22
3
  size 957523088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d2a927c475b82fe97cb22c4f9e8367a186e66d17a7716fd6fd231d190684f5d
3
  size 957523088