File size: 4,055 Bytes
b279884
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
 
 
 
 
 
 
 
 
 
 
 
b279884
ae347c6
b279884
 
 
 
 
 
 
 
ae347c6
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
 
b279884
 
 
 
 
 
 
 
 
 
ae347c6
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations

from utils import datasets


class _Progress:
    def __init__(self) -> None:
        self.updates: list[tuple[float, str | None]] = []

    def progress(self, value: float, *, text: str | None = None) -> None:
        self.updates.append((value, text))


class _Notice:
    def __init__(self) -> None:
        self.messages: list[str] = []
        self.empty_calls = 0

    def warning(self, message: str) -> None:
        self.messages.append(message)

    def empty(self) -> None:
        self.empty_calls += 1


def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
    notice = _Notice()
    progress = _Progress()
    downloads: list[tuple[str, str, str]] = []

    monkeypatch.setattr(
        datasets,
        "_is_cached",
        lambda _repo, filename: filename == "already.jsonl",
    )
    monkeypatch.setattr(datasets.st, "empty", lambda: notice)
    monkeypatch.setattr(
        datasets.st,
        "progress",
        lambda value, *, text=None: progress,
    )
    monkeypatch.setattr(
        datasets,
        "hf_hub_download",
        lambda repo, filename, *, repo_type: downloads.append(
            (repo, filename, repo_type)
        ),
    )

    datasets._download_missing_startup_files_if_needed(
        "org/repo",
        ("already.jsonl", "missing.jsonl"),
        "Example",
    )

    assert notice.messages and "First-time setup for Example" in notice.messages[0]
    assert notice.empty_calls == 1
    assert downloads == [("org/repo", "missing.jsonl", "dataset")]
    assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")


def test_download_missing_startup_files_stays_quiet_when_cached(monkeypatch):
    monkeypatch.setattr(datasets, "_is_cached", lambda *_args: True)

    def unexpected(*_args, **_kwargs):
        raise AssertionError("cold-download UI should not render for warm cache")

    monkeypatch.setattr(datasets.st, "empty", unexpected)
    monkeypatch.setattr(datasets.st, "progress", unexpected)
    monkeypatch.setattr(datasets, "hf_hub_download", unexpected)

    datasets._download_missing_startup_files_if_needed(
        "org/repo",
        ("cached.jsonl",),
        "Example",
    )


def test_prepare_nemotron_prefetches_first_parquet_shard(monkeypatch):
    calls: list[tuple[str, tuple[str, ...], str]] = []
    monkeypatch.setattr(
        datasets,
        "list_repo_files",
        lambda *_args, **_kwargs: (
            "README.md",
            "data/train-00001-of-00002.parquet",
            "data/train-00000-of-00002.parquet",
        ),
    )
    monkeypatch.setattr(
        datasets,
        "_download_missing_startup_files_if_needed",
        lambda repo, filenames, label: calls.append((repo, filenames, label)),
    )

    datasets._prepare_nemotron_startup_download(
        datasets.DatasetSource.NEMOTRON_USA.value,
        "Nemotron USA",
    )

    assert calls == [
        (
            "nvidia/Nemotron-Personas-USA",
            ("data/train-00000-of-00002.parquet",),
            "Nemotron USA",
        )
    ]


def test_warm_qa_makes_synth_qa_download_visible_before_thread(monkeypatch):
    calls: list[tuple[str, tuple[str, ...], str]] = []
    started: list[bool] = []

    class DummySynth:
        def prefetch_qa(self) -> None:
            pass

    class DummyThread:
        def __init__(self, *args, **kwargs) -> None:
            pass

        def start(self) -> None:
            started.append(True)

    monkeypatch.setattr(datasets, "SynthPersonaDataset", DummySynth)
    monkeypatch.setattr(
        datasets,
        "_download_missing_startup_files_if_needed",
        lambda repo, filenames, label: calls.append((repo, filenames, label)),
    )
    monkeypatch.setattr(datasets.threading, "Thread", DummyThread)

    datasets.warm_qa_in_background(DummySynth())

    assert calls == [
        (
            "implicit-personalization/synth-persona",
            ("dataset_qa.jsonl",),
            "SynthPersona QA",
        )
    ]
    assert started == [True]