File size: 6,006 Bytes
c5081c8
 
 
 
0cdb887
c5081c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cdb887
c5081c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from transformers import DebertaV2TokenizerFast
import torch

from multi_head_model import MultiHeadModel
from utils import get_torch_device


class MultiHeadPredictor:
    def __init__(self, model_name_or_path: str):
        self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name_or_path, add_prefix_space=True)
        self.model = MultiHeadModel.from_pretrained(model_name_or_path)
        self.id2label = self.model.config.label_maps

        self.device = get_torch_device()
        self.model.to(self.device)
        self.model.eval()


    def predict(self, text: str):
        """
        Perform multi-headed token classification on a single piece of text.

        :param text: The raw text string.

        :return: A dict with {head_name: [predicted_label_for_each_token]} for the tokens in `text`.
        """
        raw_tokens = text.split()

        # We'll do a single-example batch to replicate training chunk logic.
        # is_split_into_words=True => we pass a list of tokens, not a single string.
        # This returns possibly multiple overflows if the sequence is long:
        encoded = self.tokenizer(
            raw_tokens,
            is_split_into_words=True,
            max_length=512,
            stride=128,
            truncation=True,
            return_overflowing_tokens=True,
            return_offsets_mapping=False,
            padding="max_length"
        )

        # 'overflow_to_sample_mapping' indicates which chunk maps back to this example's index
        # For a single example, they should all map to 0, but let's handle it anyway:
        sample_map = encoded.get("overflow_to_sample_mapping", [0] * len(encoded["input_ids"]))

        # We'll store predictions for each chunk, then reconcile them.
        chunk_preds = []
        chunk_word_ids = []

        # Model forward:
        # We iterate over each chunk, move them to device, and compute logits_dict.
        for i in range(len(encoded["input_ids"])):
            # Build a batch of size 1 for chunk i
            input_ids_tensor = torch.tensor([encoded["input_ids"][i]], dtype=torch.long).to(self.device)
            attention_mask_tensor = torch.tensor([encoded["attention_mask"][i]], dtype=torch.long).to(self.device)

            # The model forward returns logits_dict since we don't provide labels_dict
            with torch.no_grad():
                logits_dict = self.model(
                    input_ids=input_ids_tensor,
                    attention_mask=attention_mask_tensor
                )  # shape for each head: (1, seq_len, num_labels)

            # Convert each head's logits to predicted IDs
            # logits_dict is {head_name: Tensor of shape [1, seq_len, num_labels]}
            pred_ids_dict = {}
            for head_name, logits in logits_dict.items():
                # shape (1, seq_len, num_labels)
                preds = torch.argmax(logits, dim=-1)  # => shape (1, seq_len)
                # Move to CPU numpy
                pred_ids_dict[head_name] = preds[0].cpu().numpy().tolist()

            # Keep track of predicted IDs + the corresponding word_ids for alignment
            chunk_preds.append(pred_ids_dict)

            # Also store the chunk's word_ids (so we can map subwords -> actual token index)
            # Note: you MUST call `tokenizer.word_ids(batch_index=i)` with is_split_into_words=True
            # which is only available on a batched encoding. So we re-call it carefully:
            word_ids_chunk = encoded.word_ids(batch_index=i)
            chunk_word_ids.append(word_ids_chunk)

        # Now we combine chunk predictions into a single sequence of token-level labels.
        # Because we used a sliding window, tokens appear in multiple chunks. We can
        # keep the first occurrence, or we might want to carefully handle overlaps.
        # Below is a simplistic approach: We will read each chunk in order, skipping
        # positions with word_id=None or repeated word_id (subword).

        # We'll build final predictions for each head at the *token* level (not subword).
        # For each original token index from 0..len(raw_tokens)-1, we pick the first chunk
        # that includes it, and the subword=first-subword label.

        # We define an array of "final predictions" for each head, size = len(raw_tokens).
        final_pred_labels = {**{
            "text": text,
            "tokens": raw_tokens,
        }, **{
            head: ["O"] * len(raw_tokens)  # or "O" or "" placeholder
            for head in self.id2label.keys()
        }}

        # We'll keep track of which tokens we've already assigned. Each chunk is
        # processed left-to-right, so effectively the earliest chunk covers it.
        assigned_tokens = set()

        for i, pred_dict in enumerate(chunk_preds):
            w_ids = chunk_word_ids[i]
            for pos, w_id in enumerate(w_ids):
                if w_id is None:
                    # This is a special token (CLS, SEP, or padding)
                    continue
                if w_id in assigned_tokens:
                    # Already assigned from a previous chunk
                    continue

                # If it's the first subword of that token, record the predicted label for each head.
                # pred_dict[head_name] is a list of length seq_len
                for head_name, pred_ids in pred_dict.items():
                    label_id = pred_ids[pos]
                    label_str = self.id2label[head_name][label_id]
                    final_pred_labels[head_name][w_id] = label_str

                assigned_tokens.add(w_id)

        return final_pred_labels


if __name__ == "__main__":
    predictor = MultiHeadPredictor("./o3-mini_20250218_final")

    test_cases = [
        "How to convince my parents to let me get a Ball python?",
    ]
    for case in test_cases:
        predictions = predictor.predict(case)
        for head_name, labels in predictions.items():
            print(f"{head_name}: {labels}")