|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
from seamless_interaction.fs import SeamlessInteractionFS |
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Demonstrate webdataset loading for both local and remote datasets. |
|
|
|
|
|
This script shows how to download and load dataset archives using |
|
|
webdataset format, supporting both local file access and direct |
|
|
HuggingFace Hub streaming. |
|
|
|
|
|
:param mode: Loading mode ('local' or 'hf') |
|
|
:param label: Dataset label ('improvised' or 'naturalistic') |
|
|
:param split: Data split ('dev', 'test', 'train') |
|
|
:param batch_idx: Batch index number |
|
|
:param archive_idx: Archive index within the batch |
|
|
""" |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--mode", type=str, default="local") |
|
|
parser.add_argument("--label", type=str, default="improvised") |
|
|
parser.add_argument("--split", type=str, default="dev") |
|
|
parser.add_argument("--batch_idx", type=int, default=0) |
|
|
parser.add_argument("--archive_idx", type=int, default=23) |
|
|
args = parser.parse_args() |
|
|
|
|
|
fs = SeamlessInteractionFS() |
|
|
local_dir = Path.home() / "datasets/seamless_interaction" |
|
|
mode = args.mode |
|
|
label = args.label |
|
|
split = args.split |
|
|
batch_idx = args.batch_idx |
|
|
archive_idx = args.archive_idx |
|
|
|
|
|
fs.download_archive_from_hf( |
|
|
idx=batch_idx, |
|
|
archive=archive_idx, |
|
|
label=label, |
|
|
split=split, |
|
|
batch=batch_idx, |
|
|
local_dir=local_dir, |
|
|
extract=False, |
|
|
) |
|
|
|
|
|
if mode == "local": |
|
|
local_path = ( |
|
|
local_dir / f"{label}/{split}/{batch_idx:04d}/{archive_idx:04d}.tar" |
|
|
) |
|
|
dataset = load_dataset( |
|
|
"webdataset", data_files={split: local_path}, split=split, streaming=True |
|
|
) |
|
|
elif mode == "hf": |
|
|
base_url = ( |
|
|
f"https://huggingface.co/datasets/facebook/" |
|
|
f"seamless-interaction/resolve/main/{label}/{split}/" |
|
|
f"{batch_idx:04d}/{archive_idx:04d}.tar" |
|
|
) |
|
|
urls = [base_url.format(batch_idx=batch_idx, archive_idx=archive_idx)] |
|
|
dataset = load_dataset( |
|
|
"webdataset", data_files={split: urls}, split=split, streaming=True |
|
|
) |
|
|
|
|
|
for item in dataset: |
|
|
break |
|
|
|
|
|
print(item.keys()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|