KhangTruong's picture
Super-squash branch 'main' using huggingface_hub
6f31f53 verified
from ..imports import *
class ImageTextLoader:
def __init__(self):
with open(self.getDirectory()) as f:
data = json.load(f)
self.data = data
self.batches = [list(range(i * BATCH_SIZE, (i + 1) * BATCH_SIZE))
for i in range(len(self.data) // BATCH_SIZE)]
@abstractmethod
def getDirectory(self) -> str:
pass
def __len__(self):
return len(self.data)
def __getitem__(self, item):
img, inp, label = self.getData(item)
return (img, inp), label
def __iter__(self):
return (self[i] for i in range(len(self)))
def getData(self, item):
directory, vectors = self.data[item]
vectors: list
good_vectors = [vector[:-1] for vector in vectors]
[good_vector.extend([0] * (MAXIMUM_LENGTH - len(good_vector))) for good_vector in good_vectors]
[vector.extend([0] * (MAXIMUM_LENGTH - len(vector))) for vector in vectors]
img = tf.constant(cv2.resize(cv2.imread(directory), IDEAL_SHAPE[:-1]), dtype=tf.float32)
textTensor = tf.constant(good_vectors, dtype=tf.int32)
shiftedVectors = [vector[1:] + [0] for vector in vectors]
shiftedTensor = tf.constant(shiftedVectors, dtype=tf.int32)
return img, textTensor, shiftedTensor
def getDirWithCorpus(self, item: str | int):
directory = self.data[item][0] if type(item) is int else item
all_vector = [value[1] for value in self.data if value[0] == directory]
return directory, all_vector
class TrainDataset(ImageTextLoader):
def __init__(self):
super().__init__()
random.shuffle(self.data)
def getDirectory(self) -> str:
return 'train.json'
class TestDataset(ImageTextLoader):
def __init__(self):
super().__init__()
random.shuffle(self.data)
def getDirectory(self) -> str:
return 'test.json'
def get_dataset(train=True):
ds = tf.data.Dataset.from_generator(TrainDataset if train else TestDataset, output_signature=(
(
tf.TensorSpec((None, None, 3)),
tf.TensorSpec((None, MAXIMUM_LENGTH))
),
tf.TensorSpec((None, MAXIMUM_LENGTH))
)
)
ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
return ds