import atexit import hashlib import shutil import threading from pathlib import Path from tempfile import mkdtemp from typing import Any import streamlit as st from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache from persona_data.nemotron_personas import ( NemotronPersonasFranceDataset, NemotronPersonasUSADataset, ) from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset from persona_data.synth_persona import SynthPersonaDataset from .helpers import DatasetSource _SYNTH_PERSONA_REPO = "implicit-personalization/synth-persona" _SYNTH_PERSONA_STARTUP_FILES = ( "implicit_shared_mc_bank.json", "dataset_personas.jsonl", ) _SYNTH_PERSONA_QA_FILE = "dataset_qa.jsonl" _NEMOTRON_REPOS = { DatasetSource.NEMOTRON_FRANCE.value: "nvidia/Nemotron-Personas-France", DatasetSource.NEMOTRON_USA.value: "nvidia/Nemotron-Personas-USA", } @st.cache_resource(show_spinner=False) def _cached_dataset(cls: type) -> Any: """Instantiate and cache a HuggingFace dataset class once per session.""" return cls() _qa_warm_lock = threading.Lock() def warm_qa_in_background(dataset: Any) -> None: """Trigger the dataset's lazy QA parse on a daemon thread, once. QA loading is deferred in persona-data (large, unused outside Extract). Kicking it off when the Extract tab opens means the parse overlaps with the user picking personas/options instead of blocking the first run. Idempotent across Streamlit reruns: guarded per cached dataset instance. """ warm = getattr(dataset, "prefetch_qa", None) if warm is None: return # persona-only dataset (e.g. Nemotron) has no QA if isinstance(dataset, SynthPersonaDataset): # Extract will need QA soon. Make the one-time large transfer explicit, # then leave the CPU-heavy parse on the existing background thread. _download_missing_startup_files_if_needed( _SYNTH_PERSONA_REPO, (_SYNTH_PERSONA_QA_FILE,), "SynthPersona QA", ) with _qa_warm_lock: if getattr(dataset, "_qa_warm_started", False): return dataset._qa_warm_started = True threading.Thread(target=warm, name="persona-ui-warm-qa", daemon=True).start() @st.cache_resource(show_spinner=False) def _cached_local_dataset(personas_path: str, qa_path: str) -> LocalPersonaDataset: """Instantiate and cache a local upload dataset for stable temp paths.""" return LocalPersonaDataset(personas_path=Path(personas_path), qa_path=Path(qa_path)) def _upload_cache_dir() -> Path: cache_dir = st.session_state.get("_upload_cache_dir") if cache_dir is None: cache_dir = mkdtemp(prefix="persona_vectors_uploads_") st.session_state["_upload_cache_dir"] = cache_dir # Register cleanup so the temp dir is removed when the server process exits. atexit.register(shutil.rmtree, cache_dir, ignore_errors=True) return Path(cache_dir) def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path: suffix = Path(uploaded_file.name).suffix or ".jsonl" data = uploaded_file.getvalue() digest = hashlib.sha256(data).hexdigest() temp_path = _upload_cache_dir() / f"{stem}_{digest[:16]}{suffix}" if temp_path.exists(): return temp_path temp_path.write_bytes(data) return temp_path def load_persona_list( dataset_source: str, personas_file: Any = None, qa_file: Any = None, ) -> tuple[list, str]: """Like ``load_dataset`` but returns ``(personas_list, status)``. The list is memoized on the cached dataset instance so repeated reruns don't pay for re-iteration. """ dataset, status = load_dataset(dataset_source, personas_file, qa_file) return load_persona_list_from_dataset(dataset), status def load_persona_list_from_dataset(dataset: Any) -> list: """Materialize and cache personas from an already-loaded dataset.""" cached = getattr(dataset, "_persona_list_cache", None) if cached is None: cached = list(dataset) try: dataset._persona_list_cache = cached except (AttributeError, TypeError): pass return cached def load_dataset( dataset_source: str, personas_file: Any = None, qa_file: Any = None, ) -> tuple[ SynthPersonaDataset | NemotronPersonasFranceDataset | NemotronPersonasUSADataset | LocalPersonaDataset, str, ]: """Load the selected dataset source for the UI.""" if dataset_source == DatasetSource.SYNTH_PERSONA.value: _download_missing_startup_files_if_needed( _SYNTH_PERSONA_REPO, _SYNTH_PERSONA_STARTUP_FILES, "SynthPersona", ) return _cached_dataset(SynthPersonaDataset), "SynthPersona" if dataset_source == DatasetSource.NEMOTRON_FRANCE.value: _prepare_nemotron_startup_download(dataset_source, "Nemotron France") return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France" if dataset_source == DatasetSource.NEMOTRON_USA.value: _prepare_nemotron_startup_download(dataset_source, "Nemotron USA") return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA" if personas_file is None or qa_file is None: raise ValueError("Upload both personas.jsonl and qa.jsonl files") personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas") qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa") return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload" def _is_cached(repo_id: str, filename: str) -> bool: """Return whether a Hub dataset file already exists in the local HF cache.""" cached = try_to_load_from_cache(repo_id, filename, repo_type="dataset") return isinstance(cached, str) def _download_missing_startup_files_if_needed( repo_id: str, filenames: tuple[str, ...], label: str, ) -> None: """Make first-time Hub downloads visible before dataset construction blocks. Hugging Face handles byte-level transfer internally. We expose file-level progress here, which is the useful unit this UI can know in advance. """ missing = tuple( filename for filename in filenames if not _is_cached(repo_id, filename) ) if not missing: return notice = st.empty() notice.warning( f"First-time setup for {label}: downloading dataset files from Hugging Face. " "Later loads should use the local cache." ) progress = st.progress(0.0, text=f"Preparing {label} download…") total = len(missing) for index, filename in enumerate(missing, start=1): progress.progress( (index - 1) / total, text=f"Downloading {filename} ({index}/{total})", ) hf_hub_download(repo_id, filename, repo_type="dataset") progress.progress( index / total, text=f"Downloaded {filename} ({index}/{total})", ) notice.empty() def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None: """Prefetch the first parquet shard used by the default Nemotron sample.""" repo_id = _NEMOTRON_REPOS[dataset_source] parquet_files = tuple( sorted( filename for filename in list_repo_files(repo_id, repo_type="dataset") if filename.startswith("data/train-") and filename.endswith(".parquet") ) ) if parquet_files: _download_missing_startup_files_if_needed(repo_id, (parquet_files[0],), label)