haukurpj commited on
Commit
3d90a1f
·
1 Parent(s): 1114b8e

changing interface to accept words instead of a sentence string. Documenting the truncation flag

Browse files
Files changed (2) hide show
  1. README.md +50 -7
  2. modeling.py +134 -96
README.md CHANGED
@@ -9,10 +9,14 @@ paper: https://arxiv.org/abs/2201.05601
9
  ---
10
  ## Prediction Methods
11
 
12
- The model provides two prediction methods:
13
 
14
- - **`predict_labels_from_text()`**: Returns structured predictions as (category, [attributes]) tuples. These can be slightly more readable and more suitable for some applications.
15
- - **`predict_ifd_labels_from_text()`**: Returns predictions in IFD (Icelandic Frequency Dictionary) format. Use this for evaluation against MIM-GOLD datasets or when you need compatibility with traditional Icelandic POS taggers.
 
 
 
 
16
 
17
  ```python
18
  from transformers import AutoModel, AutoTokenizer
@@ -22,16 +26,17 @@ tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
22
 
23
  # Example sentence
24
  sentence = "Ég veit að þú kemur í kvöld til mín ."
 
25
 
26
  # Get predictions in (category, [attributes]) format
27
- result = model.predict_labels_from_text([sentence], tokenizer)
28
  expected = [
29
  [
30
  ("fp", ["1", "sing", "nom"]),
31
- ("sf", ["act", "1", "sing", "pres"]),
32
  ("c", []),
33
  ("fp", ["2", "sing", "nom"]),
34
- ("sf", ["act", "2", "sing", "pres"]),
35
  ("af", []),
36
  ("n", ["neut", "sing", "acc"]),
37
  ("af", []),
@@ -43,10 +48,48 @@ assert result == expected, f"Expected {expected}, but got {result}"
43
  print("Test passed successfully!")
44
 
45
  # Get predictions in IFD format (for MIM-GOLD evaluation)
46
- ifd_result = model.predict_ifd_labels_from_text([sentence], tokenizer)
47
  ifd_expected = [
48
  ["fp1en", "sfg1en", "c", "fp2en", "sfg2en", "af", "nheo", "af", "fp1ee", "pl"]
49
  ]
50
  assert ifd_result == ifd_expected, f"Expected {ifd_expected}, but got {ifd_result}"
51
  print("IFD conversion test passed successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ```
 
9
  ---
10
  ## Prediction Methods
11
 
12
+ The model provides several prediction methods:
13
 
14
+ - **`prepare_inputs(words, tokenizer, truncate=False)`**: Prepares inputs for a single list of words, returning tensors without batch dimension.
15
+ - **`predict_labels(input_ids, attention_mask, word_mask)`**: Low-level prediction from prepared tensors with batch dimension.
16
+ - **`predict_labels_from_text(sentences, tokenizer, truncate=False)`**: Returns structured predictions as (category, [attributes]) tuples from word lists. These can be slightly more readable and more suitable for some applications.
17
+ - **`predict_ifd_labels_from_text(sentences, tokenizer, truncate=False)`**: Returns predictions in IFD (Icelandic Frequency Dictionary) format from word lists. Use this for evaluation against MIM-GOLD datasets or when you need compatibility with traditional Icelandic POS taggers.
18
+
19
+ All methods accept pre-tokenized word lists rather than raw sentences for better control over tokenization.
20
 
21
  ```python
22
  from transformers import AutoModel, AutoTokenizer
 
26
 
27
  # Example sentence
28
  sentence = "Ég veit að þú kemur í kvöld til mín ."
29
+ sentence_words = sentence.split()
30
 
31
  # Get predictions in (category, [attributes]) format
32
+ result = model.predict_labels_from_text([sentence_words], tokenizer)
33
  expected = [
34
  [
35
  ("fp", ["1", "sing", "nom"]),
36
+ ("sf", ["sing", "act", "1", "pres"]),
37
  ("c", []),
38
  ("fp", ["2", "sing", "nom"]),
39
+ ("sf", ["sing", "act", "2", "pres"]),
40
  ("af", []),
41
  ("n", ["neut", "sing", "acc"]),
42
  ("af", []),
 
48
  print("Test passed successfully!")
49
 
50
  # Get predictions in IFD format (for MIM-GOLD evaluation)
51
+ ifd_result = model.predict_ifd_labels_from_text([sentence_words], tokenizer)
52
  ifd_expected = [
53
  ["fp1en", "sfg1en", "c", "fp2en", "sfg2en", "af", "nheo", "af", "fp1ee", "pl"]
54
  ]
55
  assert ifd_result == ifd_expected, f"Expected {ifd_expected}, but got {ifd_result}"
56
  print("IFD conversion test passed successfully!")
57
+
58
+ # Alternative: use prepare_inputs for single sentence prediction
59
+ input_ids, attention_mask, word_mask = model.prepare_inputs(sentence_words, tokenizer)
60
+ single_result = model.predict_labels(input_ids.unsqeeze(0), attention_mask.unsqeeze(0), word_mask.unsqeeze(0))
61
+ assert single_result == expected, f"Expected {expected}, but got {single_result}"
62
+ print("Single sentence prediction test passed successfully!")
63
+ ```
64
+
65
+ ## Handling Long Sequences with Truncation
66
+
67
+ By default, `truncate=False` to avoid hard-to-debug issues where input is silently truncated. However, very long sequences will cause errors:
68
+
69
+ ```python
70
+ from transformers import AutoModel, AutoTokenizer
71
+
72
+ model = AutoModel.from_pretrained("mideind/IceBERT-PoS", trust_remote_code=True)
73
+ tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
74
+
75
+ # Create a very long sentence that exceeds model limits
76
+ words = ["Þetta", "er", "mjög", "löng", "setning"] * 200 # Very long sentence
77
+ print(f"Input length: {len(words)} words")
78
+
79
+ # This will crash due to sequence length exceeding model limits
80
+ try:
81
+ result = model.predict_labels_from_text([words], tokenizer, truncate=False)
82
+ print("This shouldn't print - sequence was too long!")
83
+ except Exception as e:
84
+ print(f"Error as expected: {type(e).__name__}")
85
+
86
+ # Use truncate=True for long sequences
87
+ result_truncated = model.predict_labels_from_text([words], tokenizer, truncate=True)
88
+ print(f"Truncated result length: {len(result_truncated[0])} tokens")
89
+ print("Warning: Output length differs from input length due to truncation!")
90
+
91
+ # When using truncation, you must handle the length mismatch carefully
92
+ # The output will have fewer predictions than input words
93
+ assert len(result_truncated[0]) < len(words), "Truncation should reduce length"
94
+ print("Truncation example completed successfully!")
95
  ```
modeling.py CHANGED
@@ -322,105 +322,167 @@ class IceBertPosForTokenClassification(PreTrainedModel):
322
  batch_first=True,
323
  )
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  @torch.no_grad()
326
  def predict_labels(
327
- self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]]
328
  ) -> List[List[Tuple[str, List[str]]]]:
329
  """
330
  Predict POS labels for input sequences.
331
 
 
 
332
  Args:
333
- input_ids: Token indices
334
- attention_mask: Attention mask
335
- word_ids: Word boundaries
336
 
337
  Returns:
338
  List of sequences, each containing (category, [attributes]) per word
339
  """
340
- # Convert word_ids to word_mask
341
- word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
342
-
343
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
344
 
345
  return self._logits_to_labels(cat_logits, attr_logits, word_mask)
346
 
347
- def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
 
 
348
  """
349
- Convert word_ids to binary mask indicating word boundaries.
350
 
351
- B = batch_size, L = seq_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  Args:
354
- word_ids: List of word id sequences for each batch item
355
- input_shape: Shape of input_ids tensor (B x L)
 
356
 
357
  Returns:
358
- word_mask: Binary tensor where 1 indicates start of word (B x L)
359
  """
360
- batch_size, seq_len = input_shape
361
- word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) # (B x L)
362
-
363
- for batch_idx, seq_word_ids in enumerate(word_ids):
364
- prev_word_id = None
365
- for token_idx, word_id in enumerate(seq_word_ids):
366
- # Skip None values (special tokens and padding)
367
- if word_id is not None and word_id != prev_word_id:
368
- word_mask[batch_idx, token_idx] = 1 # Mark word start
369
- # Only update prev_word_id for valid (non-None) word_ids
370
- if word_id is not None:
371
- prev_word_id = word_id
372
-
373
- # Debug logging to match fairseq model
374
- logger.debug(f"Word mask: {word_mask[batch_idx]}")
375
 
376
- return word_mask
 
 
 
 
377
 
378
- def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
 
 
379
  """
380
- Predict POS labels from raw text.
381
 
382
- B = batch_size, L = seq_len
383
 
384
  Args:
385
- sentences: List of input sentences (B,)
386
- tokenizer: HuggingFace tokenizer
387
 
388
  Returns:
389
- List of sequences, each containing (category, [attributes]) per word
390
  """
391
- # Split sentences by spaces to get proper word boundaries
392
- # This fixes the issue where tokens like "Kl." get split incorrectly
393
- sentences_split = [sentence.split() for sentence in sentences]
394
-
395
- # Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
396
- encoding = tokenizer.batch_encode_plus(
397
- sentences_split, return_tensors="pt", padding=True, is_split_into_words=True, add_special_tokens=True
398
- )
399
 
400
- batch_input_ids = encoding["input_ids"]
401
- batch_attention_mask = encoding["attention_mask"]
402
- word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
 
 
 
 
 
403
 
404
  # Debug logging to match fairseq model
405
- for i in range(len(sentences)):
406
- logger.debug(f"Encoded tokens: {batch_input_ids[i]}") # (L,)
407
- logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
408
- logger.debug(f"Word IDs: {word_ids_list[i]}") # (L,)
409
 
410
- return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
411
 
412
  def _logits_to_labels(
413
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
414
  ) -> List[List[Tuple[str, List[str]]]]:
415
  """
416
  Convert logits to human-readable labels using vectorized operations.
417
-
418
  Key optimizations:
419
  1. Flatten batch dimension to process all words simultaneously
420
  2. Vectorized group processing across all words
421
  3. Defer string conversion to the very end
422
  4. Minimize Python loops and tensor-CPU transfers
423
-
424
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
425
  """
426
  device = cat_logits.device
@@ -433,54 +495,54 @@ class IceBertPosForTokenClassification(PreTrainedModel):
433
  batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
434
  for b in range(bsz):
435
  if nwords[b] > 0:
436
- batch_word_mask[b, :nwords[b]] = True
437
-
438
  valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
439
  total_words = len(valid_positions)
440
-
441
  if total_words == 0:
442
  return [[] for _ in range(bsz)]
443
 
444
  # Step 2: Vectorized category prediction for all valid words
445
  flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
446
  flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
447
-
448
  # Get categories for all valid words: (total_words,)
449
  all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
450
-
451
  # Step 3: Vectorized group validity for all words: (total_words x G)
452
  all_valid_groups = self.category_to_groups[all_cat_indices]
453
-
454
  # Step 4: Collect attributes using vectorized group processing
455
  word_to_attrs = {} # word_idx -> list of attr_indices
456
-
457
  # Process each group across all words simultaneously
458
  for group_idx in range(self.group_sizes.size(0)):
459
  group_size = self.group_sizes[group_idx].item()
460
  if group_size == 0:
461
  continue
462
-
463
  # Find words that have this group valid: (words_with_group,)
464
  words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
465
  if len(words_with_group) == 0:
466
  continue
467
-
468
  # Get attribute indices for this group
469
  group_attr_indices = self.group_attr_indices[group_idx, :group_size]
470
  valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
471
  if len(valid_attr_indices) == 0:
472
  continue
473
-
474
  # Get logits for all words that need this group: (words_with_group x group_size)
475
  word_positions = valid_positions[words_with_group]
476
  group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
477
-
478
  if len(valid_attr_indices) == 1:
479
  # Binary decision for all words simultaneously: (words_with_group,)
480
  decisions = group_logits.sigmoid().squeeze(-1) > 0.5
481
  selected_words = words_with_group[decisions]
482
  attr_idx = valid_attr_indices[0].item()
483
-
484
  for word_idx in selected_words:
485
  word_idx_item = word_idx.item()
486
  if word_idx_item not in word_to_attrs:
@@ -489,7 +551,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
489
  else:
490
  # Multi-class decision for all words: (words_with_group,)
491
  best_indices = group_logits.argmax(dim=-1)
492
-
493
  for i, word_idx in enumerate(words_with_group):
494
  attr_idx = valid_attr_indices[best_indices[i]].item()
495
  word_idx_item = word_idx.item()
@@ -500,22 +562,22 @@ class IceBertPosForTokenClassification(PreTrainedModel):
500
  # Step 5: Reconstruct batch structure and convert to strings (deferred)
501
  predictions = []
502
  word_counter = 0
503
-
504
  for seq_idx in range(bsz):
505
  seq_nwords = nwords[seq_idx].item()
506
  seq_predictions = []
507
-
508
  for _ in range(seq_nwords):
509
  # Get category (string conversion deferred)
510
  cat_idx = all_cat_indices[word_counter].item()
511
  cat_name = schema.label_categories[cat_idx]
512
-
513
  # Get attributes (string conversion deferred)
514
  attributes = []
515
  if word_counter in word_to_attrs:
516
  attr_indices = word_to_attrs[word_counter]
517
  attributes = [schema.labels[idx] for idx in attr_indices]
518
-
519
  # Apply post-processing rules
520
  if len(attributes) == 1 and attributes[0] == "pos":
521
  # This label is used as a default for training but implied in mim format
@@ -523,38 +585,14 @@ class IceBertPosForTokenClassification(PreTrainedModel):
523
  elif cat_name == "sl" and "act" in attributes:
524
  # Number and tense are not shown for sl act in mim format
525
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
526
-
527
  seq_predictions.append((cat_name, attributes))
528
  word_counter += 1
529
-
530
  predictions.append(seq_predictions)
531
 
532
  return predictions
533
 
534
- def predict_ifd_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[str]]:
535
- """
536
- Predict IFD format labels from raw text.
537
-
538
- B = batch_size, Ws = seq_words
539
-
540
- Args:
541
- sentences: List of input sentences (B,)
542
- tokenizer: HuggingFace tokenizer
543
-
544
- Returns:
545
- ifd_predictions: List of IFD labels per sentence (B x Ws)
546
- """
547
- # Get model predictions in (category, [attributes]) format
548
- predictions = self.predict_labels_from_text(sentences, tokenizer)
549
-
550
- # Convert each sentence's predictions to IFD format
551
- ifd_predictions = []
552
- for sentence_predictions in predictions:
553
- ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
554
- ifd_predictions.append(ifd_labels)
555
-
556
- return ifd_predictions
557
-
558
 
559
  AutoConfig.register("icebert-pos", IceBertPosConfig)
560
  AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
 
322
  batch_first=True,
323
  )
324
 
325
+ def prepare_inputs(
326
+ self, words: List[str], tokenizer, truncate: bool = False
327
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
328
+ """
329
+ Prepare inputs for a list of words.
330
+
331
+ Args:
332
+ words: List of words
333
+ tokenizer: HuggingFace tokenizer
334
+ truncate: Whether to truncate if too long
335
+
336
+ Returns:
337
+ Tuple of (input_ids, attention_mask, word_mask) without batch dimension.
338
+ """
339
+ # Encode with word boundary preservation
340
+ encoding = tokenizer.encode_plus(
341
+ words,
342
+ return_tensors="pt",
343
+ is_split_into_words=True,
344
+ add_special_tokens=True,
345
+ truncation=truncate,
346
+ # The model was probably trained with a lot shorter sequences
347
+ max_length=self.config.max_position_embeddings - 2,
348
+ )
349
+
350
+ input_ids = encoding["input_ids"].squeeze(0) # (L,)
351
+ attention_mask = torch.ones_like(input_ids)
352
+
353
+ # Get word_ids and convert to word_mask
354
+ word_ids = encoding.word_ids()
355
+ word_mask = self._word_ids_to_word_mask(word_ids)
356
+
357
+ # Debug logging to match fairseq model
358
+ logger.debug(f"Encoded tokens: {input_ids}") # (L,)
359
+ logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(input_ids.tolist())}")
360
+ logger.debug(f"Word IDs: {word_ids}") # (L,)
361
+ logger.debug(f"Word mask: {word_mask}")
362
+
363
+ return input_ids, attention_mask, word_mask
364
+
365
  @torch.no_grad()
366
  def predict_labels(
367
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_mask: torch.Tensor
368
  ) -> List[List[Tuple[str, List[str]]]]:
369
  """
370
  Predict POS labels for input sequences.
371
 
372
+ B = batch_size, L = seq_len
373
+
374
  Args:
375
+ input_ids: Token indices (B x L)
376
+ attention_mask: Attention mask (B x L)
377
+ word_mask: Binary mask indicating word boundaries (B x L)
378
 
379
  Returns:
380
  List of sequences, each containing (category, [attributes]) per word
381
  """
 
 
 
382
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
383
 
384
  return self._logits_to_labels(cat_logits, attr_logits, word_mask)
385
 
386
+ def predict_labels_from_text(
387
+ self, sentences: List[List[str]], tokenizer, truncate: bool = False
388
+ ) -> List[List[Tuple[str, List[str]]]]:
389
  """
390
+ Predict POS labels from list of word lists.
391
 
392
+ Args:
393
+ sentences: List of sentences, each a list of words
394
+ tokenizer: HuggingFace tokenizer
395
+ truncate: Whether to truncate if too long
396
+
397
+ Returns:
398
+ List of sequences, each containing (category, [attributes]) per word
399
+ """
400
+ # Use prepare_inputs for each sentence and batch them
401
+ all_input_ids = []
402
+ all_attention_masks = []
403
+ all_word_masks = []
404
+
405
+ for words in sentences:
406
+ input_ids, attention_mask, word_mask = self.prepare_inputs(words, tokenizer, truncate)
407
+ all_input_ids.append(input_ids)
408
+ all_attention_masks.append(attention_mask)
409
+ all_word_masks.append(word_mask)
410
+
411
+ # Pad sequences to same length
412
+ batch_input_ids = pad_sequence(all_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
413
+ batch_attention_mask = pad_sequence(all_attention_masks, batch_first=True, padding_value=0)
414
+ batch_word_mask = pad_sequence(all_word_masks, batch_first=True, padding_value=0)
415
+
416
+ return self.predict_labels(batch_input_ids, batch_attention_mask, batch_word_mask)
417
+
418
+ def predict_ifd_labels_from_text(
419
+ self, sentences: List[List[str]], tokenizer, truncate: bool = False
420
+ ) -> List[List[str]]:
421
+ """
422
+ Predict IFD format labels from list of word lists.
423
+
424
+ B = batch_size, Ws = seq_words
425
 
426
  Args:
427
+ sentences: List of sentences, each a list of words
428
+ tokenizer: HuggingFace tokenizer
429
+ truncate: Whether to truncate if too long
430
 
431
  Returns:
432
+ ifd_predictions: List of IFD labels per sentence (B x Ws)
433
  """
434
+ # Get model predictions in (category, [attributes]) format
435
+ predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
+ # Convert each sentence's predictions to IFD format
438
+ ifd_predictions = []
439
+ for sentence_predictions in predictions:
440
+ ifd_labels = convert_predictions_to_ifd(sentence_predictions) # (Ws,)
441
+ ifd_predictions.append(ifd_labels)
442
 
443
+ return ifd_predictions
444
+
445
+ def _word_ids_to_word_mask(self, word_ids: List[int]) -> torch.Tensor:
446
  """
447
+ Convert word_ids to binary mask indicating word boundaries.
448
 
449
+ L = seq_len
450
 
451
  Args:
452
+ word_ids: Word id sequence for a single sequence
453
+ seq_len: Length of the sequence
454
 
455
  Returns:
456
+ word_mask: Binary tensor where 1 indicates start of word (L,)
457
  """
458
+ word_mask = torch.zeros(len(word_ids), dtype=torch.long) # (L,)
 
 
 
 
 
 
 
459
 
460
+ prev_word_id = None
461
+ for token_idx, word_id in enumerate(word_ids):
462
+ # Skip None values (special tokens and padding)
463
+ if word_id is not None and word_id != prev_word_id:
464
+ word_mask[token_idx] = 1 # Mark word start
465
+ # Only update prev_word_id for valid (non-None) word_ids
466
+ if word_id is not None:
467
+ prev_word_id = word_id
468
 
469
  # Debug logging to match fairseq model
470
+ logger.debug(f"Word mask: {word_mask}")
 
 
 
471
 
472
+ return word_mask
473
 
474
  def _logits_to_labels(
475
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
476
  ) -> List[List[Tuple[str, List[str]]]]:
477
  """
478
  Convert logits to human-readable labels using vectorized operations.
479
+
480
  Key optimizations:
481
  1. Flatten batch dimension to process all words simultaneously
482
  2. Vectorized group processing across all words
483
  3. Defer string conversion to the very end
484
  4. Minimize Python loops and tensor-CPU transfers
485
+
486
  B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
487
  """
488
  device = cat_logits.device
 
495
  batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
496
  for b in range(bsz):
497
  if nwords[b] > 0:
498
+ batch_word_mask[b, : nwords[b]] = True
499
+
500
  valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0] # (total_words,)
501
  total_words = len(valid_positions)
502
+
503
  if total_words == 0:
504
  return [[] for _ in range(bsz)]
505
 
506
  # Step 2: Vectorized category prediction for all valid words
507
  flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1)) # (B*W x C)
508
  flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1)) # (B*W x A)
509
+
510
  # Get categories for all valid words: (total_words,)
511
  all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)
512
+
513
  # Step 3: Vectorized group validity for all words: (total_words x G)
514
  all_valid_groups = self.category_to_groups[all_cat_indices]
515
+
516
  # Step 4: Collect attributes using vectorized group processing
517
  word_to_attrs = {} # word_idx -> list of attr_indices
518
+
519
  # Process each group across all words simultaneously
520
  for group_idx in range(self.group_sizes.size(0)):
521
  group_size = self.group_sizes[group_idx].item()
522
  if group_size == 0:
523
  continue
524
+
525
  # Find words that have this group valid: (words_with_group,)
526
  words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
527
  if len(words_with_group) == 0:
528
  continue
529
+
530
  # Get attribute indices for this group
531
  group_attr_indices = self.group_attr_indices[group_idx, :group_size]
532
  valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
533
  if len(valid_attr_indices) == 0:
534
  continue
535
+
536
  # Get logits for all words that need this group: (words_with_group x group_size)
537
  word_positions = valid_positions[words_with_group]
538
  group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]
539
+
540
  if len(valid_attr_indices) == 1:
541
  # Binary decision for all words simultaneously: (words_with_group,)
542
  decisions = group_logits.sigmoid().squeeze(-1) > 0.5
543
  selected_words = words_with_group[decisions]
544
  attr_idx = valid_attr_indices[0].item()
545
+
546
  for word_idx in selected_words:
547
  word_idx_item = word_idx.item()
548
  if word_idx_item not in word_to_attrs:
 
551
  else:
552
  # Multi-class decision for all words: (words_with_group,)
553
  best_indices = group_logits.argmax(dim=-1)
554
+
555
  for i, word_idx in enumerate(words_with_group):
556
  attr_idx = valid_attr_indices[best_indices[i]].item()
557
  word_idx_item = word_idx.item()
 
562
  # Step 5: Reconstruct batch structure and convert to strings (deferred)
563
  predictions = []
564
  word_counter = 0
565
+
566
  for seq_idx in range(bsz):
567
  seq_nwords = nwords[seq_idx].item()
568
  seq_predictions = []
569
+
570
  for _ in range(seq_nwords):
571
  # Get category (string conversion deferred)
572
  cat_idx = all_cat_indices[word_counter].item()
573
  cat_name = schema.label_categories[cat_idx]
574
+
575
  # Get attributes (string conversion deferred)
576
  attributes = []
577
  if word_counter in word_to_attrs:
578
  attr_indices = word_to_attrs[word_counter]
579
  attributes = [schema.labels[idx] for idx in attr_indices]
580
+
581
  # Apply post-processing rules
582
  if len(attributes) == 1 and attributes[0] == "pos":
583
  # This label is used as a default for training but implied in mim format
 
585
  elif cat_name == "sl" and "act" in attributes:
586
  # Number and tense are not shown for sl act in mim format
587
  attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
588
+
589
  seq_predictions.append((cat_name, attributes))
590
  word_counter += 1
591
+
592
  predictions.append(seq_predictions)
593
 
594
  return predictions
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
  AutoConfig.register("icebert-pos", IceBertPosConfig)
598
  AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)