Spaces:
Runtime error
Runtime error
| # train.py | |
| import sys | |
| import tensorflow as tf | |
| # def create_datasets(x_train, y_train, text_vectorizer, batch_size): | |
| # print('Building slices...') | |
| # train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size) | |
| # print('Mapping...') | |
| # train_dataset = train_dataset.map(lambda x, y: (text_vectorizer(x), y), tf.data.AUTOTUNE) | |
| # print('Prefetching...') | |
| # train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) | |
| # return train_dataset | |
| def sizeof_fmt(num, suffix='B'): | |
| ''' by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified''' | |
| for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: | |
| if abs(num) < 1024.0: | |
| return "%3.1f %s%s" % (num, unit, suffix) | |
| num /= 1024.0 | |
| return "%.1f %s%s" % (num, 'Yi', suffix) | |
| for name, size in sorted(((name, sys.getsizeof(value)) for name, value in list( | |
| locals().items())), key= lambda x: -x[1], reverse = False)[:10]: | |
| print("{:>30}: {:>8}".format(name, sizeof_fmt(size))) | |
| def data_generator(x, y): | |
| num_samples = len(x) | |
| for i in range(num_samples): | |
| yield x[i], y[i] | |
| def create_datasets(x, y, text_vectorizer, batch_size:int = 32, shuffle:bool=False, n_repeat:int = 0, buffer_size:int=1_000_000): | |
| generator = data_generator(x, y) | |
| print('Generating...') | |
| train_dataset = tf.data.Dataset.from_generator( | |
| lambda: generator, | |
| output_signature=( | |
| tf.TensorSpec(shape=(None, x.shape[1]), dtype=tf.string), | |
| tf.TensorSpec(shape=(None, y.shape[1]), dtype=tf.int32) | |
| ) | |
| ) | |
| print('Mapping...') | |
| train_dataset = train_dataset.map(lambda x, y: (tf.cast(text_vectorizer(x), tf.int32)[0], y[0]), tf.data.AUTOTUNE) | |
| train_dataset = train_dataset.batch(batch_size) | |
| if shuffle: | |
| train_dataset = train_dataset.shuffle(buffer_size) | |
| if n_repeat > 0: | |
| return train_dataset.cache().repeat(n_repeat).prefetch(tf.data.AUTOTUNE) | |
| elif n_repeat == -1: | |
| return train_dataset.cache().repeat().prefetch(tf.data.AUTOTUNE) | |
| elif n_repeat == 0: | |
| return train_dataset.cache().prefetch(tf.data.AUTOTUNE) | |