Feat: Add dataset loading from S3, GCS (#765)
Browse files* Feat: Add dataset loading from S3, GCS
* chore: update docs
* chore: add more info on cloud loading
- README.md +7 -1
- requirements.txt +6 -1
- src/axolotl/utils/data.py +97 -19
README.md
CHANGED
|
@@ -426,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|
| 426 |
- path: knowrohit07/know_sql
|
| 427 |
type: context_qa.load_v2
|
| 428 |
train_on_split: validation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
```
|
| 430 |
|
| 431 |
- loading
|
|
@@ -520,7 +526,7 @@ float16: true
|
|
| 520 |
|
| 521 |
# A list of one or more datasets to finetune the model with
|
| 522 |
datasets:
|
| 523 |
-
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
| 524 |
- path: vicgalle/alpaca-gpt4
|
| 525 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 526 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
|
|
|
| 426 |
- path: knowrohit07/know_sql
|
| 427 |
type: context_qa.load_v2
|
| 428 |
train_on_split: validation
|
| 429 |
+
|
| 430 |
+
# loading from s3 or gcs
|
| 431 |
+
# s3 creds will be loaded from the system default and gcs only supports public access
|
| 432 |
+
dataset:
|
| 433 |
+
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
| 434 |
+
...
|
| 435 |
```
|
| 436 |
|
| 437 |
- loading
|
|
|
|
| 526 |
|
| 527 |
# A list of one or more datasets to finetune the model with
|
| 528 |
datasets:
|
| 529 |
+
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
| 530 |
- path: vicgalle/alpaca-gpt4
|
| 531 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 532 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
requirements.txt
CHANGED
|
@@ -11,7 +11,7 @@ deepspeed
|
|
| 11 |
addict
|
| 12 |
fire
|
| 13 |
PyYAML>=6.0
|
| 14 |
-
datasets
|
| 15 |
flash-attn>=2.3.0
|
| 16 |
sentencepiece
|
| 17 |
wandb
|
|
@@ -33,3 +33,8 @@ art
|
|
| 33 |
fschat==0.2.29
|
| 34 |
gradio
|
| 35 |
tensorboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
addict
|
| 12 |
fire
|
| 13 |
PyYAML>=6.0
|
| 14 |
+
datasets>=2.14.0
|
| 15 |
flash-attn>=2.3.0
|
| 16 |
sentencepiece
|
| 17 |
wandb
|
|
|
|
| 33 |
fschat==0.2.29
|
| 34 |
gradio
|
| 35 |
tensorboard
|
| 36 |
+
|
| 37 |
+
# remote filesystems
|
| 38 |
+
s3fs
|
| 39 |
+
gcsfs
|
| 40 |
+
# adlfs
|
src/axolotl/utils/data.py
CHANGED
|
@@ -170,30 +170,74 @@ def load_tokenized_prepared_datasets(
|
|
| 170 |
except (FileNotFoundError, ConnectionError):
|
| 171 |
pass
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# prefer local dataset, even if hub exists
|
| 174 |
local_path = Path(config_dataset.path)
|
| 175 |
if local_path.exists():
|
| 176 |
if local_path.is_dir():
|
| 177 |
-
|
| 178 |
-
ds = load_dataset(
|
| 179 |
-
config_dataset.path,
|
| 180 |
-
name=config_dataset.name,
|
| 181 |
-
data_files=config_dataset.data_files,
|
| 182 |
-
streaming=False,
|
| 183 |
-
split=None,
|
| 184 |
-
)
|
| 185 |
elif local_path.is_file():
|
| 186 |
-
ds_type =
|
| 187 |
-
|
| 188 |
-
ds_type = config_dataset.ds_type
|
| 189 |
-
elif ".parquet" in config_dataset.path:
|
| 190 |
-
ds_type = "parquet"
|
| 191 |
-
elif ".arrow" in config_dataset.path:
|
| 192 |
-
ds_type = "arrow"
|
| 193 |
-
elif ".csv" in config_dataset.path:
|
| 194 |
-
ds_type = "csv"
|
| 195 |
-
elif ".txt" in config_dataset.path:
|
| 196 |
-
ds_type = "text"
|
| 197 |
ds = load_dataset(
|
| 198 |
ds_type,
|
| 199 |
name=config_dataset.name,
|
|
@@ -213,6 +257,22 @@ def load_tokenized_prepared_datasets(
|
|
| 213 |
data_files=config_dataset.data_files,
|
| 214 |
token=use_auth_token,
|
| 215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
else:
|
| 217 |
if isinstance(config_dataset.data_files, str):
|
| 218 |
fp = hf_hub_download(
|
|
@@ -304,6 +364,24 @@ def load_tokenized_prepared_datasets(
|
|
| 304 |
return dataset, prompters
|
| 305 |
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def load_prepare_datasets(
|
| 308 |
tokenizer: PreTrainedTokenizerBase,
|
| 309 |
cfg,
|
|
|
|
| 170 |
except (FileNotFoundError, ConnectionError):
|
| 171 |
pass
|
| 172 |
|
| 173 |
+
ds_from_cloud = False
|
| 174 |
+
storage_options = {}
|
| 175 |
+
remote_file_system = None
|
| 176 |
+
if config_dataset.path.startswith("s3://"):
|
| 177 |
+
try:
|
| 178 |
+
import aiobotocore.session # type: ignore
|
| 179 |
+
import s3fs # type: ignore
|
| 180 |
+
except ImportError as exc:
|
| 181 |
+
raise ImportError(
|
| 182 |
+
"s3:// paths require aiobotocore and s3fs to be installed"
|
| 183 |
+
) from exc
|
| 184 |
+
|
| 185 |
+
# Takes credentials from ~/.aws/credentials for default profile
|
| 186 |
+
s3_session = aiobotocore.session.AioSession(profile="default")
|
| 187 |
+
storage_options = {"session": s3_session}
|
| 188 |
+
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
| 189 |
+
elif config_dataset.path.startswith(
|
| 190 |
+
"gs://"
|
| 191 |
+
) or config_dataset.path.startswith("gcs://"):
|
| 192 |
+
try:
|
| 193 |
+
import gcsfs # type: ignore
|
| 194 |
+
except ImportError as exc:
|
| 195 |
+
raise ImportError(
|
| 196 |
+
"gs:// or gcs:// paths require gcsfs to be installed"
|
| 197 |
+
) from exc
|
| 198 |
+
|
| 199 |
+
# gcsfs will use default credentials from the environment else anon
|
| 200 |
+
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
| 201 |
+
storage_options = {"token": None}
|
| 202 |
+
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
| 203 |
+
# TODO: Figure out how to get auth creds passed
|
| 204 |
+
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
| 205 |
+
# try:
|
| 206 |
+
# import adlfs
|
| 207 |
+
# except ImportError as exc:
|
| 208 |
+
# raise ImportError(
|
| 209 |
+
# "adl:// or abfs:// paths require adlfs to be installed"
|
| 210 |
+
# ) from exc
|
| 211 |
+
|
| 212 |
+
# # Gen 1
|
| 213 |
+
# storage_options = {
|
| 214 |
+
# "tenant_id": TENANT_ID,
|
| 215 |
+
# "client_id": CLIENT_ID,
|
| 216 |
+
# "client_secret": CLIENT_SECRET,
|
| 217 |
+
# }
|
| 218 |
+
# # Gen 2
|
| 219 |
+
# storage_options = {
|
| 220 |
+
# "account_name": ACCOUNT_NAME,
|
| 221 |
+
# "account_key": ACCOUNT_KEY,
|
| 222 |
+
# }
|
| 223 |
+
|
| 224 |
+
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
| 225 |
+
try:
|
| 226 |
+
if remote_file_system and remote_file_system.exists(
|
| 227 |
+
config_dataset.path
|
| 228 |
+
):
|
| 229 |
+
ds_from_cloud = True
|
| 230 |
+
except (FileNotFoundError, ConnectionError):
|
| 231 |
+
pass
|
| 232 |
+
|
| 233 |
# prefer local dataset, even if hub exists
|
| 234 |
local_path = Path(config_dataset.path)
|
| 235 |
if local_path.exists():
|
| 236 |
if local_path.is_dir():
|
| 237 |
+
ds = load_from_disk(config_dataset.path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
elif local_path.is_file():
|
| 239 |
+
ds_type = get_ds_type(config_dataset)
|
| 240 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
ds = load_dataset(
|
| 242 |
ds_type,
|
| 243 |
name=config_dataset.name,
|
|
|
|
| 257 |
data_files=config_dataset.data_files,
|
| 258 |
token=use_auth_token,
|
| 259 |
)
|
| 260 |
+
elif ds_from_cloud and remote_file_system:
|
| 261 |
+
if remote_file_system.isdir(config_dataset.path):
|
| 262 |
+
ds = load_from_disk(
|
| 263 |
+
config_dataset.path,
|
| 264 |
+
storage_options=storage_options,
|
| 265 |
+
)
|
| 266 |
+
elif remote_file_system.isfile(config_dataset.path):
|
| 267 |
+
ds_type = get_ds_type(config_dataset)
|
| 268 |
+
ds = load_dataset(
|
| 269 |
+
ds_type,
|
| 270 |
+
name=config_dataset.name,
|
| 271 |
+
data_files=config_dataset.path,
|
| 272 |
+
streaming=False,
|
| 273 |
+
split=None,
|
| 274 |
+
storage_options=storage_options,
|
| 275 |
+
)
|
| 276 |
else:
|
| 277 |
if isinstance(config_dataset.data_files, str):
|
| 278 |
fp = hf_hub_download(
|
|
|
|
| 364 |
return dataset, prompters
|
| 365 |
|
| 366 |
|
| 367 |
+
def get_ds_type(config_dataset: DictDefault):
|
| 368 |
+
"""
|
| 369 |
+
Get the dataset type from the path if it's not specified
|
| 370 |
+
"""
|
| 371 |
+
ds_type = "json"
|
| 372 |
+
if config_dataset.ds_type:
|
| 373 |
+
ds_type = config_dataset.ds_type
|
| 374 |
+
elif ".parquet" in config_dataset.path:
|
| 375 |
+
ds_type = "parquet"
|
| 376 |
+
elif ".arrow" in config_dataset.path:
|
| 377 |
+
ds_type = "arrow"
|
| 378 |
+
elif ".csv" in config_dataset.path:
|
| 379 |
+
ds_type = "csv"
|
| 380 |
+
elif ".txt" in config_dataset.path:
|
| 381 |
+
ds_type = "text"
|
| 382 |
+
return ds_type
|
| 383 |
+
|
| 384 |
+
|
| 385 |
def load_prepare_datasets(
|
| 386 |
tokenizer: PreTrainedTokenizerBase,
|
| 387 |
cfg,
|