File size: 1,093 Bytes
b6447fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from typing import Dict

def character_lookup(text: str):
    """Build character vocabulary and lookup tables from input text."""
    assert isinstance(text, str), "text must be a string object"
    vocab = sorted(set(text.lower()))
    vocab_len = len(vocab)
    print(f"vocab len: {vocab_len}")
    print(f"vocab characters: {repr(''.join(vocab))}")
    char2id = {char: idx for idx, char in enumerate(vocab)}
    id2char = {idx: char for idx, char in enumerate(vocab)}
    return vocab_len, char2id, id2char

def text_encoder(text: str, char2id: Dict[str, int]) -> torch.Tensor:
    """Encode text string into tensor of character IDs."""
    assert isinstance(text, str), "text must be a string object"
    return torch.tensor([char2id[char] for char in text.lower()], dtype=torch.long)

def text_decoder(token_ids: torch.Tensor, id2char: Dict[int, str]) -> str:
    """Decode tensor of character IDs back into text string."""
    assert isinstance(token_ids, torch.Tensor), "token_ids must be a torch tensor object"
    return "".join(id2char[id_.item()] for id_ in token_ids)