File size: 4,055 Bytes
b279884 ae347c6 b279884 ae347c6 b279884 ae347c6 b279884 ae347c6 b279884 ae347c6 b279884 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from __future__ import annotations
from utils import datasets
class _Progress:
def __init__(self) -> None:
self.updates: list[tuple[float, str | None]] = []
def progress(self, value: float, *, text: str | None = None) -> None:
self.updates.append((value, text))
class _Notice:
def __init__(self) -> None:
self.messages: list[str] = []
self.empty_calls = 0
def warning(self, message: str) -> None:
self.messages.append(message)
def empty(self) -> None:
self.empty_calls += 1
def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
notice = _Notice()
progress = _Progress()
downloads: list[tuple[str, str, str]] = []
monkeypatch.setattr(
datasets,
"_is_cached",
lambda _repo, filename: filename == "already.jsonl",
)
monkeypatch.setattr(datasets.st, "empty", lambda: notice)
monkeypatch.setattr(
datasets.st,
"progress",
lambda value, *, text=None: progress,
)
monkeypatch.setattr(
datasets,
"hf_hub_download",
lambda repo, filename, *, repo_type: downloads.append(
(repo, filename, repo_type)
),
)
datasets._download_missing_startup_files_if_needed(
"org/repo",
("already.jsonl", "missing.jsonl"),
"Example",
)
assert notice.messages and "First-time setup for Example" in notice.messages[0]
assert notice.empty_calls == 1
assert downloads == [("org/repo", "missing.jsonl", "dataset")]
assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")
def test_download_missing_startup_files_stays_quiet_when_cached(monkeypatch):
monkeypatch.setattr(datasets, "_is_cached", lambda *_args: True)
def unexpected(*_args, **_kwargs):
raise AssertionError("cold-download UI should not render for warm cache")
monkeypatch.setattr(datasets.st, "empty", unexpected)
monkeypatch.setattr(datasets.st, "progress", unexpected)
monkeypatch.setattr(datasets, "hf_hub_download", unexpected)
datasets._download_missing_startup_files_if_needed(
"org/repo",
("cached.jsonl",),
"Example",
)
def test_prepare_nemotron_prefetches_first_parquet_shard(monkeypatch):
calls: list[tuple[str, tuple[str, ...], str]] = []
monkeypatch.setattr(
datasets,
"list_repo_files",
lambda *_args, **_kwargs: (
"README.md",
"data/train-00001-of-00002.parquet",
"data/train-00000-of-00002.parquet",
),
)
monkeypatch.setattr(
datasets,
"_download_missing_startup_files_if_needed",
lambda repo, filenames, label: calls.append((repo, filenames, label)),
)
datasets._prepare_nemotron_startup_download(
datasets.DatasetSource.NEMOTRON_USA.value,
"Nemotron USA",
)
assert calls == [
(
"nvidia/Nemotron-Personas-USA",
("data/train-00000-of-00002.parquet",),
"Nemotron USA",
)
]
def test_warm_qa_makes_synth_qa_download_visible_before_thread(monkeypatch):
calls: list[tuple[str, tuple[str, ...], str]] = []
started: list[bool] = []
class DummySynth:
def prefetch_qa(self) -> None:
pass
class DummyThread:
def __init__(self, *args, **kwargs) -> None:
pass
def start(self) -> None:
started.append(True)
monkeypatch.setattr(datasets, "SynthPersonaDataset", DummySynth)
monkeypatch.setattr(
datasets,
"_download_missing_startup_files_if_needed",
lambda repo, filenames, label: calls.append((repo, filenames, label)),
)
monkeypatch.setattr(datasets.threading, "Thread", DummyThread)
datasets.warm_qa_in_background(DummySynth())
assert calls == [
(
"implicit-personalization/synth-persona",
("dataset_qa.jsonl",),
"SynthPersona QA",
)
]
assert started == [True]
|