menopause-ml / app.py
techatcreated's picture
Upload app.py
9aa4b61 verified
"""
SWAN Menopause Stage Prediction & Forecasting — Gradio UI
Hugging Face Spaces deployment-ready.
Run locally: python app.py
Deploy: Push to a HF Space with SDK=gradio
Output structure (per execution):
swan_ml_output/
<YYYYMMDD_HHMMSS>/
charts/ ← PNG visualizations
predictions/ ← CSV result files
reports/ ← TXT summary reports
"""
import os
import json
import warnings
from datetime import datetime
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
# ── Gradio ────────────────────────────────────────────────────────────────────
import gradio as gr
# ── Local ML module ───────────────────────────────────────────────────────────
try:
from menopause import (
MenopauseForecast,
SymptomCycleForecaster,
load_forecast_model,
)
_MODULE_AVAILABLE = True
except ImportError:
_MODULE_AVAILABLE = False
# ── Model loading ─────────────────────────────────────────────────────────────
FORECAST_DIR = os.environ.get("FORECAST_DIR", "swan_ml_output")
OUTPUT_BASE = Path(FORECAST_DIR)
_forecast: Optional[MenopauseForecast] = None # type: ignore[type-arg]
_metadata: dict = {}
def _load_models():
"""Attempt to load saved joblib pipelines. Returns (success, message)."""
global _forecast, _metadata
if not _MODULE_AVAILABLE:
return False, "menopause.py not found. Make sure it is in the same directory."
meta_path = Path(FORECAST_DIR) / "forecast_metadata.json"
rf_path = Path(FORECAST_DIR) / "rf_pipeline.pkl"
lr_path = Path(FORECAST_DIR) / "lr_pipeline.pkl"
if not all(p.exists() for p in (meta_path, rf_path, lr_path)):
return (
False,
f"Model artifacts not found in '{FORECAST_DIR}'. "
"Run `python menopause.py` to train and save the models first.",
)
try:
_forecast = load_forecast_model(FORECAST_DIR)
with open(meta_path) as fh:
_metadata = json.load(fh)
return True, f"✅ Models loaded — {len(_metadata.get('feature_names', []))} features"
except Exception as exc:
return False, f"Error loading models: {exc}"
_MODEL_OK, _MODEL_MSG = _load_models()
# ── Output directory management ───────────────────────────────────────────────
def _make_run_dir() -> Path:
"""Create and return a unique timestamped run directory under swan_ml_output/."""
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = OUTPUT_BASE / ts
(run_dir / "charts").mkdir(parents=True, exist_ok=True)
(run_dir / "predictions").mkdir(parents=True, exist_ok=True)
(run_dir / "reports").mkdir(parents=True, exist_ok=True)
return run_dir
def _get_file_path(file_obj) -> Optional[str]:
"""
Safely extract a file-system path from a Gradio file component value.
Gradio ≤ 3.x → returns a file-like object with a .name attribute.
Gradio 4.x → returns a str path (or NamedString subclass).
This helper handles both.
"""
if file_obj is None:
return None
if hasattr(file_obj, "name"):
return file_obj.name
return str(file_obj)
# ── Constants & helpers ───────────────────────────────────────────────────────
STAGE_COLORS = {"pre": "#16a34a", "peri": "#d97706", "post": "#7c3aed"}
STAGE_EMOJI = {"pre": "🟢", "peri": "🟡", "post": "🟣"}
STAGE_LABELS = {
"pre": "Pre-Menopausal",
"peri": "Peri-Menopausal",
"post": "Post-Menopausal",
}
STAGE_INFO = {
"pre": {
"title": "Pre-Menopausal",
"description": "Regular menstrual cycles with typical hormonal fluctuations. Ovarian function is normal.",
"symptoms": ["Regular periods", "Normal hormone levels", "Potential mild PMS"],
"guidance": "Maintain regular check-ups. Track your cycle and note any changes.",
},
"peri": {
"title": "Peri-Menopausal (Transition)",
"description": "Hormonal changes begin — estrogen and progesterone levels fluctuate. Cycles become irregular.",
"symptoms": ["Irregular periods", "Hot flashes", "Sleep disturbances", "Mood changes", "Night sweats"],
"guidance": "Consult your healthcare provider. Lifestyle adjustments (diet, exercise, sleep) can help.",
},
"post": {
"title": "Post-Menopausal",
"description": "12+ months since last menstrual period. Estrogen remains at consistently lower levels.",
"symptoms": ["No periods", "Possible continued hot flashes", "Vaginal dryness", "Bone density changes"],
"guidance": "Focus on bone health, cardiovascular health, and regular screenings. Discuss HRT options.",
},
}
# Feature descriptions keyed by the model's canonical feature names
FEATURE_DESCRIPTIONS = {
"PAIN17": "Pain indicator (visit-specific)",
"PAINTW17": "Pain two-week indicator",
"PAIN27": "Secondary pain indicator",
"PAINTW27": "Secondary pain two-week indicator",
"SLEEP17": "Sleep disturbance pattern 1",
"SLEEP27": "Sleep disturbance pattern 2",
"BCOHOTH7": "Birth control — other method",
"EXERCIS7": "General exercise indicator",
"EXERHAR7": "Vigorous exercise",
"EXEROST7": "Osteoporosis exercise",
"EXERMEN7": "Exercise — mental health",
"EXERLOO7": "Exercise lookalike",
"EXERMEM7": "Exercise — memory",
"EXERPER7": "Exercise perception",
"EXERGEN7": "General exercise type",
"EXERWGH7": "Weight exercise",
"EXERADV7": "Exercise advice indicator",
"EXEROTH7": "Other exercise",
"EXERSPE7": "Specific exercise",
"ABBLEED7": "Abnormal bleeding (0=no, 1=yes)", # ← correct feature name
"BLEEDNG7": "Bleeding pattern",
"LMPDAY7": "Last menstrual period day",
"DEPRESS7": "Depression indicator",
"SEX17": "Sexual activity indicator 1",
"SEX27": "Sexual activity indicator 2",
"SEX37": "Sexual activity indicator 3",
"SEX47": "Sexual activity indicator 4",
"SEX57": "Sexual activity indicator 5",
"SEX67": "Sexual activity indicator 6",
"SEX77": "Sexual activity indicator 7",
"SEX87": "Sexual activity indicator 8",
"SEX97": "Sexual activity indicator 9",
"SEX107": "Sexual activity indicator 10",
"SEX117": "Sexual activity indicator 11",
"SEX127": "Sexual activity indicator 12",
"SMOKERE7": "Smoking status",
"HOTFLAS7": "Hot flash severity (1=none, 5=very severe)",
"NUMHOTF7": "Number of hot flashes per week",
"BOTHOTF7": "How bothersome are hot flashes",
"IRRITAB7": "Irritability level",
"VAGINDR7": "Vaginal dryness",
"MOODCHG7": "Mood change frequency",
"SLEEPQL7": "Sleep quality score",
"PHYSILL7": "Physical illness indicators",
"HOTHEAD7": "Hot flashes with headache",
"EXER12H7": "Exercise in last 12 hours",
"ALCO24H7": "Alcohol in last 24h",
"AGE7": "Age (years)",
"RACE": "Race (1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic)",
"LANGINT7": "Interview language indicator",
}
def _confidence_color(conf: float) -> str:
if conf >= 0.8:
return "#16a34a"
elif conf >= 0.6:
return "#d97706"
return "#dc2626"
# ── Chart builders ────────────────────────────────────────────────────────────
def _make_proba_chart(
probabilities: dict,
predicted_stage: str,
save_path: Optional[Path] = None,
) -> plt.Figure:
"""Horizontal bar chart for stage probabilities. Optionally saves PNG."""
fig, ax = plt.subplots(figsize=(6, 3.5))
fig.patch.set_facecolor("#1a1a2e")
ax.set_facecolor("#16213e")
stages = list(probabilities.keys())
probs = [probabilities[s] * 100 for s in stages]
colors = [STAGE_COLORS.get(s, "#607d8b") for s in stages]
edge_colors = ["white" if s == predicted_stage else "none" for s in stages]
lws = [2.5 if s == predicted_stage else 0 for s in stages]
bars = ax.barh(stages, probs, color=colors, edgecolor=edge_colors,
linewidth=lws, height=0.5, zorder=3)
for bar, prob in zip(bars, probs):
ax.text(
min(prob + 1, 98), bar.get_y() + bar.get_height() / 2,
f"{prob:.1f}%",
va="center", ha="left", color="white", fontsize=11, fontweight="bold",
)
labels = [STAGE_LABELS.get(s, s) for s in stages]
ax.set_yticks(range(len(stages)))
ax.set_yticklabels(labels, color="white", fontsize=10)
ax.set_xlim(0, 105)
ax.tick_params(colors="white", labelsize=11)
ax.spines[["top", "right", "left", "bottom"]].set_visible(False)
ax.xaxis.set_visible(False)
for spine in ax.spines.values():
spine.set_color("#333")
ax.set_title("Stage Probabilities", color="white", fontsize=12,
pad=10, fontweight="bold")
ax.grid(axis="x", color="#333", linestyle="--", linewidth=0.5, zorder=0)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight",
facecolor=fig.get_facecolor())
return fig
def _make_cycle_chart(
cycle_day: int,
cycle_length: int = 28,
hot_prob: float = None,
mood_prob: float = None,
save_path: Optional[Path] = None,
) -> plt.Figure:
"""Circular cycle-day visualization. Optionally saves PNG."""
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
fig.patch.set_facecolor("#1a1a2e")
ax.set_facecolor("#16213e")
days = np.linspace(0, 2 * np.pi, cycle_length, endpoint=False)
for i, d in enumerate(days):
phase = i / cycle_length
color = plt.cm.RdYlGn(1 - phase)
ax.bar(d, 1, width=2 * np.pi / cycle_length * 0.9,
bottom=0.5, color=color, alpha=0.4, zorder=1)
if cycle_day is not None:
angle = (cycle_day - 1) / cycle_length * 2 * np.pi
ax.scatter([angle], [1.05], s=200, color="#ff6b6b", zorder=5, linewidths=2)
ax.annotate(
f"Day {cycle_day}",
xy=(angle, 1.05), xytext=(0, 0),
textcoords="offset points", ha="center", va="center",
color="white", fontsize=12, fontweight="bold",
)
ax.set_rticks([])
ax.set_xticks([i * 2 * np.pi / 4 for i in range(4)])
ax.set_xticklabels(["Day 1", "Day 7", "Day 14", "Day 21"],
color="#aaa", fontsize=9)
ax.set_yticklabels([])
ax.spines["polar"].set_color("#333")
ax.grid(color="#333", linewidth=0.5)
title = "Cycle Position"
if hot_prob is not None:
title += f"\n🔥 {hot_prob:.0%} 😤 {mood_prob:.0%}"
ax.set_title(title, color="white", fontsize=11, pad=20, fontweight="bold")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight",
facecolor=fig.get_facecolor())
return fig
def _make_batch_summary_chart(results_df: pd.DataFrame,
save_path: Optional[Path] = None) -> None:
"""Stage distribution + confidence histogram for batch runs. Saves PNG."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
fig.patch.set_facecolor("#1a1a2e")
# Stage distribution pie
stage_counts = results_df["predicted_stage"].value_counts()
colors = [STAGE_COLORS.get(s, "#607d8b") for s in stage_counts.index]
ax1.set_facecolor("#16213e")
wedges, texts, autotexts = ax1.pie(
stage_counts.values, labels=stage_counts.index,
colors=colors, autopct="%1.0f%%",
textprops={"color": "white", "fontsize": 10},
)
for at in autotexts:
at.set_color("white")
ax1.set_title("Stage Distribution", color="white", fontsize=11, fontweight="bold")
# Confidence histogram
ax2.set_facecolor("#16213e")
if "confidence" in results_df.columns:
conf = results_df["confidence"].dropna()
ax2.hist(conf, bins=min(10, len(conf)), color="#3B82F6",
edgecolor="#1a1a2e", alpha=0.8)
ax2.axvline(0.8, color="#4CAF50", linestyle="--",
linewidth=1.5, label="High (0.80)")
ax2.axvline(0.6, color="#FF9800", linestyle="--",
linewidth=1.5, label="Med (0.60)")
ax2.legend(fontsize=8, labelcolor="white", facecolor="#0d0d1a")
ax2.set_xlabel("Confidence", color="#aaa", fontsize=9)
ax2.set_ylabel("Count", color="#aaa", fontsize=9)
ax2.tick_params(colors="white", labelsize=9)
for sp in ["top", "right"]:
ax2.spines[sp].set_visible(False)
for sp in ["left", "bottom"]:
ax2.spines[sp].set_color("#333")
ax2.set_title("Confidence Distribution", color="white",
fontsize=11, fontweight="bold")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight",
facecolor=fig.get_facecolor())
plt.close(fig)
# ── Text report writers ───────────────────────────────────────────────────────
def _write_single_stage_report(
path: Path,
stage: str,
confidence: float,
probabilities: dict,
model: str,
comparison: dict,
input_features: dict,
):
lines = [
"=" * 60,
"SWAN MENOPAUSE STAGE PREDICTION REPORT",
f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"=" * 60,
"",
f"Predicted Stage : {STAGE_LABELS.get(stage, stage)}",
f"Model : {model}",
f"Confidence : {confidence:.1%}",
"",
"Stage Probabilities:",
]
for s, p in probabilities.items():
bar = "█" * int(p * 20)
lines.append(f" {s:<6} : {p:.4f} {bar}")
lines += [
"",
"Model Comparison:",
f" RandomForest → {comparison['RandomForest']['stage']}"
f" ({comparison['RandomForest'].get('confidence', 0):.1%})",
f" LogisticRegression → {comparison['LogisticRegression']['stage']}"
f" ({comparison['LogisticRegression'].get('confidence', 0):.1%})",
"",
"Input Features (non-NaN):",
]
for k, v in input_features.items():
if v is not None and not (isinstance(v, float) and np.isnan(v)):
lines.append(f" {k:<12} = {v}")
lines += [
"",
"⚠️ For research/educational use only. Not a clinical diagnosis.",
"=" * 60,
]
path.write_text("\n".join(lines), encoding="utf-8")
def _write_batch_report(
path: Path,
results: pd.DataFrame,
model: str,
run_dir: Path,
):
total = len(results)
dist = results["predicted_stage"].value_counts().to_dict() \
if "predicted_stage" in results.columns else {}
if "confidence" in results.columns:
conf = results["confidence"]
mean_c = conf.mean(); min_c = conf.min(); max_c = conf.max()
high = int((conf > 0.8).sum())
medium = int(((conf > 0.6) & (conf <= 0.8)).sum())
low = int((conf <= 0.6).sum())
else:
mean_c = min_c = max_c = high = medium = low = 0
lines = [
"=" * 60,
"SWAN BATCH STAGE PREDICTION REPORT",
f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
f"Model : {model}",
"=" * 60,
"",
f"Total Individuals : {total}",
"",
"Stage Distribution:",
]
for stage in ["pre", "peri", "post"]:
count = dist.get(stage, 0)
pct = count / total * 100 if total else 0
lines.append(f" {stage:<6} : {count} ({pct:.1f}%)")
lines += [
"",
"Confidence Scores:",
f" Mean : {mean_c:.4f}",
f" Min : {min_c:.4f}",
f" Max : {max_c:.4f}",
"",
"Confidence Distribution:",
f" High (>0.80) : {high}/{total} ({high/total*100:.1f}%)" if total else " N/A",
f" Medium (0.60-0.80) : {medium}/{total} ({medium/total*100:.1f}%)" if total else " N/A",
f" Low (≤0.60) : {low}/{total} ({low/total*100:.1f}%)" if total else " N/A",
"",
f"Output Directory : {run_dir}",
"",
"⚠️ For research/educational use only. Not a clinical diagnosis.",
"=" * 60,
]
path.write_text("\n".join(lines), encoding="utf-8")
def _write_symptom_report(
path: Path,
individual_id: str,
lmp: str,
target_date: str,
cycle_day: int,
cycle_length: int,
hot_prob: float,
hot_pred: bool,
mood_prob: float,
mood_pred: bool,
):
hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0
mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0
lines = [
"=" * 60,
"SWAN SYMPTOM CYCLE FORECAST REPORT",
f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"=" * 60,
"",
f"Individual : {individual_id or 'N/A'}",
f"LMP : {lmp}",
f"Target Date : {target_date or 'Today'}",
f"Cycle Length : {cycle_length} days",
f"Cycle Day : {cycle_day}",
"",
"Symptom Probabilities:",
f" Hot Flash : {hp:.4f} {'[ELEVATED RISK]' if hot_pred else '[LOW RISK]'}",
f" Mood Change : {mp:.4f} {'[ELEVATED RISK]' if mood_pred else '[LOW RISK]'}",
"",
"⚠️ For research/educational use only. Not a clinical diagnosis.",
"=" * 60,
]
path.write_text("\n".join(lines), encoding="utf-8")
def _write_batch_symptom_report(
path: Path,
results: pd.DataFrame,
cycle_length: int,
run_dir: Path,
):
total = len(results)
hot_flags = int(results["hotflash_pred"].sum()) \
if "hotflash_pred" in results.columns else 0
mood_flags = int(results["mood_pred"].sum()) \
if "mood_pred" in results.columns else 0
mean_hot = float(results["hotflash_prob"].mean()) \
if "hotflash_prob" in results.columns else 0.0
mean_mood = float(results["mood_prob"].mean()) \
if "mood_prob" in results.columns else 0.0
lines = [
"=" * 60,
"SWAN BATCH SYMPTOM FORECAST REPORT",
f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
f"Cycle Length : {cycle_length} days",
"=" * 60,
"",
f"Total Individuals : {total}",
f"Hot Flash Risk : {hot_flags}/{total} elevated",
f"Mood Change Risk : {mood_flags}/{total} elevated",
f"Avg Hot Flash Prob : {mean_hot:.4f}",
f"Avg Mood Prob : {mean_mood:.4f}",
"",
f"Output Directory : {run_dir}",
"",
"⚠️ For research/educational use only. Not a clinical diagnosis.",
"=" * 60,
]
path.write_text("\n".join(lines), encoding="utf-8")
# ── Core prediction functions ─────────────────────────────────────────────────
def predict_single_stage(
age, race, langint,
hot_flash, num_hot_flash, bothersome_hf,
sleep_quality, depression_indicator, mood_change, irritability,
pain_indicator, abbleed, vaginal_dryness, lmp_day,
model_choice,
):
"""
Single-person stage prediction.
Returns (stage_html, chart_fig, conf_note, compare_html, csv_download_path).
"""
if not _MODEL_OK:
return f"⚠️ {_MODEL_MSG}", None, "Models unavailable.", "", None
# Build feature dict using the model's canonical feature names
def _v(x):
return float(x) if x is not None else np.nan
feature_dict = {
"AGE7": _v(age),
"RACE": _v(race),
"LANGINT7": _v(langint),
"HOTFLAS7": _v(hot_flash),
"NUMHOTF7": _v(num_hot_flash),
"BOTHOTF7": _v(bothersome_hf),
"SLEEPQL7": _v(sleep_quality),
"DEPRESS7": _v(depression_indicator),
"MOODCHG7": _v(mood_change),
"IRRITAB7": _v(irritability),
"PAIN17": _v(pain_indicator),
"ABBLEED7": _v(abbleed), # ← correct feature name (was ABLEED7)
"VAGINDR7": _v(vaginal_dryness),
"LMPDAY7": _v(lmp_day) if lmp_day else np.nan,
}
try:
result = _forecast.predict_single(feature_dict, model=model_choice, return_proba=True)
stage = result["stage"]
confidence = result.get("confidence") or 0.0
proba = result.get("probabilities") or {}
# ── Create timestamped run directory ──────────────────────────────────
run_dir = _make_run_dir()
# ── Save probability chart (PNG) ──────────────────────────────────────
chart_path = run_dir / "charts" / "stage_probabilities.png"
chart_fig = _make_proba_chart(proba, stage, save_path=chart_path) if proba else None
# ── Save prediction CSV ───────────────────────────────────────────────
pred_row = {
"predicted_stage": stage,
"model": model_choice,
"confidence": round(confidence, 4),
**{f"prob_{k}": round(v, 4) for k, v in proba.items()},
"timestamp": datetime.now().isoformat(),
}
csv_path = run_dir / "predictions" / "stage_prediction.csv"
pd.DataFrame([pred_row]).to_csv(csv_path, index=False)
# ── Model comparison ──────────────────────────────────────────────────
comparison = _forecast.compare_models(feature_dict)
rf_stage = comparison["RandomForest"]["stage"]
lr_stage = comparison["LogisticRegression"]["stage"]
agree = rf_stage == lr_stage
# ── Save text report ──────────────────────────────────────────────────
txt_path = run_dir / "reports" / "prediction_summary.txt"
_write_single_stage_report(
txt_path, stage, confidence, proba,
model_choice, comparison, feature_dict,
)
# ── Build result card HTML ────────────────────────────────────────────
info = STAGE_INFO.get(stage, {})
emoji = STAGE_EMOJI.get(stage, "⚪")
color = STAGE_COLORS.get(stage, "#607d8b")
conf_color = _confidence_color(confidence)
symptom_tags = "".join(
f'<span style="background:{color}14;color:{color};padding:4px 10px;'
f'border-radius:20px;border:1px solid {color}44;font-size:12px;'
f'font-weight:500">{s}</span>'
for s in info.get("symptoms", [])
)
stage_html = f"""
<div class="result-card" style="border-left:4px solid {color}">
<div style="display:flex;align-items:center;gap:12px;margin-bottom:16px;flex-wrap:wrap">
<span style="font-size:40px;flex-shrink:0">{emoji}</span>
<div style="flex:1;min-width:140px">
<div style="color:#6b7280;font-size:12px;text-transform:uppercase;letter-spacing:2px">
Predicted Stage
</div>
<div style="color:{color};font-size:26px;font-weight:700">
{STAGE_LABELS.get(stage, stage)}
</div>
</div>
<div style="text-align:right;flex-shrink:0">
<div style="color:#6b7280;font-size:11px">Confidence</div>
<div style="color:{conf_color};font-size:28px;font-weight:700">
{confidence:.0%}
</div>
</div>
</div>
<hr style="border:none;border-top:1px solid #e2e8f0;margin:12px 0">
<p style="color:#374151;font-size:14px;margin:8px 0">
{info.get('description', '')}
</p>
<div style="margin-top:12px">
<div style="color:#6b7280;font-size:11px;text-transform:uppercase;
letter-spacing:1px;margin-bottom:6px">Common Symptoms</div>
<div style="display:flex;flex-wrap:wrap;gap:6px">{symptom_tags}</div>
</div>
<div style="background:{color}0d;border-left:3px solid {color};
padding:10px 14px;margin-top:14px;border-radius:0 8px 8px 0">
<span style="color:{color};font-size:12px;font-weight:600">💡 Guidance: </span>
<span style="color:#374151;font-size:13px">{info.get('guidance', '')}</span>
</div>
<div style="color:#9ca3af;font-size:11px;margin-top:12px">
Model: {model_choice} · {datetime.now().strftime('%Y-%m-%d %H:%M')}
</div>
</div>
"""
# Confidence note
if confidence >= 0.8:
conf_note = "✅ High confidence — the model is quite certain about this stage."
elif confidence >= 0.6:
conf_note = ("⚠️ Moderate confidence — consider providing more feature values "
"or consulting a clinician.")
else:
conf_note = ("🔴 Low confidence — prediction is uncertain; "
"clinical consultation is strongly recommended.")
# Model comparison panel + run-dir info
compare_html = f"""
<div class="result-card" style="margin-top:0">
<div style="color:#6b7280;font-size:11px;text-transform:uppercase;
letter-spacing:1px;margin-bottom:10px;font-weight:600">
Model Comparison
</div>
<div class="stat-grid-2">
<div class="stat-item" style="border-top:3px solid #16a34a">
<div style="color:#16a34a;font-size:11px;font-weight:600">Random Forest</div>
<div style="color:#111827;font-size:17px;margin-top:4px">
{STAGE_EMOJI.get(rf_stage,'')} {STAGE_LABELS.get(rf_stage, rf_stage)}
</div>
<div style="color:#6b7280;font-size:12px">
{comparison['RandomForest'].get('confidence', 0):.0%} confidence
</div>
</div>
<div class="stat-item" style="border-top:3px solid #2563eb">
<div style="color:#2563eb;font-size:11px;font-weight:600">
Logistic Regression
</div>
<div style="color:#111827;font-size:17px;margin-top:4px">
{STAGE_EMOJI.get(lr_stage,'')} {STAGE_LABELS.get(lr_stage, lr_stage)}
</div>
<div style="color:#6b7280;font-size:12px">
{comparison['LogisticRegression'].get('confidence', 0):.0%} confidence
</div>
</div>
</div>
<div style="margin-top:10px;padding:8px;border-radius:8px;
background:{'#d1fae5' if agree else '#fef2f2'};
color:{'#065f46' if agree else '#9f1239'};
font-size:13px;text-align:center;font-weight:500">
{"✅ Both models agree — prediction is robust"
if agree else
"⚠️ Models disagree — interpret with caution"}
</div>
<div class="output-path-box">
<div class="output-path-title">📁 Outputs saved to:</div>
<div class="output-path-dir">{run_dir}/</div>
<div class="output-path-files">
charts/stage_probabilities.png<br>
predictions/stage_prediction.csv<br>
reports/prediction_summary.txt
</div>
</div>
</div>
"""
return stage_html, chart_fig, conf_note, compare_html, str(csv_path)
except Exception as exc:
return f"❌ Prediction error: {exc}", None, "", "", None
def predict_batch_stage(file, model_choice):
"""
Batch stage prediction from uploaded CSV.
Returns (csv_download_path, summary_html, preview_df).
"""
if not _MODEL_OK:
return None, f"⚠️ {_MODEL_MSG}", None
if file is None:
return None, "Please upload a CSV file.", None
file_path = _get_file_path(file)
try:
df = pd.read_csv(file_path)
except Exception as exc:
return None, f"Could not read CSV: {exc}", None
if df.empty:
return None, "Uploaded CSV is empty.", None
# Identify ID column
id_col_candidates = ["individual", "Individual", "ID", "id",
"SWANID", "subject", "Subject"]
id_col = next((c for c in id_col_candidates if c in df.columns), None)
# Validate features
feature_names = _metadata.get("feature_names", [])
matching = [c for c in df.columns if c in feature_names]
missing_pct = 1 - len(matching) / max(len(feature_names), 1)
warnings_list = []
if not matching:
return None, (
"❌ No matching feature columns found. "
"Please include columns from the training feature set "
"(see 'Feature Reference' tab)."
), None
if missing_pct > 0.5:
warnings_list.append(
f"⚠️ {missing_pct:.0%} of training features are missing — "
"prediction accuracy may be reduced."
)
try:
results = _forecast.predict_batch(df, model=model_choice, return_proba=True)
# Insert individual ID
if id_col:
results.insert(0, "individual", df[id_col].values)
else:
results.insert(0, "individual",
[f"Row_{i+1}" for i in range(len(results))])
results["model"] = model_choice
results["notes"] = ""
if "confidence" in results.columns:
low_mask = results["confidence"] < 0.6
results.loc[low_mask, "notes"] = "Low confidence — review manually"
# ── Create timestamped run directory ──────────────────────────────────
run_dir = _make_run_dir()
# ── Save predictions CSV ──────────────────────────────────────────────
csv_path = run_dir / "predictions" / "batch_stage_predictions.csv"
results.to_csv(csv_path, index=False)
# ── Save confidence/distribution chart (PNG) ──────────────────────────
chart_path = run_dir / "charts" / "batch_summary_chart.png"
_make_batch_summary_chart(results, save_path=chart_path)
# ── Save text report ──────────────────────────────────────────────────
txt_path = run_dir / "reports" / "batch_summary.txt"
_write_batch_report(txt_path, results, model_choice, run_dir)
# ── Build summary HTML ────────────────────────────────────────────────
total = len(results)
dist = results["predicted_stage"].value_counts().to_dict()
mean_conf = results["confidence"].mean() \
if "confidence" in results.columns else 0.0
high_conf = int((results["confidence"] > 0.8).sum()) \
if "confidence" in results.columns else 0
dist_bars = ""
for stage in ["pre", "peri", "post"]:
count = dist.get(stage, 0)
pct = count / total * 100
dist_bars += f"""
<div style="margin:6px 0">
<div style="display:flex;justify-content:space-between;margin-bottom:2px">
<span style="color:#374151;font-size:13px">
{STAGE_EMOJI.get(stage,'')} {STAGE_LABELS.get(stage, stage)}
</span>
<span style="color:#6b7280;font-size:12px">{count} ({pct:.0f}%)</span>
</div>
<div style="background:#e2e8f0;border-radius:4px;height:8px">
<div style="background:{STAGE_COLORS.get(stage,'#6b7280')};
width:{pct}%;height:8px;border-radius:4px"></div>
</div>
</div>"""
warn_html = "".join(
f'<div style="color:#d97706;font-size:12px;margin-top:4px">{w}</div>'
for w in warnings_list
)
summary_html = f"""
<div class="result-card">
<div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
📊 Batch Results — {total} individuals
</div>
{warn_html}
<div class="stat-grid-3">
<div class="stat-item">
<div class="stat-label">Total</div>
<div class="stat-value">{total}</div>
</div>
<div class="stat-item">
<div class="stat-label">Avg Confidence</div>
<div class="stat-value" style="color:{_confidence_color(mean_conf)}">
{mean_conf:.0%}
</div>
</div>
<div class="stat-item">
<div class="stat-label">High Conf (&gt;80%)</div>
<div class="stat-value" style="color:#16a34a">{high_conf}/{total}</div>
</div>
</div>
<div style="margin-top:12px">{dist_bars}</div>
<div class="output-path-box">
<div class="output-path-title">📁 Outputs saved to:</div>
<div class="output-path-dir">{run_dir}/</div>
<div class="output-path-files">
predictions/batch_stage_predictions.csv<br>
charts/batch_summary_chart.png<br>
reports/batch_summary.txt
</div>
</div>
</div>
"""
return str(csv_path), summary_html, results.head(20)
except Exception as exc:
return None, f"❌ Batch prediction error: {exc}", None
def predict_symptoms(individual_id, lmp_input, target_date_input, cycle_length):
"""
Cycle-based symptom forecasting (single person).
Returns (result_html, chart_fig, csv_download_path).
"""
if not lmp_input:
return "Please enter your Last Menstrual Period date.", None, None
try:
cycle_length = int(cycle_length) if cycle_length else 28
fore = SymptomCycleForecaster(cycle_length=cycle_length)
target_date = target_date_input if target_date_input else None
result = fore.predict_single(lmp=lmp_input, target_date=target_date)
cycle_day = result.get("cycle_day")
hot_prob = result.get("hotflash_prob", 0)
hot_pred = result.get("hotflash_pred", False)
mood_prob = result.get("mood_prob", 0)
mood_pred = result.get("mood_pred", False)
# Safe float helpers
hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0
mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0
# ── Create timestamped run directory ──────────────────────────────────
run_dir = _make_run_dir()
# ── Save cycle chart (PNG) ────────────────────────────────────────────
chart_path = run_dir / "charts" / "cycle_position.png"
chart_fig = _make_cycle_chart(
cycle_day, cycle_length, hp, mp, save_path=chart_path
)
# ── Save forecast CSV ─────────────────────────────────────────────────
csv_path = run_dir / "predictions" / "symptom_forecast.csv"
lmp_note = ""
try:
int(str(lmp_input).strip())
lmp_note = "LMP inferred as day-of-month; interpret with caution"
except (ValueError, TypeError):
pass
pd.DataFrame([{
"individual": individual_id or "N/A",
"LMP": lmp_input,
"date": target_date_input or datetime.now().strftime("%Y-%m-%d"),
"cycle_day": cycle_day,
"hotflash_prob": round(hp, 6),
"hotflash_pred": bool(hot_pred),
"mood_prob": round(mp, 6),
"mood_pred": bool(mood_pred),
"notes": lmp_note,
}]).to_csv(csv_path, index=False)
# ── Save text report ──────────────────────────────────────────────────
txt_path = run_dir / "reports" / "symptom_summary.txt"
_write_symptom_report(
txt_path, individual_id, lmp_input, target_date_input,
cycle_day, cycle_length, hp, hot_pred, mp, mood_pred,
)
# ── Build result HTML ─────────────────────────────────────────────────
def _prob_bar(prob, label, color):
pct = min(prob * 100, 100)
return f"""
<div style="margin:10px 0">
<div style="display:flex;justify-content:space-between;margin-bottom:4px">
<span style="color:#374151;font-size:14px">{label}</span>
<span style="color:{color};font-size:16px;font-weight:700">{pct:.0f}%</span>
</div>
<div style="background:#e2e8f0;border-radius:6px;height:10px">
<div style="background:{color};width:{pct}%;height:10px;
border-radius:6px;transition:width 0.5s"></div>
</div>
</div>"""
hot_alert = "🔴 Elevated risk" if hot_pred else "🟢 Low risk"
mood_alert = "🔴 Elevated risk" if mood_pred else "🟢 Low risk"
html = f"""
<div class="result-card">
<div style="color:#111827;font-size:18px;font-weight:700;margin-bottom:4px">
{individual_id or 'Forecast'} — Cycle Day {cycle_day or '?'}
</div>
<div style="color:#6b7280;font-size:13px;margin-bottom:20px">
LMP: {lmp_input} | Target: {target_date_input or 'Today'}
| Cycle: {cycle_length} days
</div>
{_prob_bar(hp, '🔥 Hot Flash Probability', '#ef4444')}
<div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{hot_alert}</div>
{_prob_bar(mp, '😤 Mood Change Probability', '#7c3aed')}
<div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{mood_alert}</div>
<div style="background:#f8fafc;border:1px solid #e2e8f0;border-radius:8px;
padding:12px;margin-top:14px;font-size:12px;color:#6b7280">
ℹ️ Probabilities are computed from a cycle-phase model (Gaussian heuristic).
They represent symptom likelihood based on cycle day, not a clinical diagnosis.
</div>
<div class="output-path-box">
<div class="output-path-title">📁 Outputs saved to:</div>
<div class="output-path-dir">{run_dir}/</div>
<div class="output-path-files">
charts/cycle_position.png<br>
predictions/symptom_forecast.csv<br>
reports/symptom_summary.txt
</div>
</div>
</div>
"""
return html, chart_fig, str(csv_path)
except Exception as exc:
return f"❌ Error: {exc}", None, None
def predict_symptoms_batch(file, lmp_col_name, date_col_name, cycle_length):
"""
Batch symptom forecasting from CSV.
Returns (csv_download_path, summary_html, preview_df).
"""
if file is None:
return None, "Please upload a CSV file.", None
file_path = _get_file_path(file)
try:
df = pd.read_csv(file_path)
except Exception as exc:
return None, f"Could not read CSV: {exc}", None
if lmp_col_name not in df.columns:
return None, (
f"LMP column '{lmp_col_name}' not found in CSV. "
f"Columns present: {list(df.columns)}"
), None
try:
cycle_length = int(cycle_length) if cycle_length else 28
fore = SymptomCycleForecaster(cycle_length=cycle_length)
date_col = date_col_name \
if (date_col_name and date_col_name in df.columns) else None
results = fore.predict_df(df, lmp_col=lmp_col_name, date_col=date_col)
# ── Add notes column (flag day-of-month LMP rows) ─────────────────────
def _lmp_note(val):
try:
int(str(val).strip())
return "LMP inferred as day-of-month; interpret with caution"
except (ValueError, TypeError):
return ""
results["notes"] = df[lmp_col_name].apply(_lmp_note)
# ── Create timestamped run directory ──────────────────────────────────
run_dir = _make_run_dir()
# ── Save predictions CSV ──────────────────────────────────────────────
csv_path = run_dir / "predictions" / "batch_symptom_forecast.csv"
results.to_csv(csv_path, index=False)
# ── Save text report ──────────────────────────────────────────────────
txt_path = run_dir / "reports" / "batch_symptom_summary.txt"
_write_batch_symptom_report(txt_path, results, cycle_length, run_dir)
# ── Build summary HTML ────────────────────────────────────────────────
total = len(results)
hot_flags = int(results["hotflash_pred"].sum()) \
if "hotflash_pred" in results.columns else 0
mood_flags = int(results["mood_pred"].sum()) \
if "mood_pred" in results.columns else 0
mean_hot = float(results["hotflash_prob"].mean()) \
if "hotflash_prob" in results.columns else 0.0
mean_mood = float(results["mood_prob"].mean()) \
if "mood_prob" in results.columns else 0.0
summary_html = f"""
<div class="result-card">
<div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
🌊 Symptom Forecast — {total} individuals
</div>
<div class="stat-grid-3">
<div class="stat-item">
<div class="stat-label">Total</div>
<div class="stat-value">{total}</div>
</div>
<div class="stat-item">
<div class="stat-label">🔥 Hot Flash Risk</div>
<div class="stat-value" style="color:#ef4444">{hot_flags}</div>
</div>
<div class="stat-item">
<div class="stat-label">😤 Mood Risk</div>
<div class="stat-value" style="color:#7c3aed">{mood_flags}</div>
</div>
</div>
<div class="stat-grid-2">
<div class="stat-item">
<div class="stat-label">Avg Hot Flash Prob</div>
<div class="stat-value" style="color:#ef4444;font-size:18px">
{mean_hot:.1%}
</div>
</div>
<div class="stat-item">
<div class="stat-label">Avg Mood Prob</div>
<div class="stat-value" style="color:#7c3aed;font-size:18px">
{mean_mood:.1%}
</div>
</div>
</div>
<div class="output-path-box">
<div class="output-path-title">📁 Outputs saved to:</div>
<div class="output-path-dir">{run_dir}/</div>
<div class="output-path-files">
predictions/batch_symptom_forecast.csv<br>
reports/batch_symptom_summary.txt
</div>
</div>
</div>
"""
return str(csv_path), summary_html, results
except Exception as exc:
return None, f"❌ Error: {exc}", None
# ── Feature reference & model status ─────────────────────────────────────────
def get_feature_reference() -> str:
feature_names = _metadata.get("feature_names", list(FEATURE_DESCRIPTIONS.keys()))
rows = ""
for i, f in enumerate(feature_names[:60]):
desc = FEATURE_DESCRIPTIONS.get(f, f.split("_")[0])
rows += f"""
<tr>
<td class="feature-num">{i + 1}</td>
<td class="feature-code">{f}</td>
<td class="feature-desc">{desc}</td>
</tr>"""
remaining = len(feature_names) - 60
if remaining > 0:
rows += f"""
<tr>
<td colspan="3" style="padding:8px;color:#9ca3af;font-size:12px;text-align:center">
… and {remaining} more features (one-hot encoded categories)
</td>
</tr>"""
return f"""
<div class="feature-table-wrap">
<div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
📋 Training Features ({len(feature_names)} total after encoding)
</div>
<table>
<thead>
<tr>
<th>#</th>
<th>Feature</th>
<th>Description</th>
</tr>
</thead>
<tbody>{rows}</tbody>
</table>
</div>
"""
def get_model_status() -> str:
if _MODEL_OK:
fc = len(_metadata.get("feature_names", []))
sc = _metadata.get("stage_classes", ["pre", "peri", "post"])
badges = "".join(
f'<span style="background:{STAGE_COLORS.get(s,"#607d8b")}18;'
f'color:{STAGE_COLORS.get(s,"#555")};padding:4px 12px;'
f'border-radius:20px;border:1px solid {STAGE_COLORS.get(s,"#607d8b")}44;'
f'font-size:13px;font-weight:600">{STAGE_EMOJI.get(s,"")} {s}</span>'
for s in sc
)
return f"""
<div class="status-card">
<div style="display:flex;align-items:center;gap:10px;margin-bottom:14px">
<span style="font-size:24px">✅</span>
<div>
<div style="color:#059669;font-size:16px;font-weight:700">
Models Loaded Successfully
</div>
<div style="color:#6b7280;font-size:12px">Ready for predictions</div>
</div>
</div>
<div class="stat-grid-3">
<div class="stat-item">
<div class="stat-label">Features</div>
<div class="stat-value">{fc}</div>
</div>
<div class="stat-item">
<div class="stat-label">Models</div>
<div class="stat-value">2</div>
</div>
<div class="stat-item">
<div class="stat-label">Stages</div>
<div class="stat-value">{len(sc)}</div>
</div>
</div>
<div style="margin-top:14px">
<div style="color:#6b7280;font-size:11px;text-transform:uppercase;
letter-spacing:0.5px;margin-bottom:6px">Available Stages</div>
<div style="display:flex;gap:8px;flex-wrap:wrap">{badges}</div>
</div>
</div>
"""
return f"""
<div class="status-card">
<div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">
<span style="font-size:24px">⚠️</span>
<div>
<div style="color:#dc2626;font-size:16px;font-weight:700">
Models Not Loaded
</div>
<div style="color:#6b7280;font-size:12px">{_MODEL_MSG}</div>
</div>
</div>
<div style="background:#fef2f2;border:1px solid #fecaca;border-radius:8px;
padding:12px;color:#9f1239;font-size:13px">
To train and save models:<br>
<code style="background:#1e293b;color:#a3e635;padding:4px 8px;border-radius:4px;
margin-top:6px;display:inline-block">python menopause.py</code>
<br><br>
This generates <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
color:#1e293b">swan_ml_output/rf_pipeline.pkl</code>,
<code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
color:#1e293b">lr_pipeline.pkl</code>, and
<code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
color:#1e293b">forecast_metadata.json</code>.
</div>
</div>
"""
# ── Education content ─────────────────────────────────────────────────────────
EDUCATION_HTML = """
<div class="edu-card">
<h2>🌸 Understanding Menopause</h2>
<p>Menopause is a natural biological process marking the end of menstrual cycles.
It is officially diagnosed after 12 consecutive months without a menstrual period
and typically occurs in women in their late 40s to early 50s.</p>
<h3>Three Stages</h3>
<div class="stage-cards-grid">
<div class="stage-card-pre">
<div style="color:#16a34a;font-weight:700;margin-bottom:8px">🟢 Pre-Menopause</div>
<p style="font-size:13px;margin:0;color:#374151">Regular ovarian function. Periods are predictable.
Hormones (estrogen, progesterone) follow a consistent monthly pattern.</p>
</div>
<div class="stage-card-peri">
<div style="color:#d97706;font-weight:700;margin-bottom:8px">🟡 Peri-Menopause</div>
<p style="font-size:13px;margin:0;color:#374151">Transition phase — usually begins in the mid-40s.
Hormone levels fluctuate. Periods become irregular.
Hot flashes and sleep issues may begin.</p>
</div>
<div class="stage-card-post">
<div style="color:#7c3aed;font-weight:700;margin-bottom:8px">🟣 Post-Menopause</div>
<p style="font-size:13px;margin:0;color:#374151">12+ months after the last period.
Lower estrogen levels. Risk factors for osteoporosis and
cardiovascular disease increase.</p>
</div>
</div>
<h3>Common Symptoms by Stage</h3>
<table style="width:100%;border-collapse:collapse;font-size:13px">
<thead>
<tr style="background:#f8fafc">
<th style="padding:8px;text-align:left;color:#6b7280;font-weight:600">Symptom</th>
<th style="padding:8px;text-align:center;color:#16a34a;font-weight:600">Pre</th>
<th style="padding:8px;text-align:center;color:#d97706;font-weight:600">Peri</th>
<th style="padding:8px;text-align:center;color:#7c3aed;font-weight:600">Post</th>
</tr>
</thead>
<tbody>
<tr style="border-bottom:1px solid #e2e8f0">
<td style="padding:8px;color:#374151">Hot flashes</td>
<td style="text-align:center;color:#9ca3af">–</td>
<td style="text-align:center">✅</td>
<td style="text-align:center">✅</td>
</tr>
<tr style="border-bottom:1px solid #e2e8f0">
<td style="padding:8px;color:#374151">Irregular periods</td>
<td style="text-align:center;color:#9ca3af">–</td>
<td style="text-align:center">✅</td>
<td style="text-align:center;color:#9ca3af">N/A</td>
</tr>
<tr style="border-bottom:1px solid #e2e8f0">
<td style="padding:8px;color:#374151">Sleep disturbances</td>
<td style="text-align:center;color:#6b7280">Mild</td>
<td style="text-align:center">✅</td>
<td style="text-align:center">✅</td>
</tr>
<tr style="border-bottom:1px solid #e2e8f0">
<td style="padding:8px;color:#374151">Mood changes</td>
<td style="text-align:center;color:#6b7280">PMS</td>
<td style="text-align:center">✅</td>
<td style="text-align:center;color:#6b7280">Possible</td>
</tr>
<tr style="border-bottom:1px solid #e2e8f0">
<td style="padding:8px;color:#374151">Vaginal dryness</td>
<td style="text-align:center;color:#9ca3af">–</td>
<td style="text-align:center;color:#6b7280">Possible</td>
<td style="text-align:center">✅</td>
</tr>
<tr>
<td style="padding:8px;color:#374151">Bone density changes</td>
<td style="text-align:center;color:#9ca3af">–</td>
<td style="text-align:center;color:#6b7280">Begins</td>
<td style="text-align:center">✅</td>
</tr>
</tbody>
</table>
<h3>About This Tool</h3>
<p style="font-size:13px">This application uses machine learning models trained on the
SWAN (Study of Women's Health Across the Nation) dataset — a landmark multisite,
multiethnic longitudinal study. The models were trained on self-reported symptom and
behavioral data to predict menopausal stage.</p>
<div class="disclaimer-box">
⚠️ <strong style="color:#d97706">Disclaimer:</strong>
This tool is for educational and research purposes only.
Predictions should not substitute clinical diagnosis.
Always consult a qualified healthcare provider for medical advice.
</div>
</div>
"""
# ── Gradio UI ─────────────────────────────────────────────────────────────────
CUSTOM_CSS = """
/* ── Force light mode — disable Gradio dark theme entirely ───────────── */
:root {
color-scheme: light only !important;
}
/* Fallback: if Gradio somehow sets .dark, override every key variable */
body.dark,
body.dark .gradio-container {
--body-background-fill: #f0f4f8 !important;
--background-fill-primary: #ffffff !important;
--background-fill-secondary: #f8fafc !important;
--border-color-primary: #e2e8f0 !important;
--border-color-accent: #3b82f6 !important;
--color-accent: #2563eb !important;
--color-accent-soft: #eff6ff !important;
--input-background-fill: #ffffff !important;
--input-border-color: #d1d5db !important;
--label-text-color: #374151 !important;
--block-label-text-color: #374151 !important;
--block-title-text-color: #111827 !important;
--body-text-color: #111827 !important;
--body-text-color-subdued: #6b7280 !important;
--link-text-color: #2563eb !important;
--button-primary-background-fill: #2563eb !important;
--button-primary-text-color: #ffffff !important;
--button-secondary-background-fill: #ffffff !important;
--button-secondary-text-color: #374151 !important;
--tab-text-color: #374151 !important;
--tab-text-color-selected: #2563eb !important;
color: #111827 !important;
background-color: #f0f4f8 !important;
}
/* ── Core ────────────────────────────────────────────────────────────── */
.gradio-container {
max-width: 1200px !important;
margin: 0 auto !important;
font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important;
background: #f0f4f8 !important;
}
/* ── Header banner ──────────────────────────────────────────────────── */
.header-banner {
background: linear-gradient(135deg, #faf5ff 0%, #fff0f9 50%, #eff6ff 100%);
border: 1px solid #e9d5ff;
border-radius: 16px;
padding: 28px 32px;
margin-bottom: 20px;
box-shadow: 0 2px 8px rgba(139,92,246,0.08);
position: relative;
overflow: hidden;
}
.header-banner::before {
content: '';
position: absolute;
top: -40%; right: -5%;
width: 280px; height: 280px;
background: radial-gradient(circle, rgba(139,92,246,0.08) 0%, transparent 70%);
pointer-events: none;
}
/* ── Reusable info boxes ─────────────────────────────────────────────── */
.info-box {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-left: 3px solid #3b82f6;
border-radius: 8px;
padding: 12px 16px;
color: #475569;
font-size: 13px;
margin-bottom: 16px;
line-height: 1.5;
}
.info-box code {
background: #e2e8f0;
color: #1e293b;
padding: 1px 5px;
border-radius: 3px;
font-family: monospace;
font-size: 0.9em;
}
.section-label {
color: #2563eb;
font-size: 12px;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.6px;
margin-bottom: 10px;
margin-top: 10px;
}
.format-hint {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 14px;
margin-top: 10px;
font-size: 12px;
color: #475569;
}
.format-hint-title { color: #2563eb; font-weight: 600; margin-bottom: 6px; }
.format-hint pre { color: #475569; margin: 0; font-size: 11px; white-space: pre-wrap; }
.format-hint-note { color: #94a3b8; font-size: 11px; margin-top: 8px; }
.placeholder-msg { color: #9ca3af; text-align: center; padding: 40px; font-size: 14px; }
.section-divider { border: none; border-top: 1px solid #e2e8f0; margin: 24px 0; }
.batch-section-label { color: #2563eb; font-size: 14px; font-weight: 600; margin-bottom: 12px; }
/* ── Result & summary cards ─────────────────────────────────────────── */
.result-card {
background: #ffffff;
border: 1px solid #e2e8f0;
border-radius: 16px;
padding: 24px;
box-shadow: 0 1px 4px rgba(0,0,0,0.06);
font-family: 'Segoe UI', system-ui, sans-serif;
}
.stat-grid-3 { display:grid; grid-template-columns:repeat(3,1fr); gap:12px; margin:14px 0; }
.stat-grid-2 { display:grid; grid-template-columns:1fr 1fr; gap:10px; margin-top:10px; }
.stat-item { background:#f8fafc; border:1px solid #e2e8f0; padding:12px; border-radius:8px; text-align:center; }
.stat-label { color:#6b7280; font-size:11px; text-transform:uppercase; letter-spacing:0.4px; }
.stat-value { color:#111827; font-size:22px; font-weight:700; line-height:1.2; margin-top:2px; }
.output-path-box { background:#f0fdf4; border:1px solid #bbf7d0; border-radius:8px; padding:10px 14px; margin-top:12px; font-family:monospace; }
.output-path-title { color:#059669; font-size:12px; font-weight:600; }
.output-path-dir { color:#065f46; font-size:11px; margin-top:4px; }
.output-path-files { color:#6b7280; font-size:10px; margin-top:4px; line-height:1.6; }
/* ── Code blocks ────────────────────────────────────────────────────── */
.code-block {
background: #1e293b;
color: #a3e635;
border-radius: 8px;
padding: 12px;
font-size: 12px;
font-family: monospace;
white-space: pre;
overflow-x: auto;
}
/* ── Setup instructions card ─────────────────────────────────────────── */
.setup-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; margin-top:16px; font-family:'Segoe UI',system-ui,sans-serif; }
.setup-title { color:#111827; font-size:15px; font-weight:700; margin-bottom:12px; }
.setup-step { color:#374151; font-size:13px; line-height:1.8; }
.setup-step strong { color:#2563eb; }
/* ── Education ──────────────────────────────────────────────────────── */
.edu-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:16px; padding:28px; font-family:'Segoe UI',system-ui,sans-serif; color:#374151; line-height:1.7; }
.edu-card h2 { color:#111827; font-size:22px; margin-top:0; }
.edu-card h3 { color:#7c3aed; font-size:16px; margin-top:20px; }
.stage-cards-grid { display:grid; grid-template-columns:repeat(3,1fr); gap:16px; margin:14px 0; }
.stage-card-pre { background:#f0fdf4; border-top:4px solid #16a34a; padding:16px; border-radius:10px; }
.stage-card-peri { background:#fffbeb; border-top:4px solid #d97706; padding:16px; border-radius:10px; }
.stage-card-post { background:#faf5ff; border-top:4px solid #7c3aed; padding:16px; border-radius:10px; }
.disclaimer-box { background:#fffbeb; border-left:3px solid #d97706; padding:12px 16px; border-radius:0 8px 8px 0; margin-top:14px; font-size:13px; color:#374151; }
/* ── Feature reference table ────────────────────────────────────────── */
.feature-table-wrap { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; max-height:500px; overflow-y:auto; font-family:'Segoe UI',system-ui,sans-serif; }
.feature-table-wrap table { width:100%; border-collapse:collapse; }
.feature-table-wrap thead tr { background:#f8fafc; }
.feature-table-wrap th { padding:8px; color:#6b7280; font-size:11px; text-align:left; text-transform:uppercase; letter-spacing:0.4px; }
.feature-table-wrap tr { border-bottom:1px solid #e2e8f0; }
.feature-table-wrap td { padding:8px; }
.feature-code { color:#2563eb; font-family:monospace; font-size:13px; }
.feature-desc { color:#374151; font-size:12px; }
.feature-num { color:#9ca3af; font-size:12px; }
/* ── Model status card ──────────────────────────────────────────────── */
.status-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; font-family:'Segoe UI',system-ui,sans-serif; }
/* ── Footer ─────────────────────────────────────────────────────────── */
.app-footer { text-align:center; color:#9ca3af; font-size:11px; margin-top:24px; padding:16px; border-top:1px solid #e2e8f0; }
.app-footer a { color:#2563eb; text-decoration:none; }
/* ── Responsive — Tablet (≤ 768 px) ────────────────────────────────── */
@media (max-width: 768px) {
.gradio-container { padding: 8px !important; }
.header-banner { padding: 16px 20px !important; margin-bottom: 12px !important; }
.header-status-badge { display: none !important; }
.stat-grid-3 { grid-template-columns: 1fr !important; }
.stat-grid-2 { grid-template-columns: 1fr !important; }
.stage-cards-grid { grid-template-columns: 1fr !important; }
}
/* ── Responsive — Mobile (≤ 480 px) ────────────────────────────────── */
@media (max-width: 480px) {
.header-banner h1 { font-size: 18px !important; }
.result-card { padding: 16px !important; }
.edu-card { padding: 16px !important; }
.setup-card { padding: 14px !important; }
}
"""
HEADER_HTML = """
<div class="header-banner">
<div style="display:flex;align-items:center;gap:16px;flex-wrap:wrap">
<div style="font-size:48px;flex-shrink:0">🌸</div>
<div style="flex:1;min-width:200px">
<h1 style="margin:0;font-size:26px;font-weight:800;
background:linear-gradient(135deg,#7c3aed,#db2777);
-webkit-background-clip:text;-webkit-text-fill-color:transparent">
SWAN Menopause Prediction
</h1>
<p style="margin:4px 0 0;color:#6b7280;font-size:13px">
AI-powered menopausal stage prediction &amp; symptom forecasting ·
Based on the SWAN dataset
</p>
</div>
<div class="header-status-badge" style="text-align:right;flex-shrink:0">
<div style="background:#ffffff;border:1px solid #e2e8f0;border-radius:8px;
padding:8px 16px;display:inline-block;box-shadow:0 1px 3px rgba(0,0,0,0.06)">
<div style="color:#9ca3af;font-size:10px;text-transform:uppercase;letter-spacing:1px">
Status
</div>
<div style="color:{color};font-size:13px;font-weight:600">{status}</div>
</div>
</div>
</div>
</div>
""".format(
color = "#059669" if _MODEL_OK else "#dc2626",
status = "Models Ready ✅" if _MODEL_OK else "Models Needed ⚠️",
)
# ── Force-light-mode JS (runs on every page load) ─────────────────────────────
# Removes Gradio's .dark class, locks localStorage to "light", and uses a
# MutationObserver to prevent the class from being re-applied — works on
# HuggingFace Spaces regardless of the user's OS/browser dark-mode setting.
FORCE_LIGHT_JS = """
function() {
const forceLightMode = () => {
if (document.body.classList.contains('dark')) {
document.body.classList.remove('dark');
}
};
// Apply immediately
forceLightMode();
// Lock Gradio's stored preference
try { localStorage.setItem('theme', 'light'); } catch(e) {}
// Watch for Gradio trying to re-add .dark and block it
new MutationObserver(function(mutations) {
mutations.forEach(function(m) {
if (m.attributeName === 'class') forceLightMode();
});
}).observe(document.body, { attributes: true, attributeFilter: ['class'] });
}
"""
# ── App builder ───────────────────────────────────────────────────────────────
def build_app():
with gr.Blocks(
css = CUSTOM_CSS,
js = FORCE_LIGHT_JS,
title = "SWAN Menopause Prediction",
theme = gr.themes.Soft(
primary_hue = "blue",
neutral_hue = "slate",
).set(
# ── Body ──────────────────────────────────────────────────────
body_background_fill = "#f0f4f8",
body_background_fill_dark = "#f0f4f8",
body_text_color = "#111827",
body_text_color_dark = "#111827",
body_text_color_subdued = "#6b7280",
body_text_color_subdued_dark = "#6b7280",
# ── Panel / block backgrounds ──────────────────────────────────
background_fill_primary = "#ffffff",
background_fill_primary_dark = "#ffffff",
background_fill_secondary = "#f8fafc",
background_fill_secondary_dark = "#f8fafc",
block_background_fill = "#ffffff",
block_background_fill_dark = "#ffffff",
block_border_color = "#e2e8f0",
block_border_color_dark = "#e2e8f0",
block_label_background_fill = "#f8fafc",
block_label_background_fill_dark= "#f8fafc",
block_label_text_color = "#374151",
block_label_text_color_dark = "#374151",
block_title_text_color = "#111827",
block_title_text_color_dark = "#111827",
# ── Inputs ────────────────────────────────────────────────────
input_background_fill = "#ffffff",
input_background_fill_dark = "#ffffff",
input_background_fill_focus = "#ffffff",
input_background_fill_focus_dark= "#ffffff",
input_border_color = "#d1d5db",
input_border_color_dark = "#d1d5db",
input_border_color_focus = "#3b82f6",
input_border_color_focus_dark = "#3b82f6",
input_placeholder_color = "#9ca3af",
input_placeholder_color_dark = "#9ca3af",
# ── Borders ────────────────────────────────────────────────────
border_color_primary = "#e2e8f0",
border_color_primary_dark = "#e2e8f0",
border_color_accent = "#3b82f6",
border_color_accent_dark = "#3b82f6",
# ── Buttons ────────────────────────────────────────────────────
button_primary_background_fill = "#2563eb",
button_primary_background_fill_dark = "#2563eb",
button_primary_background_fill_hover = "#1d4ed8",
button_primary_background_fill_hover_dark = "#1d4ed8",
button_primary_text_color = "#ffffff",
button_primary_text_color_dark = "#ffffff",
button_secondary_background_fill = "#ffffff",
button_secondary_background_fill_dark = "#ffffff",
button_secondary_background_fill_hover = "#f1f5f9",
button_secondary_background_fill_hover_dark="#f1f5f9",
button_secondary_text_color = "#374151",
button_secondary_text_color_dark = "#374151",
button_secondary_border_color = "#e2e8f0",
button_secondary_border_color_dark = "#e2e8f0",
# ── Checkbox / Radio ──────────────────────────────────────────
checkbox_background_color = "#ffffff",
checkbox_background_color_dark = "#ffffff",
checkbox_background_color_selected = "#2563eb",
checkbox_background_color_selected_dark = "#2563eb",
checkbox_border_color = "#d1d5db",
checkbox_border_color_dark = "#d1d5db",
checkbox_border_color_focus = "#3b82f6",
checkbox_border_color_focus_dark = "#3b82f6",
# ── Slider ────────────────────────────────────────────────────
slider_color = "#2563eb",
slider_color_dark = "#2563eb",
# ── Table ─────────────────────────────────────────────────────
table_odd_background_fill = "#f8fafc",
table_odd_background_fill_dark = "#f8fafc",
table_even_background_fill = "#ffffff",
table_even_background_fill_dark = "#ffffff",
table_border_color = "#e2e8f0",
table_border_color_dark = "#e2e8f0",
# ── Links ─────────────────────────────────────────────────────
link_text_color = "#2563eb",
link_text_color_dark = "#2563eb",
link_text_color_hover = "#1d4ed8",
link_text_color_hover_dark = "#1d4ed8",
link_text_color_visited = "#7c3aed",
link_text_color_visited_dark = "#7c3aed",
# ── Accent ────────────────────────────────────────────────────
color_accent_soft = "#eff6ff",
color_accent_soft_dark = "#eff6ff",
),
) as app:
gr.HTML(HEADER_HTML)
with gr.Tabs():
# ── TAB 1: Single Stage Prediction ────────────────────────────────
with gr.Tab("🔮 Stage Prediction"):
gr.HTML("""
<div class="info-box">
Fill in the fields below to predict menopausal stage for a single individual.
All fields are optional — the pipeline handles missing values automatically.
A timestamped output folder is created in
<code>swan_ml_output/</code> for every run.
</div>""")
with gr.Row():
# ── Input column ──────────────────────────────────────────
with gr.Column(scale=2):
with gr.Group():
gr.HTML('<div class="section-label">Demographics</div>')
with gr.Row():
age = gr.Slider(
minimum=35, maximum=75, value=48, step=1,
label="Age (AGE7)",
)
race = gr.Dropdown(
choices=[1, 2, 3, 4, 5], value=1,
label="Race (RACE)",
info="1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic",
)
langint = gr.Dropdown(
choices=[1, 2, 3], value=1,
label="Interview Language (LANGINT7)",
info="1=English, 2=Spanish, 3=Other",
)
with gr.Group():
gr.HTML('<div class="section-label">Vasomotor Symptoms</div>')
with gr.Row():
hot_flash = gr.Slider(
minimum=1, maximum=5, value=1, step=1,
label="Hot Flash Severity (HOTFLAS7)",
info="1=None, 5=Very severe",
)
num_hot_flash = gr.Slider(
minimum=0, maximum=15, value=0, step=1,
label="# Hot Flashes/Week (NUMHOTF7)",
)
bothersome_hf = gr.Slider(
minimum=1, maximum=4, value=1, step=1,
label="How Bothersome (BOTHOTF7)",
info="1=Not at all, 4=Extremely",
)
with gr.Group():
gr.HTML('<div class="section-label">Sleep &amp; Mood</div>')
with gr.Row():
sleep_quality = gr.Slider(
minimum=1, maximum=5, value=2, step=1,
label="Sleep Quality (SLEEPQL7)",
info="1=Very good, 5=Very poor",
)
depression = gr.Slider(
minimum=0, maximum=4, value=0, step=1,
label="Depression Indicator (DEPRESS7)",
info="0=No, higher=more severe",
)
with gr.Row():
mood_change = gr.Slider(
minimum=1, maximum=5, value=1, step=1,
label="Mood Changes (MOODCHG7)",
info="1=None, 5=Severe",
)
irritability = gr.Slider(
minimum=1, maximum=5, value=1, step=1,
label="Irritability (IRRITAB7)",
)
with gr.Group():
gr.HTML('<div class="section-label">Physical &amp; Gynaecological</div>')
with gr.Row():
pain = gr.Slider(
minimum=0, maximum=5, value=0, step=1,
label="Pain Indicator (PAIN17)",
)
abbleed = gr.Dropdown(
choices=[0, 1, 2], value=0,
label="Abnormal Bleeding (ABBLEED7)",
info="0=No, 1=Yes, 2=Unsure",
)
with gr.Row():
vaginal_dryness = gr.Slider(
minimum=0, maximum=5, value=0, step=1,
label="Vaginal Dryness (VAGINDR7)",
)
lmp_day = gr.Number(
value=None,
label="LMP Day (LMPDAY7)",
info="Day of last menstrual period (optional)",
)
model_choice = gr.Radio(
choices=["RandomForest", "LogisticRegression"],
value="RandomForest",
label="Model",
info="RandomForest: higher accuracy | "
"LogisticRegression: more interpretable",
)
predict_btn = gr.Button(
"🔮 Predict Stage", variant="primary", size="lg"
)
# ── Output column ─────────────────────────────────────────
with gr.Column(scale=3):
result_html = gr.HTML(
'<div class="placeholder-msg">Fill in the form and click Predict Stage</div>'
)
result_chart = gr.Plot(label="Stage Probabilities")
confidence_note = gr.Textbox(
label="Confidence Note", interactive=False, lines=2
)
compare_html = gr.HTML()
stage_download = gr.File(
label="Download Prediction CSV", interactive=False
)
predict_btn.click(
fn = predict_single_stage,
inputs = [
age, race, langint,
hot_flash, num_hot_flash, bothersome_hf,
sleep_quality, depression, mood_change, irritability,
pain, abbleed, vaginal_dryness, lmp_day,
model_choice,
],
outputs = [
result_html, result_chart, confidence_note,
compare_html, stage_download,
],
)
# ── TAB 2: Batch Stage Prediction ─────────────────────────────────
with gr.Tab("📁 Batch Stage Prediction"):
gr.HTML("""
<div class="info-box">
Upload a CSV file with individual feature values for batch prediction.
Results + charts + a summary report are saved to a timestamped folder
inside <code>swan_ml_output/</code>.
</div>""")
with gr.Row():
with gr.Column(scale=1):
batch_file = gr.File(
label="Upload stage_input.csv",
file_types=[".csv"],
)
batch_model = gr.Radio(
choices=["RandomForest", "LogisticRegression"],
value="RandomForest",
label="Model",
)
gr.HTML("""
<div class="format-hint">
<div class="format-hint-title">Expected CSV Format</div>
<pre>individual,AGE7,RACE,HOTFLAS7,...
Person_001,48,1,2,...
Person_002,52,2,1,...</pre>
<div class="format-hint-note">
See the test-csv/ folder for an approved example.
</div>
</div>""")
batch_predict_btn = gr.Button(
"🚀 Run Batch Prediction", variant="primary"
)
with gr.Column(scale=2):
batch_summary_html = gr.HTML(
'<div class="placeholder-msg">Upload a CSV to begin</div>'
)
batch_download = gr.File(
label="Download Predictions CSV", interactive=False
)
batch_results_df = gr.DataFrame(
label="Results Preview (first 20 rows)",
interactive=False,
)
batch_predict_btn.click(
fn = predict_batch_stage,
inputs = [batch_file, batch_model],
outputs = [batch_download, batch_summary_html, batch_results_df],
)
# ── TAB 3: Symptom Forecast ───────────────────────────────────────
with gr.Tab("🌊 Symptom Forecast"):
gr.HTML("""
<div class="info-box">
Predict hot flash and mood change probability based on cycle day
(calculated from Last Menstrual Period date).
All outputs are saved to a timestamped folder inside
<code>swan_ml_output/</code>.
</div>""")
with gr.Row():
with gr.Column(scale=1):
sym_individual = gr.Textbox(
label="Individual ID (optional)",
placeholder="e.g., Patient_001",
)
sym_lmp = gr.Textbox(
label="Last Menstrual Period (LMP)",
placeholder="2026-01-15 or 15 (day of month)",
info="Full date (YYYY-MM-DD) or day-of-month integer",
)
sym_date = gr.Textbox(
label="Target Date (optional)",
placeholder="2026-02-27 (defaults to today)",
info="Date to forecast for (YYYY-MM-DD)",
)
sym_cycle = gr.Slider(
minimum=21, maximum=40, value=28, step=1,
label="Cycle Length (days)",
)
sym_predict_btn = gr.Button(
"🌊 Forecast Symptoms", variant="primary"
)
with gr.Column(scale=2):
sym_result_html = gr.HTML(
'<div class="placeholder-msg">Enter LMP date and click Forecast</div>'
)
sym_chart = gr.Plot(label="Cycle Position")
sym_download = gr.File(
label="Download Forecast CSV", interactive=False
)
sym_predict_btn.click(
fn = predict_symptoms,
inputs = [sym_individual, sym_lmp, sym_date, sym_cycle],
outputs = [sym_result_html, sym_chart, sym_download],
)
gr.HTML('<hr class="section-divider">')
gr.HTML('<div class="batch-section-label">📁 Batch Symptom Forecasting</div>')
with gr.Row():
with gr.Column(scale=1):
sym_batch_file = gr.File(
label="Upload symptoms_input.csv",
file_types=[".csv"],
)
sym_lmp_col = gr.Textbox(
label="LMP Column Name", value="LMP"
)
sym_date_col = gr.Textbox(
label="Date Column Name (optional)", value="date"
)
sym_cycle_batch = gr.Slider(
minimum=21, maximum=40, value=28, step=1,
label="Default Cycle Length",
)
sym_batch_btn = gr.Button(
"🌊 Run Batch Forecast", variant="primary"
)
with gr.Column(scale=2):
sym_batch_summary = gr.HTML(
'<div class="placeholder-msg">Upload a CSV to begin</div>'
)
sym_batch_download = gr.File(
label="Download Symptom Forecast CSV", interactive=False
)
sym_batch_df = gr.DataFrame(
label="Results Preview",
interactive=False,
)
sym_batch_btn.click(
fn = predict_symptoms_batch,
inputs = [
sym_batch_file, sym_lmp_col,
sym_date_col, sym_cycle_batch,
],
outputs = [sym_batch_download, sym_batch_summary, sym_batch_df],
)
# ── TAB 4: Education ──────────────────────────────────────────────
with gr.Tab("📚 Menopause Education"):
gr.HTML(EDUCATION_HTML)
# ── TAB 5: Feature Reference ──────────────────────────────────────
with gr.Tab("🔬 Feature Reference"):
gr.HTML("""
<div class="info-box">
Canonical list of features used by the trained models
(from <code>forecast_metadata.json</code>).
For batch CSV uploads, column names must match these feature names.
</div>""")
gr.HTML(get_feature_reference())
# ── TAB 6: Model Status ───────────────────────────────────────────
with gr.Tab("⚙️ Model Status"):
gr.HTML(get_model_status())
gr.HTML("""
<div class="setup-card">
<div class="setup-title">🚀 Setup Instructions</div>
<div class="setup-step">
<p><strong>Step 1 — Train models:</strong></p>
<pre class="code-block">python menopause.py</pre>
<p><strong>Step 2 — Verify artifacts:</strong></p>
<pre class="code-block">ls swan_ml_output/
# rf_pipeline.pkl lr_pipeline.pkl forecast_metadata.json</pre>
<p><strong>Step 3 — Run this app:</strong></p>
<pre class="code-block">python app.py</pre>
<p><strong>Step 4 — Deploy on Hugging Face Spaces:</strong></p>
<pre class="code-block">git lfs install
git lfs track "*.pkl"
git add .
git commit -m "SWAN menopause prediction app"
git push</pre>
<p><strong>Output folder structure (per run):</strong></p>
<pre class="code-block">swan_ml_output/
&lt;YYYYMMDD_HHMMSS&gt;/
charts/ &larr; PNG visualizations
predictions/ &larr; CSV result files
reports/ &larr; TXT summary reports</pre>
</div>
</div>
""")
gr.HTML("""
<div class="app-footer">
SWAN Menopause Prediction App · Built with Gradio ·
For research &amp; educational use only · Not for clinical diagnosis ·
<a href="https://www.swanstudy.org/" target="_blank">SWAN Study</a>
</div>""")
return app
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo = build_app()
demo.launch(
server_name = "0.0.0.0",
server_port = int(os.environ.get("PORT", 7860)),
share = False,
show_error = True,
)