Spaces:
Sleeping
Sleeping
| """ | |
| SWAN Menopause Stage Prediction & Forecasting — Gradio UI | |
| Hugging Face Spaces deployment-ready. | |
| Run locally: python app.py | |
| Deploy: Push to a HF Space with SDK=gradio | |
| Output structure (per execution): | |
| swan_ml_output/ | |
| <YYYYMMDD_HHMMSS>/ | |
| charts/ ← PNG visualizations | |
| predictions/ ← CSV result files | |
| reports/ ← TXT summary reports | |
| """ | |
| import os | |
| import json | |
| import warnings | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| warnings.filterwarnings("ignore") | |
| # ── Gradio ──────────────────────────────────────────────────────────────────── | |
| import gradio as gr | |
| # ── Local ML module ─────────────────────────────────────────────────────────── | |
| try: | |
| from menopause import ( | |
| MenopauseForecast, | |
| SymptomCycleForecaster, | |
| load_forecast_model, | |
| ) | |
| _MODULE_AVAILABLE = True | |
| except ImportError: | |
| _MODULE_AVAILABLE = False | |
| # ── Model loading ───────────────────────────────────────────────────────────── | |
| FORECAST_DIR = os.environ.get("FORECAST_DIR", "swan_ml_output") | |
| OUTPUT_BASE = Path(FORECAST_DIR) | |
| _forecast: Optional[MenopauseForecast] = None # type: ignore[type-arg] | |
| _metadata: dict = {} | |
| def _load_models(): | |
| """Attempt to load saved joblib pipelines. Returns (success, message).""" | |
| global _forecast, _metadata | |
| if not _MODULE_AVAILABLE: | |
| return False, "menopause.py not found. Make sure it is in the same directory." | |
| meta_path = Path(FORECAST_DIR) / "forecast_metadata.json" | |
| rf_path = Path(FORECAST_DIR) / "rf_pipeline.pkl" | |
| lr_path = Path(FORECAST_DIR) / "lr_pipeline.pkl" | |
| if not all(p.exists() for p in (meta_path, rf_path, lr_path)): | |
| return ( | |
| False, | |
| f"Model artifacts not found in '{FORECAST_DIR}'. " | |
| "Run `python menopause.py` to train and save the models first.", | |
| ) | |
| try: | |
| _forecast = load_forecast_model(FORECAST_DIR) | |
| with open(meta_path) as fh: | |
| _metadata = json.load(fh) | |
| return True, f"✅ Models loaded — {len(_metadata.get('feature_names', []))} features" | |
| except Exception as exc: | |
| return False, f"Error loading models: {exc}" | |
| _MODEL_OK, _MODEL_MSG = _load_models() | |
| # ── Output directory management ─────────────────────────────────────────────── | |
| def _make_run_dir() -> Path: | |
| """Create and return a unique timestamped run directory under swan_ml_output/.""" | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| run_dir = OUTPUT_BASE / ts | |
| (run_dir / "charts").mkdir(parents=True, exist_ok=True) | |
| (run_dir / "predictions").mkdir(parents=True, exist_ok=True) | |
| (run_dir / "reports").mkdir(parents=True, exist_ok=True) | |
| return run_dir | |
| def _get_file_path(file_obj) -> Optional[str]: | |
| """ | |
| Safely extract a file-system path from a Gradio file component value. | |
| Gradio ≤ 3.x → returns a file-like object with a .name attribute. | |
| Gradio 4.x → returns a str path (or NamedString subclass). | |
| This helper handles both. | |
| """ | |
| if file_obj is None: | |
| return None | |
| if hasattr(file_obj, "name"): | |
| return file_obj.name | |
| return str(file_obj) | |
| # ── Constants & helpers ─────────────────────────────────────────────────────── | |
| STAGE_COLORS = {"pre": "#16a34a", "peri": "#d97706", "post": "#7c3aed"} | |
| STAGE_EMOJI = {"pre": "🟢", "peri": "🟡", "post": "🟣"} | |
| STAGE_LABELS = { | |
| "pre": "Pre-Menopausal", | |
| "peri": "Peri-Menopausal", | |
| "post": "Post-Menopausal", | |
| } | |
| STAGE_INFO = { | |
| "pre": { | |
| "title": "Pre-Menopausal", | |
| "description": "Regular menstrual cycles with typical hormonal fluctuations. Ovarian function is normal.", | |
| "symptoms": ["Regular periods", "Normal hormone levels", "Potential mild PMS"], | |
| "guidance": "Maintain regular check-ups. Track your cycle and note any changes.", | |
| }, | |
| "peri": { | |
| "title": "Peri-Menopausal (Transition)", | |
| "description": "Hormonal changes begin — estrogen and progesterone levels fluctuate. Cycles become irregular.", | |
| "symptoms": ["Irregular periods", "Hot flashes", "Sleep disturbances", "Mood changes", "Night sweats"], | |
| "guidance": "Consult your healthcare provider. Lifestyle adjustments (diet, exercise, sleep) can help.", | |
| }, | |
| "post": { | |
| "title": "Post-Menopausal", | |
| "description": "12+ months since last menstrual period. Estrogen remains at consistently lower levels.", | |
| "symptoms": ["No periods", "Possible continued hot flashes", "Vaginal dryness", "Bone density changes"], | |
| "guidance": "Focus on bone health, cardiovascular health, and regular screenings. Discuss HRT options.", | |
| }, | |
| } | |
| # Feature descriptions keyed by the model's canonical feature names | |
| FEATURE_DESCRIPTIONS = { | |
| "PAIN17": "Pain indicator (visit-specific)", | |
| "PAINTW17": "Pain two-week indicator", | |
| "PAIN27": "Secondary pain indicator", | |
| "PAINTW27": "Secondary pain two-week indicator", | |
| "SLEEP17": "Sleep disturbance pattern 1", | |
| "SLEEP27": "Sleep disturbance pattern 2", | |
| "BCOHOTH7": "Birth control — other method", | |
| "EXERCIS7": "General exercise indicator", | |
| "EXERHAR7": "Vigorous exercise", | |
| "EXEROST7": "Osteoporosis exercise", | |
| "EXERMEN7": "Exercise — mental health", | |
| "EXERLOO7": "Exercise lookalike", | |
| "EXERMEM7": "Exercise — memory", | |
| "EXERPER7": "Exercise perception", | |
| "EXERGEN7": "General exercise type", | |
| "EXERWGH7": "Weight exercise", | |
| "EXERADV7": "Exercise advice indicator", | |
| "EXEROTH7": "Other exercise", | |
| "EXERSPE7": "Specific exercise", | |
| "ABBLEED7": "Abnormal bleeding (0=no, 1=yes)", # ← correct feature name | |
| "BLEEDNG7": "Bleeding pattern", | |
| "LMPDAY7": "Last menstrual period day", | |
| "DEPRESS7": "Depression indicator", | |
| "SEX17": "Sexual activity indicator 1", | |
| "SEX27": "Sexual activity indicator 2", | |
| "SEX37": "Sexual activity indicator 3", | |
| "SEX47": "Sexual activity indicator 4", | |
| "SEX57": "Sexual activity indicator 5", | |
| "SEX67": "Sexual activity indicator 6", | |
| "SEX77": "Sexual activity indicator 7", | |
| "SEX87": "Sexual activity indicator 8", | |
| "SEX97": "Sexual activity indicator 9", | |
| "SEX107": "Sexual activity indicator 10", | |
| "SEX117": "Sexual activity indicator 11", | |
| "SEX127": "Sexual activity indicator 12", | |
| "SMOKERE7": "Smoking status", | |
| "HOTFLAS7": "Hot flash severity (1=none, 5=very severe)", | |
| "NUMHOTF7": "Number of hot flashes per week", | |
| "BOTHOTF7": "How bothersome are hot flashes", | |
| "IRRITAB7": "Irritability level", | |
| "VAGINDR7": "Vaginal dryness", | |
| "MOODCHG7": "Mood change frequency", | |
| "SLEEPQL7": "Sleep quality score", | |
| "PHYSILL7": "Physical illness indicators", | |
| "HOTHEAD7": "Hot flashes with headache", | |
| "EXER12H7": "Exercise in last 12 hours", | |
| "ALCO24H7": "Alcohol in last 24h", | |
| "AGE7": "Age (years)", | |
| "RACE": "Race (1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic)", | |
| "LANGINT7": "Interview language indicator", | |
| } | |
| def _confidence_color(conf: float) -> str: | |
| if conf >= 0.8: | |
| return "#16a34a" | |
| elif conf >= 0.6: | |
| return "#d97706" | |
| return "#dc2626" | |
| # ── Chart builders ──────────────────────────────────────────────────────────── | |
| def _make_proba_chart( | |
| probabilities: dict, | |
| predicted_stage: str, | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """Horizontal bar chart for stage probabilities. Optionally saves PNG.""" | |
| fig, ax = plt.subplots(figsize=(6, 3.5)) | |
| fig.patch.set_facecolor("#1a1a2e") | |
| ax.set_facecolor("#16213e") | |
| stages = list(probabilities.keys()) | |
| probs = [probabilities[s] * 100 for s in stages] | |
| colors = [STAGE_COLORS.get(s, "#607d8b") for s in stages] | |
| edge_colors = ["white" if s == predicted_stage else "none" for s in stages] | |
| lws = [2.5 if s == predicted_stage else 0 for s in stages] | |
| bars = ax.barh(stages, probs, color=colors, edgecolor=edge_colors, | |
| linewidth=lws, height=0.5, zorder=3) | |
| for bar, prob in zip(bars, probs): | |
| ax.text( | |
| min(prob + 1, 98), bar.get_y() + bar.get_height() / 2, | |
| f"{prob:.1f}%", | |
| va="center", ha="left", color="white", fontsize=11, fontweight="bold", | |
| ) | |
| labels = [STAGE_LABELS.get(s, s) for s in stages] | |
| ax.set_yticks(range(len(stages))) | |
| ax.set_yticklabels(labels, color="white", fontsize=10) | |
| ax.set_xlim(0, 105) | |
| ax.tick_params(colors="white", labelsize=11) | |
| ax.spines[["top", "right", "left", "bottom"]].set_visible(False) | |
| ax.xaxis.set_visible(False) | |
| for spine in ax.spines.values(): | |
| spine.set_color("#333") | |
| ax.set_title("Stage Probabilities", color="white", fontsize=12, | |
| pad=10, fontweight="bold") | |
| ax.grid(axis="x", color="#333", linestyle="--", linewidth=0.5, zorder=0) | |
| fig.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=150, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| return fig | |
| def _make_cycle_chart( | |
| cycle_day: int, | |
| cycle_length: int = 28, | |
| hot_prob: float = None, | |
| mood_prob: float = None, | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """Circular cycle-day visualization. Optionally saves PNG.""" | |
| fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) | |
| fig.patch.set_facecolor("#1a1a2e") | |
| ax.set_facecolor("#16213e") | |
| days = np.linspace(0, 2 * np.pi, cycle_length, endpoint=False) | |
| for i, d in enumerate(days): | |
| phase = i / cycle_length | |
| color = plt.cm.RdYlGn(1 - phase) | |
| ax.bar(d, 1, width=2 * np.pi / cycle_length * 0.9, | |
| bottom=0.5, color=color, alpha=0.4, zorder=1) | |
| if cycle_day is not None: | |
| angle = (cycle_day - 1) / cycle_length * 2 * np.pi | |
| ax.scatter([angle], [1.05], s=200, color="#ff6b6b", zorder=5, linewidths=2) | |
| ax.annotate( | |
| f"Day {cycle_day}", | |
| xy=(angle, 1.05), xytext=(0, 0), | |
| textcoords="offset points", ha="center", va="center", | |
| color="white", fontsize=12, fontweight="bold", | |
| ) | |
| ax.set_rticks([]) | |
| ax.set_xticks([i * 2 * np.pi / 4 for i in range(4)]) | |
| ax.set_xticklabels(["Day 1", "Day 7", "Day 14", "Day 21"], | |
| color="#aaa", fontsize=9) | |
| ax.set_yticklabels([]) | |
| ax.spines["polar"].set_color("#333") | |
| ax.grid(color="#333", linewidth=0.5) | |
| title = "Cycle Position" | |
| if hot_prob is not None: | |
| title += f"\n🔥 {hot_prob:.0%} 😤 {mood_prob:.0%}" | |
| ax.set_title(title, color="white", fontsize=11, pad=20, fontweight="bold") | |
| fig.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=150, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| return fig | |
| def _make_batch_summary_chart(results_df: pd.DataFrame, | |
| save_path: Optional[Path] = None) -> None: | |
| """Stage distribution + confidence histogram for batch runs. Saves PNG.""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) | |
| fig.patch.set_facecolor("#1a1a2e") | |
| # Stage distribution pie | |
| stage_counts = results_df["predicted_stage"].value_counts() | |
| colors = [STAGE_COLORS.get(s, "#607d8b") for s in stage_counts.index] | |
| ax1.set_facecolor("#16213e") | |
| wedges, texts, autotexts = ax1.pie( | |
| stage_counts.values, labels=stage_counts.index, | |
| colors=colors, autopct="%1.0f%%", | |
| textprops={"color": "white", "fontsize": 10}, | |
| ) | |
| for at in autotexts: | |
| at.set_color("white") | |
| ax1.set_title("Stage Distribution", color="white", fontsize=11, fontweight="bold") | |
| # Confidence histogram | |
| ax2.set_facecolor("#16213e") | |
| if "confidence" in results_df.columns: | |
| conf = results_df["confidence"].dropna() | |
| ax2.hist(conf, bins=min(10, len(conf)), color="#3B82F6", | |
| edgecolor="#1a1a2e", alpha=0.8) | |
| ax2.axvline(0.8, color="#4CAF50", linestyle="--", | |
| linewidth=1.5, label="High (0.80)") | |
| ax2.axvline(0.6, color="#FF9800", linestyle="--", | |
| linewidth=1.5, label="Med (0.60)") | |
| ax2.legend(fontsize=8, labelcolor="white", facecolor="#0d0d1a") | |
| ax2.set_xlabel("Confidence", color="#aaa", fontsize=9) | |
| ax2.set_ylabel("Count", color="#aaa", fontsize=9) | |
| ax2.tick_params(colors="white", labelsize=9) | |
| for sp in ["top", "right"]: | |
| ax2.spines[sp].set_visible(False) | |
| for sp in ["left", "bottom"]: | |
| ax2.spines[sp].set_color("#333") | |
| ax2.set_title("Confidence Distribution", color="white", | |
| fontsize=11, fontweight="bold") | |
| fig.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=150, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| # ── Text report writers ─────────────────────────────────────────────────────── | |
| def _write_single_stage_report( | |
| path: Path, | |
| stage: str, | |
| confidence: float, | |
| probabilities: dict, | |
| model: str, | |
| comparison: dict, | |
| input_features: dict, | |
| ): | |
| lines = [ | |
| "=" * 60, | |
| "SWAN MENOPAUSE STAGE PREDICTION REPORT", | |
| f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", | |
| "=" * 60, | |
| "", | |
| f"Predicted Stage : {STAGE_LABELS.get(stage, stage)}", | |
| f"Model : {model}", | |
| f"Confidence : {confidence:.1%}", | |
| "", | |
| "Stage Probabilities:", | |
| ] | |
| for s, p in probabilities.items(): | |
| bar = "█" * int(p * 20) | |
| lines.append(f" {s:<6} : {p:.4f} {bar}") | |
| lines += [ | |
| "", | |
| "Model Comparison:", | |
| f" RandomForest → {comparison['RandomForest']['stage']}" | |
| f" ({comparison['RandomForest'].get('confidence', 0):.1%})", | |
| f" LogisticRegression → {comparison['LogisticRegression']['stage']}" | |
| f" ({comparison['LogisticRegression'].get('confidence', 0):.1%})", | |
| "", | |
| "Input Features (non-NaN):", | |
| ] | |
| for k, v in input_features.items(): | |
| if v is not None and not (isinstance(v, float) and np.isnan(v)): | |
| lines.append(f" {k:<12} = {v}") | |
| lines += [ | |
| "", | |
| "⚠️ For research/educational use only. Not a clinical diagnosis.", | |
| "=" * 60, | |
| ] | |
| path.write_text("\n".join(lines), encoding="utf-8") | |
| def _write_batch_report( | |
| path: Path, | |
| results: pd.DataFrame, | |
| model: str, | |
| run_dir: Path, | |
| ): | |
| total = len(results) | |
| dist = results["predicted_stage"].value_counts().to_dict() \ | |
| if "predicted_stage" in results.columns else {} | |
| if "confidence" in results.columns: | |
| conf = results["confidence"] | |
| mean_c = conf.mean(); min_c = conf.min(); max_c = conf.max() | |
| high = int((conf > 0.8).sum()) | |
| medium = int(((conf > 0.6) & (conf <= 0.8)).sum()) | |
| low = int((conf <= 0.6).sum()) | |
| else: | |
| mean_c = min_c = max_c = high = medium = low = 0 | |
| lines = [ | |
| "=" * 60, | |
| "SWAN BATCH STAGE PREDICTION REPORT", | |
| f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", | |
| f"Model : {model}", | |
| "=" * 60, | |
| "", | |
| f"Total Individuals : {total}", | |
| "", | |
| "Stage Distribution:", | |
| ] | |
| for stage in ["pre", "peri", "post"]: | |
| count = dist.get(stage, 0) | |
| pct = count / total * 100 if total else 0 | |
| lines.append(f" {stage:<6} : {count} ({pct:.1f}%)") | |
| lines += [ | |
| "", | |
| "Confidence Scores:", | |
| f" Mean : {mean_c:.4f}", | |
| f" Min : {min_c:.4f}", | |
| f" Max : {max_c:.4f}", | |
| "", | |
| "Confidence Distribution:", | |
| f" High (>0.80) : {high}/{total} ({high/total*100:.1f}%)" if total else " N/A", | |
| f" Medium (0.60-0.80) : {medium}/{total} ({medium/total*100:.1f}%)" if total else " N/A", | |
| f" Low (≤0.60) : {low}/{total} ({low/total*100:.1f}%)" if total else " N/A", | |
| "", | |
| f"Output Directory : {run_dir}", | |
| "", | |
| "⚠️ For research/educational use only. Not a clinical diagnosis.", | |
| "=" * 60, | |
| ] | |
| path.write_text("\n".join(lines), encoding="utf-8") | |
| def _write_symptom_report( | |
| path: Path, | |
| individual_id: str, | |
| lmp: str, | |
| target_date: str, | |
| cycle_day: int, | |
| cycle_length: int, | |
| hot_prob: float, | |
| hot_pred: bool, | |
| mood_prob: float, | |
| mood_pred: bool, | |
| ): | |
| hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0 | |
| mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0 | |
| lines = [ | |
| "=" * 60, | |
| "SWAN SYMPTOM CYCLE FORECAST REPORT", | |
| f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", | |
| "=" * 60, | |
| "", | |
| f"Individual : {individual_id or 'N/A'}", | |
| f"LMP : {lmp}", | |
| f"Target Date : {target_date or 'Today'}", | |
| f"Cycle Length : {cycle_length} days", | |
| f"Cycle Day : {cycle_day}", | |
| "", | |
| "Symptom Probabilities:", | |
| f" Hot Flash : {hp:.4f} {'[ELEVATED RISK]' if hot_pred else '[LOW RISK]'}", | |
| f" Mood Change : {mp:.4f} {'[ELEVATED RISK]' if mood_pred else '[LOW RISK]'}", | |
| "", | |
| "⚠️ For research/educational use only. Not a clinical diagnosis.", | |
| "=" * 60, | |
| ] | |
| path.write_text("\n".join(lines), encoding="utf-8") | |
| def _write_batch_symptom_report( | |
| path: Path, | |
| results: pd.DataFrame, | |
| cycle_length: int, | |
| run_dir: Path, | |
| ): | |
| total = len(results) | |
| hot_flags = int(results["hotflash_pred"].sum()) \ | |
| if "hotflash_pred" in results.columns else 0 | |
| mood_flags = int(results["mood_pred"].sum()) \ | |
| if "mood_pred" in results.columns else 0 | |
| mean_hot = float(results["hotflash_prob"].mean()) \ | |
| if "hotflash_prob" in results.columns else 0.0 | |
| mean_mood = float(results["mood_prob"].mean()) \ | |
| if "mood_prob" in results.columns else 0.0 | |
| lines = [ | |
| "=" * 60, | |
| "SWAN BATCH SYMPTOM FORECAST REPORT", | |
| f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", | |
| f"Cycle Length : {cycle_length} days", | |
| "=" * 60, | |
| "", | |
| f"Total Individuals : {total}", | |
| f"Hot Flash Risk : {hot_flags}/{total} elevated", | |
| f"Mood Change Risk : {mood_flags}/{total} elevated", | |
| f"Avg Hot Flash Prob : {mean_hot:.4f}", | |
| f"Avg Mood Prob : {mean_mood:.4f}", | |
| "", | |
| f"Output Directory : {run_dir}", | |
| "", | |
| "⚠️ For research/educational use only. Not a clinical diagnosis.", | |
| "=" * 60, | |
| ] | |
| path.write_text("\n".join(lines), encoding="utf-8") | |
| # ── Core prediction functions ───────────────────────────────────────────────── | |
| def predict_single_stage( | |
| age, race, langint, | |
| hot_flash, num_hot_flash, bothersome_hf, | |
| sleep_quality, depression_indicator, mood_change, irritability, | |
| pain_indicator, abbleed, vaginal_dryness, lmp_day, | |
| model_choice, | |
| ): | |
| """ | |
| Single-person stage prediction. | |
| Returns (stage_html, chart_fig, conf_note, compare_html, csv_download_path). | |
| """ | |
| if not _MODEL_OK: | |
| return f"⚠️ {_MODEL_MSG}", None, "Models unavailable.", "", None | |
| # Build feature dict using the model's canonical feature names | |
| def _v(x): | |
| return float(x) if x is not None else np.nan | |
| feature_dict = { | |
| "AGE7": _v(age), | |
| "RACE": _v(race), | |
| "LANGINT7": _v(langint), | |
| "HOTFLAS7": _v(hot_flash), | |
| "NUMHOTF7": _v(num_hot_flash), | |
| "BOTHOTF7": _v(bothersome_hf), | |
| "SLEEPQL7": _v(sleep_quality), | |
| "DEPRESS7": _v(depression_indicator), | |
| "MOODCHG7": _v(mood_change), | |
| "IRRITAB7": _v(irritability), | |
| "PAIN17": _v(pain_indicator), | |
| "ABBLEED7": _v(abbleed), # ← correct feature name (was ABLEED7) | |
| "VAGINDR7": _v(vaginal_dryness), | |
| "LMPDAY7": _v(lmp_day) if lmp_day else np.nan, | |
| } | |
| try: | |
| result = _forecast.predict_single(feature_dict, model=model_choice, return_proba=True) | |
| stage = result["stage"] | |
| confidence = result.get("confidence") or 0.0 | |
| proba = result.get("probabilities") or {} | |
| # ── Create timestamped run directory ────────────────────────────────── | |
| run_dir = _make_run_dir() | |
| # ── Save probability chart (PNG) ────────────────────────────────────── | |
| chart_path = run_dir / "charts" / "stage_probabilities.png" | |
| chart_fig = _make_proba_chart(proba, stage, save_path=chart_path) if proba else None | |
| # ── Save prediction CSV ─────────────────────────────────────────────── | |
| pred_row = { | |
| "predicted_stage": stage, | |
| "model": model_choice, | |
| "confidence": round(confidence, 4), | |
| **{f"prob_{k}": round(v, 4) for k, v in proba.items()}, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| csv_path = run_dir / "predictions" / "stage_prediction.csv" | |
| pd.DataFrame([pred_row]).to_csv(csv_path, index=False) | |
| # ── Model comparison ────────────────────────────────────────────────── | |
| comparison = _forecast.compare_models(feature_dict) | |
| rf_stage = comparison["RandomForest"]["stage"] | |
| lr_stage = comparison["LogisticRegression"]["stage"] | |
| agree = rf_stage == lr_stage | |
| # ── Save text report ────────────────────────────────────────────────── | |
| txt_path = run_dir / "reports" / "prediction_summary.txt" | |
| _write_single_stage_report( | |
| txt_path, stage, confidence, proba, | |
| model_choice, comparison, feature_dict, | |
| ) | |
| # ── Build result card HTML ──────────────────────────────────────────── | |
| info = STAGE_INFO.get(stage, {}) | |
| emoji = STAGE_EMOJI.get(stage, "⚪") | |
| color = STAGE_COLORS.get(stage, "#607d8b") | |
| conf_color = _confidence_color(confidence) | |
| symptom_tags = "".join( | |
| f'<span style="background:{color}14;color:{color};padding:4px 10px;' | |
| f'border-radius:20px;border:1px solid {color}44;font-size:12px;' | |
| f'font-weight:500">{s}</span>' | |
| for s in info.get("symptoms", []) | |
| ) | |
| stage_html = f""" | |
| <div class="result-card" style="border-left:4px solid {color}"> | |
| <div style="display:flex;align-items:center;gap:12px;margin-bottom:16px;flex-wrap:wrap"> | |
| <span style="font-size:40px;flex-shrink:0">{emoji}</span> | |
| <div style="flex:1;min-width:140px"> | |
| <div style="color:#6b7280;font-size:12px;text-transform:uppercase;letter-spacing:2px"> | |
| Predicted Stage | |
| </div> | |
| <div style="color:{color};font-size:26px;font-weight:700"> | |
| {STAGE_LABELS.get(stage, stage)} | |
| </div> | |
| </div> | |
| <div style="text-align:right;flex-shrink:0"> | |
| <div style="color:#6b7280;font-size:11px">Confidence</div> | |
| <div style="color:{conf_color};font-size:28px;font-weight:700"> | |
| {confidence:.0%} | |
| </div> | |
| </div> | |
| </div> | |
| <hr style="border:none;border-top:1px solid #e2e8f0;margin:12px 0"> | |
| <p style="color:#374151;font-size:14px;margin:8px 0"> | |
| {info.get('description', '')} | |
| </p> | |
| <div style="margin-top:12px"> | |
| <div style="color:#6b7280;font-size:11px;text-transform:uppercase; | |
| letter-spacing:1px;margin-bottom:6px">Common Symptoms</div> | |
| <div style="display:flex;flex-wrap:wrap;gap:6px">{symptom_tags}</div> | |
| </div> | |
| <div style="background:{color}0d;border-left:3px solid {color}; | |
| padding:10px 14px;margin-top:14px;border-radius:0 8px 8px 0"> | |
| <span style="color:{color};font-size:12px;font-weight:600">💡 Guidance: </span> | |
| <span style="color:#374151;font-size:13px">{info.get('guidance', '')}</span> | |
| </div> | |
| <div style="color:#9ca3af;font-size:11px;margin-top:12px"> | |
| Model: {model_choice} · {datetime.now().strftime('%Y-%m-%d %H:%M')} | |
| </div> | |
| </div> | |
| """ | |
| # Confidence note | |
| if confidence >= 0.8: | |
| conf_note = "✅ High confidence — the model is quite certain about this stage." | |
| elif confidence >= 0.6: | |
| conf_note = ("⚠️ Moderate confidence — consider providing more feature values " | |
| "or consulting a clinician.") | |
| else: | |
| conf_note = ("🔴 Low confidence — prediction is uncertain; " | |
| "clinical consultation is strongly recommended.") | |
| # Model comparison panel + run-dir info | |
| compare_html = f""" | |
| <div class="result-card" style="margin-top:0"> | |
| <div style="color:#6b7280;font-size:11px;text-transform:uppercase; | |
| letter-spacing:1px;margin-bottom:10px;font-weight:600"> | |
| Model Comparison | |
| </div> | |
| <div class="stat-grid-2"> | |
| <div class="stat-item" style="border-top:3px solid #16a34a"> | |
| <div style="color:#16a34a;font-size:11px;font-weight:600">Random Forest</div> | |
| <div style="color:#111827;font-size:17px;margin-top:4px"> | |
| {STAGE_EMOJI.get(rf_stage,'')} {STAGE_LABELS.get(rf_stage, rf_stage)} | |
| </div> | |
| <div style="color:#6b7280;font-size:12px"> | |
| {comparison['RandomForest'].get('confidence', 0):.0%} confidence | |
| </div> | |
| </div> | |
| <div class="stat-item" style="border-top:3px solid #2563eb"> | |
| <div style="color:#2563eb;font-size:11px;font-weight:600"> | |
| Logistic Regression | |
| </div> | |
| <div style="color:#111827;font-size:17px;margin-top:4px"> | |
| {STAGE_EMOJI.get(lr_stage,'')} {STAGE_LABELS.get(lr_stage, lr_stage)} | |
| </div> | |
| <div style="color:#6b7280;font-size:12px"> | |
| {comparison['LogisticRegression'].get('confidence', 0):.0%} confidence | |
| </div> | |
| </div> | |
| </div> | |
| <div style="margin-top:10px;padding:8px;border-radius:8px; | |
| background:{'#d1fae5' if agree else '#fef2f2'}; | |
| color:{'#065f46' if agree else '#9f1239'}; | |
| font-size:13px;text-align:center;font-weight:500"> | |
| {"✅ Both models agree — prediction is robust" | |
| if agree else | |
| "⚠️ Models disagree — interpret with caution"} | |
| </div> | |
| <div class="output-path-box"> | |
| <div class="output-path-title">📁 Outputs saved to:</div> | |
| <div class="output-path-dir">{run_dir}/</div> | |
| <div class="output-path-files"> | |
| charts/stage_probabilities.png<br> | |
| predictions/stage_prediction.csv<br> | |
| reports/prediction_summary.txt | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return stage_html, chart_fig, conf_note, compare_html, str(csv_path) | |
| except Exception as exc: | |
| return f"❌ Prediction error: {exc}", None, "", "", None | |
| def predict_batch_stage(file, model_choice): | |
| """ | |
| Batch stage prediction from uploaded CSV. | |
| Returns (csv_download_path, summary_html, preview_df). | |
| """ | |
| if not _MODEL_OK: | |
| return None, f"⚠️ {_MODEL_MSG}", None | |
| if file is None: | |
| return None, "Please upload a CSV file.", None | |
| file_path = _get_file_path(file) | |
| try: | |
| df = pd.read_csv(file_path) | |
| except Exception as exc: | |
| return None, f"Could not read CSV: {exc}", None | |
| if df.empty: | |
| return None, "Uploaded CSV is empty.", None | |
| # Identify ID column | |
| id_col_candidates = ["individual", "Individual", "ID", "id", | |
| "SWANID", "subject", "Subject"] | |
| id_col = next((c for c in id_col_candidates if c in df.columns), None) | |
| # Validate features | |
| feature_names = _metadata.get("feature_names", []) | |
| matching = [c for c in df.columns if c in feature_names] | |
| missing_pct = 1 - len(matching) / max(len(feature_names), 1) | |
| warnings_list = [] | |
| if not matching: | |
| return None, ( | |
| "❌ No matching feature columns found. " | |
| "Please include columns from the training feature set " | |
| "(see 'Feature Reference' tab)." | |
| ), None | |
| if missing_pct > 0.5: | |
| warnings_list.append( | |
| f"⚠️ {missing_pct:.0%} of training features are missing — " | |
| "prediction accuracy may be reduced." | |
| ) | |
| try: | |
| results = _forecast.predict_batch(df, model=model_choice, return_proba=True) | |
| # Insert individual ID | |
| if id_col: | |
| results.insert(0, "individual", df[id_col].values) | |
| else: | |
| results.insert(0, "individual", | |
| [f"Row_{i+1}" for i in range(len(results))]) | |
| results["model"] = model_choice | |
| results["notes"] = "" | |
| if "confidence" in results.columns: | |
| low_mask = results["confidence"] < 0.6 | |
| results.loc[low_mask, "notes"] = "Low confidence — review manually" | |
| # ── Create timestamped run directory ────────────────────────────────── | |
| run_dir = _make_run_dir() | |
| # ── Save predictions CSV ────────────────────────────────────────────── | |
| csv_path = run_dir / "predictions" / "batch_stage_predictions.csv" | |
| results.to_csv(csv_path, index=False) | |
| # ── Save confidence/distribution chart (PNG) ────────────────────────── | |
| chart_path = run_dir / "charts" / "batch_summary_chart.png" | |
| _make_batch_summary_chart(results, save_path=chart_path) | |
| # ── Save text report ────────────────────────────────────────────────── | |
| txt_path = run_dir / "reports" / "batch_summary.txt" | |
| _write_batch_report(txt_path, results, model_choice, run_dir) | |
| # ── Build summary HTML ──────────────────────────────────────────────── | |
| total = len(results) | |
| dist = results["predicted_stage"].value_counts().to_dict() | |
| mean_conf = results["confidence"].mean() \ | |
| if "confidence" in results.columns else 0.0 | |
| high_conf = int((results["confidence"] > 0.8).sum()) \ | |
| if "confidence" in results.columns else 0 | |
| dist_bars = "" | |
| for stage in ["pre", "peri", "post"]: | |
| count = dist.get(stage, 0) | |
| pct = count / total * 100 | |
| dist_bars += f""" | |
| <div style="margin:6px 0"> | |
| <div style="display:flex;justify-content:space-between;margin-bottom:2px"> | |
| <span style="color:#374151;font-size:13px"> | |
| {STAGE_EMOJI.get(stage,'')} {STAGE_LABELS.get(stage, stage)} | |
| </span> | |
| <span style="color:#6b7280;font-size:12px">{count} ({pct:.0f}%)</span> | |
| </div> | |
| <div style="background:#e2e8f0;border-radius:4px;height:8px"> | |
| <div style="background:{STAGE_COLORS.get(stage,'#6b7280')}; | |
| width:{pct}%;height:8px;border-radius:4px"></div> | |
| </div> | |
| </div>""" | |
| warn_html = "".join( | |
| f'<div style="color:#d97706;font-size:12px;margin-top:4px">{w}</div>' | |
| for w in warnings_list | |
| ) | |
| summary_html = f""" | |
| <div class="result-card"> | |
| <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px"> | |
| 📊 Batch Results — {total} individuals | |
| </div> | |
| {warn_html} | |
| <div class="stat-grid-3"> | |
| <div class="stat-item"> | |
| <div class="stat-label">Total</div> | |
| <div class="stat-value">{total}</div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">Avg Confidence</div> | |
| <div class="stat-value" style="color:{_confidence_color(mean_conf)}"> | |
| {mean_conf:.0%} | |
| </div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">High Conf (>80%)</div> | |
| <div class="stat-value" style="color:#16a34a">{high_conf}/{total}</div> | |
| </div> | |
| </div> | |
| <div style="margin-top:12px">{dist_bars}</div> | |
| <div class="output-path-box"> | |
| <div class="output-path-title">📁 Outputs saved to:</div> | |
| <div class="output-path-dir">{run_dir}/</div> | |
| <div class="output-path-files"> | |
| predictions/batch_stage_predictions.csv<br> | |
| charts/batch_summary_chart.png<br> | |
| reports/batch_summary.txt | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return str(csv_path), summary_html, results.head(20) | |
| except Exception as exc: | |
| return None, f"❌ Batch prediction error: {exc}", None | |
| def predict_symptoms(individual_id, lmp_input, target_date_input, cycle_length): | |
| """ | |
| Cycle-based symptom forecasting (single person). | |
| Returns (result_html, chart_fig, csv_download_path). | |
| """ | |
| if not lmp_input: | |
| return "Please enter your Last Menstrual Period date.", None, None | |
| try: | |
| cycle_length = int(cycle_length) if cycle_length else 28 | |
| fore = SymptomCycleForecaster(cycle_length=cycle_length) | |
| target_date = target_date_input if target_date_input else None | |
| result = fore.predict_single(lmp=lmp_input, target_date=target_date) | |
| cycle_day = result.get("cycle_day") | |
| hot_prob = result.get("hotflash_prob", 0) | |
| hot_pred = result.get("hotflash_pred", False) | |
| mood_prob = result.get("mood_prob", 0) | |
| mood_pred = result.get("mood_pred", False) | |
| # Safe float helpers | |
| hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0 | |
| mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0 | |
| # ── Create timestamped run directory ────────────────────────────────── | |
| run_dir = _make_run_dir() | |
| # ── Save cycle chart (PNG) ──────────────────────────────────────────── | |
| chart_path = run_dir / "charts" / "cycle_position.png" | |
| chart_fig = _make_cycle_chart( | |
| cycle_day, cycle_length, hp, mp, save_path=chart_path | |
| ) | |
| # ── Save forecast CSV ───────────────────────────────────────────────── | |
| csv_path = run_dir / "predictions" / "symptom_forecast.csv" | |
| lmp_note = "" | |
| try: | |
| int(str(lmp_input).strip()) | |
| lmp_note = "LMP inferred as day-of-month; interpret with caution" | |
| except (ValueError, TypeError): | |
| pass | |
| pd.DataFrame([{ | |
| "individual": individual_id or "N/A", | |
| "LMP": lmp_input, | |
| "date": target_date_input or datetime.now().strftime("%Y-%m-%d"), | |
| "cycle_day": cycle_day, | |
| "hotflash_prob": round(hp, 6), | |
| "hotflash_pred": bool(hot_pred), | |
| "mood_prob": round(mp, 6), | |
| "mood_pred": bool(mood_pred), | |
| "notes": lmp_note, | |
| }]).to_csv(csv_path, index=False) | |
| # ── Save text report ────────────────────────────────────────────────── | |
| txt_path = run_dir / "reports" / "symptom_summary.txt" | |
| _write_symptom_report( | |
| txt_path, individual_id, lmp_input, target_date_input, | |
| cycle_day, cycle_length, hp, hot_pred, mp, mood_pred, | |
| ) | |
| # ── Build result HTML ───────────────────────────────────────────────── | |
| def _prob_bar(prob, label, color): | |
| pct = min(prob * 100, 100) | |
| return f""" | |
| <div style="margin:10px 0"> | |
| <div style="display:flex;justify-content:space-between;margin-bottom:4px"> | |
| <span style="color:#374151;font-size:14px">{label}</span> | |
| <span style="color:{color};font-size:16px;font-weight:700">{pct:.0f}%</span> | |
| </div> | |
| <div style="background:#e2e8f0;border-radius:6px;height:10px"> | |
| <div style="background:{color};width:{pct}%;height:10px; | |
| border-radius:6px;transition:width 0.5s"></div> | |
| </div> | |
| </div>""" | |
| hot_alert = "🔴 Elevated risk" if hot_pred else "🟢 Low risk" | |
| mood_alert = "🔴 Elevated risk" if mood_pred else "🟢 Low risk" | |
| html = f""" | |
| <div class="result-card"> | |
| <div style="color:#111827;font-size:18px;font-weight:700;margin-bottom:4px"> | |
| {individual_id or 'Forecast'} — Cycle Day {cycle_day or '?'} | |
| </div> | |
| <div style="color:#6b7280;font-size:13px;margin-bottom:20px"> | |
| LMP: {lmp_input} | Target: {target_date_input or 'Today'} | |
| | Cycle: {cycle_length} days | |
| </div> | |
| {_prob_bar(hp, '🔥 Hot Flash Probability', '#ef4444')} | |
| <div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{hot_alert}</div> | |
| {_prob_bar(mp, '😤 Mood Change Probability', '#7c3aed')} | |
| <div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{mood_alert}</div> | |
| <div style="background:#f8fafc;border:1px solid #e2e8f0;border-radius:8px; | |
| padding:12px;margin-top:14px;font-size:12px;color:#6b7280"> | |
| ℹ️ Probabilities are computed from a cycle-phase model (Gaussian heuristic). | |
| They represent symptom likelihood based on cycle day, not a clinical diagnosis. | |
| </div> | |
| <div class="output-path-box"> | |
| <div class="output-path-title">📁 Outputs saved to:</div> | |
| <div class="output-path-dir">{run_dir}/</div> | |
| <div class="output-path-files"> | |
| charts/cycle_position.png<br> | |
| predictions/symptom_forecast.csv<br> | |
| reports/symptom_summary.txt | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return html, chart_fig, str(csv_path) | |
| except Exception as exc: | |
| return f"❌ Error: {exc}", None, None | |
| def predict_symptoms_batch(file, lmp_col_name, date_col_name, cycle_length): | |
| """ | |
| Batch symptom forecasting from CSV. | |
| Returns (csv_download_path, summary_html, preview_df). | |
| """ | |
| if file is None: | |
| return None, "Please upload a CSV file.", None | |
| file_path = _get_file_path(file) | |
| try: | |
| df = pd.read_csv(file_path) | |
| except Exception as exc: | |
| return None, f"Could not read CSV: {exc}", None | |
| if lmp_col_name not in df.columns: | |
| return None, ( | |
| f"LMP column '{lmp_col_name}' not found in CSV. " | |
| f"Columns present: {list(df.columns)}" | |
| ), None | |
| try: | |
| cycle_length = int(cycle_length) if cycle_length else 28 | |
| fore = SymptomCycleForecaster(cycle_length=cycle_length) | |
| date_col = date_col_name \ | |
| if (date_col_name and date_col_name in df.columns) else None | |
| results = fore.predict_df(df, lmp_col=lmp_col_name, date_col=date_col) | |
| # ── Add notes column (flag day-of-month LMP rows) ───────────────────── | |
| def _lmp_note(val): | |
| try: | |
| int(str(val).strip()) | |
| return "LMP inferred as day-of-month; interpret with caution" | |
| except (ValueError, TypeError): | |
| return "" | |
| results["notes"] = df[lmp_col_name].apply(_lmp_note) | |
| # ── Create timestamped run directory ────────────────────────────────── | |
| run_dir = _make_run_dir() | |
| # ── Save predictions CSV ────────────────────────────────────────────── | |
| csv_path = run_dir / "predictions" / "batch_symptom_forecast.csv" | |
| results.to_csv(csv_path, index=False) | |
| # ── Save text report ────────────────────────────────────────────────── | |
| txt_path = run_dir / "reports" / "batch_symptom_summary.txt" | |
| _write_batch_symptom_report(txt_path, results, cycle_length, run_dir) | |
| # ── Build summary HTML ──────────────────────────────────────────────── | |
| total = len(results) | |
| hot_flags = int(results["hotflash_pred"].sum()) \ | |
| if "hotflash_pred" in results.columns else 0 | |
| mood_flags = int(results["mood_pred"].sum()) \ | |
| if "mood_pred" in results.columns else 0 | |
| mean_hot = float(results["hotflash_prob"].mean()) \ | |
| if "hotflash_prob" in results.columns else 0.0 | |
| mean_mood = float(results["mood_prob"].mean()) \ | |
| if "mood_prob" in results.columns else 0.0 | |
| summary_html = f""" | |
| <div class="result-card"> | |
| <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px"> | |
| 🌊 Symptom Forecast — {total} individuals | |
| </div> | |
| <div class="stat-grid-3"> | |
| <div class="stat-item"> | |
| <div class="stat-label">Total</div> | |
| <div class="stat-value">{total}</div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">🔥 Hot Flash Risk</div> | |
| <div class="stat-value" style="color:#ef4444">{hot_flags}</div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">😤 Mood Risk</div> | |
| <div class="stat-value" style="color:#7c3aed">{mood_flags}</div> | |
| </div> | |
| </div> | |
| <div class="stat-grid-2"> | |
| <div class="stat-item"> | |
| <div class="stat-label">Avg Hot Flash Prob</div> | |
| <div class="stat-value" style="color:#ef4444;font-size:18px"> | |
| {mean_hot:.1%} | |
| </div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">Avg Mood Prob</div> | |
| <div class="stat-value" style="color:#7c3aed;font-size:18px"> | |
| {mean_mood:.1%} | |
| </div> | |
| </div> | |
| </div> | |
| <div class="output-path-box"> | |
| <div class="output-path-title">📁 Outputs saved to:</div> | |
| <div class="output-path-dir">{run_dir}/</div> | |
| <div class="output-path-files"> | |
| predictions/batch_symptom_forecast.csv<br> | |
| reports/batch_symptom_summary.txt | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return str(csv_path), summary_html, results | |
| except Exception as exc: | |
| return None, f"❌ Error: {exc}", None | |
| # ── Feature reference & model status ───────────────────────────────────────── | |
| def get_feature_reference() -> str: | |
| feature_names = _metadata.get("feature_names", list(FEATURE_DESCRIPTIONS.keys())) | |
| rows = "" | |
| for i, f in enumerate(feature_names[:60]): | |
| desc = FEATURE_DESCRIPTIONS.get(f, f.split("_")[0]) | |
| rows += f""" | |
| <tr> | |
| <td class="feature-num">{i + 1}</td> | |
| <td class="feature-code">{f}</td> | |
| <td class="feature-desc">{desc}</td> | |
| </tr>""" | |
| remaining = len(feature_names) - 60 | |
| if remaining > 0: | |
| rows += f""" | |
| <tr> | |
| <td colspan="3" style="padding:8px;color:#9ca3af;font-size:12px;text-align:center"> | |
| … and {remaining} more features (one-hot encoded categories) | |
| </td> | |
| </tr>""" | |
| return f""" | |
| <div class="feature-table-wrap"> | |
| <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px"> | |
| 📋 Training Features ({len(feature_names)} total after encoding) | |
| </div> | |
| <table> | |
| <thead> | |
| <tr> | |
| <th>#</th> | |
| <th>Feature</th> | |
| <th>Description</th> | |
| </tr> | |
| </thead> | |
| <tbody>{rows}</tbody> | |
| </table> | |
| </div> | |
| """ | |
| def get_model_status() -> str: | |
| if _MODEL_OK: | |
| fc = len(_metadata.get("feature_names", [])) | |
| sc = _metadata.get("stage_classes", ["pre", "peri", "post"]) | |
| badges = "".join( | |
| f'<span style="background:{STAGE_COLORS.get(s,"#607d8b")}18;' | |
| f'color:{STAGE_COLORS.get(s,"#555")};padding:4px 12px;' | |
| f'border-radius:20px;border:1px solid {STAGE_COLORS.get(s,"#607d8b")}44;' | |
| f'font-size:13px;font-weight:600">{STAGE_EMOJI.get(s,"")} {s}</span>' | |
| for s in sc | |
| ) | |
| return f""" | |
| <div class="status-card"> | |
| <div style="display:flex;align-items:center;gap:10px;margin-bottom:14px"> | |
| <span style="font-size:24px">✅</span> | |
| <div> | |
| <div style="color:#059669;font-size:16px;font-weight:700"> | |
| Models Loaded Successfully | |
| </div> | |
| <div style="color:#6b7280;font-size:12px">Ready for predictions</div> | |
| </div> | |
| </div> | |
| <div class="stat-grid-3"> | |
| <div class="stat-item"> | |
| <div class="stat-label">Features</div> | |
| <div class="stat-value">{fc}</div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">Models</div> | |
| <div class="stat-value">2</div> | |
| </div> | |
| <div class="stat-item"> | |
| <div class="stat-label">Stages</div> | |
| <div class="stat-value">{len(sc)}</div> | |
| </div> | |
| </div> | |
| <div style="margin-top:14px"> | |
| <div style="color:#6b7280;font-size:11px;text-transform:uppercase; | |
| letter-spacing:0.5px;margin-bottom:6px">Available Stages</div> | |
| <div style="display:flex;gap:8px;flex-wrap:wrap">{badges}</div> | |
| </div> | |
| </div> | |
| """ | |
| return f""" | |
| <div class="status-card"> | |
| <div style="display:flex;align-items:center;gap:10px;margin-bottom:10px"> | |
| <span style="font-size:24px">⚠️</span> | |
| <div> | |
| <div style="color:#dc2626;font-size:16px;font-weight:700"> | |
| Models Not Loaded | |
| </div> | |
| <div style="color:#6b7280;font-size:12px">{_MODEL_MSG}</div> | |
| </div> | |
| </div> | |
| <div style="background:#fef2f2;border:1px solid #fecaca;border-radius:8px; | |
| padding:12px;color:#9f1239;font-size:13px"> | |
| To train and save models:<br> | |
| <code style="background:#1e293b;color:#a3e635;padding:4px 8px;border-radius:4px; | |
| margin-top:6px;display:inline-block">python menopause.py</code> | |
| <br><br> | |
| This generates <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px; | |
| color:#1e293b">swan_ml_output/rf_pipeline.pkl</code>, | |
| <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px; | |
| color:#1e293b">lr_pipeline.pkl</code>, and | |
| <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px; | |
| color:#1e293b">forecast_metadata.json</code>. | |
| </div> | |
| </div> | |
| """ | |
| # ── Education content ───────────────────────────────────────────────────────── | |
| EDUCATION_HTML = """ | |
| <div class="edu-card"> | |
| <h2>🌸 Understanding Menopause</h2> | |
| <p>Menopause is a natural biological process marking the end of menstrual cycles. | |
| It is officially diagnosed after 12 consecutive months without a menstrual period | |
| and typically occurs in women in their late 40s to early 50s.</p> | |
| <h3>Three Stages</h3> | |
| <div class="stage-cards-grid"> | |
| <div class="stage-card-pre"> | |
| <div style="color:#16a34a;font-weight:700;margin-bottom:8px">🟢 Pre-Menopause</div> | |
| <p style="font-size:13px;margin:0;color:#374151">Regular ovarian function. Periods are predictable. | |
| Hormones (estrogen, progesterone) follow a consistent monthly pattern.</p> | |
| </div> | |
| <div class="stage-card-peri"> | |
| <div style="color:#d97706;font-weight:700;margin-bottom:8px">🟡 Peri-Menopause</div> | |
| <p style="font-size:13px;margin:0;color:#374151">Transition phase — usually begins in the mid-40s. | |
| Hormone levels fluctuate. Periods become irregular. | |
| Hot flashes and sleep issues may begin.</p> | |
| </div> | |
| <div class="stage-card-post"> | |
| <div style="color:#7c3aed;font-weight:700;margin-bottom:8px">🟣 Post-Menopause</div> | |
| <p style="font-size:13px;margin:0;color:#374151">12+ months after the last period. | |
| Lower estrogen levels. Risk factors for osteoporosis and | |
| cardiovascular disease increase.</p> | |
| </div> | |
| </div> | |
| <h3>Common Symptoms by Stage</h3> | |
| <table style="width:100%;border-collapse:collapse;font-size:13px"> | |
| <thead> | |
| <tr style="background:#f8fafc"> | |
| <th style="padding:8px;text-align:left;color:#6b7280;font-weight:600">Symptom</th> | |
| <th style="padding:8px;text-align:center;color:#16a34a;font-weight:600">Pre</th> | |
| <th style="padding:8px;text-align:center;color:#d97706;font-weight:600">Peri</th> | |
| <th style="padding:8px;text-align:center;color:#7c3aed;font-weight:600">Post</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| <tr style="border-bottom:1px solid #e2e8f0"> | |
| <td style="padding:8px;color:#374151">Hot flashes</td> | |
| <td style="text-align:center;color:#9ca3af">–</td> | |
| <td style="text-align:center">✅</td> | |
| <td style="text-align:center">✅</td> | |
| </tr> | |
| <tr style="border-bottom:1px solid #e2e8f0"> | |
| <td style="padding:8px;color:#374151">Irregular periods</td> | |
| <td style="text-align:center;color:#9ca3af">–</td> | |
| <td style="text-align:center">✅</td> | |
| <td style="text-align:center;color:#9ca3af">N/A</td> | |
| </tr> | |
| <tr style="border-bottom:1px solid #e2e8f0"> | |
| <td style="padding:8px;color:#374151">Sleep disturbances</td> | |
| <td style="text-align:center;color:#6b7280">Mild</td> | |
| <td style="text-align:center">✅</td> | |
| <td style="text-align:center">✅</td> | |
| </tr> | |
| <tr style="border-bottom:1px solid #e2e8f0"> | |
| <td style="padding:8px;color:#374151">Mood changes</td> | |
| <td style="text-align:center;color:#6b7280">PMS</td> | |
| <td style="text-align:center">✅</td> | |
| <td style="text-align:center;color:#6b7280">Possible</td> | |
| </tr> | |
| <tr style="border-bottom:1px solid #e2e8f0"> | |
| <td style="padding:8px;color:#374151">Vaginal dryness</td> | |
| <td style="text-align:center;color:#9ca3af">–</td> | |
| <td style="text-align:center;color:#6b7280">Possible</td> | |
| <td style="text-align:center">✅</td> | |
| </tr> | |
| <tr> | |
| <td style="padding:8px;color:#374151">Bone density changes</td> | |
| <td style="text-align:center;color:#9ca3af">–</td> | |
| <td style="text-align:center;color:#6b7280">Begins</td> | |
| <td style="text-align:center">✅</td> | |
| </tr> | |
| </tbody> | |
| </table> | |
| <h3>About This Tool</h3> | |
| <p style="font-size:13px">This application uses machine learning models trained on the | |
| SWAN (Study of Women's Health Across the Nation) dataset — a landmark multisite, | |
| multiethnic longitudinal study. The models were trained on self-reported symptom and | |
| behavioral data to predict menopausal stage.</p> | |
| <div class="disclaimer-box"> | |
| ⚠️ <strong style="color:#d97706">Disclaimer:</strong> | |
| This tool is for educational and research purposes only. | |
| Predictions should not substitute clinical diagnosis. | |
| Always consult a qualified healthcare provider for medical advice. | |
| </div> | |
| </div> | |
| """ | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| CUSTOM_CSS = """ | |
| /* ── Force light mode — disable Gradio dark theme entirely ───────────── */ | |
| :root { | |
| color-scheme: light only !important; | |
| } | |
| /* Fallback: if Gradio somehow sets .dark, override every key variable */ | |
| body.dark, | |
| body.dark .gradio-container { | |
| --body-background-fill: #f0f4f8 !important; | |
| --background-fill-primary: #ffffff !important; | |
| --background-fill-secondary: #f8fafc !important; | |
| --border-color-primary: #e2e8f0 !important; | |
| --border-color-accent: #3b82f6 !important; | |
| --color-accent: #2563eb !important; | |
| --color-accent-soft: #eff6ff !important; | |
| --input-background-fill: #ffffff !important; | |
| --input-border-color: #d1d5db !important; | |
| --label-text-color: #374151 !important; | |
| --block-label-text-color: #374151 !important; | |
| --block-title-text-color: #111827 !important; | |
| --body-text-color: #111827 !important; | |
| --body-text-color-subdued: #6b7280 !important; | |
| --link-text-color: #2563eb !important; | |
| --button-primary-background-fill: #2563eb !important; | |
| --button-primary-text-color: #ffffff !important; | |
| --button-secondary-background-fill: #ffffff !important; | |
| --button-secondary-text-color: #374151 !important; | |
| --tab-text-color: #374151 !important; | |
| --tab-text-color-selected: #2563eb !important; | |
| color: #111827 !important; | |
| background-color: #f0f4f8 !important; | |
| } | |
| /* ── Core ────────────────────────────────────────────────────────────── */ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important; | |
| background: #f0f4f8 !important; | |
| } | |
| /* ── Header banner ──────────────────────────────────────────────────── */ | |
| .header-banner { | |
| background: linear-gradient(135deg, #faf5ff 0%, #fff0f9 50%, #eff6ff 100%); | |
| border: 1px solid #e9d5ff; | |
| border-radius: 16px; | |
| padding: 28px 32px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 2px 8px rgba(139,92,246,0.08); | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .header-banner::before { | |
| content: ''; | |
| position: absolute; | |
| top: -40%; right: -5%; | |
| width: 280px; height: 280px; | |
| background: radial-gradient(circle, rgba(139,92,246,0.08) 0%, transparent 70%); | |
| pointer-events: none; | |
| } | |
| /* ── Reusable info boxes ─────────────────────────────────────────────── */ | |
| .info-box { | |
| background: #f8fafc; | |
| border: 1px solid #e2e8f0; | |
| border-left: 3px solid #3b82f6; | |
| border-radius: 8px; | |
| padding: 12px 16px; | |
| color: #475569; | |
| font-size: 13px; | |
| margin-bottom: 16px; | |
| line-height: 1.5; | |
| } | |
| .info-box code { | |
| background: #e2e8f0; | |
| color: #1e293b; | |
| padding: 1px 5px; | |
| border-radius: 3px; | |
| font-family: monospace; | |
| font-size: 0.9em; | |
| } | |
| .section-label { | |
| color: #2563eb; | |
| font-size: 12px; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| letter-spacing: 0.6px; | |
| margin-bottom: 10px; | |
| margin-top: 10px; | |
| } | |
| .format-hint { | |
| background: #f8fafc; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 8px; | |
| padding: 14px; | |
| margin-top: 10px; | |
| font-size: 12px; | |
| color: #475569; | |
| } | |
| .format-hint-title { color: #2563eb; font-weight: 600; margin-bottom: 6px; } | |
| .format-hint pre { color: #475569; margin: 0; font-size: 11px; white-space: pre-wrap; } | |
| .format-hint-note { color: #94a3b8; font-size: 11px; margin-top: 8px; } | |
| .placeholder-msg { color: #9ca3af; text-align: center; padding: 40px; font-size: 14px; } | |
| .section-divider { border: none; border-top: 1px solid #e2e8f0; margin: 24px 0; } | |
| .batch-section-label { color: #2563eb; font-size: 14px; font-weight: 600; margin-bottom: 12px; } | |
| /* ── Result & summary cards ─────────────────────────────────────────── */ | |
| .result-card { | |
| background: #ffffff; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 16px; | |
| padding: 24px; | |
| box-shadow: 0 1px 4px rgba(0,0,0,0.06); | |
| font-family: 'Segoe UI', system-ui, sans-serif; | |
| } | |
| .stat-grid-3 { display:grid; grid-template-columns:repeat(3,1fr); gap:12px; margin:14px 0; } | |
| .stat-grid-2 { display:grid; grid-template-columns:1fr 1fr; gap:10px; margin-top:10px; } | |
| .stat-item { background:#f8fafc; border:1px solid #e2e8f0; padding:12px; border-radius:8px; text-align:center; } | |
| .stat-label { color:#6b7280; font-size:11px; text-transform:uppercase; letter-spacing:0.4px; } | |
| .stat-value { color:#111827; font-size:22px; font-weight:700; line-height:1.2; margin-top:2px; } | |
| .output-path-box { background:#f0fdf4; border:1px solid #bbf7d0; border-radius:8px; padding:10px 14px; margin-top:12px; font-family:monospace; } | |
| .output-path-title { color:#059669; font-size:12px; font-weight:600; } | |
| .output-path-dir { color:#065f46; font-size:11px; margin-top:4px; } | |
| .output-path-files { color:#6b7280; font-size:10px; margin-top:4px; line-height:1.6; } | |
| /* ── Code blocks ────────────────────────────────────────────────────── */ | |
| .code-block { | |
| background: #1e293b; | |
| color: #a3e635; | |
| border-radius: 8px; | |
| padding: 12px; | |
| font-size: 12px; | |
| font-family: monospace; | |
| white-space: pre; | |
| overflow-x: auto; | |
| } | |
| /* ── Setup instructions card ─────────────────────────────────────────── */ | |
| .setup-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; margin-top:16px; font-family:'Segoe UI',system-ui,sans-serif; } | |
| .setup-title { color:#111827; font-size:15px; font-weight:700; margin-bottom:12px; } | |
| .setup-step { color:#374151; font-size:13px; line-height:1.8; } | |
| .setup-step strong { color:#2563eb; } | |
| /* ── Education ──────────────────────────────────────────────────────── */ | |
| .edu-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:16px; padding:28px; font-family:'Segoe UI',system-ui,sans-serif; color:#374151; line-height:1.7; } | |
| .edu-card h2 { color:#111827; font-size:22px; margin-top:0; } | |
| .edu-card h3 { color:#7c3aed; font-size:16px; margin-top:20px; } | |
| .stage-cards-grid { display:grid; grid-template-columns:repeat(3,1fr); gap:16px; margin:14px 0; } | |
| .stage-card-pre { background:#f0fdf4; border-top:4px solid #16a34a; padding:16px; border-radius:10px; } | |
| .stage-card-peri { background:#fffbeb; border-top:4px solid #d97706; padding:16px; border-radius:10px; } | |
| .stage-card-post { background:#faf5ff; border-top:4px solid #7c3aed; padding:16px; border-radius:10px; } | |
| .disclaimer-box { background:#fffbeb; border-left:3px solid #d97706; padding:12px 16px; border-radius:0 8px 8px 0; margin-top:14px; font-size:13px; color:#374151; } | |
| /* ── Feature reference table ────────────────────────────────────────── */ | |
| .feature-table-wrap { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; max-height:500px; overflow-y:auto; font-family:'Segoe UI',system-ui,sans-serif; } | |
| .feature-table-wrap table { width:100%; border-collapse:collapse; } | |
| .feature-table-wrap thead tr { background:#f8fafc; } | |
| .feature-table-wrap th { padding:8px; color:#6b7280; font-size:11px; text-align:left; text-transform:uppercase; letter-spacing:0.4px; } | |
| .feature-table-wrap tr { border-bottom:1px solid #e2e8f0; } | |
| .feature-table-wrap td { padding:8px; } | |
| .feature-code { color:#2563eb; font-family:monospace; font-size:13px; } | |
| .feature-desc { color:#374151; font-size:12px; } | |
| .feature-num { color:#9ca3af; font-size:12px; } | |
| /* ── Model status card ──────────────────────────────────────────────── */ | |
| .status-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; font-family:'Segoe UI',system-ui,sans-serif; } | |
| /* ── Footer ─────────────────────────────────────────────────────────── */ | |
| .app-footer { text-align:center; color:#9ca3af; font-size:11px; margin-top:24px; padding:16px; border-top:1px solid #e2e8f0; } | |
| .app-footer a { color:#2563eb; text-decoration:none; } | |
| /* ── Responsive — Tablet (≤ 768 px) ────────────────────────────────── */ | |
| @media (max-width: 768px) { | |
| .gradio-container { padding: 8px !important; } | |
| .header-banner { padding: 16px 20px !important; margin-bottom: 12px !important; } | |
| .header-status-badge { display: none !important; } | |
| .stat-grid-3 { grid-template-columns: 1fr !important; } | |
| .stat-grid-2 { grid-template-columns: 1fr !important; } | |
| .stage-cards-grid { grid-template-columns: 1fr !important; } | |
| } | |
| /* ── Responsive — Mobile (≤ 480 px) ────────────────────────────────── */ | |
| @media (max-width: 480px) { | |
| .header-banner h1 { font-size: 18px !important; } | |
| .result-card { padding: 16px !important; } | |
| .edu-card { padding: 16px !important; } | |
| .setup-card { padding: 14px !important; } | |
| } | |
| """ | |
| HEADER_HTML = """ | |
| <div class="header-banner"> | |
| <div style="display:flex;align-items:center;gap:16px;flex-wrap:wrap"> | |
| <div style="font-size:48px;flex-shrink:0">🌸</div> | |
| <div style="flex:1;min-width:200px"> | |
| <h1 style="margin:0;font-size:26px;font-weight:800; | |
| background:linear-gradient(135deg,#7c3aed,#db2777); | |
| -webkit-background-clip:text;-webkit-text-fill-color:transparent"> | |
| SWAN Menopause Prediction | |
| </h1> | |
| <p style="margin:4px 0 0;color:#6b7280;font-size:13px"> | |
| AI-powered menopausal stage prediction & symptom forecasting · | |
| Based on the SWAN dataset | |
| </p> | |
| </div> | |
| <div class="header-status-badge" style="text-align:right;flex-shrink:0"> | |
| <div style="background:#ffffff;border:1px solid #e2e8f0;border-radius:8px; | |
| padding:8px 16px;display:inline-block;box-shadow:0 1px 3px rgba(0,0,0,0.06)"> | |
| <div style="color:#9ca3af;font-size:10px;text-transform:uppercase;letter-spacing:1px"> | |
| Status | |
| </div> | |
| <div style="color:{color};font-size:13px;font-weight:600">{status}</div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """.format( | |
| color = "#059669" if _MODEL_OK else "#dc2626", | |
| status = "Models Ready ✅" if _MODEL_OK else "Models Needed ⚠️", | |
| ) | |
| # ── Force-light-mode JS (runs on every page load) ───────────────────────────── | |
| # Removes Gradio's .dark class, locks localStorage to "light", and uses a | |
| # MutationObserver to prevent the class from being re-applied — works on | |
| # HuggingFace Spaces regardless of the user's OS/browser dark-mode setting. | |
| FORCE_LIGHT_JS = """ | |
| function() { | |
| const forceLightMode = () => { | |
| if (document.body.classList.contains('dark')) { | |
| document.body.classList.remove('dark'); | |
| } | |
| }; | |
| // Apply immediately | |
| forceLightMode(); | |
| // Lock Gradio's stored preference | |
| try { localStorage.setItem('theme', 'light'); } catch(e) {} | |
| // Watch for Gradio trying to re-add .dark and block it | |
| new MutationObserver(function(mutations) { | |
| mutations.forEach(function(m) { | |
| if (m.attributeName === 'class') forceLightMode(); | |
| }); | |
| }).observe(document.body, { attributes: true, attributeFilter: ['class'] }); | |
| } | |
| """ | |
| # ── App builder ─────────────────────────────────────────────────────────────── | |
| def build_app(): | |
| with gr.Blocks( | |
| css = CUSTOM_CSS, | |
| js = FORCE_LIGHT_JS, | |
| title = "SWAN Menopause Prediction", | |
| theme = gr.themes.Soft( | |
| primary_hue = "blue", | |
| neutral_hue = "slate", | |
| ).set( | |
| # ── Body ────────────────────────────────────────────────────── | |
| body_background_fill = "#f0f4f8", | |
| body_background_fill_dark = "#f0f4f8", | |
| body_text_color = "#111827", | |
| body_text_color_dark = "#111827", | |
| body_text_color_subdued = "#6b7280", | |
| body_text_color_subdued_dark = "#6b7280", | |
| # ── Panel / block backgrounds ────────────────────────────────── | |
| background_fill_primary = "#ffffff", | |
| background_fill_primary_dark = "#ffffff", | |
| background_fill_secondary = "#f8fafc", | |
| background_fill_secondary_dark = "#f8fafc", | |
| block_background_fill = "#ffffff", | |
| block_background_fill_dark = "#ffffff", | |
| block_border_color = "#e2e8f0", | |
| block_border_color_dark = "#e2e8f0", | |
| block_label_background_fill = "#f8fafc", | |
| block_label_background_fill_dark= "#f8fafc", | |
| block_label_text_color = "#374151", | |
| block_label_text_color_dark = "#374151", | |
| block_title_text_color = "#111827", | |
| block_title_text_color_dark = "#111827", | |
| # ── Inputs ──────────────────────────────────────────────────── | |
| input_background_fill = "#ffffff", | |
| input_background_fill_dark = "#ffffff", | |
| input_background_fill_focus = "#ffffff", | |
| input_background_fill_focus_dark= "#ffffff", | |
| input_border_color = "#d1d5db", | |
| input_border_color_dark = "#d1d5db", | |
| input_border_color_focus = "#3b82f6", | |
| input_border_color_focus_dark = "#3b82f6", | |
| input_placeholder_color = "#9ca3af", | |
| input_placeholder_color_dark = "#9ca3af", | |
| # ── Borders ──────────────────────────────────────────────────── | |
| border_color_primary = "#e2e8f0", | |
| border_color_primary_dark = "#e2e8f0", | |
| border_color_accent = "#3b82f6", | |
| border_color_accent_dark = "#3b82f6", | |
| # ── Buttons ──────────────────────────────────────────────────── | |
| button_primary_background_fill = "#2563eb", | |
| button_primary_background_fill_dark = "#2563eb", | |
| button_primary_background_fill_hover = "#1d4ed8", | |
| button_primary_background_fill_hover_dark = "#1d4ed8", | |
| button_primary_text_color = "#ffffff", | |
| button_primary_text_color_dark = "#ffffff", | |
| button_secondary_background_fill = "#ffffff", | |
| button_secondary_background_fill_dark = "#ffffff", | |
| button_secondary_background_fill_hover = "#f1f5f9", | |
| button_secondary_background_fill_hover_dark="#f1f5f9", | |
| button_secondary_text_color = "#374151", | |
| button_secondary_text_color_dark = "#374151", | |
| button_secondary_border_color = "#e2e8f0", | |
| button_secondary_border_color_dark = "#e2e8f0", | |
| # ── Checkbox / Radio ────────────────────────────────────────── | |
| checkbox_background_color = "#ffffff", | |
| checkbox_background_color_dark = "#ffffff", | |
| checkbox_background_color_selected = "#2563eb", | |
| checkbox_background_color_selected_dark = "#2563eb", | |
| checkbox_border_color = "#d1d5db", | |
| checkbox_border_color_dark = "#d1d5db", | |
| checkbox_border_color_focus = "#3b82f6", | |
| checkbox_border_color_focus_dark = "#3b82f6", | |
| # ── Slider ──────────────────────────────────────────────────── | |
| slider_color = "#2563eb", | |
| slider_color_dark = "#2563eb", | |
| # ── Table ───────────────────────────────────────────────────── | |
| table_odd_background_fill = "#f8fafc", | |
| table_odd_background_fill_dark = "#f8fafc", | |
| table_even_background_fill = "#ffffff", | |
| table_even_background_fill_dark = "#ffffff", | |
| table_border_color = "#e2e8f0", | |
| table_border_color_dark = "#e2e8f0", | |
| # ── Links ───────────────────────────────────────────────────── | |
| link_text_color = "#2563eb", | |
| link_text_color_dark = "#2563eb", | |
| link_text_color_hover = "#1d4ed8", | |
| link_text_color_hover_dark = "#1d4ed8", | |
| link_text_color_visited = "#7c3aed", | |
| link_text_color_visited_dark = "#7c3aed", | |
| # ── Accent ──────────────────────────────────────────────────── | |
| color_accent_soft = "#eff6ff", | |
| color_accent_soft_dark = "#eff6ff", | |
| ), | |
| ) as app: | |
| gr.HTML(HEADER_HTML) | |
| with gr.Tabs(): | |
| # ── TAB 1: Single Stage Prediction ──────────────────────────────── | |
| with gr.Tab("🔮 Stage Prediction"): | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| Fill in the fields below to predict menopausal stage for a single individual. | |
| All fields are optional — the pipeline handles missing values automatically. | |
| A timestamped output folder is created in | |
| <code>swan_ml_output/</code> for every run. | |
| </div>""") | |
| with gr.Row(): | |
| # ── Input column ────────────────────────────────────────── | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.HTML('<div class="section-label">Demographics</div>') | |
| with gr.Row(): | |
| age = gr.Slider( | |
| minimum=35, maximum=75, value=48, step=1, | |
| label="Age (AGE7)", | |
| ) | |
| race = gr.Dropdown( | |
| choices=[1, 2, 3, 4, 5], value=1, | |
| label="Race (RACE)", | |
| info="1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic", | |
| ) | |
| langint = gr.Dropdown( | |
| choices=[1, 2, 3], value=1, | |
| label="Interview Language (LANGINT7)", | |
| info="1=English, 2=Spanish, 3=Other", | |
| ) | |
| with gr.Group(): | |
| gr.HTML('<div class="section-label">Vasomotor Symptoms</div>') | |
| with gr.Row(): | |
| hot_flash = gr.Slider( | |
| minimum=1, maximum=5, value=1, step=1, | |
| label="Hot Flash Severity (HOTFLAS7)", | |
| info="1=None, 5=Very severe", | |
| ) | |
| num_hot_flash = gr.Slider( | |
| minimum=0, maximum=15, value=0, step=1, | |
| label="# Hot Flashes/Week (NUMHOTF7)", | |
| ) | |
| bothersome_hf = gr.Slider( | |
| minimum=1, maximum=4, value=1, step=1, | |
| label="How Bothersome (BOTHOTF7)", | |
| info="1=Not at all, 4=Extremely", | |
| ) | |
| with gr.Group(): | |
| gr.HTML('<div class="section-label">Sleep & Mood</div>') | |
| with gr.Row(): | |
| sleep_quality = gr.Slider( | |
| minimum=1, maximum=5, value=2, step=1, | |
| label="Sleep Quality (SLEEPQL7)", | |
| info="1=Very good, 5=Very poor", | |
| ) | |
| depression = gr.Slider( | |
| minimum=0, maximum=4, value=0, step=1, | |
| label="Depression Indicator (DEPRESS7)", | |
| info="0=No, higher=more severe", | |
| ) | |
| with gr.Row(): | |
| mood_change = gr.Slider( | |
| minimum=1, maximum=5, value=1, step=1, | |
| label="Mood Changes (MOODCHG7)", | |
| info="1=None, 5=Severe", | |
| ) | |
| irritability = gr.Slider( | |
| minimum=1, maximum=5, value=1, step=1, | |
| label="Irritability (IRRITAB7)", | |
| ) | |
| with gr.Group(): | |
| gr.HTML('<div class="section-label">Physical & Gynaecological</div>') | |
| with gr.Row(): | |
| pain = gr.Slider( | |
| minimum=0, maximum=5, value=0, step=1, | |
| label="Pain Indicator (PAIN17)", | |
| ) | |
| abbleed = gr.Dropdown( | |
| choices=[0, 1, 2], value=0, | |
| label="Abnormal Bleeding (ABBLEED7)", | |
| info="0=No, 1=Yes, 2=Unsure", | |
| ) | |
| with gr.Row(): | |
| vaginal_dryness = gr.Slider( | |
| minimum=0, maximum=5, value=0, step=1, | |
| label="Vaginal Dryness (VAGINDR7)", | |
| ) | |
| lmp_day = gr.Number( | |
| value=None, | |
| label="LMP Day (LMPDAY7)", | |
| info="Day of last menstrual period (optional)", | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["RandomForest", "LogisticRegression"], | |
| value="RandomForest", | |
| label="Model", | |
| info="RandomForest: higher accuracy | " | |
| "LogisticRegression: more interpretable", | |
| ) | |
| predict_btn = gr.Button( | |
| "🔮 Predict Stage", variant="primary", size="lg" | |
| ) | |
| # ── Output column ───────────────────────────────────────── | |
| with gr.Column(scale=3): | |
| result_html = gr.HTML( | |
| '<div class="placeholder-msg">Fill in the form and click Predict Stage</div>' | |
| ) | |
| result_chart = gr.Plot(label="Stage Probabilities") | |
| confidence_note = gr.Textbox( | |
| label="Confidence Note", interactive=False, lines=2 | |
| ) | |
| compare_html = gr.HTML() | |
| stage_download = gr.File( | |
| label="Download Prediction CSV", interactive=False | |
| ) | |
| predict_btn.click( | |
| fn = predict_single_stage, | |
| inputs = [ | |
| age, race, langint, | |
| hot_flash, num_hot_flash, bothersome_hf, | |
| sleep_quality, depression, mood_change, irritability, | |
| pain, abbleed, vaginal_dryness, lmp_day, | |
| model_choice, | |
| ], | |
| outputs = [ | |
| result_html, result_chart, confidence_note, | |
| compare_html, stage_download, | |
| ], | |
| ) | |
| # ── TAB 2: Batch Stage Prediction ───────────────────────────────── | |
| with gr.Tab("📁 Batch Stage Prediction"): | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| Upload a CSV file with individual feature values for batch prediction. | |
| Results + charts + a summary report are saved to a timestamped folder | |
| inside <code>swan_ml_output/</code>. | |
| </div>""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| batch_file = gr.File( | |
| label="Upload stage_input.csv", | |
| file_types=[".csv"], | |
| ) | |
| batch_model = gr.Radio( | |
| choices=["RandomForest", "LogisticRegression"], | |
| value="RandomForest", | |
| label="Model", | |
| ) | |
| gr.HTML(""" | |
| <div class="format-hint"> | |
| <div class="format-hint-title">Expected CSV Format</div> | |
| <pre>individual,AGE7,RACE,HOTFLAS7,... | |
| Person_001,48,1,2,... | |
| Person_002,52,2,1,...</pre> | |
| <div class="format-hint-note"> | |
| See the test-csv/ folder for an approved example. | |
| </div> | |
| </div>""") | |
| batch_predict_btn = gr.Button( | |
| "🚀 Run Batch Prediction", variant="primary" | |
| ) | |
| with gr.Column(scale=2): | |
| batch_summary_html = gr.HTML( | |
| '<div class="placeholder-msg">Upload a CSV to begin</div>' | |
| ) | |
| batch_download = gr.File( | |
| label="Download Predictions CSV", interactive=False | |
| ) | |
| batch_results_df = gr.DataFrame( | |
| label="Results Preview (first 20 rows)", | |
| interactive=False, | |
| ) | |
| batch_predict_btn.click( | |
| fn = predict_batch_stage, | |
| inputs = [batch_file, batch_model], | |
| outputs = [batch_download, batch_summary_html, batch_results_df], | |
| ) | |
| # ── TAB 3: Symptom Forecast ─────────────────────────────────────── | |
| with gr.Tab("🌊 Symptom Forecast"): | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| Predict hot flash and mood change probability based on cycle day | |
| (calculated from Last Menstrual Period date). | |
| All outputs are saved to a timestamped folder inside | |
| <code>swan_ml_output/</code>. | |
| </div>""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sym_individual = gr.Textbox( | |
| label="Individual ID (optional)", | |
| placeholder="e.g., Patient_001", | |
| ) | |
| sym_lmp = gr.Textbox( | |
| label="Last Menstrual Period (LMP)", | |
| placeholder="2026-01-15 or 15 (day of month)", | |
| info="Full date (YYYY-MM-DD) or day-of-month integer", | |
| ) | |
| sym_date = gr.Textbox( | |
| label="Target Date (optional)", | |
| placeholder="2026-02-27 (defaults to today)", | |
| info="Date to forecast for (YYYY-MM-DD)", | |
| ) | |
| sym_cycle = gr.Slider( | |
| minimum=21, maximum=40, value=28, step=1, | |
| label="Cycle Length (days)", | |
| ) | |
| sym_predict_btn = gr.Button( | |
| "🌊 Forecast Symptoms", variant="primary" | |
| ) | |
| with gr.Column(scale=2): | |
| sym_result_html = gr.HTML( | |
| '<div class="placeholder-msg">Enter LMP date and click Forecast</div>' | |
| ) | |
| sym_chart = gr.Plot(label="Cycle Position") | |
| sym_download = gr.File( | |
| label="Download Forecast CSV", interactive=False | |
| ) | |
| sym_predict_btn.click( | |
| fn = predict_symptoms, | |
| inputs = [sym_individual, sym_lmp, sym_date, sym_cycle], | |
| outputs = [sym_result_html, sym_chart, sym_download], | |
| ) | |
| gr.HTML('<hr class="section-divider">') | |
| gr.HTML('<div class="batch-section-label">📁 Batch Symptom Forecasting</div>') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sym_batch_file = gr.File( | |
| label="Upload symptoms_input.csv", | |
| file_types=[".csv"], | |
| ) | |
| sym_lmp_col = gr.Textbox( | |
| label="LMP Column Name", value="LMP" | |
| ) | |
| sym_date_col = gr.Textbox( | |
| label="Date Column Name (optional)", value="date" | |
| ) | |
| sym_cycle_batch = gr.Slider( | |
| minimum=21, maximum=40, value=28, step=1, | |
| label="Default Cycle Length", | |
| ) | |
| sym_batch_btn = gr.Button( | |
| "🌊 Run Batch Forecast", variant="primary" | |
| ) | |
| with gr.Column(scale=2): | |
| sym_batch_summary = gr.HTML( | |
| '<div class="placeholder-msg">Upload a CSV to begin</div>' | |
| ) | |
| sym_batch_download = gr.File( | |
| label="Download Symptom Forecast CSV", interactive=False | |
| ) | |
| sym_batch_df = gr.DataFrame( | |
| label="Results Preview", | |
| interactive=False, | |
| ) | |
| sym_batch_btn.click( | |
| fn = predict_symptoms_batch, | |
| inputs = [ | |
| sym_batch_file, sym_lmp_col, | |
| sym_date_col, sym_cycle_batch, | |
| ], | |
| outputs = [sym_batch_download, sym_batch_summary, sym_batch_df], | |
| ) | |
| # ── TAB 4: Education ────────────────────────────────────────────── | |
| with gr.Tab("📚 Menopause Education"): | |
| gr.HTML(EDUCATION_HTML) | |
| # ── TAB 5: Feature Reference ────────────────────────────────────── | |
| with gr.Tab("🔬 Feature Reference"): | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| Canonical list of features used by the trained models | |
| (from <code>forecast_metadata.json</code>). | |
| For batch CSV uploads, column names must match these feature names. | |
| </div>""") | |
| gr.HTML(get_feature_reference()) | |
| # ── TAB 6: Model Status ─────────────────────────────────────────── | |
| with gr.Tab("⚙️ Model Status"): | |
| gr.HTML(get_model_status()) | |
| gr.HTML(""" | |
| <div class="setup-card"> | |
| <div class="setup-title">🚀 Setup Instructions</div> | |
| <div class="setup-step"> | |
| <p><strong>Step 1 — Train models:</strong></p> | |
| <pre class="code-block">python menopause.py</pre> | |
| <p><strong>Step 2 — Verify artifacts:</strong></p> | |
| <pre class="code-block">ls swan_ml_output/ | |
| # rf_pipeline.pkl lr_pipeline.pkl forecast_metadata.json</pre> | |
| <p><strong>Step 3 — Run this app:</strong></p> | |
| <pre class="code-block">python app.py</pre> | |
| <p><strong>Step 4 — Deploy on Hugging Face Spaces:</strong></p> | |
| <pre class="code-block">git lfs install | |
| git lfs track "*.pkl" | |
| git add . | |
| git commit -m "SWAN menopause prediction app" | |
| git push</pre> | |
| <p><strong>Output folder structure (per run):</strong></p> | |
| <pre class="code-block">swan_ml_output/ | |
| <YYYYMMDD_HHMMSS>/ | |
| charts/ ← PNG visualizations | |
| predictions/ ← CSV result files | |
| reports/ ← TXT summary reports</pre> | |
| </div> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div class="app-footer"> | |
| SWAN Menopause Prediction App · Built with Gradio · | |
| For research & educational use only · Not for clinical diagnosis · | |
| <a href="https://www.swanstudy.org/" target="_blank">SWAN Study</a> | |
| </div>""") | |
| return app | |
| # ── Entry point ─────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| demo = build_app() | |
| demo.launch( | |
| server_name = "0.0.0.0", | |
| server_port = int(os.environ.get("PORT", 7860)), | |
| share = False, | |
| show_error = True, | |
| ) | |