Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import json | |
| import uuid | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import joblib | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, roc_auc_score | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.preprocessing import StandardScaler | |
| import gradio as gr | |
| from evidently.report import Report | |
| from evidently.metrics import DataDriftTable | |
| # ----------------------- | |
| # Config / paths | |
| # ----------------------- | |
| BASE_DIR = Path("models") | |
| BASE_DIR.mkdir(exist_ok=True) | |
| REGISTRY_FILE = BASE_DIR / "registry.json" | |
| REFERENCE_DATA = BASE_DIR / "reference_data.csv" | |
| ACTIVE_SYMLINK = BASE_DIR / "active_model.joblib" # our "active" model file | |
| # ----------------------- | |
| # Utilities | |
| # ----------------------- | |
| def load_registry(): | |
| if REGISTRY_FILE.exists(): | |
| return json.loads(REGISTRY_FILE.read_text()) | |
| return [] | |
| def save_registry(reg): | |
| REGISTRY_FILE.write_text(json.dumps(reg, indent=2)) | |
| def register_model(model_path, metadata): | |
| reg = load_registry() | |
| entry = { | |
| "id": str(uuid.uuid4()), | |
| "path": str(model_path), | |
| "created_at": datetime.utcnow().isoformat() + "Z", | |
| **metadata | |
| } | |
| reg.append(entry) | |
| save_registry(reg) | |
| return entry | |
| def list_models(): | |
| return load_registry() | |
| def set_active_model(model_path): | |
| # copy chosen model to active_model.joblib for inference | |
| shutil.copy(model_path, ACTIVE_SYMLINK) | |
| return str(ACTIVE_SYMLINK) | |
| def load_active_model(): | |
| if not ACTIVE_SYMLINK.exists(): | |
| return None | |
| return joblib.load(ACTIVE_SYMLINK) | |
| # ----------------------- | |
| # Synthetic Prior Auth Data generator | |
| # ----------------------- | |
| def generate_synthetic_data(n=2000, seed=42): | |
| np.random.seed(seed) | |
| # features: age, prior_auth_count, chronic_conditions_count, severity_score, cost_estimate | |
| age = np.random.randint(18, 90, size=n) | |
| prior_auth_count = np.random.poisson(0.5, size=n) | |
| chronic_conditions_count = np.random.poisson(1.0, size=n) | |
| severity_score = np.clip(np.random.normal(loc=2.0, scale=1.0, size=n), 0, 5) | |
| cost_estimate = np.round(np.random.exponential(scale=1200, size=n), 2) | |
| # Simple label logic (approved=1, denied=0) for demo | |
| # More severe, fewer prior auths, and lower cost -> higher chance of approval in this synthetic world | |
| score = ( | |
| -0.02 * age | |
| - 0.5 * prior_auth_count | |
| - 0.7 * chronic_conditions_count | |
| + 1.5 * (5 - severity_score) | |
| - 0.001 * cost_estimate | |
| + np.random.normal(0, 0.5, size=n) | |
| ) | |
| prob = 1 / (1 + np.exp(-score)) | |
| approved = (prob > 0.5).astype(int) | |
| df = pd.DataFrame({ | |
| "age": age, | |
| "prior_auth_count": prior_auth_count, | |
| "chronic_conditions_count": chronic_conditions_count, | |
| "severity_score": severity_score, | |
| "cost_estimate": cost_estimate, | |
| "approved": approved | |
| }) | |
| return df | |
| # ----------------------- | |
| # Preprocess / Train / Evaluate | |
| # ----------------------- | |
| def train_and_register(df, test_size=0.2, random_state=42): | |
| X = df.drop(columns=["approved"]) | |
| y = df["approved"] | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state, stratify=y | |
| ) | |
| pipeline = Pipeline([ | |
| ("scaler", StandardScaler()), | |
| ("clf", RandomForestClassifier(n_estimators=100, random_state=random_state)) | |
| ]) | |
| pipeline.fit(X_train, y_train) | |
| # evaluate | |
| y_proba = pipeline.predict_proba(X_test)[:, 1] | |
| y_pred = pipeline.predict(X_test) | |
| auc = roc_auc_score(y_test, y_proba) | |
| report_text = classification_report(y_test, y_pred, output_dict=True) | |
| # save model artifact | |
| version = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") | |
| model_path = BASE_DIR / f"pa_model_{version}.joblib" | |
| joblib.dump(pipeline, model_path) | |
| # save reference dataset (first time) | |
| if not REFERENCE_DATA.exists(): | |
| X_train.reset_index(drop=True, inplace=True) | |
| X_train["approved"] = y_train.reset_index(drop=True) | |
| X_train.to_csv(REFERENCE_DATA, index=False) | |
| metadata = { | |
| "auc": float(auc), | |
| "report": report_text, | |
| "version": version, | |
| } | |
| entry = register_model(str(model_path), metadata) | |
| return entry, float(auc), report_text | |
| def evaluate_model_on_df(model, df): | |
| X = df.drop(columns=["approved"], errors="ignore") | |
| if "approved" in df.columns: | |
| y = df["approved"] | |
| else: | |
| y = None | |
| proba = model.predict_proba(X)[:, 1] | |
| pred = (proba > 0.5).astype(int) | |
| out = X.copy() | |
| out["pred_proba"] = proba | |
| out["pred"] = pred | |
| if y is not None: | |
| out["true"] = y.values | |
| return out | |
| # ----------------------- | |
| # Monitoring (Evidently) | |
| # ----------------------- | |
| def run_evidently_report(reference_df, current_df, out_html="drift_report.html"): | |
| report = Report(metrics=[DataDriftTable()]) | |
| report.run(reference_data=reference_df, current_data=current_df) | |
| report.save_html(out_html) | |
| return out_html | |
| # very simple heuristic for retrain decision (mean-shift based) | |
| def detect_drift_heuristic(reference_df, current_df, threshold=0.5): | |
| # compute mean difference over numeric columns normalized by ref std | |
| numeric_cols = reference_df.select_dtypes(include=[np.number]).columns | |
| for c in numeric_cols: | |
| ref = reference_df[c].dropna() | |
| curr = current_df[c].dropna() | |
| if len(ref) < 20 or len(curr) < 20: | |
| continue | |
| ref_mean = ref.mean() | |
| ref_std = ref.std() if ref.std() > 0 else 1.0 | |
| curr_mean = curr.mean() | |
| z = abs(curr_mean - ref_mean) / ref_std | |
| if z > threshold: | |
| return True, f"Column {c} shifted (z={z:.2f})" | |
| return False, "No significant mean shifts detected" | |
| # ----------------------- | |
| # Gradio app: UI actions | |
| # ----------------------- | |
| def action_generate_data(samples=2000): | |
| df = generate_synthetic_data(n=samples) | |
| # save sample as csv for user to download | |
| csv_path = "synthetic_claims.csv" | |
| df.to_csv(csv_path, index=False) | |
| return csv_path | |
| def action_train(samples=2000): | |
| df = generate_synthetic_data(n=samples) | |
| entry, auc, report = train_and_register(df) | |
| # auto-set as active | |
| set_active_model(entry["path"]) | |
| return f"Trained and registered model version {entry['version']} (AUC={auc:.3f})", entry, auc | |
| def action_list_models(): | |
| reg = list_models() | |
| return reg | |
| def action_set_active(model_id): | |
| reg = load_registry() | |
| for r in reg: | |
| if r["id"] == model_id: | |
| set_active_model(r["path"]) | |
| return f"Set active model to {r['path']}" | |
| return "Model id not found" | |
| def action_predict(csv_file): | |
| model = load_active_model() | |
| if model is None: | |
| return "No active model. Please train and set an active model first." | |
| df = pd.read_csv(csv_file.name if hasattr(csv_file, "name") else csv_file) | |
| # ensure required cols exist | |
| required = ["age","prior_auth_count","chronic_conditions_count","severity_score","cost_estimate"] | |
| missing = [c for c in required if c not in df.columns] | |
| if missing: | |
| return f"Uploaded CSV is missing columns: {missing}" | |
| out = evaluate_model_on_df(model, df) | |
| out_csv = "predictions.csv" | |
| out.to_csv(out_csv, index=False) | |
| return out_csv | |
| def action_monitor_and_maybe_retrain(uploaded_csv=None): | |
| # reference data (from first training batch) | |
| if not REFERENCE_DATA.exists(): | |
| return "No reference data available. Train a model first." | |
| ref_df = pd.read_csv(REFERENCE_DATA) | |
| if uploaded_csv is None: | |
| # simulate production batch from generator | |
| curr_df = generate_synthetic_data(n=500) | |
| else: | |
| curr_df = pd.read_csv(uploaded_csv.name if hasattr(uploaded_csv, "name") else uploaded_csv) | |
| # Ensure same columns | |
| for c in ["approved"]: | |
| if c in curr_df.columns: | |
| curr_df.drop(columns=[c], inplace=True) | |
| # add placeholder approved when running Evidently (it expects both sides similar) | |
| # use reference approved distribution as filler (not used by our mean-shift heuristic) | |
| curr_for_evidently = curr_df.copy() | |
| curr_for_evidently["approved"] = np.random.choice(ref_df["approved"].values, size=len(curr_for_evidently)) | |
| # run evidently report | |
| report_path = "monitoring_report.html" | |
| run_evidently_report(ref_df, curr_for_evidently, out_html=report_path) | |
| # heuristic drift detection | |
| drift_detected, reason = detect_drift_heuristic(ref_df.drop(columns=["approved"], errors="ignore"), | |
| curr_df.drop(columns=["approved"], errors="ignore")) | |
| retrain_message = "No retraining triggered." | |
| if drift_detected: | |
| # retrain quickly on combined data (ref + curr) | |
| combined = pd.concat([ref_df, curr_for_evidently], ignore_index=True) | |
| # keep recent training small for demo | |
| entry, auc, rep = train_and_register(combined, test_size=0.2) | |
| set_active_model(entry["path"]) | |
| retrain_message = f"Drift detected ({reason}). Retrained and registered new model {entry['version']} (AUC={auc:.3f})." | |
| return report_path, retrain_message | |
| # ----------------------- | |
| # Build Gradio UI | |
| # ----------------------- | |
| with gr.Blocks(title="Prior Authorization MLOps Pipeline Demo") as demo: | |
| gr.Markdown("# Prior Authorization — MLOps Pipeline (Demo)") | |
| gr.Markdown("This demo shows a lightweight MLOps pipeline: data generation, training, model registry, inference, monitoring, and automatic retrain trigger.") | |
| with gr.Tab("Data"): | |
| gr.Markdown("Generate synthetic prior authorization claim dataset for training or upload your CSV.") | |
| gen_btn = gr.Button("Generate Synthetic Data (CSV)") | |
| gen_file = gr.File() | |
| gen_btn.click(action_generate_data, inputs=[gr.Number(value=2000, label="Samples")], outputs=gen_file) | |
| gr.Markdown("CSV format must include columns: age, prior_auth_count, chronic_conditions_count, severity_score, cost_estimate") | |
| with gr.Tab("Train / Registry"): | |
| tr_btn = gr.Button("Train & Register Model (on synthetic data)") | |
| tr_out = gr.Textbox() | |
| reg_table = gr.Dataframe(headers=["id", "path", "created_at", "version", "auc"], interactive=False) | |
| tr_btn.click(action_train, outputs=[tr_out, gr.JSON(), gr.Number()]) | |
| list_btn = gr.Button("List Registered Models") | |
| list_btn.click(fn=action_list_models, outputs=reg_table) | |
| with gr.Tab("Set Active Model"): | |
| gr.Markdown("Choose a model ID from the registry to mark as active for inference") | |
| active_id = gr.Textbox(label="Model ID to activate") | |
| set_btn = gr.Button("Set Active") | |
| set_out = gr.Textbox() | |
| set_btn.click(action_set_active, inputs=[active_id], outputs=[set_out]) | |
| with gr.Tab("Inference"): | |
| gr.Markdown("Upload a CSV (claims) to score with the active model") | |
| upload_infer = gr.File(label="Upload claims CSV") | |
| infer_btn = gr.Button("Run Inference") | |
| infer_out = gr.File() | |
| infer_btn.click(action_predict, inputs=[upload_infer], outputs=[infer_out]) | |
| with gr.Tab("Monitoring"): | |
| gr.Markdown("Run monitoring on a sample production batch or upload a production CSV. If drift is detected the demo will retrain and register a new model.") | |
| upload_prod = gr.File(label="(Optional) Upload production CSV") | |
| mon_btn = gr.Button("Run Monitoring (+ auto retrain if drift)") | |
| mon_report = gr.File() | |
| mon_msg = gr.Textbox() | |
| mon_btn.click(action_monitor_and_maybe_retrain, inputs=[upload_prod], outputs=[mon_report, mon_msg]) | |
| with gr.Tab("Active Model Info"): | |
| info_btn = gr.Button("Show Active Model Path") | |
| info_out = gr.Textbox() | |
| def show_active(): | |
| if ACTIVE_SYMLINK.exists(): | |
| return str(ACTIVE_SYMLINK) | |
| return "No active model set." | |
| info_btn.click(show_active, outputs=info_out) | |
| demo.launch(share=True) |