File size: 3,407 Bytes
097315f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Cached loaders for the two ELSI waves with encoding handling and value labels.

Performance note: pyreadstat.set_value_labels iterates every column and is
prohibitively slow on ELSI's 1000+ column files. We build our own labeled
view that maps only columns with declared value labels using vectorized
ops, then pickle the result so subsequent loads are < 1 second.
"""
from __future__ import annotations
import functools
import pickle
import pandas as pd
import pyreadstat
from pathlib import Path
from typing import Any
from app.config import WAVE_FILES, WAVE_ENCODING, WAVE_LABEL, OUTPUT_DIR


def _label_columns(df: pd.DataFrame, value_labels: dict[str, dict]) -> pd.DataFrame:
    """Vectorized: replace numeric codes with their string labels for columns
    that have value labels declared. Other columns pass through unchanged."""
    out = df.copy()
    for col, mapping in value_labels.items():
        if col not in out.columns:
            continue
        s = out[col]
        if pd.api.types.is_numeric_dtype(s):
            m = {float(k): str(v) for k, v in mapping.items()}
            mapped = s.map(m)
            # Where mapped is NaN but original is not, fall back to string
            fallback = s.where(s.isna(), s.astype("Float64").astype(str))
            out[col] = mapped.where(mapped.notna() | s.isna(), fallback)
        else:
            m = {str(k): str(v) for k, v in mapping.items()}
            out[col] = s.map(m).fillna(s)
    return out


class WaveBundle:
    def __init__(self, wave: int, df: pd.DataFrame, df_labeled: pd.DataFrame, meta_dict: dict):
        self.wave = wave
        self.df = df
        self.df_labeled = df_labeled
        self.meta = meta_dict
        self.label = WAVE_LABEL[wave]
        self.var_labels: dict[str, str] = meta_dict.get("var_labels", {})
        self.value_labels: dict[str, dict] = meta_dict.get("value_labels", {})

    @property
    def n(self) -> int:
        return len(self.df)

    def label_for(self, var: str) -> str:
        return self.var_labels.get(var, "") or ""

    def value_map(self, var: str) -> dict | None:
        return self.value_labels.get(var)

    def variables(self) -> list[str]:
        return list(self.df.columns)


def _cache_path(wave: int) -> Path:
    return OUTPUT_DIR / f"_cache_wave{wave}.pkl"


@functools.lru_cache(maxsize=4)
def load_wave(wave: int) -> WaveBundle:
    if wave not in WAVE_FILES:
        raise ValueError(f"Unknown wave {wave}")
    cache = _cache_path(wave)
    src = Path(WAVE_FILES[wave])
    if cache.exists() and cache.stat().st_mtime > src.stat().st_mtime:
        with open(cache, "rb") as f:
            obj = pickle.load(f)
        return WaveBundle(wave, obj["df"], obj["df_labeled"], obj["meta"])

    encoding = WAVE_ENCODING[wave]
    df, meta = pyreadstat.read_dta(str(src), encoding=encoding)
    var_labels = dict(meta.column_names_to_labels or {})
    value_labels = dict(meta.variable_value_labels or {})
    df_labeled = _label_columns(df, value_labels)

    meta_dict = {"var_labels": var_labels, "value_labels": value_labels}
    cache.parent.mkdir(parents=True, exist_ok=True)
    with open(cache, "wb") as f:
        pickle.dump({"df": df, "df_labeled": df_labeled, "meta": meta_dict}, f, protocol=4)
    return WaveBundle(wave, df, df_labeled, meta_dict)


def load_all() -> dict[int, WaveBundle]:
    return {w: load_wave(w) for w in WAVE_FILES}