File size: 3,582 Bytes
4b08319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import warnings
from typing import NamedTuple

import torch
import torch.nn as nn

PromptType = str | list[str]


class ClassTokenizerOutput(NamedTuple):
    class_ids: torch.Tensor
    attention_mask: torch.Tensor


class ClassTokenizer:
    def __init__(
        self,
        label2id: dict[str, int],
        splitter: str = " ",
    ) -> None:
        self.label2id = label2id
        self.id2label = {v: k for k, v in label2id.items()}
        self.splitter = splitter

        self.pad_token_id = len(label2id)

        assert all([id < len(label2id) for id in label2id.values()]), (
            "All label IDs must be less than the number of classes."
        )

    def normalize_prompts(
        self,
        class_names: PromptType,
    ) -> list[str]:
        _class_names: list[str] = (
            class_names if isinstance(class_names, list) else [class_names]
        )
        return _class_names

    def tokenize(
        self,
        prompts: PromptType,
        max_length: int = 32,
    ) -> ClassTokenizerOutput:
        # 1. Normalize class names
        _prompts = self.normalize_prompts(prompts)

        # 2. Convert to IDs
        class_ids = []
        masks = []
        for text in _prompts:
            ids = []

            for label in text.split(self.splitter):
                if label.strip() == "":
                    continue
                id = self.label2id.get(label.strip())
                if id is not None:  # 0 is OK
                    ids.append(id)
                    masks.append(1)
                else:
                    warnings.warn(f"Label '{label}' not found in label2id mapping.")
            class_ids.append(ids)

        # 3. Pad to max_length
        padded_class_ids = []
        padded_masks = []

        for _i, ids in enumerate(class_ids):
            if len(ids) < max_length:
                mask = [1] * len(ids) + [0] * (max_length - len(ids))
                ids = ids + [self.pad_token_id] * (max_length - len(ids))  # padding idx
            else:
                mask = [1] * max_length
                ids = ids[:max_length]

            padded_class_ids.append(ids)
            padded_masks.append(mask)

        return ClassTokenizerOutput(
            class_ids=torch.tensor(padded_class_ids, dtype=torch.long),
            attention_mask=torch.tensor(padded_masks, dtype=torch.long),
        )


class ClassEncoderOutput(NamedTuple):
    embeddings: torch.Tensor
    attention_mask: torch.Tensor


class ClassEncoder(nn.Module):
    def __init__(
        self,
        label2id: dict[str, int],
        embedding_dim: int,
    ):
        super().__init__()

        self.num_classes = len(label2id)

        self.pad_token_id = self.num_classes  # padding idx

        self.embedding = nn.Embedding(
            self.num_classes + 1,  # +1 for padding idx
            embedding_dim,
            padding_idx=self.num_classes,
        )

        self.tokenizer = ClassTokenizer(label2id)

    def initialize_weights(self):
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)

    def encode_prompts(
        self,
        prompts: PromptType,
        max_token_length: int = 32,
    ):
        # 1. Tokenize prompts
        class_ids, attention_mask = self.tokenizer.tokenize(
            prompts,
            max_length=max_token_length,
        )

        # 3. Get embeddings
        embeddings = self.embedding(class_ids.to(self.embedding.weight.device))

        return ClassEncoderOutput(
            embeddings=embeddings,
            attention_mask=attention_mask,
        )