haukurpj commited on
Commit
1114b8e
·
1 Parent(s): bca5d58

more vectorization

Browse files
Files changed (1) hide show
  1. modeling.py +100 -81
modeling.py CHANGED
@@ -97,61 +97,60 @@ class IceBertPosForTokenClassification(PreTrainedModel):
97
 
98
  # Create tensors as regular attributes (not buffers to avoid init warnings)
99
  self.group_mask = schema.get_group_masks() # (C x G)
100
-
101
  # Convert group mappings to tensor format for GPU operations
102
  self._create_tensor_group_mappings(schema)
103
-
104
  # Category name to index mapping (regular dict, no device movement needed)
105
  self.category_name_to_index = schema.get_category_name_to_index()
106
-
107
  def _create_tensor_group_mappings(self, schema):
108
  """
109
  Create tensor-based group mappings for efficient GPU operations.
110
-
111
  Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
112
  This optimization replaces dict lookups with tensor indexing for better performance.
113
-
114
  C = num_categories, G = num_groups, A = num_attributes
115
  """
116
  num_groups = len(schema.group_names)
117
- num_labels = len(schema.labels)
118
  device = torch.device("cpu") # Will be moved with model
119
-
120
  # Create group attribute indices tensor: (G x max_group_size)
121
  # Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
122
  max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
123
  self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
124
  self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
125
-
126
  for group_idx, group_name in enumerate(schema.group_names):
127
  group_labels = schema.group_name_to_labels[group_name]
128
  group_size = len(group_labels)
129
  self.group_sizes[group_idx] = group_size
130
-
131
  for label_idx, label in enumerate(group_labels):
132
  if label in schema.labels:
133
  attr_idx = schema.labels.index(label)
134
  self.group_attr_indices[group_idx, label_idx] = attr_idx
135
-
136
  # Create category to groups mapping: (C x G) - which groups are valid for each category
137
  # Replaces dict-based category_to_group_names with tensor indexing
138
  # Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
139
  self.category_to_groups = self.group_mask.clone() # (C x G)
140
 
141
- def _apply(self, fn):
142
  """Override _apply to move our custom tensors with the model."""
143
  super()._apply(fn)
144
-
145
  # Move our custom tensors when model.to(device) is called
146
- if hasattr(self, 'group_mask'):
147
  self.group_mask = fn(self.group_mask)
148
- if hasattr(self, 'group_attr_indices'):
149
  self.group_attr_indices = fn(self.group_attr_indices)
150
- if hasattr(self, 'group_sizes'):
151
  self.group_sizes = fn(self.group_sizes)
152
- if hasattr(self, 'category_to_groups'):
153
  self.category_to_groups = fn(self.category_to_groups)
154
-
155
  return self
156
 
157
  def forward(
@@ -285,7 +284,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
285
  valid_output = flat_output[valid_word_tokens] # (valid_word_tokens x H)
286
  valid_word_indices = flat_word_indices[valid_word_tokens] # (valid_word_tokens,)
287
 
288
- total_words = max_words_per_seq.sum().item()
289
  if total_words == 0:
290
  return torch.empty(0, hidden_size, device=device)
291
 
@@ -372,7 +371,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
372
  prev_word_id = word_id
373
 
374
  # Debug logging to match fairseq model
375
- logger.debug(f"Word mask: {word_mask[batch_idx].tolist()}")
376
 
377
  return word_mask
378
 
@@ -417,86 +416,105 @@ class IceBertPosForTokenClassification(PreTrainedModel):
417
  Convert logits to human-readable labels using vectorized operations.
418
 
419
  Key optimizations:
420
- 1. Tensor-based schema lookups instead of Python dict access
421
- 2. Vectorized argmax for categories across entire batch
422
- 3. Reduced CPU-GPU context switching by batching operations
423
- 4. Pre-computed tensor mappings for group/attribute relationships
424
 
425
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
426
-
427
- Args:
428
- cat_logits: Category logits (B x W x C)
429
- attr_logits: Attribute logits (B x W x A)
430
- word_mask: Binary mask for valid words (B x L)
431
-
432
- Returns:
433
- predictions: List of [(category, [attributes])] for each sequence
434
  """
435
  device = cat_logits.device
436
  bsz, max_words = cat_logits.shape[:2]
437
  nwords = word_mask.sum(-1) # (B,)
438
  schema = self.config.label_schema
439
 
440
- # Vectorized category prediction: (B x W)
441
- # Single GPU operation instead of nested loops
442
- pred_cat_indices = cat_logits.argmax(dim=-1) # (B x W)
 
 
 
443
 
444
- # Vectorized attribute prediction for all groups
445
- predictions = []
446
 
447
- for seq_idx in range(bsz):
448
- seq_nwords = int(nwords[seq_idx].item())
449
- if seq_nwords == 0:
450
- predictions.append([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  continue
452
 
453
- # Get categories for this sequence: (seq_nwords,)
454
- seq_cat_indices = pred_cat_indices[seq_idx, :seq_nwords]
 
 
 
455
 
456
- # Get valid groups for each category: (seq_nwords x G)
457
- # Tensor lookup replaces dict access: category_to_group_names[cat_name]
458
- seq_valid_groups = self.category_to_groups[seq_cat_indices] # (seq_nwords x G)
459
 
460
- # Process attributes for all words in sequence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  seq_predictions = []
462
- for word_idx in range(seq_nwords):
463
- cat_idx = seq_cat_indices[word_idx].item()
 
 
464
  cat_name = schema.label_categories[cat_idx]
465
 
466
- # Get valid groups for this word
467
- word_valid_groups = seq_valid_groups[word_idx] # (G,)
468
- valid_group_indices = word_valid_groups.nonzero(as_tuple=True)[0] # (valid_groups,)
469
-
470
- # Collect attributes using vectorized operations
471
  attributes = []
472
- for group_idx in valid_group_indices:
473
- group_idx = group_idx.item()
474
- group_size = self.group_sizes[group_idx].item()
475
-
476
- if group_size == 0:
477
- continue
478
-
479
- # Get attribute indices for this group: (group_size,)
480
- # Tensor lookup replaces dict: group_name_to_group_attr_indices[group_name]
481
- group_attr_indices = self.group_attr_indices[group_idx, :group_size]
482
- valid_indices = group_attr_indices[group_attr_indices >= 0]
483
-
484
- if len(valid_indices) == 0:
485
- continue
486
-
487
- # Get logits for this group: (group_size,)
488
- group_logits = attr_logits[seq_idx, word_idx, valid_indices]
489
-
490
- if len(valid_indices) == 1:
491
- # Binary decision
492
- if group_logits.sigmoid().item() > 0.5:
493
- attr_idx = valid_indices[0].item()
494
- attributes.append(schema.labels[attr_idx])
495
- else:
496
- # Multi-class decision
497
- best_local_idx = group_logits.argmax().item()
498
- attr_idx = valid_indices[best_local_idx].item()
499
- attributes.append(schema.labels[attr_idx])
500
 
501
  # Apply post-processing rules
502
  if len(attributes) == 1 and attributes[0] == "pos":
@@ -507,6 +525,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
507
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
508
 
509
  seq_predictions.append((cat_name, attributes))
 
510
 
511
  predictions.append(seq_predictions)
512
 
 
97
 
98
  # Create tensors as regular attributes (not buffers to avoid init warnings)
99
  self.group_mask = schema.get_group_masks() # (C x G)
100
+
101
  # Convert group mappings to tensor format for GPU operations
102
  self._create_tensor_group_mappings(schema)
103
+
104
  # Category name to index mapping (regular dict, no device movement needed)
105
  self.category_name_to_index = schema.get_category_name_to_index()
106
+
107
  def _create_tensor_group_mappings(self, schema):
108
  """
109
  Create tensor-based group mappings for efficient GPU operations.
110
+
111
  Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
112
  This optimization replaces dict lookups with tensor indexing for better performance.
113
+
114
  C = num_categories, G = num_groups, A = num_attributes
115
  """
116
  num_groups = len(schema.group_names)
 
117
  device = torch.device("cpu") # Will be moved with model
118
+
119
  # Create group attribute indices tensor: (G x max_group_size)
120
  # Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
121
  max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
122
  self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
123
  self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
124
+
125
  for group_idx, group_name in enumerate(schema.group_names):
126
  group_labels = schema.group_name_to_labels[group_name]
127
  group_size = len(group_labels)
128
  self.group_sizes[group_idx] = group_size
129
+
130
  for label_idx, label in enumerate(group_labels):
131
  if label in schema.labels:
132
  attr_idx = schema.labels.index(label)
133
  self.group_attr_indices[group_idx, label_idx] = attr_idx
134
+
135
  # Create category to groups mapping: (C x G) - which groups are valid for each category
136
  # Replaces dict-based category_to_group_names with tensor indexing
137
  # Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
138
  self.category_to_groups = self.group_mask.clone() # (C x G)
139
 
140
+ def _apply(self, fn): # type: ignore[override]
141
  """Override _apply to move our custom tensors with the model."""
142
  super()._apply(fn)
143
+
144
  # Move our custom tensors when model.to(device) is called
145
+ if hasattr(self, "group_mask"):
146
  self.group_mask = fn(self.group_mask)
147
+ if hasattr(self, "group_attr_indices"):
148
  self.group_attr_indices = fn(self.group_attr_indices)
149
+ if hasattr(self, "group_sizes"):
150
  self.group_sizes = fn(self.group_sizes)
151
+ if hasattr(self, "category_to_groups"):
152
  self.category_to_groups = fn(self.category_to_groups)
153
+
154
  return self
155
 
156
  def forward(
 
284
  valid_output = flat_output[valid_word_tokens] # (valid_word_tokens x H)
285
  valid_word_indices = flat_word_indices[valid_word_tokens] # (valid_word_tokens,)
286
 
287
+ total_words = max_words_per_seq.sum()
288
  if total_words == 0:
289
  return torch.empty(0, hidden_size, device=device)
290
 
 
371
  prev_word_id = word_id
372
 
373
  # Debug logging to match fairseq model
374
+ logger.debug(f"Word mask: {word_mask[batch_idx]}")
375
 
376
  return word_mask
377
 
 
416
  Convert logits to human-readable labels using vectorized operations.
417
 
418
  Key optimizations:
419
+ 1. Flatten batch dimension to process all words simultaneously
420
+ 2. Vectorized group processing across all words
421
+ 3. Defer string conversion to the very end
422
+ 4. Minimize Python loops and tensor-CPU transfers
423
 
424
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
 
 
 
 
 
 
 
 
425
  """
426
  device = cat_logits.device
427
  bsz, max_words = cat_logits.shape[:2]
428
  nwords = word_mask.sum(-1) # (B,)
429
  schema = self.config.label_schema
430
 
431
+ # Step 1: Create valid word mask and flatten batch dimension
432
+ # (B x W) -> (total_words,) to process all words simultaneously
433
+ batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
434
+ for b in range(bsz):
435
+ if nwords[b] > 0:
436
+ batch_word_mask[b, :nwords[b]] = True
437
 
438
+ valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
439
+ total_words = len(valid_positions)
440
 
441
+ if total_words == 0:
442
+ return [[] for _ in range(bsz)]
443
+
444
+ # Step 2: Vectorized category prediction for all valid words
445
+ flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
446
+ flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
447
+
448
+ # Get categories for all valid words: (total_words,)
449
+ all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
450
+
451
+ # Step 3: Vectorized group validity for all words: (total_words x G)
452
+ all_valid_groups = self.category_to_groups[all_cat_indices]
453
+
454
+ # Step 4: Collect attributes using vectorized group processing
455
+ word_to_attrs = {} # word_idx -> list of attr_indices
456
+
457
+ # Process each group across all words simultaneously
458
+ for group_idx in range(self.group_sizes.size(0)):
459
+ group_size = self.group_sizes[group_idx].item()
460
+ if group_size == 0:
461
+ continue
462
+
463
+ # Find words that have this group valid: (words_with_group,)
464
+ words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
465
+ if len(words_with_group) == 0:
466
  continue
467
 
468
+ # Get attribute indices for this group
469
+ group_attr_indices = self.group_attr_indices[group_idx, :group_size]
470
+ valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
471
+ if len(valid_attr_indices) == 0:
472
+ continue
473
 
474
+ # Get logits for all words that need this group: (words_with_group x group_size)
475
+ word_positions = valid_positions[words_with_group]
476
+ group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
477
 
478
+ if len(valid_attr_indices) == 1:
479
+ # Binary decision for all words simultaneously: (words_with_group,)
480
+ decisions = group_logits.sigmoid().squeeze(-1) > 0.5
481
+ selected_words = words_with_group[decisions]
482
+ attr_idx = valid_attr_indices[0].item()
483
+
484
+ for word_idx in selected_words:
485
+ word_idx_item = word_idx.item()
486
+ if word_idx_item not in word_to_attrs:
487
+ word_to_attrs[word_idx_item] = []
488
+ word_to_attrs[word_idx_item].append(attr_idx)
489
+ else:
490
+ # Multi-class decision for all words: (words_with_group,)
491
+ best_indices = group_logits.argmax(dim=-1)
492
+
493
+ for i, word_idx in enumerate(words_with_group):
494
+ attr_idx = valid_attr_indices[best_indices[i]].item()
495
+ word_idx_item = word_idx.item()
496
+ if word_idx_item not in word_to_attrs:
497
+ word_to_attrs[word_idx_item] = []
498
+ word_to_attrs[word_idx_item].append(attr_idx)
499
+
500
+ # Step 5: Reconstruct batch structure and convert to strings (deferred)
501
+ predictions = []
502
+ word_counter = 0
503
+
504
+ for seq_idx in range(bsz):
505
+ seq_nwords = nwords[seq_idx].item()
506
  seq_predictions = []
507
+
508
+ for _ in range(seq_nwords):
509
+ # Get category (string conversion deferred)
510
+ cat_idx = all_cat_indices[word_counter].item()
511
  cat_name = schema.label_categories[cat_idx]
512
 
513
+ # Get attributes (string conversion deferred)
 
 
 
 
514
  attributes = []
515
+ if word_counter in word_to_attrs:
516
+ attr_indices = word_to_attrs[word_counter]
517
+ attributes = [schema.labels[idx] for idx in attr_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
  # Apply post-processing rules
520
  if len(attributes) == 1 and attributes[0] == "pos":
 
525
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
526
 
527
  seq_predictions.append((cat_name, attributes))
528
+ word_counter += 1
529
 
530
  predictions.append(seq_predictions)
531