raster2seq / datasets /discrete_tokenizer.py
anas
Initial deployment of Raster2Seq floor plan vectorization API
fadb92b
import numpy as np
import torch
class DiscreteTokenizer(object):
def __init__(self, num_bins, seq_len, add_cls=False):
self.num_bins = num_bins
vocab_size = num_bins * num_bins
self.seq_len = seq_len
self.add_cls = add_cls
self.bos = vocab_size + 0
self.eos = vocab_size + 1
self.sep = vocab_size + 2
self.pad = vocab_size + 3
if add_cls:
self.cls = vocab_size + 4
self.vocab_size = vocab_size + 5
else:
self.vocab_size = vocab_size + 4
def __len__(self):
return self.vocab_size
def _padding(self, seq, pad_value, dtype):
if self.seq_len > len(seq):
seq.extend([pad_value] * (self.seq_len - len(seq)))
return torch.tensor(np.array(seq), dtype=dtype)
def __call__(self, seq, add_bos, add_eos, dtype, return_indices=False):
out = []
if add_bos:
out = [self.bos]
num_extra = 1 if not self.add_cls else 2 # cls and sep
indices = []
for i, sub in enumerate(seq):
cur_len = len(out)
# Append sub only if it doesn't exceed seq_len
if cur_len + len(sub) + num_extra <= self.seq_len:
out.extend(sub)
indices.append(i)
else:
continue
# Append cls and sep tokens only if it doesn't exceed seq_len
if self.add_cls:
out.append(self.cls) # cls token
out.append(self.sep)
# Remove last separator token if present
if out and out[-1] == self.sep:
out.pop(-1) # remove last separator token
if self.seq_len > len(out):
out.extend([self.pad] * (self.seq_len - len(out)))
if add_eos:
out[-1] = self.eos
if return_indices:
return torch.tensor(out, dtype=dtype), indices
return torch.tensor(out, dtype=dtype)