File size: 5,641 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Tuple, Optional, Union, Sequence
class CharTokenizer:
"""
Maps every symbol in `alphabet` to the *smallest* consecutive integers
starting at 1. Index 0 is reserved for padding, so the set of usable
integers is exactly ``range(1, vocab_size)`` and is therefore the minimal
range required to represent the alphabet.
"""
pad_token_id: int = 0
pad_token: str = "<pad>"
cls_token_id: int = -1 # Will be set in __init__
cls_token: str = "<cls>"
eos_token_id: int = -2 # Will be set in __init__
eos_token: str = "<eos>"
def __init__(self, alphabet: str) -> None:
self.alphabet: str = alphabet
self.char2id: Dict[str, int] = {ch: i + 1 for i, ch in enumerate(alphabet)}
self.id2char: Dict[int, str] = {i + 1: ch for i, ch in enumerate(alphabet)}
# Add special tokens
next_id = len(alphabet) + 1
self.cls_token_id = next_id
self.char2id[self.cls_token] = self.cls_token_id
self.id2char[self.cls_token_id] = self.cls_token
next_id += 1
self.eos_token_id = next_id
self.char2id[self.eos_token] = self.eos_token_id
self.id2char[self.eos_token_id] = self.eos_token
self.vocab_size: int = len(alphabet) + 3 # +3 for pad, cls, and eos
def encode(self, seq: str) -> torch.Tensor:
"""Convert a single string to a *1-D LongTensor* of ids (no padding)."""
try:
# Add cls at the beginning and eos at the end
ids = [self.cls_token_id] + [self.char2id[ch] for ch in seq] + [self.eos_token_id]
except KeyError as e: # unknown character
raise ValueError(f"Unknown symbol {e} for this alphabet") from None
return torch.tensor(ids, dtype=torch.long)
def decode(self, ids: Sequence[int], skip_pad: bool = True, skip_special: bool = True) -> str:
"""Convert a sequence of ids back to a string."""
chars = []
for idx in ids:
# Skip special tokens if requested
if (skip_pad and idx == self.pad_token_id) or \
(skip_special and (idx == self.cls_token_id or idx == self.eos_token_id)):
continue
chars.append(self.id2char.get(int(idx), "?"))
return "".join(chars)
def __call__(
self,
sequences: Union[str, List[str], Tuple[str, ...]],
return_tensors: str = "pt",
**kwargs,
) -> Dict[str, torch.Tensor]:
if isinstance(sequences, str):
sequences = [sequences]
# Tokenise each sequence
encoded: List[torch.Tensor] = [self.encode(seq) for seq in sequences]
# Pad on the **right** with `pad_token_id`
input_ids: torch.Tensor = pad_sequence(
encoded, batch_first=True, padding_value=self.pad_token_id
)
# Attention mask: 1 for real tokens, 0 for pad
attention_mask = (input_ids != self.pad_token_id).to(torch.long)
if return_tensors == "pt":
return {"input_ids": input_ids, "attention_mask": attention_mask}
raise ValueError(f"Unsupported tensor type: {return_tensors}")
class OneHotModel(nn.Module):
"""
Fast, parameter-free one-hot projection.
Forward signature follows HuggingFace style:
forward(input_ids, attention_mask=None) -> Tensor[B, L, vocab]
"""
def __init__(self, vocab_size: int) -> None:
super().__init__()
self.vocab_size = vocab_size
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
one_hot = F.one_hot(input_ids, num_classes=self.vocab_size).float()
if attention_mask is not None:
one_hot = one_hot * attention_mask.unsqueeze(-1) # zero out pads
return one_hot
AA_SET = 'LAGVSERTIPDKQNFYMHWCXBUOZ*'
CODON_SET = 'aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})'
DNA_SET = 'ATCG'
RNA_SET = 'AUCG'
ALPHABET_DICT = {
'OneHot-Protein': AA_SET,
'OneHot-DNA': DNA_SET,
'OneHot-RNA': RNA_SET,
'OneHot-Codon': CODON_SET
}
def build_one_hot_model(preset: str = 'OneHot-Protein', model_path: str = None, **kwargs):
alphabet = str(ALPHABET_DICT[preset])
tokenizer = CharTokenizer(alphabet)
model = OneHotModel(tokenizer.vocab_size)
return model, tokenizer
def get_one_hot_tokenizer(preset: str, model_path: str = None):
return CharTokenizer(ALPHABET_DICT[preset])
if __name__ == '__main__':
# py -m base_models.one_hot
model, tokenizer = build_one_hot_model() # default: protein
sequences = ["ACGT", "PROTEIN"]
batch = tokenizer(sequences)
one_hot = model(**batch) # [2, 7, vocab]
print(f"input_ids shape : {batch['input_ids'].shape}")
print(f"attention_mask shape : {batch['attention_mask'].shape}")
print(f"one-hot shape : {one_hot.shape}")
print("\n--- first sequence, first 5 one-hot rows (trimmed) ---")
print(one_hot[0, :5])
# round-trip demo
decoded = tokenizer.decode(batch["input_ids"][1])
print("\nDecoded second sequence:", decoded)
# Show tokens with special tokens
print("\nInput with special tokens visible:")
decoded_with_special = tokenizer.decode(batch["input_ids"][0], skip_special=False)
print(decoded_with_special)
# Show token IDs
print("\nToken IDs for first sequence:")
print(batch["input_ids"][0])
|