more vectorization
Browse files- modeling.py +100 -81
modeling.py
CHANGED
|
@@ -97,61 +97,60 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 97 |
|
| 98 |
# Create tensors as regular attributes (not buffers to avoid init warnings)
|
| 99 |
self.group_mask = schema.get_group_masks() # (C x G)
|
| 100 |
-
|
| 101 |
# Convert group mappings to tensor format for GPU operations
|
| 102 |
self._create_tensor_group_mappings(schema)
|
| 103 |
-
|
| 104 |
# Category name to index mapping (regular dict, no device movement needed)
|
| 105 |
self.category_name_to_index = schema.get_category_name_to_index()
|
| 106 |
-
|
| 107 |
def _create_tensor_group_mappings(self, schema):
|
| 108 |
"""
|
| 109 |
Create tensor-based group mappings for efficient GPU operations.
|
| 110 |
-
|
| 111 |
Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
|
| 112 |
This optimization replaces dict lookups with tensor indexing for better performance.
|
| 113 |
-
|
| 114 |
C = num_categories, G = num_groups, A = num_attributes
|
| 115 |
"""
|
| 116 |
num_groups = len(schema.group_names)
|
| 117 |
-
num_labels = len(schema.labels)
|
| 118 |
device = torch.device("cpu") # Will be moved with model
|
| 119 |
-
|
| 120 |
# Create group attribute indices tensor: (G x max_group_size)
|
| 121 |
# Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
|
| 122 |
max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
|
| 123 |
self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
|
| 124 |
self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
|
| 125 |
-
|
| 126 |
for group_idx, group_name in enumerate(schema.group_names):
|
| 127 |
group_labels = schema.group_name_to_labels[group_name]
|
| 128 |
group_size = len(group_labels)
|
| 129 |
self.group_sizes[group_idx] = group_size
|
| 130 |
-
|
| 131 |
for label_idx, label in enumerate(group_labels):
|
| 132 |
if label in schema.labels:
|
| 133 |
attr_idx = schema.labels.index(label)
|
| 134 |
self.group_attr_indices[group_idx, label_idx] = attr_idx
|
| 135 |
-
|
| 136 |
# Create category to groups mapping: (C x G) - which groups are valid for each category
|
| 137 |
# Replaces dict-based category_to_group_names with tensor indexing
|
| 138 |
# Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
|
| 139 |
self.category_to_groups = self.group_mask.clone() # (C x G)
|
| 140 |
|
| 141 |
-
def _apply(self, fn):
|
| 142 |
"""Override _apply to move our custom tensors with the model."""
|
| 143 |
super()._apply(fn)
|
| 144 |
-
|
| 145 |
# Move our custom tensors when model.to(device) is called
|
| 146 |
-
if hasattr(self,
|
| 147 |
self.group_mask = fn(self.group_mask)
|
| 148 |
-
if hasattr(self,
|
| 149 |
self.group_attr_indices = fn(self.group_attr_indices)
|
| 150 |
-
if hasattr(self,
|
| 151 |
self.group_sizes = fn(self.group_sizes)
|
| 152 |
-
if hasattr(self,
|
| 153 |
self.category_to_groups = fn(self.category_to_groups)
|
| 154 |
-
|
| 155 |
return self
|
| 156 |
|
| 157 |
def forward(
|
|
@@ -285,7 +284,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 285 |
valid_output = flat_output[valid_word_tokens] # (valid_word_tokens x H)
|
| 286 |
valid_word_indices = flat_word_indices[valid_word_tokens] # (valid_word_tokens,)
|
| 287 |
|
| 288 |
-
total_words = max_words_per_seq.sum()
|
| 289 |
if total_words == 0:
|
| 290 |
return torch.empty(0, hidden_size, device=device)
|
| 291 |
|
|
@@ -372,7 +371,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 372 |
prev_word_id = word_id
|
| 373 |
|
| 374 |
# Debug logging to match fairseq model
|
| 375 |
-
logger.debug(f"Word mask: {word_mask[batch_idx]
|
| 376 |
|
| 377 |
return word_mask
|
| 378 |
|
|
@@ -417,86 +416,105 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 417 |
Convert logits to human-readable labels using vectorized operations.
|
| 418 |
|
| 419 |
Key optimizations:
|
| 420 |
-
1.
|
| 421 |
-
2. Vectorized
|
| 422 |
-
3.
|
| 423 |
-
4.
|
| 424 |
|
| 425 |
B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
|
| 426 |
-
|
| 427 |
-
Args:
|
| 428 |
-
cat_logits: Category logits (B x W x C)
|
| 429 |
-
attr_logits: Attribute logits (B x W x A)
|
| 430 |
-
word_mask: Binary mask for valid words (B x L)
|
| 431 |
-
|
| 432 |
-
Returns:
|
| 433 |
-
predictions: List of [(category, [attributes])] for each sequence
|
| 434 |
"""
|
| 435 |
device = cat_logits.device
|
| 436 |
bsz, max_words = cat_logits.shape[:2]
|
| 437 |
nwords = word_mask.sum(-1) # (B,)
|
| 438 |
schema = self.config.label_schema
|
| 439 |
|
| 440 |
-
#
|
| 441 |
-
#
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
continue
|
| 452 |
|
| 453 |
-
# Get
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
-
# Get
|
| 457 |
-
|
| 458 |
-
|
| 459 |
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
seq_predictions = []
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
| 464 |
cat_name = schema.label_categories[cat_idx]
|
| 465 |
|
| 466 |
-
# Get
|
| 467 |
-
word_valid_groups = seq_valid_groups[word_idx] # (G,)
|
| 468 |
-
valid_group_indices = word_valid_groups.nonzero(as_tuple=True)[0] # (valid_groups,)
|
| 469 |
-
|
| 470 |
-
# Collect attributes using vectorized operations
|
| 471 |
attributes = []
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
if group_size == 0:
|
| 477 |
-
continue
|
| 478 |
-
|
| 479 |
-
# Get attribute indices for this group: (group_size,)
|
| 480 |
-
# Tensor lookup replaces dict: group_name_to_group_attr_indices[group_name]
|
| 481 |
-
group_attr_indices = self.group_attr_indices[group_idx, :group_size]
|
| 482 |
-
valid_indices = group_attr_indices[group_attr_indices >= 0]
|
| 483 |
-
|
| 484 |
-
if len(valid_indices) == 0:
|
| 485 |
-
continue
|
| 486 |
-
|
| 487 |
-
# Get logits for this group: (group_size,)
|
| 488 |
-
group_logits = attr_logits[seq_idx, word_idx, valid_indices]
|
| 489 |
-
|
| 490 |
-
if len(valid_indices) == 1:
|
| 491 |
-
# Binary decision
|
| 492 |
-
if group_logits.sigmoid().item() > 0.5:
|
| 493 |
-
attr_idx = valid_indices[0].item()
|
| 494 |
-
attributes.append(schema.labels[attr_idx])
|
| 495 |
-
else:
|
| 496 |
-
# Multi-class decision
|
| 497 |
-
best_local_idx = group_logits.argmax().item()
|
| 498 |
-
attr_idx = valid_indices[best_local_idx].item()
|
| 499 |
-
attributes.append(schema.labels[attr_idx])
|
| 500 |
|
| 501 |
# Apply post-processing rules
|
| 502 |
if len(attributes) == 1 and attributes[0] == "pos":
|
|
@@ -507,6 +525,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 507 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 508 |
|
| 509 |
seq_predictions.append((cat_name, attributes))
|
|
|
|
| 510 |
|
| 511 |
predictions.append(seq_predictions)
|
| 512 |
|
|
|
|
| 97 |
|
| 98 |
# Create tensors as regular attributes (not buffers to avoid init warnings)
|
| 99 |
self.group_mask = schema.get_group_masks() # (C x G)
|
| 100 |
+
|
| 101 |
# Convert group mappings to tensor format for GPU operations
|
| 102 |
self._create_tensor_group_mappings(schema)
|
| 103 |
+
|
| 104 |
# Category name to index mapping (regular dict, no device movement needed)
|
| 105 |
self.category_name_to_index = schema.get_category_name_to_index()
|
| 106 |
+
|
| 107 |
def _create_tensor_group_mappings(self, schema):
|
| 108 |
"""
|
| 109 |
Create tensor-based group mappings for efficient GPU operations.
|
| 110 |
+
|
| 111 |
Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
|
| 112 |
This optimization replaces dict lookups with tensor indexing for better performance.
|
| 113 |
+
|
| 114 |
C = num_categories, G = num_groups, A = num_attributes
|
| 115 |
"""
|
| 116 |
num_groups = len(schema.group_names)
|
|
|
|
| 117 |
device = torch.device("cpu") # Will be moved with model
|
| 118 |
+
|
| 119 |
# Create group attribute indices tensor: (G x max_group_size)
|
| 120 |
# Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
|
| 121 |
max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
|
| 122 |
self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
|
| 123 |
self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
|
| 124 |
+
|
| 125 |
for group_idx, group_name in enumerate(schema.group_names):
|
| 126 |
group_labels = schema.group_name_to_labels[group_name]
|
| 127 |
group_size = len(group_labels)
|
| 128 |
self.group_sizes[group_idx] = group_size
|
| 129 |
+
|
| 130 |
for label_idx, label in enumerate(group_labels):
|
| 131 |
if label in schema.labels:
|
| 132 |
attr_idx = schema.labels.index(label)
|
| 133 |
self.group_attr_indices[group_idx, label_idx] = attr_idx
|
| 134 |
+
|
| 135 |
# Create category to groups mapping: (C x G) - which groups are valid for each category
|
| 136 |
# Replaces dict-based category_to_group_names with tensor indexing
|
| 137 |
# Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
|
| 138 |
self.category_to_groups = self.group_mask.clone() # (C x G)
|
| 139 |
|
| 140 |
+
def _apply(self, fn): # type: ignore[override]
|
| 141 |
"""Override _apply to move our custom tensors with the model."""
|
| 142 |
super()._apply(fn)
|
| 143 |
+
|
| 144 |
# Move our custom tensors when model.to(device) is called
|
| 145 |
+
if hasattr(self, "group_mask"):
|
| 146 |
self.group_mask = fn(self.group_mask)
|
| 147 |
+
if hasattr(self, "group_attr_indices"):
|
| 148 |
self.group_attr_indices = fn(self.group_attr_indices)
|
| 149 |
+
if hasattr(self, "group_sizes"):
|
| 150 |
self.group_sizes = fn(self.group_sizes)
|
| 151 |
+
if hasattr(self, "category_to_groups"):
|
| 152 |
self.category_to_groups = fn(self.category_to_groups)
|
| 153 |
+
|
| 154 |
return self
|
| 155 |
|
| 156 |
def forward(
|
|
|
|
| 284 |
valid_output = flat_output[valid_word_tokens] # (valid_word_tokens x H)
|
| 285 |
valid_word_indices = flat_word_indices[valid_word_tokens] # (valid_word_tokens,)
|
| 286 |
|
| 287 |
+
total_words = max_words_per_seq.sum()
|
| 288 |
if total_words == 0:
|
| 289 |
return torch.empty(0, hidden_size, device=device)
|
| 290 |
|
|
|
|
| 371 |
prev_word_id = word_id
|
| 372 |
|
| 373 |
# Debug logging to match fairseq model
|
| 374 |
+
logger.debug(f"Word mask: {word_mask[batch_idx]}")
|
| 375 |
|
| 376 |
return word_mask
|
| 377 |
|
|
|
|
| 416 |
Convert logits to human-readable labels using vectorized operations.
|
| 417 |
|
| 418 |
Key optimizations:
|
| 419 |
+
1. Flatten batch dimension to process all words simultaneously
|
| 420 |
+
2. Vectorized group processing across all words
|
| 421 |
+
3. Defer string conversion to the very end
|
| 422 |
+
4. Minimize Python loops and tensor-CPU transfers
|
| 423 |
|
| 424 |
B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
"""
|
| 426 |
device = cat_logits.device
|
| 427 |
bsz, max_words = cat_logits.shape[:2]
|
| 428 |
nwords = word_mask.sum(-1) # (B,)
|
| 429 |
schema = self.config.label_schema
|
| 430 |
|
| 431 |
+
# Step 1: Create valid word mask and flatten batch dimension
|
| 432 |
+
# (B x W) -> (total_words,) to process all words simultaneously
|
| 433 |
+
batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
|
| 434 |
+
for b in range(bsz):
|
| 435 |
+
if nwords[b] > 0:
|
| 436 |
+
batch_word_mask[b, :nwords[b]] = True
|
| 437 |
|
| 438 |
+
valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
|
| 439 |
+
total_words = len(valid_positions)
|
| 440 |
|
| 441 |
+
if total_words == 0:
|
| 442 |
+
return [[] for _ in range(bsz)]
|
| 443 |
+
|
| 444 |
+
# Step 2: Vectorized category prediction for all valid words
|
| 445 |
+
flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
|
| 446 |
+
flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
|
| 447 |
+
|
| 448 |
+
# Get categories for all valid words: (total_words,)
|
| 449 |
+
all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
|
| 450 |
+
|
| 451 |
+
# Step 3: Vectorized group validity for all words: (total_words x G)
|
| 452 |
+
all_valid_groups = self.category_to_groups[all_cat_indices]
|
| 453 |
+
|
| 454 |
+
# Step 4: Collect attributes using vectorized group processing
|
| 455 |
+
word_to_attrs = {} # word_idx -> list of attr_indices
|
| 456 |
+
|
| 457 |
+
# Process each group across all words simultaneously
|
| 458 |
+
for group_idx in range(self.group_sizes.size(0)):
|
| 459 |
+
group_size = self.group_sizes[group_idx].item()
|
| 460 |
+
if group_size == 0:
|
| 461 |
+
continue
|
| 462 |
+
|
| 463 |
+
# Find words that have this group valid: (words_with_group,)
|
| 464 |
+
words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
|
| 465 |
+
if len(words_with_group) == 0:
|
| 466 |
continue
|
| 467 |
|
| 468 |
+
# Get attribute indices for this group
|
| 469 |
+
group_attr_indices = self.group_attr_indices[group_idx, :group_size]
|
| 470 |
+
valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
|
| 471 |
+
if len(valid_attr_indices) == 0:
|
| 472 |
+
continue
|
| 473 |
|
| 474 |
+
# Get logits for all words that need this group: (words_with_group x group_size)
|
| 475 |
+
word_positions = valid_positions[words_with_group]
|
| 476 |
+
group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
|
| 477 |
|
| 478 |
+
if len(valid_attr_indices) == 1:
|
| 479 |
+
# Binary decision for all words simultaneously: (words_with_group,)
|
| 480 |
+
decisions = group_logits.sigmoid().squeeze(-1) > 0.5
|
| 481 |
+
selected_words = words_with_group[decisions]
|
| 482 |
+
attr_idx = valid_attr_indices[0].item()
|
| 483 |
+
|
| 484 |
+
for word_idx in selected_words:
|
| 485 |
+
word_idx_item = word_idx.item()
|
| 486 |
+
if word_idx_item not in word_to_attrs:
|
| 487 |
+
word_to_attrs[word_idx_item] = []
|
| 488 |
+
word_to_attrs[word_idx_item].append(attr_idx)
|
| 489 |
+
else:
|
| 490 |
+
# Multi-class decision for all words: (words_with_group,)
|
| 491 |
+
best_indices = group_logits.argmax(dim=-1)
|
| 492 |
+
|
| 493 |
+
for i, word_idx in enumerate(words_with_group):
|
| 494 |
+
attr_idx = valid_attr_indices[best_indices[i]].item()
|
| 495 |
+
word_idx_item = word_idx.item()
|
| 496 |
+
if word_idx_item not in word_to_attrs:
|
| 497 |
+
word_to_attrs[word_idx_item] = []
|
| 498 |
+
word_to_attrs[word_idx_item].append(attr_idx)
|
| 499 |
+
|
| 500 |
+
# Step 5: Reconstruct batch structure and convert to strings (deferred)
|
| 501 |
+
predictions = []
|
| 502 |
+
word_counter = 0
|
| 503 |
+
|
| 504 |
+
for seq_idx in range(bsz):
|
| 505 |
+
seq_nwords = nwords[seq_idx].item()
|
| 506 |
seq_predictions = []
|
| 507 |
+
|
| 508 |
+
for _ in range(seq_nwords):
|
| 509 |
+
# Get category (string conversion deferred)
|
| 510 |
+
cat_idx = all_cat_indices[word_counter].item()
|
| 511 |
cat_name = schema.label_categories[cat_idx]
|
| 512 |
|
| 513 |
+
# Get attributes (string conversion deferred)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
attributes = []
|
| 515 |
+
if word_counter in word_to_attrs:
|
| 516 |
+
attr_indices = word_to_attrs[word_counter]
|
| 517 |
+
attributes = [schema.labels[idx] for idx in attr_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
# Apply post-processing rules
|
| 520 |
if len(attributes) == 1 and attributes[0] == "pos":
|
|
|
|
| 525 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 526 |
|
| 527 |
seq_predictions.append((cat_name, attributes))
|
| 528 |
+
word_counter += 1
|
| 529 |
|
| 530 |
predictions.append(seq_predictions)
|
| 531 |
|