haukurpj commited on
Commit
bca5d58
·
1 Parent(s): a2c9d48

more vectorization of logits to labels

Browse files
Files changed (1) hide show
  1. modeling.py +122 -54
modeling.py CHANGED
@@ -96,24 +96,62 @@ class IceBertPosForTokenClassification(PreTrainedModel):
96
  schema = self.config.label_schema
97
 
98
  # Create tensors as regular attributes (not buffers to avoid init warnings)
99
- self.group_mask = schema.get_group_masks()
100
- self.group_name_to_group_attr_indices = schema.get_group_name_to_group_attr_indices()
101
-
 
 
102
  # Category name to index mapping (regular dict, no device movement needed)
103
  self.category_name_to_index = schema.get_category_name_to_index()
104
-
105
- def _apply(self, fn): # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  """Override _apply to move our custom tensors with the model."""
107
  super()._apply(fn)
108
-
109
  # Move our custom tensors when model.to(device) is called
110
- if hasattr(self, "group_mask"):
111
  self.group_mask = fn(self.group_mask)
112
-
113
- if hasattr(self, "group_name_to_group_attr_indices"):
114
- for group_name, tensor in self.group_name_to_group_attr_indices.items():
115
- self.group_name_to_group_attr_indices[group_name] = fn(tensor)
116
-
 
 
117
  return self
118
 
119
  def forward(
@@ -376,70 +414,100 @@ class IceBertPosForTokenClassification(PreTrainedModel):
376
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
377
  ) -> List[List[Tuple[str, List[str]]]]:
378
  """
379
- Convert logits to human-readable labels using schema-based logic.
380
-
381
- B = batch_size, W = max_words, C = num_categories, A = num_attributes, L = seq_len
382
-
 
 
 
 
 
 
383
  Args:
384
  cat_logits: Category logits (B x W x C)
385
  attr_logits: Attribute logits (B x W x A)
386
  word_mask: Binary mask for valid words (B x L)
387
-
388
  Returns:
389
  predictions: List of [(category, [attributes])] for each sequence
390
  """
391
- bsz, _, num_cats = cat_logits.shape
392
- _, _, num_attrs = attr_logits.shape
393
  nwords = word_mask.sum(-1) # (B,)
394
-
395
- assert num_attrs == len(self.config.label_schema.labels)
396
- assert num_cats == len(self.config.label_schema.label_categories)
397
-
398
- predictions = []
399
  schema = self.config.label_schema
400
 
 
 
 
 
 
 
 
401
  for seq_idx in range(bsz):
402
- seq_nwords = nwords[seq_idx]
403
- # (W x C) -> (seq_nwords,)
404
- pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
405
-
 
 
 
 
 
 
 
 
 
406
  seq_predictions = []
407
  for word_idx in range(seq_nwords):
408
- cat_idx = pred_cat_indices[word_idx].item()
409
  cat_name = schema.label_categories[cat_idx]
410
-
411
- # Get valid groups for this category
412
- valid_groups = schema.category_to_group_names.get(cat_name, [])
413
-
414
- # Collect attributes for this word
 
415
  attributes = []
416
- for group_name in valid_groups:
417
- if group_name in self.group_name_to_group_attr_indices:
418
- group_indices = self.group_name_to_group_attr_indices[group_name] # (Gs,)
419
- if len(group_indices) > 0:
420
- # (A,) -> (Gs,)
421
- group_logits = attr_logits[seq_idx, word_idx, group_indices]
422
- if len(group_indices) == 1:
423
- # Binary decision for single-item groups
424
- if group_logits.sigmoid().item() > 0.5:
425
- attr_idx = group_indices[0].item()
426
- attributes.append(schema.labels[attr_idx])
427
- else:
428
- # Multi-class decision for multi-item groups
429
- best_idx = group_logits.max(dim=-1).indices.item()
430
- attr_idx = group_indices[best_idx].item()
431
- attributes.append(schema.labels[attr_idx])
432
-
433
- # Apply specific rules from original model
 
 
 
 
 
 
 
 
 
 
 
 
434
  if len(attributes) == 1 and attributes[0] == "pos":
435
  # This label is used as a default for training but implied in mim format
436
  attributes = []
437
  elif cat_name == "sl" and "act" in attributes:
438
  # Number and tense are not shown for sl act in mim format
439
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
440
-
441
  seq_predictions.append((cat_name, attributes))
442
-
443
  predictions.append(seq_predictions)
444
 
445
  return predictions
 
96
  schema = self.config.label_schema
97
 
98
  # Create tensors as regular attributes (not buffers to avoid init warnings)
99
+ self.group_mask = schema.get_group_masks() # (C x G)
100
+
101
+ # Convert group mappings to tensor format for GPU operations
102
+ self._create_tensor_group_mappings(schema)
103
+
104
  # Category name to index mapping (regular dict, no device movement needed)
105
  self.category_name_to_index = schema.get_category_name_to_index()
106
+
107
+ def _create_tensor_group_mappings(self, schema):
108
+ """
109
+ Create tensor-based group mappings for efficient GPU operations.
110
+
111
+ Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
112
+ This optimization replaces dict lookups with tensor indexing for better performance.
113
+
114
+ C = num_categories, G = num_groups, A = num_attributes
115
+ """
116
+ num_groups = len(schema.group_names)
117
+ num_labels = len(schema.labels)
118
+ device = torch.device("cpu") # Will be moved with model
119
+
120
+ # Create group attribute indices tensor: (G x max_group_size)
121
+ # Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
122
+ max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
123
+ self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
124
+ self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
125
+
126
+ for group_idx, group_name in enumerate(schema.group_names):
127
+ group_labels = schema.group_name_to_labels[group_name]
128
+ group_size = len(group_labels)
129
+ self.group_sizes[group_idx] = group_size
130
+
131
+ for label_idx, label in enumerate(group_labels):
132
+ if label in schema.labels:
133
+ attr_idx = schema.labels.index(label)
134
+ self.group_attr_indices[group_idx, label_idx] = attr_idx
135
+
136
+ # Create category to groups mapping: (C x G) - which groups are valid for each category
137
+ # Replaces dict-based category_to_group_names with tensor indexing
138
+ # Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
139
+ self.category_to_groups = self.group_mask.clone() # (C x G)
140
+
141
+ def _apply(self, fn):
142
  """Override _apply to move our custom tensors with the model."""
143
  super()._apply(fn)
144
+
145
  # Move our custom tensors when model.to(device) is called
146
+ if hasattr(self, 'group_mask'):
147
  self.group_mask = fn(self.group_mask)
148
+ if hasattr(self, 'group_attr_indices'):
149
+ self.group_attr_indices = fn(self.group_attr_indices)
150
+ if hasattr(self, 'group_sizes'):
151
+ self.group_sizes = fn(self.group_sizes)
152
+ if hasattr(self, 'category_to_groups'):
153
+ self.category_to_groups = fn(self.category_to_groups)
154
+
155
  return self
156
 
157
  def forward(
 
414
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
415
  ) -> List[List[Tuple[str, List[str]]]]:
416
  """
417
+ Convert logits to human-readable labels using vectorized operations.
418
+
419
+ Key optimizations:
420
+ 1. Tensor-based schema lookups instead of Python dict access
421
+ 2. Vectorized argmax for categories across entire batch
422
+ 3. Reduced CPU-GPU context switching by batching operations
423
+ 4. Pre-computed tensor mappings for group/attribute relationships
424
+
425
+ B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
426
+
427
  Args:
428
  cat_logits: Category logits (B x W x C)
429
  attr_logits: Attribute logits (B x W x A)
430
  word_mask: Binary mask for valid words (B x L)
431
+
432
  Returns:
433
  predictions: List of [(category, [attributes])] for each sequence
434
  """
435
+ device = cat_logits.device
436
+ bsz, max_words = cat_logits.shape[:2]
437
  nwords = word_mask.sum(-1) # (B,)
 
 
 
 
 
438
  schema = self.config.label_schema
439
 
440
+ # Vectorized category prediction: (B x W)
441
+ # Single GPU operation instead of nested loops
442
+ pred_cat_indices = cat_logits.argmax(dim=-1) # (B x W)
443
+
444
+ # Vectorized attribute prediction for all groups
445
+ predictions = []
446
+
447
  for seq_idx in range(bsz):
448
+ seq_nwords = int(nwords[seq_idx].item())
449
+ if seq_nwords == 0:
450
+ predictions.append([])
451
+ continue
452
+
453
+ # Get categories for this sequence: (seq_nwords,)
454
+ seq_cat_indices = pred_cat_indices[seq_idx, :seq_nwords]
455
+
456
+ # Get valid groups for each category: (seq_nwords x G)
457
+ # Tensor lookup replaces dict access: category_to_group_names[cat_name]
458
+ seq_valid_groups = self.category_to_groups[seq_cat_indices] # (seq_nwords x G)
459
+
460
+ # Process attributes for all words in sequence
461
  seq_predictions = []
462
  for word_idx in range(seq_nwords):
463
+ cat_idx = seq_cat_indices[word_idx].item()
464
  cat_name = schema.label_categories[cat_idx]
465
+
466
+ # Get valid groups for this word
467
+ word_valid_groups = seq_valid_groups[word_idx] # (G,)
468
+ valid_group_indices = word_valid_groups.nonzero(as_tuple=True)[0] # (valid_groups,)
469
+
470
+ # Collect attributes using vectorized operations
471
  attributes = []
472
+ for group_idx in valid_group_indices:
473
+ group_idx = group_idx.item()
474
+ group_size = self.group_sizes[group_idx].item()
475
+
476
+ if group_size == 0:
477
+ continue
478
+
479
+ # Get attribute indices for this group: (group_size,)
480
+ # Tensor lookup replaces dict: group_name_to_group_attr_indices[group_name]
481
+ group_attr_indices = self.group_attr_indices[group_idx, :group_size]
482
+ valid_indices = group_attr_indices[group_attr_indices >= 0]
483
+
484
+ if len(valid_indices) == 0:
485
+ continue
486
+
487
+ # Get logits for this group: (group_size,)
488
+ group_logits = attr_logits[seq_idx, word_idx, valid_indices]
489
+
490
+ if len(valid_indices) == 1:
491
+ # Binary decision
492
+ if group_logits.sigmoid().item() > 0.5:
493
+ attr_idx = valid_indices[0].item()
494
+ attributes.append(schema.labels[attr_idx])
495
+ else:
496
+ # Multi-class decision
497
+ best_local_idx = group_logits.argmax().item()
498
+ attr_idx = valid_indices[best_local_idx].item()
499
+ attributes.append(schema.labels[attr_idx])
500
+
501
+ # Apply post-processing rules
502
  if len(attributes) == 1 and attributes[0] == "pos":
503
  # This label is used as a default for training but implied in mim format
504
  attributes = []
505
  elif cat_name == "sl" and "act" in attributes:
506
  # Number and tense are not shown for sl act in mim format
507
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
508
+
509
  seq_predictions.append((cat_name, attributes))
510
+
511
  predictions.append(seq_predictions)
512
 
513
  return predictions