Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,582 Bytes
4b08319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import warnings
from typing import NamedTuple
import torch
import torch.nn as nn
PromptType = str | list[str]
class ClassTokenizerOutput(NamedTuple):
class_ids: torch.Tensor
attention_mask: torch.Tensor
class ClassTokenizer:
def __init__(
self,
label2id: dict[str, int],
splitter: str = " ",
) -> None:
self.label2id = label2id
self.id2label = {v: k for k, v in label2id.items()}
self.splitter = splitter
self.pad_token_id = len(label2id)
assert all([id < len(label2id) for id in label2id.values()]), (
"All label IDs must be less than the number of classes."
)
def normalize_prompts(
self,
class_names: PromptType,
) -> list[str]:
_class_names: list[str] = (
class_names if isinstance(class_names, list) else [class_names]
)
return _class_names
def tokenize(
self,
prompts: PromptType,
max_length: int = 32,
) -> ClassTokenizerOutput:
# 1. Normalize class names
_prompts = self.normalize_prompts(prompts)
# 2. Convert to IDs
class_ids = []
masks = []
for text in _prompts:
ids = []
for label in text.split(self.splitter):
if label.strip() == "":
continue
id = self.label2id.get(label.strip())
if id is not None: # 0 is OK
ids.append(id)
masks.append(1)
else:
warnings.warn(f"Label '{label}' not found in label2id mapping.")
class_ids.append(ids)
# 3. Pad to max_length
padded_class_ids = []
padded_masks = []
for _i, ids in enumerate(class_ids):
if len(ids) < max_length:
mask = [1] * len(ids) + [0] * (max_length - len(ids))
ids = ids + [self.pad_token_id] * (max_length - len(ids)) # padding idx
else:
mask = [1] * max_length
ids = ids[:max_length]
padded_class_ids.append(ids)
padded_masks.append(mask)
return ClassTokenizerOutput(
class_ids=torch.tensor(padded_class_ids, dtype=torch.long),
attention_mask=torch.tensor(padded_masks, dtype=torch.long),
)
class ClassEncoderOutput(NamedTuple):
embeddings: torch.Tensor
attention_mask: torch.Tensor
class ClassEncoder(nn.Module):
def __init__(
self,
label2id: dict[str, int],
embedding_dim: int,
):
super().__init__()
self.num_classes = len(label2id)
self.pad_token_id = self.num_classes # padding idx
self.embedding = nn.Embedding(
self.num_classes + 1, # +1 for padding idx
embedding_dim,
padding_idx=self.num_classes,
)
self.tokenizer = ClassTokenizer(label2id)
def initialize_weights(self):
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
def encode_prompts(
self,
prompts: PromptType,
max_token_length: int = 32,
):
# 1. Tokenize prompts
class_ids, attention_mask = self.tokenizer.tokenize(
prompts,
max_length=max_token_length,
)
# 3. Get embeddings
embeddings = self.embedding(class_ids.to(self.embedding.weight.device))
return ClassEncoderOutput(
embeddings=embeddings,
attention_mask=attention_mask,
)
|