Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
1b4da0d
1
Parent(s):
96feb73
improved interface
Browse files- app.py +18 -9
- data_preprocessing.py +7 -4
app.py
CHANGED
|
@@ -9,7 +9,6 @@ import torchvision.transforms as T
|
|
| 9 |
|
| 10 |
MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
|
| 11 |
|
| 12 |
-
# TODO: make faster
|
| 13 |
transformer = torch.load(MODEL_PATH)
|
| 14 |
image_transform = T.Compose((
|
| 15 |
T.ToTensor(),
|
|
@@ -18,12 +17,22 @@ image_transform = T.Compose((
|
|
| 18 |
random_magnitude=0)
|
| 19 |
))
|
| 20 |
|
| 21 |
-
st.
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
image = image.convert("RGB")
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
|
| 11 |
|
|
|
|
| 12 |
transformer = torch.load(MODEL_PATH)
|
| 13 |
image_transform = T.Compose((
|
| 14 |
T.ToTensor(),
|
|
|
|
| 17 |
random_magnitude=0)
|
| 18 |
))
|
| 19 |
|
| 20 |
+
st.title("Image to TeX")
|
| 21 |
+
|
| 22 |
+
st.image("resources/frontend/fraction_derivative.png", width=500)
|
| 23 |
+
st.image("resources/frontend/positional_encoding.png")
|
| 24 |
+
st.image("resources/frontend/taylor_sequence_expanded.png")
|
| 25 |
+
# st.image("resources/frontend/taylor_sequence.png")
|
| 26 |
+
# st.image("resources/frontend/maclaurin_series.png")
|
| 27 |
+
# st.image("resources/frontend/gauss_distribution.png")
|
| 28 |
+
|
| 29 |
+
image_file = st.file_uploader("Upload an image with equation", type=([".png", ".jpg", ".jpeg"]))
|
| 30 |
+
|
| 31 |
+
if image_file is not None:
|
| 32 |
+
image = PIL.Image.open(image_file)
|
| 33 |
image = image.convert("RGB")
|
| 34 |
+
texs = beam_search_decode(transformer, image, image_transform=image_transform)
|
| 35 |
+
# streamlit latex doesn't support boldmath
|
| 36 |
+
tex = texs[0].replace("\\boldmath", "")
|
| 37 |
+
st.latex(tex)
|
| 38 |
+
st.markdown(tex)
|
data_preprocessing.py
CHANGED
|
@@ -74,14 +74,16 @@ class RandomizeImageTransform(object):
|
|
| 74 |
|
| 75 |
def __init__(self, width, height, random_magnitude):
|
| 76 |
self.transform = T.Compose((
|
| 77 |
-
T.ColorJitter(brightness=random_magnitude / 10,
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
T.Resize(height, max_size=width),
|
| 80 |
T.Grayscale(),
|
| 81 |
T.functional.invert,
|
| 82 |
T.CenterCrop((height, width)),
|
| 83 |
torch.Tensor.contiguous,
|
| 84 |
-
T.RandAugment(magnitude=random_magnitude),
|
| 85 |
T.ConvertImageDtype(torch.float32)
|
| 86 |
))
|
| 87 |
|
|
@@ -133,7 +135,8 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
| 133 |
super().__init__()
|
| 134 |
|
| 135 |
dataset = TexImageDataset(root_dir=DATA_DIR,
|
| 136 |
-
image_transform=RandomizeImageTransform(image_width, image_height,
|
|
|
|
| 137 |
tex_transform=ExtractEquationFromTexTransform())
|
| 138 |
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
| 139 |
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
|
|
|
|
| 74 |
|
| 75 |
def __init__(self, width, height, random_magnitude):
|
| 76 |
self.transform = T.Compose((
|
| 77 |
+
lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
| 78 |
+
contrast=random_magnitude / 10,
|
| 79 |
+
saturation=random_magnitude / 10,
|
| 80 |
+
hue=min(0.5, random_magnitude / 10)),
|
| 81 |
T.Resize(height, max_size=width),
|
| 82 |
T.Grayscale(),
|
| 83 |
T.functional.invert,
|
| 84 |
T.CenterCrop((height, width)),
|
| 85 |
torch.Tensor.contiguous,
|
| 86 |
+
lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
| 87 |
T.ConvertImageDtype(torch.float32)
|
| 88 |
))
|
| 89 |
|
|
|
|
| 135 |
super().__init__()
|
| 136 |
|
| 137 |
dataset = TexImageDataset(root_dir=DATA_DIR,
|
| 138 |
+
image_transform=RandomizeImageTransform(image_width, image_height,
|
| 139 |
+
random_magnitude),
|
| 140 |
tex_transform=ExtractEquationFromTexTransform())
|
| 141 |
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
| 142 |
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
|