""" SWAN Menopause Stage Prediction & Forecasting — Gradio UI Hugging Face Spaces deployment-ready. Run locally: python app.py Deploy: Push to a HF Space with SDK=gradio Output structure (per execution): swan_ml_output/ / charts/ ← PNG visualizations predictions/ ← CSV result files reports/ ← TXT summary reports """ import os import json import warnings from datetime import datetime from pathlib import Path from typing import Optional import numpy as np import pandas as pd import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt warnings.filterwarnings("ignore") # ── Gradio ──────────────────────────────────────────────────────────────────── import gradio as gr # ── Local ML module ─────────────────────────────────────────────────────────── try: from menopause import ( MenopauseForecast, SymptomCycleForecaster, load_forecast_model, ) _MODULE_AVAILABLE = True except ImportError: _MODULE_AVAILABLE = False # ── Model loading ───────────────────────────────────────────────────────────── FORECAST_DIR = os.environ.get("FORECAST_DIR", "swan_ml_output") OUTPUT_BASE = Path(FORECAST_DIR) _forecast: Optional[MenopauseForecast] = None # type: ignore[type-arg] _metadata: dict = {} def _load_models(): """Attempt to load saved joblib pipelines. Returns (success, message).""" global _forecast, _metadata if not _MODULE_AVAILABLE: return False, "menopause.py not found. Make sure it is in the same directory." meta_path = Path(FORECAST_DIR) / "forecast_metadata.json" rf_path = Path(FORECAST_DIR) / "rf_pipeline.pkl" lr_path = Path(FORECAST_DIR) / "lr_pipeline.pkl" if not all(p.exists() for p in (meta_path, rf_path, lr_path)): return ( False, f"Model artifacts not found in '{FORECAST_DIR}'. " "Run `python menopause.py` to train and save the models first.", ) try: _forecast = load_forecast_model(FORECAST_DIR) with open(meta_path) as fh: _metadata = json.load(fh) return True, f"✅ Models loaded — {len(_metadata.get('feature_names', []))} features" except Exception as exc: return False, f"Error loading models: {exc}" _MODEL_OK, _MODEL_MSG = _load_models() # ── Output directory management ─────────────────────────────────────────────── def _make_run_dir() -> Path: """Create and return a unique timestamped run directory under swan_ml_output/.""" ts = datetime.now().strftime("%Y%m%d_%H%M%S") run_dir = OUTPUT_BASE / ts (run_dir / "charts").mkdir(parents=True, exist_ok=True) (run_dir / "predictions").mkdir(parents=True, exist_ok=True) (run_dir / "reports").mkdir(parents=True, exist_ok=True) return run_dir def _get_file_path(file_obj) -> Optional[str]: """ Safely extract a file-system path from a Gradio file component value. Gradio ≤ 3.x → returns a file-like object with a .name attribute. Gradio 4.x → returns a str path (or NamedString subclass). This helper handles both. """ if file_obj is None: return None if hasattr(file_obj, "name"): return file_obj.name return str(file_obj) # ── Constants & helpers ─────────────────────────────────────────────────────── STAGE_COLORS = {"pre": "#16a34a", "peri": "#d97706", "post": "#7c3aed"} STAGE_EMOJI = {"pre": "🟢", "peri": "🟡", "post": "🟣"} STAGE_LABELS = { "pre": "Pre-Menopausal", "peri": "Peri-Menopausal", "post": "Post-Menopausal", } STAGE_INFO = { "pre": { "title": "Pre-Menopausal", "description": "Regular menstrual cycles with typical hormonal fluctuations. Ovarian function is normal.", "symptoms": ["Regular periods", "Normal hormone levels", "Potential mild PMS"], "guidance": "Maintain regular check-ups. Track your cycle and note any changes.", }, "peri": { "title": "Peri-Menopausal (Transition)", "description": "Hormonal changes begin — estrogen and progesterone levels fluctuate. Cycles become irregular.", "symptoms": ["Irregular periods", "Hot flashes", "Sleep disturbances", "Mood changes", "Night sweats"], "guidance": "Consult your healthcare provider. Lifestyle adjustments (diet, exercise, sleep) can help.", }, "post": { "title": "Post-Menopausal", "description": "12+ months since last menstrual period. Estrogen remains at consistently lower levels.", "symptoms": ["No periods", "Possible continued hot flashes", "Vaginal dryness", "Bone density changes"], "guidance": "Focus on bone health, cardiovascular health, and regular screenings. Discuss HRT options.", }, } # Feature descriptions keyed by the model's canonical feature names FEATURE_DESCRIPTIONS = { "PAIN17": "Pain indicator (visit-specific)", "PAINTW17": "Pain two-week indicator", "PAIN27": "Secondary pain indicator", "PAINTW27": "Secondary pain two-week indicator", "SLEEP17": "Sleep disturbance pattern 1", "SLEEP27": "Sleep disturbance pattern 2", "BCOHOTH7": "Birth control — other method", "EXERCIS7": "General exercise indicator", "EXERHAR7": "Vigorous exercise", "EXEROST7": "Osteoporosis exercise", "EXERMEN7": "Exercise — mental health", "EXERLOO7": "Exercise lookalike", "EXERMEM7": "Exercise — memory", "EXERPER7": "Exercise perception", "EXERGEN7": "General exercise type", "EXERWGH7": "Weight exercise", "EXERADV7": "Exercise advice indicator", "EXEROTH7": "Other exercise", "EXERSPE7": "Specific exercise", "ABBLEED7": "Abnormal bleeding (0=no, 1=yes)", # ← correct feature name "BLEEDNG7": "Bleeding pattern", "LMPDAY7": "Last menstrual period day", "DEPRESS7": "Depression indicator", "SEX17": "Sexual activity indicator 1", "SEX27": "Sexual activity indicator 2", "SEX37": "Sexual activity indicator 3", "SEX47": "Sexual activity indicator 4", "SEX57": "Sexual activity indicator 5", "SEX67": "Sexual activity indicator 6", "SEX77": "Sexual activity indicator 7", "SEX87": "Sexual activity indicator 8", "SEX97": "Sexual activity indicator 9", "SEX107": "Sexual activity indicator 10", "SEX117": "Sexual activity indicator 11", "SEX127": "Sexual activity indicator 12", "SMOKERE7": "Smoking status", "HOTFLAS7": "Hot flash severity (1=none, 5=very severe)", "NUMHOTF7": "Number of hot flashes per week", "BOTHOTF7": "How bothersome are hot flashes", "IRRITAB7": "Irritability level", "VAGINDR7": "Vaginal dryness", "MOODCHG7": "Mood change frequency", "SLEEPQL7": "Sleep quality score", "PHYSILL7": "Physical illness indicators", "HOTHEAD7": "Hot flashes with headache", "EXER12H7": "Exercise in last 12 hours", "ALCO24H7": "Alcohol in last 24h", "AGE7": "Age (years)", "RACE": "Race (1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic)", "LANGINT7": "Interview language indicator", } def _confidence_color(conf: float) -> str: if conf >= 0.8: return "#16a34a" elif conf >= 0.6: return "#d97706" return "#dc2626" # ── Chart builders ──────────────────────────────────────────────────────────── def _make_proba_chart( probabilities: dict, predicted_stage: str, save_path: Optional[Path] = None, ) -> plt.Figure: """Horizontal bar chart for stage probabilities. Optionally saves PNG.""" fig, ax = plt.subplots(figsize=(6, 3.5)) fig.patch.set_facecolor("#1a1a2e") ax.set_facecolor("#16213e") stages = list(probabilities.keys()) probs = [probabilities[s] * 100 for s in stages] colors = [STAGE_COLORS.get(s, "#607d8b") for s in stages] edge_colors = ["white" if s == predicted_stage else "none" for s in stages] lws = [2.5 if s == predicted_stage else 0 for s in stages] bars = ax.barh(stages, probs, color=colors, edgecolor=edge_colors, linewidth=lws, height=0.5, zorder=3) for bar, prob in zip(bars, probs): ax.text( min(prob + 1, 98), bar.get_y() + bar.get_height() / 2, f"{prob:.1f}%", va="center", ha="left", color="white", fontsize=11, fontweight="bold", ) labels = [STAGE_LABELS.get(s, s) for s in stages] ax.set_yticks(range(len(stages))) ax.set_yticklabels(labels, color="white", fontsize=10) ax.set_xlim(0, 105) ax.tick_params(colors="white", labelsize=11) ax.spines[["top", "right", "left", "bottom"]].set_visible(False) ax.xaxis.set_visible(False) for spine in ax.spines.values(): spine.set_color("#333") ax.set_title("Stage Probabilities", color="white", fontsize=12, pad=10, fontweight="bold") ax.grid(axis="x", color="#333", linestyle="--", linewidth=0.5, zorder=0) fig.tight_layout() if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) return fig def _make_cycle_chart( cycle_day: int, cycle_length: int = 28, hot_prob: float = None, mood_prob: float = None, save_path: Optional[Path] = None, ) -> plt.Figure: """Circular cycle-day visualization. Optionally saves PNG.""" fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) fig.patch.set_facecolor("#1a1a2e") ax.set_facecolor("#16213e") days = np.linspace(0, 2 * np.pi, cycle_length, endpoint=False) for i, d in enumerate(days): phase = i / cycle_length color = plt.cm.RdYlGn(1 - phase) ax.bar(d, 1, width=2 * np.pi / cycle_length * 0.9, bottom=0.5, color=color, alpha=0.4, zorder=1) if cycle_day is not None: angle = (cycle_day - 1) / cycle_length * 2 * np.pi ax.scatter([angle], [1.05], s=200, color="#ff6b6b", zorder=5, linewidths=2) ax.annotate( f"Day {cycle_day}", xy=(angle, 1.05), xytext=(0, 0), textcoords="offset points", ha="center", va="center", color="white", fontsize=12, fontweight="bold", ) ax.set_rticks([]) ax.set_xticks([i * 2 * np.pi / 4 for i in range(4)]) ax.set_xticklabels(["Day 1", "Day 7", "Day 14", "Day 21"], color="#aaa", fontsize=9) ax.set_yticklabels([]) ax.spines["polar"].set_color("#333") ax.grid(color="#333", linewidth=0.5) title = "Cycle Position" if hot_prob is not None: title += f"\n🔥 {hot_prob:.0%} 😤 {mood_prob:.0%}" ax.set_title(title, color="white", fontsize=11, pad=20, fontweight="bold") fig.tight_layout() if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) return fig def _make_batch_summary_chart(results_df: pd.DataFrame, save_path: Optional[Path] = None) -> None: """Stage distribution + confidence histogram for batch runs. Saves PNG.""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) fig.patch.set_facecolor("#1a1a2e") # Stage distribution pie stage_counts = results_df["predicted_stage"].value_counts() colors = [STAGE_COLORS.get(s, "#607d8b") for s in stage_counts.index] ax1.set_facecolor("#16213e") wedges, texts, autotexts = ax1.pie( stage_counts.values, labels=stage_counts.index, colors=colors, autopct="%1.0f%%", textprops={"color": "white", "fontsize": 10}, ) for at in autotexts: at.set_color("white") ax1.set_title("Stage Distribution", color="white", fontsize=11, fontweight="bold") # Confidence histogram ax2.set_facecolor("#16213e") if "confidence" in results_df.columns: conf = results_df["confidence"].dropna() ax2.hist(conf, bins=min(10, len(conf)), color="#3B82F6", edgecolor="#1a1a2e", alpha=0.8) ax2.axvline(0.8, color="#4CAF50", linestyle="--", linewidth=1.5, label="High (0.80)") ax2.axvline(0.6, color="#FF9800", linestyle="--", linewidth=1.5, label="Med (0.60)") ax2.legend(fontsize=8, labelcolor="white", facecolor="#0d0d1a") ax2.set_xlabel("Confidence", color="#aaa", fontsize=9) ax2.set_ylabel("Count", color="#aaa", fontsize=9) ax2.tick_params(colors="white", labelsize=9) for sp in ["top", "right"]: ax2.spines[sp].set_visible(False) for sp in ["left", "bottom"]: ax2.spines[sp].set_color("#333") ax2.set_title("Confidence Distribution", color="white", fontsize=11, fontweight="bold") fig.tight_layout() if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) # ── Text report writers ─────────────────────────────────────────────────────── def _write_single_stage_report( path: Path, stage: str, confidence: float, probabilities: dict, model: str, comparison: dict, input_features: dict, ): lines = [ "=" * 60, "SWAN MENOPAUSE STAGE PREDICTION REPORT", f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "=" * 60, "", f"Predicted Stage : {STAGE_LABELS.get(stage, stage)}", f"Model : {model}", f"Confidence : {confidence:.1%}", "", "Stage Probabilities:", ] for s, p in probabilities.items(): bar = "█" * int(p * 20) lines.append(f" {s:<6} : {p:.4f} {bar}") lines += [ "", "Model Comparison:", f" RandomForest → {comparison['RandomForest']['stage']}" f" ({comparison['RandomForest'].get('confidence', 0):.1%})", f" LogisticRegression → {comparison['LogisticRegression']['stage']}" f" ({comparison['LogisticRegression'].get('confidence', 0):.1%})", "", "Input Features (non-NaN):", ] for k, v in input_features.items(): if v is not None and not (isinstance(v, float) and np.isnan(v)): lines.append(f" {k:<12} = {v}") lines += [ "", "⚠️ For research/educational use only. Not a clinical diagnosis.", "=" * 60, ] path.write_text("\n".join(lines), encoding="utf-8") def _write_batch_report( path: Path, results: pd.DataFrame, model: str, run_dir: Path, ): total = len(results) dist = results["predicted_stage"].value_counts().to_dict() \ if "predicted_stage" in results.columns else {} if "confidence" in results.columns: conf = results["confidence"] mean_c = conf.mean(); min_c = conf.min(); max_c = conf.max() high = int((conf > 0.8).sum()) medium = int(((conf > 0.6) & (conf <= 0.8)).sum()) low = int((conf <= 0.6).sum()) else: mean_c = min_c = max_c = high = medium = low = 0 lines = [ "=" * 60, "SWAN BATCH STAGE PREDICTION REPORT", f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", f"Model : {model}", "=" * 60, "", f"Total Individuals : {total}", "", "Stage Distribution:", ] for stage in ["pre", "peri", "post"]: count = dist.get(stage, 0) pct = count / total * 100 if total else 0 lines.append(f" {stage:<6} : {count} ({pct:.1f}%)") lines += [ "", "Confidence Scores:", f" Mean : {mean_c:.4f}", f" Min : {min_c:.4f}", f" Max : {max_c:.4f}", "", "Confidence Distribution:", f" High (>0.80) : {high}/{total} ({high/total*100:.1f}%)" if total else " N/A", f" Medium (0.60-0.80) : {medium}/{total} ({medium/total*100:.1f}%)" if total else " N/A", f" Low (≤0.60) : {low}/{total} ({low/total*100:.1f}%)" if total else " N/A", "", f"Output Directory : {run_dir}", "", "⚠️ For research/educational use only. Not a clinical diagnosis.", "=" * 60, ] path.write_text("\n".join(lines), encoding="utf-8") def _write_symptom_report( path: Path, individual_id: str, lmp: str, target_date: str, cycle_day: int, cycle_length: int, hot_prob: float, hot_pred: bool, mood_prob: float, mood_pred: bool, ): hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0 mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0 lines = [ "=" * 60, "SWAN SYMPTOM CYCLE FORECAST REPORT", f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "=" * 60, "", f"Individual : {individual_id or 'N/A'}", f"LMP : {lmp}", f"Target Date : {target_date or 'Today'}", f"Cycle Length : {cycle_length} days", f"Cycle Day : {cycle_day}", "", "Symptom Probabilities:", f" Hot Flash : {hp:.4f} {'[ELEVATED RISK]' if hot_pred else '[LOW RISK]'}", f" Mood Change : {mp:.4f} {'[ELEVATED RISK]' if mood_pred else '[LOW RISK]'}", "", "⚠️ For research/educational use only. Not a clinical diagnosis.", "=" * 60, ] path.write_text("\n".join(lines), encoding="utf-8") def _write_batch_symptom_report( path: Path, results: pd.DataFrame, cycle_length: int, run_dir: Path, ): total = len(results) hot_flags = int(results["hotflash_pred"].sum()) \ if "hotflash_pred" in results.columns else 0 mood_flags = int(results["mood_pred"].sum()) \ if "mood_pred" in results.columns else 0 mean_hot = float(results["hotflash_prob"].mean()) \ if "hotflash_prob" in results.columns else 0.0 mean_mood = float(results["mood_prob"].mean()) \ if "mood_prob" in results.columns else 0.0 lines = [ "=" * 60, "SWAN BATCH SYMPTOM FORECAST REPORT", f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", f"Cycle Length : {cycle_length} days", "=" * 60, "", f"Total Individuals : {total}", f"Hot Flash Risk : {hot_flags}/{total} elevated", f"Mood Change Risk : {mood_flags}/{total} elevated", f"Avg Hot Flash Prob : {mean_hot:.4f}", f"Avg Mood Prob : {mean_mood:.4f}", "", f"Output Directory : {run_dir}", "", "⚠️ For research/educational use only. Not a clinical diagnosis.", "=" * 60, ] path.write_text("\n".join(lines), encoding="utf-8") # ── Core prediction functions ───────────────────────────────────────────────── def predict_single_stage( age, race, langint, hot_flash, num_hot_flash, bothersome_hf, sleep_quality, depression_indicator, mood_change, irritability, pain_indicator, abbleed, vaginal_dryness, lmp_day, model_choice, ): """ Single-person stage prediction. Returns (stage_html, chart_fig, conf_note, compare_html, csv_download_path). """ if not _MODEL_OK: return f"⚠️ {_MODEL_MSG}", None, "Models unavailable.", "", None # Build feature dict using the model's canonical feature names def _v(x): return float(x) if x is not None else np.nan feature_dict = { "AGE7": _v(age), "RACE": _v(race), "LANGINT7": _v(langint), "HOTFLAS7": _v(hot_flash), "NUMHOTF7": _v(num_hot_flash), "BOTHOTF7": _v(bothersome_hf), "SLEEPQL7": _v(sleep_quality), "DEPRESS7": _v(depression_indicator), "MOODCHG7": _v(mood_change), "IRRITAB7": _v(irritability), "PAIN17": _v(pain_indicator), "ABBLEED7": _v(abbleed), # ← correct feature name (was ABLEED7) "VAGINDR7": _v(vaginal_dryness), "LMPDAY7": _v(lmp_day) if lmp_day else np.nan, } try: result = _forecast.predict_single(feature_dict, model=model_choice, return_proba=True) stage = result["stage"] confidence = result.get("confidence") or 0.0 proba = result.get("probabilities") or {} # ── Create timestamped run directory ────────────────────────────────── run_dir = _make_run_dir() # ── Save probability chart (PNG) ────────────────────────────────────── chart_path = run_dir / "charts" / "stage_probabilities.png" chart_fig = _make_proba_chart(proba, stage, save_path=chart_path) if proba else None # ── Save prediction CSV ─────────────────────────────────────────────── pred_row = { "predicted_stage": stage, "model": model_choice, "confidence": round(confidence, 4), **{f"prob_{k}": round(v, 4) for k, v in proba.items()}, "timestamp": datetime.now().isoformat(), } csv_path = run_dir / "predictions" / "stage_prediction.csv" pd.DataFrame([pred_row]).to_csv(csv_path, index=False) # ── Model comparison ────────────────────────────────────────────────── comparison = _forecast.compare_models(feature_dict) rf_stage = comparison["RandomForest"]["stage"] lr_stage = comparison["LogisticRegression"]["stage"] agree = rf_stage == lr_stage # ── Save text report ────────────────────────────────────────────────── txt_path = run_dir / "reports" / "prediction_summary.txt" _write_single_stage_report( txt_path, stage, confidence, proba, model_choice, comparison, feature_dict, ) # ── Build result card HTML ──────────────────────────────────────────── info = STAGE_INFO.get(stage, {}) emoji = STAGE_EMOJI.get(stage, "⚪") color = STAGE_COLORS.get(stage, "#607d8b") conf_color = _confidence_color(confidence) symptom_tags = "".join( f'{s}' for s in info.get("symptoms", []) ) stage_html = f"""
{emoji}
Predicted Stage
{STAGE_LABELS.get(stage, stage)}
Confidence
{confidence:.0%}

{info.get('description', '')}

Common Symptoms
{symptom_tags}
💡 Guidance: {info.get('guidance', '')}
Model: {model_choice} · {datetime.now().strftime('%Y-%m-%d %H:%M')}
""" # Confidence note if confidence >= 0.8: conf_note = "✅ High confidence — the model is quite certain about this stage." elif confidence >= 0.6: conf_note = ("⚠️ Moderate confidence — consider providing more feature values " "or consulting a clinician.") else: conf_note = ("🔴 Low confidence — prediction is uncertain; " "clinical consultation is strongly recommended.") # Model comparison panel + run-dir info compare_html = f"""
Model Comparison
Random Forest
{STAGE_EMOJI.get(rf_stage,'')} {STAGE_LABELS.get(rf_stage, rf_stage)}
{comparison['RandomForest'].get('confidence', 0):.0%} confidence
Logistic Regression
{STAGE_EMOJI.get(lr_stage,'')} {STAGE_LABELS.get(lr_stage, lr_stage)}
{comparison['LogisticRegression'].get('confidence', 0):.0%} confidence
{"✅ Both models agree — prediction is robust" if agree else "⚠️ Models disagree — interpret with caution"}
📁 Outputs saved to:
{run_dir}/
charts/stage_probabilities.png
predictions/stage_prediction.csv
reports/prediction_summary.txt
""" return stage_html, chart_fig, conf_note, compare_html, str(csv_path) except Exception as exc: return f"❌ Prediction error: {exc}", None, "", "", None def predict_batch_stage(file, model_choice): """ Batch stage prediction from uploaded CSV. Returns (csv_download_path, summary_html, preview_df). """ if not _MODEL_OK: return None, f"⚠️ {_MODEL_MSG}", None if file is None: return None, "Please upload a CSV file.", None file_path = _get_file_path(file) try: df = pd.read_csv(file_path) except Exception as exc: return None, f"Could not read CSV: {exc}", None if df.empty: return None, "Uploaded CSV is empty.", None # Identify ID column id_col_candidates = ["individual", "Individual", "ID", "id", "SWANID", "subject", "Subject"] id_col = next((c for c in id_col_candidates if c in df.columns), None) # Validate features feature_names = _metadata.get("feature_names", []) matching = [c for c in df.columns if c in feature_names] missing_pct = 1 - len(matching) / max(len(feature_names), 1) warnings_list = [] if not matching: return None, ( "❌ No matching feature columns found. " "Please include columns from the training feature set " "(see 'Feature Reference' tab)." ), None if missing_pct > 0.5: warnings_list.append( f"⚠️ {missing_pct:.0%} of training features are missing — " "prediction accuracy may be reduced." ) try: results = _forecast.predict_batch(df, model=model_choice, return_proba=True) # Insert individual ID if id_col: results.insert(0, "individual", df[id_col].values) else: results.insert(0, "individual", [f"Row_{i+1}" for i in range(len(results))]) results["model"] = model_choice results["notes"] = "" if "confidence" in results.columns: low_mask = results["confidence"] < 0.6 results.loc[low_mask, "notes"] = "Low confidence — review manually" # ── Create timestamped run directory ────────────────────────────────── run_dir = _make_run_dir() # ── Save predictions CSV ────────────────────────────────────────────── csv_path = run_dir / "predictions" / "batch_stage_predictions.csv" results.to_csv(csv_path, index=False) # ── Save confidence/distribution chart (PNG) ────────────────────────── chart_path = run_dir / "charts" / "batch_summary_chart.png" _make_batch_summary_chart(results, save_path=chart_path) # ── Save text report ────────────────────────────────────────────────── txt_path = run_dir / "reports" / "batch_summary.txt" _write_batch_report(txt_path, results, model_choice, run_dir) # ── Build summary HTML ──────────────────────────────────────────────── total = len(results) dist = results["predicted_stage"].value_counts().to_dict() mean_conf = results["confidence"].mean() \ if "confidence" in results.columns else 0.0 high_conf = int((results["confidence"] > 0.8).sum()) \ if "confidence" in results.columns else 0 dist_bars = "" for stage in ["pre", "peri", "post"]: count = dist.get(stage, 0) pct = count / total * 100 dist_bars += f"""
{STAGE_EMOJI.get(stage,'')} {STAGE_LABELS.get(stage, stage)} {count} ({pct:.0f}%)
""" warn_html = "".join( f'
{w}
' for w in warnings_list ) summary_html = f"""
📊 Batch Results — {total} individuals
{warn_html}
Total
{total}
Avg Confidence
{mean_conf:.0%}
High Conf (>80%)
{high_conf}/{total}
{dist_bars}
📁 Outputs saved to:
{run_dir}/
predictions/batch_stage_predictions.csv
charts/batch_summary_chart.png
reports/batch_summary.txt
""" return str(csv_path), summary_html, results.head(20) except Exception as exc: return None, f"❌ Batch prediction error: {exc}", None def predict_symptoms(individual_id, lmp_input, target_date_input, cycle_length): """ Cycle-based symptom forecasting (single person). Returns (result_html, chart_fig, csv_download_path). """ if not lmp_input: return "Please enter your Last Menstrual Period date.", None, None try: cycle_length = int(cycle_length) if cycle_length else 28 fore = SymptomCycleForecaster(cycle_length=cycle_length) target_date = target_date_input if target_date_input else None result = fore.predict_single(lmp=lmp_input, target_date=target_date) cycle_day = result.get("cycle_day") hot_prob = result.get("hotflash_prob", 0) hot_pred = result.get("hotflash_pred", False) mood_prob = result.get("mood_prob", 0) mood_pred = result.get("mood_pred", False) # Safe float helpers hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0 mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0 # ── Create timestamped run directory ────────────────────────────────── run_dir = _make_run_dir() # ── Save cycle chart (PNG) ──────────────────────────────────────────── chart_path = run_dir / "charts" / "cycle_position.png" chart_fig = _make_cycle_chart( cycle_day, cycle_length, hp, mp, save_path=chart_path ) # ── Save forecast CSV ───────────────────────────────────────────────── csv_path = run_dir / "predictions" / "symptom_forecast.csv" lmp_note = "" try: int(str(lmp_input).strip()) lmp_note = "LMP inferred as day-of-month; interpret with caution" except (ValueError, TypeError): pass pd.DataFrame([{ "individual": individual_id or "N/A", "LMP": lmp_input, "date": target_date_input or datetime.now().strftime("%Y-%m-%d"), "cycle_day": cycle_day, "hotflash_prob": round(hp, 6), "hotflash_pred": bool(hot_pred), "mood_prob": round(mp, 6), "mood_pred": bool(mood_pred), "notes": lmp_note, }]).to_csv(csv_path, index=False) # ── Save text report ────────────────────────────────────────────────── txt_path = run_dir / "reports" / "symptom_summary.txt" _write_symptom_report( txt_path, individual_id, lmp_input, target_date_input, cycle_day, cycle_length, hp, hot_pred, mp, mood_pred, ) # ── Build result HTML ───────────────────────────────────────────────── def _prob_bar(prob, label, color): pct = min(prob * 100, 100) return f"""
{label} {pct:.0f}%
""" hot_alert = "🔴 Elevated risk" if hot_pred else "🟢 Low risk" mood_alert = "🔴 Elevated risk" if mood_pred else "🟢 Low risk" html = f"""
{individual_id or 'Forecast'} — Cycle Day {cycle_day or '?'}
LMP: {lmp_input} | Target: {target_date_input or 'Today'} | Cycle: {cycle_length} days
{_prob_bar(hp, '🔥 Hot Flash Probability', '#ef4444')}
{hot_alert}
{_prob_bar(mp, '😤 Mood Change Probability', '#7c3aed')}
{mood_alert}
ℹ️ Probabilities are computed from a cycle-phase model (Gaussian heuristic). They represent symptom likelihood based on cycle day, not a clinical diagnosis.
📁 Outputs saved to:
{run_dir}/
charts/cycle_position.png
predictions/symptom_forecast.csv
reports/symptom_summary.txt
""" return html, chart_fig, str(csv_path) except Exception as exc: return f"❌ Error: {exc}", None, None def predict_symptoms_batch(file, lmp_col_name, date_col_name, cycle_length): """ Batch symptom forecasting from CSV. Returns (csv_download_path, summary_html, preview_df). """ if file is None: return None, "Please upload a CSV file.", None file_path = _get_file_path(file) try: df = pd.read_csv(file_path) except Exception as exc: return None, f"Could not read CSV: {exc}", None if lmp_col_name not in df.columns: return None, ( f"LMP column '{lmp_col_name}' not found in CSV. " f"Columns present: {list(df.columns)}" ), None try: cycle_length = int(cycle_length) if cycle_length else 28 fore = SymptomCycleForecaster(cycle_length=cycle_length) date_col = date_col_name \ if (date_col_name and date_col_name in df.columns) else None results = fore.predict_df(df, lmp_col=lmp_col_name, date_col=date_col) # ── Add notes column (flag day-of-month LMP rows) ───────────────────── def _lmp_note(val): try: int(str(val).strip()) return "LMP inferred as day-of-month; interpret with caution" except (ValueError, TypeError): return "" results["notes"] = df[lmp_col_name].apply(_lmp_note) # ── Create timestamped run directory ────────────────────────────────── run_dir = _make_run_dir() # ── Save predictions CSV ────────────────────────────────────────────── csv_path = run_dir / "predictions" / "batch_symptom_forecast.csv" results.to_csv(csv_path, index=False) # ── Save text report ────────────────────────────────────────────────── txt_path = run_dir / "reports" / "batch_symptom_summary.txt" _write_batch_symptom_report(txt_path, results, cycle_length, run_dir) # ── Build summary HTML ──────────────────────────────────────────────── total = len(results) hot_flags = int(results["hotflash_pred"].sum()) \ if "hotflash_pred" in results.columns else 0 mood_flags = int(results["mood_pred"].sum()) \ if "mood_pred" in results.columns else 0 mean_hot = float(results["hotflash_prob"].mean()) \ if "hotflash_prob" in results.columns else 0.0 mean_mood = float(results["mood_prob"].mean()) \ if "mood_prob" in results.columns else 0.0 summary_html = f"""
🌊 Symptom Forecast — {total} individuals
Total
{total}
🔥 Hot Flash Risk
{hot_flags}
😤 Mood Risk
{mood_flags}
Avg Hot Flash Prob
{mean_hot:.1%}
Avg Mood Prob
{mean_mood:.1%}
📁 Outputs saved to:
{run_dir}/
predictions/batch_symptom_forecast.csv
reports/batch_symptom_summary.txt
""" return str(csv_path), summary_html, results except Exception as exc: return None, f"❌ Error: {exc}", None # ── Feature reference & model status ───────────────────────────────────────── def get_feature_reference() -> str: feature_names = _metadata.get("feature_names", list(FEATURE_DESCRIPTIONS.keys())) rows = "" for i, f in enumerate(feature_names[:60]): desc = FEATURE_DESCRIPTIONS.get(f, f.split("_")[0]) rows += f""" {i + 1} {f} {desc} """ remaining = len(feature_names) - 60 if remaining > 0: rows += f""" … and {remaining} more features (one-hot encoded categories) """ return f"""
📋 Training Features ({len(feature_names)} total after encoding)
{rows}
# Feature Description
""" def get_model_status() -> str: if _MODEL_OK: fc = len(_metadata.get("feature_names", [])) sc = _metadata.get("stage_classes", ["pre", "peri", "post"]) badges = "".join( f'{STAGE_EMOJI.get(s,"")} {s}' for s in sc ) return f"""
Models Loaded Successfully
Ready for predictions
Features
{fc}
Models
2
Stages
{len(sc)}
Available Stages
{badges}
""" return f"""
⚠️
Models Not Loaded
{_MODEL_MSG}
To train and save models:
python menopause.py

This generates swan_ml_output/rf_pipeline.pkl, lr_pipeline.pkl, and forecast_metadata.json.
""" # ── Education content ───────────────────────────────────────────────────────── EDUCATION_HTML = """

🌸 Understanding Menopause

Menopause is a natural biological process marking the end of menstrual cycles. It is officially diagnosed after 12 consecutive months without a menstrual period and typically occurs in women in their late 40s to early 50s.

Three Stages

🟢 Pre-Menopause

Regular ovarian function. Periods are predictable. Hormones (estrogen, progesterone) follow a consistent monthly pattern.

🟡 Peri-Menopause

Transition phase — usually begins in the mid-40s. Hormone levels fluctuate. Periods become irregular. Hot flashes and sleep issues may begin.

🟣 Post-Menopause

12+ months after the last period. Lower estrogen levels. Risk factors for osteoporosis and cardiovascular disease increase.

Common Symptoms by Stage

Symptom Pre Peri Post
Hot flashes
Irregular periods N/A
Sleep disturbances Mild
Mood changes PMS Possible
Vaginal dryness Possible
Bone density changes Begins

About This Tool

This application uses machine learning models trained on the SWAN (Study of Women's Health Across the Nation) dataset — a landmark multisite, multiethnic longitudinal study. The models were trained on self-reported symptom and behavioral data to predict menopausal stage.

⚠️ Disclaimer: This tool is for educational and research purposes only. Predictions should not substitute clinical diagnosis. Always consult a qualified healthcare provider for medical advice.
""" # ── Gradio UI ───────────────────────────────────────────────────────────────── CUSTOM_CSS = """ /* ── Force light mode — disable Gradio dark theme entirely ───────────── */ :root { color-scheme: light only !important; } /* Fallback: if Gradio somehow sets .dark, override every key variable */ body.dark, body.dark .gradio-container { --body-background-fill: #f0f4f8 !important; --background-fill-primary: #ffffff !important; --background-fill-secondary: #f8fafc !important; --border-color-primary: #e2e8f0 !important; --border-color-accent: #3b82f6 !important; --color-accent: #2563eb !important; --color-accent-soft: #eff6ff !important; --input-background-fill: #ffffff !important; --input-border-color: #d1d5db !important; --label-text-color: #374151 !important; --block-label-text-color: #374151 !important; --block-title-text-color: #111827 !important; --body-text-color: #111827 !important; --body-text-color-subdued: #6b7280 !important; --link-text-color: #2563eb !important; --button-primary-background-fill: #2563eb !important; --button-primary-text-color: #ffffff !important; --button-secondary-background-fill: #ffffff !important; --button-secondary-text-color: #374151 !important; --tab-text-color: #374151 !important; --tab-text-color-selected: #2563eb !important; color: #111827 !important; background-color: #f0f4f8 !important; } /* ── Core ────────────────────────────────────────────────────────────── */ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important; background: #f0f4f8 !important; } /* ── Header banner ──────────────────────────────────────────────────── */ .header-banner { background: linear-gradient(135deg, #faf5ff 0%, #fff0f9 50%, #eff6ff 100%); border: 1px solid #e9d5ff; border-radius: 16px; padding: 28px 32px; margin-bottom: 20px; box-shadow: 0 2px 8px rgba(139,92,246,0.08); position: relative; overflow: hidden; } .header-banner::before { content: ''; position: absolute; top: -40%; right: -5%; width: 280px; height: 280px; background: radial-gradient(circle, rgba(139,92,246,0.08) 0%, transparent 70%); pointer-events: none; } /* ── Reusable info boxes ─────────────────────────────────────────────── */ .info-box { background: #f8fafc; border: 1px solid #e2e8f0; border-left: 3px solid #3b82f6; border-radius: 8px; padding: 12px 16px; color: #475569; font-size: 13px; margin-bottom: 16px; line-height: 1.5; } .info-box code { background: #e2e8f0; color: #1e293b; padding: 1px 5px; border-radius: 3px; font-family: monospace; font-size: 0.9em; } .section-label { color: #2563eb; font-size: 12px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.6px; margin-bottom: 10px; margin-top: 10px; } .format-hint { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 14px; margin-top: 10px; font-size: 12px; color: #475569; } .format-hint-title { color: #2563eb; font-weight: 600; margin-bottom: 6px; } .format-hint pre { color: #475569; margin: 0; font-size: 11px; white-space: pre-wrap; } .format-hint-note { color: #94a3b8; font-size: 11px; margin-top: 8px; } .placeholder-msg { color: #9ca3af; text-align: center; padding: 40px; font-size: 14px; } .section-divider { border: none; border-top: 1px solid #e2e8f0; margin: 24px 0; } .batch-section-label { color: #2563eb; font-size: 14px; font-weight: 600; margin-bottom: 12px; } /* ── Result & summary cards ─────────────────────────────────────────── */ .result-card { background: #ffffff; border: 1px solid #e2e8f0; border-radius: 16px; padding: 24px; box-shadow: 0 1px 4px rgba(0,0,0,0.06); font-family: 'Segoe UI', system-ui, sans-serif; } .stat-grid-3 { display:grid; grid-template-columns:repeat(3,1fr); gap:12px; margin:14px 0; } .stat-grid-2 { display:grid; grid-template-columns:1fr 1fr; gap:10px; margin-top:10px; } .stat-item { background:#f8fafc; border:1px solid #e2e8f0; padding:12px; border-radius:8px; text-align:center; } .stat-label { color:#6b7280; font-size:11px; text-transform:uppercase; letter-spacing:0.4px; } .stat-value { color:#111827; font-size:22px; font-weight:700; line-height:1.2; margin-top:2px; } .output-path-box { background:#f0fdf4; border:1px solid #bbf7d0; border-radius:8px; padding:10px 14px; margin-top:12px; font-family:monospace; } .output-path-title { color:#059669; font-size:12px; font-weight:600; } .output-path-dir { color:#065f46; font-size:11px; margin-top:4px; } .output-path-files { color:#6b7280; font-size:10px; margin-top:4px; line-height:1.6; } /* ── Code blocks ────────────────────────────────────────────────────── */ .code-block { background: #1e293b; color: #a3e635; border-radius: 8px; padding: 12px; font-size: 12px; font-family: monospace; white-space: pre; overflow-x: auto; } /* ── Setup instructions card ─────────────────────────────────────────── */ .setup-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; margin-top:16px; font-family:'Segoe UI',system-ui,sans-serif; } .setup-title { color:#111827; font-size:15px; font-weight:700; margin-bottom:12px; } .setup-step { color:#374151; font-size:13px; line-height:1.8; } .setup-step strong { color:#2563eb; } /* ── Education ──────────────────────────────────────────────────────── */ .edu-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:16px; padding:28px; font-family:'Segoe UI',system-ui,sans-serif; color:#374151; line-height:1.7; } .edu-card h2 { color:#111827; font-size:22px; margin-top:0; } .edu-card h3 { color:#7c3aed; font-size:16px; margin-top:20px; } .stage-cards-grid { display:grid; grid-template-columns:repeat(3,1fr); gap:16px; margin:14px 0; } .stage-card-pre { background:#f0fdf4; border-top:4px solid #16a34a; padding:16px; border-radius:10px; } .stage-card-peri { background:#fffbeb; border-top:4px solid #d97706; padding:16px; border-radius:10px; } .stage-card-post { background:#faf5ff; border-top:4px solid #7c3aed; padding:16px; border-radius:10px; } .disclaimer-box { background:#fffbeb; border-left:3px solid #d97706; padding:12px 16px; border-radius:0 8px 8px 0; margin-top:14px; font-size:13px; color:#374151; } /* ── Feature reference table ────────────────────────────────────────── */ .feature-table-wrap { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; max-height:500px; overflow-y:auto; font-family:'Segoe UI',system-ui,sans-serif; } .feature-table-wrap table { width:100%; border-collapse:collapse; } .feature-table-wrap thead tr { background:#f8fafc; } .feature-table-wrap th { padding:8px; color:#6b7280; font-size:11px; text-align:left; text-transform:uppercase; letter-spacing:0.4px; } .feature-table-wrap tr { border-bottom:1px solid #e2e8f0; } .feature-table-wrap td { padding:8px; } .feature-code { color:#2563eb; font-family:monospace; font-size:13px; } .feature-desc { color:#374151; font-size:12px; } .feature-num { color:#9ca3af; font-size:12px; } /* ── Model status card ──────────────────────────────────────────────── */ .status-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; font-family:'Segoe UI',system-ui,sans-serif; } /* ── Footer ─────────────────────────────────────────────────────────── */ .app-footer { text-align:center; color:#9ca3af; font-size:11px; margin-top:24px; padding:16px; border-top:1px solid #e2e8f0; } .app-footer a { color:#2563eb; text-decoration:none; } /* ── Responsive — Tablet (≤ 768 px) ────────────────────────────────── */ @media (max-width: 768px) { .gradio-container { padding: 8px !important; } .header-banner { padding: 16px 20px !important; margin-bottom: 12px !important; } .header-status-badge { display: none !important; } .stat-grid-3 { grid-template-columns: 1fr !important; } .stat-grid-2 { grid-template-columns: 1fr !important; } .stage-cards-grid { grid-template-columns: 1fr !important; } } /* ── Responsive — Mobile (≤ 480 px) ────────────────────────────────── */ @media (max-width: 480px) { .header-banner h1 { font-size: 18px !important; } .result-card { padding: 16px !important; } .edu-card { padding: 16px !important; } .setup-card { padding: 14px !important; } } """ HEADER_HTML = """
🌸

SWAN Menopause Prediction

AI-powered menopausal stage prediction & symptom forecasting · Based on the SWAN dataset

Status
{status}
""".format( color = "#059669" if _MODEL_OK else "#dc2626", status = "Models Ready ✅" if _MODEL_OK else "Models Needed ⚠️", ) # ── Force-light-mode JS (runs on every page load) ───────────────────────────── # Removes Gradio's .dark class, locks localStorage to "light", and uses a # MutationObserver to prevent the class from being re-applied — works on # HuggingFace Spaces regardless of the user's OS/browser dark-mode setting. FORCE_LIGHT_JS = """ function() { const forceLightMode = () => { if (document.body.classList.contains('dark')) { document.body.classList.remove('dark'); } }; // Apply immediately forceLightMode(); // Lock Gradio's stored preference try { localStorage.setItem('theme', 'light'); } catch(e) {} // Watch for Gradio trying to re-add .dark and block it new MutationObserver(function(mutations) { mutations.forEach(function(m) { if (m.attributeName === 'class') forceLightMode(); }); }).observe(document.body, { attributes: true, attributeFilter: ['class'] }); } """ # ── App builder ─────────────────────────────────────────────────────────────── def build_app(): with gr.Blocks( css = CUSTOM_CSS, js = FORCE_LIGHT_JS, title = "SWAN Menopause Prediction", theme = gr.themes.Soft( primary_hue = "blue", neutral_hue = "slate", ).set( # ── Body ────────────────────────────────────────────────────── body_background_fill = "#f0f4f8", body_background_fill_dark = "#f0f4f8", body_text_color = "#111827", body_text_color_dark = "#111827", body_text_color_subdued = "#6b7280", body_text_color_subdued_dark = "#6b7280", # ── Panel / block backgrounds ────────────────────────────────── background_fill_primary = "#ffffff", background_fill_primary_dark = "#ffffff", background_fill_secondary = "#f8fafc", background_fill_secondary_dark = "#f8fafc", block_background_fill = "#ffffff", block_background_fill_dark = "#ffffff", block_border_color = "#e2e8f0", block_border_color_dark = "#e2e8f0", block_label_background_fill = "#f8fafc", block_label_background_fill_dark= "#f8fafc", block_label_text_color = "#374151", block_label_text_color_dark = "#374151", block_title_text_color = "#111827", block_title_text_color_dark = "#111827", # ── Inputs ──────────────────────────────────────────────────── input_background_fill = "#ffffff", input_background_fill_dark = "#ffffff", input_background_fill_focus = "#ffffff", input_background_fill_focus_dark= "#ffffff", input_border_color = "#d1d5db", input_border_color_dark = "#d1d5db", input_border_color_focus = "#3b82f6", input_border_color_focus_dark = "#3b82f6", input_placeholder_color = "#9ca3af", input_placeholder_color_dark = "#9ca3af", # ── Borders ──────────────────────────────────────────────────── border_color_primary = "#e2e8f0", border_color_primary_dark = "#e2e8f0", border_color_accent = "#3b82f6", border_color_accent_dark = "#3b82f6", # ── Buttons ──────────────────────────────────────────────────── button_primary_background_fill = "#2563eb", button_primary_background_fill_dark = "#2563eb", button_primary_background_fill_hover = "#1d4ed8", button_primary_background_fill_hover_dark = "#1d4ed8", button_primary_text_color = "#ffffff", button_primary_text_color_dark = "#ffffff", button_secondary_background_fill = "#ffffff", button_secondary_background_fill_dark = "#ffffff", button_secondary_background_fill_hover = "#f1f5f9", button_secondary_background_fill_hover_dark="#f1f5f9", button_secondary_text_color = "#374151", button_secondary_text_color_dark = "#374151", button_secondary_border_color = "#e2e8f0", button_secondary_border_color_dark = "#e2e8f0", # ── Checkbox / Radio ────────────────────────────────────────── checkbox_background_color = "#ffffff", checkbox_background_color_dark = "#ffffff", checkbox_background_color_selected = "#2563eb", checkbox_background_color_selected_dark = "#2563eb", checkbox_border_color = "#d1d5db", checkbox_border_color_dark = "#d1d5db", checkbox_border_color_focus = "#3b82f6", checkbox_border_color_focus_dark = "#3b82f6", # ── Slider ──────────────────────────────────────────────────── slider_color = "#2563eb", slider_color_dark = "#2563eb", # ── Table ───────────────────────────────────────────────────── table_odd_background_fill = "#f8fafc", table_odd_background_fill_dark = "#f8fafc", table_even_background_fill = "#ffffff", table_even_background_fill_dark = "#ffffff", table_border_color = "#e2e8f0", table_border_color_dark = "#e2e8f0", # ── Links ───────────────────────────────────────────────────── link_text_color = "#2563eb", link_text_color_dark = "#2563eb", link_text_color_hover = "#1d4ed8", link_text_color_hover_dark = "#1d4ed8", link_text_color_visited = "#7c3aed", link_text_color_visited_dark = "#7c3aed", # ── Accent ──────────────────────────────────────────────────── color_accent_soft = "#eff6ff", color_accent_soft_dark = "#eff6ff", ), ) as app: gr.HTML(HEADER_HTML) with gr.Tabs(): # ── TAB 1: Single Stage Prediction ──────────────────────────────── with gr.Tab("🔮 Stage Prediction"): gr.HTML("""
Fill in the fields below to predict menopausal stage for a single individual. All fields are optional — the pipeline handles missing values automatically. A timestamped output folder is created in swan_ml_output/ for every run.
""") with gr.Row(): # ── Input column ────────────────────────────────────────── with gr.Column(scale=2): with gr.Group(): gr.HTML('
Demographics
') with gr.Row(): age = gr.Slider( minimum=35, maximum=75, value=48, step=1, label="Age (AGE7)", ) race = gr.Dropdown( choices=[1, 2, 3, 4, 5], value=1, label="Race (RACE)", info="1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic", ) langint = gr.Dropdown( choices=[1, 2, 3], value=1, label="Interview Language (LANGINT7)", info="1=English, 2=Spanish, 3=Other", ) with gr.Group(): gr.HTML('
Vasomotor Symptoms
') with gr.Row(): hot_flash = gr.Slider( minimum=1, maximum=5, value=1, step=1, label="Hot Flash Severity (HOTFLAS7)", info="1=None, 5=Very severe", ) num_hot_flash = gr.Slider( minimum=0, maximum=15, value=0, step=1, label="# Hot Flashes/Week (NUMHOTF7)", ) bothersome_hf = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="How Bothersome (BOTHOTF7)", info="1=Not at all, 4=Extremely", ) with gr.Group(): gr.HTML('
Sleep & Mood
') with gr.Row(): sleep_quality = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Sleep Quality (SLEEPQL7)", info="1=Very good, 5=Very poor", ) depression = gr.Slider( minimum=0, maximum=4, value=0, step=1, label="Depression Indicator (DEPRESS7)", info="0=No, higher=more severe", ) with gr.Row(): mood_change = gr.Slider( minimum=1, maximum=5, value=1, step=1, label="Mood Changes (MOODCHG7)", info="1=None, 5=Severe", ) irritability = gr.Slider( minimum=1, maximum=5, value=1, step=1, label="Irritability (IRRITAB7)", ) with gr.Group(): gr.HTML('
Physical & Gynaecological
') with gr.Row(): pain = gr.Slider( minimum=0, maximum=5, value=0, step=1, label="Pain Indicator (PAIN17)", ) abbleed = gr.Dropdown( choices=[0, 1, 2], value=0, label="Abnormal Bleeding (ABBLEED7)", info="0=No, 1=Yes, 2=Unsure", ) with gr.Row(): vaginal_dryness = gr.Slider( minimum=0, maximum=5, value=0, step=1, label="Vaginal Dryness (VAGINDR7)", ) lmp_day = gr.Number( value=None, label="LMP Day (LMPDAY7)", info="Day of last menstrual period (optional)", ) model_choice = gr.Radio( choices=["RandomForest", "LogisticRegression"], value="RandomForest", label="Model", info="RandomForest: higher accuracy | " "LogisticRegression: more interpretable", ) predict_btn = gr.Button( "🔮 Predict Stage", variant="primary", size="lg" ) # ── Output column ───────────────────────────────────────── with gr.Column(scale=3): result_html = gr.HTML( '
Fill in the form and click Predict Stage
' ) result_chart = gr.Plot(label="Stage Probabilities") confidence_note = gr.Textbox( label="Confidence Note", interactive=False, lines=2 ) compare_html = gr.HTML() stage_download = gr.File( label="Download Prediction CSV", interactive=False ) predict_btn.click( fn = predict_single_stage, inputs = [ age, race, langint, hot_flash, num_hot_flash, bothersome_hf, sleep_quality, depression, mood_change, irritability, pain, abbleed, vaginal_dryness, lmp_day, model_choice, ], outputs = [ result_html, result_chart, confidence_note, compare_html, stage_download, ], ) # ── TAB 2: Batch Stage Prediction ───────────────────────────────── with gr.Tab("📁 Batch Stage Prediction"): gr.HTML("""
Upload a CSV file with individual feature values for batch prediction. Results + charts + a summary report are saved to a timestamped folder inside swan_ml_output/.
""") with gr.Row(): with gr.Column(scale=1): batch_file = gr.File( label="Upload stage_input.csv", file_types=[".csv"], ) batch_model = gr.Radio( choices=["RandomForest", "LogisticRegression"], value="RandomForest", label="Model", ) gr.HTML("""
Expected CSV Format
individual,AGE7,RACE,HOTFLAS7,...
Person_001,48,1,2,...
Person_002,52,2,1,...
See the test-csv/ folder for an approved example.
""") batch_predict_btn = gr.Button( "🚀 Run Batch Prediction", variant="primary" ) with gr.Column(scale=2): batch_summary_html = gr.HTML( '
Upload a CSV to begin
' ) batch_download = gr.File( label="Download Predictions CSV", interactive=False ) batch_results_df = gr.DataFrame( label="Results Preview (first 20 rows)", interactive=False, ) batch_predict_btn.click( fn = predict_batch_stage, inputs = [batch_file, batch_model], outputs = [batch_download, batch_summary_html, batch_results_df], ) # ── TAB 3: Symptom Forecast ─────────────────────────────────────── with gr.Tab("🌊 Symptom Forecast"): gr.HTML("""
Predict hot flash and mood change probability based on cycle day (calculated from Last Menstrual Period date). All outputs are saved to a timestamped folder inside swan_ml_output/.
""") with gr.Row(): with gr.Column(scale=1): sym_individual = gr.Textbox( label="Individual ID (optional)", placeholder="e.g., Patient_001", ) sym_lmp = gr.Textbox( label="Last Menstrual Period (LMP)", placeholder="2026-01-15 or 15 (day of month)", info="Full date (YYYY-MM-DD) or day-of-month integer", ) sym_date = gr.Textbox( label="Target Date (optional)", placeholder="2026-02-27 (defaults to today)", info="Date to forecast for (YYYY-MM-DD)", ) sym_cycle = gr.Slider( minimum=21, maximum=40, value=28, step=1, label="Cycle Length (days)", ) sym_predict_btn = gr.Button( "🌊 Forecast Symptoms", variant="primary" ) with gr.Column(scale=2): sym_result_html = gr.HTML( '
Enter LMP date and click Forecast
' ) sym_chart = gr.Plot(label="Cycle Position") sym_download = gr.File( label="Download Forecast CSV", interactive=False ) sym_predict_btn.click( fn = predict_symptoms, inputs = [sym_individual, sym_lmp, sym_date, sym_cycle], outputs = [sym_result_html, sym_chart, sym_download], ) gr.HTML('
') gr.HTML('
📁 Batch Symptom Forecasting
') with gr.Row(): with gr.Column(scale=1): sym_batch_file = gr.File( label="Upload symptoms_input.csv", file_types=[".csv"], ) sym_lmp_col = gr.Textbox( label="LMP Column Name", value="LMP" ) sym_date_col = gr.Textbox( label="Date Column Name (optional)", value="date" ) sym_cycle_batch = gr.Slider( minimum=21, maximum=40, value=28, step=1, label="Default Cycle Length", ) sym_batch_btn = gr.Button( "🌊 Run Batch Forecast", variant="primary" ) with gr.Column(scale=2): sym_batch_summary = gr.HTML( '
Upload a CSV to begin
' ) sym_batch_download = gr.File( label="Download Symptom Forecast CSV", interactive=False ) sym_batch_df = gr.DataFrame( label="Results Preview", interactive=False, ) sym_batch_btn.click( fn = predict_symptoms_batch, inputs = [ sym_batch_file, sym_lmp_col, sym_date_col, sym_cycle_batch, ], outputs = [sym_batch_download, sym_batch_summary, sym_batch_df], ) # ── TAB 4: Education ────────────────────────────────────────────── with gr.Tab("📚 Menopause Education"): gr.HTML(EDUCATION_HTML) # ── TAB 5: Feature Reference ────────────────────────────────────── with gr.Tab("🔬 Feature Reference"): gr.HTML("""
Canonical list of features used by the trained models (from forecast_metadata.json). For batch CSV uploads, column names must match these feature names.
""") gr.HTML(get_feature_reference()) # ── TAB 6: Model Status ─────────────────────────────────────────── with gr.Tab("⚙️ Model Status"): gr.HTML(get_model_status()) gr.HTML("""
🚀 Setup Instructions

Step 1 — Train models:

python menopause.py

Step 2 — Verify artifacts:

ls swan_ml_output/
# rf_pipeline.pkl  lr_pipeline.pkl  forecast_metadata.json

Step 3 — Run this app:

python app.py

Step 4 — Deploy on Hugging Face Spaces:

git lfs install
git lfs track "*.pkl"
git add .
git commit -m "SWAN menopause prediction app"
git push

Output folder structure (per run):

swan_ml_output/
  <YYYYMMDD_HHMMSS>/
    charts/       ← PNG visualizations
    predictions/  ← CSV result files
    reports/      ← TXT summary reports
""") gr.HTML(""" """) return app # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": demo = build_app() demo.launch( server_name = "0.0.0.0", server_port = int(os.environ.get("PORT", 7860)), share = False, show_error = True, )