File size: 4,399 Bytes
2cc5253 eeeaee6 2cc5253 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Pre-process survey data and generate config artifacts.
Validates the raw CSV, applies the same data-cleaning steps used by
src/train.py, then writes:
- config/valid_categories.yaml β valid input values for runtime guardrails
- config/currency_rates.yaml β per-country median currency conversion rates
Run before ``make train`` to validate data and pre-generate configs, or
standalone to inspect what categories the current dataset supports.
"""
import sys
from pathlib import Path
import pandas as pd
import yaml
from src.train import (
CATEGORICAL_FEATURES,
apply_cardinality_reduction,
compute_currency_rates,
drop_other_rows,
extract_valid_categories,
filter_salaries,
)
REQUIRED_COLUMNS = [
"Country",
"YearsCode",
"WorkExp",
"EdLevel",
"DevType",
"Industry",
"Age",
"ICorPM",
"OrgSize",
"Employment",
"Currency",
"CompTotal",
"ConvertedCompYearly",
]
def validate_columns(data_path: Path) -> None:
"""Exit 1 if any required column is absent from the CSV header."""
header = pd.read_csv(data_path, nrows=0)
missing = [c for c in REQUIRED_COLUMNS if c not in header.columns]
if missing:
print(f"Error: missing required columns: {missing}")
sys.exit(1)
print(f"All {len(REQUIRED_COLUMNS)} required columns present.")
def print_category_summary(df: pd.DataFrame) -> None:
"""Print the number of unique categories per categorical feature."""
for col in CATEGORICAL_FEATURES:
n = df[col].dropna().nunique()
print(f" {col}: {n} categories")
def main() -> None:
"""Validate data, apply preprocessing, and write config artifacts."""
config_path = Path("config/model_parameters.yaml")
with open(config_path) as f:
config = yaml.safe_load(f)
data_path = Path("data/survey_results_public.csv")
# Step 1 β Validate data file ------------------------------------------------
print("=" * 60)
print("STEP 1 β Validate data file")
print("=" * 60)
if not data_path.exists():
print(f"Error: {data_path} not found.")
print("Download from: https://insights.stackoverflow.com/survey")
sys.exit(1)
print(f"Checking columns in {data_path} ...")
validate_columns(data_path)
# Step 2 β Load and filter salaries ------------------------------------------
print("\n" + "=" * 60)
print("STEP 2 β Load and filter salaries")
print("=" * 60)
df = pd.read_csv(data_path, usecols=REQUIRED_COLUMNS)
print(f"Loaded {len(df):,} rows")
df = filter_salaries(df, config)
print(f"After salary filtering: {len(df):,} rows")
# Step 3 β Cardinality reduction ---------------------------------------------
print("\n" + "=" * 60)
print("STEP 3 β Cardinality reduction")
print("=" * 60)
df = apply_cardinality_reduction(df)
before = len(df)
df = drop_other_rows(df, config)
drop_cols = config["features"]["cardinality"].get("drop_other_from", [])
if drop_cols:
print(f"Dropped {before - len(df):,} rows with 'Other' in {drop_cols}")
print(f"Final dataset: {len(df):,} rows")
# Step 4 β Category summary --------------------------------------------------
print("\n" + "=" * 60)
print("STEP 4 β Category summary")
print("=" * 60)
print_category_summary(df)
# Step 5 β Write config artifacts --------------------------------------------
print("\n" + "=" * 60)
print("STEP 5 β Write config artifacts")
print("=" * 60)
valid_categories = extract_valid_categories(df)
vc_path = Path("config/valid_categories.yaml")
with open(vc_path, "w") as f:
yaml.dump(valid_categories, f, default_flow_style=False, sort_keys=False)
n_total = sum(len(v) for v in valid_categories.values())
print(f"Saved {vc_path} ({n_total} total valid values)")
currency_rates = compute_currency_rates(df, valid_categories["Country"])
cr_path = Path("config/currency_rates.yaml")
with open(cr_path, "w") as f:
yaml.dump(
currency_rates,
f,
default_flow_style=False,
sort_keys=True,
allow_unicode=True,
)
print(f"Saved {cr_path} ({len(currency_rates)} countries)")
print("\nPre-processing complete. Ready for `make train`.")
if __name__ == "__main__":
main()
|