File size: 5,641 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Tuple, Optional, Union, Sequence


class CharTokenizer:
    """
    Maps every symbol in `alphabet` to the *smallest* consecutive integers
    starting at 1.  Index 0 is reserved for padding, so the set of usable
    integers is exactly ``range(1, vocab_size)`` and is therefore the minimal
    range required to represent the alphabet.
    """
    pad_token_id: int = 0
    pad_token: str = "<pad>"
    cls_token_id: int = -1  # Will be set in __init__
    cls_token: str = "<cls>"
    eos_token_id: int = -2  # Will be set in __init__
    eos_token: str = "<eos>"

    def __init__(self, alphabet: str) -> None:
        self.alphabet: str = alphabet
        self.char2id: Dict[str, int] = {ch: i + 1 for i, ch in enumerate(alphabet)}
        self.id2char: Dict[int, str] = {i + 1: ch for i, ch in enumerate(alphabet)}
        
        # Add special tokens
        next_id = len(alphabet) + 1
        self.cls_token_id = next_id
        self.char2id[self.cls_token] = self.cls_token_id
        self.id2char[self.cls_token_id] = self.cls_token
        
        next_id += 1
        self.eos_token_id = next_id
        self.char2id[self.eos_token] = self.eos_token_id
        self.id2char[self.eos_token_id] = self.eos_token
        
        self.vocab_size: int = len(alphabet) + 3  # +3 for pad, cls, and eos

    def encode(self, seq: str) -> torch.Tensor:
        """Convert a single string to a *1-D LongTensor* of ids (no padding)."""
        try:
            # Add cls at the beginning and eos at the end
            ids = [self.cls_token_id] + [self.char2id[ch] for ch in seq] + [self.eos_token_id]
        except KeyError as e:  # unknown character
            raise ValueError(f"Unknown symbol {e} for this alphabet") from None
        return torch.tensor(ids, dtype=torch.long)

    def decode(self, ids: Sequence[int], skip_pad: bool = True, skip_special: bool = True) -> str:
        """Convert a sequence of ids back to a string."""
        chars = []
        for idx in ids:
            # Skip special tokens if requested
            if (skip_pad and idx == self.pad_token_id) or \
               (skip_special and (idx == self.cls_token_id or idx == self.eos_token_id)):
                continue
            chars.append(self.id2char.get(int(idx), "?"))
        return "".join(chars)

    def __call__(
        self,
        sequences: Union[str, List[str], Tuple[str, ...]],
        return_tensors: str = "pt",
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        if isinstance(sequences, str):
            sequences = [sequences]

        # Tokenise each sequence
        encoded: List[torch.Tensor] = [self.encode(seq) for seq in sequences]

        # Pad on the **right** with `pad_token_id`
        input_ids: torch.Tensor = pad_sequence(
            encoded, batch_first=True, padding_value=self.pad_token_id
        )

        # Attention mask: 1 for real tokens, 0 for pad
        attention_mask = (input_ids != self.pad_token_id).to(torch.long)

        if return_tensors == "pt":
            return {"input_ids": input_ids, "attention_mask": attention_mask}
        raise ValueError(f"Unsupported tensor type: {return_tensors}")


class OneHotModel(nn.Module):
    """
    Fast, parameter-free one-hot projection.

    Forward signature follows HuggingFace style:
        forward(input_ids, attention_mask=None) -> Tensor[B, L, vocab]
    """

    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.vocab_size = vocab_size

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        one_hot = F.one_hot(input_ids, num_classes=self.vocab_size).float()
        if attention_mask is not None:
            one_hot = one_hot * attention_mask.unsqueeze(-1)  # zero out pads
        return one_hot


AA_SET = 'LAGVSERTIPDKQNFYMHWCXBUOZ*'
CODON_SET = 'aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})'
DNA_SET = 'ATCG'
RNA_SET = 'AUCG'

ALPHABET_DICT = {
    'OneHot-Protein': AA_SET,
    'OneHot-DNA': DNA_SET,
    'OneHot-RNA': RNA_SET,
    'OneHot-Codon': CODON_SET
}


def build_one_hot_model(preset: str = 'OneHot-Protein', model_path: str = None, **kwargs):
    alphabet = str(ALPHABET_DICT[preset])
    tokenizer = CharTokenizer(alphabet)
    model = OneHotModel(tokenizer.vocab_size)
    return model, tokenizer


def get_one_hot_tokenizer(preset: str, model_path: str = None):
    return CharTokenizer(ALPHABET_DICT[preset])


if __name__ == '__main__':
    # py -m base_models.one_hot
    model, tokenizer = build_one_hot_model()          # default: protein

    sequences = ["ACGT", "PROTEIN"]
    batch = tokenizer(sequences)

    one_hot = model(**batch)                          # [2, 7, vocab]
    print(f"input_ids shape      : {batch['input_ids'].shape}")
    print(f"attention_mask shape : {batch['attention_mask'].shape}")
    print(f"one-hot shape        : {one_hot.shape}")

    print("\n--- first sequence, first 5 one-hot rows (trimmed) ---")
    print(one_hot[0, :5])

    # round-trip demo
    decoded = tokenizer.decode(batch["input_ids"][1])
    print("\nDecoded second sequence:", decoded)
    
    # Show tokens with special tokens
    print("\nInput with special tokens visible:")
    decoded_with_special = tokenizer.decode(batch["input_ids"][0], skip_special=False)
    print(decoded_with_special)
    
    # Show token IDs
    print("\nToken IDs for first sequence:")
    print(batch["input_ids"][0])