File size: 18,572 Bytes
5412d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
"""
ARKitScenes subset loader for the depth-aware scene description evaluation.

Downloads a small, reproducible set of indoor RGB frames from the ARKitScenes
Validation split [Baruch et al., NeurIPS 2021].  Frames are the evaluation
dataset for all three ablation metrics (STD, SFS, Preamble BERTScore).

Dataset:
    Baruch, G., Chen, Z., Dehghan, A., et al. (2021). ARKitScenes β€” A Diverse
    Real-World Dataset for 3D Indoor Scene Understanding.
    NeurIPS 2021 Datasets and Benchmarks Track.
    https://github.com/apple/ARKitScenes

Prerequisites
-------------
Clone the ARKitScenes repository and set the ``ARKITSCENES_REPO`` environment
variable to its absolute path, OR pass ``--repo-dir`` on the CLI::

    git clone https://github.com/apple/ARKitScenes /path/to/ARKitScenes
    set ARKITSCENES_REPO=C:\\path\\to\\ARKitScenes        (Windows CMD)
    $env:ARKITSCENES_REPO = "C:\\path\\to\\ARKitScenes"   (Windows PowerShell)
    export ARKITSCENES_REPO=/path/to/ARKitScenes           (Linux/macOS)

The repo is cloned only to read the splits CSV.  Frame downloads go directly
to the Apple CDN via Python builtins (no curl or unzip shell commands).

Downloaded frames land at::

    <tmp_dir>/raw/Validation/<video_id>/lowres_wide/*.png

Usage (library)::

    from src.data.arkitscenes_loader import fetch_arkitscenes_subset
    paths = fetch_arkitscenes_subset(n_videos=20, frames_per_video=10)

CLI::

    python -m src.data.arkitscenes_loader \\
        --n-videos 20 --frames-per-video 10 --output data/test_images/ \\
        --repo-dir C:\\path\\to\\ARKitScenes
"""

from __future__ import annotations

import argparse
import csv
import os
import random
import tempfile
import urllib.request
import zipfile
from pathlib import Path
from typing import Optional

from ..config import TEST_IMAGES_DIR

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

# Path of the splits manifest *inside* the ARKitScenes repo clone.
_SPLITS_REL = Path("raw") / "raw_train_val_splits.csv"

# Column names inside the CSV (ARKitScenes convention).
_COL_VIDEO_ID = "video_id"
_COL_FOLD = "fold"
_VALIDATION_FOLD = "Validation"

# ARKitScenes CDN base URL (public, no auth required).
_CDN_BASE = "https://docs-assets.developer.apple.com/ml-research/datasets/arkitscenes/v1"

# lowres_wide = 256Γ—192 wide-angle RGB frames, sufficient for VLM evaluation.
_RGB_ASSET = "lowres_wide"

# Image extensions to accept when scanning downloaded frames.
_IMAGE_EXTS = {".png", ".jpg", ".jpeg"}

# Environment variable users set to point at their ARKitScenes clone.
_REPO_ENV_VAR = "ARKITSCENES_REPO"


def _default_tmp_dir() -> Path:
    """Return a system temp directory guaranteed to have no spaces in its path.

    ``download_data.py`` invokes curl/unzip via shell strings without quoting
    paths, which breaks on Windows when the project directory contains spaces
    (e.g. ``C:\\Users\\risha\\New folder\\...``).  We avoid that entirely by
    downloading and extracting with Python builtins, but we still want the
    scratch directory to be clean and outside the project tree.
    """
    return Path(tempfile.gettempdir()) / "arkit_tmp"


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _read_splits_csv(repo_dir: Path) -> list[dict[str, str]]:
    """Read the train/val splits manifest from the local ARKitScenes clone.

    Args:
        repo_dir: Root of the cloned ARKitScenes repository.

    Returns:
        List of row dicts, one per video in the manifest.

    Raises:
        FileNotFoundError: If the CSV is not found at the expected path inside
            the repo clone.
    """
    csv_path = repo_dir / _SPLITS_REL
    if not csv_path.exists():
        raise FileNotFoundError(
            f"Splits CSV not found at {csv_path}.\n"
            f"Make sure --repo-dir (or ${_REPO_ENV_VAR}) points to the root "
            f"of a complete ARKitScenes clone."
        )
    print(f"  Reading splits CSV: {csv_path}")
    with open(csv_path, newline="", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        rows = list(reader)
    return rows


def _validation_video_ids(rows: list[dict[str, str]]) -> list[str]:
    """Extract video IDs for the Validation fold.

    Args:
        rows: All rows from the splits CSV.

    Returns:
        Sorted list of video ID strings marked as Validation.

    Raises:
        ValueError: If the expected column names are absent.
    """
    if not rows:
        raise ValueError("Splits CSV is empty.")

    first = rows[0]
    if _COL_VIDEO_ID not in first or _COL_FOLD not in first:
        raise ValueError(
            f"Splits CSV columns do not match expected schema. "
            f"Got: {list(first.keys())}. "
            f"Expected '{_COL_VIDEO_ID}' and '{_COL_FOLD}'."
        )

    ids = [r[_COL_VIDEO_ID] for r in rows if r[_COL_FOLD] == _VALIDATION_FOLD]
    if not ids:
        raise ValueError(
            f"No rows with fold='{_VALIDATION_FOLD}' found in splits CSV."
        )
    return sorted(ids)


def _resolve_repo_dir(repo_dir: Optional[Path]) -> Path:
    """Return the ARKitScenes repo directory, checking env var as fallback.

    The repo is only needed to read the splits CSV.  It is NOT used for
    downloading frames (we fetch directly from the Apple CDN with Python
    builtins to avoid Windows curl/unzip path-with-spaces issues).

    Args:
        repo_dir: Explicit path supplied by the caller (may be None).

    Returns:
        Resolved Path to the ARKitScenes clone root.

    Raises:
        RuntimeError: If neither argument nor env var points to a valid dir.
    """
    if repo_dir is not None:
        resolved = Path(repo_dir).expanduser().resolve()
        if resolved.is_dir():
            return resolved
        raise RuntimeError(f"--repo-dir does not exist: {resolved}")

    env_val = os.environ.get(_REPO_ENV_VAR)
    if env_val:
        resolved = Path(env_val).expanduser().resolve()
        if resolved.is_dir():
            return resolved
        raise RuntimeError(
            f"${_REPO_ENV_VAR} is set to '{env_val}' but that directory "
            f"does not exist."
        )

    raise RuntimeError(
        f"ARKitScenes repo directory not found.\n"
        f"Either:\n"
        f"  1. Pass --repo-dir on the command line:\n"
        f"       --repo-dir \"C:\\path\\to\\ARKitScenes\"\n"
        f"  2. Set the {_REPO_ENV_VAR} environment variable:\n"
        f"       Windows CMD:        set {_REPO_ENV_VAR}=C:\\path\\to\\ARKitScenes\n"
        f"       Windows PowerShell: $env:{_REPO_ENV_VAR} = \"C:\\path\\to\\ARKitScenes\"\n"
        f"       Linux/macOS:        export {_REPO_ENV_VAR}=/path/to/ARKitScenes\n"
        f"\n"
        f"Clone with:  git clone https://github.com/apple/ARKitScenes"
    )


def _download_video(video_id: str, tmp_dir: Path) -> Optional[Path]:
    """Download and extract the lowres_wide RGB frames for one video_id.

    Downloads the zip directly from the Apple ARKitScenes CDN using
    ``urllib.request`` and extracts it with ``zipfile``.  This is fully
    cross-platform and handles paths with spaces correctly β€” unlike the
    ``download_data.py`` approach which uses unquoted curl/unzip shell
    commands that break on Windows paths containing spaces.

    The zip is streamed to ``<tmp_dir>/<video_id>.zip``, then extracted to
    ``<tmp_dir>/<video_id>/``.

    Args:
        video_id: ARKitScenes video identifier string (e.g. ``"42444596"``).
        tmp_dir:  Scratch directory to write the zip and extracted frames.
                  Must be a path **without spaces** (use the system temp dir).

    Returns:
        Path to the extracted frame directory, or None on failure.
    """
    url = f"{_CDN_BASE}/raw/{_VALIDATION_FOLD}/{video_id}/{_RGB_ASSET}.zip"
    zip_path = tmp_dir / f"{video_id}.zip"
    extract_dir = tmp_dir / video_id

    # ── Download ──────────────────────────────────────────────────────────────
    if not zip_path.exists():
        print(f"  Downloading {video_id} from CDN …", flush=True)
        try:
            def _progress(block_num: int, block_size: int, total: int) -> None:
                if total > 0:
                    pct = min(100, block_num * block_size * 100 // total)
                    print(f"\r    {pct:3d}%  ({block_num * block_size / 1e6:.1f} MB)", end="", flush=True)

            urllib.request.urlretrieve(url, zip_path, reporthook=_progress)
            print()   # newline after progress
        except Exception as exc:
            print(f"\n    [WARN] Download failed for {video_id}: {exc}")
            if zip_path.exists():
                zip_path.unlink()
            return None
    else:
        print(f"  {video_id}.zip already downloaded, skipping.")

    # ── Extract ───────────────────────────────────────────────────────────────
    if not extract_dir.exists():
        print(f"  Extracting {video_id}.zip …", flush=True)
        try:
            with zipfile.ZipFile(zip_path, "r") as zf:
                zf.extractall(extract_dir)
        except zipfile.BadZipFile as exc:
            print(f"    [WARN] Bad zip for {video_id}: {exc}")
            zip_path.unlink(missing_ok=True)
            return None
    else:
        print(f"  {video_id}/ already extracted, skipping.")

    return extract_dir


def _collect_frames(video_dir: Path) -> list[Path]:
    """Recursively collect all image files under a video directory, sorted.

    Args:
        video_dir: Root directory to search.

    Returns:
        Sorted list of image Paths.
    """
    return sorted(
        p for p in video_dir.rglob("*")
        if p.suffix.lower() in _IMAGE_EXTS and p.is_file()
    )


def _evenly_spaced(items: list, n: int) -> list:
    """Return n evenly-spaced elements from items (not random).

    Uses integer linspace so the selection is deterministic and reproducible
    regardless of list length.  If ``len(items) <= n``, returns all items.

    Args:
        items: Source list.
        n:     Desired number of elements.

    Returns:
        Selected sub-list, preserving original order.
    """
    if len(items) <= n:
        return list(items)
    step = len(items) / n
    indices = [int(i * step) for i in range(n)]
    return [items[idx] for idx in indices]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def fetch_arkitscenes_subset(
    n_videos: int = 20,
    frames_per_video: int = 10,
    seed: int = 42,
    output_dir: Path = TEST_IMAGES_DIR,
    repo_dir: Optional[Path] = None,
    tmp_dir: Optional[Path] = None,
    skip_existing: bool = True,
) -> list[Path]:
    """Download a small, reproducible ARKitScenes Validation subset.

    Reads the official Validation split manifest from the local ARKitScenes
    clone, samples ``n_videos`` video IDs with a fixed seed, downloads the
    ``lowres_wide`` RGB asset for each video directly from the Apple CDN
    using Python builtins (no curl, no unzip shell commands), then copies
    ``frames_per_video`` evenly-spaced frames to ``output_dir`` with
    canonical names (``arkit_<video_id>_<frame_idx:04d>.jpg``).

    Args:
        n_videos:         Number of video sequences to sample.
        frames_per_video: Frames to keep per sequence (evenly spaced,
                          deterministic β€” not random).
        seed:             Random seed for video sampling.  Fixed at 42 in
                          all published results for reproducibility.
        output_dir:       Destination directory for the sampled JPEG frames.
        repo_dir:         Path to a local ARKitScenes clone (needed only to
                          read the splits CSV).  If None, the
                          ``ARKITSCENES_REPO`` environment variable is used.
        tmp_dir:          Scratch directory for zip downloads and extraction.
                          Defaults to ``<system_temp>/arkit_tmp`` to avoid
                          path-with-spaces issues on Windows.  Safe to delete
                          after this function returns.
        skip_existing:    If True, skip frames whose destination path already
                          exists (allows resuming interrupted runs).

    Returns:
        Sorted list of output JPEG Paths (up to n_videos Γ— frames_per_video).

    Raises:
        RuntimeError: If the ARKitScenes repo directory cannot be resolved.
        FileNotFoundError: If the splits CSV is missing from the repo clone.
    """
    output_dir = Path(output_dir)
    # Default tmp_dir to system temp to guarantee a spaceless path on Windows.
    resolved_tmp = Path(tmp_dir) if tmp_dir is not None else _default_tmp_dir()
    output_dir.mkdir(parents=True, exist_ok=True)
    resolved_tmp.mkdir(parents=True, exist_ok=True)

    resolved_repo = _resolve_repo_dir(repo_dir)
    print(f"ARKitScenes repo : {resolved_repo}")
    print(f"Output directory : {output_dir}")
    print(f"Temp directory   : {resolved_tmp}")
    print(f"Seed={seed}  n_videos={n_videos}  frames_per_video={frames_per_video}")
    print()

    # ── 1. Read validation split from local repo ──────────────────────────────
    rows = _read_splits_csv(resolved_repo)
    val_ids = _validation_video_ids(rows)
    print(f"  Validation videos in manifest: {len(val_ids)}")

    # ── 2. Sample n_videos with fixed seed ────────────────────────────────────
    rng = random.Random(seed)
    sampled_ids = sorted(rng.sample(val_ids, min(n_videos, len(val_ids))))
    print(f"  Sampled {len(sampled_ids)} videos (seed={seed}):")
    for vid in sampled_ids:
        print(f"    {vid}")
    print()

    # ── 3. Download + sample each video ──────────────────────────────────────
    output_paths: list[Path] = []
    skipped_videos = 0

    for idx, video_id in enumerate(sampled_ids, start=1):
        print(f"[{idx}/{len(sampled_ids)}] video_id={video_id}", flush=True)

        frames_dir = _download_video(video_id, resolved_tmp)
        if frames_dir is None:
            print(f"  Skipped (download failed).")
            skipped_videos += 1
            continue

        all_frames = _collect_frames(frames_dir)
        if not all_frames:
            print(f"  Skipped (no image files found under {frames_dir}).")
            skipped_videos += 1
            continue

        print(f"  {len(all_frames)} frames found β†’ keeping {frames_per_video}")
        selected = _evenly_spaced(all_frames, frames_per_video)

        for frame_path in selected:
            frame_idx = all_frames.index(frame_path)
            dest_name = f"arkit_{video_id}_{frame_idx:04d}.jpg"
            dest = output_dir / dest_name

            if skip_existing and dest.exists():
                print(f"    {dest_name}  (exists, skipped)")
                output_paths.append(dest)
                continue

            try:
                from PIL import Image as _PILImage  # noqa: PLC0415
                img = _PILImage.open(frame_path).convert("RGB")
                img.save(dest, format="JPEG", quality=95)
                print(f"    {dest_name}")
                output_paths.append(dest)
            except Exception as exc:
                print(f"    [WARN] Could not save {frame_path.name}: {exc}")

    # ── 4. Summary ────────────────────────────────────────────────────────────
    print()
    print(f"Done.  {len(output_paths)} frames written to {output_dir}")
    if skipped_videos:
        print(f"  ({skipped_videos} videos skipped)")

    return sorted(output_paths)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Download a reproducible ARKitScenes Validation subset for "
            "depth-aware scene description evaluation."
        )
    )
    p.add_argument(
        "--n-videos", type=int, default=20,
        help="Number of video sequences to sample (default: 20).",
    )
    p.add_argument(
        "--frames-per-video", type=int, default=10,
        help="Evenly-spaced frames to keep per video (default: 10).",
    )
    p.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for video sampling (default: 42). "
             "Keep at 42 to match published results.",
    )
    p.add_argument(
        "--output", type=Path, default=TEST_IMAGES_DIR,
        help=f"Output directory for JPEG frames (default: {TEST_IMAGES_DIR}).",
    )
    p.add_argument(
        "--repo-dir", type=Path, default=None,
        help=(
            "Path to the cloned ARKitScenes repo. "
            f"Overrides the ${_REPO_ENV_VAR} environment variable."
        ),
    )
    p.add_argument(
        "--tmp-dir", type=Path, default=None,
        help=(
            "Scratch directory for zip downloads and extraction. "
            "Defaults to <system_temp>/arkit_tmp to avoid Windows path-with-spaces issues."
        ),
    )
    p.add_argument(
        "--no-skip-existing", action="store_true",
        help="Overwrite frames that already exist in --output.",
    )
    return p.parse_args(argv)


def main(argv: list[str] | None = None) -> None:
    """CLI entry point."""
    args = _parse_args(argv)
    paths = fetch_arkitscenes_subset(
        n_videos=args.n_videos,
        frames_per_video=args.frames_per_video,
        seed=args.seed,
        output_dir=args.output,
        repo_dir=args.repo_dir,
        tmp_dir=args.tmp_dir,
        skip_existing=not args.no_skip_existing,
    )
    print(f"\n{len(paths)} evaluation frames ready.")


if __name__ == "__main__":
    main()