Upload 8 files
Browse files- .gitignore +21 -0
- CONTRIBUTING.md +7 -0
- LICENSE +13 -0
- README.md +69 -17
- README_RETRAIN.md +110 -0
- Scenario_heldout_final_PRECISE.py +22 -33
- requirements.txt +8 -0
- retrain_helper.py +232 -0
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
.venv/
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# Data, models, caches
|
| 9 |
+
models_*/
|
| 10 |
+
models*/
|
| 11 |
+
cache_dir/
|
| 12 |
+
*.joblib
|
| 13 |
+
*.pkl
|
| 14 |
+
|
| 15 |
+
# Logs
|
| 16 |
+
*.log
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Contribution guidelines
|
| 2 |
+
|
| 3 |
+
- Fork the repo and open a pull request.
|
| 4 |
+
- Add tests for new functionality.
|
| 5 |
+
- Keep functions small and well-documented.
|
| 6 |
+
- Use the existing coding style.
|
| 7 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
[...standard MIT license continues...]
|
| 13 |
+
|
README.md
CHANGED
|
@@ -1,17 +1,69 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Project: PRECISE-GBM - Model training & retraining helpers
|
| 2 |
+
|
| 3 |
+
Overview
|
| 4 |
+
|
| 5 |
+
This repository contains code to train models (Gaussian Mixture labelling + SVM and ensemble classifiers) and to persist all artifacts required to reproduce or retrain models on new data. It includes:
|
| 6 |
+
|
| 7 |
+
- `Scenario_heldout_final_PRECISE.py` — training pipeline producing `.joblib` models and metadata JSONs (selected features, best params, CV results).
|
| 8 |
+
- `retrain_helper.py` — CLI utility to rebuild pipelines, set best params and retrain using saved selected-features and params JSONs. Supports JSON/YAML config files and auto-detection of model type.
|
| 9 |
+
- `README_RETRAIN.md` — detailed retrain examples and a notebook cell.
|
| 10 |
+
|
| 11 |
+
This repo also includes helper files to make it ready for GitHub:
|
| 12 |
+
- `requirements.txt` — Python dependencies
|
| 13 |
+
- `.gitignore` — recommended ignores (models, caches, logs)
|
| 14 |
+
- `LICENSE` — MIT license
|
| 15 |
+
- GitHub Actions workflow for CI (pytest smoke test)
|
| 16 |
+
|
| 17 |
+
Getting started (Windows PowerShell)
|
| 18 |
+
|
| 19 |
+
1) Create and activate a virtual environment
|
| 20 |
+
|
| 21 |
+
```powershell
|
| 22 |
+
python -m venv .venv
|
| 23 |
+
.\.venv\Scripts\Activate.ps1
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
2) Install dependencies
|
| 27 |
+
|
| 28 |
+
```powershell
|
| 29 |
+
pip install --upgrade pip
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
3) Run training (note: the training script reads data from absolute paths configured in the script — adjust them or run from an environment where those files are present)
|
| 34 |
+
|
| 35 |
+
```powershell
|
| 36 |
+
python Scenario_heldout_final_PRECISE.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
The training script will create model files under `models_LM22/` and `models_GBM/` and write metadata JSONs next to each joblib model (selected features, params, cv results) as well as group-level JSON summaries.
|
| 40 |
+
|
| 41 |
+
Retraining
|
| 42 |
+
|
| 43 |
+
See `README_RETRAIN.md` for detailed CLI and notebook examples. Short example:
|
| 44 |
+
|
| 45 |
+
```powershell
|
| 46 |
+
python retrain_helper.py \
|
| 47 |
+
--model-prefix "models_GBM/scenario_1/GBM_scen1_Tcell" \
|
| 48 |
+
--train-csv "data\new_train.csv" \
|
| 49 |
+
--label-col "label"
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Notes
|
| 53 |
+
|
| 54 |
+
- The training script contains hard-coded absolute paths to data files. Before running on another machine, update the `scenarios_*` file paths or place the datasets in the same paths.
|
| 55 |
+
- Retrain helper auto-detects model type when `--model-type` is omitted by looking for `{prefix}_svm_params.json` or `{prefix}_ens_params.json`.
|
| 56 |
+
- YAML config support for retrain requires PyYAML (`pip install pyyaml`).
|
| 57 |
+
|
| 58 |
+
CI
|
| 59 |
+
|
| 60 |
+
A basic GitHub Actions workflow runs a smoke pytest to ensure the retrain helper imports and basic pipeline construction works. It does not run heavy training.
|
| 61 |
+
|
| 62 |
+
Contributing
|
| 63 |
+
|
| 64 |
+
See `CONTRIBUTING.md` for guidance on opening issues and PRs.
|
| 65 |
+
|
| 66 |
+
License
|
| 67 |
+
|
| 68 |
+
This project is released under the MIT License — see `LICENSE`.
|
| 69 |
+
|
README_RETRAIN.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Retrain helper — quick start
|
| 2 |
+
|
| 3 |
+
This file shows a minimal end-to-end example (CLI and notebook cell) to retrain a saved model using the `retrain_helper.py` utility.
|
| 4 |
+
|
| 5 |
+
1) Quick CLI example
|
| 6 |
+
|
| 7 |
+
From the project root (PowerShell):
|
| 8 |
+
|
| 9 |
+
```powershell
|
| 10 |
+
# Retrain SVM using explicit CLI args
|
| 11 |
+
python retrain_helper.py \
|
| 12 |
+
--model-prefix "models_GBM/scenario_1/GBM_scen1_Tcell" \
|
| 13 |
+
--model-type svm \
|
| 14 |
+
--train-csv "data\new_train.csv" \
|
| 15 |
+
--label-col "label"
|
| 16 |
+
|
| 17 |
+
# Or let the helper auto-detect model type (it looks for *_svm_params.json or *_ens_params.json)
|
| 18 |
+
python retrain_helper.py \
|
| 19 |
+
--model-prefix "models_GBM/scenario_1/GBM_scen1_Tcell" \
|
| 20 |
+
--train-csv "data\new_train.csv" \
|
| 21 |
+
--label-col "label"
|
| 22 |
+
|
| 23 |
+
# Using a JSON config file (CLI args override config values)
|
| 24 |
+
python retrain_helper.py --config retrain_config.json
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2) Example config files
|
| 28 |
+
|
| 29 |
+
JSON (retrain_config.json):
|
| 30 |
+
|
| 31 |
+
```json
|
| 32 |
+
{
|
| 33 |
+
"model_prefix": "models_GBM/scenario_1/GBM_scen1_Tcell",
|
| 34 |
+
"train_csv": "data/new_train.csv",
|
| 35 |
+
"label_col": "label",
|
| 36 |
+
"out_dir": "models_GBM/scenario_1/retrained",
|
| 37 |
+
"overwrite": false
|
| 38 |
+
}
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
YAML (retrain_config.yml):
|
| 42 |
+
|
| 43 |
+
```yaml
|
| 44 |
+
model_prefix: models_GBM/scenario_1/GBM_scen1_Tcell
|
| 45 |
+
train_csv: data/new_train.csv
|
| 46 |
+
label_col: label
|
| 47 |
+
out_dir: models_GBM/scenario_1/retrained
|
| 48 |
+
overwrite: false
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
3) Notebook / Jupyter cell example (end-to-end)
|
| 52 |
+
|
| 53 |
+
This cell shows minimal steps to (A) run the CLI retrain helper using Python's subprocess, then (B) load the retrained model and run a quick prediction.
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
# Notebook cell (Jupyter / Colab / Kaggle) - Python
|
| 57 |
+
import subprocess
|
| 58 |
+
import json
|
| 59 |
+
from joblib import load
|
| 60 |
+
import pandas as pd
|
| 61 |
+
|
| 62 |
+
# 1) Run retrain (will create a timestamped retrained model and metadata JSON)
|
| 63 |
+
cmd = [
|
| 64 |
+
"python", "retrain_helper.py",
|
| 65 |
+
"--model-prefix", "models_GBM/scenario_1/GBM_scen1_Tcell",
|
| 66 |
+
"--train-csv", "data/new_train.csv",
|
| 67 |
+
"--label-col", "label"
|
| 68 |
+
]
|
| 69 |
+
print('Running:', ' '.join(cmd))
|
| 70 |
+
subprocess.check_call(cmd)
|
| 71 |
+
|
| 72 |
+
# 2) Locate the retrain metadata file in the model-prefix folder (it has suffix _retrain_meta_YYYYMMDD_HHMMSS.json)
|
| 73 |
+
# For the demo, search the output folder and load the latest metadata to find the retrained model path.
|
| 74 |
+
import glob, os
|
| 75 |
+
meta_files = glob.glob(os.path.join('models_GBM','scenario_1','GBM_scen1_Tcell*_retrain_meta_*.json'))
|
| 76 |
+
meta_files = sorted(meta_files)
|
| 77 |
+
print('Found meta files:', meta_files[-3:])
|
| 78 |
+
|
| 79 |
+
meta = json.load(open(meta_files[-1]))
|
| 80 |
+
model_path = meta['model_file']
|
| 81 |
+
print('Retrained model path:', model_path)
|
| 82 |
+
|
| 83 |
+
# 3) Load retrained model and perform a smoke prediction
|
| 84 |
+
pipe = load(model_path)
|
| 85 |
+
df = pd.read_csv('data/new_train.csv', index_col=0)
|
| 86 |
+
sel_meta = json.load(open('models_GBM/scenario_1/GBM_scen1_Tcell_selected_features.json'))
|
| 87 |
+
selected_features = sel_meta.get('selected_features', sel_meta)
|
| 88 |
+
X = df[selected_features]
|
| 89 |
+
print('Predict shape', X.shape)
|
| 90 |
+
probs = pipe.predict_proba(X)[:5]
|
| 91 |
+
print('Example probs (first 5 rows):', probs)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
4) Notes & troubleshooting
|
| 95 |
+
|
| 96 |
+
- The retrain helper expects the following files to exist next to the model prefix:
|
| 97 |
+
- `{prefix}_selected_features.json` — produced by the training script and contains `selected_features` list inside metadata
|
| 98 |
+
- `{prefix}_svm_params.json` or `{prefix}_ens_params.json` — best params metadata
|
| 99 |
+
|
| 100 |
+
- If parameter keys don't map to the pipeline built in `retrain_helper.py`, `pipe.set_params(**best_params)` may raise; in that case the script prints a warning and fits the pipeline with default parameter values.
|
| 101 |
+
|
| 102 |
+
- If you want to continue training from a saved estimator object instead of rebuilding the pipeline, modify the helper to `load` the .joblib file and call `.fit()` on it.
|
| 103 |
+
|
| 104 |
+
- YAML support requires PyYAML (`pip install pyyaml`).
|
| 105 |
+
|
| 106 |
+
5) Example minimal workflow to add to your notebook
|
| 107 |
+
|
| 108 |
+
- Run your `Scenario_heldout_final_PRECISE.py` training script to produce models and metadata.
|
| 109 |
+
- Prepare a CSV of new training data with the same column names as the original radiomics/immune features (index column required).
|
| 110 |
+
- Use `retrain_helper.py` through the CLI or config to retrain.
|
Scenario_heldout_final_PRECISE.py
CHANGED
|
@@ -21,13 +21,13 @@ from sklearn.metrics import (
|
|
| 21 |
accuracy_score, precision_score, recall_score,
|
| 22 |
f1_score, balanced_accuracy_score, matthews_corrcoef
|
| 23 |
)
|
| 24 |
-
from joblib import dump
|
| 25 |
|
| 26 |
# -------------------------
|
| 27 |
# Logging & warnings
|
| 28 |
# -------------------------
|
| 29 |
logging.basicConfig(
|
| 30 |
-
filename='
|
| 31 |
level=logging.INFO,
|
| 32 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 33 |
)
|
|
@@ -35,22 +35,20 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|
| 35 |
warnings.filterwarnings('ignore', category=ConvergenceWarning)
|
| 36 |
|
| 37 |
# Create directories for saving models if they don't exist
|
| 38 |
-
os.makedirs('
|
| 39 |
-
os.makedirs('
|
| 40 |
-
os.makedirs('
|
| 41 |
-
os.makedirs('
|
| 42 |
-
os.makedirs('
|
| 43 |
-
os.makedirs('
|
| 44 |
|
| 45 |
# -------------------------
|
| 46 |
# Caching for pipelines
|
| 47 |
# -------------------------
|
| 48 |
-
|
| 49 |
-
# PermissionError race conditions on Windows when using parallel workers.
|
| 50 |
-
memory = None
|
| 51 |
-
logging.info("Joblib Memory disabled; no pipeline caching will be used")
|
| 52 |
|
| 53 |
# Helper: convert numpy scalars/arrays and dicts into JSON-serializable Python types
|
|
|
|
| 54 |
|
| 55 |
def _convert_obj(o):
|
| 56 |
"""Recursively convert numpy types/arrays to native Python objects for JSON dumping."""
|
|
@@ -67,7 +65,7 @@ def _convert_obj(o):
|
|
| 67 |
if isinstance(o, (list, tuple)):
|
| 68 |
return [_convert_obj(v) for v in o]
|
| 69 |
# numpy scalar -> python native
|
| 70 |
-
if isinstance(o, (
|
| 71 |
return o.item()
|
| 72 |
# otherwise return as-is
|
| 73 |
return o
|
|
@@ -88,7 +86,7 @@ def _cv_results_to_serializable(cv_dict):
|
|
| 88 |
# -------------------------
|
| 89 |
# Utility: two-step Lasso selection
|
| 90 |
# -------------------------
|
| 91 |
-
def select_features(X, y, alphas=(0.1, 0.01), cv=5, max_iter=10000, n_jobs
|
| 92 |
for alpha in alphas:
|
| 93 |
lasso = LassoCV(
|
| 94 |
alphas=[alpha], cv=cv,
|
|
@@ -217,7 +215,6 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 217 |
y_tr = 1 - y_tr; y_ho = 1 - y_ho
|
| 218 |
# save gmm model
|
| 219 |
gmm_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_gmm_model.joblib'
|
| 220 |
-
os.makedirs(os.path.dirname(gmm_model_path), exist_ok=True)
|
| 221 |
dump(gmm, gmm_model_path)
|
| 222 |
logging.info(f"Saved GMM model to {gmm_model_path}")
|
| 223 |
logging.info(f"GMM means for {sig_name}:{scen_id}, col {col}: {gmm.means_.flatten().tolist()}")
|
|
@@ -239,19 +236,14 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 239 |
json.dump(meta, _f, indent=2)
|
| 240 |
|
| 241 |
# SVM nested CV
|
| 242 |
-
# Avoid using joblib.Memory at the Pipeline level when running parallel CV (n_jobs != 1).
|
| 243 |
-
# Joblib's Memory can hit race conditions on Windows when multiple workers try to
|
| 244 |
-
# read/write the same cache files which leads to PermissionError (output.pkl).
|
| 245 |
-
# We therefore disable pipeline caching here (memory=None). This does NOT affect
|
| 246 |
-
# saving final models or params (those are written explicitly with dump/json below).
|
| 247 |
pipe_svm = Pipeline([
|
| 248 |
('scaler', StandardScaler()),
|
| 249 |
('clf', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 250 |
-
], memory=
|
| 251 |
search_svm = RandomizedSearchCV(
|
| 252 |
pipe_svm, param_dist_svm, n_iter=5,
|
| 253 |
cv=inner_cv, scoring='balanced_accuracy',
|
| 254 |
-
n_jobs
|
| 255 |
)
|
| 256 |
search_svm.fit(X_tr_sel, y_tr)
|
| 257 |
y_pred_svm = search_svm.predict(X_ho_sel)
|
|
@@ -259,7 +251,6 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 259 |
for k, v in search_svm.cv_results_.items()}
|
| 260 |
# save SVM model
|
| 261 |
svm_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_svm_model.joblib'
|
| 262 |
-
os.makedirs(os.path.dirname(svm_model_path), exist_ok=True)
|
| 263 |
dump(search_svm.best_estimator_, svm_model_path)
|
| 264 |
logging.info(f"Saved SVM model to {svm_model_path}")
|
| 265 |
logging.info(f"SVM best params for {sig_name}:{scen_id}, col {col}: {search_svm.best_params_}")
|
|
@@ -287,20 +278,20 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 287 |
base_pipe = Pipeline([
|
| 288 |
('scaler', StandardScaler()),
|
| 289 |
('classifier', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 290 |
-
], memory=
|
| 291 |
ensemble = VotingClassifier([
|
| 292 |
('svm', base_pipe),
|
| 293 |
('rf', RandomForestClassifier(class_weight='balanced', random_state=42)),
|
| 294 |
('gb', HistGradientBoostingClassifier(random_state=42))
|
| 295 |
-
], voting='soft', weights=[1,1,1], n_jobs
|
| 296 |
pipe_ens = Pipeline([
|
| 297 |
('scaler', StandardScaler()),
|
| 298 |
('ensemble', ensemble)
|
| 299 |
-
], memory=
|
| 300 |
search_ens = RandomizedSearchCV(
|
| 301 |
pipe_ens, param_dist_ensemble, n_iter=3,
|
| 302 |
cv=inner_cv, scoring='balanced_accuracy',
|
| 303 |
-
n_jobs
|
| 304 |
)
|
| 305 |
search_ens.fit(X_tr_sel, y_tr)
|
| 306 |
y_pred_ens = search_ens.predict(X_ho_sel)
|
|
@@ -308,7 +299,6 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 308 |
for k, v in search_ens.cv_results_.items()}
|
| 309 |
# save Ensemble model
|
| 310 |
ens_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_ens_model.joblib'
|
| 311 |
-
os.makedirs(os.path.dirname(ens_model_path), exist_ok=True)
|
| 312 |
dump(search_ens.best_estimator_, ens_model_path)
|
| 313 |
logging.info(f"Saved Ensemble model to {ens_model_path}")
|
| 314 |
logging.info(f"Ensemble best params for {sig_name}:{scen_id}, col {col}: {search_ens.best_params_}")
|
|
@@ -347,8 +337,7 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 347 |
scen_cv[col] = {'svm_cv': cv_svm, 'ensemble_cv': cv_ens}
|
| 348 |
|
| 349 |
except Exception as e:
|
| 350 |
-
|
| 351 |
-
logging.exception(f"{sig_name}:{scen_id}, col {col}: unexpected error")
|
| 352 |
print(f"[ERROR] {sig_name}:{scen_id}, column {col}: {e}")
|
| 353 |
|
| 354 |
# Save for this scenario
|
|
@@ -358,11 +347,11 @@ for sig_name, scenarios in signature_groups.items():
|
|
| 358 |
logging.info(f"[{sig_name}] {scen_id} done in {time.time()-t0:.1f}s")
|
| 359 |
|
| 360 |
# Write group-level JSONs
|
| 361 |
-
with open(f'
|
| 362 |
json.dump(all_results, f, indent=2)
|
| 363 |
-
with open(f'
|
| 364 |
json.dump(all_features, f, indent=2)
|
| 365 |
-
with open(f'
|
| 366 |
json.dump(all_cv, f, indent=2)
|
| 367 |
print(f"✅ {sig_name} group complete: scenarios={list(all_results.keys())}")
|
| 368 |
|
|
|
|
| 21 |
accuracy_score, precision_score, recall_score,
|
| 22 |
f1_score, balanced_accuracy_score, matthews_corrcoef
|
| 23 |
)
|
| 24 |
+
from joblib import Memory, dump
|
| 25 |
|
| 26 |
# -------------------------
|
| 27 |
# Logging & warnings
|
| 28 |
# -------------------------
|
| 29 |
logging.basicConfig(
|
| 30 |
+
filename='nested_lodo_groups.log',
|
| 31 |
level=logging.INFO,
|
| 32 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 33 |
)
|
|
|
|
| 35 |
warnings.filterwarnings('ignore', category=ConvergenceWarning)
|
| 36 |
|
| 37 |
# Create directories for saving models if they don't exist
|
| 38 |
+
os.makedirs('models_GBM/scenario_1', exist_ok=True)
|
| 39 |
+
os.makedirs('models_GBM/scenario_2', exist_ok=True)
|
| 40 |
+
os.makedirs('models_GBM/scenario_3', exist_ok=True)
|
| 41 |
+
os.makedirs('models_LM22/scenario_1', exist_ok=True)
|
| 42 |
+
os.makedirs('models_LM22/scenario_2', exist_ok=True)
|
| 43 |
+
os.makedirs('models_LM22/scenario_3', exist_ok=True)
|
| 44 |
|
| 45 |
# -------------------------
|
| 46 |
# Caching for pipelines
|
| 47 |
# -------------------------
|
| 48 |
+
memory = Memory(location='cache_dir', verbose=0)
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# Helper: convert numpy scalars/arrays and dicts into JSON-serializable Python types
|
| 51 |
+
import numpy as _np
|
| 52 |
|
| 53 |
def _convert_obj(o):
|
| 54 |
"""Recursively convert numpy types/arrays to native Python objects for JSON dumping."""
|
|
|
|
| 65 |
if isinstance(o, (list, tuple)):
|
| 66 |
return [_convert_obj(v) for v in o]
|
| 67 |
# numpy scalar -> python native
|
| 68 |
+
if isinstance(o, (_np.integer, _np.floating, _np.bool_)):
|
| 69 |
return o.item()
|
| 70 |
# otherwise return as-is
|
| 71 |
return o
|
|
|
|
| 86 |
# -------------------------
|
| 87 |
# Utility: two-step Lasso selection
|
| 88 |
# -------------------------
|
| 89 |
+
def select_features(X, y, alphas=(0.1, 0.01), cv=5, max_iter=10000, n_jobs=-1, random_state=42):
|
| 90 |
for alpha in alphas:
|
| 91 |
lasso = LassoCV(
|
| 92 |
alphas=[alpha], cv=cv,
|
|
|
|
| 215 |
y_tr = 1 - y_tr; y_ho = 1 - y_ho
|
| 216 |
# save gmm model
|
| 217 |
gmm_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_gmm_model.joblib'
|
|
|
|
| 218 |
dump(gmm, gmm_model_path)
|
| 219 |
logging.info(f"Saved GMM model to {gmm_model_path}")
|
| 220 |
logging.info(f"GMM means for {sig_name}:{scen_id}, col {col}: {gmm.means_.flatten().tolist()}")
|
|
|
|
| 236 |
json.dump(meta, _f, indent=2)
|
| 237 |
|
| 238 |
# SVM nested CV
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
pipe_svm = Pipeline([
|
| 240 |
('scaler', StandardScaler()),
|
| 241 |
('clf', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 242 |
+
], memory=memory)
|
| 243 |
search_svm = RandomizedSearchCV(
|
| 244 |
pipe_svm, param_dist_svm, n_iter=5,
|
| 245 |
cv=inner_cv, scoring='balanced_accuracy',
|
| 246 |
+
n_jobs=-1, refit=True, error_score='raise'
|
| 247 |
)
|
| 248 |
search_svm.fit(X_tr_sel, y_tr)
|
| 249 |
y_pred_svm = search_svm.predict(X_ho_sel)
|
|
|
|
| 251 |
for k, v in search_svm.cv_results_.items()}
|
| 252 |
# save SVM model
|
| 253 |
svm_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_svm_model.joblib'
|
|
|
|
| 254 |
dump(search_svm.best_estimator_, svm_model_path)
|
| 255 |
logging.info(f"Saved SVM model to {svm_model_path}")
|
| 256 |
logging.info(f"SVM best params for {sig_name}:{scen_id}, col {col}: {search_svm.best_params_}")
|
|
|
|
| 278 |
base_pipe = Pipeline([
|
| 279 |
('scaler', StandardScaler()),
|
| 280 |
('classifier', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 281 |
+
], memory=memory)
|
| 282 |
ensemble = VotingClassifier([
|
| 283 |
('svm', base_pipe),
|
| 284 |
('rf', RandomForestClassifier(class_weight='balanced', random_state=42)),
|
| 285 |
('gb', HistGradientBoostingClassifier(random_state=42))
|
| 286 |
+
], voting='soft', weights=[1,1,1], n_jobs=-1)
|
| 287 |
pipe_ens = Pipeline([
|
| 288 |
('scaler', StandardScaler()),
|
| 289 |
('ensemble', ensemble)
|
| 290 |
+
], memory=memory)
|
| 291 |
search_ens = RandomizedSearchCV(
|
| 292 |
pipe_ens, param_dist_ensemble, n_iter=3,
|
| 293 |
cv=inner_cv, scoring='balanced_accuracy',
|
| 294 |
+
n_jobs=-1, refit=True, error_score='raise'
|
| 295 |
)
|
| 296 |
search_ens.fit(X_tr_sel, y_tr)
|
| 297 |
y_pred_ens = search_ens.predict(X_ho_sel)
|
|
|
|
| 299 |
for k, v in search_ens.cv_results_.items()}
|
| 300 |
# save Ensemble model
|
| 301 |
ens_model_path = f'models_{sig_name}/scenario_{scen_id}/{sig_name}_scen{scen_id}_{col}_ens_model.joblib'
|
|
|
|
| 302 |
dump(search_ens.best_estimator_, ens_model_path)
|
| 303 |
logging.info(f"Saved Ensemble model to {ens_model_path}")
|
| 304 |
logging.info(f"Ensemble best params for {sig_name}:{scen_id}, col {col}: {search_ens.best_params_}")
|
|
|
|
| 337 |
scen_cv[col] = {'svm_cv': cv_svm, 'ensemble_cv': cv_ens}
|
| 338 |
|
| 339 |
except Exception as e:
|
| 340 |
+
logging.error(f"{sig_name}:{scen_id}, col {col}: {e}")
|
|
|
|
| 341 |
print(f"[ERROR] {sig_name}:{scen_id}, column {col}: {e}")
|
| 342 |
|
| 343 |
# Save for this scenario
|
|
|
|
| 347 |
logging.info(f"[{sig_name}] {scen_id} done in {time.time()-t0:.1f}s")
|
| 348 |
|
| 349 |
# Write group-level JSONs
|
| 350 |
+
with open(f'nested_results111_{sig_name}.json', 'w') as f:
|
| 351 |
json.dump(all_results, f, indent=2)
|
| 352 |
+
with open(f'nested_features111_{sig_name}.json', 'w') as f:
|
| 353 |
json.dump(all_features, f, indent=2)
|
| 354 |
+
with open(f'nested_cv111_{sig_name}.json', 'w') as f:
|
| 355 |
json.dump(all_cv, f, indent=2)
|
| 356 |
print(f"✅ {sig_name} group complete: scenarios={list(all_results.keys())}")
|
| 357 |
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
pandas
|
| 3 |
+
scikit-learn
|
| 4 |
+
joblib
|
| 5 |
+
tqdm
|
| 6 |
+
pyyaml
|
| 7 |
+
pytest
|
| 8 |
+
|
retrain_helper.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""retrain_helper.py
|
| 2 |
+
Small CLI to retrain saved SVM or Ensemble models using saved metadata.
|
| 3 |
+
|
| 4 |
+
Enhancements in this version:
|
| 5 |
+
- Accept a JSON or YAML config file via --config with keys: model_prefix, model_type (optional), train_csv, label_col, out_dir (optional)
|
| 6 |
+
- If model_type is omitted, auto-detect by checking for *_svm_params.json or *_ens_params.json next to the prefix
|
| 7 |
+
- CLI arguments override config values
|
| 8 |
+
|
| 9 |
+
Usage (from project root):
|
| 10 |
+
python retrain_helper.py --model-prefix "models_GBM/scenario_1/GBM_scen1_Tcell" --model-type svm --train-csv new_train.csv --label-col label
|
| 11 |
+
or using config.json/yaml:
|
| 12 |
+
python retrain_helper.py --config retrain_config.json
|
| 13 |
+
|
| 14 |
+
The script expects files with these suffixes next to the prefix:
|
| 15 |
+
- _selected_features.json (contains metadata.selected_features list)
|
| 16 |
+
- _svm_params.json or _ens_params.json (contains metadata.best_params)
|
| 17 |
+
|
| 18 |
+
It builds pipelines matching the original script, sets the best params, fits on the provided CSV using the selected features, and saves a retrained joblib model and a metadata JSON.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
from datetime import datetime, timezone
|
| 25 |
+
from joblib import dump
|
| 26 |
+
import pandas as pd
|
| 27 |
+
from sklearn.pipeline import Pipeline
|
| 28 |
+
from sklearn.preprocessing import StandardScaler
|
| 29 |
+
from sklearn.svm import SVC
|
| 30 |
+
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier, VotingClassifier
|
| 31 |
+
|
| 32 |
+
# optional yaml support
|
| 33 |
+
try:
|
| 34 |
+
import yaml
|
| 35 |
+
_HAS_YAML = True
|
| 36 |
+
except Exception:
|
| 37 |
+
_HAS_YAML = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_json_meta(path):
|
| 41 |
+
with open(path, 'r') as f:
|
| 42 |
+
return json.load(f)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_config(path):
|
| 46 |
+
"""Load JSON or YAML config file into a dict."""
|
| 47 |
+
if path.lower().endswith(('.yaml', '.yml')):
|
| 48 |
+
if not _HAS_YAML:
|
| 49 |
+
raise RuntimeError('PyYAML is not installed, cannot read YAML config')
|
| 50 |
+
with open(path, 'r') as f:
|
| 51 |
+
return yaml.safe_load(f)
|
| 52 |
+
else:
|
| 53 |
+
with open(path, 'r') as f:
|
| 54 |
+
return json.load(f)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_svm_pipeline():
|
| 58 |
+
pipe = Pipeline([
|
| 59 |
+
('scaler', StandardScaler()),
|
| 60 |
+
('clf', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 61 |
+
])
|
| 62 |
+
return pipe
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def build_ensemble_pipeline():
|
| 66 |
+
# base pipe inside voting ensemble should be named and structured like in training script
|
| 67 |
+
base_pipe = Pipeline([
|
| 68 |
+
('scaler', StandardScaler()),
|
| 69 |
+
('classifier', SVC(class_weight='balanced', probability=True, random_state=42))
|
| 70 |
+
])
|
| 71 |
+
ensemble = VotingClassifier([
|
| 72 |
+
('svm', base_pipe),
|
| 73 |
+
('rf', RandomForestClassifier(class_weight='balanced', random_state=42)),
|
| 74 |
+
('gb', HistGradientBoostingClassifier(random_state=42))
|
| 75 |
+
], voting='soft', weights=[1, 1, 1])
|
| 76 |
+
pipe = Pipeline([
|
| 77 |
+
('scaler', StandardScaler()),
|
| 78 |
+
('ensemble', ensemble)
|
| 79 |
+
])
|
| 80 |
+
return pipe
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _auto_detect_model_type(model_prefix):
|
| 84 |
+
"""Return 'svm' or 'ens' based on presence of params files next to the prefix.
|
| 85 |
+
If both present, prefer 'svm' and warn."""
|
| 86 |
+
svm_path = model_prefix + '_svm_params.json'
|
| 87 |
+
ens_path = model_prefix + '_ens_params.json'
|
| 88 |
+
svm_exists = os.path.exists(svm_path)
|
| 89 |
+
ens_exists = os.path.exists(ens_path)
|
| 90 |
+
if svm_exists and not ens_exists:
|
| 91 |
+
return 'svm'
|
| 92 |
+
if ens_exists and not svm_exists:
|
| 93 |
+
return 'ens'
|
| 94 |
+
if svm_exists and ens_exists:
|
| 95 |
+
print('Warning: both SVM and Ensemble params found; defaulting to SVM')
|
| 96 |
+
return 'svm'
|
| 97 |
+
# if neither exists, raise
|
| 98 |
+
raise FileNotFoundError(f'Neither {svm_path} nor {ens_path} found for auto-detection')
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def retrain(model_prefix, model_type=None, train_csv=None, label_col=None, out_dir=None, overwrite=False):
|
| 102 |
+
"""Retrain a saved model using the saved selected-features and best-params metadata.
|
| 103 |
+
|
| 104 |
+
model_type can be 'svm' or 'ens' (ensemble). If None, the function will try to auto-detect.
|
| 105 |
+
"""
|
| 106 |
+
if model_type is None:
|
| 107 |
+
model_type = _auto_detect_model_type(model_prefix)
|
| 108 |
+
|
| 109 |
+
# Resolve file paths
|
| 110 |
+
sel_path = model_prefix + '_selected_features.json'
|
| 111 |
+
if model_type.lower() == 'svm':
|
| 112 |
+
params_path = model_prefix + '_svm_params.json'
|
| 113 |
+
elif model_type.lower() in ('ens', 'ensemble'):
|
| 114 |
+
params_path = model_prefix + '_ens_params.json'
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError('model_type must be "svm" or "ens"')
|
| 117 |
+
|
| 118 |
+
if not os.path.exists(sel_path):
|
| 119 |
+
raise FileNotFoundError(f'Selected-features file not found: {sel_path}')
|
| 120 |
+
if not os.path.exists(params_path):
|
| 121 |
+
raise FileNotFoundError(f'Params file not found: {params_path}')
|
| 122 |
+
if train_csv is None or not os.path.exists(train_csv):
|
| 123 |
+
raise FileNotFoundError(f'Train CSV not found: {train_csv}')
|
| 124 |
+
|
| 125 |
+
sel_meta = load_json_meta(sel_path)
|
| 126 |
+
# selected features are stored under top-level key 'selected_features' (script writes metadata)
|
| 127 |
+
if isinstance(sel_meta, dict) and 'selected_features' in sel_meta:
|
| 128 |
+
sel_features = sel_meta['selected_features']
|
| 129 |
+
elif isinstance(sel_meta, list):
|
| 130 |
+
sel_features = sel_meta
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError('Unexpected selected features file format')
|
| 133 |
+
|
| 134 |
+
params_meta = load_json_meta(params_path)
|
| 135 |
+
# params saved under 'best_params' inside metadata
|
| 136 |
+
if isinstance(params_meta, dict) and 'best_params' in params_meta:
|
| 137 |
+
best_params = params_meta['best_params']
|
| 138 |
+
else:
|
| 139 |
+
# fallback: file may contain bare params
|
| 140 |
+
best_params = params_meta
|
| 141 |
+
|
| 142 |
+
# load training data and subset columns
|
| 143 |
+
df = pd.read_csv(train_csv, index_col=0)
|
| 144 |
+
|
| 145 |
+
missing = [c for c in sel_features if c not in df.columns]
|
| 146 |
+
if missing:
|
| 147 |
+
raise ValueError(f'The following selected features are missing from training CSV: {missing}')
|
| 148 |
+
|
| 149 |
+
X = df[sel_features].values
|
| 150 |
+
y = df[label_col].values
|
| 151 |
+
|
| 152 |
+
# Build pipeline and set params
|
| 153 |
+
if model_type.lower() == 'svm':
|
| 154 |
+
pipe = build_svm_pipeline()
|
| 155 |
+
else:
|
| 156 |
+
pipe = build_ensemble_pipeline()
|
| 157 |
+
|
| 158 |
+
# set params (keys should match the original training param names)
|
| 159 |
+
try:
|
| 160 |
+
pipe.set_params(**best_params)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print('Warning: failed to set all params on pipeline:', e)
|
| 163 |
+
# continue anyway
|
| 164 |
+
|
| 165 |
+
# Fit
|
| 166 |
+
print(f'Fitting {model_type} on {X.shape[0]} samples with {X.shape[1]} features...')
|
| 167 |
+
pipe.fit(X, y)
|
| 168 |
+
|
| 169 |
+
# Save retrained model
|
| 170 |
+
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
|
| 171 |
+
if out_dir is None:
|
| 172 |
+
out_dir = os.path.dirname(model_prefix) or '.'
|
| 173 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 174 |
+
model_out_path = os.path.join(out_dir, os.path.basename(model_prefix) + f'_{model_type}_retrained_{ts}.joblib')
|
| 175 |
+
|
| 176 |
+
# respect overwrite flag
|
| 177 |
+
if os.path.exists(model_out_path) and not overwrite:
|
| 178 |
+
raise FileExistsError(f'Model output already exists: {model_out_path}. Use overwrite=True to overwrite.')
|
| 179 |
+
|
| 180 |
+
dump(pipe, model_out_path)
|
| 181 |
+
|
| 182 |
+
# Save retrain metadata
|
| 183 |
+
meta = {
|
| 184 |
+
'retrained_at': datetime.now(timezone.utc).isoformat(),
|
| 185 |
+
'version': ts,
|
| 186 |
+
'model_type': model_type,
|
| 187 |
+
'n_samples': int(X.shape[0]),
|
| 188 |
+
'n_features': int(X.shape[1]),
|
| 189 |
+
'selected_features_file': os.path.abspath(sel_path),
|
| 190 |
+
'params_file': os.path.abspath(params_path),
|
| 191 |
+
'model_file': os.path.abspath(str(model_out_path))
|
| 192 |
+
}
|
| 193 |
+
meta_out = os.path.join(out_dir, os.path.basename(model_prefix) + f'_{model_type}_retrain_meta_{ts}.json')
|
| 194 |
+
with open(meta_out, 'w') as f:
|
| 195 |
+
json.dump(meta, f, indent=2)
|
| 196 |
+
|
| 197 |
+
print('Retrained model saved to:', model_out_path)
|
| 198 |
+
print('Retrain metadata saved to:', meta_out)
|
| 199 |
+
return model_out_path, meta_out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main():
|
| 203 |
+
p = argparse.ArgumentParser(description='Retrain a saved model using saved selected features and best params')
|
| 204 |
+
p.add_argument('--config', required=False, help='Path to JSON or YAML config file with keys: model_prefix, model_type (optional), train_csv, label_col, out_dir')
|
| 205 |
+
p.add_argument('--model-prefix', required=False, help='Path prefix to model files (without suffix). E.g. models_GBM/scenario_1/GBM_scen1_Tcell')
|
| 206 |
+
p.add_argument('--model-type', required=False, choices=['svm', 'ens', 'ensemble'], help='svm or ens (if omitted, auto-detect)')
|
| 207 |
+
p.add_argument('--train-csv', required=False, help='CSV with training data (index column present). Must contain selected features and label column')
|
| 208 |
+
p.add_argument('--label-col', required=False, help='Name of the label column in train CSV')
|
| 209 |
+
p.add_argument('--out-dir', default=None, help='Output directory (defaults to model-prefix directory)')
|
| 210 |
+
p.add_argument('--overwrite', action='store_true', help='Overwrite existing output files')
|
| 211 |
+
args = p.parse_args()
|
| 212 |
+
|
| 213 |
+
cfg = {}
|
| 214 |
+
if args.config:
|
| 215 |
+
cfg = load_config(args.config) or {}
|
| 216 |
+
|
| 217 |
+
# Merge config and CLI args; CLI takes precedence
|
| 218 |
+
model_prefix = args.model_prefix or cfg.get('model_prefix')
|
| 219 |
+
model_type = args.model_type or cfg.get('model_type')
|
| 220 |
+
train_csv = args.train_csv or cfg.get('train_csv')
|
| 221 |
+
label_col = args.label_col or cfg.get('label_col')
|
| 222 |
+
out_dir = args.out_dir or cfg.get('out_dir')
|
| 223 |
+
overwrite = args.overwrite or cfg.get('overwrite', False)
|
| 224 |
+
|
| 225 |
+
if model_prefix is None or train_csv is None or label_col is None:
|
| 226 |
+
raise ValueError('model_prefix, train_csv and label_col must be provided either via --config or CLI args')
|
| 227 |
+
|
| 228 |
+
retrain(model_prefix, model_type=model_type, train_csv=train_csv, label_col=label_col, out_dir=out_dir, overwrite=overwrite)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
main()
|