Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
11c4819
1
Parent(s):
a31e03c
fixed image transform
Browse files- data_preprocessing.py +2 -2
- train.py +4 -5
data_preprocessing.py
CHANGED
|
@@ -74,7 +74,7 @@ class RandomizeImageTransform(object):
|
|
| 74 |
|
| 75 |
def __init__(self, width, height, random_magnitude):
|
| 76 |
self.transform = T.Compose((
|
| 77 |
-
lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
| 78 |
contrast=random_magnitude / 10,
|
| 79 |
saturation=random_magnitude / 10,
|
| 80 |
hue=min(0.5, random_magnitude / 10)),
|
|
@@ -83,7 +83,7 @@ class RandomizeImageTransform(object):
|
|
| 83 |
T.functional.invert,
|
| 84 |
T.CenterCrop((height, width)),
|
| 85 |
torch.Tensor.contiguous,
|
| 86 |
-
lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
| 87 |
T.ConvertImageDtype(torch.float32)
|
| 88 |
))
|
| 89 |
|
|
|
|
| 74 |
|
| 75 |
def __init__(self, width, height, random_magnitude):
|
| 76 |
self.transform = T.Compose((
|
| 77 |
+
(lambda x: x) if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
| 78 |
contrast=random_magnitude / 10,
|
| 79 |
saturation=random_magnitude / 10,
|
| 80 |
hue=min(0.5, random_magnitude / 10)),
|
|
|
|
| 83 |
T.functional.invert,
|
| 84 |
T.CenterCrop((height, width)),
|
| 85 |
torch.Tensor.contiguous,
|
| 86 |
+
(lambda x: x) if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
| 87 |
T.ConvertImageDtype(torch.float32)
|
| 88 |
))
|
| 89 |
|
train.py
CHANGED
|
@@ -13,8 +13,7 @@ import torch
|
|
| 13 |
|
| 14 |
|
| 15 |
def check_setup():
|
| 16 |
-
|
| 17 |
-
"Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
|
| 18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 19 |
if not os.path.isfile(DATAMODULE_PATH):
|
| 20 |
print("Generating default datamodule")
|
|
@@ -107,7 +106,7 @@ def main():
|
|
| 107 |
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
|
| 108 |
LearningRateMonitor(logging_interval="step"),
|
| 109 |
ModelCheckpoint(save_top_k=10,
|
| 110 |
-
every_n_train_steps=
|
| 111 |
monitor="val_loss",
|
| 112 |
mode="min",
|
| 113 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
|
@@ -135,9 +134,9 @@ def main():
|
|
| 135 |
trainer.fit(transformer, datamodule=datamodule)
|
| 136 |
trainer.test(transformer, datamodule=datamodule)
|
| 137 |
|
| 138 |
-
if args.log:
|
| 139 |
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
|
| 140 |
-
transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
|
| 141 |
transformer.eval()
|
| 142 |
transformer.freeze()
|
| 143 |
torch.save(transformer.state_dict(), transformer_path)
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def check_setup():
|
| 16 |
+
# Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out
|
|
|
|
| 17 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
if not os.path.isfile(DATAMODULE_PATH):
|
| 19 |
print("Generating default datamodule")
|
|
|
|
| 106 |
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
|
| 107 |
LearningRateMonitor(logging_interval="step"),
|
| 108 |
ModelCheckpoint(save_top_k=10,
|
| 109 |
+
every_n_train_steps=5,
|
| 110 |
monitor="val_loss",
|
| 111 |
mode="min",
|
| 112 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
|
|
|
| 134 |
trainer.fit(transformer, datamodule=datamodule)
|
| 135 |
trainer.test(transformer, datamodule=datamodule)
|
| 136 |
|
| 137 |
+
if args.log and len(os.listdir(trainer.checkpoint_callback.dirpath)):
|
| 138 |
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
|
| 139 |
+
transformer_path = os.path.join(RESOURCES, f"model_{trainer.logger.version}.pt")
|
| 140 |
transformer.eval()
|
| 141 |
transformer.freeze()
|
| 142 |
torch.save(transformer.state_dict(), transformer_path)
|