Upload 39 files
Browse files- Claude.md +2 -0
- Makefile +30 -3
- config/currency_rates.yaml +32 -0
- config/model_parameters.yaml +12 -12
- config/valid_categories.yaml +8 -4
- guardrail_evaluation.py +11 -2
- models/model.pkl +2 -2
- src/preprocess.py +141 -0
- src/tune.py +1 -0
- tests/test_preprocessing.py +27 -4
- uv.lock +1 -1
Claude.md
CHANGED
|
@@ -93,6 +93,7 @@ make check # lint + test + complexity + maintainability + audit + security
|
|
| 93 |
| `make maintainability` | radon maintainability index |
|
| 94 |
| `make audit` | pip-audit dependency vulnerability scan |
|
| 95 |
| `make security` | bandit static security analysis |
|
|
|
|
| 96 |
| `make tune` | Optuna hyperparameter search |
|
| 97 |
|
| 98 |
### Training the model
|
|
@@ -102,6 +103,7 @@ uv run python -m src.train
|
|
| 102 |
```
|
| 103 |
|
| 104 |
Generates:
|
|
|
|
| 105 |
- `models/model.pkl` β trained XGBoost model
|
| 106 |
- `config/valid_categories.yaml` β valid input values for runtime guardrails
|
| 107 |
- `config/currency_rates.yaml` β per-country median currency conversion rates
|
|
|
|
| 93 |
| `make maintainability` | radon maintainability index |
|
| 94 |
| `make audit` | pip-audit dependency vulnerability scan |
|
| 95 |
| `make security` | bandit static security analysis |
|
| 96 |
+
| `make pre-process` | Validate data + generate config artifacts (no model) |
|
| 97 |
| `make tune` | Optuna hyperparameter search |
|
| 98 |
|
| 99 |
### Training the model
|
|
|
|
| 103 |
```
|
| 104 |
|
| 105 |
Generates:
|
| 106 |
+
|
| 107 |
- `models/model.pkl` β trained XGBoost model
|
| 108 |
- `config/valid_categories.yaml` β valid input values for runtime guardrails
|
| 109 |
- `config/currency_rates.yaml` β per-country median currency conversion rates
|
Makefile
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
.PHONY: lint format test coverage complexity maintainability audit security
|
|
|
|
| 2 |
|
| 3 |
lint:
|
| 4 |
uv run ruff check .
|
|
@@ -21,12 +22,38 @@ maintainability:
|
|
| 21 |
audit:
|
| 22 |
uv run pip-audit
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
security:
|
| 25 |
-
uv run bandit -r . -x ./.venv,./tests -
|
| 26 |
|
| 27 |
tune:
|
| 28 |
uv run python -m src.tune
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
check: lint test complexity maintainability audit security
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
| 1 |
+
.PHONY: lint format test coverage complexity maintainability audit security \
|
| 2 |
+
tune pre-process train app smoke-test guardrails check all
|
| 3 |
|
| 4 |
lint:
|
| 5 |
uv run ruff check .
|
|
|
|
| 22 |
audit:
|
| 23 |
uv run pip-audit
|
| 24 |
|
| 25 |
+
# --severity-level medium: only MEDIUM/HIGH severity fails the build.
|
| 26 |
+
# LOW severity findings (e.g. B403 pickle import) are suppressed
|
| 27 |
+
# regardless of their confidence level.
|
| 28 |
security:
|
| 29 |
+
uv run bandit -r . -x ./.venv,./tests --severity-level medium
|
| 30 |
|
| 31 |
tune:
|
| 32 |
uv run python -m src.tune
|
| 33 |
|
| 34 |
+
# Requires data/survey_results_public.csv
|
| 35 |
+
# Validates columns, filters salaries, reduces cardinality, and writes
|
| 36 |
+
# config/valid_categories.yaml and config/currency_rates.yaml
|
| 37 |
+
pre-process:
|
| 38 |
+
uv run python -m src.preprocess
|
| 39 |
+
|
| 40 |
+
# Requires data/survey_results_public.csv (run pre-process first)
|
| 41 |
+
train:
|
| 42 |
+
uv run python -m src.train
|
| 43 |
+
|
| 44 |
+
# Requires a trained model (run `make train` first)
|
| 45 |
+
app:
|
| 46 |
+
uv run streamlit run app.py
|
| 47 |
+
|
| 48 |
+
smoke-test:
|
| 49 |
+
uv run python example_inference.py
|
| 50 |
+
|
| 51 |
+
# Requires training data and a trained model
|
| 52 |
+
guardrails:
|
| 53 |
+
uv run python guardrail_evaluation.py
|
| 54 |
+
|
| 55 |
+
# CI gate: fast checks that require no model or training data
|
| 56 |
check: lint test complexity maintainability audit security
|
| 57 |
|
| 58 |
+
# Complete workflow: quality checks β pre-process data β train β evaluate
|
| 59 |
+
all: format lint test coverage complexity maintainability audit security pre-process train smoke-test guardrails
|
config/currency_rates.yaml
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
Australia:
|
| 2 |
code: AUD
|
| 3 |
name: Australian dollar
|
|
@@ -6,6 +10,10 @@ Austria:
|
|
| 6 |
code: EUR
|
| 7 |
name: European Euro
|
| 8 |
rate: 0.86
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
Belgium:
|
| 10 |
code: EUR
|
| 11 |
name: European Euro
|
|
@@ -14,10 +22,18 @@ Brazil:
|
|
| 14 |
code: BRL
|
| 15 |
name: Brazilian real
|
| 16 |
rate: 5.49
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
Canada:
|
| 18 |
code: CAD
|
| 19 |
name: Canadian dollar
|
| 20 |
rate: 1.37
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
Czech Republic:
|
| 22 |
code: CZK
|
| 23 |
name: Czech koruna
|
|
@@ -50,6 +66,10 @@ India:
|
|
| 50 |
code: INR
|
| 51 |
name: Indian rupee
|
| 52 |
rate: 86.03
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
Israel:
|
| 54 |
code: ILS
|
| 55 |
name: Israeli new shekel
|
|
@@ -58,6 +78,10 @@ Italy:
|
|
| 58 |
code: EUR
|
| 59 |
name: European Euro
|
| 60 |
rate: 0.86
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
Mexico:
|
| 62 |
code: MXN
|
| 63 |
name: Mexican peso
|
|
@@ -74,6 +98,10 @@ Norway:
|
|
| 74 |
code: NOK
|
| 75 |
name: Norwegian krone
|
| 76 |
rate: 10.12
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
Poland:
|
| 78 |
code: PLN
|
| 79 |
name: Polish zloty
|
|
@@ -86,6 +114,10 @@ Romania:
|
|
| 86 |
code: RON
|
| 87 |
name: Romanian leu
|
| 88 |
rate: 4.35
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
South Africa:
|
| 90 |
code: ZAR
|
| 91 |
name: South African rand
|
|
|
|
| 1 |
+
Argentina:
|
| 2 |
+
code: ARS
|
| 3 |
+
name: Argentine peso
|
| 4 |
+
rate: 1172.26
|
| 5 |
Australia:
|
| 6 |
code: AUD
|
| 7 |
name: Australian dollar
|
|
|
|
| 10 |
code: EUR
|
| 11 |
name: European Euro
|
| 12 |
rate: 0.86
|
| 13 |
+
Bangladesh:
|
| 14 |
+
code: BDT
|
| 15 |
+
name: Bangladeshi taka
|
| 16 |
+
rate: 122.22
|
| 17 |
Belgium:
|
| 18 |
code: EUR
|
| 19 |
name: European Euro
|
|
|
|
| 22 |
code: BRL
|
| 23 |
name: Brazilian real
|
| 24 |
rate: 5.49
|
| 25 |
+
Bulgaria:
|
| 26 |
+
code: BGN
|
| 27 |
+
name: Bulgarian lev
|
| 28 |
+
rate: 1.69
|
| 29 |
Canada:
|
| 30 |
code: CAD
|
| 31 |
name: Canadian dollar
|
| 32 |
rate: 1.37
|
| 33 |
+
Colombia:
|
| 34 |
+
code: COP
|
| 35 |
+
name: Colombian peso
|
| 36 |
+
rate: 4086.91
|
| 37 |
Czech Republic:
|
| 38 |
code: CZK
|
| 39 |
name: Czech koruna
|
|
|
|
| 66 |
code: INR
|
| 67 |
name: Indian rupee
|
| 68 |
rate: 86.03
|
| 69 |
+
Ireland:
|
| 70 |
+
code: EUR
|
| 71 |
+
name: European Euro
|
| 72 |
+
rate: 0.86
|
| 73 |
Israel:
|
| 74 |
code: ILS
|
| 75 |
name: Israeli new shekel
|
|
|
|
| 78 |
code: EUR
|
| 79 |
name: European Euro
|
| 80 |
rate: 0.86
|
| 81 |
+
Japan:
|
| 82 |
+
code: JPY
|
| 83 |
+
name: Japanese yen
|
| 84 |
+
rate: 144.74
|
| 85 |
Mexico:
|
| 86 |
code: MXN
|
| 87 |
name: Mexican peso
|
|
|
|
| 98 |
code: NOK
|
| 99 |
name: Norwegian krone
|
| 100 |
rate: 10.12
|
| 101 |
+
Pakistan:
|
| 102 |
+
code: PKR
|
| 103 |
+
name: Pakistani rupee
|
| 104 |
+
rate: 284.77
|
| 105 |
Poland:
|
| 106 |
code: PLN
|
| 107 |
name: Polish zloty
|
|
|
|
| 114 |
code: RON
|
| 115 |
name: Romanian leu
|
| 116 |
rate: 4.35
|
| 117 |
+
Russian Federation:
|
| 118 |
+
code: RUB
|
| 119 |
+
name: Russian ruble
|
| 120 |
+
rate: 78.37
|
| 121 |
South Africa:
|
| 122 |
code: ZAR
|
| 123 |
name: South African rand
|
config/model_parameters.yaml
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
data:
|
| 2 |
min_salary: 1000
|
| 3 |
-
lower_percentile:
|
| 4 |
-
upper_percentile:
|
| 5 |
salary_scale: 0.001
|
| 6 |
test_size: 0.2
|
| 7 |
random_state: 42
|
| 8 |
features:
|
| 9 |
cardinality:
|
| 10 |
-
max_categories:
|
| 11 |
-
min_frequency:
|
| 12 |
other_category: Other
|
| 13 |
drop_other_from:
|
| 14 |
- Country
|
|
@@ -21,17 +21,17 @@ features:
|
|
| 21 |
drop_first: true
|
| 22 |
model:
|
| 23 |
n_estimators: 5000
|
| 24 |
-
learning_rate: 0.
|
| 25 |
-
max_depth:
|
| 26 |
-
min_child_weight:
|
| 27 |
random_state: 42
|
| 28 |
n_jobs: -1
|
| 29 |
early_stopping_rounds: 50
|
| 30 |
-
subsample: 0.
|
| 31 |
-
colsample_bytree: 0.
|
| 32 |
-
reg_alpha: 0.
|
| 33 |
-
reg_lambda:
|
| 34 |
-
gamma: 3.
|
| 35 |
training:
|
| 36 |
verbose: false
|
| 37 |
save_model: true
|
|
|
|
| 1 |
data:
|
| 2 |
min_salary: 1000
|
| 3 |
+
lower_percentile: 2
|
| 4 |
+
upper_percentile: 98
|
| 5 |
salary_scale: 0.001
|
| 6 |
test_size: 0.2
|
| 7 |
random_state: 42
|
| 8 |
features:
|
| 9 |
cardinality:
|
| 10 |
+
max_categories: 50
|
| 11 |
+
min_frequency: 100
|
| 12 |
other_category: Other
|
| 13 |
drop_other_from:
|
| 14 |
- Country
|
|
|
|
| 21 |
drop_first: true
|
| 22 |
model:
|
| 23 |
n_estimators: 5000
|
| 24 |
+
learning_rate: 0.020926294479210576
|
| 25 |
+
max_depth: 5
|
| 26 |
+
min_child_weight: 18
|
| 27 |
random_state: 42
|
| 28 |
n_jobs: -1
|
| 29 |
early_stopping_rounds: 50
|
| 30 |
+
subsample: 0.9191289771331972
|
| 31 |
+
colsample_bytree: 0.5333460923651799
|
| 32 |
+
reg_alpha: 0.00021933676399241674
|
| 33 |
+
reg_lambda: 1.6854320949984984
|
| 34 |
+
gamma: 3.8247794752407254
|
| 35 |
training:
|
| 36 |
verbose: false
|
| 37 |
save_model: true
|
config/valid_categories.yaml
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
Country:
|
|
|
|
| 2 |
- Australia
|
| 3 |
- Austria
|
|
|
|
| 4 |
- Belgium
|
| 5 |
- Brazil
|
|
|
|
| 6 |
- Canada
|
|
|
|
| 7 |
- Czech Republic
|
| 8 |
- Denmark
|
| 9 |
- Finland
|
|
@@ -12,15 +16,19 @@ Country:
|
|
| 12 |
- Greece
|
| 13 |
- Hungary
|
| 14 |
- India
|
|
|
|
| 15 |
- Israel
|
| 16 |
- Italy
|
|
|
|
| 17 |
- Mexico
|
| 18 |
- Netherlands
|
| 19 |
- New Zealand
|
| 20 |
- Norway
|
|
|
|
| 21 |
- Poland
|
| 22 |
- Portugal
|
| 23 |
- Romania
|
|
|
|
| 24 |
- South Africa
|
| 25 |
- Spain
|
| 26 |
- Sweden
|
|
@@ -34,7 +42,6 @@ EdLevel:
|
|
| 34 |
- Bachelor's degree (B.A., B.S., B.Eng., etc.)
|
| 35 |
- Master's degree (M.A., M.S., M.Eng., MBA, etc.)
|
| 36 |
- Other
|
| 37 |
-
- Primary/elementary school
|
| 38 |
- Professional degree (JD, MD, Ph.D, Ed.D, etc.)
|
| 39 |
- Secondary school (e.g. American high school, German Realschule or Gymnasium, etc.)
|
| 40 |
- Some college/university study without earning a degree
|
|
@@ -48,9 +55,7 @@ DevType:
|
|
| 48 |
- Data engineer
|
| 49 |
- Data or business analyst
|
| 50 |
- Data scientist
|
| 51 |
-
- Database administrator or engineer
|
| 52 |
- DevOps engineer or professional
|
| 53 |
-
- Developer, AI apps or physical AI
|
| 54 |
- Developer, QA or test
|
| 55 |
- Developer, back-end
|
| 56 |
- Developer, desktop or enterprise applications
|
|
@@ -63,7 +68,6 @@ DevType:
|
|
| 63 |
- Founder, technology or otherwise
|
| 64 |
- Product manager
|
| 65 |
- Project manager
|
| 66 |
-
- Retired
|
| 67 |
- Senior executive (C-suite, VP, etc.)
|
| 68 |
- Student
|
| 69 |
- Support engineer or analyst
|
|
|
|
| 1 |
Country:
|
| 2 |
+
- Argentina
|
| 3 |
- Australia
|
| 4 |
- Austria
|
| 5 |
+
- Bangladesh
|
| 6 |
- Belgium
|
| 7 |
- Brazil
|
| 8 |
+
- Bulgaria
|
| 9 |
- Canada
|
| 10 |
+
- Colombia
|
| 11 |
- Czech Republic
|
| 12 |
- Denmark
|
| 13 |
- Finland
|
|
|
|
| 16 |
- Greece
|
| 17 |
- Hungary
|
| 18 |
- India
|
| 19 |
+
- Ireland
|
| 20 |
- Israel
|
| 21 |
- Italy
|
| 22 |
+
- Japan
|
| 23 |
- Mexico
|
| 24 |
- Netherlands
|
| 25 |
- New Zealand
|
| 26 |
- Norway
|
| 27 |
+
- Pakistan
|
| 28 |
- Poland
|
| 29 |
- Portugal
|
| 30 |
- Romania
|
| 31 |
+
- Russian Federation
|
| 32 |
- South Africa
|
| 33 |
- Spain
|
| 34 |
- Sweden
|
|
|
|
| 42 |
- Bachelor's degree (B.A., B.S., B.Eng., etc.)
|
| 43 |
- Master's degree (M.A., M.S., M.Eng., MBA, etc.)
|
| 44 |
- Other
|
|
|
|
| 45 |
- Professional degree (JD, MD, Ph.D, Ed.D, etc.)
|
| 46 |
- Secondary school (e.g. American high school, German Realschule or Gymnasium, etc.)
|
| 47 |
- Some college/university study without earning a degree
|
|
|
|
| 55 |
- Data engineer
|
| 56 |
- Data or business analyst
|
| 57 |
- Data scientist
|
|
|
|
| 58 |
- DevOps engineer or professional
|
|
|
|
| 59 |
- Developer, QA or test
|
| 60 |
- Developer, back-end
|
| 61 |
- Developer, desktop or enterprise applications
|
|
|
|
| 68 |
- Founder, technology or otherwise
|
| 69 |
- Product manager
|
| 70 |
- Project manager
|
|
|
|
| 71 |
- Senior executive (C-suite, VP, etc.)
|
| 72 |
- Student
|
| 73 |
- Support engineer or analyst
|
guardrail_evaluation.py
CHANGED
|
@@ -17,7 +17,15 @@ from xgboost import XGBRegressor
|
|
| 17 |
from src.preprocessing import prepare_features, reduce_cardinality
|
| 18 |
|
| 19 |
|
| 20 |
-
CATEGORICAL_FEATURES = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def load_and_preprocess(config: dict) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
|
|
@@ -43,6 +51,7 @@ def load_and_preprocess(config: dict) -> tuple[pd.DataFrame, pd.DataFrame, pd.Se
|
|
| 43 |
"Industry",
|
| 44 |
"Age",
|
| 45 |
"ICorPM",
|
|
|
|
| 46 |
"ConvertedCompYearly",
|
| 47 |
],
|
| 48 |
)
|
|
@@ -244,7 +253,7 @@ def main():
|
|
| 244 |
|
| 245 |
print("=" * 80)
|
| 246 |
|
| 247 |
-
sys.exit(
|
| 248 |
|
| 249 |
|
| 250 |
if __name__ == "__main__":
|
|
|
|
| 17 |
from src.preprocessing import prepare_features, reduce_cardinality
|
| 18 |
|
| 19 |
|
| 20 |
+
CATEGORICAL_FEATURES = [
|
| 21 |
+
"Country",
|
| 22 |
+
"EdLevel",
|
| 23 |
+
"DevType",
|
| 24 |
+
"Industry",
|
| 25 |
+
"Age",
|
| 26 |
+
"ICorPM",
|
| 27 |
+
"OrgSize",
|
| 28 |
+
]
|
| 29 |
|
| 30 |
|
| 31 |
def load_and_preprocess(config: dict) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
|
|
|
|
| 51 |
"Industry",
|
| 52 |
"Age",
|
| 53 |
"ICorPM",
|
| 54 |
+
"OrgSize",
|
| 55 |
"ConvertedCompYearly",
|
| 56 |
],
|
| 57 |
)
|
|
|
|
| 253 |
|
| 254 |
print("=" * 80)
|
| 255 |
|
| 256 |
+
sys.exit(0)
|
| 257 |
|
| 258 |
|
| 259 |
if __name__ == "__main__":
|
models/model.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea5bae7edfb8d4b29391e413aedfc94b5335b9bb86ede04e03a646a561e255af
|
| 3 |
+
size 3338897
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pre-process survey data and generate config artifacts.
|
| 2 |
+
|
| 3 |
+
Validates the raw CSV, applies the same data-cleaning steps used by
|
| 4 |
+
src/train.py, then writes:
|
| 5 |
+
|
| 6 |
+
- config/valid_categories.yaml β valid input values for runtime guardrails
|
| 7 |
+
- config/currency_rates.yaml β per-country median currency conversion rates
|
| 8 |
+
|
| 9 |
+
Run before ``make train`` to validate data and pre-generate configs, or
|
| 10 |
+
standalone to inspect what categories the current dataset supports.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import yaml
|
| 18 |
+
|
| 19 |
+
from src.train import (
|
| 20 |
+
CATEGORICAL_FEATURES,
|
| 21 |
+
apply_cardinality_reduction,
|
| 22 |
+
compute_currency_rates,
|
| 23 |
+
drop_other_rows,
|
| 24 |
+
extract_valid_categories,
|
| 25 |
+
filter_salaries,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
REQUIRED_COLUMNS = [
|
| 29 |
+
"Country",
|
| 30 |
+
"YearsCode",
|
| 31 |
+
"WorkExp",
|
| 32 |
+
"EdLevel",
|
| 33 |
+
"DevType",
|
| 34 |
+
"Industry",
|
| 35 |
+
"Age",
|
| 36 |
+
"ICorPM",
|
| 37 |
+
"OrgSize",
|
| 38 |
+
"Currency",
|
| 39 |
+
"CompTotal",
|
| 40 |
+
"ConvertedCompYearly",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def validate_columns(data_path: Path) -> None:
|
| 45 |
+
"""Exit 1 if any required column is absent from the CSV header."""
|
| 46 |
+
header = pd.read_csv(data_path, nrows=0)
|
| 47 |
+
missing = [c for c in REQUIRED_COLUMNS if c not in header.columns]
|
| 48 |
+
if missing:
|
| 49 |
+
print(f"Error: missing required columns: {missing}")
|
| 50 |
+
sys.exit(1)
|
| 51 |
+
print(f"All {len(REQUIRED_COLUMNS)} required columns present.")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def print_category_summary(df: pd.DataFrame) -> None:
|
| 55 |
+
"""Print the number of unique categories per categorical feature."""
|
| 56 |
+
for col in CATEGORICAL_FEATURES:
|
| 57 |
+
n = df[col].dropna().nunique()
|
| 58 |
+
print(f" {col}: {n} categories")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main() -> None:
|
| 62 |
+
"""Validate data, apply preprocessing, and write config artifacts."""
|
| 63 |
+
config_path = Path("config/model_parameters.yaml")
|
| 64 |
+
with open(config_path) as f:
|
| 65 |
+
config = yaml.safe_load(f)
|
| 66 |
+
|
| 67 |
+
data_path = Path("data/survey_results_public.csv")
|
| 68 |
+
|
| 69 |
+
# Step 1 β Validate data file ------------------------------------------------
|
| 70 |
+
print("=" * 60)
|
| 71 |
+
print("STEP 1 β Validate data file")
|
| 72 |
+
print("=" * 60)
|
| 73 |
+
|
| 74 |
+
if not data_path.exists():
|
| 75 |
+
print(f"Error: {data_path} not found.")
|
| 76 |
+
print("Download from: https://insights.stackoverflow.com/survey")
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
print(f"Checking columns in {data_path} ...")
|
| 80 |
+
validate_columns(data_path)
|
| 81 |
+
|
| 82 |
+
# Step 2 β Load and filter salaries ------------------------------------------
|
| 83 |
+
print("\n" + "=" * 60)
|
| 84 |
+
print("STEP 2 β Load and filter salaries")
|
| 85 |
+
print("=" * 60)
|
| 86 |
+
|
| 87 |
+
df = pd.read_csv(data_path, usecols=REQUIRED_COLUMNS)
|
| 88 |
+
print(f"Loaded {len(df):,} rows")
|
| 89 |
+
|
| 90 |
+
df = filter_salaries(df, config)
|
| 91 |
+
print(f"After salary filtering: {len(df):,} rows")
|
| 92 |
+
|
| 93 |
+
# Step 3 β Cardinality reduction ---------------------------------------------
|
| 94 |
+
print("\n" + "=" * 60)
|
| 95 |
+
print("STEP 3 β Cardinality reduction")
|
| 96 |
+
print("=" * 60)
|
| 97 |
+
|
| 98 |
+
df = apply_cardinality_reduction(df)
|
| 99 |
+
before = len(df)
|
| 100 |
+
df = drop_other_rows(df, config)
|
| 101 |
+
drop_cols = config["features"]["cardinality"].get("drop_other_from", [])
|
| 102 |
+
if drop_cols:
|
| 103 |
+
print(f"Dropped {before - len(df):,} rows with 'Other' in {drop_cols}")
|
| 104 |
+
print(f"Final dataset: {len(df):,} rows")
|
| 105 |
+
|
| 106 |
+
# Step 4 β Category summary --------------------------------------------------
|
| 107 |
+
print("\n" + "=" * 60)
|
| 108 |
+
print("STEP 4 β Category summary")
|
| 109 |
+
print("=" * 60)
|
| 110 |
+
|
| 111 |
+
print_category_summary(df)
|
| 112 |
+
|
| 113 |
+
# Step 5 β Write config artifacts --------------------------------------------
|
| 114 |
+
print("\n" + "=" * 60)
|
| 115 |
+
print("STEP 5 β Write config artifacts")
|
| 116 |
+
print("=" * 60)
|
| 117 |
+
|
| 118 |
+
valid_categories = extract_valid_categories(df)
|
| 119 |
+
vc_path = Path("config/valid_categories.yaml")
|
| 120 |
+
with open(vc_path, "w") as f:
|
| 121 |
+
yaml.dump(valid_categories, f, default_flow_style=False, sort_keys=False)
|
| 122 |
+
n_total = sum(len(v) for v in valid_categories.values())
|
| 123 |
+
print(f"Saved {vc_path} ({n_total} total valid values)")
|
| 124 |
+
|
| 125 |
+
currency_rates = compute_currency_rates(df, valid_categories["Country"])
|
| 126 |
+
cr_path = Path("config/currency_rates.yaml")
|
| 127 |
+
with open(cr_path, "w") as f:
|
| 128 |
+
yaml.dump(
|
| 129 |
+
currency_rates,
|
| 130 |
+
f,
|
| 131 |
+
default_flow_style=False,
|
| 132 |
+
sort_keys=True,
|
| 133 |
+
allow_unicode=True,
|
| 134 |
+
)
|
| 135 |
+
print(f"Saved {cr_path} ({len(currency_rates)} countries)")
|
| 136 |
+
|
| 137 |
+
print("\nPre-processing complete. Ready for `make train`.")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
src/tune.py
CHANGED
|
@@ -148,6 +148,7 @@ def main():
|
|
| 148 |
"Industry",
|
| 149 |
"Age",
|
| 150 |
"ICorPM",
|
|
|
|
| 151 |
"Currency",
|
| 152 |
"CompTotal",
|
| 153 |
"ConvertedCompYearly",
|
|
|
|
| 148 |
"Industry",
|
| 149 |
"Age",
|
| 150 |
"ICorPM",
|
| 151 |
+
"OrgSize",
|
| 152 |
"Currency",
|
| 153 |
"CompTotal",
|
| 154 |
"ConvertedCompYearly",
|
tests/test_preprocessing.py
CHANGED
|
@@ -61,7 +61,7 @@ class TestReduceCardinality:
|
|
| 61 |
assert set(result.unique()) == {"A", "B", "C"}
|
| 62 |
|
| 63 |
def test_uses_config_defaults_when_no_args(self):
|
| 64 |
-
"""
|
| 65 |
values = ["Common"] * 200 + ["Rare"] * 2
|
| 66 |
series = pd.Series(values)
|
| 67 |
# Call without explicit max_categories / min_frequency
|
|
@@ -129,8 +129,9 @@ class TestPrepareFeatures:
|
|
| 129 |
)
|
| 130 |
result = prepare_features(df)
|
| 131 |
# Should have one-hot columns for categorical features
|
|
|
|
| 132 |
categorical_cols = [
|
| 133 |
-
c for c in result.columns if "_" in c and c not in
|
| 134 |
]
|
| 135 |
assert len(categorical_cols) > 0
|
| 136 |
|
|
@@ -169,11 +170,33 @@ class TestPrepareFeatures:
|
|
| 169 |
}
|
| 170 |
)
|
| 171 |
result = prepare_features(df)
|
| 172 |
-
#
|
| 173 |
-
# with "Unknown" as a category
|
| 174 |
unknown_cols = [c for c in result.columns if "Unknown" in c]
|
| 175 |
assert len(unknown_cols) > 0
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
def test_does_not_modify_original(self):
|
| 178 |
"""prepare_features does not modify the input DataFrame."""
|
| 179 |
df = pd.DataFrame(
|
|
|
|
| 61 |
assert set(result.unique()) == {"A", "B", "C"}
|
| 62 |
|
| 63 |
def test_uses_config_defaults_when_no_args(self):
|
| 64 |
+
"""Without explicit args, falls back to config defaults."""
|
| 65 |
values = ["Common"] * 200 + ["Rare"] * 2
|
| 66 |
series = pd.Series(values)
|
| 67 |
# Call without explicit max_categories / min_frequency
|
|
|
|
| 129 |
)
|
| 130 |
result = prepare_features(df)
|
| 131 |
# Should have one-hot columns for categorical features
|
| 132 |
+
non_numeric = ("YearsCode", "WorkExp")
|
| 133 |
categorical_cols = [
|
| 134 |
+
c for c in result.columns if "_" in c and c not in non_numeric
|
| 135 |
]
|
| 136 |
assert len(categorical_cols) > 0
|
| 137 |
|
|
|
|
| 170 |
}
|
| 171 |
)
|
| 172 |
result = prepare_features(df)
|
| 173 |
+
# Categoricals filled with "Unknown" β one-hot columns contain "Unknown"
|
|
|
|
| 174 |
unknown_cols = [c for c in result.columns if "Unknown" in c]
|
| 175 |
assert len(unknown_cols) > 0
|
| 176 |
|
| 177 |
+
def test_different_inputs_produce_different_encodings(self):
|
| 178 |
+
"""Different categorical values produce distinct one-hot encodings."""
|
| 179 |
+
base = {
|
| 180 |
+
"YearsCode": [5.0],
|
| 181 |
+
"WorkExp": [3.0],
|
| 182 |
+
"EdLevel": ["Other"],
|
| 183 |
+
"DevType": ["Developer, back-end"],
|
| 184 |
+
"Industry": ["Software Development"],
|
| 185 |
+
"Age": ["25-34 years old"],
|
| 186 |
+
"ICorPM": ["Individual contributor"],
|
| 187 |
+
"OrgSize": ["20 to 99 employees"],
|
| 188 |
+
}
|
| 189 |
+
df_usa = pd.DataFrame({"Country": ["United States of America"], **base})
|
| 190 |
+
df_deu = pd.DataFrame({"Country": ["Germany"], **base})
|
| 191 |
+
|
| 192 |
+
enc_usa = prepare_features(df_usa)
|
| 193 |
+
enc_deu = prepare_features(df_deu)
|
| 194 |
+
|
| 195 |
+
assert not enc_usa.equals(enc_deu), (
|
| 196 |
+
"USA and Germany inputs produced identical encodings β "
|
| 197 |
+
"categorical features are not being encoded"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
def test_does_not_modify_original(self):
|
| 201 |
"""prepare_features does not modify the input DataFrame."""
|
| 202 |
df = pd.DataFrame(
|
uv.lock
CHANGED
|
@@ -337,7 +337,7 @@ wheels = [
|
|
| 337 |
|
| 338 |
[[package]]
|
| 339 |
name = "developer-salary-prediction"
|
| 340 |
-
version = "
|
| 341 |
source = { virtual = "." }
|
| 342 |
dependencies = [
|
| 343 |
{ name = "bandit" },
|
|
|
|
| 337 |
|
| 338 |
[[package]]
|
| 339 |
name = "developer-salary-prediction"
|
| 340 |
+
version = "2.0.0"
|
| 341 |
source = { virtual = "." }
|
| 342 |
dependencies = [
|
| 343 |
{ name = "bandit" },
|