File size: 2,502 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from pathlib import Path

from datasets import load_dataset

from seamless_interaction.fs import SeamlessInteractionFS


def main():
    """
    Demonstrate webdataset loading for both local and remote datasets.

    This script shows how to download and load dataset archives using
    webdataset format, supporting both local file access and direct
    HuggingFace Hub streaming.

    :param mode: Loading mode ('local' or 'hf')
    :param label: Dataset label ('improvised' or 'naturalistic')
    :param split: Data split ('dev', 'test', 'train')
    :param batch_idx: Batch index number
    :param archive_idx: Archive index within the batch
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="local")
    parser.add_argument("--label", type=str, default="improvised")
    parser.add_argument("--split", type=str, default="dev")
    parser.add_argument("--batch_idx", type=int, default=0)
    parser.add_argument("--archive_idx", type=int, default=23)
    args = parser.parse_args()

    fs = SeamlessInteractionFS()
    local_dir = Path.home() / "datasets/seamless_interaction"
    mode = args.mode
    label = args.label
    split = args.split
    batch_idx = args.batch_idx
    archive_idx = args.archive_idx

    fs.download_archive_from_hf(
        idx=batch_idx,
        archive=archive_idx,
        label=label,
        split=split,
        batch=batch_idx,
        local_dir=local_dir,
        extract=False,
    )

    if mode == "local":
        local_path = (
            local_dir / f"{label}/{split}/{batch_idx:04d}/{archive_idx:04d}.tar"
        )
        dataset = load_dataset(
            "webdataset", data_files={split: local_path}, split=split, streaming=True
        )
    elif mode == "hf":
        base_url = (
            f"https://huggingface.co/datasets/facebook/"
            f"seamless-interaction/resolve/main/{label}/{split}/"
            f"{batch_idx:04d}/{archive_idx:04d}.tar"
        )
        urls = [base_url.format(batch_idx=batch_idx, archive_idx=archive_idx)]
        dataset = load_dataset(
            "webdataset", data_files={split: urls}, split=split, streaming=True
        )

    for item in dataset:
        break

    print(item.keys())


if __name__ == "__main__":
    main()