File size: 25,952 Bytes
2d923bf
 
 
 
4ede3eb
2d923bf
 
 
 
 
 
 
 
 
d50a6a0
2d923bf
 
 
 
 
 
 
 
 
 
 
 
 
3bce6c8
2d923bf
 
 
 
 
3bce6c8
 
2d923bf
3bce6c8
2d923bf
 
 
 
3bce6c8
a2c9d48
2d923bf
3bce6c8
2d923bf
 
3bce6c8
 
2d923bf
3bce6c8
 
 
 
2d923bf
3bce6c8
2d923bf
3bce6c8
2d923bf
3bce6c8
2d923bf
3bce6c8
2d923bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaca62a
 
2d923bf
 
 
aaca62a
d50a6a0
aaca62a
 
6edf0cd
bca5d58
1114b8e
bca5d58
 
1114b8e
d50a6a0
 
1114b8e
bca5d58
 
 
1114b8e
bca5d58
 
1114b8e
bca5d58
 
 
 
1114b8e
bca5d58
 
 
 
 
1114b8e
bca5d58
 
 
 
1114b8e
bca5d58
 
 
 
1114b8e
bca5d58
 
 
 
 
1114b8e
6edf0cd
 
1114b8e
6edf0cd
1114b8e
6edf0cd
1114b8e
bca5d58
1114b8e
bca5d58
1114b8e
bca5d58
1114b8e
6edf0cd
aaca62a
2d923bf
 
 
 
 
 
 
 
 
 
 
 
 
 
3bce6c8
a2c9d48
2d923bf
3bce6c8
 
 
2d923bf
 
3bce6c8
 
2d923bf
 
 
 
 
 
 
 
 
 
 
 
aaca62a
2d923bf
 
 
3bce6c8
a2c9d48
3bce6c8
 
a2c9d48
3bce6c8
 
a2c9d48
3bce6c8
 
 
 
aaca62a
2d923bf
 
3bce6c8
 
2d923bf
3bce6c8
a2c9d48
 
3bce6c8
a2c9d48
2d923bf
3bce6c8
 
 
2d923bf
 
3bce6c8
2d923bf
3bce6c8
a2c9d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bce6c8
2d923bf
a2c9d48
 
 
 
3bce6c8
a2c9d48
3bce6c8
a2c9d48
3bce6c8
a2c9d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114b8e
a2c9d48
 
 
 
 
 
 
 
 
 
 
2d923bf
a2c9d48
 
 
2d923bf
a2c9d48
2d923bf
3bce6c8
2d923bf
3bce6c8
a2c9d48
3bce6c8
a2c9d48
2d923bf
3bce6c8
 
a2c9d48
2d923bf
3bce6c8
2d923bf
3bce6c8
 
 
 
 
2d923bf
3d90a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d923bf
 
3d90a1f
2d923bf
 
 
 
3d90a1f
 
2d923bf
3d90a1f
 
 
2d923bf
 
 
 
4ede3eb
 
2d923bf
4ede3eb
 
2d923bf
4ede3eb
 
 
 
 
 
 
2d923bf
3d90a1f
 
 
2d923bf
3d90a1f
2d923bf
3d90a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa55a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d90a1f
 
 
 
 
 
 
a2c9d48
2d923bf
3d90a1f
 
 
2d923bf
 
3d90a1f
2d923bf
3d90a1f
 
aa55a31
3d90a1f
 
2d923bf
3d90a1f
2d923bf
3d90a1f
a2c9d48
2d923bf
3d90a1f
 
2d923bf
 
3d90a1f
2d923bf
3d90a1f
d50a6a0
3d90a1f
 
 
 
 
 
 
 
d50a6a0
aaca62a
3d90a1f
2d923bf
3d90a1f
2d923bf
 
aaca62a
2d923bf
 
bca5d58
3d90a1f
bca5d58
1114b8e
 
 
 
3d90a1f
bca5d58
2d923bf
bca5d58
 
3bce6c8
d50a6a0
a2c9d48
1114b8e
 
 
 
 
3d90a1f
 
1114b8e
 
3d90a1f
1114b8e
 
 
 
 
 
3d90a1f
1114b8e
 
3d90a1f
1114b8e
 
3d90a1f
1114b8e
 
3d90a1f
1114b8e
 
 
 
 
3d90a1f
1114b8e
 
 
bca5d58
3d90a1f
1114b8e
 
 
 
 
3d90a1f
1114b8e
 
 
3d90a1f
1114b8e
 
 
 
 
3d90a1f
1114b8e
 
 
 
 
 
 
 
3d90a1f
1114b8e
 
 
 
 
 
 
 
 
 
3d90a1f
1114b8e
 
d50a6a0
3d90a1f
1114b8e
 
 
d50a6a0
3d90a1f
1114b8e
d50a6a0
1114b8e
 
 
3d90a1f
bca5d58
d50a6a0
 
 
 
 
 
3d90a1f
d50a6a0
1114b8e
3d90a1f
d50a6a0
aaca62a
 
 
2d923bf
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.

import logging
import time
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel

from .configuration import IceBertPosConfig
from .ifd_utils import convert_predictions_to_ifd

logger = logging.getLogger(__name__)


class MultiLabelTokenClassificationHead(nn.Module):
    """Head for multilabel word-level classification tasks."""

    def __init__(self, config: IceBertPosConfig):
        super().__init__()
        self.num_categories = config.num_categories
        self.num_labels = config.num_labels
        self.hidden_size = config.hidden_size

        # (*, H) -> (*, H)
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.activation_fn = F.relu
        self.dropout = nn.Dropout(p=config.classifier_dropout)
        self.layer_norm = nn.LayerNorm(self.hidden_size)

        # Projection heads for multilabel classification
        # (*, H) -> (*, C)
        self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
        # (*, H + C) -> (*, A)
        self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        H = hidden_size, C = num_categories, A = num_attributes, Wt = total_words

        Args:
            features: Word-level features (Wt x H)

        Returns:
            cat_logits: Category logits (Wt x C)
            attr_logits: Attribute logits (Wt x A)
        """
        x = self.dropout(features)  # (Wt x H)
        x = self.dense(x)  # (Wt x H)
        x = self.layer_norm(x)  # (Wt x H)
        x = self.activation_fn(x)  # (Wt x H)

        # (Wt x H) -> (Wt x C)
        cat_logits = self.cat_proj(x)
        cat_probs = torch.softmax(cat_logits, dim=-1)  # (Wt x C)

        # (Wt x H) + (Wt x C) -> (Wt x H+C)
        attr_input = torch.cat((cat_probs, x), dim=-1)
        # (Wt x H+C) -> (Wt x A)
        attr_logits = self.out_proj(attr_input)

        return cat_logits, attr_logits


class IceBertPosForTokenClassification(PreTrainedModel):
    """
    IceBERT model for multilabel token classification (POS tagging).

    This model performs word-level POS tagging by:
    1. Encoding input with RoBERTa
    2. Aggregating subword tokens to word-level representations
    3. Predicting both categories and attributes for each word
    """

    config_class = IceBertPosConfig

    def __init__(self, config: IceBertPosConfig):
        super().__init__(config)
        self.config = config
        self.num_categories = config.num_categories
        self.num_labels = config.num_labels

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.classifier = MultiLabelTokenClassificationHead(config)

        self._setup_label_mappings()

        # Initialize weights and apply final processing
        self.post_init()

    def _setup_label_mappings(self):
        """Setup label mappings using schema methods."""
        schema = self.config.label_schema

        # Create tensors as regular attributes (not buffers to avoid init warnings)
        self.group_mask = schema.get_group_masks()  # (C x G)

        # Convert group mappings to tensor format for GPU operations
        self._create_tensor_group_mappings(schema)

        # Category name to index mapping (regular dict, no device movement needed)
        self.category_name_to_index = schema.get_category_name_to_index()

    def _create_tensor_group_mappings(self, schema):
        """
        Create tensor-based group mappings for efficient GPU operations.

        Converts Python dict-based schema to tensors to avoid CPU-GPU context switching.
        This optimization replaces dict lookups with tensor indexing for better performance.

        C = num_categories, G = num_groups, A = num_attributes
        """
        num_groups = len(schema.group_names)
        device = torch.device("cpu")  # Will be moved with model

        # Create group attribute indices tensor: (G x max_group_size)
        # Instead of dict lookups, we can index directly: group_attr_indices[group_id, :]
        max_group_size = max(len(labels) for labels in schema.group_name_to_labels.values())
        self.group_attr_indices = torch.full((num_groups, max_group_size), -1, dtype=torch.long, device=device)
        self.group_sizes = torch.zeros(num_groups, dtype=torch.long, device=device)  # (G,)

        for group_idx, group_name in enumerate(schema.group_names):
            group_labels = schema.group_name_to_labels[group_name]
            group_size = len(group_labels)
            self.group_sizes[group_idx] = group_size

            for label_idx, label in enumerate(group_labels):
                if label in schema.labels:
                    attr_idx = schema.labels.index(label)
                    self.group_attr_indices[group_idx, label_idx] = attr_idx

        # Create category to groups mapping: (C x G) - which groups are valid for each category
        # Replaces dict-based category_to_group_names with tensor indexing
        # Usage: category_to_groups[cat_idx, :] gives valid groups for category cat_idx
        self.category_to_groups = self.group_mask.clone()  # (C x G)

    def _apply(self, fn):  # type: ignore[override]
        """Override _apply to move our custom tensors with the model."""
        super()._apply(fn)

        # Move our custom tensors when model.to(device) is called
        if hasattr(self, "group_mask"):
            self.group_mask = fn(self.group_mask)
        if hasattr(self, "group_attr_indices"):
            self.group_attr_indices = fn(self.group_attr_indices)
        if hasattr(self, "group_sizes"):
            self.group_sizes = fn(self.group_sizes)
        if hasattr(self, "category_to_groups"):
            self.category_to_groups = fn(self.category_to_groups)

        return self

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        word_mask: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        B = batch_size, L = seq_len, H = hidden_size, C = num_categories, A = num_attributes, W = max_words

        Args:
            input_ids: Token indices (B x L)
            attention_mask: Attention mask (B x L)
            word_mask: Binary mask indicating word boundaries, 1 = word start (B x L)

        Returns:
            cat_logits: Category logits (B x W x C)
            attr_logits: Attribute logits (B x W x A)
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Get RoBERTa outputs
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]  # (B x L x H)

        # (B x L x H) -> (Wt x H)
        word_embeddings = self._aggregate_subword_tokens(hidden_states, word_mask, attention_mask)

        # (Wt x H) -> (Wt x C), (Wt x A)
        cat_logits, attr_logits = self.classifier(word_embeddings)

        # (Wt x C) -> (B x W x C), (Wt x A) -> (B x W x A)
        nwords = word_mask.sum(dim=-1)  # (B,)
        cat_logits = self._reshape_to_batch_format(cat_logits, nwords)
        attr_logits = self._reshape_to_batch_format(attr_logits, nwords)
        return cat_logits, attr_logits

    def _aggregate_subword_tokens(
        self, sequence_output: torch.Tensor, word_mask: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Average subword tokens within each word to get word-level representations.
        Vectorized implementation using scatter operations for efficiency.

        B = batch_size, L = seq_len, H = hidden_size, Wt = total_words

        Args:
            sequence_output: Subword token representations (B x L x H)
            word_mask: Binary mask where 1 indicates start of word (B x L)
            attention_mask: Attention mask to exclude padding tokens (B x L)

        Returns:
            word_features: Concatenated word-level features (Wt x H)
        """
        batch_size, seq_len, hidden_size = sequence_output.shape
        device = sequence_output.device

        # Create word indices mapping each token to its word
        # Strategy: assign each token to a word ID, then use scatter operations to sum/average
        # Only tokens that belong to actual words get valid indices
        word_indices = torch.full_like(word_mask, -1, dtype=torch.long)  # (B x L)

        # Build word indices by finding word boundaries
        # Each token gets assigned to a word index (0, 1, 2, ...) within its sequence
        for b in range(batch_size):
            valid_mask = attention_mask[b].bool()  # (L,) - exclude padding tokens
            if not valid_mask.any():
                continue

            # Get word starts for this sequence
            seq_word_mask = word_mask[b, valid_mask]  # (Lv,) - only valid positions
            word_starts = seq_word_mask.nonzero(as_tuple=True)[0]  # (Ws,) - positions where words start

            if len(word_starts) == 0:
                continue

            # Assign each token to its word within this sequence
            seq_word_indices = torch.full((len(seq_word_mask),), -1, dtype=torch.long, device=device)

            for i, start_pos in enumerate(word_starts):
                # Find end position (next word start or end of sequence)
                if i + 1 < len(word_starts):
                    end_pos = word_starts[i + 1]  # Next word boundary
                else:
                    end_pos = len(seq_word_mask)  # End of sequence

                # All tokens from start_pos to end_pos belong to word i
                seq_word_indices[start_pos:end_pos] = i

            # Store the word indices for this sequence
            word_indices[b, valid_mask] = seq_word_indices

        # Create global word indices across the entire batch
        # Convert local word indices (0,1,2... per sequence) to global indices (0,1,2...total_words-1)
        # This allows us to use scatter operations across the entire batch
        max_words_per_seq = word_mask.sum(dim=-1)  # (B,) - words per sequence
        word_offset = torch.cat(
            [torch.zeros(1, device=device, dtype=torch.long), max_words_per_seq.cumsum(dim=0)[:-1]]
        )  # (B,) - cumulative word offsets

        # Add batch offsets to make global unique indices
        # E.g., if batch has [3,2] words: seq0=[0,1,2], seq1=[3,4]
        global_word_indices = word_indices + word_offset.unsqueeze(1)  # (B x L)

        # Flatten everything for scatter operations
        flat_output = sequence_output.view(-1, hidden_size)  # (B*L x H)
        flat_word_indices = global_word_indices.view(-1)  # (B*L,)
        flat_attention = attention_mask.view(-1)  # (B*L,)

        # Only use tokens that belong to words (not padding and not before first word)
        valid_word_tokens = (flat_attention.bool()) & (flat_word_indices >= 0)  # (B*L,)
        valid_output = flat_output[valid_word_tokens]  # (valid_word_tokens x H)
        valid_word_indices = flat_word_indices[valid_word_tokens]  # (valid_word_tokens,)

        total_words = max_words_per_seq.sum()
        if total_words == 0:
            return torch.empty(0, hidden_size, device=device)

        # Vectorized aggregation using scatter operations
        # Sum all token embeddings that belong to the same word
        word_sums = torch.zeros(total_words, hidden_size, device=device)  # (Wt x H)
        word_sums.scatter_add_(0, valid_word_indices.unsqueeze(1).expand(-1, hidden_size), valid_output)

        # Count how many tokens belong to each word (for averaging)
        word_counts = torch.zeros(total_words, device=device)  # (Wt,)
        word_counts.scatter_add_(0, valid_word_indices, torch.ones_like(valid_word_indices, dtype=torch.float))

        # Compute average: word_embedding = sum_of_tokens / count_of_tokens
        word_counts = torch.clamp(word_counts, min=1.0)  # Prevent division by zero
        word_features = word_sums / word_counts.unsqueeze(1)  # (Wt x H)

        return word_features

    def _reshape_to_batch_format(self, logits: torch.Tensor, nwords: torch.Tensor) -> torch.Tensor:
        """
        Reshape concatenated word predictions back to padded batch format.

        B = batch_size, W = max_words, Wt = total_words, K = num_classes

        Args:
            logits: Concatenated word predictions (Wt x K)
            nwords: Number of words per sequence (B,)

        Returns:
            batch_logits: Batched predictions (B x W x K)
        """
        return pad_sequence(
            logits.split(nwords.tolist()),
            padding_value=0,
            batch_first=True,
        )

    def prepare_inputs(
        self, words: List[str], tokenizer, truncate: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Prepare inputs for a list of words.

        Args:
            words: List of words
            tokenizer: HuggingFace tokenizer
            truncate: Whether to truncate if too long

        Returns:
            Tuple of (input_ids, attention_mask, word_mask) without batch dimension.
        """
        # Encode with word boundary preservation
        encoding = tokenizer.encode_plus(
            words,
            return_tensors="pt",
            is_split_into_words=True,
            add_special_tokens=True,
            truncation=truncate,
            # The model was probably trained with a lot shorter sequences
            max_length=self.config.max_position_embeddings - 2,
        )

        input_ids = encoding["input_ids"].squeeze(0)  # (L,)
        attention_mask = torch.ones_like(input_ids)

        # Get word_ids and convert to word_mask
        word_ids = encoding.word_ids()
        word_mask = self._word_ids_to_word_mask(word_ids)

        # Debug logging to match fairseq model
        logger.debug(f"Encoded tokens: {input_ids}")  # (L,)
        logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(input_ids.tolist())}")
        logger.debug(f"Word IDs: {word_ids}")  # (L,)
        logger.debug(f"Word mask: {word_mask}")

        return input_ids, attention_mask, word_mask

    @torch.no_grad()
    def predict_labels(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_mask: torch.Tensor
    ) -> List[List[Tuple[str, List[str]]]]:
        """
        Predict POS labels for input sequences.

        B = batch_size, L = seq_len

        Args:
            input_ids: Token indices (B x L)
            attention_mask: Attention mask (B x L)
            word_mask: Binary mask indicating word boundaries (B x L)

        Returns:
            List of sequences, each containing (category, [attributes]) per word
        """
        # Time the forward pass
        start_time = time.perf_counter()
        cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
        forward_time = time.perf_counter() - start_time
        logger.debug(f"Forward pass took {forward_time:.4f} seconds")

        # Time the logits to labels conversion
        start_time = time.perf_counter()
        result = self._logits_to_labels(cat_logits, attr_logits, word_mask)
        logits_to_labels_time = time.perf_counter() - start_time
        logger.debug(f"Logits to labels conversion took {logits_to_labels_time:.4f} seconds")

        return result

    def predict_labels_from_text(
        self, sentences: List[List[str]], tokenizer, truncate: bool = False
    ) -> List[List[Tuple[str, List[str]]]]:
        """
        Predict POS labels from list of word lists.

        Args:
            sentences: List of sentences, each a list of words
            tokenizer: HuggingFace tokenizer
            truncate: Whether to truncate if too long

        Returns:
            List of sequences, each containing (category, [attributes]) per word
        """
        # Use prepare_inputs for each sentence and batch them
        all_input_ids = []
        all_attention_masks = []
        all_word_masks = []

        for words in sentences:
            input_ids, attention_mask, word_mask = self.prepare_inputs(words, tokenizer, truncate)
            all_input_ids.append(input_ids)
            all_attention_masks.append(attention_mask)
            all_word_masks.append(word_mask)

        # Pad sequences to same length
        batch_input_ids = pad_sequence(all_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        batch_attention_mask = pad_sequence(all_attention_masks, batch_first=True, padding_value=0)
        batch_word_mask = pad_sequence(all_word_masks, batch_first=True, padding_value=0)

        return self.predict_labels(batch_input_ids, batch_attention_mask, batch_word_mask)

    def convert_labels_to_ifd(self, predictions: List[List[Tuple[str, List[str]]]]) -> List[List[str]]:
        """
        Convert model predictions to IFD format labels.

        Args:
            predictions: List of sequences, each containing (category, [attributes]) per word

        Returns:
            List of IFD format labels per sentence
        """
        # Time the IFD conversion
        start_time = time.perf_counter()
        ifd_labels = []
        for sentence_predictions in predictions:
            ifd_labels.append(convert_predictions_to_ifd(sentence_predictions))
        ifd_conversion_time = time.perf_counter() - start_time
        logger.debug(f"IFD conversion took {ifd_conversion_time:.4f} seconds")
        return ifd_labels

    def predict_ifd_labels_from_text(
        self, sentences: List[List[str]], tokenizer, truncate: bool = False
    ) -> List[List[str]]:
        """
        Predict IFD format labels from list of word lists.

        B = batch_size, Ws = seq_words

        Args:
            sentences: List of sentences, each a list of words
            tokenizer: HuggingFace tokenizer
            truncate: Whether to truncate if too long

        Returns:
            ifd_predictions: List of IFD labels per sentence (B x Ws)
        """
        # Get model predictions in (category, [attributes]) format
        predictions = self.predict_labels_from_text(sentences, tokenizer, truncate)
        return self.convert_labels_to_ifd(predictions)

    def _word_ids_to_word_mask(self, word_ids: List[int]) -> torch.Tensor:
        """
        Convert word_ids to binary mask indicating word boundaries.

        L = seq_len

        Args:
            word_ids: Word id sequence for a single sequence
            seq_len: Length of the sequence

        Returns:
            word_mask: Binary tensor where 1 indicates start of word (L,)
        """
        word_mask = torch.zeros(len(word_ids), dtype=torch.long)  # (L,)

        prev_word_id = None
        for token_idx, word_id in enumerate(word_ids):
            # Skip None values (special tokens and padding)
            if word_id is not None and word_id != prev_word_id:
                word_mask[token_idx] = 1  # Mark word start
            # Only update prev_word_id for valid (non-None) word_ids
            if word_id is not None:
                prev_word_id = word_id

        # Debug logging to match fairseq model
        logger.debug(f"Word mask: {word_mask}")

        return word_mask

    def _logits_to_labels(
        self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
    ) -> List[List[Tuple[str, List[str]]]]:
        """
        Convert logits to human-readable labels using vectorized operations.

        Key optimizations:
        1. Flatten batch dimension to process all words simultaneously
        2. Vectorized group processing across all words
        3. Defer string conversion to the very end
        4. Minimize Python loops and tensor-CPU transfers

        B = batch_size, W = max_words, C = num_categories, A = num_attributes, G = num_groups
        """
        device = cat_logits.device
        bsz, max_words = cat_logits.shape[:2]
        nwords = word_mask.sum(-1)  # (B,)
        schema = self.config.label_schema

        # Step 1: Create valid word mask and flatten batch dimension
        # (B x W) -> (total_words,) to process all words simultaneously
        batch_word_mask = torch.zeros(bsz, max_words, dtype=torch.bool, device=device)
        for b in range(bsz):
            if nwords[b] > 0:
                batch_word_mask[b, : nwords[b]] = True

        valid_positions = batch_word_mask.flatten().nonzero(as_tuple=True)[0]  # (total_words,)
        total_words = len(valid_positions)

        if total_words == 0:
            return [[] for _ in range(bsz)]

        # Step 2: Vectorized category prediction for all valid words
        flat_cat_logits = cat_logits.view(-1, cat_logits.size(-1))  # (B*W x C)
        flat_attr_logits = attr_logits.view(-1, attr_logits.size(-1))  # (B*W x A)

        # Get categories for all valid words: (total_words,)
        all_cat_indices = flat_cat_logits[valid_positions].argmax(dim=-1)

        # Step 3: Vectorized group validity for all words: (total_words x G)
        all_valid_groups = self.category_to_groups[all_cat_indices]

        # Step 4: Collect attributes using vectorized group processing
        word_to_attrs = {}  # word_idx -> list of attr_indices

        # Process each group across all words simultaneously
        for group_idx in range(self.group_sizes.size(0)):
            group_size = self.group_sizes[group_idx].item()
            if group_size == 0:
                continue

            # Find words that have this group valid: (words_with_group,)
            words_with_group = all_valid_groups[:, group_idx].nonzero(as_tuple=True)[0]
            if len(words_with_group) == 0:
                continue

            # Get attribute indices for this group
            group_attr_indices = self.group_attr_indices[group_idx, :group_size]
            valid_attr_indices = group_attr_indices[group_attr_indices >= 0]
            if len(valid_attr_indices) == 0:
                continue

            # Get logits for all words that need this group: (words_with_group x group_size)
            word_positions = valid_positions[words_with_group]
            group_logits = flat_attr_logits[word_positions][:, valid_attr_indices]

            if len(valid_attr_indices) == 1:
                # Binary decision for all words simultaneously: (words_with_group,)
                decisions = group_logits.sigmoid().squeeze(-1) > 0.5
                selected_words = words_with_group[decisions]
                attr_idx = valid_attr_indices[0].item()

                for word_idx in selected_words:
                    word_idx_item = word_idx.item()
                    if word_idx_item not in word_to_attrs:
                        word_to_attrs[word_idx_item] = []
                    word_to_attrs[word_idx_item].append(attr_idx)
            else:
                # Multi-class decision for all words: (words_with_group,)
                best_indices = group_logits.argmax(dim=-1)

                for i, word_idx in enumerate(words_with_group):
                    attr_idx = valid_attr_indices[best_indices[i]].item()
                    word_idx_item = word_idx.item()
                    if word_idx_item not in word_to_attrs:
                        word_to_attrs[word_idx_item] = []
                    word_to_attrs[word_idx_item].append(attr_idx)

        # Step 5: Reconstruct batch structure and convert to strings (deferred)
        predictions = []
        word_counter = 0

        for seq_idx in range(bsz):
            seq_nwords = nwords[seq_idx].item()
            seq_predictions = []

            for _ in range(seq_nwords):
                # Get category (string conversion deferred)
                cat_idx = all_cat_indices[word_counter].item()
                cat_name = schema.label_categories[cat_idx]

                # Get attributes (string conversion deferred)
                attributes = []
                if word_counter in word_to_attrs:
                    attr_indices = word_to_attrs[word_counter]
                    attributes = [schema.labels[idx] for idx in attr_indices]

                # Apply post-processing rules
                if len(attributes) == 1 and attributes[0] == "pos":
                    # This label is used as a default for training but implied in mim format
                    attributes = []
                elif cat_name == "sl" and "act" in attributes:
                    # Number and tense are not shown for sl act in mim format
                    attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]

                seq_predictions.append((cat_name, attributes))
                word_counter += 1

            predictions.append(seq_predictions)

        return predictions


AutoConfig.register("icebert-pos", IceBertPosConfig)
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
IceBertPosConfig.register_for_auto_class()
IceBertPosForTokenClassification.register_for_auto_class("AutoModel")