CharRNN / src /preprocessing.py
hoom4n's picture
Upload 18 files
b6447fa verified
import torch
from typing import Dict
def character_lookup(text: str):
"""Build character vocabulary and lookup tables from input text."""
assert isinstance(text, str), "text must be a string object"
vocab = sorted(set(text.lower()))
vocab_len = len(vocab)
print(f"vocab len: {vocab_len}")
print(f"vocab characters: {repr(''.join(vocab))}")
char2id = {char: idx for idx, char in enumerate(vocab)}
id2char = {idx: char for idx, char in enumerate(vocab)}
return vocab_len, char2id, id2char
def text_encoder(text: str, char2id: Dict[str, int]) -> torch.Tensor:
"""Encode text string into tensor of character IDs."""
assert isinstance(text, str), "text must be a string object"
return torch.tensor([char2id[char] for char in text.lower()], dtype=torch.long)
def text_decoder(token_ids: torch.Tensor, id2char: Dict[int, str]) -> str:
"""Decode tensor of character IDs back into text string."""
assert isinstance(token_ids, torch.Tensor), "token_ids must be a torch tensor object"
return "".join(id2char[id_.item()] for id_ in token_ids)