"""Cached loaders for the two ELSI waves with encoding handling and value labels. Performance note: pyreadstat.set_value_labels iterates every column and is prohibitively slow on ELSI's 1000+ column files. We build our own labeled view that maps only columns with declared value labels using vectorized ops, then pickle the result so subsequent loads are < 1 second. """ from __future__ import annotations import functools import pickle import pandas as pd import pyreadstat from pathlib import Path from typing import Any from app.config import WAVE_FILES, WAVE_ENCODING, WAVE_LABEL, OUTPUT_DIR def _label_columns(df: pd.DataFrame, value_labels: dict[str, dict]) -> pd.DataFrame: """Vectorized: replace numeric codes with their string labels for columns that have value labels declared. Other columns pass through unchanged.""" out = df.copy() for col, mapping in value_labels.items(): if col not in out.columns: continue s = out[col] if pd.api.types.is_numeric_dtype(s): m = {float(k): str(v) for k, v in mapping.items()} mapped = s.map(m) # Where mapped is NaN but original is not, fall back to string fallback = s.where(s.isna(), s.astype("Float64").astype(str)) out[col] = mapped.where(mapped.notna() | s.isna(), fallback) else: m = {str(k): str(v) for k, v in mapping.items()} out[col] = s.map(m).fillna(s) return out class WaveBundle: def __init__(self, wave: int, df: pd.DataFrame, df_labeled: pd.DataFrame, meta_dict: dict): self.wave = wave self.df = df self.df_labeled = df_labeled self.meta = meta_dict self.label = WAVE_LABEL[wave] self.var_labels: dict[str, str] = meta_dict.get("var_labels", {}) self.value_labels: dict[str, dict] = meta_dict.get("value_labels", {}) @property def n(self) -> int: return len(self.df) def label_for(self, var: str) -> str: return self.var_labels.get(var, "") or "" def value_map(self, var: str) -> dict | None: return self.value_labels.get(var) def variables(self) -> list[str]: return list(self.df.columns) def _cache_path(wave: int) -> Path: return OUTPUT_DIR / f"_cache_wave{wave}.pkl" @functools.lru_cache(maxsize=4) def load_wave(wave: int) -> WaveBundle: if wave not in WAVE_FILES: raise ValueError(f"Unknown wave {wave}") cache = _cache_path(wave) src = Path(WAVE_FILES[wave]) if cache.exists() and cache.stat().st_mtime > src.stat().st_mtime: with open(cache, "rb") as f: obj = pickle.load(f) return WaveBundle(wave, obj["df"], obj["df_labeled"], obj["meta"]) encoding = WAVE_ENCODING[wave] df, meta = pyreadstat.read_dta(str(src), encoding=encoding) var_labels = dict(meta.column_names_to_labels or {}) value_labels = dict(meta.variable_value_labels or {}) df_labeled = _label_columns(df, value_labels) meta_dict = {"var_labels": var_labels, "value_labels": value_labels} cache.parent.mkdir(parents=True, exist_ok=True) with open(cache, "wb") as f: pickle.dump({"df": df, "df_labeled": df_labeled, "meta": meta_dict}, f, protocol=4) return WaveBundle(wave, df, df_labeled, meta_dict) def load_all() -> dict[int, WaveBundle]: return {w: load_wave(w) for w in WAVE_FILES}