File size: 7,589 Bytes
a89a7f1
f4259c0
a89a7f1
9ac8f1c
a89a7f1
 
 
 
 
b279884
e75684b
 
 
 
5bf7fd5
a89a7f1
 
db3d901
a89a7f1
b279884
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
c30bbc5
 
a89a7f1
c30bbc5
e75684b
 
9ac8f1c
 
 
 
 
 
 
 
 
 
 
 
b527c23
9ac8f1c
 
b279884
 
 
 
 
 
 
 
9ac8f1c
 
 
 
b279884
9ac8f1c
 
12cdb17
 
 
 
 
 
 
a89a7f1
 
 
 
 
 
 
 
 
 
 
 
 
f4259c0
12cdb17
 
a89a7f1
 
 
 
 
dc186e4
 
 
 
 
 
 
 
 
 
 
 
db3d901
 
 
 
 
 
dc186e4
 
 
 
 
 
 
db3d901
dc186e4
 
a89a7f1
 
eb41f91
 
e75684b
 
 
 
 
 
 
a89a7f1
 
db3d901
b279884
 
 
 
 
c30bbc5
a89a7f1
db3d901
b279884
c30bbc5
e75684b
db3d901
b279884
c30bbc5
e75684b
a89a7f1
 
 
 
 
12cdb17
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import atexit
import hashlib
import shutil
import threading
from pathlib import Path
from tempfile import mkdtemp
from typing import Any

import streamlit as st
from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache
from persona_data.nemotron_personas import (
    NemotronPersonasFranceDataset,
    NemotronPersonasUSADataset,
)
from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
from persona_data.synth_persona import SynthPersonaDataset

from .helpers import DatasetSource

_SYNTH_PERSONA_REPO = "implicit-personalization/synth-persona"
_SYNTH_PERSONA_STARTUP_FILES = (
    "implicit_shared_mc_bank.json",
    "dataset_personas.jsonl",
)
_SYNTH_PERSONA_QA_FILE = "dataset_qa.jsonl"
_NEMOTRON_REPOS = {
    DatasetSource.NEMOTRON_FRANCE.value: "nvidia/Nemotron-Personas-France",
    DatasetSource.NEMOTRON_USA.value: "nvidia/Nemotron-Personas-USA",
}


@st.cache_resource(show_spinner=False)
def _cached_dataset(cls: type) -> Any:
    """Instantiate and cache a HuggingFace dataset class once per session."""

    return cls()


_qa_warm_lock = threading.Lock()


def warm_qa_in_background(dataset: Any) -> None:
    """Trigger the dataset's lazy QA parse on a daemon thread, once.

    QA loading is deferred in persona-data (large, unused outside Extract).
    Kicking it off when the Extract tab opens means the parse overlaps with
    the user picking personas/options instead of blocking the first run.
    Idempotent across Streamlit reruns: guarded per cached dataset instance.
    """

    warm = getattr(dataset, "prefetch_qa", None)
    if warm is None:
        return  # persona-only dataset (e.g. Nemotron) has no QA
    if isinstance(dataset, SynthPersonaDataset):
        # Extract will need QA soon. Make the one-time large transfer explicit,
        # then leave the CPU-heavy parse on the existing background thread.
        _download_missing_startup_files_if_needed(
            _SYNTH_PERSONA_REPO,
            (_SYNTH_PERSONA_QA_FILE,),
            "SynthPersona QA",
        )
    with _qa_warm_lock:
        if getattr(dataset, "_qa_warm_started", False):
            return
        dataset._qa_warm_started = True
    threading.Thread(target=warm, name="persona-ui-warm-qa", daemon=True).start()


@st.cache_resource(show_spinner=False)
def _cached_local_dataset(personas_path: str, qa_path: str) -> LocalPersonaDataset:
    """Instantiate and cache a local upload dataset for stable temp paths."""

    return LocalPersonaDataset(personas_path=Path(personas_path), qa_path=Path(qa_path))


def _upload_cache_dir() -> Path:
    cache_dir = st.session_state.get("_upload_cache_dir")
    if cache_dir is None:
        cache_dir = mkdtemp(prefix="persona_vectors_uploads_")
        st.session_state["_upload_cache_dir"] = cache_dir
        # Register cleanup so the temp dir is removed when the server process exits.
        atexit.register(shutil.rmtree, cache_dir, ignore_errors=True)
    return Path(cache_dir)


def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
    suffix = Path(uploaded_file.name).suffix or ".jsonl"
    data = uploaded_file.getvalue()
    digest = hashlib.sha256(data).hexdigest()
    temp_path = _upload_cache_dir() / f"{stem}_{digest[:16]}{suffix}"
    if temp_path.exists():
        return temp_path
    temp_path.write_bytes(data)
    return temp_path


def load_persona_list(
    dataset_source: str,
    personas_file: Any = None,
    qa_file: Any = None,
) -> tuple[list, str]:
    """Like ``load_dataset`` but returns ``(personas_list, status)``.

    The list is memoized on the cached dataset instance so repeated reruns
    don't pay for re-iteration.
    """

    dataset, status = load_dataset(dataset_source, personas_file, qa_file)
    return load_persona_list_from_dataset(dataset), status


def load_persona_list_from_dataset(dataset: Any) -> list:
    """Materialize and cache personas from an already-loaded dataset."""

    cached = getattr(dataset, "_persona_list_cache", None)
    if cached is None:
        cached = list(dataset)
        try:
            dataset._persona_list_cache = cached
        except (AttributeError, TypeError):
            pass
    return cached


def load_dataset(
    dataset_source: str,
    personas_file: Any = None,
    qa_file: Any = None,
) -> tuple[
    SynthPersonaDataset
    | NemotronPersonasFranceDataset
    | NemotronPersonasUSADataset
    | LocalPersonaDataset,
    str,
]:
    """Load the selected dataset source for the UI."""

    if dataset_source == DatasetSource.SYNTH_PERSONA.value:
        _download_missing_startup_files_if_needed(
            _SYNTH_PERSONA_REPO,
            _SYNTH_PERSONA_STARTUP_FILES,
            "SynthPersona",
        )
        return _cached_dataset(SynthPersonaDataset), "SynthPersona"

    if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
        _prepare_nemotron_startup_download(dataset_source, "Nemotron France")
        return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"

    if dataset_source == DatasetSource.NEMOTRON_USA.value:
        _prepare_nemotron_startup_download(dataset_source, "Nemotron USA")
        return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"

    if personas_file is None or qa_file is None:
        raise ValueError("Upload both personas.jsonl and qa.jsonl files")

    personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
    qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
    return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"


def _is_cached(repo_id: str, filename: str) -> bool:
    """Return whether a Hub dataset file already exists in the local HF cache."""

    cached = try_to_load_from_cache(repo_id, filename, repo_type="dataset")
    return isinstance(cached, str)


def _download_missing_startup_files_if_needed(
    repo_id: str,
    filenames: tuple[str, ...],
    label: str,
) -> None:
    """Make first-time Hub downloads visible before dataset construction blocks.

    Hugging Face handles byte-level transfer internally. We expose file-level
    progress here, which is the useful unit this UI can know in advance.
    """

    missing = tuple(
        filename for filename in filenames if not _is_cached(repo_id, filename)
    )
    if not missing:
        return

    notice = st.empty()
    notice.warning(
        f"First-time setup for {label}: downloading dataset files from Hugging Face. "
        "Later loads should use the local cache."
    )
    progress = st.progress(0.0, text=f"Preparing {label} download…")
    total = len(missing)
    for index, filename in enumerate(missing, start=1):
        progress.progress(
            (index - 1) / total,
            text=f"Downloading {filename} ({index}/{total})",
        )
        hf_hub_download(repo_id, filename, repo_type="dataset")
        progress.progress(
            index / total,
            text=f"Downloaded {filename} ({index}/{total})",
        )
    notice.empty()


def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
    """Prefetch the first parquet shard used by the default Nemotron sample."""

    repo_id = _NEMOTRON_REPOS[dataset_source]
    parquet_files = tuple(
        sorted(
            filename
            for filename in list_repo_files(repo_id, repo_type="dataset")
            if filename.startswith("data/train-") and filename.endswith(".parquet")
        )
    )
    if parquet_files:
        _download_missing_startup_files_if_needed(repo_id, (parquet_files[0],), label)