more vectorization of logits to labels
Browse files- modeling.py +122 -54
modeling.py
CHANGED
|
@@ -96,24 +96,62 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 96 |
schema = self.config.label_schema
|
| 97 |
|
| 98 |
# Create tensors as regular attributes (not buffers to avoid init warnings)
|
| 99 |
-
self.group_mask = schema.get_group_masks()
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
# Category name to index mapping (regular dict, no device movement needed)
|
| 103 |
self.category_name_to_index = schema.get_category_name_to_index()
|
| 104 |
-
|
| 105 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
"""Override _apply to move our custom tensors with the model."""
|
| 107 |
super()._apply(fn)
|
| 108 |
-
|
| 109 |
# Move our custom tensors when model.to(device) is called
|
| 110 |
-
if hasattr(self,
|
| 111 |
self.group_mask = fn(self.group_mask)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
return self
|
| 118 |
|
| 119 |
def forward(
|
|
@@ -376,70 +414,100 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 376 |
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
|
| 377 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 378 |
"""
|
| 379 |
-
Convert logits to human-readable labels using
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
Args:
|
| 384 |
cat_logits: Category logits (B x W x C)
|
| 385 |
attr_logits: Attribute logits (B x W x A)
|
| 386 |
word_mask: Binary mask for valid words (B x L)
|
| 387 |
-
|
| 388 |
Returns:
|
| 389 |
predictions: List of [(category, [attributes])] for each sequence
|
| 390 |
"""
|
| 391 |
-
|
| 392 |
-
|
| 393 |
nwords = word_mask.sum(-1) # (B,)
|
| 394 |
-
|
| 395 |
-
assert num_attrs == len(self.config.label_schema.labels)
|
| 396 |
-
assert num_cats == len(self.config.label_schema.label_categories)
|
| 397 |
-
|
| 398 |
-
predictions = []
|
| 399 |
schema = self.config.label_schema
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
for seq_idx in range(bsz):
|
| 402 |
-
seq_nwords = nwords[seq_idx]
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
seq_predictions = []
|
| 407 |
for word_idx in range(seq_nwords):
|
| 408 |
-
cat_idx =
|
| 409 |
cat_name = schema.label_categories[cat_idx]
|
| 410 |
-
|
| 411 |
-
# Get valid groups for this
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
| 415 |
attributes = []
|
| 416 |
-
for
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
if len(attributes) == 1 and attributes[0] == "pos":
|
| 435 |
# This label is used as a default for training but implied in mim format
|
| 436 |
attributes = []
|
| 437 |
elif cat_name == "sl" and "act" in attributes:
|
| 438 |
# Number and tense are not shown for sl act in mim format
|
| 439 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 440 |
-
|
| 441 |
seq_predictions.append((cat_name, attributes))
|
| 442 |
-
|
| 443 |
predictions.append(seq_predictions)
|
| 444 |
|
| 445 |
return predictions
|
|
|
|
| 96 |
schema = self.config.label_schema
|
| 97 |
|
| 98 |
# Create tensors as regular attributes (not buffers to avoid init warnings)
|
| 99 |
+
self.group_mask = schema.get_group_masks() # (C x G)
|
| 100 |
+
|
| 101 |
+
# Convert group mappings to tensor format for GPU operations
|
| 102 |
+
self._create_tensor_group_mappings(schema)
|
| 103 |
+
|
| 104 |
# Category name to index mapping (regular dict, no device movement needed)
|
| 105 |
self.category_name_to_index = schema.get_category_name_to_index()
|
| 106 |
+
|
| 107 |
+
def _create_tensor_group_mappings(self, schema):
|
| 108 |
+
"""
|
| 109 |
+
Create tensor-based group mappings for efficient GPU operations.
|
| 110 |
+
|
| 111 |
+
Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
|
| 112 |
+
This optimization replaces dict lookups with tensor indexing for better performance.
|
| 113 |
+
|
| 114 |
+
C = num_categories, G = num_groups, A = num_attributes
|
| 115 |
+
"""
|
| 116 |
+
num_groups = len(schema.group_names)
|
| 117 |
+
num_labels = len(schema.labels)
|
| 118 |
+
device = torch.device("cpu") # Will be moved with model
|
| 119 |
+
|
| 120 |
+
# Create group attribute indices tensor: (G x max_group_size)
|
| 121 |
+
# Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
|
| 122 |
+
max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
|
| 123 |
+
self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
|
| 124 |
+
self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device) # (G,)
|
| 125 |
+
|
| 126 |
+
for group_idx, group_name in enumerate(schema.group_names):
|
| 127 |
+
group_labels = schema.group_name_to_labels[group_name]
|
| 128 |
+
group_size = len(group_labels)
|
| 129 |
+
self.group_sizes[group_idx] = group_size
|
| 130 |
+
|
| 131 |
+
for label_idx, label in enumerate(group_labels):
|
| 132 |
+
if label in schema.labels:
|
| 133 |
+
attr_idx = schema.labels.index(label)
|
| 134 |
+
self.group_attr_indices[group_idx, label_idx] = attr_idx
|
| 135 |
+
|
| 136 |
+
# Create category to groups mapping: (C x G) - which groups are valid for each category
|
| 137 |
+
# Replaces dict-based category_to_group_names with tensor indexing
|
| 138 |
+
# Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
|
| 139 |
+
self.category_to_groups = self.group_mask.clone() # (C x G)
|
| 140 |
+
|
| 141 |
+
def _apply(self, fn):
|
| 142 |
"""Override _apply to move our custom tensors with the model."""
|
| 143 |
super()._apply(fn)
|
| 144 |
+
|
| 145 |
# Move our custom tensors when model.to(device) is called
|
| 146 |
+
if hasattr(self, 'group_mask'):
|
| 147 |
self.group_mask = fn(self.group_mask)
|
| 148 |
+
if hasattr(self, 'group_attr_indices'):
|
| 149 |
+
self.group_attr_indices = fn(self.group_attr_indices)
|
| 150 |
+
if hasattr(self, 'group_sizes'):
|
| 151 |
+
self.group_sizes = fn(self.group_sizes)
|
| 152 |
+
if hasattr(self, 'category_to_groups'):
|
| 153 |
+
self.category_to_groups = fn(self.category_to_groups)
|
| 154 |
+
|
| 155 |
return self
|
| 156 |
|
| 157 |
def forward(
|
|
|
|
| 414 |
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
|
| 415 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 416 |
"""
|
| 417 |
+
Convert logits to human-readable labels using vectorized operations.
|
| 418 |
+
|
| 419 |
+
Key optimizations:
|
| 420 |
+
1. Tensor-based schema lookups instead of Python dict access
|
| 421 |
+
2. Vectorized argmax for categories across entire batch
|
| 422 |
+
3. Reduced CPU-GPU context switching by batching operations
|
| 423 |
+
4. Pre-computed tensor mappings for group/attribute relationships
|
| 424 |
+
|
| 425 |
+
B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
|
| 426 |
+
|
| 427 |
Args:
|
| 428 |
cat_logits: Category logits (B x W x C)
|
| 429 |
attr_logits: Attribute logits (B x W x A)
|
| 430 |
word_mask: Binary mask for valid words (B x L)
|
| 431 |
+
|
| 432 |
Returns:
|
| 433 |
predictions: List of [(category, [attributes])] for each sequence
|
| 434 |
"""
|
| 435 |
+
device = cat_logits.device
|
| 436 |
+
bsz, max_words = cat_logits.shape[:2]
|
| 437 |
nwords = word_mask.sum(-1) # (B,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
schema = self.config.label_schema
|
| 439 |
|
| 440 |
+
# Vectorized category prediction: (B x W)
|
| 441 |
+
# Single GPU operation instead of nested loops
|
| 442 |
+
pred_cat_indices = cat_logits.argmax(dim=-1) # (B x W)
|
| 443 |
+
|
| 444 |
+
# Vectorized attribute prediction for all groups
|
| 445 |
+
predictions = []
|
| 446 |
+
|
| 447 |
for seq_idx in range(bsz):
|
| 448 |
+
seq_nwords = int(nwords[seq_idx].item())
|
| 449 |
+
if seq_nwords == 0:
|
| 450 |
+
predictions.append([])
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
# Get categories for this sequence: (seq_nwords,)
|
| 454 |
+
seq_cat_indices = pred_cat_indices[seq_idx, :seq_nwords]
|
| 455 |
+
|
| 456 |
+
# Get valid groups for each category: (seq_nwords x G)
|
| 457 |
+
# Tensor lookup replaces dict access: category_to_group_names[cat_name]
|
| 458 |
+
seq_valid_groups = self.category_to_groups[seq_cat_indices] # (seq_nwords x G)
|
| 459 |
+
|
| 460 |
+
# Process attributes for all words in sequence
|
| 461 |
seq_predictions = []
|
| 462 |
for word_idx in range(seq_nwords):
|
| 463 |
+
cat_idx = seq_cat_indices[word_idx].item()
|
| 464 |
cat_name = schema.label_categories[cat_idx]
|
| 465 |
+
|
| 466 |
+
# Get valid groups for this word
|
| 467 |
+
word_valid_groups = seq_valid_groups[word_idx] # (G,)
|
| 468 |
+
valid_group_indices = word_valid_groups.nonzero(as_tuple=True)[0] # (valid_groups,)
|
| 469 |
+
|
| 470 |
+
# Collect attributes using vectorized operations
|
| 471 |
attributes = []
|
| 472 |
+
for group_idx in valid_group_indices:
|
| 473 |
+
group_idx = group_idx.item()
|
| 474 |
+
group_size = self.group_sizes[group_idx].item()
|
| 475 |
+
|
| 476 |
+
if group_size == 0:
|
| 477 |
+
continue
|
| 478 |
+
|
| 479 |
+
# Get attribute indices for this group: (group_size,)
|
| 480 |
+
# Tensor lookup replaces dict: group_name_to_group_attr_indices[group_name]
|
| 481 |
+
group_attr_indices = self.group_attr_indices[group_idx, :group_size]
|
| 482 |
+
valid_indices = group_attr_indices[group_attr_indices >= 0]
|
| 483 |
+
|
| 484 |
+
if len(valid_indices) == 0:
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
# Get logits for this group: (group_size,)
|
| 488 |
+
group_logits = attr_logits[seq_idx, word_idx, valid_indices]
|
| 489 |
+
|
| 490 |
+
if len(valid_indices) == 1:
|
| 491 |
+
# Binary decision
|
| 492 |
+
if group_logits.sigmoid().item() > 0.5:
|
| 493 |
+
attr_idx = valid_indices[0].item()
|
| 494 |
+
attributes.append(schema.labels[attr_idx])
|
| 495 |
+
else:
|
| 496 |
+
# Multi-class decision
|
| 497 |
+
best_local_idx = group_logits.argmax().item()
|
| 498 |
+
attr_idx = valid_indices[best_local_idx].item()
|
| 499 |
+
attributes.append(schema.labels[attr_idx])
|
| 500 |
+
|
| 501 |
+
# Apply post-processing rules
|
| 502 |
if len(attributes) == 1 and attributes[0] == "pos":
|
| 503 |
# This label is used as a default for training but implied in mim format
|
| 504 |
attributes = []
|
| 505 |
elif cat_name == "sl" and "act" in attributes:
|
| 506 |
# Number and tense are not shown for sl act in mim format
|
| 507 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 508 |
+
|
| 509 |
seq_predictions.append((cat_name, attributes))
|
| 510 |
+
|
| 511 |
predictions.append(seq_predictions)
|
| 512 |
|
| 513 |
return predictions
|