Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
ae308b4
1
Parent(s):
e949d7b
script now generates a small dataset
Browse files- data_generator.py +3 -5
- data_preprocessing.py +20 -20
- train.py +14 -7
data_generator.py
CHANGED
|
@@ -153,14 +153,12 @@ def generate_image(directory: str, latex_path: str, filename: str, max_length=20
|
|
| 153 |
assert (pr2.returncode == 0)
|
| 154 |
|
| 155 |
|
| 156 |
-
def
|
| 157 |
filenames: iter(str),
|
| 158 |
-
directory: str
|
| 159 |
-
latex_path: str
|
| 160 |
overwrite: bool = False
|
| 161 |
) -> None:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
"""
|
| 165 |
Generates a latex dataset in given directory
|
| 166 |
-------
|
|
|
|
| 153 |
assert (pr2.returncode == 0)
|
| 154 |
|
| 155 |
|
| 156 |
+
def generate_data(
|
| 157 |
filenames: iter(str),
|
| 158 |
+
directory: str,
|
| 159 |
+
latex_path: str,
|
| 160 |
overwrite: bool = False
|
| 161 |
) -> None:
|
|
|
|
|
|
|
| 162 |
"""
|
| 163 |
Generates a latex dataset in given directory
|
| 164 |
-------
|
data_preprocessing.py
CHANGED
|
@@ -67,26 +67,6 @@ class TexImageDataset(Dataset):
|
|
| 67 |
else:
|
| 68 |
self.image_transform = normalize
|
| 69 |
|
| 70 |
-
def subjoin_tex_tokenize_transform(self, texs, vocab_size=300):
|
| 71 |
-
"""Returns a tokenizer trained on given tex strings"""
|
| 72 |
-
|
| 73 |
-
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 74 |
-
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
| 75 |
-
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
| 76 |
-
vocab_size=vocab_size,
|
| 77 |
-
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
| 78 |
-
)
|
| 79 |
-
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
| 80 |
-
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
| 81 |
-
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
| 82 |
-
single="$A [SEP]",
|
| 83 |
-
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
|
| 84 |
-
)
|
| 85 |
-
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
| 86 |
-
|
| 87 |
-
self.tokenizer = tokenizer
|
| 88 |
-
return tokenizer
|
| 89 |
-
|
| 90 |
|
| 91 |
class BatchCollator(object):
|
| 92 |
"""Image, tex batch collator"""
|
|
@@ -156,3 +136,23 @@ class ExtractEquationFromTexTransform(object):
|
|
| 156 |
equation = equation.strip()
|
| 157 |
equation = self.spaces.sub(' ', equation)
|
| 158 |
return equation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
else:
|
| 68 |
self.image_transform = normalize
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
class BatchCollator(object):
|
| 72 |
"""Image, tex batch collator"""
|
|
|
|
| 136 |
equation = equation.strip()
|
| 137 |
equation = self.spaces.sub(' ', equation)
|
| 138 |
return equation
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def generate_tex_tokenizer(texs, vocab_size=300):
|
| 142 |
+
"""Returns a tokenizer trained on given tex strings"""
|
| 143 |
+
|
| 144 |
+
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 145 |
+
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
| 146 |
+
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
| 147 |
+
vocab_size=vocab_size,
|
| 148 |
+
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
| 149 |
+
)
|
| 150 |
+
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
| 151 |
+
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
| 152 |
+
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
| 153 |
+
single="$A [SEP]",
|
| 154 |
+
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
|
| 155 |
+
)
|
| 156 |
+
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
| 157 |
+
|
| 158 |
+
return tokenizer
|
train.py
CHANGED
|
@@ -1,23 +1,30 @@
|
|
|
|
|
| 1 |
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
|
| 2 |
-
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
|
| 8 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
image_transform = RandomizeImageTransform()
|
| 10 |
tex_transform = ExtractEquationFromTexTransform()
|
| 11 |
-
dataset = TexImageDataset(
|
| 12 |
dataset.subjoin_image_normalize_transform()
|
| 13 |
train_dataset, test_dataset = torch.utils.data.random_split(
|
| 14 |
dataset,
|
| 15 |
[len(dataset) * 9 // 10, len(dataset) // 10]
|
| 16 |
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
tokenizer = generate_tex_tokenizer(texs)
|
| 20 |
-
collate_fn = BatchCollator(tokenizer)
|
| 21 |
|
| 22 |
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
|
| 23 |
collate_fn=collate_fn)
|
|
|
|
| 1 |
+
from data_generator import generate_data
|
| 2 |
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
|
| 3 |
+
BatchCollator, generate_tex_tokenizer
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from torch.utils.data import DataLoader
|
| 7 |
+
|
| 8 |
+
DATA_DIR = 'data'
|
| 9 |
+
LATEX_PATH = 'resources/latex.json'
|
| 10 |
|
| 11 |
if __name__ == '__main__':
|
| 12 |
+
generate_data(
|
| 13 |
+
filenames=map(str, range(1000)),
|
| 14 |
+
directory=DATA_DIR,
|
| 15 |
+
latex_path=LATEX_PATH,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
image_transform = RandomizeImageTransform()
|
| 19 |
tex_transform = ExtractEquationFromTexTransform()
|
| 20 |
+
dataset = TexImageDataset(DATA_DIR, image_transform=image_transform, tex_transform=tex_transform)
|
| 21 |
dataset.subjoin_image_normalize_transform()
|
| 22 |
train_dataset, test_dataset = torch.utils.data.random_split(
|
| 23 |
dataset,
|
| 24 |
[len(dataset) * 9 // 10, len(dataset) // 10]
|
| 25 |
)
|
| 26 |
+
tex_tokenizer = generate_tex_tokenizer(dataset.texs)
|
| 27 |
+
collate_fn = BatchCollator(tex_tokenizer)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
|
| 30 |
collate_fn=collate_fn)
|