haukurpj commited on
Commit
a2c9d48
·
1 Parent(s): 3bce6c8

add vectorized and more stable way to aggregate subwords to words

Browse files
Files changed (1) hide show
  1. modeling.py +98 -51
modeling.py CHANGED
@@ -40,7 +40,7 @@ class MultiLabelTokenClassificationHead(nn.Module):
40
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41
  """
42
  H = hidden_size, C = num_categories, A = num_attributes, Wt = total_words
43
-
44
  Args:
45
  features: Word-level features (Wt x H)
46
 
@@ -131,7 +131,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
132
  """
133
  B = batch_size, L = seq_len, H = hidden_size, C = num_categories, A = num_attributes, W = max_words
134
-
135
  Args:
136
  input_ids: Token indices (B x L)
137
  attention_mask: Attention mask (B x L)
@@ -157,13 +157,13 @@ class IceBertPosForTokenClassification(PreTrainedModel):
157
  )
158
 
159
  hidden_states = outputs[0] # (B x L x H)
160
-
161
  # (B x L x H) -> (Wt x H)
162
  word_embeddings = self._aggregate_subword_tokens(hidden_states, word_mask, attention_mask)
163
-
164
  # (Wt x H) -> (Wt x C), (Wt x A)
165
  cat_logits, attr_logits = self.classifier(word_embeddings)
166
-
167
  # (Wt x C) -> (B x W x C), (Wt x A) -> (B x W x A)
168
  nwords = word_mask.sum(dim=-1) # (B,)
169
  cat_logits = self._reshape_to_batch_format(cat_logits, nwords)
@@ -175,9 +175,10 @@ class IceBertPosForTokenClassification(PreTrainedModel):
175
  ) -> torch.Tensor:
176
  """
177
  Average subword tokens within each word to get word-level representations.
178
-
 
179
  B = batch_size, L = seq_len, H = hidden_size, Wt = total_words
180
-
181
  Args:
182
  sequence_output: Subword token representations (B x L x H)
183
  word_mask: Binary mask where 1 indicates start of word (B x L)
@@ -187,48 +188,94 @@ class IceBertPosForTokenClassification(PreTrainedModel):
187
  word_features: Concatenated word-level features (Wt x H)
188
  """
189
  batch_size, seq_len, hidden_size = sequence_output.shape
190
- mean_words = []
191
-
192
- for batch_idx in range(batch_size):
193
- # Get valid (non-padding) tokens for this sequence
194
- valid_mask = attention_mask[batch_idx].bool() # (L,) -> (Lv,)
195
- seq_output = sequence_output[batch_idx, valid_mask] # (Lv x H)
196
- seq_word_mask = word_mask[batch_idx, valid_mask] # (Lv,)
197
-
198
- # Find word start positions
199
- word_starts = seq_word_mask.nonzero(as_tuple=True)[0] # (Ws,)
200
-
 
 
 
 
 
 
 
201
  if len(word_starts) == 0:
202
  continue
203
-
204
- # For each word, find its token span and average
 
 
205
  for i, start_pos in enumerate(word_starts):
206
- # Find end position (start of next word or end of valid sequence)
207
  if i + 1 < len(word_starts):
208
- end_pos = word_starts[i + 1]
209
  else:
210
- end_pos = len(seq_output)
211
-
212
- # Average tokens within this word (excluding padding)
213
- word_tokens = seq_output[start_pos:end_pos] # (Tw x H)
214
- word_repr = word_tokens.mean(dim=0) # (H,)
215
- mean_words.append(word_repr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- if len(mean_words) == 0:
218
- return torch.empty(0, hidden_size, device=sequence_output.device)
 
219
 
220
- return torch.stack(mean_words) # (Wt x H)
221
 
222
  def _reshape_to_batch_format(self, logits: torch.Tensor, nwords: torch.Tensor) -> torch.Tensor:
223
  """
224
  Reshape concatenated word predictions back to padded batch format.
225
-
226
  B = batch_size, W = max_words, Wt = total_words, K = num_classes
227
-
228
  Args:
229
  logits: Concatenated word predictions (Wt x K)
230
  nwords: Number of words per sequence (B,)
231
-
232
  Returns:
233
  batch_logits: Batched predictions (B x W x K)
234
  """
@@ -265,7 +312,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
265
  Convert word_ids to binary mask indicating word boundaries.
266
 
267
  B = batch_size, L = seq_len
268
-
269
  Args:
270
  word_ids: List of word id sequences for each batch item
271
  input_shape: Shape of input_ids tensor (B x L)
@@ -296,7 +343,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
296
  Predict POS labels from raw text.
297
 
298
  B = batch_size, L = seq_len
299
-
300
  Args:
301
  sentences: List of input sentences (B,)
302
  tokenizer: HuggingFace tokenizer
@@ -330,14 +377,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
330
  ) -> List[List[Tuple[str, List[str]]]]:
331
  """
332
  Convert logits to human-readable labels using schema-based logic.
333
-
334
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, L = seq_len
335
-
336
  Args:
337
  cat_logits: Category logits (B x W x C)
338
  attr_logits: Attribute logits (B x W x A)
339
  word_mask: Binary mask for valid words (B x L)
340
-
341
  Returns:
342
  predictions: List of [(category, [attributes])] for each sequence
343
  """
@@ -350,20 +397,20 @@ class IceBertPosForTokenClassification(PreTrainedModel):
350
 
351
  predictions = []
352
  schema = self.config.label_schema
353
-
354
  for seq_idx in range(bsz):
355
  seq_nwords = nwords[seq_idx]
356
  # (W x C) -> (seq_nwords,)
357
  pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
358
-
359
  seq_predictions = []
360
  for word_idx in range(seq_nwords):
361
  cat_idx = pred_cat_indices[word_idx].item()
362
  cat_name = schema.label_categories[cat_idx]
363
-
364
  # Get valid groups for this category
365
  valid_groups = schema.category_to_group_names.get(cat_name, [])
366
-
367
  # Collect attributes for this word
368
  attributes = []
369
  for group_name in valid_groups:
@@ -382,7 +429,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
382
  best_idx = group_logits.max(dim=-1).indices.item()
383
  attr_idx = group_indices[best_idx].item()
384
  attributes.append(schema.labels[attr_idx])
385
-
386
  # Apply specific rules from original model
387
  if len(attributes) == 1 and attributes[0] == "pos":
388
  # This label is used as a default for training but implied in mim format
@@ -390,9 +437,9 @@ class IceBertPosForTokenClassification(PreTrainedModel):
390
  elif cat_name == "sl" and "act" in attributes:
391
  # Number and tense are not shown for sl act in mim format
392
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
393
-
394
  seq_predictions.append((cat_name, attributes))
395
-
396
  predictions.append(seq_predictions)
397
 
398
  return predictions
@@ -400,25 +447,25 @@ class IceBertPosForTokenClassification(PreTrainedModel):
400
  def predict_ifd_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[str]]:
401
  """
402
  Predict IFD format labels from raw text.
403
-
404
  B = batch_size, Ws = seq_words
405
-
406
  Args:
407
  sentences: List of input sentences (B,)
408
  tokenizer: HuggingFace tokenizer
409
-
410
  Returns:
411
  ifd_predictions: List of IFD labels per sentence (B x Ws)
412
  """
413
  # Get model predictions in (category, [attributes]) format
414
  predictions = self.predict_labels_from_text(sentences, tokenizer)
415
-
416
  # Convert each sentence's predictions to IFD format
417
  ifd_predictions = []
418
  for sentence_predictions in predictions:
419
  ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
420
  ifd_predictions.append(ifd_labels)
421
-
422
  return ifd_predictions
423
 
424
 
 
40
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41
  """
42
  H = hidden_size, C = num_categories, A = num_attributes, Wt = total_words
43
+
44
  Args:
45
  features: Word-level features (Wt x H)
46
 
 
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
132
  """
133
  B = batch_size, L = seq_len, H = hidden_size, C = num_categories, A = num_attributes, W = max_words
134
+
135
  Args:
136
  input_ids: Token indices (B x L)
137
  attention_mask: Attention mask (B x L)
 
157
  )
158
 
159
  hidden_states = outputs[0] # (B x L x H)
160
+
161
  # (B x L x H) -> (Wt x H)
162
  word_embeddings = self._aggregate_subword_tokens(hidden_states, word_mask, attention_mask)
163
+
164
  # (Wt x H) -> (Wt x C), (Wt x A)
165
  cat_logits, attr_logits = self.classifier(word_embeddings)
166
+
167
  # (Wt x C) -> (B x W x C), (Wt x A) -> (B x W x A)
168
  nwords = word_mask.sum(dim=-1) # (B,)
169
  cat_logits = self._reshape_to_batch_format(cat_logits, nwords)
 
175
  ) -> torch.Tensor:
176
  """
177
  Average subword tokens within each word to get word-level representations.
178
+ Vectorized implementation using scatter operations for efficiency.
179
+
180
  B = batch_size, L = seq_len, H = hidden_size, Wt = total_words
181
+
182
  Args:
183
  sequence_output: Subword token representations (B x L x H)
184
  word_mask: Binary mask where 1 indicates start of word (B x L)
 
188
  word_features: Concatenated word-level features (Wt x H)
189
  """
190
  batch_size, seq_len, hidden_size = sequence_output.shape
191
+ device = sequence_output.device
192
+
193
+ # Create word indices mapping each token to its word
194
+ # Strategy: assign each token to a word ID, then use scatter operations to sum/average
195
+ # Only tokens that belong to actual words get valid indices
196
+ word_indices = torch.full_like(word_mask, -1, dtype=torch.long) # (B x L)
197
+
198
+ # Build word indices by finding word boundaries
199
+ # Each token gets assigned to a word index (0, 1, 2, ...) within its sequence
200
+ for b in range(batch_size):
201
+ valid_mask = attention_mask[b].bool() # (L,) - exclude padding tokens
202
+ if not valid_mask.any():
203
+ continue
204
+
205
+ # Get word starts for this sequence
206
+ seq_word_mask = word_mask[b, valid_mask] # (Lv,) - only valid positions
207
+ word_starts = seq_word_mask.nonzero(as_tuple=True)[0] # (Ws,) - positions where words start
208
+
209
  if len(word_starts) == 0:
210
  continue
211
+
212
+ # Assign each token to its word within this sequence
213
+ seq_word_indices = torch.full((len(seq_word_mask),), -1, dtype=torch.long, device=device)
214
+
215
  for i, start_pos in enumerate(word_starts):
216
+ # Find end position (next word start or end of sequence)
217
  if i + 1 < len(word_starts):
218
+ end_pos = word_starts[i + 1] # Next word boundary
219
  else:
220
+ end_pos = len(seq_word_mask) # End of sequence
221
+
222
+ # All tokens from start_pos to end_pos belong to word i
223
+ seq_word_indices[start_pos:end_pos] = i
224
+
225
+ # Store the word indices for this sequence
226
+ word_indices[b, valid_mask] = seq_word_indices
227
+
228
+ # Create global word indices across the entire batch
229
+ # Convert local word indices (0,1,2... per sequence) to global indices (0,1,2...total_words-1)
230
+ # This allows us to use scatter operations across the entire batch
231
+ max_words_per_seq = word_mask.sum(dim=-1) # (B,) - words per sequence
232
+ word_offset = torch.cat(
233
+ [torch.zeros(1, device=device, dtype=torch.long), max_words_per_seq.cumsum(dim=0)[:-1]]
234
+ ) # (B,) - cumulative word offsets
235
+
236
+ # Add batch offsets to make global unique indices
237
+ # E.g., if batch has [3,2] words: seq0=[0,1,2], seq1=[3,4]
238
+ global_word_indices = word_indices + word_offset.unsqueeze(1) # (B x L)
239
+
240
+ # Flatten everything for scatter operations
241
+ flat_output = sequence_output.view(-1, hidden_size) # (B*L x H)
242
+ flat_word_indices = global_word_indices.view(-1) # (B*L,)
243
+ flat_attention = attention_mask.view(-1) # (B*L,)
244
+
245
+ # Only use tokens that belong to words (not padding and not before first word)
246
+ valid_word_tokens = (flat_attention.bool()) & (flat_word_indices >= 0) # (B*L,)
247
+ valid_output = flat_output[valid_word_tokens] # (valid_word_tokens x H)
248
+ valid_word_indices = flat_word_indices[valid_word_tokens] # (valid_word_tokens,)
249
+
250
+ total_words = max_words_per_seq.sum().item()
251
+ if total_words == 0:
252
+ return torch.empty(0, hidden_size, device=device)
253
+
254
+ # Vectorized aggregation using scatter operations
255
+ # Sum all token embeddings that belong to the same word
256
+ word_sums = torch.zeros(total_words, hidden_size, device=device) # (Wt x H)
257
+ word_sums.scatter_add_(0, valid_word_indices.unsqueeze(1).expand(-1, hidden_size), valid_output)
258
+
259
+ # Count how many tokens belong to each word (for averaging)
260
+ word_counts = torch.zeros(total_words, device=device) # (Wt,)
261
+ word_counts.scatter_add_(0, valid_word_indices, torch.ones_like(valid_word_indices, dtype=torch.float))
262
 
263
+ # Compute average: word_embedding = sum_of_tokens / count_of_tokens
264
+ word_counts = torch.clamp(word_counts, min=1.0) # Prevent division by zero
265
+ word_features = word_sums / word_counts.unsqueeze(1) # (Wt x H)
266
 
267
+ return word_features
268
 
269
  def _reshape_to_batch_format(self, logits: torch.Tensor, nwords: torch.Tensor) -> torch.Tensor:
270
  """
271
  Reshape concatenated word predictions back to padded batch format.
272
+
273
  B = batch_size, W = max_words, Wt = total_words, K = num_classes
274
+
275
  Args:
276
  logits: Concatenated word predictions (Wt x K)
277
  nwords: Number of words per sequence (B,)
278
+
279
  Returns:
280
  batch_logits: Batched predictions (B x W x K)
281
  """
 
312
  Convert word_ids to binary mask indicating word boundaries.
313
 
314
  B = batch_size, L = seq_len
315
+
316
  Args:
317
  word_ids: List of word id sequences for each batch item
318
  input_shape: Shape of input_ids tensor (B x L)
 
343
  Predict POS labels from raw text.
344
 
345
  B = batch_size, L = seq_len
346
+
347
  Args:
348
  sentences: List of input sentences (B,)
349
  tokenizer: HuggingFace tokenizer
 
377
  ) -> List[List[Tuple[str, List[str]]]]:
378
  """
379
  Convert logits to human-readable labels using schema-based logic.
380
+
381
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, L = seq_len
382
+
383
  Args:
384
  cat_logits: Category logits (B x W x C)
385
  attr_logits: Attribute logits (B x W x A)
386
  word_mask: Binary mask for valid words (B x L)
387
+
388
  Returns:
389
  predictions: List of [(category, [attributes])] for each sequence
390
  """
 
397
 
398
  predictions = []
399
  schema = self.config.label_schema
400
+
401
  for seq_idx in range(bsz):
402
  seq_nwords = nwords[seq_idx]
403
  # (W x C) -> (seq_nwords,)
404
  pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
405
+
406
  seq_predictions = []
407
  for word_idx in range(seq_nwords):
408
  cat_idx = pred_cat_indices[word_idx].item()
409
  cat_name = schema.label_categories[cat_idx]
410
+
411
  # Get valid groups for this category
412
  valid_groups = schema.category_to_group_names.get(cat_name, [])
413
+
414
  # Collect attributes for this word
415
  attributes = []
416
  for group_name in valid_groups:
 
429
  best_idx = group_logits.max(dim=-1).indices.item()
430
  attr_idx = group_indices[best_idx].item()
431
  attributes.append(schema.labels[attr_idx])
432
+
433
  # Apply specific rules from original model
434
  if len(attributes) == 1 and attributes[0] == "pos":
435
  # This label is used as a default for training but implied in mim format
 
437
  elif cat_name == "sl" and "act" in attributes:
438
  # Number and tense are not shown for sl act in mim format
439
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
440
+
441
  seq_predictions.append((cat_name, attributes))
442
+
443
  predictions.append(seq_predictions)
444
 
445
  return predictions
 
447
  def predict_ifd_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[str]]:
448
  """
449
  Predict IFD format labels from raw text.
450
+
451
  B = batch_size, Ws = seq_words
452
+
453
  Args:
454
  sentences: List of input sentences (B,)
455
  tokenizer: HuggingFace tokenizer
456
+
457
  Returns:
458
  ifd_predictions: List of IFD labels per sentence (B x Ws)
459
  """
460
  # Get model predictions in (category, [attributes]) format
461
  predictions = self.predict_labels_from_text(sentences, tokenizer)
462
+
463
  # Convert each sentence's predictions to IFD format
464
  ifd_predictions = []
465
  for sentence_predictions in predictions:
466
  ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
467
  ifd_predictions.append(ifd_labels)
468
+
469
  return ifd_predictions
470
 
471