import streamlit as st from pathlib import Path from catboost import CatBoostClassifier # from xgboost import XGBClassifier # from lightgbm import LGBMClassifier from sklearn.ensemble import RandomForestClassifier MODEL_DIR = Path("src/params") # MODEL_DIR.mkdir(exist_ok=True) import yaml def load_model_params(model_type, target="GVHD", mode="ensemble", path=MODEL_DIR / "model_params.yaml"): if target not in ["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: raise ValueError("target must be one of 'GVHD', 'Acute GVHD(<100 days)', or 'Chronic GVHD>100 days'") if mode not in ["ensemble", "single_model"]: raise ValueError("mode must be either 'ensemble' or 'single_model'") if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]: raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'") with open(path, "r") as f: all_params = yaml.safe_load(f) params = all_params[model_type][mode] if "random_seed" in params: st.session_state.random_seed = params["random_seed"] return params def get_model(model_type, mode="ensemble", target="GVHD", best_iter=None): if target == "GVHD": path = MODEL_DIR / "model_params_gvhd.yaml" elif target == "Acute GVHD(<100 days)": path = MODEL_DIR / "model_params_acute.yaml" elif target == "Chronic GVHD>100 days": path = MODEL_DIR / "model_params_chronic.yaml" params = load_model_params(model_type, target, mode, path) # iter is set for single_model mode, where if best_iter is not None: params['iterations'] = best_iter # if "random_seed" in st.session_state: # random_seed = st.session_state.random_seed if model_type == "CatBoost": return CatBoostClassifier(**params) # elif model_type == "XGBoost": # return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss") # elif model_type == "LightGBM": # return LGBMClassifier(**params) elif model_type == "RandomForest": return RandomForestClassifier(**params) else: raise ValueError(f"Unsupported model type: {model_type}") def save_model(model, user_model_name, metrics_result_single=None): from datetime import datetime import io import pickle import json import pyarrow as pa import pyarrow.parquet as pq from huggingface_hub import login, CommitScheduler import os if "HF_TOKEN" in os.environ: login(token=os.environ["HF_TOKEN"]) timestamp = datetime.now().strftime("%y%m%d_%H%M%S") filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_single" # Prepare model dict (same as before) model_data = { "timestamp": timestamp, "model_name": user_model_name, "target_col": st.session_state.get("target_col", "UNKNOWN"), "model": model, "best_iteration": st.session_state.get("best_iteration"), "metrics_result_single": metrics_result_single, } # Serialize (pickle) to bytes model_bytes = pickle.dumps(model_data) # Prepare Parquet row row = { "filename": filename, "timestamp": timestamp, "type": "single", "model_file": {"path": filename, "bytes": model_bytes}, } table = pa.Table.from_pylist([row]) table = table.replace_schema_metadata({ "huggingface": json.dumps({"info": { "features": { "filename": {"_type": "Value", "dtype": "string"}, "timestamp": {"_type": "Value", "dtype": "string"}, "type": {"_type": "Value", "dtype": "string"}, "model_file": {"_type": "Value", "dtype": "binary"}, } }}) }) # Write to in-memory buffer buf = io.BytesIO() pq.write_table(table, buf) buf.seek(0) # Upload to HF dataset scheduler = CommitScheduler( repo_id=os.environ["HF_REPO_ID"], repo_type="dataset", path_in_repo="models", token=os.environ["HF_TOKEN"], private=True, folder_path=Path("/tmp/dummy") ) scheduler.api.upload_file( repo_id=os.environ["HF_REPO_ID"], repo_type="dataset", path_in_repo=f"models/{filename}.parquet", path_or_fileobj=buf ) return filename def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None): from datetime import datetime import io import pickle import json import pyarrow as pa import pyarrow.parquet as pq from huggingface_hub import login, CommitScheduler import os if "HF_TOKEN" in os.environ: login(token=os.environ["HF_TOKEN"]) timestamp = datetime.now().strftime("%y%m%d_%H%M%S") filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_ensemble" ensemble_data = { "timestamp": timestamp, "model_name": user_model_name, "target_col": st.session_state.get("target_col", "UNKNOWN"), "model": models, "best_iterations": best_iterations, "fold_scores": fold_scores, "metrics_result_ensemble": metrics_result_ensemble, } model_bytes = pickle.dumps(ensemble_data) row = { "filename": filename, "timestamp": timestamp, "type": "ensemble", "model_file": {"path": filename, "bytes": model_bytes}, } table = pa.Table.from_pylist([row]) table = table.replace_schema_metadata({ "huggingface": json.dumps({"info": { "features": { "filename": {"_type": "Value", "dtype": "string"}, "timestamp": {"_type": "Value", "dtype": "string"}, "type": {"_type": "Value", "dtype": "string"}, "model_file": {"_type": "Value", "dtype": "binary"}, } }}) }) buf = io.BytesIO() pq.write_table(table, buf) buf.seek(0) scheduler = CommitScheduler( repo_id=os.environ["HF_REPO_ID"], repo_type="dataset", path_in_repo="models", token=os.environ["HF_TOKEN"], private=True, folder_path=Path("/tmp/dummy") ) scheduler.api.upload_file( repo_id=os.environ["HF_REPO_ID"], repo_type="dataset", path_in_repo=f"models/{filename}.parquet", path_or_fileobj=buf ) return filename def load_model(model_name): from huggingface_hub import login, hf_hub_download import pyarrow.parquet as pq import pickle import os if "HF_TOKEN" in os.environ: login(token=os.environ["HF_TOKEN"]) from huggingface_hub import HfApi api = HfApi(token=os.environ["HF_TOKEN"]) all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset") model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")] # Find matching filename target_file = None for f in model_files: downloaded = hf_hub_download( repo_id=os.environ["HF_REPO_ID"], repo_type="dataset", filename=f, token=os.environ["HF_TOKEN"] ) table = pq.read_table(downloaded) row = table.to_pylist()[0] if row["filename"] == model_name: target_file = downloaded break if not target_file: raise FileNotFoundError(f"Model {model_name} not found in repo.") model_bytes = row["model_file"]["bytes"] return pickle.loads(model_bytes) def load_model_ensemble(filename): return load_model(filename) def ensemble_predict(models, X, cat_features): preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models) return preds