changing interface to accept words instead of a sentence string. Documenting the truncation flag
Browse files- README.md +50 -7
- modeling.py +134 -96
README.md
CHANGED
|
@@ -9,10 +9,14 @@ paper: https://arxiv.org/abs/2201.05601
|
|
| 9 |
---
|
| 10 |
## Prediction Methods
|
| 11 |
|
| 12 |
-
The model provides
|
| 13 |
|
| 14 |
-
- **`
|
| 15 |
-
- **`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
```python
|
| 18 |
from transformers import AutoModel, AutoTokenizer
|
|
@@ -22,16 +26,17 @@ tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
|
|
| 22 |
|
| 23 |
# Example sentence
|
| 24 |
sentence = "Ég veit að þú kemur í kvöld til mín ."
|
|
|
|
| 25 |
|
| 26 |
# Get predictions in (category, [attributes]) format
|
| 27 |
-
result = model.predict_labels_from_text([
|
| 28 |
expected = [
|
| 29 |
[
|
| 30 |
("fp", ["1", "sing", "nom"]),
|
| 31 |
-
("sf", ["
|
| 32 |
("c", []),
|
| 33 |
("fp", ["2", "sing", "nom"]),
|
| 34 |
-
("sf", ["
|
| 35 |
("af", []),
|
| 36 |
("n", ["neut", "sing", "acc"]),
|
| 37 |
("af", []),
|
|
@@ -43,10 +48,48 @@ assert result == expected, f"Expected {expected}, but got {result}"
|
|
| 43 |
print("Test passed successfully!")
|
| 44 |
|
| 45 |
# Get predictions in IFD format (for MIM-GOLD evaluation)
|
| 46 |
-
ifd_result = model.predict_ifd_labels_from_text([
|
| 47 |
ifd_expected = [
|
| 48 |
["fp1en", "sfg1en", "c", "fp2en", "sfg2en", "af", "nheo", "af", "fp1ee", "pl"]
|
| 49 |
]
|
| 50 |
assert ifd_result == ifd_expected, f"Expected {ifd_expected}, but got {ifd_result}"
|
| 51 |
print("IFD conversion test passed successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
```
|
|
|
|
| 9 |
---
|
| 10 |
## Prediction Methods
|
| 11 |
|
| 12 |
+
The model provides several prediction methods:
|
| 13 |
|
| 14 |
+
- **`prepare_inputs(words, tokenizer, truncate=False)`**: Prepares inputs for a single list of words, returning tensors without batch dimension.
|
| 15 |
+
- **`predict_labels(input_ids, attention_mask, word_mask)`**: Low-level prediction from prepared tensors with batch dimension.
|
| 16 |
+
- **`predict_labels_from_text(sentences, tokenizer, truncate=False)`**: Returns structured predictions as (category, [attributes]) tuples from word lists. These can be slightly more readable and more suitable for some applications.
|
| 17 |
+
- **`predict_ifd_labels_from_text(sentences, tokenizer, truncate=False)`**: Returns predictions in IFD (Icelandic Frequency Dictionary) format from word lists. Use this for evaluation against MIM-GOLD datasets or when you need compatibility with traditional Icelandic POS taggers.
|
| 18 |
+
|
| 19 |
+
All methods accept pre-tokenized word lists rather than raw sentences for better control over tokenization.
|
| 20 |
|
| 21 |
```python
|
| 22 |
from transformers import AutoModel, AutoTokenizer
|
|
|
|
| 26 |
|
| 27 |
# Example sentence
|
| 28 |
sentence = "Ég veit að þú kemur í kvöld til mín ."
|
| 29 |
+
sentence_words = sentence.split()
|
| 30 |
|
| 31 |
# Get predictions in (category, [attributes]) format
|
| 32 |
+
result = model.predict_labels_from_text([sentence_words], tokenizer)
|
| 33 |
expected = [
|
| 34 |
[
|
| 35 |
("fp", ["1", "sing", "nom"]),
|
| 36 |
+
("sf", ["sing", "act", "1", "pres"]),
|
| 37 |
("c", []),
|
| 38 |
("fp", ["2", "sing", "nom"]),
|
| 39 |
+
("sf", ["sing", "act", "2", "pres"]),
|
| 40 |
("af", []),
|
| 41 |
("n", ["neut", "sing", "acc"]),
|
| 42 |
("af", []),
|
|
|
|
| 48 |
print("Test passed successfully!")
|
| 49 |
|
| 50 |
# Get predictions in IFD format (for MIM-GOLD evaluation)
|
| 51 |
+
ifd_result = model.predict_ifd_labels_from_text([sentence_words], tokenizer)
|
| 52 |
ifd_expected = [
|
| 53 |
["fp1en", "sfg1en", "c", "fp2en", "sfg2en", "af", "nheo", "af", "fp1ee", "pl"]
|
| 54 |
]
|
| 55 |
assert ifd_result == ifd_expected, f"Expected {ifd_expected}, but got {ifd_result}"
|
| 56 |
print("IFD conversion test passed successfully!")
|
| 57 |
+
|
| 58 |
+
# Alternative: use prepare_inputs for single sentence prediction
|
| 59 |
+
input_ids, attention_mask, word_mask = model.prepare_inputs(sentence_words, tokenizer)
|
| 60 |
+
single_result = model.predict_labels(input_ids.unsqeeze(0), attention_mask.unsqeeze(0), word_mask.unsqeeze(0))
|
| 61 |
+
assert single_result == expected, f"Expected {expected}, but got {single_result}"
|
| 62 |
+
print("Single sentence prediction test passed successfully!")
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Handling Long Sequences with Truncation
|
| 66 |
+
|
| 67 |
+
By default, `truncate=False` to avoid hard-to-debug issues where input is silently truncated. However, very long sequences will cause errors:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from transformers import AutoModel, AutoTokenizer
|
| 71 |
+
|
| 72 |
+
model = AutoModel.from_pretrained("mideind/IceBERT-PoS", trust_remote_code=True)
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
|
| 74 |
+
|
| 75 |
+
# Create a very long sentence that exceeds model limits
|
| 76 |
+
words = ["Þetta", "er", "mjög", "löng", "setning"] * 200 # Very long sentence
|
| 77 |
+
print(f"Input length: {len(words)} words")
|
| 78 |
+
|
| 79 |
+
# This will crash due to sequence length exceeding model limits
|
| 80 |
+
try:
|
| 81 |
+
result = model.predict_labels_from_text([words], tokenizer, truncate=False)
|
| 82 |
+
print("This shouldn't print - sequence was too long!")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error as expected: {type(e).__name__}")
|
| 85 |
+
|
| 86 |
+
# Use truncate=True for long sequences
|
| 87 |
+
result_truncated = model.predict_labels_from_text([words], tokenizer, truncate=True)
|
| 88 |
+
print(f"Truncated result length: {len(result_truncated[0])} tokens")
|
| 89 |
+
print("Warning: Output length differs from input length due to truncation!")
|
| 90 |
+
|
| 91 |
+
# When using truncation, you must handle the length mismatch carefully
|
| 92 |
+
# The output will have fewer predictions than input words
|
| 93 |
+
assert len(result_truncated[0]) < len(words), "Truncation should reduce length"
|
| 94 |
+
print("Truncation example completed successfully!")
|
| 95 |
```
|
modeling.py
CHANGED
|
@@ -322,105 +322,167 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 322 |
batch_first=True,
|
| 323 |
)
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
@torch.no_grad()
|
| 326 |
def predict_labels(
|
| 327 |
-
self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
|
| 328 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 329 |
"""
|
| 330 |
Predict POS labels for input sequences.
|
| 331 |
|
|
|
|
|
|
|
| 332 |
Args:
|
| 333 |
-
input_ids: Token indices
|
| 334 |
-
attention_mask: Attention mask
|
| 335 |
-
|
| 336 |
|
| 337 |
Returns:
|
| 338 |
List of sequences, each containing (category, [attributes]) per word
|
| 339 |
"""
|
| 340 |
-
# Convert word_ids to word_mask
|
| 341 |
-
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
|
| 342 |
-
|
| 343 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 344 |
|
| 345 |
return self._logits_to_labels(cat_logits, attr_logits, word_mask)
|
| 346 |
|
| 347 |
-
def
|
|
|
|
|
|
|
| 348 |
"""
|
| 349 |
-
|
| 350 |
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
Args:
|
| 354 |
-
|
| 355 |
-
|
|
|
|
| 356 |
|
| 357 |
Returns:
|
| 358 |
-
|
| 359 |
"""
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
for batch_idx, seq_word_ids in enumerate(word_ids):
|
| 364 |
-
prev_word_id = None
|
| 365 |
-
for token_idx, word_id in enumerate(seq_word_ids):
|
| 366 |
-
# Skip None values (special tokens and padding)
|
| 367 |
-
if word_id is not None and word_id != prev_word_id:
|
| 368 |
-
word_mask[batch_idx, token_idx] = 1 # Mark word start
|
| 369 |
-
# Only update prev_word_id for valid (non-None) word_ids
|
| 370 |
-
if word_id is not None:
|
| 371 |
-
prev_word_id = word_id
|
| 372 |
-
|
| 373 |
-
# Debug logging to match fairseq model
|
| 374 |
-
logger.debug(f"Word mask: {word_mask[batch_idx]}")
|
| 375 |
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
-
|
|
|
|
|
|
|
| 379 |
"""
|
| 380 |
-
|
| 381 |
|
| 382 |
-
|
| 383 |
|
| 384 |
Args:
|
| 385 |
-
|
| 386 |
-
|
| 387 |
|
| 388 |
Returns:
|
| 389 |
-
|
| 390 |
"""
|
| 391 |
-
|
| 392 |
-
# This fixes the issue where tokens like "Kl." get split incorrectly
|
| 393 |
-
sentences_split = [sentence.split() for sentence in sentences]
|
| 394 |
-
|
| 395 |
-
# Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
|
| 396 |
-
encoding = tokenizer.batch_encode_plus(
|
| 397 |
-
sentences_split, return_tensors="pt", padding=True, is_split_into_words=True, add_special_tokens=True
|
| 398 |
-
)
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Debug logging to match fairseq model
|
| 405 |
-
|
| 406 |
-
logger.debug(f"Encoded tokens: {batch_input_ids[i]}") # (L,)
|
| 407 |
-
logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
|
| 408 |
-
logger.debug(f"Word IDs: {word_ids_list[i]}") # (L,)
|
| 409 |
|
| 410 |
-
return
|
| 411 |
|
| 412 |
def _logits_to_labels(
|
| 413 |
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
|
| 414 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 415 |
"""
|
| 416 |
Convert logits to human-readable labels using vectorized operations.
|
| 417 |
-
|
| 418 |
Key optimizations:
|
| 419 |
1. Flatten batch dimension to process all words simultaneously
|
| 420 |
2. Vectorized group processing across all words
|
| 421 |
3. Defer string conversion to the very end
|
| 422 |
4. Minimize Python loops and tensor-CPU transfers
|
| 423 |
-
|
| 424 |
B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
|
| 425 |
"""
|
| 426 |
device = cat_logits.device
|
|
@@ -433,54 +495,54 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 433 |
batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
|
| 434 |
for b in range(bsz):
|
| 435 |
if nwords[b] > 0:
|
| 436 |
-
batch_word_mask[b, :nwords[b]] = True
|
| 437 |
-
|
| 438 |
valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
|
| 439 |
total_words = len(valid_positions)
|
| 440 |
-
|
| 441 |
if total_words == 0:
|
| 442 |
return [[] for _ in range(bsz)]
|
| 443 |
|
| 444 |
# Step 2: Vectorized category prediction for all valid words
|
| 445 |
flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
|
| 446 |
flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
|
| 447 |
-
|
| 448 |
# Get categories for all valid words: (total_words,)
|
| 449 |
all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
|
| 450 |
-
|
| 451 |
# Step 3: Vectorized group validity for all words: (total_words x G)
|
| 452 |
all_valid_groups = self.category_to_groups[all_cat_indices]
|
| 453 |
-
|
| 454 |
# Step 4: Collect attributes using vectorized group processing
|
| 455 |
word_to_attrs = {} # word_idx -> list of attr_indices
|
| 456 |
-
|
| 457 |
# Process each group across all words simultaneously
|
| 458 |
for group_idx in range(self.group_sizes.size(0)):
|
| 459 |
group_size = self.group_sizes[group_idx].item()
|
| 460 |
if group_size == 0:
|
| 461 |
continue
|
| 462 |
-
|
| 463 |
# Find words that have this group valid: (words_with_group,)
|
| 464 |
words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
|
| 465 |
if len(words_with_group) == 0:
|
| 466 |
continue
|
| 467 |
-
|
| 468 |
# Get attribute indices for this group
|
| 469 |
group_attr_indices = self.group_attr_indices[group_idx, :group_size]
|
| 470 |
valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
|
| 471 |
if len(valid_attr_indices) == 0:
|
| 472 |
continue
|
| 473 |
-
|
| 474 |
# Get logits for all words that need this group: (words_with_group x group_size)
|
| 475 |
word_positions = valid_positions[words_with_group]
|
| 476 |
group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
|
| 477 |
-
|
| 478 |
if len(valid_attr_indices) == 1:
|
| 479 |
# Binary decision for all words simultaneously: (words_with_group,)
|
| 480 |
decisions = group_logits.sigmoid().squeeze(-1) > 0.5
|
| 481 |
selected_words = words_with_group[decisions]
|
| 482 |
attr_idx = valid_attr_indices[0].item()
|
| 483 |
-
|
| 484 |
for word_idx in selected_words:
|
| 485 |
word_idx_item = word_idx.item()
|
| 486 |
if word_idx_item not in word_to_attrs:
|
|
@@ -489,7 +551,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 489 |
else:
|
| 490 |
# Multi-class decision for all words: (words_with_group,)
|
| 491 |
best_indices = group_logits.argmax(dim=-1)
|
| 492 |
-
|
| 493 |
for i, word_idx in enumerate(words_with_group):
|
| 494 |
attr_idx = valid_attr_indices[best_indices[i]].item()
|
| 495 |
word_idx_item = word_idx.item()
|
|
@@ -500,22 +562,22 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 500 |
# Step 5: Reconstruct batch structure and convert to strings (deferred)
|
| 501 |
predictions = []
|
| 502 |
word_counter = 0
|
| 503 |
-
|
| 504 |
for seq_idx in range(bsz):
|
| 505 |
seq_nwords = nwords[seq_idx].item()
|
| 506 |
seq_predictions = []
|
| 507 |
-
|
| 508 |
for _ in range(seq_nwords):
|
| 509 |
# Get category (string conversion deferred)
|
| 510 |
cat_idx = all_cat_indices[word_counter].item()
|
| 511 |
cat_name = schema.label_categories[cat_idx]
|
| 512 |
-
|
| 513 |
# Get attributes (string conversion deferred)
|
| 514 |
attributes = []
|
| 515 |
if word_counter in word_to_attrs:
|
| 516 |
attr_indices = word_to_attrs[word_counter]
|
| 517 |
attributes = [schema.labels[idx] for idx in attr_indices]
|
| 518 |
-
|
| 519 |
# Apply post-processing rules
|
| 520 |
if len(attributes) == 1 and attributes[0] == "pos":
|
| 521 |
# This label is used as a default for training but implied in mim format
|
|
@@ -523,38 +585,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 523 |
elif cat_name == "sl" and "act" in attributes:
|
| 524 |
# Number and tense are not shown for sl act in mim format
|
| 525 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 526 |
-
|
| 527 |
seq_predictions.append((cat_name, attributes))
|
| 528 |
word_counter += 1
|
| 529 |
-
|
| 530 |
predictions.append(seq_predictions)
|
| 531 |
|
| 532 |
return predictions
|
| 533 |
|
| 534 |
-
def predict_ifd_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[str]]:
|
| 535 |
-
"""
|
| 536 |
-
Predict IFD format labels from raw text.
|
| 537 |
-
|
| 538 |
-
B = batch_size, Ws = seq_words
|
| 539 |
-
|
| 540 |
-
Args:
|
| 541 |
-
sentences: List of input sentences (B,)
|
| 542 |
-
tokenizer: HuggingFace tokenizer
|
| 543 |
-
|
| 544 |
-
Returns:
|
| 545 |
-
ifd_predictions: List of IFD labels per sentence (B x Ws)
|
| 546 |
-
"""
|
| 547 |
-
# Get model predictions in (category, [attributes]) format
|
| 548 |
-
predictions = self.predict_labels_from_text(sentences, tokenizer)
|
| 549 |
-
|
| 550 |
-
# Convert each sentence's predictions to IFD format
|
| 551 |
-
ifd_predictions = []
|
| 552 |
-
for sentence_predictions in predictions:
|
| 553 |
-
ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
|
| 554 |
-
ifd_predictions.append(ifd_labels)
|
| 555 |
-
|
| 556 |
-
return ifd_predictions
|
| 557 |
-
|
| 558 |
|
| 559 |
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
| 560 |
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
|
|
|
|
| 322 |
batch_first=True,
|
| 323 |
)
|
| 324 |
|
| 325 |
+
def prepare_inputs(
|
| 326 |
+
self, words: List[str], tokenizer, truncate: bool = False
|
| 327 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 328 |
+
"""
|
| 329 |
+
Prepare inputs for a list of words.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
words: List of words
|
| 333 |
+
tokenizer: HuggingFace tokenizer
|
| 334 |
+
truncate: Whether to truncate if too long
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Tuple of (input_ids, attention_mask, word_mask) without batch dimension.
|
| 338 |
+
"""
|
| 339 |
+
# Encode with word boundary preservation
|
| 340 |
+
encoding = tokenizer.encode_plus(
|
| 341 |
+
words,
|
| 342 |
+
return_tensors="pt",
|
| 343 |
+
is_split_into_words=True,
|
| 344 |
+
add_special_tokens=True,
|
| 345 |
+
truncation=truncate,
|
| 346 |
+
# The model was probably trained with a lot shorter sequences
|
| 347 |
+
max_length=self.config.max_position_embeddings - 2,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
input_ids = encoding["input_ids"].squeeze(0) # (L,)
|
| 351 |
+
attention_mask = torch.ones_like(input_ids)
|
| 352 |
+
|
| 353 |
+
# Get word_ids and convert to word_mask
|
| 354 |
+
word_ids = encoding.word_ids()
|
| 355 |
+
word_mask = self._word_ids_to_word_mask(word_ids)
|
| 356 |
+
|
| 357 |
+
# Debug logging to match fairseq model
|
| 358 |
+
logger.debug(f"Encoded tokens: {input_ids}") # (L,)
|
| 359 |
+
logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(input_ids.tolist())}")
|
| 360 |
+
logger.debug(f"Word IDs: {word_ids}") # (L,)
|
| 361 |
+
logger.debug(f"Word mask: {word_mask}")
|
| 362 |
+
|
| 363 |
+
return input_ids, attention_mask, word_mask
|
| 364 |
+
|
| 365 |
@torch.no_grad()
|
| 366 |
def predict_labels(
|
| 367 |
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_mask: torch.Tensor
|
| 368 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 369 |
"""
|
| 370 |
Predict POS labels for input sequences.
|
| 371 |
|
| 372 |
+
B = batch_size, L = seq_len
|
| 373 |
+
|
| 374 |
Args:
|
| 375 |
+
input_ids: Token indices (B x L)
|
| 376 |
+
attention_mask: Attention mask (B x L)
|
| 377 |
+
word_mask: Binary mask indicating word boundaries (B x L)
|
| 378 |
|
| 379 |
Returns:
|
| 380 |
List of sequences, each containing (category, [attributes]) per word
|
| 381 |
"""
|
|
|
|
|
|
|
|
|
|
| 382 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 383 |
|
| 384 |
return self._logits_to_labels(cat_logits, attr_logits, word_mask)
|
| 385 |
|
| 386 |
+
def predict_labels_from_text(
|
| 387 |
+
self, sentences: List[List[str]], tokenizer, truncate: bool = False
|
| 388 |
+
) -> List[List[Tuple[str, List[str]]]]:
|
| 389 |
"""
|
| 390 |
+
Predict POS labels from list of word lists.
|
| 391 |
|
| 392 |
+
Args:
|
| 393 |
+
sentences: List of sentences, each a list of words
|
| 394 |
+
tokenizer: HuggingFace tokenizer
|
| 395 |
+
truncate: Whether to truncate if too long
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
List of sequences, each containing (category, [attributes]) per word
|
| 399 |
+
"""
|
| 400 |
+
# Use prepare_inputs for each sentence and batch them
|
| 401 |
+
all_input_ids = []
|
| 402 |
+
all_attention_masks = []
|
| 403 |
+
all_word_masks = []
|
| 404 |
+
|
| 405 |
+
for words in sentences:
|
| 406 |
+
input_ids, attention_mask, word_mask = self.prepare_inputs(words, tokenizer, truncate)
|
| 407 |
+
all_input_ids.append(input_ids)
|
| 408 |
+
all_attention_masks.append(attention_mask)
|
| 409 |
+
all_word_masks.append(word_mask)
|
| 410 |
+
|
| 411 |
+
# Pad sequences to same length
|
| 412 |
+
batch_input_ids = pad_sequence(all_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
| 413 |
+
batch_attention_mask = pad_sequence(all_attention_masks, batch_first=True, padding_value=0)
|
| 414 |
+
batch_word_mask = pad_sequence(all_word_masks, batch_first=True, padding_value=0)
|
| 415 |
+
|
| 416 |
+
return self.predict_labels(batch_input_ids, batch_attention_mask, batch_word_mask)
|
| 417 |
+
|
| 418 |
+
def predict_ifd_labels_from_text(
|
| 419 |
+
self, sentences: List[List[str]], tokenizer, truncate: bool = False
|
| 420 |
+
) -> List[List[str]]:
|
| 421 |
+
"""
|
| 422 |
+
Predict IFD format labels from list of word lists.
|
| 423 |
+
|
| 424 |
+
B = batch_size, Ws = seq_words
|
| 425 |
|
| 426 |
Args:
|
| 427 |
+
sentences: List of sentences, each a list of words
|
| 428 |
+
tokenizer: HuggingFace tokenizer
|
| 429 |
+
truncate: Whether to truncate if too long
|
| 430 |
|
| 431 |
Returns:
|
| 432 |
+
ifd_predictions: List of IFD labels per sentence (B x Ws)
|
| 433 |
"""
|
| 434 |
+
# Get model predictions in (category, [attributes]) format
|
| 435 |
+
predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
# Convert each sentence's predictions to IFD format
|
| 438 |
+
ifd_predictions = []
|
| 439 |
+
for sentence_predictions in predictions:
|
| 440 |
+
ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
|
| 441 |
+
ifd_predictions.append(ifd_labels)
|
| 442 |
|
| 443 |
+
return ifd_predictions
|
| 444 |
+
|
| 445 |
+
def _word_ids_to_word_mask(self, word_ids: List[int]) -> torch.Tensor:
|
| 446 |
"""
|
| 447 |
+
Convert word_ids to binary mask indicating word boundaries.
|
| 448 |
|
| 449 |
+
L = seq_len
|
| 450 |
|
| 451 |
Args:
|
| 452 |
+
word_ids: Word id sequence for a single sequence
|
| 453 |
+
seq_len: Length of the sequence
|
| 454 |
|
| 455 |
Returns:
|
| 456 |
+
word_mask: Binary tensor where 1 indicates start of word (L,)
|
| 457 |
"""
|
| 458 |
+
word_mask = torch.zeros(len(word_ids), dtype=torch.long) # (L,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
+
prev_word_id = None
|
| 461 |
+
for token_idx, word_id in enumerate(word_ids):
|
| 462 |
+
# Skip None values (special tokens and padding)
|
| 463 |
+
if word_id is not None and word_id != prev_word_id:
|
| 464 |
+
word_mask[token_idx] = 1 # Mark word start
|
| 465 |
+
# Only update prev_word_id for valid (non-None) word_ids
|
| 466 |
+
if word_id is not None:
|
| 467 |
+
prev_word_id = word_id
|
| 468 |
|
| 469 |
# Debug logging to match fairseq model
|
| 470 |
+
logger.debug(f"Word mask: {word_mask}")
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
+
return word_mask
|
| 473 |
|
| 474 |
def _logits_to_labels(
|
| 475 |
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
|
| 476 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 477 |
"""
|
| 478 |
Convert logits to human-readable labels using vectorized operations.
|
| 479 |
+
|
| 480 |
Key optimizations:
|
| 481 |
1. Flatten batch dimension to process all words simultaneously
|
| 482 |
2. Vectorized group processing across all words
|
| 483 |
3. Defer string conversion to the very end
|
| 484 |
4. Minimize Python loops and tensor-CPU transfers
|
| 485 |
+
|
| 486 |
B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
|
| 487 |
"""
|
| 488 |
device = cat_logits.device
|
|
|
|
| 495 |
batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
|
| 496 |
for b in range(bsz):
|
| 497 |
if nwords[b] > 0:
|
| 498 |
+
batch_word_mask[b, : nwords[b]] = True
|
| 499 |
+
|
| 500 |
valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
|
| 501 |
total_words = len(valid_positions)
|
| 502 |
+
|
| 503 |
if total_words == 0:
|
| 504 |
return [[] for _ in range(bsz)]
|
| 505 |
|
| 506 |
# Step 2: Vectorized category prediction for all valid words
|
| 507 |
flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
|
| 508 |
flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
|
| 509 |
+
|
| 510 |
# Get categories for all valid words: (total_words,)
|
| 511 |
all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
|
| 512 |
+
|
| 513 |
# Step 3: Vectorized group validity for all words: (total_words x G)
|
| 514 |
all_valid_groups = self.category_to_groups[all_cat_indices]
|
| 515 |
+
|
| 516 |
# Step 4: Collect attributes using vectorized group processing
|
| 517 |
word_to_attrs = {} # word_idx -> list of attr_indices
|
| 518 |
+
|
| 519 |
# Process each group across all words simultaneously
|
| 520 |
for group_idx in range(self.group_sizes.size(0)):
|
| 521 |
group_size = self.group_sizes[group_idx].item()
|
| 522 |
if group_size == 0:
|
| 523 |
continue
|
| 524 |
+
|
| 525 |
# Find words that have this group valid: (words_with_group,)
|
| 526 |
words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
|
| 527 |
if len(words_with_group) == 0:
|
| 528 |
continue
|
| 529 |
+
|
| 530 |
# Get attribute indices for this group
|
| 531 |
group_attr_indices = self.group_attr_indices[group_idx, :group_size]
|
| 532 |
valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
|
| 533 |
if len(valid_attr_indices) == 0:
|
| 534 |
continue
|
| 535 |
+
|
| 536 |
# Get logits for all words that need this group: (words_with_group x group_size)
|
| 537 |
word_positions = valid_positions[words_with_group]
|
| 538 |
group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
|
| 539 |
+
|
| 540 |
if len(valid_attr_indices) == 1:
|
| 541 |
# Binary decision for all words simultaneously: (words_with_group,)
|
| 542 |
decisions = group_logits.sigmoid().squeeze(-1) > 0.5
|
| 543 |
selected_words = words_with_group[decisions]
|
| 544 |
attr_idx = valid_attr_indices[0].item()
|
| 545 |
+
|
| 546 |
for word_idx in selected_words:
|
| 547 |
word_idx_item = word_idx.item()
|
| 548 |
if word_idx_item not in word_to_attrs:
|
|
|
|
| 551 |
else:
|
| 552 |
# Multi-class decision for all words: (words_with_group,)
|
| 553 |
best_indices = group_logits.argmax(dim=-1)
|
| 554 |
+
|
| 555 |
for i, word_idx in enumerate(words_with_group):
|
| 556 |
attr_idx = valid_attr_indices[best_indices[i]].item()
|
| 557 |
word_idx_item = word_idx.item()
|
|
|
|
| 562 |
# Step 5: Reconstruct batch structure and convert to strings (deferred)
|
| 563 |
predictions = []
|
| 564 |
word_counter = 0
|
| 565 |
+
|
| 566 |
for seq_idx in range(bsz):
|
| 567 |
seq_nwords = nwords[seq_idx].item()
|
| 568 |
seq_predictions = []
|
| 569 |
+
|
| 570 |
for _ in range(seq_nwords):
|
| 571 |
# Get category (string conversion deferred)
|
| 572 |
cat_idx = all_cat_indices[word_counter].item()
|
| 573 |
cat_name = schema.label_categories[cat_idx]
|
| 574 |
+
|
| 575 |
# Get attributes (string conversion deferred)
|
| 576 |
attributes = []
|
| 577 |
if word_counter in word_to_attrs:
|
| 578 |
attr_indices = word_to_attrs[word_counter]
|
| 579 |
attributes = [schema.labels[idx] for idx in attr_indices]
|
| 580 |
+
|
| 581 |
# Apply post-processing rules
|
| 582 |
if len(attributes) == 1 and attributes[0] == "pos":
|
| 583 |
# This label is used as a default for training but implied in mim format
|
|
|
|
| 585 |
elif cat_name == "sl" and "act" in attributes:
|
| 586 |
# Number and tense are not shown for sl act in mim format
|
| 587 |
attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
|
| 588 |
+
|
| 589 |
seq_predictions.append((cat_name, attributes))
|
| 590 |
word_counter += 1
|
| 591 |
+
|
| 592 |
predictions.append(seq_predictions)
|
| 593 |
|
| 594 |
return predictions
|
| 595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
| 598 |
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
|