dima806 commited on
Commit
1a584f9
·
verified ·
1 Parent(s): 561c0bd

Upload 38 files

Browse files
.github/workflows/ci.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: ["**"]
6
+
7
+ jobs:
8
+ lint-and-test:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+
14
+ - name: Set up Python
15
+ uses: actions/setup-python@v5
16
+ with:
17
+ python-version: "3.12"
18
+
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@v5
21
+
22
+ - name: Install dependencies
23
+ run: uv sync --all-extras
24
+
25
+ - name: Lint
26
+ run: make lint
27
+
28
+ - name: Test
29
+ run: make test
.gitignore CHANGED
@@ -217,4 +217,4 @@ data/*.zip
217
  # models/*.joblib
218
 
219
  # LLM
220
- .llm/
 
217
  # models/*.joblib
218
 
219
  # LLM
220
+ .llm/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: mixed-line-ending
8
+ args: [--fix=lf]
9
+ - id: check-yaml
10
+ - id: check-toml
11
+ - id: check-json
12
+ - id: check-added-large-files
13
+ args: [--maxkb=1000]
14
+ - id: check-merge-conflict
15
+ - id: debug-statements
16
+
17
+ - repo: local
18
+ hooks:
19
+ - id: format
20
+ name: ruff format
21
+ entry: make format
22
+ language: system
23
+ types: [python]
24
+ pass_filenames: false
25
+
26
+ - id: lint
27
+ name: ruff lint
28
+ entry: make lint
29
+ language: system
30
+ types: [python]
31
+ pass_filenames: false
Claude.md CHANGED
@@ -1,229 +1,303 @@
1
  # Claude Development Guide
2
 
3
  ## Project Overview
4
- This is a minimal, local-first ML application built in Python that predicts developer salaries using Stack Overflow Developer Survey data. The project emphasizes clarity and simplicity over production completeness.
 
 
 
5
 
6
  ## Tech Stack
7
- - **Python 3.11+**
8
- - **uv** - Package & virtual environment management
9
- - **pandas** - Data manipulation
10
- - **scikit-learn** - ML modeling
11
- - **pydantic** - Input validation
12
- - **streamlit** - Web UI
13
- - **xgboost** - Advanced gradient boosting (optional)
 
 
 
 
 
 
 
14
 
15
  ## Project Structure
16
- ```
 
17
  .
 
 
 
 
 
 
 
 
18
  ├── data/
19
- │ └── survey_results_public.csv # Stack Overflow survey data
20
  ├── models/
21
- │ └── model.pkl # Serialized trained model
22
  ├── src/
23
- │ ├── schema.py # Pydantic validation models
24
- │ ├── train.py # Model training script
25
- ── infer.py # Inference utilities
26
- ├── app.py # Streamlit web application
27
- ├── example_inference.py # Example inference script
28
- ── pyproject.toml # Project dependencies (uv)
29
- ├── uv.lock # Locked dependencies
30
- ── README.md # Project documentation
 
 
 
 
 
 
 
 
 
 
 
 
31
  ```
32
 
33
  ## Setup & Installation
34
 
35
- ### Initial Setup
36
  ```bash
37
- # The virtual environment is already created at .venv/
38
- # Activate it:
39
- source .venv/bin/activate # On Linux/Mac
40
- # or
41
- .venv\Scripts\activate # On Windows
42
-
43
- # Install/sync dependencies with uv:
44
  uv sync
 
 
 
45
  ```
46
 
47
- ### Adding New Dependencies
 
 
 
 
 
48
  ```bash
49
- uv add <package-name>
50
  ```
51
 
52
- ## Key Workflows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- ### Training the Model
55
  ```bash
56
- python src/train.py
 
 
57
  ```
58
- This will:
59
- - Load data from `data/survey_results_public.csv`
60
- - Clean and preprocess features
61
- - Train the regression model
62
- - Save model to `models/model.pkl`
63
 
64
- ### Running the Streamlit App
 
 
 
 
65
  ```bash
66
- streamlit run app.py
67
  ```
68
- Opens a browser interface for salary predictions.
69
 
70
- ### Running Inference Programmatically
 
71
  ```python
72
  from src.schema import SalaryInput
73
  from src.infer import predict_salary
74
 
75
  input_data = SalaryInput(
76
- country="United States",
77
  years_code=5.0,
78
- education_level="Bachelor's degree",
79
- dev_type="Developer, back-end",
 
80
  industry="Software Development",
81
- age="25-34 years old"
 
 
82
  )
83
  salary = predict_salary(input_data)
84
  ```
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ## Key Files
87
 
88
  ### [src/schema.py](src/schema.py)
89
- Contains Pydantic models for:
90
- - Input validation (`SalaryInput`)
91
- - Type safety across the application
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  ### [src/train.py](src/train.py)
94
- Training pipeline:
95
- - Data loading and cleaning
96
- - Feature engineering
97
- - Model training
98
- - Model persistence
 
 
 
 
 
 
 
 
 
 
99
 
100
  ### [src/infer.py](src/infer.py)
101
- Inference utilities:
102
- - Model loading
103
- - Prediction logic
104
- - Validated input processing
 
105
 
106
  ### [app.py](app.py)
107
- Streamlit UI:
108
- - User input forms
109
- - Real-time predictions
110
- - Results visualization
111
-
112
- ## Development Guidelines
113
-
114
- ### Code Style
115
- - Keep code simple and readable
116
- - Total codebase should remain under ~200 lines
117
- - Focus on clarity over cleverness
118
- - Use type hints where helpful
119
-
120
- ### Data Requirements
121
- The dataset must include these columns:
122
- - `Country` - Developer location
123
- - `YearsCode` - Total years of coding (including education)
124
- - `EdLevel` - Education level
125
- - `DevType` - Developer type
126
- - `Industry` - Industry the developer works in
127
- - `Age` - Developer's age range
128
- - `ConvertedCompYearly` - Annual salary (target variable)
129
-
130
- ### Model Expectations
131
- - Basic regression model (LinearRegression or similar)
132
- - Simple feature encoding (one-hot for categoricals)
133
- - No hyperparameter tuning required
134
- - Focus on working end-to-end pipeline
135
-
136
- ## Common Tasks
137
-
138
- ### Debugging Training Issues
139
- 1. Check if data file exists: `ls -la data/`
140
- 2. Verify CSV columns: `head -1 data/survey_results_public.csv`
141
- 3. Check for missing values in target column
142
- 4. Review data types and encoding
143
-
144
- ### Updating Features
145
- 1. Modify `SalaryInput` schema in [src/schema.py](src/schema.py)
146
- 2. Update feature extraction in [src/train.py](src/train.py)
147
- 3. Update inference logic in [src/infer.py](src/infer.py)
148
- 4. Update UI inputs in [app.py](app.py)
149
- 5. Retrain the model
150
-
151
- ### Testing Predictions
152
- ```python
153
- # Quick test in Python REPL
154
- from src.infer import predict_salary
155
- from src.schema import SalaryInput
156
 
157
- test_input = SalaryInput(
158
- country="United States",
159
- years_code=3.0,
160
- education_level="Bachelor's degree",
161
- dev_type="Developer, back-end",
162
- industry="Software Development",
163
- age="25-34 years old"
164
- )
165
- print(predict_salary(test_input))
166
- ```
167
 
168
- ## Non-Goals (Intentionally Excluded)
169
- - Cloud deployment or serving
170
- - Hyperparameter tuning
171
- - Model registry or experiment tracking
172
- - Advanced feature engineering
173
- - Production monitoring
174
- - API endpoints (beyond Streamlit)
175
 
176
- ## Useful Commands
177
 
178
- ```bash
179
- # Check environment
180
- which python
181
- python --version
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Verify uv installation
184
- uv --version
185
 
186
- # List installed packages
187
- uv pip list
 
188
 
189
- # Run with specific Python version
190
- uv run python src/train.py
191
 
192
- # Clean generated files
193
- rm -f models/model.pkl
194
 
195
- # Check data file size
196
- du -h data/survey_results_public.csv
 
197
  ```
198
 
199
- ## Troubleshooting
200
 
201
- ### Model file not found
202
- - Run training first: `python src/train.py`
203
- - Check file exists: `ls -la models/model.pkl`
 
204
 
205
- ### Missing dependencies
206
- - Sync environment: `uv sync`
207
- - Verify pyproject.toml has all required packages
 
 
208
 
209
- ### Data file issues
210
- - Ensure CSV is in `data/` directory
211
- - Check file encoding (should be UTF-8)
212
- - Verify required columns exist
213
 
214
- ### Streamlit won't start
215
- - Check port 8501 is available
216
- - Try specifying port: `streamlit run app.py --server.port 8502`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  ## Additional Resources
219
- - [PRD](.llm/prd.md) - Full product requirements
220
- - [README.md](README.md) - Project readme
221
- - [Stack Overflow Survey](https://insights.stackoverflow.com/survey) - Data source
222
-
223
- ## Working with Claude Code
224
- When asking Claude to help with this project:
225
- - Reference specific files using markdown links: [filename](path)
226
- - Be specific about which component needs changes
227
- - Mention if you need training, inference, or UI updates
228
- - Provide error messages in full when debugging
229
- - Ask for explanations of model choices if unclear
 
1
  # Claude Development Guide
2
 
3
  ## Project Overview
4
+
5
+ A minimal, local-first ML application that predicts developer salaries using Stack Overflow
6
+ Developer Survey data. Built with Python 3.12, XGBoost, Pydantic v2, and Streamlit. Emphasises
7
+ clarity and simplicity over production completeness.
8
 
9
  ## Tech Stack
10
+
11
+ - **Python 3.12+**
12
+ - **uv** Package & virtual environment management
13
+ - **pandas** Data manipulation
14
+ - **xgboost** Gradient boosting model (primary model)
15
+ - **scikit-learn** — Cross-validation and train/test split
16
+ - **optuna** Hyperparameter optimisation
17
+ - **pydantic** — Input schema validation (v2)
18
+ - **streamlit** — Web UI
19
+ - **ruff** — Linting and formatting
20
+ - **radon** — Cyclomatic complexity and maintainability metrics
21
+ - **bandit** — Static security analysis
22
+ - **pip-audit** — Dependency vulnerability scanning
23
+ - **pre-commit** — Git hook management
24
 
25
  ## Project Structure
26
+
27
+ ```text
28
  .
29
+ ├── .github/
30
+ │ └── workflows/
31
+ │ └── ci.yml # GitHub Actions CI (lint + test on every push)
32
+ ├── config/
33
+ │ ├── model_parameters.yaml # Model config, guardrails, cardinality settings
34
+ │ ├── optuna_config.yaml # Optuna search space and trial settings
35
+ │ ├── valid_categories.yaml # Valid input categories (generated by training)
36
+ │ └── currency_rates.yaml # Per-country currency rates (generated by training)
37
  ├── data/
38
+ │ └── survey_results_public.csv # Stack Overflow survey data (download required)
39
  ├── models/
40
+ │ └── model.pkl # Trained model artifact (generated by training)
41
  ├── src/
42
+ │ ├── __init__.py
43
+ │ ├── schema.py # Pydantic input model (SalaryInput)
44
+ ── preprocessing.py # Feature engineering (one-hot encoding, scaling)
45
+ ├── train.py # Training pipeline
46
+ ├── tune.py # Optuna hyperparameter optimisation
47
+ │ └── infer.py # Inference with runtime category guardrails
48
+ ├── tests/
49
+ │ ├── conftest.py # Shared pytest fixtures
50
+ │ ├── test_schema.py # Pydantic validation tests
51
+ │ ├── test_infer.py # Inference and guardrail tests
52
+ │ ├── test_train.py # Training pipeline helper tests
53
+ │ ├── test_preprocessing.py # Feature engineering tests
54
+ │ ├── test_tune.py # Optuna tuning tests
55
+ │ └── test_feature_impact.py # Model sanity — each feature affects predictions
56
+ ├── app.py # Streamlit web app
57
+ ├── example_inference.py # Programmatic usage examples
58
+ ├── Makefile # Developer workflow commands
59
+ ├── .pre-commit-config.yaml # Pre-commit hooks (format, lint, standard checks)
60
+ ├── pyproject.toml # Project metadata and dependencies (uv)
61
+ └── README.md # Project documentation + HuggingFace Space config
62
  ```
63
 
64
  ## Setup & Installation
65
 
 
66
  ```bash
67
+ # Install dependencies
 
 
 
 
 
 
68
  uv sync
69
+
70
+ # Install pre-commit hooks (once, after cloning)
71
+ uv run pre-commit install
72
  ```
73
 
74
+ ## Key Workflows
75
+
76
+ All common tasks are available via the Makefile. Run `make help` to list targets.
77
+
78
+ ### Full quality check
79
+
80
  ```bash
81
+ make check # lint + test + complexity + maintainability + audit + security
82
  ```
83
 
84
+ ### Individual targets
85
+
86
+ | Target | What it does |
87
+ | ------ | ------------ |
88
+ | `make lint` | ruff check (style + errors) |
89
+ | `make format` | ruff format (auto-format) |
90
+ | `make test` | pytest — all tests |
91
+ | `make coverage` | pytest with HTML coverage report |
92
+ | `make complexity` | radon cyclomatic complexity |
93
+ | `make maintainability` | radon maintainability index |
94
+ | `make audit` | pip-audit dependency vulnerability scan |
95
+ | `make security` | bandit static security analysis |
96
+ | `make tune` | Optuna hyperparameter search |
97
+
98
+ ### Training the model
99
+
100
+ ```bash
101
+ uv run python -m src.train
102
+ ```
103
+
104
+ Generates:
105
+ - `models/model.pkl` — trained XGBoost model
106
+ - `config/valid_categories.yaml` — valid input values for runtime guardrails
107
+ - `config/currency_rates.yaml` — per-country median currency conversion rates
108
+
109
+ ### Hyperparameter tuning (optional, run before training)
110
 
 
111
  ```bash
112
+ make tune
113
+ # or
114
+ uv run python -m src.tune --n-trials 50
115
  ```
 
 
 
 
 
116
 
117
+ Reads search space from `config/optuna_config.yaml`, writes best parameters back into
118
+ `config/model_parameters.yaml`.
119
+
120
+ ### Running the Streamlit app
121
+
122
  ```bash
123
+ uv run streamlit run app.py
124
  ```
 
125
 
126
+ ### Running inference programmatically
127
+
128
  ```python
129
  from src.schema import SalaryInput
130
  from src.infer import predict_salary
131
 
132
  input_data = SalaryInput(
133
+ country="United States of America",
134
  years_code=5.0,
135
+ work_exp=3.0,
136
+ education_level="Bachelor's degree (B.A., B.S., B.Eng., etc.)",
137
+ dev_type="Developer, full-stack",
138
  industry="Software Development",
139
+ age="25-34 years old",
140
+ ic_or_pm="Individual contributor",
141
+ org_size="20 to 99 employees",
142
  )
143
  salary = predict_salary(input_data)
144
  ```
145
 
146
+ Valid values for each categorical field are listed in `config/valid_categories.yaml`
147
+ (generated at training time).
148
+
149
+ ## Data Requirements
150
+
151
+ The `survey_results_public.csv` must include these columns:
152
+
153
+ | Column | Description |
154
+ | ------ | ----------- |
155
+ | `Country` | Developer's country of residence |
156
+ | `YearsCode` | Total years coding (including education) |
157
+ | `WorkExp` | Years of professional work experience |
158
+ | `EdLevel` | Highest education level |
159
+ | `DevType` | Primary developer role |
160
+ | `Industry` | Industry the developer works in |
161
+ | `Age` | Age range |
162
+ | `ICorPM` | Individual contributor or people manager |
163
+ | `OrgSize` | Organisation size (number of employees) |
164
+ | `ConvertedCompYearly` | Annual salary in USD (target variable) |
165
+
166
+ ## Input Validation (Two Layers)
167
+
168
+ ### Layer 1 — Pydantic schema (`src/schema.py`)
169
+
170
+ All 9 fields are required. `years_code` and `work_exp` must be `>= 0`. Validated at
171
+ object construction time — raises `ValidationError` on failure.
172
+
173
+ ### Layer 2 — Runtime guardrails (`src/infer.py`)
174
+
175
+ Each categorical field is checked against `config/valid_categories.yaml` at inference
176
+ time. Raises `ValueError` with a clear message on invalid input.
177
+
178
  ## Key Files
179
 
180
  ### [src/schema.py](src/schema.py)
181
+
182
+ Pydantic v2 `SalaryInput` model — defines all 9 required input fields, types, and
183
+ constraints. The JSON schema example in the docstring is the canonical usage example.
184
+
185
+ ### [src/preprocessing.py](src/preprocessing.py)
186
+
187
+ `prepare_features(df)` — takes a raw DataFrame and returns an encoded feature matrix:
188
+
189
+ - Unicode apostrophe normalisation on all categorical columns
190
+ - Rare category → "Other" normalisation
191
+ - Missing numeric values → 0
192
+ - Missing categoricals → "Unknown"
193
+ - One-hot encoding (training uses `drop_first=True`; inference uses `drop_first=False`
194
+ then reindexes to match training columns)
195
 
196
  ### [src/train.py](src/train.py)
197
+
198
+ Full training pipeline:
199
+
200
+ 1. Load CSV, filter salaries (min/max percentile per country)
201
+ 2. `apply_cardinality_reduction` — collapse rare categories to "Other"
202
+ 3. `drop_other_rows` — remove rows with "Other" in specified columns
203
+ 4. `prepare_features` — encode features
204
+ 5. 5-fold CV with MAPE metric
205
+ 6. Train final XGBoost model on full data with early stopping
206
+ 7. Save `models/model.pkl`, `config/valid_categories.yaml`, `config/currency_rates.yaml`
207
+
208
+ ### [src/tune.py](src/tune.py)
209
+
210
+ Optuna study over the search space in `config/optuna_config.yaml`. Writes best
211
+ parameters back into `config/model_parameters.yaml` after the study completes.
212
 
213
  ### [src/infer.py](src/infer.py)
214
+
215
+ `predict_salary(SalaryInput)` validates categories, builds a single-row DataFrame,
216
+ runs `prepare_features`, reindexes to training columns, returns float USD salary.
217
+ `get_local_currency(country, salary)` converts to local currency using rates from
218
+ `config/currency_rates.yaml`.
219
 
220
  ### [app.py](app.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ Streamlit UI with dropdowns populated from `config/valid_categories.yaml` (only values
223
+ that appeared in the training data). Shows USD + local currency side by side.
 
 
 
 
 
 
 
 
224
 
225
+ ## Updating Features
 
 
 
 
 
 
226
 
227
+ When adding a new input feature, update **all** of the following in order:
228
 
229
+ 1. `config/model_parameters.yaml` — add to `features.cardinality.drop_other_from` if applicable
230
+ 2. `src/schema.py` — add field to `SalaryInput`
231
+ 3. `src/preprocessing.py` — add to `_categorical_cols` (or numeric handling)
232
+ 4. `src/train.py` — add to `CATEGORICAL_FEATURES` and `usecols`
233
+ 5. `src/infer.py` — add validation block and DataFrame column
234
+ 6. `app.py` — add selectbox, default, sidebar entry, `SalaryInput` construction
235
+ 7. `tests/conftest.py` — add to `sample_salary_input` fixture
236
+ 8. `tests/test_schema.py` — assert field, add missing-field test
237
+ 9. `tests/test_infer.py` — add invalid-value test
238
+ 10. `tests/test_feature_impact.py` — add to all `base_input` dicts, add impact test
239
+ 11. `tests/test_preprocessing.py` — add column to all `pd.DataFrame(...)` fixtures
240
+ 12. `tests/test_train.py` — add column to `_make_salary_df` and all test DataFrames
241
+ 13. `README.md` — required columns, valid categories list, code example
242
+ 14. `example_inference.py` — add to all `SalaryInput` calls
243
+ 15. Retrain: `uv run python -m src.train`
244
+
245
+ ## Versioning
246
 
247
+ Follows [Semantic Versioning](https://semver.org/):
 
248
 
249
+ - **MAJOR** new required input field, incompatible model artifact, renamed API
250
+ - **MINOR** — new optional field, new supported country, new Makefile target
251
+ - **PATCH** — bug fix, model retrain with same schema, config tuning
252
 
253
+ Current version: `2.0.0` (added `OrgSize` required field).
 
254
 
255
+ Update `pyproject.toml` before tagging:
 
256
 
257
+ ```bash
258
+ git tag v2.0.0
259
+ git push origin v2.0.0
260
  ```
261
 
262
+ ## Code Style
263
 
264
+ - Line length: **79 characters** (enforced by ruff `E501`)
265
+ - Formatter: `ruff format` (run via `make format`)
266
+ - Linter: `ruff check` (run via `make lint`)
267
+ - No `# noqa` suppressions — fix the code instead
268
 
269
+ ## Common Debugging
270
+
271
+ ```bash
272
+ # Check data columns
273
+ head -1 data/survey_results_public.csv | tr ',' '\n'
274
 
275
+ # Verify model exists
276
+ ls -lh models/model.pkl
 
 
277
 
278
+ # Check valid categories after training
279
+ cat config/valid_categories.yaml
280
+
281
+ # Run a single test file
282
+ uv run pytest tests/test_infer.py -v
283
+
284
+ # Run pre-commit manually
285
+ uv run pre-commit run --all-files
286
+ ```
287
+
288
+ ## Troubleshooting
289
+
290
+ | Symptom | Fix |
291
+ | ------- | --- |
292
+ | `FileNotFoundError: model.pkl` | Run `uv run python -m src.train` |
293
+ | `FileNotFoundError: valid_categories.yaml` | Same — generated by training |
294
+ | `ValidationError` on `SalaryInput` | Check all 9 fields are present and non-negative numerics |
295
+ | `ValueError: Invalid ...` at inference | Value not in `config/valid_categories.yaml`; retrain or use a listed value |
296
+ | `E501` ruff errors | Lines > 79 chars — split strings, use variables, or wrap lists |
297
+ | Tests fail after adding a feature | Check the "Updating Features" checklist above |
298
 
299
  ## Additional Resources
300
+
301
+ - [README.md](README.md) — User-facing documentation and HuggingFace Space config
302
+ - [Stack Overflow Survey](https://insights.stackoverflow.com/survey) Data source
303
+ - [semver.org](https://semver.org/) — Versioning reference
 
 
 
 
 
 
 
README.md CHANGED
@@ -14,14 +14,14 @@ license: apache-2.0
14
 
15
  # Developer Salary Prediction
16
 
17
- A minimal, local-first ML application that predicts developer salaries using Stack Overflow Developer Survey data. Built with Python, scikit-learn, Pydantic, and Streamlit.
18
 
19
  ## Features
20
 
21
  - 🎯 XGBoost (gradient boosting) model for salary prediction
22
- - ✅ Input validation with Pydantic
23
  - 🌐 Interactive web UI with Streamlit
24
- - 📊 Trained on Stack Overflow Developer Survey data
25
  - 🔧 Easy setup with `uv` package manager
26
 
27
  ## Quick Start
@@ -37,14 +37,15 @@ uv sync
37
  Download the Stack Overflow Developer Survey CSV file:
38
 
39
  1. Visit: https://insights.stackoverflow.com/survey
40
- 2. Download the latest survey results (2024 or 2025)
41
  3. Extract the `survey_results_public.csv` file
42
  4. Place it in the `data/` directory:
43
- ```
 
44
  data/survey_results_public.csv
45
  ```
46
 
47
- **Required columns:** `Country`, `YearsCode`, `WorkExp`, `EdLevel`, `DevType`, `Industry`, `Age`, `ICorPM`, `ConvertedCompYearly`
48
 
49
  ### 3. Train the Model
50
 
@@ -53,11 +54,14 @@ uv run python -m src.train
53
  ```
54
 
55
  This will:
 
56
  - Load configuration from `config/model_parameters.yaml`
57
- - Load and preprocess the survey data (with cardinality reduction)
58
- - Train an XGBoost model with early stopping
59
- - Save the model to `models/model.pkl`
60
- - Generate `config/valid_categories.yaml` with valid country, education, developer type, industry, age, and IC/PM values
 
 
61
 
62
  ### 4. Run the Streamlit App
63
 
@@ -67,11 +71,68 @@ uv run streamlit run app.py
67
 
68
  The app will open in your browser at `http://localhost:8501`
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ## Usage
71
 
72
  ### Web Interface
73
 
74
  Launch the Streamlit app and enter:
 
75
  - **Country**: Developer's country
76
  - **Years of Coding (Total)**: Total years coding including education
77
  - **Years of Professional Work Experience**: Years of professional work experience
@@ -80,18 +141,17 @@ Launch the Streamlit app and enter:
80
  - **Industry**: Industry the developer works in
81
  - **Age**: Developer's age range
82
  - **IC or PM**: Individual contributor or people manager
 
83
 
84
- Click "Predict Salary" to see the estimated annual salary.
 
85
 
86
  ### Programmatic Usage
87
 
88
- **Quick example:**
89
-
90
  ```python
91
  from src.schema import SalaryInput
92
  from src.infer import predict_salary
93
 
94
- # Create input
95
  input_data = SalaryInput(
96
  country="United States of America",
97
  years_code=5.0,
@@ -100,10 +160,10 @@ input_data = SalaryInput(
100
  dev_type="Developer, full-stack",
101
  industry="Software Development",
102
  age="25-34 years old",
103
- ic_or_pm="Individual contributor"
 
104
  )
105
 
106
- # Get prediction
107
  salary = predict_salary(input_data)
108
  print(f"Estimated salary: ${salary:,.0f}")
109
  ```
@@ -114,186 +174,365 @@ print(f"Estimated salary: ${salary:,.0f}")
114
  uv run python example_inference.py
115
  ```
116
 
117
- This will show predictions for multiple sample scenarios (junior, mid-level, senior developers, different countries).
 
 
 
 
 
 
118
 
119
- ## Input Validation
 
 
120
 
121
- The model validates inputs against actual training data categories:
122
 
123
- - **Valid Countries**: Only countries from `config/valid_categories.yaml` (~21 countries)
124
- - **Valid Education Levels**: Only education levels from training data (~9 levels)
125
- - **Valid Developer Types**: Only developer types from training data (~20 types)
126
- - **Valid Industries**: Only industries from training data (~15 industries)
127
- - **Valid Age Ranges**: Only age ranges from training data (~7 ranges)
128
- - **Valid IC/PM Values**: Only IC/PM values from training data (~3 values)
129
 
130
- The Streamlit app uses dropdown menus with only valid options. If you use the programmatic API with invalid values, you'll get a helpful error message pointing to the valid categories file.
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- **Example validation:**
133
  ```python
134
  from src.infer import predict_salary
135
  from src.schema import SalaryInput
136
 
137
- # This will raise ValueError - Japan not in training data after cardinality reduction
138
- invalid_input = SalaryInput(
139
- country="Japan", # Invalid!
140
- years_code=5.0,
141
- work_exp=3.0,
142
- education_level="Bachelor's degree (B.A., B.S., B.Eng., etc.)",
143
- dev_type="Developer, back-end",
144
- industry="Software Development",
145
- age="25-34 years old",
146
- ic_or_pm="Individual contributor"
147
- )
148
  ```
149
 
150
  **View valid categories:**
 
151
  ```bash
152
  cat config/valid_categories.yaml
153
  ```
154
 
155
- ## Configuration
156
 
157
- Model parameters are centralized in [config/model_parameters.yaml](config/model_parameters.yaml). You can customize:
158
 
159
- - **Data Processing**: Salary thresholds, percentile bounds, train/test split ratio
160
- - **Feature Engineering**: Cardinality reduction settings (max categories, min frequency)
161
- - **Model Hyperparameters**: Learning rate, tree depth, early stopping, etc.
162
- - **Training Settings**: Verbosity, model save path
 
 
 
163
 
164
- **To modify parameters:**
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  ```bash
167
- # Edit the config file
168
- nano config/model_parameters.yaml
169
 
170
- # Then retrain the model
171
- uv run python -m src.train
 
 
172
  ```
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  **Example parameter changes:**
 
175
  ```yaml
176
  # Increase model complexity
177
  model:
178
- max_depth: 8 # Default: 6
179
  n_estimators: 10000 # Default: 5000
180
 
181
  # Keep more categories
182
  features:
183
  cardinality:
184
- max_categories: 30 # Default: 20
185
- min_frequency: 100 # Default: 50
186
  ```
187
 
 
 
 
 
 
 
188
  ## Project Structure
189
 
190
- ```
191
  .
 
 
 
192
  ├── config/
193
- │ ├── model_parameters.yaml # Model configuration
194
- ── valid_categories.yaml # Valid input categories (generated)
 
 
195
  ├── data/
196
  │ └── survey_results_public.csv # Stack Overflow survey data (download required)
197
  ├── models/
198
- │ └── model.pkl # Trained model (generated)
199
  ├── src/
200
- │ ├── __init__.py # Package initialization
201
- │ ├── schema.py # Pydantic models
202
- │ ├── preprocessing.py # Feature engineering utilities
203
- │ ├── train.py # Training script
204
- ── infer.py # Inference utilities
 
 
 
 
 
 
 
 
 
205
  ├── app.py # Streamlit web app
206
- ├── example_inference.py # Example inference script
 
 
207
  ├── pyproject.toml # Project dependencies
208
- └── README.md # This file
209
  ```
210
 
211
  ## Tech Stack
212
 
213
  - **Python 3.12+**
214
- - **uv** - Package manager
215
- - **pandas** - Data manipulation
216
- - **xgboost** - Gradient boosting model
217
- - **scikit-learn** - ML utilities (train/test split)
218
- - **pydantic** - Data validation
219
- - **streamlit** - Web UI
 
 
 
 
 
220
 
221
  ## Development
222
 
223
  For detailed development information, see [Claude.md](Claude.md).
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  ### Re-training the Model
226
 
227
  If you want to use a different survey year or update the model:
228
 
229
  ```bash
230
- # Place new CSV in data/ directory
 
 
 
231
  uv run python -m src.train
232
  ```
233
 
234
  ### Running Tests
235
 
236
- **Quick one-liner test:**
237
  ```bash
238
- uv run python -c "from src.schema import SalaryInput; from src.infer import predict_salary; test = SalaryInput(country='United States of America', years_code=5.0, work_exp=3.0, education_level='Bachelor'\''s degree (B.A., B.S., B.Eng., etc.)', dev_type='Developer, full-stack', industry='Software Development', age='25-34 years old', ic_or_pm='Individual contributor'); print(f'Prediction: \${predict_salary(test):,.0f}')"
239
- ```
240
 
241
- **Or run the full example script:**
242
- ```bash
243
- uv run python example_inference.py
 
 
244
  ```
245
 
246
- ## Deployment
247
 
248
- ### Hugging Face Spaces
 
 
 
 
 
 
249
 
250
- This application is Docker-ready for deployment on Hugging Face Spaces:
 
 
 
 
 
 
 
 
251
 
252
- **1. Build the Docker image:**
253
  ```bash
254
- docker build -t developer-salary-predictor .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  ```
256
 
257
- **2. Test locally:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  ```bash
259
- docker run -p 8501:8501 developer-salary-predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  ```
261
 
262
- Then visit `http://localhost:8501`
 
 
263
 
264
- **3. Deploy to Hugging Face:**
265
 
266
- 1. Create a new Space on [Hugging Face](https://huggingface.co/new-space)
267
- 2. Select "Docker" as the SDK
268
- 3. Clone your Space repository
269
- 4. Copy these files to your Space:
270
 
271
- ```text
272
- Dockerfile
273
- requirements.txt
274
- app.py
275
- src/
276
- config/
277
- models/
278
- ```
279
 
280
- 5. Push to your Space:
281
  ```bash
282
- git add .
283
- git commit -m "Initial deployment"
284
- git push
285
  ```
286
 
287
- **Note:** The pre-trained model (`models/model.pkl`) and configuration (`config/valid_categories.yaml`) are included in the Docker image. If you want to use a different model, retrain locally first, then rebuild the Docker image.
288
 
289
- ### Alternative: Local Deployment
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  **Using uv (recommended for development):**
 
292
  ```bash
293
  uv run streamlit run app.py
294
  ```
295
 
296
  **Using pip:**
 
297
  ```bash
298
  pip install -r requirements.txt
299
  streamlit run app.py
@@ -302,24 +541,34 @@ streamlit run app.py
302
  ## Troubleshooting
303
 
304
  ### "Model file not found"
 
305
  - Run `uv run python -m src.train` first to generate the model
306
 
 
 
 
 
 
307
  ### "Data file not found"
 
308
  - Download the Stack Overflow survey CSV and place it in `data/`
309
 
310
  ### "Configuration file not found"
 
311
  - The `config/model_parameters.yaml` file should exist in the project root
312
  - Check that you're running commands from the project root directory
313
 
314
  ### Dependencies issues
 
315
  - Run `uv sync` to ensure all packages are installed
316
 
317
  ## Design Principles
318
 
319
- - **Simplicity**: Under 200 lines of code total
320
- - **Clarity**: Easy to understand and modify
321
- - **Local-first**: No cloud dependencies
322
- - **Hackable**: Plain Python, no complex frameworks
 
323
 
324
  ## License
325
 
 
14
 
15
  # Developer Salary Prediction
16
 
17
+ A minimal, local-first ML application that predicts developer salaries using Stack Overflow Developer Survey data. Built with Python, XGBoost, Pydantic, and Streamlit.
18
 
19
  ## Features
20
 
21
  - 🎯 XGBoost (gradient boosting) model for salary prediction
22
+ - ✅ Input validation with Pydantic (schema) and runtime guardrails (valid categories)
23
  - 🌐 Interactive web UI with Streamlit
24
+ - 📊 Trained on Stack Overflow 2025 Developer Survey data
25
  - 🔧 Easy setup with `uv` package manager
26
 
27
  ## Quick Start
 
37
  Download the Stack Overflow Developer Survey CSV file:
38
 
39
  1. Visit: https://insights.stackoverflow.com/survey
40
+ 2. Download the latest survey results (2025)
41
  3. Extract the `survey_results_public.csv` file
42
  4. Place it in the `data/` directory:
43
+
44
+ ```text
45
  data/survey_results_public.csv
46
  ```
47
 
48
+ **Required columns:** `Country`, `YearsCode`, `WorkExp`, `EdLevel`, `DevType`, `Industry`, `Age`, `ICorPM`, `OrgSize`, `ConvertedCompYearly`
49
 
50
  ### 3. Train the Model
51
 
 
54
  ```
55
 
56
  This will:
57
+
58
  - Load configuration from `config/model_parameters.yaml`
59
+ - Filter salaries and reduce cardinality of categorical features
60
+ - Run 5-fold cross-validation and report mean MAPE per fold
61
+ - Train a final XGBoost model on the full dataset with early stopping
62
+ - Save the model artifact to `models/model.pkl`
63
+ - Generate `config/valid_categories.yaml` — valid input values for runtime guardrails
64
+ - Generate `config/currency_rates.yaml` — per-country median currency conversion rates
65
 
66
  ### 4. Run the Streamlit App
67
 
 
71
 
72
  The app will open in your browser at `http://localhost:8501`
73
 
74
+ ## Development Cycle
75
+
76
+ The full development workflow from data to deployment:
77
+
78
+ ```text
79
+ data/ ──► (optional) tune ──► train ──► test ──► commit ──► CI passes ──► deploy
80
+ ```
81
+
82
+ ### Step-by-step
83
+
84
+ #### 1. (Optional) Tune hyperparameters
85
+
86
+ Run Optuna to search for optimal XGBoost hyperparameters. The search space is
87
+ defined in `config/optuna_config.yaml`. Best parameters are written directly
88
+ back into `config/model_parameters.yaml`.
89
+
90
+ ```bash
91
+ make tune
92
+ # or with a custom number of trials:
93
+ uv run python -m src.tune --n-trials 50
94
+ ```
95
+
96
+ #### 2. Train the model
97
+
98
+ ```bash
99
+ uv run python -m src.train
100
+ ```
101
+
102
+ #### 3. Check code quality (lint + test + complexity + security)
103
+
104
+ ```bash
105
+ make check
106
+ ```
107
+
108
+ This runs all quality gates in sequence:
109
+
110
+ | Target | Tool | What it checks |
111
+ | ------ | ---- | -------------- |
112
+ | `make lint` | ruff | Style and linting errors |
113
+ | `make format` | ruff | Auto-formats code |
114
+ | `make test` | pytest | Unit and integration tests |
115
+ | `make coverage` | pytest-cov | Test coverage report |
116
+ | `make complexity` | radon CC | Cyclomatic complexity |
117
+ | `make maintainability` | radon MI | Maintainability index |
118
+ | `make audit` | pip-audit | Dependency vulnerability scan |
119
+ | `make security` | bandit | Static security analysis |
120
+
121
+ `make check` runs lint, test, complexity, maintainability, audit, and security together.
122
+ `make all` is an alias for `make check`.
123
+
124
+ #### 4. Run all pre-commit checks manually
125
+
126
+ ```bash
127
+ uv run pre-commit run --all-files
128
+ ```
129
+
130
  ## Usage
131
 
132
  ### Web Interface
133
 
134
  Launch the Streamlit app and enter:
135
+
136
  - **Country**: Developer's country
137
  - **Years of Coding (Total)**: Total years coding including education
138
  - **Years of Professional Work Experience**: Years of professional work experience
 
141
  - **Industry**: Industry the developer works in
142
  - **Age**: Developer's age range
143
  - **IC or PM**: Individual contributor or people manager
144
+ - **Organization Size**: Approximate number of employees at the developer's company
145
 
146
+ Click "Predict Salary" to see the estimated annual salary in USD plus a local
147
+ currency equivalent where available.
148
 
149
  ### Programmatic Usage
150
 
 
 
151
  ```python
152
  from src.schema import SalaryInput
153
  from src.infer import predict_salary
154
 
 
155
  input_data = SalaryInput(
156
  country="United States of America",
157
  years_code=5.0,
 
160
  dev_type="Developer, full-stack",
161
  industry="Software Development",
162
  age="25-34 years old",
163
+ ic_or_pm="Individual contributor",
164
+ org_size="20 to 99 employees",
165
  )
166
 
 
167
  salary = predict_salary(input_data)
168
  print(f"Estimated salary: ${salary:,.0f}")
169
  ```
 
174
  uv run python example_inference.py
175
  ```
176
 
177
+ ## Input Validation and Guardrails
178
+
179
+ Validation is enforced at two layers:
180
+
181
+ ### Layer 1 — Pydantic schema (`src/schema.py`)
182
+
183
+ Checked at object construction time:
184
 
185
+ - All 9 fields are required
186
+ - `years_code` must be `>= 0`
187
+ - `work_exp` must be `>= 0`
188
 
189
+ ### Layer 2 Runtime category guardrails (`src/infer.py`)
190
 
191
+ Checked at inference time against `config/valid_categories.yaml`, which is
192
+ generated during training to reflect only categories that appeared frequently
193
+ enough in the training data (controlled by `features.cardinality.min_frequency`
194
+ in `config/model_parameters.yaml`):
 
 
195
 
196
+ - **Valid Countries** (~21) low-frequency countries collapsed to `Other`, which is then dropped
197
+ - **Valid Education Levels** (~9)
198
+ - **Valid Developer Types** (~20) — `Other` dropped
199
+ - **Valid Industries** (~15) — `Other` dropped
200
+ - **Valid Age Ranges** (~7) — `Other` dropped
201
+ - **Valid IC/PM Values** (~3) — `Other` dropped
202
+ - **Valid Organization Sizes** (~8) — `Other` dropped
203
+
204
+ Passing an invalid value raises a `ValueError` with a message pointing to
205
+ `config/valid_categories.yaml`.
206
+
207
+ **Example:**
208
 
 
209
  ```python
210
  from src.infer import predict_salary
211
  from src.schema import SalaryInput
212
 
213
+ # Raises ValueError: "Invalid country: 'Japan'. Check config/valid_categories.yaml"
214
+ predict_salary(SalaryInput(country="Japan", ...))
 
 
 
 
 
 
 
 
 
215
  ```
216
 
217
  **View valid categories:**
218
+
219
  ```bash
220
  cat config/valid_categories.yaml
221
  ```
222
 
223
+ ### Model guardrails (`config/model_parameters.yaml`)
224
 
225
+ The `guardrails` section defines thresholds used during training evaluation:
226
 
227
+ ```yaml
228
+ guardrails:
229
+ max_mape_per_category: 100 # max acceptable MAPE per category (%)
230
+ max_abs_pct_diff: 100 # max acceptable absolute % difference
231
+ ```
232
+
233
+ ## Testing
234
 
235
+ Tests live in `tests/` and cover all major modules:
236
+
237
+ | File | What it tests |
238
+ | ---- | ------------- |
239
+ | `test_schema.py` | Pydantic validation — required fields, `ge=0` constraints |
240
+ | `test_infer.py` | Inference pipeline — valid predictions, `ValueError` on invalid categories, currency lookup |
241
+ | `test_train.py` | Training helpers — salary filtering, cardinality reduction, valid category extraction, currency rate computation |
242
+ | `test_preprocessing.py` | Feature engineering — one-hot encoding, numeric transforms |
243
+ | `test_tune.py` | Optuna helpers — parameter sampling, objective function construction, best-param saving |
244
+ | `test_feature_impact.py` | Model sanity — changing each input feature (country, education, dev type, etc.) produces a distinct prediction |
245
+
246
+ Run all tests:
247
 
248
  ```bash
249
+ make test
250
+ ```
251
 
252
+ Run with coverage:
253
+
254
+ ```bash
255
+ make coverage
256
  ```
257
 
258
+ ## Configuration
259
+
260
+ All runtime parameters are centralised in two YAML files:
261
+
262
+ ### `config/model_parameters.yaml`
263
+
264
+ Controls data processing, feature engineering, model hyperparameters, training
265
+ settings, and guardrail thresholds. You can customise:
266
+
267
+ - **Data Processing**: Salary thresholds, percentile bounds, train/test split ratio
268
+ - **Feature Engineering**: Cardinality reduction settings (max categories, min frequency)
269
+ - **Model Hyperparameters**: Learning rate, tree depth, early stopping, etc.
270
+ - **Training Settings**: Verbosity, model save path
271
+ - **Guardrails**: MAPE thresholds for model evaluation
272
+
273
  **Example parameter changes:**
274
+
275
  ```yaml
276
  # Increase model complexity
277
  model:
278
+ max_depth: 8 # Default: 3
279
  n_estimators: 10000 # Default: 5000
280
 
281
  # Keep more categories
282
  features:
283
  cardinality:
284
+ max_categories: 30 # Default: 30
285
+ min_frequency: 50 # Default: 50
286
  ```
287
 
288
+ ### `config/optuna_config.yaml`
289
+
290
+ Controls the Optuna hyperparameter search — search space (type, bounds, log
291
+ scale), number of trials, CV folds, and fixed parameters that are not tuned
292
+ (e.g. `n_estimators`, `random_state`).
293
+
294
  ## Project Structure
295
 
296
+ ```text
297
  .
298
+ ├── .github/
299
+ │ └── workflows/
300
+ │ └── ci.yml # GitHub Actions CI (lint + test)
301
  ├── config/
302
+ │ ├── model_parameters.yaml # Model configuration and guardrails
303
+ ── optuna_config.yaml # Optuna hyperparameter search space
304
+ │ ├── valid_categories.yaml # Valid input categories (generated by training)
305
+ │ └── currency_rates.yaml # Per-country currency rates (generated by training)
306
  ├── data/
307
  │ └── survey_results_public.csv # Stack Overflow survey data (download required)
308
  ├── models/
309
+ │ └── model.pkl # Trained model artifact (generated by training)
310
  ├── src/
311
+ │ ├── __init__.py
312
+ │ ├── schema.py # Pydantic input model
313
+ │ ├── preprocessing.py # Feature engineering (one-hot encoding, scaling)
314
+ │ ├── train.py # Training pipeline
315
+ ── tune.py # Optuna hyperparameter optimisation
316
+ │ └── infer.py # Inference with runtime guardrails
317
+ ├── tests/
318
+ │ ├── conftest.py # Shared pytest fixtures
319
+ │ ├── test_schema.py
320
+ │ ├── test_infer.py
321
+ │ ├── test_train.py
322
+ │ ├── test_preprocessing.py
323
+ │ ├── test_tune.py
324
+ │ └── test_feature_impact.py
325
  ├── app.py # Streamlit web app
326
+ ├── example_inference.py # Inference usage examples
327
+ ├── Makefile # Developer workflow commands
328
+ ├── .pre-commit-config.yaml # Pre-commit hooks
329
  ├── pyproject.toml # Project dependencies
330
+ └── README.md # This file (also Hugging Face Space config)
331
  ```
332
 
333
  ## Tech Stack
334
 
335
  - **Python 3.12+**
336
+ - **uv** Package manager
337
+ - **pandas** Data manipulation
338
+ - **xgboost** Gradient boosting model
339
+ - **scikit-learn** — Cross-validation and train/test split
340
+ - **optuna** Hyperparameter optimisation
341
+ - **pydantic** Input schema validation
342
+ - **streamlit** — Web UI
343
+ - **ruff** — Linting and formatting
344
+ - **radon** — Complexity and maintainability metrics
345
+ - **bandit** — Static security analysis
346
+ - **pip-audit** — Dependency vulnerability scanning
347
 
348
  ## Development
349
 
350
  For detailed development information, see [Claude.md](Claude.md).
351
 
352
+ ### Code Quality
353
+
354
+ #### Pre-commit hooks
355
+
356
+ The project uses [pre-commit](https://pre-commit.com) to enforce code quality checks before each commit. Hooks are defined in [.pre-commit-config.yaml](.pre-commit-config.yaml) and run:
357
+
358
+ - **ruff format** — auto-formats Python files (`make format`)
359
+ - **ruff lint** — checks for linting errors (`make lint`)
360
+ - **Standard checks** — trailing whitespace, end-of-file newline, LF line endings, valid YAML/TOML/JSON, large files, merge conflict markers, stray debug statements
361
+
362
+ **Install hooks** (once, after cloning):
363
+
364
+ ```bash
365
+ uv run pre-commit install
366
+ ```
367
+
368
+ Hooks will then run automatically on every `git commit`. To run them manually against all files:
369
+
370
+ ```bash
371
+ uv run pre-commit run --all-files
372
+ ```
373
+
374
+ #### GitHub Actions CI
375
+
376
+ A CI workflow ([.github/workflows/ci.yml](.github/workflows/ci.yml)) runs automatically on every push to any branch. It:
377
+
378
+ 1. Sets up Python 3.12 and installs `uv`
379
+ 2. Installs all dependencies (`uv sync --all-extras`)
380
+ 3. Runs `make lint` — ruff linting
381
+ 4. Runs `make test` — full pytest suite
382
+
383
+ The workflow must pass before merging changes.
384
+
385
  ### Re-training the Model
386
 
387
  If you want to use a different survey year or update the model:
388
 
389
  ```bash
390
+ # 1. Place new CSV in data/
391
+ # 2. (Optional) tune first
392
+ make tune
393
+ # 3. Retrain
394
  uv run python -m src.train
395
  ```
396
 
397
  ### Running Tests
398
 
 
399
  ```bash
400
+ # Run all tests
401
+ make test
402
 
403
+ # Run with coverage report
404
+ make coverage
405
+
406
+ # Run a specific test file
407
+ uv run pytest tests/test_infer.py -v
408
  ```
409
 
410
+ ## Versioning
411
 
412
+ This project follows [Semantic Versioning](https://semver.org/) (`MAJOR.MINOR.PATCH`):
413
+
414
+ | Version bump | When to use | Examples |
415
+ | --- | --- | --- |
416
+ | **MAJOR** | Breaking changes to the public interface | New required input field, incompatible model artifact format, renamed API |
417
+ | **MINOR** | Backward-compatible new features | New optional input field, new supported country, new Makefile target, UI addition |
418
+ | **PATCH** | Backward-compatible fixes and improvements | Bug fixes, model retrain with same schema, config tuning, dependency updates |
419
 
420
+ **Pre-release suffixes** (for work in progress):
421
+
422
+ ```text
423
+ v1.0.0-alpha.1 # early development, unstable
424
+ v1.0.0-beta.1 # feature-complete, under testing
425
+ v1.0.0-rc.1 # release candidate, final validation
426
+ ```
427
+
428
+ Tags are applied on `main` after a successful CI run:
429
 
 
430
  ```bash
431
+ git tag v2.0.0
432
+ git push origin v2.0.0
433
+ ```
434
+
435
+ ## Branching Strategy
436
+
437
+ The project uses a **GitFlow-inspired** branching model:
438
+
439
+ ```text
440
+ main ◄──────────────────────────────────── hotfix/v2.0.1
441
+ ▲ │
442
+ │ merge + tag │
443
+ │ (urgent fix)
444
+ develop ◄──── feature/add-currency-display
445
+ ◄──── feature/new-dev-types
446
+ ◄──── fix/invalid-category-message
447
+
448
+ └──► release/v2.1.0 ──► (final testing) ──► main + tag v2.1.0
449
  ```
450
 
451
+ ### Branches
452
+
453
+ | Branch | Purpose | Merges into |
454
+ | ------ | ------- | ----------- |
455
+ | `main` | Production-ready code, always deployable. Tagged on every release. | — |
456
+ | `develop` | Integration branch for completed features. Base for new work. | `main` via release branch |
457
+ | `feature/<name>` | New features or improvements (e.g. `feature/add-local-currency`) | `develop` |
458
+ | `fix/<name>` | Non-urgent bug fixes (e.g. `fix/guardrail-error-message`) | `develop` |
459
+ | `release/v<semver>` | Release preparation — version bump, changelog, final QA | `main` and back to `develop` |
460
+ | `hotfix/v<semver>` | Urgent production fixes (e.g. `hotfix/v2.0.1`) | `main` and back to `develop` |
461
+
462
+ ### Rules
463
+
464
+ - **`main`** is protected — no direct pushes; merge only via PR after CI passes
465
+ - **`develop`** is the default branch for day-to-day work
466
+ - Branch names use lowercase kebab-case: `feature/optuna-cv-splits`
467
+ - Every merge to `main` is tagged with a semver version
468
+ - Hotfixes branch off `main` directly and merge back to both `main` and `develop`
469
+
470
+ ### Typical workflow
471
+
472
  ```bash
473
+ # Start a new feature
474
+ git checkout develop
475
+ git pull origin develop
476
+ git checkout -b feature/add-local-currency
477
+
478
+ # ... work, commit, push ...
479
+ git push -u origin feature/add-local-currency
480
+
481
+ # Open a PR into develop, CI must pass before merging
482
+
483
+ # Prepare a release
484
+ git checkout -b release/v2.1.0 develop
485
+ # bump version in pyproject.toml, update changelog
486
+ git push -u origin release/v2.1.0
487
+ # Open PR into main, merge, tag
488
+
489
+ git tag v2.1.0
490
+ git push origin v2.1.0
491
  ```
492
 
493
+ ## Deployment
494
+
495
+ ### Hugging Face Spaces
496
 
497
+ The app is deployed on [Hugging Face Spaces](https://huggingface.co/spaces) using the Docker SDK. The Space configuration is embedded in the frontmatter at the top of this README, which Hugging Face reads automatically:
498
 
499
+ - **SDK**: Docker (runs the `Dockerfile` in the repo root)
500
+ - **Port**: 8501 (Streamlit default)
501
+ - **License**: Apache 2.0
 
502
 
503
+ To deploy your own copy:
504
+
505
+ 1. Create a new Space on [Hugging Face](https://huggingface.co/new-space) and select "Docker" as the SDK
506
+ 2. Push this repository to your Space:
 
 
 
 
507
 
 
508
  ```bash
509
+ git remote add space https://huggingface.co/spaces/<your-username>/<your-space-name>
510
+ git push space main
 
511
  ```
512
 
513
+ **Note:** The pre-trained model (`models/model.pkl`) and configuration (`config/valid_categories.yaml`, `config/currency_rates.yaml`) must be present before building the Docker image. Train locally first if needed.
514
 
515
+ ### Local Docker
516
+
517
+ **Build and run:**
518
+
519
+ ```bash
520
+ docker build -t developer-salary-predictor .
521
+ docker run -p 8501:8501 developer-salary-predictor
522
+ ```
523
+
524
+ Then visit `http://localhost:8501`
525
+
526
+ ### Local (without Docker)
527
 
528
  **Using uv (recommended for development):**
529
+
530
  ```bash
531
  uv run streamlit run app.py
532
  ```
533
 
534
  **Using pip:**
535
+
536
  ```bash
537
  pip install -r requirements.txt
538
  streamlit run app.py
 
541
  ## Troubleshooting
542
 
543
  ### "Model file not found"
544
+
545
  - Run `uv run python -m src.train` first to generate the model
546
 
547
+ ### "Valid categories file not found"
548
+
549
+ - Run `uv run python -m src.train` — training generates both `models/model.pkl`
550
+ and `config/valid_categories.yaml`
551
+
552
  ### "Data file not found"
553
+
554
  - Download the Stack Overflow survey CSV and place it in `data/`
555
 
556
  ### "Configuration file not found"
557
+
558
  - The `config/model_parameters.yaml` file should exist in the project root
559
  - Check that you're running commands from the project root directory
560
 
561
  ### Dependencies issues
562
+
563
  - Run `uv sync` to ensure all packages are installed
564
 
565
  ## Design Principles
566
 
567
+ - **Simplicity**: Minimal codebase, easy to read and modify
568
+ - **Separation of concerns**: Schema validation, preprocessing, training, and inference are distinct modules
569
+ - **Config-driven**: All tunable parameters in YAML — no magic numbers in code
570
+ - **Local-first**: No cloud dependencies for training or inference
571
+ - **Testable**: Every public function has unit tests; model sanity covered by feature-impact tests
572
 
573
  ## License
574
 
app.py CHANGED
@@ -33,6 +33,7 @@ with st.sidebar:
33
  - Industry
34
  - Age
35
  - Individual contributor or people manager
 
36
  """
37
  )
38
  st.info("💡 Tip: Results are estimates based on survey averages.")
@@ -45,6 +46,7 @@ with st.sidebar:
45
  st.write(f"**Industries:** {len(valid_categories['Industry'])} available")
46
  st.write(f"**Age Ranges:** {len(valid_categories['Age'])} available")
47
  st.write(f"**IC/PM Roles:** {len(valid_categories['ICorPM'])} available")
 
48
  st.caption("Only values from the training data are shown in the dropdowns.")
49
 
50
  # Main input form
@@ -59,6 +61,7 @@ valid_dev_types = valid_categories["DevType"]
59
  valid_industries = valid_categories["Industry"]
60
  valid_ages = valid_categories["Age"]
61
  valid_icorpm = valid_categories["ICorPM"]
 
62
 
63
  # Set default values (if available)
64
  default_country = (
@@ -87,6 +90,11 @@ default_icorpm = (
87
  if "Individual contributor" in valid_icorpm
88
  else valid_icorpm[0]
89
  )
 
 
 
 
 
90
 
91
  with col1:
92
  country = st.selectbox(
@@ -150,6 +158,13 @@ ic_or_pm = st.selectbox(
150
  help="Are you an individual contributor or people manager?",
151
  )
152
 
 
 
 
 
 
 
 
153
  # Prediction button
154
  if st.button("🔮 Predict Salary", type="primary", use_container_width=True):
155
  try:
@@ -163,6 +178,7 @@ if st.button("🔮 Predict Salary", type="primary", use_container_width=True):
163
  industry=industry,
164
  age=age,
165
  ic_or_pm=ic_or_pm,
 
166
  )
167
 
168
  # Make prediction
 
33
  - Industry
34
  - Age
35
  - Individual contributor or people manager
36
+ - Organization size
37
  """
38
  )
39
  st.info("💡 Tip: Results are estimates based on survey averages.")
 
46
  st.write(f"**Industries:** {len(valid_categories['Industry'])} available")
47
  st.write(f"**Age Ranges:** {len(valid_categories['Age'])} available")
48
  st.write(f"**IC/PM Roles:** {len(valid_categories['ICorPM'])} available")
49
+ st.write(f"**Org Sizes:** {len(valid_categories['OrgSize'])} available")
50
  st.caption("Only values from the training data are shown in the dropdowns.")
51
 
52
  # Main input form
 
61
  valid_industries = valid_categories["Industry"]
62
  valid_ages = valid_categories["Age"]
63
  valid_icorpm = valid_categories["ICorPM"]
64
+ valid_org_sizes = valid_categories["OrgSize"]
65
 
66
  # Set default values (if available)
67
  default_country = (
 
90
  if "Individual contributor" in valid_icorpm
91
  else valid_icorpm[0]
92
  )
93
+ default_org_size = (
94
+ "20 to 99 employees"
95
+ if "20 to 99 employees" in valid_org_sizes
96
+ else valid_org_sizes[0]
97
+ )
98
 
99
  with col1:
100
  country = st.selectbox(
 
158
  help="Are you an individual contributor or people manager?",
159
  )
160
 
161
+ org_size = st.selectbox(
162
+ "Organization Size",
163
+ options=valid_org_sizes,
164
+ index=valid_org_sizes.index(default_org_size),
165
+ help="Approximate number of employees at the developer's company",
166
+ )
167
+
168
  # Prediction button
169
  if st.button("🔮 Predict Salary", type="primary", use_container_width=True):
170
  try:
 
178
  industry=industry,
179
  age=age,
180
  ic_or_pm=ic_or_pm,
181
+ org_size=org_size,
182
  )
183
 
184
  # Make prediction
config/model_parameters.yaml CHANGED
@@ -16,6 +16,7 @@ features:
16
  - Industry
17
  - Age
18
  - ICorPM
 
19
  encoding:
20
  drop_first: true
21
  model:
 
16
  - Industry
17
  - Age
18
  - ICorPM
19
+ - OrgSize
20
  encoding:
21
  drop_first: true
22
  model:
config/valid_categories.yaml CHANGED
@@ -93,3 +93,13 @@ Age:
93
  ICorPM:
94
  - Individual contributor
95
  - People manager
 
 
 
 
 
 
 
 
 
 
 
93
  ICorPM:
94
  - Individual contributor
95
  - People manager
96
+ OrgSize:
97
+ - 1,000 to 4,999 employees
98
+ - 10,000 or more employees
99
+ - 100 to 499 employees
100
+ - 20 to 99 employees
101
+ - 5,000 to 9,999 employees
102
+ - 500 to 999 employees
103
+ - I don't know
104
+ - Just me - I am a freelancer, sole proprietor, etc.
105
+ - Less than 20 employees
example_inference.py CHANGED
@@ -24,6 +24,7 @@ def main():
24
  industry="Software Development",
25
  age="25-34 years old",
26
  ic_or_pm="Individual contributor",
 
27
  )
28
 
29
  print(f"Country: {input_data_1.country}")
@@ -34,6 +35,7 @@ def main():
34
  print(f"Industry: {input_data_1.industry}")
35
  print(f"Age: {input_data_1.age}")
36
  print(f"IC or PM: {input_data_1.ic_or_pm}")
 
37
 
38
  salary_1 = predict_salary(input_data_1)
39
  print(f"💰 Predicted Salary: ${salary_1:,.2f} USD/year")
@@ -51,6 +53,7 @@ def main():
51
  industry="Fintech",
52
  age="18-24 years old",
53
  ic_or_pm="Individual contributor",
 
54
  )
55
 
56
  print(f"Country: {input_data_2.country}")
@@ -61,6 +64,7 @@ def main():
61
  print(f"Industry: {input_data_2.industry}")
62
  print(f"Age: {input_data_2.age}")
63
  print(f"IC or PM: {input_data_2.ic_or_pm}")
 
64
 
65
  salary_2 = predict_salary(input_data_2)
66
  print(f"💰 Predicted Salary: ${salary_2:,.2f} USD/year")
@@ -78,6 +82,7 @@ def main():
78
  industry="Banking/Financial Services",
79
  age="35-44 years old",
80
  ic_or_pm="People manager",
 
81
  )
82
 
83
  print(f"Country: {input_data_3.country}")
@@ -88,6 +93,7 @@ def main():
88
  print(f"Industry: {input_data_3.industry}")
89
  print(f"Age: {input_data_3.age}")
90
  print(f"IC or PM: {input_data_3.ic_or_pm}")
 
91
 
92
  salary_3 = predict_salary(input_data_3)
93
  print(f"💰 Predicted Salary: ${salary_3:,.2f} USD/year")
@@ -105,6 +111,7 @@ def main():
105
  industry="Manufacturing",
106
  age="25-34 years old",
107
  ic_or_pm="Individual contributor",
 
108
  )
109
 
110
  print(f"Country: {input_data_4.country}")
@@ -115,6 +122,7 @@ def main():
115
  print(f"Industry: {input_data_4.industry}")
116
  print(f"Age: {input_data_4.age}")
117
  print(f"IC or PM: {input_data_4.ic_or_pm}")
 
118
 
119
  salary_4 = predict_salary(input_data_4)
120
  print(f"💰 Predicted Salary: ${salary_4:,.2f} USD/year")
 
24
  industry="Software Development",
25
  age="25-34 years old",
26
  ic_or_pm="Individual contributor",
27
+ org_size="20 to 99 employees",
28
  )
29
 
30
  print(f"Country: {input_data_1.country}")
 
35
  print(f"Industry: {input_data_1.industry}")
36
  print(f"Age: {input_data_1.age}")
37
  print(f"IC or PM: {input_data_1.ic_or_pm}")
38
+ print(f"Organization Size: {input_data_1.org_size}")
39
 
40
  salary_1 = predict_salary(input_data_1)
41
  print(f"💰 Predicted Salary: ${salary_1:,.2f} USD/year")
 
53
  industry="Fintech",
54
  age="18-24 years old",
55
  ic_or_pm="Individual contributor",
56
+ org_size="20 to 99 employees",
57
  )
58
 
59
  print(f"Country: {input_data_2.country}")
 
64
  print(f"Industry: {input_data_2.industry}")
65
  print(f"Age: {input_data_2.age}")
66
  print(f"IC or PM: {input_data_2.ic_or_pm}")
67
+ print(f"Organization Size: {input_data_2.org_size}")
68
 
69
  salary_2 = predict_salary(input_data_2)
70
  print(f"💰 Predicted Salary: ${salary_2:,.2f} USD/year")
 
82
  industry="Banking/Financial Services",
83
  age="35-44 years old",
84
  ic_or_pm="People manager",
85
+ org_size="1,000 to 4,999 employees",
86
  )
87
 
88
  print(f"Country: {input_data_3.country}")
 
93
  print(f"Industry: {input_data_3.industry}")
94
  print(f"Age: {input_data_3.age}")
95
  print(f"IC or PM: {input_data_3.ic_or_pm}")
96
+ print(f"Organization Size: {input_data_3.org_size}")
97
 
98
  salary_3 = predict_salary(input_data_3)
99
  print(f"💰 Predicted Salary: ${salary_3:,.2f} USD/year")
 
111
  industry="Manufacturing",
112
  age="25-34 years old",
113
  ic_or_pm="Individual contributor",
114
+ org_size="100 to 499 employees",
115
  )
116
 
117
  print(f"Country: {input_data_4.country}")
 
122
  print(f"Industry: {input_data_4.industry}")
123
  print(f"Age: {input_data_4.age}")
124
  print(f"IC or PM: {input_data_4.ic_or_pm}")
125
+ print(f"Organization Size: {input_data_4.org_size}")
126
 
127
  salary_4 = predict_salary(input_data_4)
128
  print(f"💰 Predicted Salary: ${salary_4:,.2f} USD/year")
guardrail_evaluation.py CHANGED
@@ -172,12 +172,8 @@ def compute_category_metrics(
172
  def format_table(metrics_df: pd.DataFrame) -> str:
173
  """Format metrics DataFrame as a markdown table."""
174
  lines = []
175
- header = (
176
- "| Category | Count | MAPE (%) | Mean Actual ($) | Mean Predicted ($) | Abs % Diff |"
177
- )
178
- sep = (
179
- "|----------|------:|---------:|----------------:|-------------------:|-----------:|"
180
- )
181
  lines.append(header)
182
  lines.append(sep)
183
 
 
172
  def format_table(metrics_df: pd.DataFrame) -> str:
173
  """Format metrics DataFrame as a markdown table."""
174
  lines = []
175
+ header = "| Category | Count | MAPE (%) | Mean Actual ($) | Mean Predicted ($) | Abs % Diff |"
176
+ sep = "|----------|------:|---------:|----------------:|-------------------:|-----------:|"
 
 
 
 
177
  lines.append(header)
178
  lines.append(sep)
179
 
models/model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a22e7a728aeb84f766e9acbef698afe0b4733a3385eed44d1663dc771d68be2
3
- size 1851836
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:295c1996202ba1a93f502705986ac96ff3e802d4009479a22d82d52b0b5e7f42
3
+ size 1830437
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "developer-salary-prediction"
3
- version = "0.1.0"
4
  description = "Simple ML app for predicting developer salaries using Stack Overflow survey data"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
@@ -17,6 +17,7 @@ dependencies = [
17
  "radon>=6.0.1",
18
  "pip-audit>=2.10.0",
19
  "bandit>=1.9.3",
 
20
  ]
21
 
22
  [project.optional-dependencies]
 
1
  [project]
2
  name = "developer-salary-prediction"
3
+ version = "2.0.0"
4
  description = "Simple ML app for predicting developer salaries using Stack Overflow survey data"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
 
17
  "radon>=6.0.1",
18
  "pip-audit>=2.10.0",
19
  "bandit>=1.9.3",
20
+ "pre-commit>=4.5.1",
21
  ]
22
 
23
  [project.optional-dependencies]
src/infer.py CHANGED
@@ -113,6 +113,13 @@ def predict_salary(data: SalaryInput) -> float:
113
  f"Check config/valid_categories.yaml for all valid values."
114
  )
115
 
 
 
 
 
 
 
 
116
  # Create a DataFrame with the input data
117
  input_df = pd.DataFrame(
118
  {
@@ -124,6 +131,7 @@ def predict_salary(data: SalaryInput) -> float:
124
  "Industry": [data.industry],
125
  "Age": [data.age],
126
  "ICorPM": [data.ic_or_pm],
 
127
  }
128
  )
129
 
 
113
  f"Check config/valid_categories.yaml for all valid values."
114
  )
115
 
116
+ if data.org_size not in valid_categories["OrgSize"]:
117
+ raise ValueError(
118
+ f"Invalid organization size: '{data.org_size}'. "
119
+ f"Must be one of {len(valid_categories['OrgSize'])} valid sizes. "
120
+ f"Check config/valid_categories.yaml for all valid values."
121
+ )
122
+
123
  # Create a DataFrame with the input data
124
  input_df = pd.DataFrame(
125
  {
 
131
  "Industry": [data.industry],
132
  "Age": [data.age],
133
  "ICorPM": [data.ic_or_pm],
134
+ "OrgSize": [data.org_size],
135
  }
136
  )
137
 
src/preprocessing.py CHANGED
@@ -78,7 +78,8 @@ def prepare_features(df: pd.DataFrame) -> pd.DataFrame:
78
  during training and inference, preventing data leakage and inconsistencies.
79
 
80
  Args:
81
- df: DataFrame with columns: Country, YearsCode, WorkExp, EdLevel, DevType, Industry, Age, ICorPM
 
82
  NOTE: During training, cardinality reduction should be applied to df
83
  BEFORE calling this function. During inference, valid_categories.yaml
84
  ensures only valid (already-reduced) categories are used.
@@ -98,14 +99,23 @@ def prepare_features(df: pd.DataFrame) -> pd.DataFrame:
98
 
99
  # Normalize Unicode apostrophes to regular apostrophes for consistency
100
  # This handles cases where data has \u2019 (') instead of '
101
- for col in ["Country", "EdLevel", "DevType", "Industry", "Age", "ICorPM"]:
 
 
 
 
 
 
 
 
 
102
  if col in df_processed.columns:
103
  df_processed[col] = df_processed[col].str.replace(
104
  "\u2019", "'", regex=False
105
  )
106
 
107
  # Normalize "Other" category variants (e.g. "Other (please specify):" -> "Other")
108
- for col in ["Country", "EdLevel", "DevType", "Industry", "Age", "ICorPM"]:
109
  if col in df_processed.columns:
110
  df_processed[col] = normalize_other_categories(df_processed[col])
111
 
@@ -125,6 +135,7 @@ def prepare_features(df: pd.DataFrame) -> pd.DataFrame:
125
  df_processed["Industry"] = df_processed["Industry"].fillna("Unknown")
126
  df_processed["Age"] = df_processed["Age"].fillna("Unknown")
127
  df_processed["ICorPM"] = df_processed["ICorPM"].fillna("Unknown")
 
128
 
129
  # NOTE: Cardinality reduction is NOT applied here
130
  # It should be applied during training BEFORE calling this function
@@ -140,6 +151,7 @@ def prepare_features(df: pd.DataFrame) -> pd.DataFrame:
140
  "Industry",
141
  "Age",
142
  "ICorPM",
 
143
  ]
144
  df_features = df_processed[feature_cols]
145
 
 
78
  during training and inference, preventing data leakage and inconsistencies.
79
 
80
  Args:
81
+ df: DataFrame with columns: Country, YearsCode, WorkExp, EdLevel,
82
+ DevType, Industry, Age, ICorPM, OrgSize.
83
  NOTE: During training, cardinality reduction should be applied to df
84
  BEFORE calling this function. During inference, valid_categories.yaml
85
  ensures only valid (already-reduced) categories are used.
 
99
 
100
  # Normalize Unicode apostrophes to regular apostrophes for consistency
101
  # This handles cases where data has \u2019 (') instead of '
102
+ _categorical_cols = [
103
+ "Country",
104
+ "EdLevel",
105
+ "DevType",
106
+ "Industry",
107
+ "Age",
108
+ "ICorPM",
109
+ "OrgSize",
110
+ ]
111
+ for col in _categorical_cols:
112
  if col in df_processed.columns:
113
  df_processed[col] = df_processed[col].str.replace(
114
  "\u2019", "'", regex=False
115
  )
116
 
117
  # Normalize "Other" category variants (e.g. "Other (please specify):" -> "Other")
118
+ for col in _categorical_cols:
119
  if col in df_processed.columns:
120
  df_processed[col] = normalize_other_categories(df_processed[col])
121
 
 
135
  df_processed["Industry"] = df_processed["Industry"].fillna("Unknown")
136
  df_processed["Age"] = df_processed["Age"].fillna("Unknown")
137
  df_processed["ICorPM"] = df_processed["ICorPM"].fillna("Unknown")
138
+ df_processed["OrgSize"] = df_processed["OrgSize"].fillna("Unknown")
139
 
140
  # NOTE: Cardinality reduction is NOT applied here
141
  # It should be applied during training BEFORE calling this function
 
151
  "Industry",
152
  "Age",
153
  "ICorPM",
154
+ "OrgSize",
155
  ]
156
  df_features = df_processed[feature_cols]
157
 
src/schema.py CHANGED
@@ -18,6 +18,7 @@ class SalaryInput(BaseModel):
18
  "industry": "Software Development",
19
  "age": "25-34 years old",
20
  "ic_or_pm": "Individual contributor",
 
21
  }
22
  ]
23
  }
@@ -39,3 +40,6 @@ class SalaryInput(BaseModel):
39
  industry: str = Field(..., description="Industry the developer works in")
40
  age: str = Field(..., description="Developer's age range")
41
  ic_or_pm: str = Field(..., description="Individual contributor or people manager")
 
 
 
 
18
  "industry": "Software Development",
19
  "age": "25-34 years old",
20
  "ic_or_pm": "Individual contributor",
21
+ "org_size": "20 to 99 employees",
22
  }
23
  ]
24
  }
 
40
  industry: str = Field(..., description="Industry the developer works in")
41
  age: str = Field(..., description="Developer's age range")
42
  ic_or_pm: str = Field(..., description="Individual contributor or people manager")
43
+ org_size: str = Field(
44
+ ..., description="Size of the organisation the developer works for"
45
+ )
src/train.py CHANGED
@@ -11,7 +11,15 @@ from sklearn.model_selection import KFold, train_test_split
11
 
12
  from src.preprocessing import prepare_features, reduce_cardinality
13
 
14
- CATEGORICAL_FEATURES = ["Country", "EdLevel", "DevType", "Industry", "Age", "ICorPM"]
 
 
 
 
 
 
 
 
15
 
16
 
17
  def filter_salaries(df: pd.DataFrame, config: dict) -> pd.DataFrame:
@@ -160,6 +168,7 @@ def main():
160
  "Industry",
161
  "Age",
162
  "ICorPM",
 
163
  "Currency",
164
  "CompTotal",
165
  "ConvertedCompYearly",
 
11
 
12
  from src.preprocessing import prepare_features, reduce_cardinality
13
 
14
+ CATEGORICAL_FEATURES = [
15
+ "Country",
16
+ "EdLevel",
17
+ "DevType",
18
+ "Industry",
19
+ "Age",
20
+ "ICorPM",
21
+ "OrgSize",
22
+ ]
23
 
24
 
25
  def filter_salaries(df: pd.DataFrame, config: dict) -> pd.DataFrame:
 
168
  "Industry",
169
  "Age",
170
  "ICorPM",
171
+ "OrgSize",
172
  "Currency",
173
  "CompTotal",
174
  "ConvertedCompYearly",
src/tune.py CHANGED
@@ -36,15 +36,11 @@ def sample_params(trial: optuna.Trial, search_space: dict) -> dict:
36
  params[name] = trial.suggest_int(name, spec["low"], spec["high"])
37
  elif param_type == "float":
38
  log = spec.get("log", False)
39
- params[name] = trial.suggest_float(
40
- name, spec["low"], spec["high"], log=log
41
- )
42
  return params
43
 
44
 
45
- def build_objective(
46
- X: pd.DataFrame, y: pd.Series, optuna_config: dict
47
- ) -> callable:
48
  """Build an Optuna objective function for XGBoost CV evaluation.
49
 
50
  Args:
 
36
  params[name] = trial.suggest_int(name, spec["low"], spec["high"])
37
  elif param_type == "float":
38
  log = spec.get("log", False)
39
+ params[name] = trial.suggest_float(name, spec["low"], spec["high"], log=log)
 
 
40
  return params
41
 
42
 
43
+ def build_objective(X: pd.DataFrame, y: pd.Series, optuna_config: dict) -> callable:
 
 
44
  """Build an Optuna objective function for XGBoost CV evaluation.
45
 
46
  Args:
tests/conftest.py CHANGED
@@ -18,6 +18,7 @@ def sample_salary_input():
18
  "industry": "Software Development",
19
  "age": "25-34 years old",
20
  "ic_or_pm": "Individual contributor",
 
21
  }
22
 
23
 
 
18
  "industry": "Software Development",
19
  "age": "25-34 years old",
20
  "ic_or_pm": "Individual contributor",
21
+ "org_size": "20 to 99 employees",
22
  }
23
 
24
 
tests/test_feature_impact.py CHANGED
@@ -14,6 +14,7 @@ def test_years_experience_impact():
14
  "industry": "Software Development",
15
  "age": "25-34 years old",
16
  "ic_or_pm": "Individual contributor",
 
17
  }
18
 
19
  years_tests = [0, 2, 5, 10, 20]
@@ -37,6 +38,7 @@ def test_country_impact():
37
  "industry": "Software Development",
38
  "age": "25-34 years old",
39
  "ic_or_pm": "Individual contributor",
 
40
  }
41
 
42
  test_countries = [
@@ -71,12 +73,14 @@ def test_education_impact():
71
  "industry": "Software Development",
72
  "age": "25-34 years old",
73
  "ic_or_pm": "Individual contributor",
 
74
  }
75
 
76
  test_education = [
77
  e
78
  for e in [
79
- "Secondary school (e.g. American high school, German Realschule or Gymnasium, etc.)",
 
80
  "Some college/university study without earning a degree",
81
  "Associate degree (A.A., A.S., etc.)",
82
  "Bachelor's degree (B.A., B.S., B.Eng., etc.)",
@@ -106,6 +110,7 @@ def test_devtype_impact():
106
  "industry": "Software Development",
107
  "age": "25-34 years old",
108
  "ic_or_pm": "Individual contributor",
 
109
  }
110
 
111
  test_devtypes = [
@@ -141,6 +146,7 @@ def test_industry_impact():
141
  "dev_type": "Developer, full-stack",
142
  "age": "25-34 years old",
143
  "ic_or_pm": "Individual contributor",
 
144
  }
145
 
146
  test_industries = [
@@ -176,6 +182,7 @@ def test_age_impact():
176
  "dev_type": "Developer, full-stack",
177
  "industry": "Software Development",
178
  "ic_or_pm": "Individual contributor",
 
179
  }
180
 
181
  test_ages = [
@@ -210,6 +217,7 @@ def test_work_exp_impact():
210
  "industry": "Software Development",
211
  "age": "25-34 years old",
212
  "ic_or_pm": "Individual contributor",
 
213
  }
214
 
215
  work_exp_tests = [0, 1, 3, 5, 10, 20]
@@ -219,7 +227,8 @@ def test_work_exp_impact():
219
  predictions.append(predict_salary(input_data))
220
 
221
  assert len(set(predictions)) >= len(predictions) - 1, (
222
- f"Expected at least {len(predictions) - 1} unique predictions, got {len(set(predictions))}"
 
223
  )
224
 
225
 
@@ -233,6 +242,7 @@ def test_icorpm_impact():
233
  "dev_type": "Developer, full-stack",
234
  "industry": "Software Development",
235
  "age": "25-34 years old",
 
236
  }
237
 
238
  test_icorpm = [
@@ -251,6 +261,31 @@ def test_icorpm_impact():
251
  )
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def test_combined_features():
255
  """Test that combining different features produces expected variations."""
256
  test_cases = [
@@ -263,6 +298,7 @@ def test_combined_features():
263
  "Software Development",
264
  "18-24 years old",
265
  "Individual contributor",
 
266
  ),
267
  (
268
  "Germany",
@@ -273,6 +309,7 @@ def test_combined_features():
273
  "Manufacturing",
274
  "25-34 years old",
275
  "Individual contributor",
 
276
  ),
277
  (
278
  "United States of America",
@@ -283,6 +320,7 @@ def test_combined_features():
283
  "Fintech",
284
  "35-44 years old",
285
  "People manager",
 
286
  ),
287
  (
288
  "Poland",
@@ -293,6 +331,7 @@ def test_combined_features():
293
  "Healthcare",
294
  "45-54 years old",
295
  "Individual contributor",
 
296
  ),
297
  (
298
  "Brazil",
@@ -303,6 +342,7 @@ def test_combined_features():
303
  "Government",
304
  "25-34 years old",
305
  "Individual contributor",
 
306
  ),
307
  ]
308
 
@@ -316,6 +356,7 @@ def test_combined_features():
316
  industry,
317
  age,
318
  icorpm,
 
319
  ) in test_cases:
320
  if (
321
  country not in valid_categories["Country"]
@@ -324,6 +365,7 @@ def test_combined_features():
324
  or industry not in valid_categories["Industry"]
325
  or age not in valid_categories["Age"]
326
  or icorpm not in valid_categories["ICorPM"]
 
327
  ):
328
  continue
329
 
@@ -336,6 +378,7 @@ def test_combined_features():
336
  industry=industry,
337
  age=age,
338
  ic_or_pm=icorpm,
 
339
  )
340
  predictions.append(predict_salary(input_data))
341
 
 
14
  "industry": "Software Development",
15
  "age": "25-34 years old",
16
  "ic_or_pm": "Individual contributor",
17
+ "org_size": "20 to 99 employees",
18
  }
19
 
20
  years_tests = [0, 2, 5, 10, 20]
 
38
  "industry": "Software Development",
39
  "age": "25-34 years old",
40
  "ic_or_pm": "Individual contributor",
41
+ "org_size": "20 to 99 employees",
42
  }
43
 
44
  test_countries = [
 
73
  "industry": "Software Development",
74
  "age": "25-34 years old",
75
  "ic_or_pm": "Individual contributor",
76
+ "org_size": "20 to 99 employees",
77
  }
78
 
79
  test_education = [
80
  e
81
  for e in [
82
+ "Secondary school (e.g. American high school, "
83
+ "German Realschule or Gymnasium, etc.)",
84
  "Some college/university study without earning a degree",
85
  "Associate degree (A.A., A.S., etc.)",
86
  "Bachelor's degree (B.A., B.S., B.Eng., etc.)",
 
110
  "industry": "Software Development",
111
  "age": "25-34 years old",
112
  "ic_or_pm": "Individual contributor",
113
+ "org_size": "20 to 99 employees",
114
  }
115
 
116
  test_devtypes = [
 
146
  "dev_type": "Developer, full-stack",
147
  "age": "25-34 years old",
148
  "ic_or_pm": "Individual contributor",
149
+ "org_size": "20 to 99 employees",
150
  }
151
 
152
  test_industries = [
 
182
  "dev_type": "Developer, full-stack",
183
  "industry": "Software Development",
184
  "ic_or_pm": "Individual contributor",
185
+ "org_size": "20 to 99 employees",
186
  }
187
 
188
  test_ages = [
 
217
  "industry": "Software Development",
218
  "age": "25-34 years old",
219
  "ic_or_pm": "Individual contributor",
220
+ "org_size": "20 to 99 employees",
221
  }
222
 
223
  work_exp_tests = [0, 1, 3, 5, 10, 20]
 
227
  predictions.append(predict_salary(input_data))
228
 
229
  assert len(set(predictions)) >= len(predictions) - 1, (
230
+ f"Expected at least {len(predictions) - 1} unique predictions, "
231
+ f"got {len(set(predictions))}"
232
  )
233
 
234
 
 
242
  "dev_type": "Developer, full-stack",
243
  "industry": "Software Development",
244
  "age": "25-34 years old",
245
+ "org_size": "20 to 99 employees",
246
  }
247
 
248
  test_icorpm = [
 
261
  )
262
 
263
 
264
+ def test_org_size_impact():
265
+ """Test that changing organization size changes prediction."""
266
+ base_input = {
267
+ "country": "United States of America",
268
+ "years_code": 5.0,
269
+ "work_exp": 3.0,
270
+ "education_level": "Bachelor's degree (B.A., B.S., B.Eng., etc.)",
271
+ "dev_type": "Developer, full-stack",
272
+ "industry": "Software Development",
273
+ "age": "25-34 years old",
274
+ "ic_or_pm": "Individual contributor",
275
+ }
276
+
277
+ test_org_sizes = valid_categories["OrgSize"][:5]
278
+
279
+ predictions = []
280
+ for org_size in test_org_sizes:
281
+ input_data = SalaryInput(**base_input, org_size=org_size)
282
+ predictions.append(predict_salary(input_data))
283
+
284
+ assert len(set(predictions)) == len(predictions), (
285
+ f"Expected {len(predictions)} unique predictions, got {len(set(predictions))}"
286
+ )
287
+
288
+
289
  def test_combined_features():
290
  """Test that combining different features produces expected variations."""
291
  test_cases = [
 
298
  "Software Development",
299
  "18-24 years old",
300
  "Individual contributor",
301
+ "20 to 99 employees",
302
  ),
303
  (
304
  "Germany",
 
309
  "Manufacturing",
310
  "25-34 years old",
311
  "Individual contributor",
312
+ "100 to 499 employees",
313
  ),
314
  (
315
  "United States of America",
 
320
  "Fintech",
321
  "35-44 years old",
322
  "People manager",
323
+ "1,000 to 4,999 employees",
324
  ),
325
  (
326
  "Poland",
 
331
  "Healthcare",
332
  "45-54 years old",
333
  "Individual contributor",
334
+ "20 to 99 employees",
335
  ),
336
  (
337
  "Brazil",
 
342
  "Government",
343
  "25-34 years old",
344
  "Individual contributor",
345
+ "20 to 99 employees",
346
  ),
347
  ]
348
 
 
356
  industry,
357
  age,
358
  icorpm,
359
+ org_size,
360
  ) in test_cases:
361
  if (
362
  country not in valid_categories["Country"]
 
365
  or industry not in valid_categories["Industry"]
366
  or age not in valid_categories["Age"]
367
  or icorpm not in valid_categories["ICorPM"]
368
+ or org_size not in valid_categories["OrgSize"]
369
  ):
370
  continue
371
 
 
378
  industry=industry,
379
  age=age,
380
  ic_or_pm=icorpm,
381
+ org_size=org_size,
382
  )
383
  predictions.append(predict_salary(input_data))
384
 
tests/test_infer.py CHANGED
@@ -55,6 +55,13 @@ def test_invalid_ic_or_pm(sample_salary_input):
55
  predict_salary(SalaryInput(**sample_salary_input))
56
 
57
 
 
 
 
 
 
 
 
58
  def test_get_local_currency_unknown_country():
59
  """get_local_currency returns None for unknown country."""
60
  result = get_local_currency("Narnia", 100000)
 
55
  predict_salary(SalaryInput(**sample_salary_input))
56
 
57
 
58
+ def test_invalid_org_size(sample_salary_input):
59
+ """Invalid organization size raises ValueError."""
60
+ sample_salary_input["org_size"] = "Megacorp 10M+"
61
+ with pytest.raises(ValueError, match="Invalid organization size"):
62
+ predict_salary(SalaryInput(**sample_salary_input))
63
+
64
+
65
  def test_get_local_currency_unknown_country():
66
  """get_local_currency returns None for unknown country."""
67
  result = get_local_currency("Narnia", 100000)
tests/test_preprocessing.py CHANGED
@@ -86,6 +86,7 @@ class TestPrepareFeatures:
86
  "Industry": ["Software Development"],
87
  "Age": ["25-34 years old"],
88
  "ICorPM": ["Individual contributor"],
 
89
  }
90
  )
91
  result = prepare_features(df)
@@ -104,6 +105,7 @@ class TestPrepareFeatures:
104
  "Industry": ["Software Development"],
105
  "Age": ["25-34 years old"],
106
  "ICorPM": ["Individual contributor"],
 
107
  }
108
  )
109
  result = prepare_features(df)
@@ -122,6 +124,7 @@ class TestPrepareFeatures:
122
  "Industry": ["Software Development", "Healthcare"],
123
  "Age": ["25-34 years old", "35-44 years old"],
124
  "ICorPM": ["Individual contributor", "People manager"],
 
125
  }
126
  )
127
  result = prepare_features(df)
@@ -143,6 +146,7 @@ class TestPrepareFeatures:
143
  "Industry": ["Software Development"],
144
  "Age": ["25-34 years old"],
145
  "ICorPM": ["Individual contributor"],
 
146
  }
147
  )
148
  result = prepare_features(df)
@@ -161,6 +165,7 @@ class TestPrepareFeatures:
161
  "Industry": [None],
162
  "Age": [None],
163
  "ICorPM": [None],
 
164
  }
165
  )
166
  result = prepare_features(df)
@@ -181,6 +186,7 @@ class TestPrepareFeatures:
181
  "Industry": ["Software Development"],
182
  "Age": ["25-34 years old"],
183
  "ICorPM": ["Individual contributor"],
 
184
  }
185
  )
186
  original_country = df["Country"].iloc[0]
 
86
  "Industry": ["Software Development"],
87
  "Age": ["25-34 years old"],
88
  "ICorPM": ["Individual contributor"],
89
+ "OrgSize": ["20 to 99 employees"],
90
  }
91
  )
92
  result = prepare_features(df)
 
105
  "Industry": ["Software Development"],
106
  "Age": ["25-34 years old"],
107
  "ICorPM": ["Individual contributor"],
108
+ "OrgSize": ["20 to 99 employees"],
109
  }
110
  )
111
  result = prepare_features(df)
 
124
  "Industry": ["Software Development", "Healthcare"],
125
  "Age": ["25-34 years old", "35-44 years old"],
126
  "ICorPM": ["Individual contributor", "People manager"],
127
+ "OrgSize": ["20 to 99 employees", "100 to 499 employees"],
128
  }
129
  )
130
  result = prepare_features(df)
 
146
  "Industry": ["Software Development"],
147
  "Age": ["25-34 years old"],
148
  "ICorPM": ["Individual contributor"],
149
+ "OrgSize": ["20 to 99 employees"],
150
  }
151
  )
152
  result = prepare_features(df)
 
165
  "Industry": [None],
166
  "Age": [None],
167
  "ICorPM": [None],
168
+ "OrgSize": [None],
169
  }
170
  )
171
  result = prepare_features(df)
 
186
  "Industry": ["Software Development"],
187
  "Age": ["25-34 years old"],
188
  "ICorPM": ["Individual contributor"],
189
+ "OrgSize": ["20 to 99 employees"],
190
  }
191
  )
192
  original_country = df["Country"].iloc[0]
tests/test_schema.py CHANGED
@@ -17,6 +17,7 @@ def test_valid_input(sample_salary_input):
17
  assert result.industry == sample_salary_input["industry"]
18
  assert result.age == sample_salary_input["age"]
19
  assert result.ic_or_pm == sample_salary_input["ic_or_pm"]
 
20
 
21
 
22
  def test_negative_years_code(sample_salary_input):
@@ -44,6 +45,7 @@ def test_missing_country():
44
  industry="Software Development",
45
  age="25-34 years old",
46
  ic_or_pm="Individual contributor",
 
47
  )
48
 
49
 
@@ -58,6 +60,22 @@ def test_missing_education_level():
58
  industry="Software Development",
59
  age="25-34 years old",
60
  ic_or_pm="Individual contributor",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
 
 
17
  assert result.industry == sample_salary_input["industry"]
18
  assert result.age == sample_salary_input["age"]
19
  assert result.ic_or_pm == sample_salary_input["ic_or_pm"]
20
+ assert result.org_size == sample_salary_input["org_size"]
21
 
22
 
23
  def test_negative_years_code(sample_salary_input):
 
45
  industry="Software Development",
46
  age="25-34 years old",
47
  ic_or_pm="Individual contributor",
48
+ org_size="20 to 99 employees",
49
  )
50
 
51
 
 
60
  industry="Software Development",
61
  age="25-34 years old",
62
  ic_or_pm="Individual contributor",
63
+ org_size="20 to 99 employees",
64
+ )
65
+
66
+
67
+ def test_missing_org_size():
68
+ """Missing org_size raises ValidationError."""
69
+ with pytest.raises(ValidationError):
70
+ SalaryInput(
71
+ country="United States of America",
72
+ years_code=5.0,
73
+ work_exp=3.0,
74
+ education_level="Bachelor's degree (B.A., B.S., B.Eng., etc.)",
75
+ dev_type="Developer, full-stack",
76
+ industry="Software Development",
77
+ age="25-34 years old",
78
+ ic_or_pm="Individual contributor",
79
  )
80
 
81
 
tests/test_train.py CHANGED
@@ -34,6 +34,7 @@ def _make_salary_df(countries=None, salaries=None, n=100) -> pd.DataFrame:
34
  "Industry": ["Software Development"] * n,
35
  "Age": ["25-34 years old"] * n,
36
  "ICorPM": ["Individual contributor"] * n,
 
37
  "Currency": ["USD United States Dollar"] * n,
38
  "CompTotal": salaries,
39
  "ConvertedCompYearly": salaries,
@@ -140,6 +141,7 @@ class TestDropOtherRows:
140
  "Industry": ["SW", "SW", "SW"],
141
  "Age": ["25-34", "25-34", "25-34"],
142
  "ICorPM": ["IC", "IC", "IC"],
 
143
  }
144
  )
145
  config = {
@@ -164,6 +166,7 @@ class TestDropOtherRows:
164
  "Industry": ["SW", "SW"],
165
  "Age": ["25-34", "25-34"],
166
  "ICorPM": ["IC", "IC"],
 
167
  }
168
  )
169
  config = {
@@ -187,6 +190,7 @@ class TestDropOtherRows:
187
  "Industry": ["SW", "SW"],
188
  "Age": ["25-34", "25-34"],
189
  "ICorPM": ["IC", "IC"],
 
190
  }
191
  )
192
  config = {
@@ -214,15 +218,17 @@ class TestExtractValidCategories:
214
  "Industry": ["SW", "Fin", "SW"],
215
  "Age": ["25-34", "35-44", "25-34"],
216
  "ICorPM": ["IC", "PM", "IC"],
 
217
  }
218
  )
219
  result = extract_valid_categories(df)
220
  assert result["Country"] == ["Germany", "USA"]
221
  assert result["EdLevel"] == ["BS", "MS"]
222
  assert result["ICorPM"] == ["IC", "PM"]
 
223
 
224
  def test_all_categorical_features_present(self):
225
- """All 6 categorical features are present as keys."""
226
  df = pd.DataFrame(
227
  {
228
  "Country": ["USA"],
@@ -231,6 +237,7 @@ class TestExtractValidCategories:
231
  "Industry": ["SW"],
232
  "Age": ["25-34"],
233
  "ICorPM": ["IC"],
 
234
  }
235
  )
236
  result = extract_valid_categories(df)
@@ -241,6 +248,7 @@ class TestExtractValidCategories:
241
  "Industry",
242
  "Age",
243
  "ICorPM",
 
244
  }
245
 
246
  def test_excludes_nan_values(self):
@@ -253,6 +261,7 @@ class TestExtractValidCategories:
253
  "Industry": ["SW", "SW"],
254
  "Age": ["25-34", "25-34"],
255
  "ICorPM": ["IC", "IC"],
 
256
  }
257
  )
258
  result = extract_valid_categories(df)
 
34
  "Industry": ["Software Development"] * n,
35
  "Age": ["25-34 years old"] * n,
36
  "ICorPM": ["Individual contributor"] * n,
37
+ "OrgSize": ["20 to 99 employees"] * n,
38
  "Currency": ["USD United States Dollar"] * n,
39
  "CompTotal": salaries,
40
  "ConvertedCompYearly": salaries,
 
141
  "Industry": ["SW", "SW", "SW"],
142
  "Age": ["25-34", "25-34", "25-34"],
143
  "ICorPM": ["IC", "IC", "IC"],
144
+ "OrgSize": ["Small", "Small", "Small"],
145
  }
146
  )
147
  config = {
 
166
  "Industry": ["SW", "SW"],
167
  "Age": ["25-34", "25-34"],
168
  "ICorPM": ["IC", "IC"],
169
+ "OrgSize": ["Small", "Small"],
170
  }
171
  )
172
  config = {
 
190
  "Industry": ["SW", "SW"],
191
  "Age": ["25-34", "25-34"],
192
  "ICorPM": ["IC", "IC"],
193
+ "OrgSize": ["Small", "Small"],
194
  }
195
  )
196
  config = {
 
218
  "Industry": ["SW", "Fin", "SW"],
219
  "Age": ["25-34", "35-44", "25-34"],
220
  "ICorPM": ["IC", "PM", "IC"],
221
+ "OrgSize": ["Small", "Large", "Small"],
222
  }
223
  )
224
  result = extract_valid_categories(df)
225
  assert result["Country"] == ["Germany", "USA"]
226
  assert result["EdLevel"] == ["BS", "MS"]
227
  assert result["ICorPM"] == ["IC", "PM"]
228
+ assert result["OrgSize"] == ["Large", "Small"]
229
 
230
  def test_all_categorical_features_present(self):
231
+ """All 7 categorical features are present as keys."""
232
  df = pd.DataFrame(
233
  {
234
  "Country": ["USA"],
 
237
  "Industry": ["SW"],
238
  "Age": ["25-34"],
239
  "ICorPM": ["IC"],
240
+ "OrgSize": ["Small"],
241
  }
242
  )
243
  result = extract_valid_categories(df)
 
248
  "Industry",
249
  "Age",
250
  "ICorPM",
251
+ "OrgSize",
252
  }
253
 
254
  def test_excludes_nan_values(self):
 
261
  "Industry": ["SW", "SW"],
262
  "Age": ["25-34", "25-34"],
263
  "ICorPM": ["IC", "IC"],
264
+ "OrgSize": ["Small", "Small"],
265
  }
266
  )
267
  result = extract_valid_categories(df)
tests/test_tune.py CHANGED
@@ -87,9 +87,7 @@ class TestSaveBestParams:
87
  "max_depth": 6,
88
  },
89
  }
90
- with tempfile.NamedTemporaryFile(
91
- mode="w", suffix=".yaml", delete=False
92
- ) as f:
93
  yaml.dump(config, f)
94
  tmp_path = Path(f.name)
95
 
@@ -110,9 +108,7 @@ class TestSaveBestParams:
110
  "model": {"n_estimators": 5000, "max_depth": 6},
111
  "training": {"verbose": False},
112
  }
113
- with tempfile.NamedTemporaryFile(
114
- mode="w", suffix=".yaml", delete=False
115
- ) as f:
116
  yaml.dump(config, f)
117
  tmp_path = Path(f.name)
118
 
@@ -134,9 +130,7 @@ class TestSaveBestParams:
134
  "n_jobs": -1,
135
  },
136
  }
137
- with tempfile.NamedTemporaryFile(
138
- mode="w", suffix=".yaml", delete=False
139
- ) as f:
140
  yaml.dump(config, f)
141
  tmp_path = Path(f.name)
142
 
 
87
  "max_depth": 6,
88
  },
89
  }
90
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
 
 
91
  yaml.dump(config, f)
92
  tmp_path = Path(f.name)
93
 
 
108
  "model": {"n_estimators": 5000, "max_depth": 6},
109
  "training": {"verbose": False},
110
  }
111
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
 
 
112
  yaml.dump(config, f)
113
  tmp_path = Path(f.name)
114
 
 
130
  "n_jobs": -1,
131
  },
132
  }
133
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
 
 
134
  yaml.dump(config, f)
135
  tmp_path = Path(f.name)
136
 
uv.lock CHANGED
@@ -127,6 +127,15 @@ wheels = [
127
  { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" },
128
  ]
129
 
 
 
 
 
 
 
 
 
 
130
  [[package]]
131
  name = "charset-normalizer"
132
  version = "3.4.4"
@@ -328,7 +337,7 @@ wheels = [
328
 
329
  [[package]]
330
  name = "developer-salary-prediction"
331
- version = "0.1.0"
332
  source = { virtual = "." }
333
  dependencies = [
334
  { name = "bandit" },
@@ -336,6 +345,7 @@ dependencies = [
336
  { name = "optuna" },
337
  { name = "pandas" },
338
  { name = "pip-audit" },
 
339
  { name = "pydantic" },
340
  { name = "pyyaml" },
341
  { name = "radon" },
@@ -363,6 +373,7 @@ requires-dist = [
363
  { name = "pandas", specifier = ">=2.0.0" },
364
  { name = "pip-audit", specifier = ">=2.10.0" },
365
  { name = "pip-audit", marker = "extra == 'dev'", specifier = ">=2.7.0" },
 
366
  { name = "pydantic", specifier = ">=2.0.0" },
367
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
368
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=6.0.0" },
@@ -376,13 +387,22 @@ requires-dist = [
376
  ]
377
  provides-extras = ["dev"]
378
 
 
 
 
 
 
 
 
 
 
379
  [[package]]
380
  name = "filelock"
381
- version = "3.24.0"
382
  source = { registry = "https://pypi.org/simple" }
383
- sdist = { url = "https://files.pythonhosted.org/packages/00/cd/fa3ab025a8f9772e8a9146d8fd8eef6d62649274d231ca84249f54a0de4a/filelock-3.24.0.tar.gz", hash = "sha256:aeeab479339ddf463a1cdd1f15a6e6894db976071e5883efc94d22ed5139044b", size = 37166, upload-time = "2026-02-14T16:05:28.723Z" }
384
  wheels = [
385
- { url = "https://files.pythonhosted.org/packages/d9/dd/d7e7f4f49180e8591c9e1281d15ecf8e7f25eb2c829771d9682f1f9fe0c8/filelock-3.24.0-py3-none-any.whl", hash = "sha256:eebebb403d78363ef7be8e236b63cc6760b0004c7464dceaba3fd0afbd637ced", size = 23977, upload-time = "2026-02-14T16:05:27.578Z" },
386
  ]
387
 
388
  [[package]]
@@ -448,6 +468,15 @@ wheels = [
448
  { url = "https://files.pythonhosted.org/packages/e1/2b/98c7f93e6db9977aaee07eb1e51ca63bd5f779b900d362791d3252e60558/greenlet-3.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:301860987846c24cb8964bdec0e31a96ad4a2a801b41b4ef40963c1b44f33451", size = 233181, upload-time = "2026-01-23T15:33:00.29Z" },
449
  ]
450
 
 
 
 
 
 
 
 
 
 
451
  [[package]]
452
  name = "idna"
453
  version = "3.11"
@@ -687,6 +716,15 @@ wheels = [
687
  { url = "https://files.pythonhosted.org/packages/03/cc/7cb74758e6df95e0c4e1253f203b6dd7f348bf2f29cf89e9210a2416d535/narwhals-2.16.0-py3-none-any.whl", hash = "sha256:846f1fd7093ac69d63526e50732033e86c30ea0026a44d9b23991010c7d1485d", size = 443951, upload-time = "2026-02-02T10:30:58.635Z" },
688
  ]
689
 
 
 
 
 
 
 
 
 
 
690
  [[package]]
691
  name = "numpy"
692
  version = "2.4.2"
@@ -982,6 +1020,22 @@ wheels = [
982
  { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
983
  ]
984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  [[package]]
986
  name = "protobuf"
987
  version = "6.33.5"
@@ -1790,6 +1844,20 @@ wheels = [
1790
  { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
1791
  ]
1792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1793
  [[package]]
1794
  name = "watchdog"
1795
  version = "6.0.0"
 
127
  { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" },
128
  ]
129
 
130
+ [[package]]
131
+ name = "cfgv"
132
+ version = "3.5.0"
133
+ source = { registry = "https://pypi.org/simple" }
134
+ sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" }
135
+ wheels = [
136
+ { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" },
137
+ ]
138
+
139
  [[package]]
140
  name = "charset-normalizer"
141
  version = "3.4.4"
 
337
 
338
  [[package]]
339
  name = "developer-salary-prediction"
340
+ version = "1.0.0"
341
  source = { virtual = "." }
342
  dependencies = [
343
  { name = "bandit" },
 
345
  { name = "optuna" },
346
  { name = "pandas" },
347
  { name = "pip-audit" },
348
+ { name = "pre-commit" },
349
  { name = "pydantic" },
350
  { name = "pyyaml" },
351
  { name = "radon" },
 
373
  { name = "pandas", specifier = ">=2.0.0" },
374
  { name = "pip-audit", specifier = ">=2.10.0" },
375
  { name = "pip-audit", marker = "extra == 'dev'", specifier = ">=2.7.0" },
376
+ { name = "pre-commit", specifier = ">=4.5.1" },
377
  { name = "pydantic", specifier = ">=2.0.0" },
378
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
379
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=6.0.0" },
 
387
  ]
388
  provides-extras = ["dev"]
389
 
390
+ [[package]]
391
+ name = "distlib"
392
+ version = "0.4.0"
393
+ source = { registry = "https://pypi.org/simple" }
394
+ sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" }
395
+ wheels = [
396
+ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
397
+ ]
398
+
399
  [[package]]
400
  name = "filelock"
401
+ version = "3.24.3"
402
  source = { registry = "https://pypi.org/simple" }
403
+ sdist = { url = "https://files.pythonhosted.org/packages/73/92/a8e2479937ff39185d20dd6a851c1a63e55849e447a55e798cc2e1f49c65/filelock-3.24.3.tar.gz", hash = "sha256:011a5644dc937c22699943ebbfc46e969cdde3e171470a6e40b9533e5a72affa", size = 37935, upload-time = "2026-02-19T00:48:20.543Z" }
404
  wheels = [
405
+ { url = "https://files.pythonhosted.org/packages/9c/0f/5d0c71a1aefeb08efff26272149e07ab922b64f46c63363756224bd6872e/filelock-3.24.3-py3-none-any.whl", hash = "sha256:426e9a4660391f7f8a810d71b0555bce9008b0a1cc342ab1f6947d37639e002d", size = 24331, upload-time = "2026-02-19T00:48:18.465Z" },
406
  ]
407
 
408
  [[package]]
 
468
  { url = "https://files.pythonhosted.org/packages/e1/2b/98c7f93e6db9977aaee07eb1e51ca63bd5f779b900d362791d3252e60558/greenlet-3.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:301860987846c24cb8964bdec0e31a96ad4a2a801b41b4ef40963c1b44f33451", size = 233181, upload-time = "2026-01-23T15:33:00.29Z" },
469
  ]
470
 
471
+ [[package]]
472
+ name = "identify"
473
+ version = "2.6.16"
474
+ source = { registry = "https://pypi.org/simple" }
475
+ sdist = { url = "https://files.pythonhosted.org/packages/5b/8d/e8b97e6bd3fb6fb271346f7981362f1e04d6a7463abd0de79e1fda17c067/identify-2.6.16.tar.gz", hash = "sha256:846857203b5511bbe94d5a352a48ef2359532bc8f6727b5544077a0dcfb24980", size = 99360, upload-time = "2026-01-12T18:58:58.201Z" }
476
+ wheels = [
477
+ { url = "https://files.pythonhosted.org/packages/b8/58/40fbbcefeda82364720eba5cf2270f98496bdfa19ea75b4cccae79c698e6/identify-2.6.16-py2.py3-none-any.whl", hash = "sha256:391ee4d77741d994189522896270b787aed8670389bfd60f326d677d64a6dfb0", size = 99202, upload-time = "2026-01-12T18:58:56.627Z" },
478
+ ]
479
+
480
  [[package]]
481
  name = "idna"
482
  version = "3.11"
 
716
  { url = "https://files.pythonhosted.org/packages/03/cc/7cb74758e6df95e0c4e1253f203b6dd7f348bf2f29cf89e9210a2416d535/narwhals-2.16.0-py3-none-any.whl", hash = "sha256:846f1fd7093ac69d63526e50732033e86c30ea0026a44d9b23991010c7d1485d", size = 443951, upload-time = "2026-02-02T10:30:58.635Z" },
717
  ]
718
 
719
+ [[package]]
720
+ name = "nodeenv"
721
+ version = "1.10.0"
722
+ source = { registry = "https://pypi.org/simple" }
723
+ sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
724
+ wheels = [
725
+ { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
726
+ ]
727
+
728
  [[package]]
729
  name = "numpy"
730
  version = "2.4.2"
 
1020
  { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
1021
  ]
1022
 
1023
+ [[package]]
1024
+ name = "pre-commit"
1025
+ version = "4.5.1"
1026
+ source = { registry = "https://pypi.org/simple" }
1027
+ dependencies = [
1028
+ { name = "cfgv" },
1029
+ { name = "identify" },
1030
+ { name = "nodeenv" },
1031
+ { name = "pyyaml" },
1032
+ { name = "virtualenv" },
1033
+ ]
1034
+ sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" }
1035
+ wheels = [
1036
+ { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" },
1037
+ ]
1038
+
1039
  [[package]]
1040
  name = "protobuf"
1041
  version = "6.33.5"
 
1844
  { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
1845
  ]
1846
 
1847
+ [[package]]
1848
+ name = "virtualenv"
1849
+ version = "20.38.0"
1850
+ source = { registry = "https://pypi.org/simple" }
1851
+ dependencies = [
1852
+ { name = "distlib" },
1853
+ { name = "filelock" },
1854
+ { name = "platformdirs" },
1855
+ ]
1856
+ sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" }
1857
+ wheels = [
1858
+ { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" },
1859
+ ]
1860
+
1861
  [[package]]
1862
  name = "watchdog"
1863
  version = "6.0.0"