Spaces:
Runtime error
Runtime error
feat(train): save to bucket
Browse files- tools/train/train.py +70 -47
tools/train/train.py
CHANGED
|
@@ -18,7 +18,7 @@ Training DALL·E Mini.
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
import
|
| 22 |
import logging
|
| 23 |
import os
|
| 24 |
import sys
|
|
@@ -41,6 +41,7 @@ from flax.core.frozen_dict import FrozenDict, freeze
|
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import onehot
|
|
|
|
| 44 |
from jax.experimental import PartitionSpec, maps
|
| 45 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 46 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
@@ -59,7 +60,6 @@ cc.initialize_cache(
|
|
| 59 |
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
| 60 |
)
|
| 61 |
|
| 62 |
-
|
| 63 |
logger = logging.getLogger(__name__)
|
| 64 |
|
| 65 |
|
|
@@ -123,17 +123,20 @@ class ModelArguments:
|
|
| 123 |
else:
|
| 124 |
return dict()
|
| 125 |
|
| 126 |
-
def get_opt_state(self
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
@dataclass
|
|
@@ -785,10 +788,9 @@ def main():
|
|
| 785 |
|
| 786 |
else:
|
| 787 |
# load opt_state
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
opt_state_file.close()
|
| 792 |
|
| 793 |
# restore other attributes
|
| 794 |
attr_state = {
|
|
@@ -1034,42 +1036,60 @@ def main():
|
|
| 1034 |
|
| 1035 |
def run_save_model(state, eval_metrics=None):
|
| 1036 |
if jax.process_index() == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
params = jax.device_get(state.params)
|
| 1038 |
-
# save model locally
|
| 1039 |
model.save_pretrained(
|
| 1040 |
-
|
| 1041 |
params=params,
|
| 1042 |
)
|
| 1043 |
|
| 1044 |
# save tokenizer
|
| 1045 |
-
tokenizer.save_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
# save state
|
| 1048 |
opt_state = jax.device_get(state.opt_state)
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
"w"
|
| 1057 |
-
) as f:
|
| 1058 |
-
json.dump(
|
| 1059 |
-
state_dict,
|
| 1060 |
-
f,
|
| 1061 |
-
)
|
| 1062 |
|
| 1063 |
# save to W&B
|
| 1064 |
if training_args.log_model:
|
| 1065 |
# save some space
|
| 1066 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 1067 |
-
c.cleanup(wandb.util.from_human_size("
|
| 1068 |
|
| 1069 |
-
metadata =
|
|
|
|
|
|
|
|
|
|
| 1070 |
metadata["num_params"] = num_params
|
| 1071 |
if eval_metrics is not None:
|
| 1072 |
metadata["eval"] = eval_metrics
|
|
|
|
|
|
|
| 1073 |
|
| 1074 |
# create model artifact
|
| 1075 |
artifact = wandb.Artifact(
|
|
@@ -1077,16 +1097,19 @@ def main():
|
|
| 1077 |
type="DalleBart_model",
|
| 1078 |
metadata=metadata,
|
| 1079 |
)
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
|
|
|
|
|
|
|
|
|
| 1090 |
wandb.run.log_artifact(artifact)
|
| 1091 |
|
| 1092 |
# create state artifact
|
|
@@ -1095,9 +1118,9 @@ def main():
|
|
| 1095 |
type="DalleBart_state",
|
| 1096 |
metadata=metadata,
|
| 1097 |
)
|
| 1098 |
-
|
| 1099 |
artifact_state.add_file(
|
| 1100 |
-
f"{Path(training_args.output_dir) /
|
| 1101 |
)
|
| 1102 |
wandb.run.log_artifact(artifact_state)
|
| 1103 |
|
|
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
import io
|
| 22 |
import logging
|
| 23 |
import os
|
| 24 |
import sys
|
|
|
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import onehot
|
| 44 |
+
from google.cloud import storage
|
| 45 |
from jax.experimental import PartitionSpec, maps
|
| 46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
|
| 60 |
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
| 61 |
)
|
| 62 |
|
|
|
|
| 63 |
logger = logging.getLogger(__name__)
|
| 64 |
|
| 65 |
|
|
|
|
| 123 |
else:
|
| 124 |
return dict()
|
| 125 |
|
| 126 |
+
def get_opt_state(self):
|
| 127 |
+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
| 128 |
+
if self.restore_state is True:
|
| 129 |
+
# wandb artifact
|
| 130 |
+
state_artifact = self.model_name_or_path.replace(
|
| 131 |
+
"/model-", "/state-", 1
|
| 132 |
+
)
|
| 133 |
+
if jax.process_index() == 0:
|
| 134 |
+
artifact = wandb.run.use_artifact(state_artifact)
|
| 135 |
+
else:
|
| 136 |
+
artifact = wandb.Api().artifact(state_artifact)
|
| 137 |
+
artifact_dir = artifact.download(tmp_dir)
|
| 138 |
+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
|
| 139 |
+
return Path(self.restore_state).open("rb")
|
| 140 |
|
| 141 |
|
| 142 |
@dataclass
|
|
|
|
| 788 |
|
| 789 |
else:
|
| 790 |
# load opt_state
|
| 791 |
+
opt_state_file = model_args.get_opt_state()
|
| 792 |
+
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
|
| 793 |
+
opt_state_file.close()
|
|
|
|
| 794 |
|
| 795 |
# restore other attributes
|
| 796 |
attr_state = {
|
|
|
|
| 1036 |
|
| 1037 |
def run_save_model(state, eval_metrics=None):
|
| 1038 |
if jax.process_index() == 0:
|
| 1039 |
+
|
| 1040 |
+
output_dir = training_args.output_dir
|
| 1041 |
+
use_bucket = output_dir.startswith("gs://")
|
| 1042 |
+
if use_bucket:
|
| 1043 |
+
bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
|
| 1044 |
+
bucket, dir_path = str(bucket_path).split("/", 1)
|
| 1045 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
| 1046 |
+
output_dir = tmp_dir.name
|
| 1047 |
+
|
| 1048 |
+
# save model
|
| 1049 |
params = jax.device_get(state.params)
|
|
|
|
| 1050 |
model.save_pretrained(
|
| 1051 |
+
output_dir,
|
| 1052 |
params=params,
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
# save tokenizer
|
| 1056 |
+
tokenizer.save_pretrained(output_dir)
|
| 1057 |
+
|
| 1058 |
+
# copy to bucket
|
| 1059 |
+
if use_bucket:
|
| 1060 |
+
client = storage.Client()
|
| 1061 |
+
bucket = client.bucket(bucket)
|
| 1062 |
+
for filename in Path(output_dir).glob("*"):
|
| 1063 |
+
blob_name = str(Path(dir_path) / filename.name)
|
| 1064 |
+
blob = bucket.blob(blob_name)
|
| 1065 |
+
blob.upload_from_filename(str(filename))
|
| 1066 |
+
tmp_dir.cleanup()
|
| 1067 |
|
| 1068 |
# save state
|
| 1069 |
opt_state = jax.device_get(state.opt_state)
|
| 1070 |
+
if use_bucket:
|
| 1071 |
+
blob_name = str(Path(dir_path) / "opt_state.msgpack")
|
| 1072 |
+
blob = bucket.blob(blob_name)
|
| 1073 |
+
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
| 1074 |
+
else:
|
| 1075 |
+
with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
|
| 1076 |
+
f.write(to_bytes(opt_state))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
|
| 1078 |
# save to W&B
|
| 1079 |
if training_args.log_model:
|
| 1080 |
# save some space
|
| 1081 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 1082 |
+
c.cleanup(wandb.util.from_human_size("20GB"))
|
| 1083 |
|
| 1084 |
+
metadata = {
|
| 1085 |
+
k: jax.device_get(getattr(state, k)).item()
|
| 1086 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
| 1087 |
+
}
|
| 1088 |
metadata["num_params"] = num_params
|
| 1089 |
if eval_metrics is not None:
|
| 1090 |
metadata["eval"] = eval_metrics
|
| 1091 |
+
if use_bucket:
|
| 1092 |
+
metadata["bucket_path"] = bucket_path
|
| 1093 |
|
| 1094 |
# create model artifact
|
| 1095 |
artifact = wandb.Artifact(
|
|
|
|
| 1097 |
type="DalleBart_model",
|
| 1098 |
metadata=metadata,
|
| 1099 |
)
|
| 1100 |
+
if not use_bucket:
|
| 1101 |
+
for filename in [
|
| 1102 |
+
"config.json",
|
| 1103 |
+
"flax_model.msgpack",
|
| 1104 |
+
"merges.txt",
|
| 1105 |
+
"special_tokens_map.json",
|
| 1106 |
+
"tokenizer.json",
|
| 1107 |
+
"tokenizer_config.json",
|
| 1108 |
+
"vocab.json",
|
| 1109 |
+
]:
|
| 1110 |
+
artifact.add_file(
|
| 1111 |
+
f"{Path(training_args.output_dir) / filename}"
|
| 1112 |
+
)
|
| 1113 |
wandb.run.log_artifact(artifact)
|
| 1114 |
|
| 1115 |
# create state artifact
|
|
|
|
| 1118 |
type="DalleBart_state",
|
| 1119 |
metadata=metadata,
|
| 1120 |
)
|
| 1121 |
+
if not use_bucket:
|
| 1122 |
artifact_state.add_file(
|
| 1123 |
+
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
| 1124 |
)
|
| 1125 |
wandb.run.log_artifact(artifact_state)
|
| 1126 |
|