File size: 7,589 Bytes
a89a7f1 f4259c0 a89a7f1 9ac8f1c a89a7f1 b279884 e75684b 5bf7fd5 a89a7f1 db3d901 a89a7f1 b279884 a89a7f1 c30bbc5 a89a7f1 c30bbc5 e75684b 9ac8f1c b527c23 9ac8f1c b279884 9ac8f1c b279884 9ac8f1c 12cdb17 a89a7f1 f4259c0 12cdb17 a89a7f1 dc186e4 db3d901 dc186e4 db3d901 dc186e4 a89a7f1 eb41f91 e75684b a89a7f1 db3d901 b279884 c30bbc5 a89a7f1 db3d901 b279884 c30bbc5 e75684b db3d901 b279884 c30bbc5 e75684b a89a7f1 12cdb17 b279884 ae347c6 b279884 ae347c6 b279884 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | import atexit
import hashlib
import shutil
import threading
from pathlib import Path
from tempfile import mkdtemp
from typing import Any
import streamlit as st
from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache
from persona_data.nemotron_personas import (
NemotronPersonasFranceDataset,
NemotronPersonasUSADataset,
)
from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
from persona_data.synth_persona import SynthPersonaDataset
from .helpers import DatasetSource
_SYNTH_PERSONA_REPO = "implicit-personalization/synth-persona"
_SYNTH_PERSONA_STARTUP_FILES = (
"implicit_shared_mc_bank.json",
"dataset_personas.jsonl",
)
_SYNTH_PERSONA_QA_FILE = "dataset_qa.jsonl"
_NEMOTRON_REPOS = {
DatasetSource.NEMOTRON_FRANCE.value: "nvidia/Nemotron-Personas-France",
DatasetSource.NEMOTRON_USA.value: "nvidia/Nemotron-Personas-USA",
}
@st.cache_resource(show_spinner=False)
def _cached_dataset(cls: type) -> Any:
"""Instantiate and cache a HuggingFace dataset class once per session."""
return cls()
_qa_warm_lock = threading.Lock()
def warm_qa_in_background(dataset: Any) -> None:
"""Trigger the dataset's lazy QA parse on a daemon thread, once.
QA loading is deferred in persona-data (large, unused outside Extract).
Kicking it off when the Extract tab opens means the parse overlaps with
the user picking personas/options instead of blocking the first run.
Idempotent across Streamlit reruns: guarded per cached dataset instance.
"""
warm = getattr(dataset, "prefetch_qa", None)
if warm is None:
return # persona-only dataset (e.g. Nemotron) has no QA
if isinstance(dataset, SynthPersonaDataset):
# Extract will need QA soon. Make the one-time large transfer explicit,
# then leave the CPU-heavy parse on the existing background thread.
_download_missing_startup_files_if_needed(
_SYNTH_PERSONA_REPO,
(_SYNTH_PERSONA_QA_FILE,),
"SynthPersona QA",
)
with _qa_warm_lock:
if getattr(dataset, "_qa_warm_started", False):
return
dataset._qa_warm_started = True
threading.Thread(target=warm, name="persona-ui-warm-qa", daemon=True).start()
@st.cache_resource(show_spinner=False)
def _cached_local_dataset(personas_path: str, qa_path: str) -> LocalPersonaDataset:
"""Instantiate and cache a local upload dataset for stable temp paths."""
return LocalPersonaDataset(personas_path=Path(personas_path), qa_path=Path(qa_path))
def _upload_cache_dir() -> Path:
cache_dir = st.session_state.get("_upload_cache_dir")
if cache_dir is None:
cache_dir = mkdtemp(prefix="persona_vectors_uploads_")
st.session_state["_upload_cache_dir"] = cache_dir
# Register cleanup so the temp dir is removed when the server process exits.
atexit.register(shutil.rmtree, cache_dir, ignore_errors=True)
return Path(cache_dir)
def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
suffix = Path(uploaded_file.name).suffix or ".jsonl"
data = uploaded_file.getvalue()
digest = hashlib.sha256(data).hexdigest()
temp_path = _upload_cache_dir() / f"{stem}_{digest[:16]}{suffix}"
if temp_path.exists():
return temp_path
temp_path.write_bytes(data)
return temp_path
def load_persona_list(
dataset_source: str,
personas_file: Any = None,
qa_file: Any = None,
) -> tuple[list, str]:
"""Like ``load_dataset`` but returns ``(personas_list, status)``.
The list is memoized on the cached dataset instance so repeated reruns
don't pay for re-iteration.
"""
dataset, status = load_dataset(dataset_source, personas_file, qa_file)
return load_persona_list_from_dataset(dataset), status
def load_persona_list_from_dataset(dataset: Any) -> list:
"""Materialize and cache personas from an already-loaded dataset."""
cached = getattr(dataset, "_persona_list_cache", None)
if cached is None:
cached = list(dataset)
try:
dataset._persona_list_cache = cached
except (AttributeError, TypeError):
pass
return cached
def load_dataset(
dataset_source: str,
personas_file: Any = None,
qa_file: Any = None,
) -> tuple[
SynthPersonaDataset
| NemotronPersonasFranceDataset
| NemotronPersonasUSADataset
| LocalPersonaDataset,
str,
]:
"""Load the selected dataset source for the UI."""
if dataset_source == DatasetSource.SYNTH_PERSONA.value:
_download_missing_startup_files_if_needed(
_SYNTH_PERSONA_REPO,
_SYNTH_PERSONA_STARTUP_FILES,
"SynthPersona",
)
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
_prepare_nemotron_startup_download(dataset_source, "Nemotron France")
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
if dataset_source == DatasetSource.NEMOTRON_USA.value:
_prepare_nemotron_startup_download(dataset_source, "Nemotron USA")
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
if personas_file is None or qa_file is None:
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"
def _is_cached(repo_id: str, filename: str) -> bool:
"""Return whether a Hub dataset file already exists in the local HF cache."""
cached = try_to_load_from_cache(repo_id, filename, repo_type="dataset")
return isinstance(cached, str)
def _download_missing_startup_files_if_needed(
repo_id: str,
filenames: tuple[str, ...],
label: str,
) -> None:
"""Make first-time Hub downloads visible before dataset construction blocks.
Hugging Face handles byte-level transfer internally. We expose file-level
progress here, which is the useful unit this UI can know in advance.
"""
missing = tuple(
filename for filename in filenames if not _is_cached(repo_id, filename)
)
if not missing:
return
notice = st.empty()
notice.warning(
f"First-time setup for {label}: downloading dataset files from Hugging Face. "
"Later loads should use the local cache."
)
progress = st.progress(0.0, text=f"Preparing {label} download…")
total = len(missing)
for index, filename in enumerate(missing, start=1):
progress.progress(
(index - 1) / total,
text=f"Downloading {filename} ({index}/{total})",
)
hf_hub_download(repo_id, filename, repo_type="dataset")
progress.progress(
index / total,
text=f"Downloaded {filename} ({index}/{total})",
)
notice.empty()
def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
"""Prefetch the first parquet shard used by the default Nemotron sample."""
repo_id = _NEMOTRON_REPOS[dataset_source]
parquet_files = tuple(
sorted(
filename
for filename in list_repo_files(repo_id, repo_type="dataset")
if filename.startswith("data/train-") and filename.endswith(".parquet")
)
)
if parquet_files:
_download_missing_startup_files_if_needed(repo_id, (parquet_files[0],), label)
|