Spaces:
Running
Running
| import streamlit as st | |
| from pathlib import Path | |
| from catboost import CatBoostClassifier | |
| # from xgboost import XGBClassifier | |
| # from lightgbm import LGBMClassifier | |
| from sklearn.ensemble import RandomForestClassifier | |
| MODEL_DIR = Path("src/params") | |
| # MODEL_DIR.mkdir(exist_ok=True) | |
| import yaml | |
| def load_model_params(model_type, target="GVHD", mode="ensemble", path=MODEL_DIR / "model_params.yaml"): | |
| if target not in ["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: | |
| raise ValueError("target must be one of 'GVHD', 'Acute GVHD(<100 days)', or 'Chronic GVHD>100 days'") | |
| if mode not in ["ensemble", "single_model"]: | |
| raise ValueError("mode must be either 'ensemble' or 'single_model'") | |
| if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]: | |
| raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'") | |
| with open(path, "r") as f: | |
| all_params = yaml.safe_load(f) | |
| params = all_params[model_type][mode] | |
| if "random_seed" in params: | |
| st.session_state.random_seed = params["random_seed"] | |
| return params | |
| def get_model(model_type, mode="ensemble", target="GVHD", best_iter=None): | |
| if target == "GVHD": | |
| path = MODEL_DIR / "model_params_gvhd.yaml" | |
| elif target == "Acute GVHD(<100 days)": | |
| path = MODEL_DIR / "model_params_acute.yaml" | |
| elif target == "Chronic GVHD>100 days": | |
| path = MODEL_DIR / "model_params_chronic.yaml" | |
| params = load_model_params(model_type, target, mode, path) | |
| # iter is set for single_model mode, where | |
| if best_iter is not None: | |
| params['iterations'] = best_iter | |
| # if "random_seed" in st.session_state: | |
| # random_seed = st.session_state.random_seed | |
| if model_type == "CatBoost": | |
| return CatBoostClassifier(**params) | |
| # elif model_type == "XGBoost": | |
| # return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss") | |
| # elif model_type == "LightGBM": | |
| # return LGBMClassifier(**params) | |
| elif model_type == "RandomForest": | |
| return RandomForestClassifier(**params) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| def save_model(model, user_model_name, metrics_result_single=None): | |
| from datetime import datetime | |
| import io | |
| import pickle | |
| import json | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import login, CommitScheduler | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| timestamp = datetime.now().strftime("%y%m%d_%H%M%S") | |
| filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_single" | |
| # Prepare model dict (same as before) | |
| model_data = { | |
| "timestamp": timestamp, | |
| "model_name": user_model_name, | |
| "target_col": st.session_state.get("target_col", "UNKNOWN"), | |
| "model": model, | |
| "best_iteration": st.session_state.get("best_iteration"), | |
| "metrics_result_single": metrics_result_single, | |
| } | |
| # Serialize (pickle) to bytes | |
| model_bytes = pickle.dumps(model_data) | |
| # Prepare Parquet row | |
| row = { | |
| "filename": filename, | |
| "timestamp": timestamp, | |
| "type": "single", | |
| "model_file": {"path": filename, "bytes": model_bytes}, | |
| } | |
| table = pa.Table.from_pylist([row]) | |
| table = table.replace_schema_metadata({ | |
| "huggingface": json.dumps({"info": { | |
| "features": { | |
| "filename": {"_type": "Value", "dtype": "string"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| "type": {"_type": "Value", "dtype": "string"}, | |
| "model_file": {"_type": "Value", "dtype": "binary"}, | |
| } | |
| }}) | |
| }) | |
| # Write to in-memory buffer | |
| buf = io.BytesIO() | |
| pq.write_table(table, buf) | |
| buf.seek(0) | |
| # Upload to HF dataset | |
| scheduler = CommitScheduler( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo="models", | |
| token=os.environ["HF_TOKEN"], | |
| private=True, | |
| folder_path=Path("/tmp/dummy") | |
| ) | |
| scheduler.api.upload_file( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo=f"models/{filename}.parquet", | |
| path_or_fileobj=buf | |
| ) | |
| return filename | |
| def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None): | |
| from datetime import datetime | |
| import io | |
| import pickle | |
| import json | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import login, CommitScheduler | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| timestamp = datetime.now().strftime("%y%m%d_%H%M%S") | |
| filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_ensemble" | |
| ensemble_data = { | |
| "timestamp": timestamp, | |
| "model_name": user_model_name, | |
| "target_col": st.session_state.get("target_col", "UNKNOWN"), | |
| "model": models, | |
| "best_iterations": best_iterations, | |
| "fold_scores": fold_scores, | |
| "metrics_result_ensemble": metrics_result_ensemble, | |
| } | |
| model_bytes = pickle.dumps(ensemble_data) | |
| row = { | |
| "filename": filename, | |
| "timestamp": timestamp, | |
| "type": "ensemble", | |
| "model_file": {"path": filename, "bytes": model_bytes}, | |
| } | |
| table = pa.Table.from_pylist([row]) | |
| table = table.replace_schema_metadata({ | |
| "huggingface": json.dumps({"info": { | |
| "features": { | |
| "filename": {"_type": "Value", "dtype": "string"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| "type": {"_type": "Value", "dtype": "string"}, | |
| "model_file": {"_type": "Value", "dtype": "binary"}, | |
| } | |
| }}) | |
| }) | |
| buf = io.BytesIO() | |
| pq.write_table(table, buf) | |
| buf.seek(0) | |
| scheduler = CommitScheduler( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo="models", | |
| token=os.environ["HF_TOKEN"], | |
| private=True, | |
| folder_path=Path("/tmp/dummy") | |
| ) | |
| scheduler.api.upload_file( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo=f"models/{filename}.parquet", | |
| path_or_fileobj=buf | |
| ) | |
| return filename | |
| def load_model(model_name): | |
| from huggingface_hub import login, hf_hub_download | |
| import pyarrow.parquet as pq | |
| import pickle | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=os.environ["HF_TOKEN"]) | |
| all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset") | |
| model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")] | |
| # Find matching filename | |
| target_file = None | |
| for f in model_files: | |
| downloaded = hf_hub_download( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| filename=f, | |
| token=os.environ["HF_TOKEN"] | |
| ) | |
| table = pq.read_table(downloaded) | |
| row = table.to_pylist()[0] | |
| if row["filename"] == model_name: | |
| target_file = downloaded | |
| break | |
| if not target_file: | |
| raise FileNotFoundError(f"Model {model_name} not found in repo.") | |
| model_bytes = row["model_file"]["bytes"] | |
| return pickle.loads(model_bytes) | |
| def load_model_ensemble(filename): | |
| return load_model(filename) | |
| def ensemble_predict(models, X, cat_features): | |
| preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models) | |
| return preds |