gvhd-intel-pro / src /model_utils.py
mridzuan's picture
fix syntax error
47c68ae
import streamlit as st
from pathlib import Path
from catboost import CatBoostClassifier
# from xgboost import XGBClassifier
# from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
MODEL_DIR = Path("src/params")
# MODEL_DIR.mkdir(exist_ok=True)
import yaml
def load_model_params(model_type, target="GVHD", mode="ensemble", path=MODEL_DIR / "model_params.yaml"):
if target not in ["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
raise ValueError("target must be one of 'GVHD', 'Acute GVHD(<100 days)', or 'Chronic GVHD>100 days'")
if mode not in ["ensemble", "single_model"]:
raise ValueError("mode must be either 'ensemble' or 'single_model'")
if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]:
raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'")
with open(path, "r") as f:
all_params = yaml.safe_load(f)
params = all_params[model_type][mode]
if "random_seed" in params:
st.session_state.random_seed = params["random_seed"]
return params
def get_model(model_type, mode="ensemble", target="GVHD", best_iter=None):
if target == "GVHD":
path = MODEL_DIR / "model_params_gvhd.yaml"
elif target == "Acute GVHD(<100 days)":
path = MODEL_DIR / "model_params_acute.yaml"
elif target == "Chronic GVHD>100 days":
path = MODEL_DIR / "model_params_chronic.yaml"
params = load_model_params(model_type, target, mode, path)
# iter is set for single_model mode, where
if best_iter is not None:
params['iterations'] = best_iter
# if "random_seed" in st.session_state:
# random_seed = st.session_state.random_seed
if model_type == "CatBoost":
return CatBoostClassifier(**params)
# elif model_type == "XGBoost":
# return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss")
# elif model_type == "LightGBM":
# return LGBMClassifier(**params)
elif model_type == "RandomForest":
return RandomForestClassifier(**params)
else:
raise ValueError(f"Unsupported model type: {model_type}")
def save_model(model, user_model_name, metrics_result_single=None):
from datetime import datetime
import io
import pickle
import json
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import login, CommitScheduler
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_single"
# Prepare model dict (same as before)
model_data = {
"timestamp": timestamp,
"model_name": user_model_name,
"target_col": st.session_state.get("target_col", "UNKNOWN"),
"model": model,
"best_iteration": st.session_state.get("best_iteration"),
"metrics_result_single": metrics_result_single,
}
# Serialize (pickle) to bytes
model_bytes = pickle.dumps(model_data)
# Prepare Parquet row
row = {
"filename": filename,
"timestamp": timestamp,
"type": "single",
"model_file": {"path": filename, "bytes": model_bytes},
}
table = pa.Table.from_pylist([row])
table = table.replace_schema_metadata({
"huggingface": json.dumps({"info": {
"features": {
"filename": {"_type": "Value", "dtype": "string"},
"timestamp": {"_type": "Value", "dtype": "string"},
"type": {"_type": "Value", "dtype": "string"},
"model_file": {"_type": "Value", "dtype": "binary"},
}
}})
})
# Write to in-memory buffer
buf = io.BytesIO()
pq.write_table(table, buf)
buf.seek(0)
# Upload to HF dataset
scheduler = CommitScheduler(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo="models",
token=os.environ["HF_TOKEN"],
private=True,
folder_path=Path("/tmp/dummy")
)
scheduler.api.upload_file(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo=f"models/{filename}.parquet",
path_or_fileobj=buf
)
return filename
def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
from datetime import datetime
import io
import pickle
import json
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import login, CommitScheduler
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_ensemble"
ensemble_data = {
"timestamp": timestamp,
"model_name": user_model_name,
"target_col": st.session_state.get("target_col", "UNKNOWN"),
"model": models,
"best_iterations": best_iterations,
"fold_scores": fold_scores,
"metrics_result_ensemble": metrics_result_ensemble,
}
model_bytes = pickle.dumps(ensemble_data)
row = {
"filename": filename,
"timestamp": timestamp,
"type": "ensemble",
"model_file": {"path": filename, "bytes": model_bytes},
}
table = pa.Table.from_pylist([row])
table = table.replace_schema_metadata({
"huggingface": json.dumps({"info": {
"features": {
"filename": {"_type": "Value", "dtype": "string"},
"timestamp": {"_type": "Value", "dtype": "string"},
"type": {"_type": "Value", "dtype": "string"},
"model_file": {"_type": "Value", "dtype": "binary"},
}
}})
})
buf = io.BytesIO()
pq.write_table(table, buf)
buf.seek(0)
scheduler = CommitScheduler(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo="models",
token=os.environ["HF_TOKEN"],
private=True,
folder_path=Path("/tmp/dummy")
)
scheduler.api.upload_file(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo=f"models/{filename}.parquet",
path_or_fileobj=buf
)
return filename
def load_model(model_name):
from huggingface_hub import login, hf_hub_download
import pyarrow.parquet as pq
import pickle
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
from huggingface_hub import HfApi
api = HfApi(token=os.environ["HF_TOKEN"])
all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
# Find matching filename
target_file = None
for f in model_files:
downloaded = hf_hub_download(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
filename=f,
token=os.environ["HF_TOKEN"]
)
table = pq.read_table(downloaded)
row = table.to_pylist()[0]
if row["filename"] == model_name:
target_file = downloaded
break
if not target_file:
raise FileNotFoundError(f"Model {model_name} not found in repo.")
model_bytes = row["model_file"]["bytes"]
return pickle.loads(model_bytes)
def load_model_ensemble(filename):
return load_model(filename)
def ensemble_predict(models, X, cat_features):
preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
return preds