techatcreated commited on
Commit
66d45ea
·
verified ·
1 Parent(s): 6f5fd0f
app.py ADDED
@@ -0,0 +1,1823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SWAN Menopause Stage Prediction & Forecasting — Gradio UI
3
+ Hugging Face Spaces deployment-ready.
4
+
5
+ Run locally: python app.py
6
+ Deploy: Push to a HF Space with SDK=gradio
7
+
8
+ Output structure (per execution):
9
+ swan_ml_output/
10
+ <YYYYMMDD_HHMMSS>/
11
+ charts/ ← PNG visualizations
12
+ predictions/ ← CSV result files
13
+ reports/ ← TXT summary reports
14
+ """
15
+
16
+ import os
17
+ import json
18
+ import warnings
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import matplotlib
26
+ matplotlib.use("Agg")
27
+ import matplotlib.pyplot as plt
28
+
29
+ warnings.filterwarnings("ignore")
30
+
31
+ # ── Gradio ────────────────────────────────────────────────────────────────────
32
+ import gradio as gr
33
+
34
+ # ── Local ML module ───────────────────────────────────────────────────────────
35
+ try:
36
+ from menopause import (
37
+ MenopauseForecast,
38
+ SymptomCycleForecaster,
39
+ load_forecast_model,
40
+ )
41
+ _MODULE_AVAILABLE = True
42
+ except ImportError:
43
+ _MODULE_AVAILABLE = False
44
+
45
+ # ── Model loading ─────────────────────────────────────────────────────────────
46
+ FORECAST_DIR = os.environ.get("FORECAST_DIR", "swan_ml_output")
47
+ OUTPUT_BASE = Path(FORECAST_DIR)
48
+
49
+ _forecast: Optional[MenopauseForecast] = None # type: ignore[type-arg]
50
+ _metadata: dict = {}
51
+
52
+
53
+ def _load_models():
54
+ """Attempt to load saved joblib pipelines. Returns (success, message)."""
55
+ global _forecast, _metadata
56
+
57
+ if not _MODULE_AVAILABLE:
58
+ return False, "menopause.py not found. Make sure it is in the same directory."
59
+
60
+ meta_path = Path(FORECAST_DIR) / "forecast_metadata.json"
61
+ rf_path = Path(FORECAST_DIR) / "rf_pipeline.pkl"
62
+ lr_path = Path(FORECAST_DIR) / "lr_pipeline.pkl"
63
+
64
+ if not all(p.exists() for p in (meta_path, rf_path, lr_path)):
65
+ return (
66
+ False,
67
+ f"Model artifacts not found in '{FORECAST_DIR}'. "
68
+ "Run `python menopause.py` to train and save the models first.",
69
+ )
70
+
71
+ try:
72
+ _forecast = load_forecast_model(FORECAST_DIR)
73
+ with open(meta_path) as fh:
74
+ _metadata = json.load(fh)
75
+ return True, f"✅ Models loaded — {len(_metadata.get('feature_names', []))} features"
76
+ except Exception as exc:
77
+ return False, f"Error loading models: {exc}"
78
+
79
+
80
+ _MODEL_OK, _MODEL_MSG = _load_models()
81
+
82
+
83
+ # ── Output directory management ───────────────────────────────────────────────
84
+
85
+ def _make_run_dir() -> Path:
86
+ """Create and return a unique timestamped run directory under swan_ml_output/."""
87
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
88
+ run_dir = OUTPUT_BASE / ts
89
+ (run_dir / "charts").mkdir(parents=True, exist_ok=True)
90
+ (run_dir / "predictions").mkdir(parents=True, exist_ok=True)
91
+ (run_dir / "reports").mkdir(parents=True, exist_ok=True)
92
+ return run_dir
93
+
94
+
95
+ def _get_file_path(file_obj) -> Optional[str]:
96
+ """
97
+ Safely extract a file-system path from a Gradio file component value.
98
+
99
+ Gradio ≤ 3.x → returns a file-like object with a .name attribute.
100
+ Gradio 4.x → returns a str path (or NamedString subclass).
101
+ This helper handles both.
102
+ """
103
+ if file_obj is None:
104
+ return None
105
+ if hasattr(file_obj, "name"):
106
+ return file_obj.name
107
+ return str(file_obj)
108
+
109
+
110
+ # ── Constants & helpers ───────────────────────────────────────────────────────
111
+
112
+ STAGE_COLORS = {"pre": "#16a34a", "peri": "#d97706", "post": "#7c3aed"}
113
+ STAGE_EMOJI = {"pre": "🟢", "peri": "🟡", "post": "🟣"}
114
+ STAGE_LABELS = {
115
+ "pre": "Pre-Menopausal",
116
+ "peri": "Peri-Menopausal",
117
+ "post": "Post-Menopausal",
118
+ }
119
+
120
+ STAGE_INFO = {
121
+ "pre": {
122
+ "title": "Pre-Menopausal",
123
+ "description": "Regular menstrual cycles with typical hormonal fluctuations. Ovarian function is normal.",
124
+ "symptoms": ["Regular periods", "Normal hormone levels", "Potential mild PMS"],
125
+ "guidance": "Maintain regular check-ups. Track your cycle and note any changes.",
126
+ },
127
+ "peri": {
128
+ "title": "Peri-Menopausal (Transition)",
129
+ "description": "Hormonal changes begin — estrogen and progesterone levels fluctuate. Cycles become irregular.",
130
+ "symptoms": ["Irregular periods", "Hot flashes", "Sleep disturbances", "Mood changes", "Night sweats"],
131
+ "guidance": "Consult your healthcare provider. Lifestyle adjustments (diet, exercise, sleep) can help.",
132
+ },
133
+ "post": {
134
+ "title": "Post-Menopausal",
135
+ "description": "12+ months since last menstrual period. Estrogen remains at consistently lower levels.",
136
+ "symptoms": ["No periods", "Possible continued hot flashes", "Vaginal dryness", "Bone density changes"],
137
+ "guidance": "Focus on bone health, cardiovascular health, and regular screenings. Discuss HRT options.",
138
+ },
139
+ }
140
+
141
+ # Feature descriptions keyed by the model's canonical feature names
142
+ FEATURE_DESCRIPTIONS = {
143
+ "PAIN17": "Pain indicator (visit-specific)",
144
+ "PAINTW17": "Pain two-week indicator",
145
+ "PAIN27": "Secondary pain indicator",
146
+ "PAINTW27": "Secondary pain two-week indicator",
147
+ "SLEEP17": "Sleep disturbance pattern 1",
148
+ "SLEEP27": "Sleep disturbance pattern 2",
149
+ "BCOHOTH7": "Birth control — other method",
150
+ "EXERCIS7": "General exercise indicator",
151
+ "EXERHAR7": "Vigorous exercise",
152
+ "EXEROST7": "Osteoporosis exercise",
153
+ "EXERMEN7": "Exercise — mental health",
154
+ "EXERLOO7": "Exercise lookalike",
155
+ "EXERMEM7": "Exercise — memory",
156
+ "EXERPER7": "Exercise perception",
157
+ "EXERGEN7": "General exercise type",
158
+ "EXERWGH7": "Weight exercise",
159
+ "EXERADV7": "Exercise advice indicator",
160
+ "EXEROTH7": "Other exercise",
161
+ "EXERSPE7": "Specific exercise",
162
+ "ABBLEED7": "Abnormal bleeding (0=no, 1=yes)", # ← correct feature name
163
+ "BLEEDNG7": "Bleeding pattern",
164
+ "LMPDAY7": "Last menstrual period day",
165
+ "DEPRESS7": "Depression indicator",
166
+ "SEX17": "Sexual activity indicator 1",
167
+ "SEX27": "Sexual activity indicator 2",
168
+ "SEX37": "Sexual activity indicator 3",
169
+ "SEX47": "Sexual activity indicator 4",
170
+ "SEX57": "Sexual activity indicator 5",
171
+ "SEX67": "Sexual activity indicator 6",
172
+ "SEX77": "Sexual activity indicator 7",
173
+ "SEX87": "Sexual activity indicator 8",
174
+ "SEX97": "Sexual activity indicator 9",
175
+ "SEX107": "Sexual activity indicator 10",
176
+ "SEX117": "Sexual activity indicator 11",
177
+ "SEX127": "Sexual activity indicator 12",
178
+ "SMOKERE7": "Smoking status",
179
+ "HOTFLAS7": "Hot flash severity (1=none, 5=very severe)",
180
+ "NUMHOTF7": "Number of hot flashes per week",
181
+ "BOTHOTF7": "How bothersome are hot flashes",
182
+ "IRRITAB7": "Irritability level",
183
+ "VAGINDR7": "Vaginal dryness",
184
+ "MOODCHG7": "Mood change frequency",
185
+ "SLEEPQL7": "Sleep quality score",
186
+ "PHYSILL7": "Physical illness indicators",
187
+ "HOTHEAD7": "Hot flashes with headache",
188
+ "EXER12H7": "Exercise in last 12 hours",
189
+ "ALCO24H7": "Alcohol in last 24h",
190
+ "AGE7": "Age (years)",
191
+ "RACE": "Race (1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic)",
192
+ "LANGINT7": "Interview language indicator",
193
+ }
194
+
195
+
196
+ def _confidence_color(conf: float) -> str:
197
+ if conf >= 0.8:
198
+ return "#16a34a"
199
+ elif conf >= 0.6:
200
+ return "#d97706"
201
+ return "#dc2626"
202
+
203
+
204
+ # ── Chart builders ────────────────────────────────────────────────────────────
205
+
206
+ def _make_proba_chart(
207
+ probabilities: dict,
208
+ predicted_stage: str,
209
+ save_path: Optional[Path] = None,
210
+ ) -> plt.Figure:
211
+ """Horizontal bar chart for stage probabilities. Optionally saves PNG."""
212
+ fig, ax = plt.subplots(figsize=(6, 3.5))
213
+ fig.patch.set_facecolor("#1a1a2e")
214
+ ax.set_facecolor("#16213e")
215
+
216
+ stages = list(probabilities.keys())
217
+ probs = [probabilities[s] * 100 for s in stages]
218
+ colors = [STAGE_COLORS.get(s, "#607d8b") for s in stages]
219
+ edge_colors = ["white" if s == predicted_stage else "none" for s in stages]
220
+ lws = [2.5 if s == predicted_stage else 0 for s in stages]
221
+
222
+ bars = ax.barh(stages, probs, color=colors, edgecolor=edge_colors,
223
+ linewidth=lws, height=0.5, zorder=3)
224
+
225
+ for bar, prob in zip(bars, probs):
226
+ ax.text(
227
+ min(prob + 1, 98), bar.get_y() + bar.get_height() / 2,
228
+ f"{prob:.1f}%",
229
+ va="center", ha="left", color="white", fontsize=11, fontweight="bold",
230
+ )
231
+
232
+ labels = [STAGE_LABELS.get(s, s) for s in stages]
233
+ ax.set_yticks(range(len(stages)))
234
+ ax.set_yticklabels(labels, color="white", fontsize=10)
235
+ ax.set_xlim(0, 105)
236
+ ax.tick_params(colors="white", labelsize=11)
237
+ ax.spines[["top", "right", "left", "bottom"]].set_visible(False)
238
+ ax.xaxis.set_visible(False)
239
+ for spine in ax.spines.values():
240
+ spine.set_color("#333")
241
+ ax.set_title("Stage Probabilities", color="white", fontsize=12,
242
+ pad=10, fontweight="bold")
243
+ ax.grid(axis="x", color="#333", linestyle="--", linewidth=0.5, zorder=0)
244
+ fig.tight_layout()
245
+
246
+ if save_path:
247
+ fig.savefig(save_path, dpi=150, bbox_inches="tight",
248
+ facecolor=fig.get_facecolor())
249
+ return fig
250
+
251
+
252
+ def _make_cycle_chart(
253
+ cycle_day: int,
254
+ cycle_length: int = 28,
255
+ hot_prob: float = None,
256
+ mood_prob: float = None,
257
+ save_path: Optional[Path] = None,
258
+ ) -> plt.Figure:
259
+ """Circular cycle-day visualization. Optionally saves PNG."""
260
+ fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
261
+ fig.patch.set_facecolor("#1a1a2e")
262
+ ax.set_facecolor("#16213e")
263
+
264
+ days = np.linspace(0, 2 * np.pi, cycle_length, endpoint=False)
265
+ for i, d in enumerate(days):
266
+ phase = i / cycle_length
267
+ color = plt.cm.RdYlGn(1 - phase)
268
+ ax.bar(d, 1, width=2 * np.pi / cycle_length * 0.9,
269
+ bottom=0.5, color=color, alpha=0.4, zorder=1)
270
+
271
+ if cycle_day is not None:
272
+ angle = (cycle_day - 1) / cycle_length * 2 * np.pi
273
+ ax.scatter([angle], [1.05], s=200, color="#ff6b6b", zorder=5, linewidths=2)
274
+ ax.annotate(
275
+ f"Day {cycle_day}",
276
+ xy=(angle, 1.05), xytext=(0, 0),
277
+ textcoords="offset points", ha="center", va="center",
278
+ color="white", fontsize=12, fontweight="bold",
279
+ )
280
+
281
+ ax.set_rticks([])
282
+ ax.set_xticks([i * 2 * np.pi / 4 for i in range(4)])
283
+ ax.set_xticklabels(["Day 1", "Day 7", "Day 14", "Day 21"],
284
+ color="#aaa", fontsize=9)
285
+ ax.set_yticklabels([])
286
+ ax.spines["polar"].set_color("#333")
287
+ ax.grid(color="#333", linewidth=0.5)
288
+
289
+ title = "Cycle Position"
290
+ if hot_prob is not None:
291
+ title += f"\n🔥 {hot_prob:.0%} 😤 {mood_prob:.0%}"
292
+ ax.set_title(title, color="white", fontsize=11, pad=20, fontweight="bold")
293
+ fig.tight_layout()
294
+
295
+ if save_path:
296
+ fig.savefig(save_path, dpi=150, bbox_inches="tight",
297
+ facecolor=fig.get_facecolor())
298
+ return fig
299
+
300
+
301
+ def _make_batch_summary_chart(results_df: pd.DataFrame,
302
+ save_path: Optional[Path] = None) -> None:
303
+ """Stage distribution + confidence histogram for batch runs. Saves PNG."""
304
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
305
+ fig.patch.set_facecolor("#1a1a2e")
306
+
307
+ # Stage distribution pie
308
+ stage_counts = results_df["predicted_stage"].value_counts()
309
+ colors = [STAGE_COLORS.get(s, "#607d8b") for s in stage_counts.index]
310
+ ax1.set_facecolor("#16213e")
311
+ wedges, texts, autotexts = ax1.pie(
312
+ stage_counts.values, labels=stage_counts.index,
313
+ colors=colors, autopct="%1.0f%%",
314
+ textprops={"color": "white", "fontsize": 10},
315
+ )
316
+ for at in autotexts:
317
+ at.set_color("white")
318
+ ax1.set_title("Stage Distribution", color="white", fontsize=11, fontweight="bold")
319
+
320
+ # Confidence histogram
321
+ ax2.set_facecolor("#16213e")
322
+ if "confidence" in results_df.columns:
323
+ conf = results_df["confidence"].dropna()
324
+ ax2.hist(conf, bins=min(10, len(conf)), color="#3B82F6",
325
+ edgecolor="#1a1a2e", alpha=0.8)
326
+ ax2.axvline(0.8, color="#4CAF50", linestyle="--",
327
+ linewidth=1.5, label="High (0.80)")
328
+ ax2.axvline(0.6, color="#FF9800", linestyle="--",
329
+ linewidth=1.5, label="Med (0.60)")
330
+ ax2.legend(fontsize=8, labelcolor="white", facecolor="#0d0d1a")
331
+ ax2.set_xlabel("Confidence", color="#aaa", fontsize=9)
332
+ ax2.set_ylabel("Count", color="#aaa", fontsize=9)
333
+ ax2.tick_params(colors="white", labelsize=9)
334
+ for sp in ["top", "right"]:
335
+ ax2.spines[sp].set_visible(False)
336
+ for sp in ["left", "bottom"]:
337
+ ax2.spines[sp].set_color("#333")
338
+ ax2.set_title("Confidence Distribution", color="white",
339
+ fontsize=11, fontweight="bold")
340
+
341
+ fig.tight_layout()
342
+ if save_path:
343
+ fig.savefig(save_path, dpi=150, bbox_inches="tight",
344
+ facecolor=fig.get_facecolor())
345
+ plt.close(fig)
346
+
347
+
348
+ # ── Text report writers ───────────────────────────────────────────────────────
349
+
350
+ def _write_single_stage_report(
351
+ path: Path,
352
+ stage: str,
353
+ confidence: float,
354
+ probabilities: dict,
355
+ model: str,
356
+ comparison: dict,
357
+ input_features: dict,
358
+ ):
359
+ lines = [
360
+ "=" * 60,
361
+ "SWAN MENOPAUSE STAGE PREDICTION REPORT",
362
+ f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
363
+ "=" * 60,
364
+ "",
365
+ f"Predicted Stage : {STAGE_LABELS.get(stage, stage)}",
366
+ f"Model : {model}",
367
+ f"Confidence : {confidence:.1%}",
368
+ "",
369
+ "Stage Probabilities:",
370
+ ]
371
+ for s, p in probabilities.items():
372
+ bar = "█" * int(p * 20)
373
+ lines.append(f" {s:<6} : {p:.4f} {bar}")
374
+ lines += [
375
+ "",
376
+ "Model Comparison:",
377
+ f" RandomForest → {comparison['RandomForest']['stage']}"
378
+ f" ({comparison['RandomForest'].get('confidence', 0):.1%})",
379
+ f" LogisticRegression → {comparison['LogisticRegression']['stage']}"
380
+ f" ({comparison['LogisticRegression'].get('confidence', 0):.1%})",
381
+ "",
382
+ "Input Features (non-NaN):",
383
+ ]
384
+ for k, v in input_features.items():
385
+ if v is not None and not (isinstance(v, float) and np.isnan(v)):
386
+ lines.append(f" {k:<12} = {v}")
387
+ lines += [
388
+ "",
389
+ "⚠️ For research/educational use only. Not a clinical diagnosis.",
390
+ "=" * 60,
391
+ ]
392
+ path.write_text("\n".join(lines), encoding="utf-8")
393
+
394
+
395
+ def _write_batch_report(
396
+ path: Path,
397
+ results: pd.DataFrame,
398
+ model: str,
399
+ run_dir: Path,
400
+ ):
401
+ total = len(results)
402
+ dist = results["predicted_stage"].value_counts().to_dict() \
403
+ if "predicted_stage" in results.columns else {}
404
+ if "confidence" in results.columns:
405
+ conf = results["confidence"]
406
+ mean_c = conf.mean(); min_c = conf.min(); max_c = conf.max()
407
+ high = int((conf > 0.8).sum())
408
+ medium = int(((conf > 0.6) & (conf <= 0.8)).sum())
409
+ low = int((conf <= 0.6).sum())
410
+ else:
411
+ mean_c = min_c = max_c = high = medium = low = 0
412
+
413
+ lines = [
414
+ "=" * 60,
415
+ "SWAN BATCH STAGE PREDICTION REPORT",
416
+ f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
417
+ f"Model : {model}",
418
+ "=" * 60,
419
+ "",
420
+ f"Total Individuals : {total}",
421
+ "",
422
+ "Stage Distribution:",
423
+ ]
424
+ for stage in ["pre", "peri", "post"]:
425
+ count = dist.get(stage, 0)
426
+ pct = count / total * 100 if total else 0
427
+ lines.append(f" {stage:<6} : {count} ({pct:.1f}%)")
428
+ lines += [
429
+ "",
430
+ "Confidence Scores:",
431
+ f" Mean : {mean_c:.4f}",
432
+ f" Min : {min_c:.4f}",
433
+ f" Max : {max_c:.4f}",
434
+ "",
435
+ "Confidence Distribution:",
436
+ f" High (>0.80) : {high}/{total} ({high/total*100:.1f}%)" if total else " N/A",
437
+ f" Medium (0.60-0.80) : {medium}/{total} ({medium/total*100:.1f}%)" if total else " N/A",
438
+ f" Low (≤0.60) : {low}/{total} ({low/total*100:.1f}%)" if total else " N/A",
439
+ "",
440
+ f"Output Directory : {run_dir}",
441
+ "",
442
+ "⚠️ For research/educational use only. Not a clinical diagnosis.",
443
+ "=" * 60,
444
+ ]
445
+ path.write_text("\n".join(lines), encoding="utf-8")
446
+
447
+
448
+ def _write_symptom_report(
449
+ path: Path,
450
+ individual_id: str,
451
+ lmp: str,
452
+ target_date: str,
453
+ cycle_day: int,
454
+ cycle_length: int,
455
+ hot_prob: float,
456
+ hot_pred: bool,
457
+ mood_prob: float,
458
+ mood_pred: bool,
459
+ ):
460
+ hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0
461
+ mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0
462
+ lines = [
463
+ "=" * 60,
464
+ "SWAN SYMPTOM CYCLE FORECAST REPORT",
465
+ f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
466
+ "=" * 60,
467
+ "",
468
+ f"Individual : {individual_id or 'N/A'}",
469
+ f"LMP : {lmp}",
470
+ f"Target Date : {target_date or 'Today'}",
471
+ f"Cycle Length : {cycle_length} days",
472
+ f"Cycle Day : {cycle_day}",
473
+ "",
474
+ "Symptom Probabilities:",
475
+ f" Hot Flash : {hp:.4f} {'[ELEVATED RISK]' if hot_pred else '[LOW RISK]'}",
476
+ f" Mood Change : {mp:.4f} {'[ELEVATED RISK]' if mood_pred else '[LOW RISK]'}",
477
+ "",
478
+ "⚠️ For research/educational use only. Not a clinical diagnosis.",
479
+ "=" * 60,
480
+ ]
481
+ path.write_text("\n".join(lines), encoding="utf-8")
482
+
483
+
484
+ def _write_batch_symptom_report(
485
+ path: Path,
486
+ results: pd.DataFrame,
487
+ cycle_length: int,
488
+ run_dir: Path,
489
+ ):
490
+ total = len(results)
491
+ hot_flags = int(results["hotflash_pred"].sum()) \
492
+ if "hotflash_pred" in results.columns else 0
493
+ mood_flags = int(results["mood_pred"].sum()) \
494
+ if "mood_pred" in results.columns else 0
495
+ mean_hot = float(results["hotflash_prob"].mean()) \
496
+ if "hotflash_prob" in results.columns else 0.0
497
+ mean_mood = float(results["mood_prob"].mean()) \
498
+ if "mood_prob" in results.columns else 0.0
499
+ lines = [
500
+ "=" * 60,
501
+ "SWAN BATCH SYMPTOM FORECAST REPORT",
502
+ f"Generated : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
503
+ f"Cycle Length : {cycle_length} days",
504
+ "=" * 60,
505
+ "",
506
+ f"Total Individuals : {total}",
507
+ f"Hot Flash Risk : {hot_flags}/{total} elevated",
508
+ f"Mood Change Risk : {mood_flags}/{total} elevated",
509
+ f"Avg Hot Flash Prob : {mean_hot:.4f}",
510
+ f"Avg Mood Prob : {mean_mood:.4f}",
511
+ "",
512
+ f"Output Directory : {run_dir}",
513
+ "",
514
+ "⚠️ For research/educational use only. Not a clinical diagnosis.",
515
+ "=" * 60,
516
+ ]
517
+ path.write_text("\n".join(lines), encoding="utf-8")
518
+
519
+
520
+ # ── Core prediction functions ─────────────────────────────────────────────────
521
+
522
+ def predict_single_stage(
523
+ age, race, langint,
524
+ hot_flash, num_hot_flash, bothersome_hf,
525
+ sleep_quality, depression_indicator, mood_change, irritability,
526
+ pain_indicator, abbleed, vaginal_dryness, lmp_day,
527
+ model_choice,
528
+ ):
529
+ """
530
+ Single-person stage prediction.
531
+
532
+ Returns (stage_html, chart_fig, conf_note, compare_html, csv_download_path).
533
+ """
534
+ if not _MODEL_OK:
535
+ return f"⚠️ {_MODEL_MSG}", None, "Models unavailable.", "", None
536
+
537
+ # Build feature dict using the model's canonical feature names
538
+ def _v(x):
539
+ return float(x) if x is not None else np.nan
540
+
541
+ feature_dict = {
542
+ "AGE7": _v(age),
543
+ "RACE": _v(race),
544
+ "LANGINT7": _v(langint),
545
+ "HOTFLAS7": _v(hot_flash),
546
+ "NUMHOTF7": _v(num_hot_flash),
547
+ "BOTHOTF7": _v(bothersome_hf),
548
+ "SLEEPQL7": _v(sleep_quality),
549
+ "DEPRESS7": _v(depression_indicator),
550
+ "MOODCHG7": _v(mood_change),
551
+ "IRRITAB7": _v(irritability),
552
+ "PAIN17": _v(pain_indicator),
553
+ "ABBLEED7": _v(abbleed), # ← correct feature name (was ABLEED7)
554
+ "VAGINDR7": _v(vaginal_dryness),
555
+ "LMPDAY7": _v(lmp_day) if lmp_day else np.nan,
556
+ }
557
+
558
+ try:
559
+ result = _forecast.predict_single(feature_dict, model=model_choice, return_proba=True)
560
+ stage = result["stage"]
561
+ confidence = result.get("confidence") or 0.0
562
+ proba = result.get("probabilities") or {}
563
+
564
+ # ── Create timestamped run directory ──────────────────────────────────
565
+ run_dir = _make_run_dir()
566
+
567
+ # ── Save probability chart (PNG) ──────────────────────────────────────
568
+ chart_path = run_dir / "charts" / "stage_probabilities.png"
569
+ chart_fig = _make_proba_chart(proba, stage, save_path=chart_path) if proba else None
570
+
571
+ # ── Save prediction CSV ───────────────────────────────────────────────
572
+ pred_row = {
573
+ "predicted_stage": stage,
574
+ "model": model_choice,
575
+ "confidence": round(confidence, 4),
576
+ **{f"prob_{k}": round(v, 4) for k, v in proba.items()},
577
+ "timestamp": datetime.now().isoformat(),
578
+ }
579
+ csv_path = run_dir / "predictions" / "stage_prediction.csv"
580
+ pd.DataFrame([pred_row]).to_csv(csv_path, index=False)
581
+
582
+ # ── Model comparison ──────────────────────────────────────────────────
583
+ comparison = _forecast.compare_models(feature_dict)
584
+ rf_stage = comparison["RandomForest"]["stage"]
585
+ lr_stage = comparison["LogisticRegression"]["stage"]
586
+ agree = rf_stage == lr_stage
587
+
588
+ # ── Save text report ──────────────────────────────────────────────────
589
+ txt_path = run_dir / "reports" / "prediction_summary.txt"
590
+ _write_single_stage_report(
591
+ txt_path, stage, confidence, proba,
592
+ model_choice, comparison, feature_dict,
593
+ )
594
+
595
+ # ── Build result card HTML ────────────────────────────────────────────
596
+ info = STAGE_INFO.get(stage, {})
597
+ emoji = STAGE_EMOJI.get(stage, "⚪")
598
+ color = STAGE_COLORS.get(stage, "#607d8b")
599
+ conf_color = _confidence_color(confidence)
600
+
601
+ symptom_tags = "".join(
602
+ f'<span style="background:{color}14;color:{color};padding:4px 10px;'
603
+ f'border-radius:20px;border:1px solid {color}44;font-size:12px;'
604
+ f'font-weight:500">{s}</span>'
605
+ for s in info.get("symptoms", [])
606
+ )
607
+
608
+ stage_html = f"""
609
+ <div class="result-card" style="border-left:4px solid {color}">
610
+ <div style="display:flex;align-items:center;gap:12px;margin-bottom:16px;flex-wrap:wrap">
611
+ <span style="font-size:40px;flex-shrink:0">{emoji}</span>
612
+ <div style="flex:1;min-width:140px">
613
+ <div style="color:#6b7280;font-size:12px;text-transform:uppercase;letter-spacing:2px">
614
+ Predicted Stage
615
+ </div>
616
+ <div style="color:{color};font-size:26px;font-weight:700">
617
+ {STAGE_LABELS.get(stage, stage)}
618
+ </div>
619
+ </div>
620
+ <div style="text-align:right;flex-shrink:0">
621
+ <div style="color:#6b7280;font-size:11px">Confidence</div>
622
+ <div style="color:{conf_color};font-size:28px;font-weight:700">
623
+ {confidence:.0%}
624
+ </div>
625
+ </div>
626
+ </div>
627
+ <hr style="border:none;border-top:1px solid #e2e8f0;margin:12px 0">
628
+ <p style="color:#374151;font-size:14px;margin:8px 0">
629
+ {info.get('description', '')}
630
+ </p>
631
+ <div style="margin-top:12px">
632
+ <div style="color:#6b7280;font-size:11px;text-transform:uppercase;
633
+ letter-spacing:1px;margin-bottom:6px">Common Symptoms</div>
634
+ <div style="display:flex;flex-wrap:wrap;gap:6px">{symptom_tags}</div>
635
+ </div>
636
+ <div style="background:{color}0d;border-left:3px solid {color};
637
+ padding:10px 14px;margin-top:14px;border-radius:0 8px 8px 0">
638
+ <span style="color:{color};font-size:12px;font-weight:600">💡 Guidance: </span>
639
+ <span style="color:#374151;font-size:13px">{info.get('guidance', '')}</span>
640
+ </div>
641
+ <div style="color:#9ca3af;font-size:11px;margin-top:12px">
642
+ Model: {model_choice} · {datetime.now().strftime('%Y-%m-%d %H:%M')}
643
+ </div>
644
+ </div>
645
+ """
646
+
647
+ # Confidence note
648
+ if confidence >= 0.8:
649
+ conf_note = "✅ High confidence — the model is quite certain about this stage."
650
+ elif confidence >= 0.6:
651
+ conf_note = ("⚠️ Moderate confidence — consider providing more feature values "
652
+ "or consulting a clinician.")
653
+ else:
654
+ conf_note = ("🔴 Low confidence — prediction is uncertain; "
655
+ "clinical consultation is strongly recommended.")
656
+
657
+ # Model comparison panel + run-dir info
658
+ compare_html = f"""
659
+ <div class="result-card" style="margin-top:0">
660
+ <div style="color:#6b7280;font-size:11px;text-transform:uppercase;
661
+ letter-spacing:1px;margin-bottom:10px;font-weight:600">
662
+ Model Comparison
663
+ </div>
664
+ <div class="stat-grid-2">
665
+ <div class="stat-item" style="border-top:3px solid #16a34a">
666
+ <div style="color:#16a34a;font-size:11px;font-weight:600">Random Forest</div>
667
+ <div style="color:#111827;font-size:17px;margin-top:4px">
668
+ {STAGE_EMOJI.get(rf_stage,'')} {STAGE_LABELS.get(rf_stage, rf_stage)}
669
+ </div>
670
+ <div style="color:#6b7280;font-size:12px">
671
+ {comparison['RandomForest'].get('confidence', 0):.0%} confidence
672
+ </div>
673
+ </div>
674
+ <div class="stat-item" style="border-top:3px solid #2563eb">
675
+ <div style="color:#2563eb;font-size:11px;font-weight:600">
676
+ Logistic Regression
677
+ </div>
678
+ <div style="color:#111827;font-size:17px;margin-top:4px">
679
+ {STAGE_EMOJI.get(lr_stage,'')} {STAGE_LABELS.get(lr_stage, lr_stage)}
680
+ </div>
681
+ <div style="color:#6b7280;font-size:12px">
682
+ {comparison['LogisticRegression'].get('confidence', 0):.0%} confidence
683
+ </div>
684
+ </div>
685
+ </div>
686
+ <div style="margin-top:10px;padding:8px;border-radius:8px;
687
+ background:{'#d1fae5' if agree else '#fef2f2'};
688
+ color:{'#065f46' if agree else '#9f1239'};
689
+ font-size:13px;text-align:center;font-weight:500">
690
+ {"✅ Both models agree — prediction is robust"
691
+ if agree else
692
+ "⚠️ Models disagree — interpret with caution"}
693
+ </div>
694
+ <div class="output-path-box">
695
+ <div class="output-path-title">📁 Outputs saved to:</div>
696
+ <div class="output-path-dir">{run_dir}/</div>
697
+ <div class="output-path-files">
698
+ charts/stage_probabilities.png<br>
699
+ predictions/stage_prediction.csv<br>
700
+ reports/prediction_summary.txt
701
+ </div>
702
+ </div>
703
+ </div>
704
+ """
705
+
706
+ return stage_html, chart_fig, conf_note, compare_html, str(csv_path)
707
+
708
+ except Exception as exc:
709
+ return f"❌ Prediction error: {exc}", None, "", "", None
710
+
711
+
712
+ def predict_batch_stage(file, model_choice):
713
+ """
714
+ Batch stage prediction from uploaded CSV.
715
+
716
+ Returns (csv_download_path, summary_html, preview_df).
717
+ """
718
+ if not _MODEL_OK:
719
+ return None, f"⚠️ {_MODEL_MSG}", None
720
+
721
+ if file is None:
722
+ return None, "Please upload a CSV file.", None
723
+
724
+ file_path = _get_file_path(file)
725
+ try:
726
+ df = pd.read_csv(file_path)
727
+ except Exception as exc:
728
+ return None, f"Could not read CSV: {exc}", None
729
+
730
+ if df.empty:
731
+ return None, "Uploaded CSV is empty.", None
732
+
733
+ # Identify ID column
734
+ id_col_candidates = ["individual", "Individual", "ID", "id",
735
+ "SWANID", "subject", "Subject"]
736
+ id_col = next((c for c in id_col_candidates if c in df.columns), None)
737
+
738
+ # Validate features
739
+ feature_names = _metadata.get("feature_names", [])
740
+ matching = [c for c in df.columns if c in feature_names]
741
+ missing_pct = 1 - len(matching) / max(len(feature_names), 1)
742
+
743
+ warnings_list = []
744
+ if not matching:
745
+ return None, (
746
+ "❌ No matching feature columns found. "
747
+ "Please include columns from the training feature set "
748
+ "(see 'Feature Reference' tab)."
749
+ ), None
750
+ if missing_pct > 0.5:
751
+ warnings_list.append(
752
+ f"⚠️ {missing_pct:.0%} of training features are missing — "
753
+ "prediction accuracy may be reduced."
754
+ )
755
+
756
+ try:
757
+ results = _forecast.predict_batch(df, model=model_choice, return_proba=True)
758
+
759
+ # Insert individual ID
760
+ if id_col:
761
+ results.insert(0, "individual", df[id_col].values)
762
+ else:
763
+ results.insert(0, "individual",
764
+ [f"Row_{i+1}" for i in range(len(results))])
765
+
766
+ results["model"] = model_choice
767
+ results["notes"] = ""
768
+ if "confidence" in results.columns:
769
+ low_mask = results["confidence"] < 0.6
770
+ results.loc[low_mask, "notes"] = "Low confidence — review manually"
771
+
772
+ # ── Create timestamped run directory ──────────────────────────────────
773
+ run_dir = _make_run_dir()
774
+
775
+ # ── Save predictions CSV ──────────────────────────────────────────────
776
+ csv_path = run_dir / "predictions" / "batch_stage_predictions.csv"
777
+ results.to_csv(csv_path, index=False)
778
+
779
+ # ── Save confidence/distribution chart (PNG) ──────────────────────────
780
+ chart_path = run_dir / "charts" / "batch_summary_chart.png"
781
+ _make_batch_summary_chart(results, save_path=chart_path)
782
+
783
+ # ── Save text report ──────────────────────────────────────────────────
784
+ txt_path = run_dir / "reports" / "batch_summary.txt"
785
+ _write_batch_report(txt_path, results, model_choice, run_dir)
786
+
787
+ # ── Build summary HTML ────────────────────────────────────────────────
788
+ total = len(results)
789
+ dist = results["predicted_stage"].value_counts().to_dict()
790
+ mean_conf = results["confidence"].mean() \
791
+ if "confidence" in results.columns else 0.0
792
+ high_conf = int((results["confidence"] > 0.8).sum()) \
793
+ if "confidence" in results.columns else 0
794
+
795
+ dist_bars = ""
796
+ for stage in ["pre", "peri", "post"]:
797
+ count = dist.get(stage, 0)
798
+ pct = count / total * 100
799
+ dist_bars += f"""
800
+ <div style="margin:6px 0">
801
+ <div style="display:flex;justify-content:space-between;margin-bottom:2px">
802
+ <span style="color:#374151;font-size:13px">
803
+ {STAGE_EMOJI.get(stage,'')} {STAGE_LABELS.get(stage, stage)}
804
+ </span>
805
+ <span style="color:#6b7280;font-size:12px">{count} ({pct:.0f}%)</span>
806
+ </div>
807
+ <div style="background:#e2e8f0;border-radius:4px;height:8px">
808
+ <div style="background:{STAGE_COLORS.get(stage,'#6b7280')};
809
+ width:{pct}%;height:8px;border-radius:4px"></div>
810
+ </div>
811
+ </div>"""
812
+
813
+ warn_html = "".join(
814
+ f'<div style="color:#d97706;font-size:12px;margin-top:4px">{w}</div>'
815
+ for w in warnings_list
816
+ )
817
+
818
+ summary_html = f"""
819
+ <div class="result-card">
820
+ <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
821
+ 📊 Batch Results — {total} individuals
822
+ </div>
823
+ {warn_html}
824
+ <div class="stat-grid-3">
825
+ <div class="stat-item">
826
+ <div class="stat-label">Total</div>
827
+ <div class="stat-value">{total}</div>
828
+ </div>
829
+ <div class="stat-item">
830
+ <div class="stat-label">Avg Confidence</div>
831
+ <div class="stat-value" style="color:{_confidence_color(mean_conf)}">
832
+ {mean_conf:.0%}
833
+ </div>
834
+ </div>
835
+ <div class="stat-item">
836
+ <div class="stat-label">High Conf (&gt;80%)</div>
837
+ <div class="stat-value" style="color:#16a34a">{high_conf}/{total}</div>
838
+ </div>
839
+ </div>
840
+ <div style="margin-top:12px">{dist_bars}</div>
841
+ <div class="output-path-box">
842
+ <div class="output-path-title">📁 Outputs saved to:</div>
843
+ <div class="output-path-dir">{run_dir}/</div>
844
+ <div class="output-path-files">
845
+ predictions/batch_stage_predictions.csv<br>
846
+ charts/batch_summary_chart.png<br>
847
+ reports/batch_summary.txt
848
+ </div>
849
+ </div>
850
+ </div>
851
+ """
852
+
853
+ return str(csv_path), summary_html, results.head(20)
854
+
855
+ except Exception as exc:
856
+ return None, f"❌ Batch prediction error: {exc}", None
857
+
858
+
859
+ def predict_symptoms(individual_id, lmp_input, target_date_input, cycle_length):
860
+ """
861
+ Cycle-based symptom forecasting (single person).
862
+
863
+ Returns (result_html, chart_fig, csv_download_path).
864
+ """
865
+ if not lmp_input:
866
+ return "Please enter your Last Menstrual Period date.", None, None
867
+
868
+ try:
869
+ cycle_length = int(cycle_length) if cycle_length else 28
870
+ fore = SymptomCycleForecaster(cycle_length=cycle_length)
871
+ target_date = target_date_input if target_date_input else None
872
+ result = fore.predict_single(lmp=lmp_input, target_date=target_date)
873
+
874
+ cycle_day = result.get("cycle_day")
875
+ hot_prob = result.get("hotflash_prob", 0)
876
+ hot_pred = result.get("hotflash_pred", False)
877
+ mood_prob = result.get("mood_prob", 0)
878
+ mood_pred = result.get("mood_pred", False)
879
+
880
+ # Safe float helpers
881
+ hp = float(hot_prob) if (hot_prob is not None and not np.isnan(hot_prob)) else 0.0
882
+ mp = float(mood_prob) if (mood_prob is not None and not np.isnan(mood_prob)) else 0.0
883
+
884
+ # ── Create timestamped run directory ──────────────────────────────────
885
+ run_dir = _make_run_dir()
886
+
887
+ # ── Save cycle chart (PNG) ────────────────────────────────────────────
888
+ chart_path = run_dir / "charts" / "cycle_position.png"
889
+ chart_fig = _make_cycle_chart(
890
+ cycle_day, cycle_length, hp, mp, save_path=chart_path
891
+ )
892
+
893
+ # ── Save forecast CSV ─────────────────────────────────────────────────
894
+ csv_path = run_dir / "predictions" / "symptom_forecast.csv"
895
+ pd.DataFrame([{
896
+ "individual": individual_id or "N/A",
897
+ "LMP": lmp_input,
898
+ "date": target_date_input or datetime.now().strftime("%Y-%m-%d"),
899
+ "cycle_day": cycle_day,
900
+ "hotflash_prob": round(hp, 6),
901
+ "hotflash_pred": bool(hot_pred),
902
+ "mood_prob": round(mp, 6),
903
+ "mood_pred": bool(mood_pred),
904
+ }]).to_csv(csv_path, index=False)
905
+
906
+ # ── Save text report ──────────────────────────────────────────────────
907
+ txt_path = run_dir / "reports" / "symptom_summary.txt"
908
+ _write_symptom_report(
909
+ txt_path, individual_id, lmp_input, target_date_input,
910
+ cycle_day, cycle_length, hp, hot_pred, mp, mood_pred,
911
+ )
912
+
913
+ # ── Build result HTML ─────────────────────────────────────────────────
914
+ def _prob_bar(prob, label, color):
915
+ pct = min(prob * 100, 100)
916
+ return f"""
917
+ <div style="margin:10px 0">
918
+ <div style="display:flex;justify-content:space-between;margin-bottom:4px">
919
+ <span style="color:#374151;font-size:14px">{label}</span>
920
+ <span style="color:{color};font-size:16px;font-weight:700">{pct:.0f}%</span>
921
+ </div>
922
+ <div style="background:#e2e8f0;border-radius:6px;height:10px">
923
+ <div style="background:{color};width:{pct}%;height:10px;
924
+ border-radius:6px;transition:width 0.5s"></div>
925
+ </div>
926
+ </div>"""
927
+
928
+ hot_alert = "🔴 Elevated risk" if hot_pred else "🟢 Low risk"
929
+ mood_alert = "🔴 Elevated risk" if mood_pred else "🟢 Low risk"
930
+
931
+ html = f"""
932
+ <div class="result-card">
933
+ <div style="color:#111827;font-size:18px;font-weight:700;margin-bottom:4px">
934
+ {individual_id or 'Forecast'} — Cycle Day {cycle_day or '?'}
935
+ </div>
936
+ <div style="color:#6b7280;font-size:13px;margin-bottom:20px">
937
+ LMP: {lmp_input} | Target: {target_date_input or 'Today'}
938
+ | Cycle: {cycle_length} days
939
+ </div>
940
+ {_prob_bar(hp, '🔥 Hot Flash Probability', '#ef4444')}
941
+ <div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{hot_alert}</div>
942
+ {_prob_bar(mp, '😤 Mood Change Probability', '#7c3aed')}
943
+ <div style="color:#6b7280;font-size:12px;margin:-6px 0 10px 2px">{mood_alert}</div>
944
+ <div style="background:#f8fafc;border:1px solid #e2e8f0;border-radius:8px;
945
+ padding:12px;margin-top:14px;font-size:12px;color:#6b7280">
946
+ ℹ️ Probabilities are computed from a cycle-phase model (Gaussian heuristic).
947
+ They represent symptom likelihood based on cycle day, not a clinical diagnosis.
948
+ </div>
949
+ <div class="output-path-box">
950
+ <div class="output-path-title">📁 Outputs saved to:</div>
951
+ <div class="output-path-dir">{run_dir}/</div>
952
+ <div class="output-path-files">
953
+ charts/cycle_position.png<br>
954
+ predictions/symptom_forecast.csv<br>
955
+ reports/symptom_summary.txt
956
+ </div>
957
+ </div>
958
+ </div>
959
+ """
960
+
961
+ return html, chart_fig, str(csv_path)
962
+
963
+ except Exception as exc:
964
+ return f"❌ Error: {exc}", None, None
965
+
966
+
967
+ def predict_symptoms_batch(file, lmp_col_name, date_col_name, cycle_length):
968
+ """
969
+ Batch symptom forecasting from CSV.
970
+
971
+ Returns (csv_download_path, summary_html, preview_df).
972
+ """
973
+ if file is None:
974
+ return None, "Please upload a CSV file.", None
975
+
976
+ file_path = _get_file_path(file)
977
+ try:
978
+ df = pd.read_csv(file_path)
979
+ except Exception as exc:
980
+ return None, f"Could not read CSV: {exc}", None
981
+
982
+ if lmp_col_name not in df.columns:
983
+ return None, (
984
+ f"LMP column '{lmp_col_name}' not found in CSV. "
985
+ f"Columns present: {list(df.columns)}"
986
+ ), None
987
+
988
+ try:
989
+ cycle_length = int(cycle_length) if cycle_length else 28
990
+ fore = SymptomCycleForecaster(cycle_length=cycle_length)
991
+ date_col = date_col_name \
992
+ if (date_col_name and date_col_name in df.columns) else None
993
+ results = fore.predict_df(df, lmp_col=lmp_col_name, date_col=date_col)
994
+
995
+ # ── Create timestamped run directory ──────────────────────────────────
996
+ run_dir = _make_run_dir()
997
+
998
+ # ── Save predictions CSV ──────────────────────────────────────────────
999
+ csv_path = run_dir / "predictions" / "batch_symptom_forecast.csv"
1000
+ results.to_csv(csv_path, index=False)
1001
+
1002
+ # ── Save text report ──────────────────────────────────────────────────
1003
+ txt_path = run_dir / "reports" / "batch_symptom_summary.txt"
1004
+ _write_batch_symptom_report(txt_path, results, cycle_length, run_dir)
1005
+
1006
+ # ── Build summary HTML ────────────────────────────────────────────────
1007
+ total = len(results)
1008
+ hot_flags = int(results["hotflash_pred"].sum()) \
1009
+ if "hotflash_pred" in results.columns else 0
1010
+ mood_flags = int(results["mood_pred"].sum()) \
1011
+ if "mood_pred" in results.columns else 0
1012
+ mean_hot = float(results["hotflash_prob"].mean()) \
1013
+ if "hotflash_prob" in results.columns else 0.0
1014
+ mean_mood = float(results["mood_prob"].mean()) \
1015
+ if "mood_prob" in results.columns else 0.0
1016
+
1017
+ summary_html = f"""
1018
+ <div class="result-card">
1019
+ <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
1020
+ 🌊 Symptom Forecast — {total} individuals
1021
+ </div>
1022
+ <div class="stat-grid-3">
1023
+ <div class="stat-item">
1024
+ <div class="stat-label">Total</div>
1025
+ <div class="stat-value">{total}</div>
1026
+ </div>
1027
+ <div class="stat-item">
1028
+ <div class="stat-label">🔥 Hot Flash Risk</div>
1029
+ <div class="stat-value" style="color:#ef4444">{hot_flags}</div>
1030
+ </div>
1031
+ <div class="stat-item">
1032
+ <div class="stat-label">😤 Mood Risk</div>
1033
+ <div class="stat-value" style="color:#7c3aed">{mood_flags}</div>
1034
+ </div>
1035
+ </div>
1036
+ <div class="stat-grid-2">
1037
+ <div class="stat-item">
1038
+ <div class="stat-label">Avg Hot Flash Prob</div>
1039
+ <div class="stat-value" style="color:#ef4444;font-size:18px">
1040
+ {mean_hot:.1%}
1041
+ </div>
1042
+ </div>
1043
+ <div class="stat-item">
1044
+ <div class="stat-label">Avg Mood Prob</div>
1045
+ <div class="stat-value" style="color:#7c3aed;font-size:18px">
1046
+ {mean_mood:.1%}
1047
+ </div>
1048
+ </div>
1049
+ </div>
1050
+ <div class="output-path-box">
1051
+ <div class="output-path-title">📁 Outputs saved to:</div>
1052
+ <div class="output-path-dir">{run_dir}/</div>
1053
+ <div class="output-path-files">
1054
+ predictions/batch_symptom_forecast.csv<br>
1055
+ reports/batch_symptom_summary.txt
1056
+ </div>
1057
+ </div>
1058
+ </div>
1059
+ """
1060
+
1061
+ return str(csv_path), summary_html, results
1062
+
1063
+ except Exception as exc:
1064
+ return None, f"❌ Error: {exc}", None
1065
+
1066
+
1067
+ # ── Feature reference & model status ─────────────────────────────────────────
1068
+
1069
+ def get_feature_reference() -> str:
1070
+ feature_names = _metadata.get("feature_names", list(FEATURE_DESCRIPTIONS.keys()))
1071
+
1072
+ rows = ""
1073
+ for i, f in enumerate(feature_names[:60]):
1074
+ desc = FEATURE_DESCRIPTIONS.get(f, f.split("_")[0])
1075
+ rows += f"""
1076
+ <tr>
1077
+ <td class="feature-num">{i + 1}</td>
1078
+ <td class="feature-code">{f}</td>
1079
+ <td class="feature-desc">{desc}</td>
1080
+ </tr>"""
1081
+
1082
+ remaining = len(feature_names) - 60
1083
+ if remaining > 0:
1084
+ rows += f"""
1085
+ <tr>
1086
+ <td colspan="3" style="padding:8px;color:#9ca3af;font-size:12px;text-align:center">
1087
+ … and {remaining} more features (one-hot encoded categories)
1088
+ </td>
1089
+ </tr>"""
1090
+
1091
+ return f"""
1092
+ <div class="feature-table-wrap">
1093
+ <div style="color:#111827;font-size:16px;font-weight:700;margin-bottom:14px">
1094
+ 📋 Training Features ({len(feature_names)} total after encoding)
1095
+ </div>
1096
+ <table>
1097
+ <thead>
1098
+ <tr>
1099
+ <th>#</th>
1100
+ <th>Feature</th>
1101
+ <th>Description</th>
1102
+ </tr>
1103
+ </thead>
1104
+ <tbody>{rows}</tbody>
1105
+ </table>
1106
+ </div>
1107
+ """
1108
+
1109
+
1110
+ def get_model_status() -> str:
1111
+ if _MODEL_OK:
1112
+ fc = len(_metadata.get("feature_names", []))
1113
+ sc = _metadata.get("stage_classes", ["pre", "peri", "post"])
1114
+ badges = "".join(
1115
+ f'<span style="background:{STAGE_COLORS.get(s,"#607d8b")}18;'
1116
+ f'color:{STAGE_COLORS.get(s,"#555")};padding:4px 12px;'
1117
+ f'border-radius:20px;border:1px solid {STAGE_COLORS.get(s,"#607d8b")}44;'
1118
+ f'font-size:13px;font-weight:600">{STAGE_EMOJI.get(s,"")} {s}</span>'
1119
+ for s in sc
1120
+ )
1121
+ return f"""
1122
+ <div class="status-card">
1123
+ <div style="display:flex;align-items:center;gap:10px;margin-bottom:14px">
1124
+ <span style="font-size:24px">✅</span>
1125
+ <div>
1126
+ <div style="color:#059669;font-size:16px;font-weight:700">
1127
+ Models Loaded Successfully
1128
+ </div>
1129
+ <div style="color:#6b7280;font-size:12px">Ready for predictions</div>
1130
+ </div>
1131
+ </div>
1132
+ <div class="stat-grid-3">
1133
+ <div class="stat-item">
1134
+ <div class="stat-label">Features</div>
1135
+ <div class="stat-value">{fc}</div>
1136
+ </div>
1137
+ <div class="stat-item">
1138
+ <div class="stat-label">Models</div>
1139
+ <div class="stat-value">2</div>
1140
+ </div>
1141
+ <div class="stat-item">
1142
+ <div class="stat-label">Stages</div>
1143
+ <div class="stat-value">{len(sc)}</div>
1144
+ </div>
1145
+ </div>
1146
+ <div style="margin-top:14px">
1147
+ <div style="color:#6b7280;font-size:11px;text-transform:uppercase;
1148
+ letter-spacing:0.5px;margin-bottom:6px">Available Stages</div>
1149
+ <div style="display:flex;gap:8px;flex-wrap:wrap">{badges}</div>
1150
+ </div>
1151
+ </div>
1152
+ """
1153
+ return f"""
1154
+ <div class="status-card">
1155
+ <div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">
1156
+ <span style="font-size:24px">⚠️</span>
1157
+ <div>
1158
+ <div style="color:#dc2626;font-size:16px;font-weight:700">
1159
+ Models Not Loaded
1160
+ </div>
1161
+ <div style="color:#6b7280;font-size:12px">{_MODEL_MSG}</div>
1162
+ </div>
1163
+ </div>
1164
+ <div style="background:#fef2f2;border:1px solid #fecaca;border-radius:8px;
1165
+ padding:12px;color:#9f1239;font-size:13px">
1166
+ To train and save models:<br>
1167
+ <code style="background:#1e293b;color:#a3e635;padding:4px 8px;border-radius:4px;
1168
+ margin-top:6px;display:inline-block">python menopause.py</code>
1169
+ <br><br>
1170
+ This generates <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
1171
+ color:#1e293b">swan_ml_output/rf_pipeline.pkl</code>,
1172
+ <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
1173
+ color:#1e293b">lr_pipeline.pkl</code>, and
1174
+ <code style="background:#e2e8f0;padding:2px 5px;border-radius:3px;
1175
+ color:#1e293b">forecast_metadata.json</code>.
1176
+ </div>
1177
+ </div>
1178
+ """
1179
+
1180
+
1181
+ # ── Education content ─────────────────────────────────────────────────────────
1182
+ EDUCATION_HTML = """
1183
+ <div class="edu-card">
1184
+ <h2>🌸 Understanding Menopause</h2>
1185
+ <p>Menopause is a natural biological process marking the end of menstrual cycles.
1186
+ It is officially diagnosed after 12 consecutive months without a menstrual period
1187
+ and typically occurs in women in their late 40s to early 50s.</p>
1188
+
1189
+ <h3>Three Stages</h3>
1190
+ <div class="stage-cards-grid">
1191
+ <div class="stage-card-pre">
1192
+ <div style="color:#16a34a;font-weight:700;margin-bottom:8px">🟢 Pre-Menopause</div>
1193
+ <p style="font-size:13px;margin:0;color:#374151">Regular ovarian function. Periods are predictable.
1194
+ Hormones (estrogen, progesterone) follow a consistent monthly pattern.</p>
1195
+ </div>
1196
+ <div class="stage-card-peri">
1197
+ <div style="color:#d97706;font-weight:700;margin-bottom:8px">🟡 Peri-Menopause</div>
1198
+ <p style="font-size:13px;margin:0;color:#374151">Transition phase — usually begins in the mid-40s.
1199
+ Hormone levels fluctuate. Periods become irregular.
1200
+ Hot flashes and sleep issues may begin.</p>
1201
+ </div>
1202
+ <div class="stage-card-post">
1203
+ <div style="color:#7c3aed;font-weight:700;margin-bottom:8px">🟣 Post-Menopause</div>
1204
+ <p style="font-size:13px;margin:0;color:#374151">12+ months after the last period.
1205
+ Lower estrogen levels. Risk factors for osteoporosis and
1206
+ cardiovascular disease increase.</p>
1207
+ </div>
1208
+ </div>
1209
+
1210
+ <h3>Common Symptoms by Stage</h3>
1211
+ <table style="width:100%;border-collapse:collapse;font-size:13px">
1212
+ <thead>
1213
+ <tr style="background:#f8fafc">
1214
+ <th style="padding:8px;text-align:left;color:#6b7280;font-weight:600">Symptom</th>
1215
+ <th style="padding:8px;text-align:center;color:#16a34a;font-weight:600">Pre</th>
1216
+ <th style="padding:8px;text-align:center;color:#d97706;font-weight:600">Peri</th>
1217
+ <th style="padding:8px;text-align:center;color:#7c3aed;font-weight:600">Post</th>
1218
+ </tr>
1219
+ </thead>
1220
+ <tbody>
1221
+ <tr style="border-bottom:1px solid #e2e8f0">
1222
+ <td style="padding:8px;color:#374151">Hot flashes</td>
1223
+ <td style="text-align:center;color:#9ca3af">–</td>
1224
+ <td style="text-align:center">✅</td>
1225
+ <td style="text-align:center">✅</td>
1226
+ </tr>
1227
+ <tr style="border-bottom:1px solid #e2e8f0">
1228
+ <td style="padding:8px;color:#374151">Irregular periods</td>
1229
+ <td style="text-align:center;color:#9ca3af">–</td>
1230
+ <td style="text-align:center">✅</td>
1231
+ <td style="text-align:center;color:#9ca3af">N/A</td>
1232
+ </tr>
1233
+ <tr style="border-bottom:1px solid #e2e8f0">
1234
+ <td style="padding:8px;color:#374151">Sleep disturbances</td>
1235
+ <td style="text-align:center;color:#6b7280">Mild</td>
1236
+ <td style="text-align:center">✅</td>
1237
+ <td style="text-align:center">✅</td>
1238
+ </tr>
1239
+ <tr style="border-bottom:1px solid #e2e8f0">
1240
+ <td style="padding:8px;color:#374151">Mood changes</td>
1241
+ <td style="text-align:center;color:#6b7280">PMS</td>
1242
+ <td style="text-align:center">✅</td>
1243
+ <td style="text-align:center;color:#6b7280">Possible</td>
1244
+ </tr>
1245
+ <tr style="border-bottom:1px solid #e2e8f0">
1246
+ <td style="padding:8px;color:#374151">Vaginal dryness</td>
1247
+ <td style="text-align:center;color:#9ca3af">–</td>
1248
+ <td style="text-align:center;color:#6b7280">Possible</td>
1249
+ <td style="text-align:center">✅</td>
1250
+ </tr>
1251
+ <tr>
1252
+ <td style="padding:8px;color:#374151">Bone density changes</td>
1253
+ <td style="text-align:center;color:#9ca3af">–</td>
1254
+ <td style="text-align:center;color:#6b7280">Begins</td>
1255
+ <td style="text-align:center">✅</td>
1256
+ </tr>
1257
+ </tbody>
1258
+ </table>
1259
+
1260
+ <h3>About This Tool</h3>
1261
+ <p style="font-size:13px">This application uses machine learning models trained on the
1262
+ SWAN (Study of Women's Health Across the Nation) dataset — a landmark multisite,
1263
+ multiethnic longitudinal study. The models were trained on self-reported symptom and
1264
+ behavioral data to predict menopausal stage.</p>
1265
+ <div class="disclaimer-box">
1266
+ ⚠️ <strong style="color:#d97706">Disclaimer:</strong>
1267
+ This tool is for educational and research purposes only.
1268
+ Predictions should not substitute clinical diagnosis.
1269
+ Always consult a qualified healthcare provider for medical advice.
1270
+ </div>
1271
+ </div>
1272
+ """
1273
+
1274
+
1275
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
1276
+ CUSTOM_CSS = """
1277
+ /* ── Core ────────────────────────────────────────────────────────────── */
1278
+ .gradio-container {
1279
+ max-width: 1200px !important;
1280
+ margin: 0 auto !important;
1281
+ font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important;
1282
+ background: #f0f4f8 !important;
1283
+ }
1284
+
1285
+ /* ── Header banner ──────────────────────────────────────────────────── */
1286
+ .header-banner {
1287
+ background: linear-gradient(135deg, #faf5ff 0%, #fff0f9 50%, #eff6ff 100%);
1288
+ border: 1px solid #e9d5ff;
1289
+ border-radius: 16px;
1290
+ padding: 28px 32px;
1291
+ margin-bottom: 20px;
1292
+ box-shadow: 0 2px 8px rgba(139,92,246,0.08);
1293
+ position: relative;
1294
+ overflow: hidden;
1295
+ }
1296
+ .header-banner::before {
1297
+ content: '';
1298
+ position: absolute;
1299
+ top: -40%; right: -5%;
1300
+ width: 280px; height: 280px;
1301
+ background: radial-gradient(circle, rgba(139,92,246,0.08) 0%, transparent 70%);
1302
+ pointer-events: none;
1303
+ }
1304
+
1305
+ /* ── Reusable info boxes ─────────────────────────────────────────────── */
1306
+ .info-box {
1307
+ background: #f8fafc;
1308
+ border: 1px solid #e2e8f0;
1309
+ border-left: 3px solid #3b82f6;
1310
+ border-radius: 8px;
1311
+ padding: 12px 16px;
1312
+ color: #475569;
1313
+ font-size: 13px;
1314
+ margin-bottom: 16px;
1315
+ line-height: 1.5;
1316
+ }
1317
+ .info-box code {
1318
+ background: #e2e8f0;
1319
+ color: #1e293b;
1320
+ padding: 1px 5px;
1321
+ border-radius: 3px;
1322
+ font-family: monospace;
1323
+ font-size: 0.9em;
1324
+ }
1325
+ .section-label {
1326
+ color: #2563eb;
1327
+ font-size: 12px;
1328
+ font-weight: 700;
1329
+ text-transform: uppercase;
1330
+ letter-spacing: 0.6px;
1331
+ margin-bottom: 10px;
1332
+ margin-top: 10px;
1333
+ }
1334
+ .format-hint {
1335
+ background: #f8fafc;
1336
+ border: 1px solid #e2e8f0;
1337
+ border-radius: 8px;
1338
+ padding: 14px;
1339
+ margin-top: 10px;
1340
+ font-size: 12px;
1341
+ color: #475569;
1342
+ }
1343
+ .format-hint-title { color: #2563eb; font-weight: 600; margin-bottom: 6px; }
1344
+ .format-hint pre { color: #475569; margin: 0; font-size: 11px; white-space: pre-wrap; }
1345
+ .format-hint-note { color: #94a3b8; font-size: 11px; margin-top: 8px; }
1346
+ .placeholder-msg { color: #9ca3af; text-align: center; padding: 40px; font-size: 14px; }
1347
+ .section-divider { border: none; border-top: 1px solid #e2e8f0; margin: 24px 0; }
1348
+ .batch-section-label { color: #2563eb; font-size: 14px; font-weight: 600; margin-bottom: 12px; }
1349
+
1350
+ /* ── Result & summary cards ─────────────────────────────────────────── */
1351
+ .result-card {
1352
+ background: #ffffff;
1353
+ border: 1px solid #e2e8f0;
1354
+ border-radius: 16px;
1355
+ padding: 24px;
1356
+ box-shadow: 0 1px 4px rgba(0,0,0,0.06);
1357
+ font-family: 'Segoe UI', system-ui, sans-serif;
1358
+ }
1359
+ .stat-grid-3 { display:grid; grid-template-columns:repeat(3,1fr); gap:12px; margin:14px 0; }
1360
+ .stat-grid-2 { display:grid; grid-template-columns:1fr 1fr; gap:10px; margin-top:10px; }
1361
+ .stat-item { background:#f8fafc; border:1px solid #e2e8f0; padding:12px; border-radius:8px; text-align:center; }
1362
+ .stat-label { color:#6b7280; font-size:11px; text-transform:uppercase; letter-spacing:0.4px; }
1363
+ .stat-value { color:#111827; font-size:22px; font-weight:700; line-height:1.2; margin-top:2px; }
1364
+ .output-path-box { background:#f0fdf4; border:1px solid #bbf7d0; border-radius:8px; padding:10px 14px; margin-top:12px; font-family:monospace; }
1365
+ .output-path-title { color:#059669; font-size:12px; font-weight:600; }
1366
+ .output-path-dir { color:#065f46; font-size:11px; margin-top:4px; }
1367
+ .output-path-files { color:#6b7280; font-size:10px; margin-top:4px; line-height:1.6; }
1368
+
1369
+ /* ── Code blocks ────────────────────────────────────────────────────── */
1370
+ .code-block {
1371
+ background: #1e293b;
1372
+ color: #a3e635;
1373
+ border-radius: 8px;
1374
+ padding: 12px;
1375
+ font-size: 12px;
1376
+ font-family: monospace;
1377
+ white-space: pre;
1378
+ overflow-x: auto;
1379
+ }
1380
+
1381
+ /* ── Setup instructions card ─────────────────────────────────────────── */
1382
+ .setup-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; margin-top:16px; font-family:'Segoe UI',system-ui,sans-serif; }
1383
+ .setup-title { color:#111827; font-size:15px; font-weight:700; margin-bottom:12px; }
1384
+ .setup-step { color:#374151; font-size:13px; line-height:1.8; }
1385
+ .setup-step strong { color:#2563eb; }
1386
+
1387
+ /* ── Education ──────────────────────────────────────────────────────── */
1388
+ .edu-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:16px; padding:28px; font-family:'Segoe UI',system-ui,sans-serif; color:#374151; line-height:1.7; }
1389
+ .edu-card h2 { color:#111827; font-size:22px; margin-top:0; }
1390
+ .edu-card h3 { color:#7c3aed; font-size:16px; margin-top:20px; }
1391
+ .stage-cards-grid { display:grid; grid-template-columns:repeat(3,1fr); gap:16px; margin:14px 0; }
1392
+ .stage-card-pre { background:#f0fdf4; border-top:4px solid #16a34a; padding:16px; border-radius:10px; }
1393
+ .stage-card-peri { background:#fffbeb; border-top:4px solid #d97706; padding:16px; border-radius:10px; }
1394
+ .stage-card-post { background:#faf5ff; border-top:4px solid #7c3aed; padding:16px; border-radius:10px; }
1395
+ .disclaimer-box { background:#fffbeb; border-left:3px solid #d97706; padding:12px 16px; border-radius:0 8px 8px 0; margin-top:14px; font-size:13px; color:#374151; }
1396
+
1397
+ /* ── Feature reference table ────────────────────────────────────────── */
1398
+ .feature-table-wrap { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; max-height:500px; overflow-y:auto; font-family:'Segoe UI',system-ui,sans-serif; }
1399
+ .feature-table-wrap table { width:100%; border-collapse:collapse; }
1400
+ .feature-table-wrap thead tr { background:#f8fafc; }
1401
+ .feature-table-wrap th { padding:8px; color:#6b7280; font-size:11px; text-align:left; text-transform:uppercase; letter-spacing:0.4px; }
1402
+ .feature-table-wrap tr { border-bottom:1px solid #e2e8f0; }
1403
+ .feature-table-wrap td { padding:8px; }
1404
+ .feature-code { color:#2563eb; font-family:monospace; font-size:13px; }
1405
+ .feature-desc { color:#374151; font-size:12px; }
1406
+ .feature-num { color:#9ca3af; font-size:12px; }
1407
+
1408
+ /* ── Model status card ──────────────────────────────────────────────── */
1409
+ .status-card { background:#ffffff; border:1px solid #e2e8f0; border-radius:12px; padding:20px; font-family:'Segoe UI',system-ui,sans-serif; }
1410
+
1411
+ /* ── Footer ─────────────────────────────────────────────────────────── */
1412
+ .app-footer { text-align:center; color:#9ca3af; font-size:11px; margin-top:24px; padding:16px; border-top:1px solid #e2e8f0; }
1413
+ .app-footer a { color:#2563eb; text-decoration:none; }
1414
+
1415
+ /* ── Responsive — Tablet (≤ 768 px) ────────────────────────────────── */
1416
+ @media (max-width: 768px) {
1417
+ .gradio-container { padding: 8px !important; }
1418
+ .header-banner { padding: 16px 20px !important; margin-bottom: 12px !important; }
1419
+ .header-status-badge { display: none !important; }
1420
+ .stat-grid-3 { grid-template-columns: 1fr !important; }
1421
+ .stat-grid-2 { grid-template-columns: 1fr !important; }
1422
+ .stage-cards-grid { grid-template-columns: 1fr !important; }
1423
+ }
1424
+
1425
+ /* ── Responsive — Mobile (≤ 480 px) ────────────────────────────────── */
1426
+ @media (max-width: 480px) {
1427
+ .header-banner h1 { font-size: 18px !important; }
1428
+ .result-card { padding: 16px !important; }
1429
+ .edu-card { padding: 16px !important; }
1430
+ .setup-card { padding: 14px !important; }
1431
+ }
1432
+ """
1433
+
1434
+ HEADER_HTML = """
1435
+ <div class="header-banner">
1436
+ <div style="display:flex;align-items:center;gap:16px;flex-wrap:wrap">
1437
+ <div style="font-size:48px;flex-shrink:0">🌸</div>
1438
+ <div style="flex:1;min-width:200px">
1439
+ <h1 style="margin:0;font-size:26px;font-weight:800;
1440
+ background:linear-gradient(135deg,#7c3aed,#db2777);
1441
+ -webkit-background-clip:text;-webkit-text-fill-color:transparent">
1442
+ SWAN Menopause Prediction
1443
+ </h1>
1444
+ <p style="margin:4px 0 0;color:#6b7280;font-size:13px">
1445
+ AI-powered menopausal stage prediction &amp; symptom forecasting ·
1446
+ Based on the SWAN dataset
1447
+ </p>
1448
+ </div>
1449
+ <div class="header-status-badge" style="text-align:right;flex-shrink:0">
1450
+ <div style="background:#ffffff;border:1px solid #e2e8f0;border-radius:8px;
1451
+ padding:8px 16px;display:inline-block;box-shadow:0 1px 3px rgba(0,0,0,0.06)">
1452
+ <div style="color:#9ca3af;font-size:10px;text-transform:uppercase;letter-spacing:1px">
1453
+ Status
1454
+ </div>
1455
+ <div style="color:{color};font-size:13px;font-weight:600">{status}</div>
1456
+ </div>
1457
+ </div>
1458
+ </div>
1459
+ </div>
1460
+ """.format(
1461
+ color = "#059669" if _MODEL_OK else "#dc2626",
1462
+ status = "Models Ready ✅" if _MODEL_OK else "Models Needed ⚠️",
1463
+ )
1464
+
1465
+
1466
+ # ── App builder ───────────────────────────────────────────────────────────────
1467
+
1468
+ def build_app():
1469
+ with gr.Blocks(
1470
+ css = CUSTOM_CSS,
1471
+ title = "SWAN Menopause Prediction",
1472
+ theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
1473
+ ) as app:
1474
+
1475
+ gr.HTML(HEADER_HTML)
1476
+
1477
+ with gr.Tabs():
1478
+
1479
+ # ── TAB 1: Single Stage Prediction ────────────────────────────────
1480
+ with gr.Tab("🔮 Stage Prediction"):
1481
+ gr.HTML("""
1482
+ <div class="info-box">
1483
+ Fill in the fields below to predict menopausal stage for a single individual.
1484
+ All fields are optional — the pipeline handles missing values automatically.
1485
+ A timestamped output folder is created in
1486
+ <code>swan_ml_output/</code> for every run.
1487
+ </div>""")
1488
+
1489
+ with gr.Row():
1490
+ # ── Input column ──────────────────────────────────────────
1491
+ with gr.Column(scale=2):
1492
+
1493
+ with gr.Group():
1494
+ gr.HTML('<div class="section-label">Demographics</div>')
1495
+ with gr.Row():
1496
+ age = gr.Slider(
1497
+ minimum=35, maximum=75, value=48, step=1,
1498
+ label="Age (AGE7)",
1499
+ )
1500
+ race = gr.Dropdown(
1501
+ choices=[1, 2, 3, 4, 5], value=1,
1502
+ label="Race (RACE)",
1503
+ info="1=White, 2=Black, 3=Chinese, 4=Japanese, 5=Hispanic",
1504
+ )
1505
+ langint = gr.Dropdown(
1506
+ choices=[1, 2, 3], value=1,
1507
+ label="Interview Language (LANGINT7)",
1508
+ info="1=English, 2=Spanish, 3=Other",
1509
+ )
1510
+
1511
+ with gr.Group():
1512
+ gr.HTML('<div class="section-label">Vasomotor Symptoms</div>')
1513
+ with gr.Row():
1514
+ hot_flash = gr.Slider(
1515
+ minimum=1, maximum=5, value=1, step=1,
1516
+ label="Hot Flash Severity (HOTFLAS7)",
1517
+ info="1=None, 5=Very severe",
1518
+ )
1519
+ num_hot_flash = gr.Slider(
1520
+ minimum=0, maximum=15, value=0, step=1,
1521
+ label="# Hot Flashes/Week (NUMHOTF7)",
1522
+ )
1523
+ bothersome_hf = gr.Slider(
1524
+ minimum=1, maximum=4, value=1, step=1,
1525
+ label="How Bothersome (BOTHOTF7)",
1526
+ info="1=Not at all, 4=Extremely",
1527
+ )
1528
+
1529
+ with gr.Group():
1530
+ gr.HTML('<div class="section-label">Sleep &amp; Mood</div>')
1531
+ with gr.Row():
1532
+ sleep_quality = gr.Slider(
1533
+ minimum=1, maximum=5, value=2, step=1,
1534
+ label="Sleep Quality (SLEEPQL7)",
1535
+ info="1=Very good, 5=Very poor",
1536
+ )
1537
+ depression = gr.Slider(
1538
+ minimum=0, maximum=4, value=0, step=1,
1539
+ label="Depression Indicator (DEPRESS7)",
1540
+ info="0=No, higher=more severe",
1541
+ )
1542
+ with gr.Row():
1543
+ mood_change = gr.Slider(
1544
+ minimum=1, maximum=5, value=1, step=1,
1545
+ label="Mood Changes (MOODCHG7)",
1546
+ info="1=None, 5=Severe",
1547
+ )
1548
+ irritability = gr.Slider(
1549
+ minimum=1, maximum=5, value=1, step=1,
1550
+ label="Irritability (IRRITAB7)",
1551
+ )
1552
+
1553
+ with gr.Group():
1554
+ gr.HTML('<div class="section-label">Physical &amp; Gynaecological</div>')
1555
+ with gr.Row():
1556
+ pain = gr.Slider(
1557
+ minimum=0, maximum=5, value=0, step=1,
1558
+ label="Pain Indicator (PAIN17)",
1559
+ )
1560
+ abbleed = gr.Dropdown(
1561
+ choices=[0, 1, 2], value=0,
1562
+ label="Abnormal Bleeding (ABBLEED7)",
1563
+ info="0=No, 1=Yes, 2=Unsure",
1564
+ )
1565
+ with gr.Row():
1566
+ vaginal_dryness = gr.Slider(
1567
+ minimum=0, maximum=5, value=0, step=1,
1568
+ label="Vaginal Dryness (VAGINDR7)",
1569
+ )
1570
+ lmp_day = gr.Number(
1571
+ value=None,
1572
+ label="LMP Day (LMPDAY7)",
1573
+ info="Day of last menstrual period (optional)",
1574
+ )
1575
+
1576
+ model_choice = gr.Radio(
1577
+ choices=["RandomForest", "LogisticRegression"],
1578
+ value="RandomForest",
1579
+ label="Model",
1580
+ info="RandomForest: higher accuracy | "
1581
+ "LogisticRegression: more interpretable",
1582
+ )
1583
+ predict_btn = gr.Button(
1584
+ "🔮 Predict Stage", variant="primary", size="lg"
1585
+ )
1586
+
1587
+ # ── Output column ─────────────────────────────────────────
1588
+ with gr.Column(scale=3):
1589
+ result_html = gr.HTML(
1590
+ '<div class="placeholder-msg">Fill in the form and click Predict Stage</div>'
1591
+ )
1592
+ result_chart = gr.Plot(label="Stage Probabilities")
1593
+ confidence_note = gr.Textbox(
1594
+ label="Confidence Note", interactive=False, lines=2
1595
+ )
1596
+ compare_html = gr.HTML()
1597
+ stage_download = gr.File(
1598
+ label="Download Prediction CSV", interactive=False
1599
+ )
1600
+
1601
+ predict_btn.click(
1602
+ fn = predict_single_stage,
1603
+ inputs = [
1604
+ age, race, langint,
1605
+ hot_flash, num_hot_flash, bothersome_hf,
1606
+ sleep_quality, depression, mood_change, irritability,
1607
+ pain, abbleed, vaginal_dryness, lmp_day,
1608
+ model_choice,
1609
+ ],
1610
+ outputs = [
1611
+ result_html, result_chart, confidence_note,
1612
+ compare_html, stage_download,
1613
+ ],
1614
+ )
1615
+
1616
+ # ── TAB 2: Batch Stage Prediction ─────────────────────────────────
1617
+ with gr.Tab("📁 Batch Stage Prediction"):
1618
+ gr.HTML("""
1619
+ <div class="info-box">
1620
+ Upload a CSV file with individual feature values for batch prediction.
1621
+ Results + charts + a summary report are saved to a timestamped folder
1622
+ inside <code>swan_ml_output/</code>.
1623
+ </div>""")
1624
+
1625
+ with gr.Row():
1626
+ with gr.Column(scale=1):
1627
+ batch_file = gr.File(
1628
+ label="Upload stage_input.csv",
1629
+ file_types=[".csv"],
1630
+ )
1631
+ batch_model = gr.Radio(
1632
+ choices=["RandomForest", "LogisticRegression"],
1633
+ value="RandomForest",
1634
+ label="Model",
1635
+ )
1636
+ gr.HTML("""
1637
+ <div class="format-hint">
1638
+ <div class="format-hint-title">Expected CSV Format</div>
1639
+ <pre>individual,AGE7,RACE,HOTFLAS7,...
1640
+ Person_001,48,1,2,...
1641
+ Person_002,52,2,1,...</pre>
1642
+ <div class="format-hint-note">
1643
+ See the test-csv/ folder for an approved example.
1644
+ </div>
1645
+ </div>""")
1646
+ batch_predict_btn = gr.Button(
1647
+ "🚀 Run Batch Prediction", variant="primary"
1648
+ )
1649
+
1650
+ with gr.Column(scale=2):
1651
+ batch_summary_html = gr.HTML(
1652
+ '<div class="placeholder-msg">Upload a CSV to begin</div>'
1653
+ )
1654
+ batch_download = gr.File(
1655
+ label="Download Predictions CSV", interactive=False
1656
+ )
1657
+ batch_results_df = gr.DataFrame(
1658
+ label="Results Preview (first 20 rows)",
1659
+ interactive=False,
1660
+ )
1661
+
1662
+ batch_predict_btn.click(
1663
+ fn = predict_batch_stage,
1664
+ inputs = [batch_file, batch_model],
1665
+ outputs = [batch_download, batch_summary_html, batch_results_df],
1666
+ )
1667
+
1668
+ # ── TAB 3: Symptom Forecast ───────────────────────────────────────
1669
+ with gr.Tab("🌊 Symptom Forecast"):
1670
+ gr.HTML("""
1671
+ <div class="info-box">
1672
+ Predict hot flash and mood change probability based on cycle day
1673
+ (calculated from Last Menstrual Period date).
1674
+ All outputs are saved to a timestamped folder inside
1675
+ <code>swan_ml_output/</code>.
1676
+ </div>""")
1677
+
1678
+ with gr.Row():
1679
+ with gr.Column(scale=1):
1680
+ sym_individual = gr.Textbox(
1681
+ label="Individual ID (optional)",
1682
+ placeholder="e.g., Patient_001",
1683
+ )
1684
+ sym_lmp = gr.Textbox(
1685
+ label="Last Menstrual Period (LMP)",
1686
+ placeholder="2026-01-15 or 15 (day of month)",
1687
+ info="Full date (YYYY-MM-DD) or day-of-month integer",
1688
+ )
1689
+ sym_date = gr.Textbox(
1690
+ label="Target Date (optional)",
1691
+ placeholder="2026-02-27 (defaults to today)",
1692
+ info="Date to forecast for (YYYY-MM-DD)",
1693
+ )
1694
+ sym_cycle = gr.Slider(
1695
+ minimum=21, maximum=40, value=28, step=1,
1696
+ label="Cycle Length (days)",
1697
+ )
1698
+ sym_predict_btn = gr.Button(
1699
+ "🌊 Forecast Symptoms", variant="primary"
1700
+ )
1701
+
1702
+ with gr.Column(scale=2):
1703
+ sym_result_html = gr.HTML(
1704
+ '<div class="placeholder-msg">Enter LMP date and click Forecast</div>'
1705
+ )
1706
+ sym_chart = gr.Plot(label="Cycle Position")
1707
+ sym_download = gr.File(
1708
+ label="Download Forecast CSV", interactive=False
1709
+ )
1710
+
1711
+ sym_predict_btn.click(
1712
+ fn = predict_symptoms,
1713
+ inputs = [sym_individual, sym_lmp, sym_date, sym_cycle],
1714
+ outputs = [sym_result_html, sym_chart, sym_download],
1715
+ )
1716
+
1717
+ gr.HTML('<hr class="section-divider">')
1718
+ gr.HTML('<div class="batch-section-label">📁 Batch Symptom Forecasting</div>')
1719
+
1720
+ with gr.Row():
1721
+ with gr.Column(scale=1):
1722
+ sym_batch_file = gr.File(
1723
+ label="Upload symptoms_input.csv",
1724
+ file_types=[".csv"],
1725
+ )
1726
+ sym_lmp_col = gr.Textbox(
1727
+ label="LMP Column Name", value="LMP"
1728
+ )
1729
+ sym_date_col = gr.Textbox(
1730
+ label="Date Column Name (optional)", value="date"
1731
+ )
1732
+ sym_cycle_batch = gr.Slider(
1733
+ minimum=21, maximum=40, value=28, step=1,
1734
+ label="Default Cycle Length",
1735
+ )
1736
+ sym_batch_btn = gr.Button(
1737
+ "🌊 Run Batch Forecast", variant="primary"
1738
+ )
1739
+
1740
+ with gr.Column(scale=2):
1741
+ sym_batch_summary = gr.HTML(
1742
+ '<div class="placeholder-msg">Upload a CSV to begin</div>'
1743
+ )
1744
+ sym_batch_download = gr.File(
1745
+ label="Download Symptom Forecast CSV", interactive=False
1746
+ )
1747
+ sym_batch_df = gr.DataFrame(
1748
+ label="Results Preview",
1749
+ interactive=False,
1750
+ )
1751
+
1752
+ sym_batch_btn.click(
1753
+ fn = predict_symptoms_batch,
1754
+ inputs = [
1755
+ sym_batch_file, sym_lmp_col,
1756
+ sym_date_col, sym_cycle_batch,
1757
+ ],
1758
+ outputs = [sym_batch_download, sym_batch_summary, sym_batch_df],
1759
+ )
1760
+
1761
+ # ── TAB 4: Education ──────────────────────────────────────────────
1762
+ with gr.Tab("📚 Menopause Education"):
1763
+ gr.HTML(EDUCATION_HTML)
1764
+
1765
+ # ── TAB 5: Feature Reference ──────────────────────────────────────
1766
+ with gr.Tab("🔬 Feature Reference"):
1767
+ gr.HTML("""
1768
+ <div class="info-box">
1769
+ Canonical list of features used by the trained models
1770
+ (from <code>forecast_metadata.json</code>).
1771
+ For batch CSV uploads, column names must match these feature names.
1772
+ </div>""")
1773
+ gr.HTML(get_feature_reference())
1774
+
1775
+ # ── TAB 6: Model Status ───────────────────────────────────────────
1776
+ with gr.Tab("⚙️ Model Status"):
1777
+ gr.HTML(get_model_status())
1778
+ gr.HTML("""
1779
+ <div class="setup-card">
1780
+ <div class="setup-title">🚀 Setup Instructions</div>
1781
+ <div class="setup-step">
1782
+ <p><strong>Step 1 — Train models:</strong></p>
1783
+ <pre class="code-block">python menopause.py</pre>
1784
+ <p><strong>Step 2 — Verify artifacts:</strong></p>
1785
+ <pre class="code-block">ls swan_ml_output/
1786
+ # rf_pipeline.pkl lr_pipeline.pkl forecast_metadata.json</pre>
1787
+ <p><strong>Step 3 — Run this app:</strong></p>
1788
+ <pre class="code-block">python app.py</pre>
1789
+ <p><strong>Step 4 — Deploy on Hugging Face Spaces:</strong></p>
1790
+ <pre class="code-block">git lfs install
1791
+ git lfs track "*.pkl"
1792
+ git add .
1793
+ git commit -m "SWAN menopause prediction app"
1794
+ git push</pre>
1795
+ <p><strong>Output folder structure (per run):</strong></p>
1796
+ <pre class="code-block">swan_ml_output/
1797
+ &lt;YYYYMMDD_HHMMSS&gt;/
1798
+ charts/ &larr; PNG visualizations
1799
+ predictions/ &larr; CSV result files
1800
+ reports/ &larr; TXT summary reports</pre>
1801
+ </div>
1802
+ </div>
1803
+ """)
1804
+
1805
+ gr.HTML("""
1806
+ <div class="app-footer">
1807
+ SWAN Menopause Prediction App · Built with Gradio ·
1808
+ For research &amp; educational use only · Not for clinical diagnosis ·
1809
+ <a href="https://www.swanstudy.org/" target="_blank">SWAN Study</a>
1810
+ </div>""")
1811
+
1812
+ return app
1813
+
1814
+
1815
+ # ── Entry point ───────────────────────────────────────────────────────────────
1816
+ if __name__ == "__main__":
1817
+ demo = build_app()
1818
+ demo.launch(
1819
+ server_name = "0.0.0.0",
1820
+ server_port = int(os.environ.get("PORT", 7860)),
1821
+ share = False,
1822
+ show_error = True,
1823
+ )
menopause.py ADDED
@@ -0,0 +1,1383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SWAN Menopause Stage Prediction (pre / peri / post) using self-reported features
3
+ Uses only the uploaded SWAN TSV file (no synthetic data, no external datasets).
4
+
5
+ Outputs:
6
+ - saved artifacts in ./swan_ml_output/
7
+ - documentation.md summarizing steps and results
8
+ - optional CSV outputs for stage predictions and symptom predictions (separate files)
9
+
10
+ Notes:
11
+ - The script attempts to locate a menopause-stage column heuristically (common names like MENOSTAT,
12
+ MENO, MENOSYM, MENOP etc.). Please verify the chosen stage column against the codebook.
13
+ - Self-reported features are identified using name-pattern heuristics (VMS/HOT/SLEEP/CESD/STRESS/MOOD/SMOK/ALCOH/EXER/PHYS/VAG/URINE/SEX/PAIN etc).
14
+ - Duplicate column names are tolerantly handled by renaming duplicates.
15
+ """
16
+
17
+ import os, re, sys, argparse
18
+ import numpy as np
19
+ import pandas as pd
20
+ import importlib
21
+ import sklearn
22
+ import matplotlib
23
+ # Use a non-interactive backend by default so the script can run on servers/CI
24
+ matplotlib.use('Agg')
25
+ import matplotlib.pyplot as plt
26
+ from datetime import datetime, timedelta
27
+
28
+ from sklearn.model_selection import train_test_split
29
+ from sklearn.impute import SimpleImputer
30
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder
31
+ from sklearn.compose import ColumnTransformer
32
+ from sklearn.pipeline import Pipeline
33
+ from sklearn.ensemble import RandomForestClassifier
34
+ from sklearn.linear_model import LogisticRegression
35
+ from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
36
+ from sklearn.inspection import permutation_importance
37
+ from sklearn.preprocessing import label_binarize
38
+
39
+ # --------------------------
40
+ # Environment / CLI defaults
41
+ # --------------------------
42
+ # Defaults may be overridden by environment variables or CLI args below
43
+ DATA_PATH = os.environ.get('MENOPAUSE_DATA', "ICPSR_31901/DS0001/31901-0001-Data.tsv")
44
+ OUTPUT_DIR = os.environ.get('MENOPAUSE_OUT', "swan_ml_output")
45
+
46
+ # Parse CLI args (safe to parse here for a script; this will be ignored when imported)
47
+ parser = argparse.ArgumentParser(description='Run menopause stage prediction pipeline')
48
+ parser.add_argument('--data', '-d', default=DATA_PATH, help='Path to SWAN TSV file')
49
+ parser.add_argument('--output', '-o', default=OUTPUT_DIR, help='Output directory for artifacts')
50
+ parser.add_argument('--show', action='store_true', help='Show plots interactively (default: off)')
51
+ parser.add_argument('--stage-col', default=None, help='Override detected stage column name')
52
+ # Symptom cycle prediction CLI options
53
+ parser.add_argument('--predict-symptoms', action='store_true', help='Run symptom cycle prediction from CSV input')
54
+ parser.add_argument('--symptoms-input', default=None, help='Input CSV for symptom predictions')
55
+ parser.add_argument('--symptoms-output', default=None, help='Output CSV to write symptom predictions')
56
+ parser.add_argument('--lmp-col', default='LMP', help='Column name used as LMP (date string or day-of-month integer)')
57
+ parser.add_argument('--date-col', default=None, help='Column name for target date; if omitted, uses today or VISIT date if present')
58
+ parser.add_argument('--cycle-length', type=int, default=28, help='Average cycle length in days for symptom prediction')
59
+ # Dual prediction CLI options (separate inputs/outputs for each model)
60
+ parser.add_argument('--predict-dual', action='store_true', help='Run stage + symptom predictions using separate input/output files')
61
+ parser.add_argument('--stage-input', default=None, help='Input CSV for menopause stage predictions')
62
+ parser.add_argument('--stage-output', default=None, help='Output CSV for menopause stage predictions')
63
+ parser.add_argument('--stage-model', default='RandomForest', help='Model for stage prediction: RandomForest or LogisticRegression')
64
+ parser.add_argument('--forecast-dir', default=OUTPUT_DIR, help='Directory containing saved forecast models')
65
+ parser.add_argument('--menopause-stage-col', default=None, help='(Deprecated) Kept for backward compatibility; symptom forecasting no longer uses menopause stage')
66
+ # Parse CLI args only when script is run directly; when imported (e.g., during testing), avoid consuming external argv
67
+ if __name__ == '__main__':
68
+ args = parser.parse_args()
69
+ else:
70
+ # Use defaults when module is imported to avoid interfering with external CLI (pytest, etc.)
71
+ args = parser.parse_args([])
72
+
73
+ DATA_PATH = args.data
74
+ OUTPUT_DIR = args.output
75
+ SHOW_PLOTS = bool(args.show)
76
+ STAGE_COL_OVERRIDE = args.stage_col
77
+
78
+ # If user only wants symptom-cycle predictions, provide a fast-path before loading the large TSV
79
+ # Define a light-weight cycle-based symptom forecaster and CSV helper so users can run predictions
80
+ # without training the menopause models (useful for small CSV inputs).
81
+ class SymptomCycleForecaster:
82
+ def __init__(self, cycle_length=28, hot_mu=14, hot_sigma=5, mood_mu=26, mood_sigma=4,
83
+ base_hot=0.1, amp_hot=0.4, base_mood=0.1, amp_mood=0.45, threshold=0.5):
84
+ self.cycle_length = cycle_length
85
+ self.hot_mu = hot_mu
86
+ self.hot_sigma = hot_sigma
87
+ self.mood_mu = mood_mu
88
+ self.mood_sigma = mood_sigma
89
+ self.base_hot = base_hot
90
+ self.amp_hot = amp_hot
91
+ self.base_mood = base_mood
92
+ self.amp_mood = amp_mood
93
+ self.threshold = threshold
94
+
95
+ def _parse_lmp(self, lmp, reference_date=None):
96
+ if pd.isna(lmp):
97
+ return None
98
+ try:
99
+ lmp_int = int(lmp)
100
+ if reference_date is None:
101
+ ref = pd.Timestamp(datetime.today()).to_pydatetime()
102
+ else:
103
+ ref = pd.to_datetime(reference_date, errors='coerce')
104
+ if pd.isna(ref):
105
+ ref = pd.Timestamp(datetime.today()).to_pydatetime()
106
+ else:
107
+ ref = ref.to_pydatetime()
108
+ day = max(1, min(lmp_int, 28))
109
+ return datetime(ref.year, ref.month, day)
110
+ except Exception:
111
+ try:
112
+ return pd.to_datetime(lmp, errors='coerce').to_pydatetime()
113
+ except Exception:
114
+ return None
115
+
116
+ def compute_cycle_day(self, lmp, target_date=None):
117
+ if target_date is None:
118
+ tdate = datetime.today()
119
+ else:
120
+ tdate = pd.to_datetime(target_date, errors='coerce')
121
+ if pd.isna(tdate):
122
+ tdate = datetime.today()
123
+ else:
124
+ tdate = tdate.to_pydatetime()
125
+ lmp_date = self._parse_lmp(lmp, reference_date=tdate)
126
+ if lmp_date is None:
127
+ return None
128
+ delta = (tdate - lmp_date).days
129
+ if delta < 0:
130
+ lmp_date = lmp_date - timedelta(days=self.cycle_length)
131
+ delta = (tdate - lmp_date).days
132
+ cycle_day = (delta % self.cycle_length) + 1
133
+ return int(cycle_day)
134
+
135
+ def _gauss_prob(self, day, mu, sigma, base, amp):
136
+ if day is None:
137
+ return np.nan
138
+ val = base + amp * np.exp(-0.5 * ((day - mu) / float(sigma)) ** 2)
139
+ return float(min(max(val, 0.0), 1.0))
140
+
141
+ def predict_single(self, lmp, target_date=None):
142
+ day = self.compute_cycle_day(lmp, target_date=target_date)
143
+ hot_p = self._gauss_prob(day, self.hot_mu, self.hot_sigma, self.base_hot, self.amp_hot)
144
+ mood_p = self._gauss_prob(day, self.mood_mu, self.mood_sigma, self.base_mood, self.amp_mood)
145
+ return {
146
+ 'cycle_day': day,
147
+ 'hotflash_prob': hot_p,
148
+ 'hotflash_pred': hot_p >= self.threshold if not np.isnan(hot_p) else None,
149
+ 'mood_prob': mood_p,
150
+ 'mood_pred': mood_p >= self.threshold if not np.isnan(mood_p) else None
151
+ }
152
+
153
+ def predict_df(self, df, lmp_col='LMP', date_col=None, menopause_stage_col=None):
154
+ df = df.copy()
155
+ results = df.apply(
156
+ lambda row: pd.Series(self.predict_single(
157
+ lmp=row.get(lmp_col),
158
+ target_date=(row.get(date_col) if date_col is not None else None)
159
+ )), axis=1
160
+ )
161
+ out = pd.concat([df.reset_index(drop=True), results.reset_index(drop=True)], axis=1)
162
+ return out
163
+
164
+
165
+ def predict_symptoms_from_csv(input_csv, output_csv, lmp_col='LMP', date_col=None,
166
+ menopause_stage_col=None, cycle_length=28, **kwargs):
167
+ df = pd.read_csv(input_csv)
168
+ fore = SymptomCycleForecaster(cycle_length=cycle_length)
169
+ out_df = fore.predict_df(df, lmp_col=lmp_col, date_col=date_col, menopause_stage_col=menopause_stage_col)
170
+ out_df.to_csv(output_csv, index=False)
171
+ print(f"Wrote symptom predictions for {out_df.shape[0]} rows to {output_csv}")
172
+ print("Sample predictions (first 5 rows):")
173
+ print(out_df[[lmp_col] + ['cycle_day','hotflash_prob','hotflash_pred','mood_prob','mood_pred']].head().to_string())
174
+
175
+ # If the user requested only symptom predictions from a CSV, run fast-path and exit
176
+ if args.predict_symptoms:
177
+ if not args.symptoms_input or not args.symptoms_output:
178
+ print("Error: --symptoms-input and --symptoms-output are required when --predict-symptoms is set")
179
+ sys.exit(1)
180
+ else:
181
+ predict_symptoms_from_csv(
182
+ input_csv=args.symptoms_input,
183
+ output_csv=args.symptoms_output,
184
+ lmp_col=args.lmp_col,
185
+ date_col=args.date_col,
186
+ menopause_stage_col=None,
187
+ cycle_length=args.cycle_length
188
+ )
189
+ sys.exit(0)
190
+
191
+ # Fast-path for dual predictions (separate stage + symptoms) without loading large TSV
192
+ if args.predict_dual:
193
+ if not args.stage_input or not args.stage_output or not args.symptoms_input or not args.symptoms_output:
194
+ print("Error: --stage-input, --stage-output, --symptoms-input, and --symptoms-output are required when --predict-dual is set")
195
+ sys.exit(1)
196
+
197
+ # Load saved pipeline directly via joblib to avoid initializing full training pipeline
198
+ import joblib
199
+ model_file = os.path.join(args.forecast_dir, 'rf_pipeline.pkl' if args.stage_model == 'RandomForest' else 'lr_pipeline.pkl')
200
+ try:
201
+ pipeline = joblib.load(model_file)
202
+ except Exception as e:
203
+ print(f"ERROR: Could not load model file '{model_file}': {e}")
204
+ print("Please train the models first (run the script without --predict-dual) or provide correct --forecast-dir")
205
+ sys.exit(1)
206
+
207
+ # Stage predictions
208
+ try:
209
+ stage_data = pd.read_csv(args.stage_input)
210
+ except Exception as e:
211
+ print(f"ERROR: Could not read stage input CSV '{args.stage_input}': {e}")
212
+ sys.exit(1)
213
+
214
+ id_cols = ['ID', 'id', 'SWANID', 'individual', 'Individual', 'subject', 'Subject']
215
+ feature_cols = [c for c in stage_data.columns if c not in id_cols]
216
+
217
+ # Attempt to load feature metadata so we can reindex inputs to expected features
218
+ import json
219
+ metadata_path = os.path.join(args.forecast_dir, 'forecast_metadata.json')
220
+ try:
221
+ with open(metadata_path, 'r') as f:
222
+ metadata = json.load(f)
223
+ expected_features = metadata.get('feature_names', feature_cols)
224
+ except Exception:
225
+ expected_features = feature_cols
226
+
227
+ X = stage_data.reindex(columns=expected_features, fill_value=np.nan)
228
+ preds = pd.DataFrame({'predicted_stage': pipeline.predict(X), 'model': args.stage_model})
229
+ try:
230
+ proba = pipeline.predict_proba(X)
231
+ final_est = pipeline.named_steps[list(pipeline.named_steps.keys())[-1]]
232
+ preds['confidence'] = np.max(proba, axis=1)
233
+ for i, cls in enumerate(final_est.classes_):
234
+ preds[f'prob_{cls}'] = proba[:, i]
235
+ except Exception:
236
+ preds['confidence'] = np.nan
237
+
238
+ id_data = stage_data[[c for c in id_cols if c in stage_data.columns]] if any(c in stage_data.columns for c in id_cols) else None
239
+ if id_data is not None:
240
+ stage_results = pd.concat([id_data.reset_index(drop=True), preds.reset_index(drop=True)], axis=1)
241
+ else:
242
+ stage_results = preds.reset_index(drop=True)
243
+ stage_results.insert(0, 'individual', range(1, len(stage_results) + 1))
244
+
245
+ stage_results.to_csv(args.stage_output, index=False)
246
+ print(f"Wrote stage predictions for {stage_results.shape[0]} rows to {args.stage_output}")
247
+
248
+ # Symptom predictions (independent input/output)
249
+ try:
250
+ symptom_data = pd.read_csv(args.symptoms_input)
251
+ except Exception as e:
252
+ print(f"ERROR: Could not read symptom input CSV '{args.symptoms_input}': {e}")
253
+ sys.exit(1)
254
+
255
+ date_col = args.date_col if args.date_col else ('date' if 'date' in symptom_data.columns else None)
256
+ fore = SymptomCycleForecaster(cycle_length=args.cycle_length)
257
+ symptom_results = fore.predict_df(symptom_data, lmp_col=args.lmp_col, date_col=date_col)
258
+ symptom_results.to_csv(args.symptoms_output, index=False)
259
+ print(f"Wrote symptom predictions for {symptom_results.shape[0]} rows to {args.symptoms_output}")
260
+ sys.exit(0)
261
+
262
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
263
+
264
+ # --------------------------
265
+ # Utility: make column names unique (pandas allows duplicates)
266
+ # --------------------------
267
+ def make_unique_columns(cols):
268
+ counts = {}
269
+ new_cols = []
270
+ for c in cols:
271
+ if c not in counts:
272
+ counts[c] = 0
273
+ new_cols.append(c)
274
+ else:
275
+ counts[c] += 1
276
+ new_cols.append(f"{c}__dup{counts[c]}")
277
+ return new_cols
278
+
279
+ # --------------------------
280
+ # 1. Load data
281
+ # --------------------------
282
+ # Guard: only run training and heavy data loading when script is executed directly
283
+ if __name__ == '__main__' and os.path.exists(DATA_PATH):
284
+ print("Loading data from:", DATA_PATH)
285
+ df = pd.read_csv(DATA_PATH, sep='\t', low_memory=False)
286
+ print("Original shape:", df.shape)
287
+
288
+ # make column names unique for robust selection (duplicates -> __dup1, __dup2)
289
+ df.columns = make_unique_columns(df.columns.tolist())
290
+
291
+ # Show a few columns (first 40) so user can inspect if running interactively
292
+ print("First 40 column names (for inspection):")
293
+ print(df.columns[:40].tolist())
294
+
295
+ # --------------------------
296
+ # 2. Identify candidate self-reported features and menopause-stage variable
297
+ # --------------------------
298
+ # Heuristic patterns for self-report variables (adjust if you'd like to include additional columns)
299
+ selfreport_patterns = [
300
+ r'VMS', r'HOT', r'HOTFL', r'NIGHTSW', r'SLEEP', r'CESD', r'STRESS', r'MOOD',
301
+ r'SMOK', r'ALCOH', r'ALCO', r'EXER', r'PHYS', r'ACTIV', r'VAG', r'URINE', r'SEX', r'PAIN',
302
+ r'FATIG', r'IRRIT', r'ANXI', r'DEPRESS', r'BLEED', r'MENSE', r'PERIOD', r'LMP',
303
+ r'HOTSW', r'QOL', r'DRY'
304
+ ]
305
+ # Exclude laboratory/biomarker variable name patterns
306
+ biomarker_exclude = r'E2|FSH|GLUCOSE|CHOLESTEROL|HDL|TRIG|SHBG|DHEAS|INSULIN|BMD|BP|HEIGHT|WEIGHT'
307
+
308
+ upper_cols = {c: c.upper() for c in df.columns}
309
+
310
+ selfreport_cols = []
311
+ for orig, up in upper_cols.items():
312
+ for pat in selfreport_patterns:
313
+ if re.search(pat, up):
314
+ # skip biomarkers that match both symptom patterns and biomarker patterns
315
+ if re.search(biomarker_exclude, up):
316
+ continue
317
+ selfreport_cols.append(orig)
318
+ break
319
+
320
+ # Also include basic self-report demographics commonly present (AGE, RACE)
321
+ for dem in ['AGE7','AGE','RACE','LANGINT7','LANGINT']:
322
+ if dem in df.columns and dem not in selfreport_cols:
323
+ selfreport_cols.append(dem)
324
+
325
+ # Deduplicate preserving order
326
+ seen=set()
327
+ selfreport_cols = [x for x in selfreport_cols if not (x in seen or seen.add(x))]
328
+
329
+ print(f"Found {len(selfreport_cols)} candidate self-reported columns (first 50 shown):")
330
+ print(selfreport_cols[:50])
331
+
332
+ # Identify menopause-stage variable heuristically
333
+ stage_cand_patterns = [r'MENOSTAT', r'MENOSYM', r'MENO', r'MENOP', r'MENST', r'MENSE', r'STATUS']
334
+ stage_candidates = [c for c in df.columns if any(re.search(p, c, flags=re.I) for p in stage_cand_patterns)]
335
+ print("Stage-like candidate columns (found):", stage_candidates[:10])
336
+
337
+ # If user provided an override for stage column via CLI, honor it (if present in data)
338
+ if STAGE_COL_OVERRIDE:
339
+ if STAGE_COL_OVERRIDE in df.columns:
340
+ print(f"Using overridden stage column: {STAGE_COL_OVERRIDE}")
341
+ stage_candidates = [STAGE_COL_OVERRIDE]
342
+ else:
343
+ print(f"Warning: requested stage column '{STAGE_COL_OVERRIDE}' not present in data; proceeding with heuristic detection")
344
+
345
+ # If multiple candidates choose one with few unique values (likely coded categories)
346
+ stage_col = None
347
+ for c in stage_candidates:
348
+ nunique = df[c].nunique(dropna=True)
349
+ # prefer small discrete sets (e.g., 2-6 categories)
350
+ if 1 < nunique <= 20:
351
+ stage_col = c
352
+ break
353
+
354
+ if stage_col is None and stage_candidates:
355
+ # fallback to first candidate
356
+ stage_col = stage_candidates[0]
357
+
358
+ if stage_col is None:
359
+ raise RuntimeError("No menopause-stage-like column found automatically. Inspect df.columns and pick the proper variable (e.g., MENOSTAT).")
360
+
361
+ print("Selected stage column:", stage_col, " unique values:", df[stage_col].nunique(dropna=True))
362
+ print("Sample raw counts:")
363
+ print(df[stage_col].value_counts(dropna=False).head(20))
364
+
365
+ # --------------------------
366
+ # 3. Create working dataframe with self-report features + stage
367
+ # --------------------------
368
+ use_cols = [stage_col] + [c for c in selfreport_cols if c in df.columns and c != stage_col]
369
+ data = df[use_cols].copy()
370
+
371
+ # Replace common SWAN missing codes with NaN
372
+ missing_values = [-9, -8, -7, -1, '.', 'NA', 'N/A', '999', 9999]
373
+ data.replace(missing_values, np.nan, inplace=True)
374
+
375
+ # Try convert object columns to numeric when appropriate
376
+ for col in data.columns:
377
+ if data[col].dtype == object:
378
+ coerced = pd.to_numeric(data[col].astype(str).str.strip(), errors='coerce')
379
+ # If many values become numeric, use numeric version; else leave as categorical string
380
+ if coerced.notna().sum() > len(coerced) * 0.5:
381
+ data[col] = coerced
382
+ else:
383
+ # replace blank/'nan' strings with np.nan
384
+ data[col] = data[col].astype(str).str.strip().replace({'nan': np.nan, '': np.nan})
385
+
386
+ # --------------------------
387
+ # 4. Map stage variable to standardized labels {pre, peri, post}
388
+ # *Important*: this is heuristic. Verify using the codebook and adjust mapping if needed.
389
+ # --------------------------
390
+ def map_stage_to_labels(series):
391
+ # Try textual mapping first
392
+ s = series.copy()
393
+ try:
394
+ uniques = [str(x).lower() for x in s.dropna().unique()]
395
+ except Exception:
396
+ uniques = []
397
+ # textual mapping
398
+ if any(x in ['pre','premenopausal','premenopause','pre-menopausal'] for x in uniques):
399
+ s = s.astype(str).str.lower()
400
+ s = s.replace({'premenopausal':'pre','pre-menopausal':'pre','pre-menopause':'pre','pre':'pre'})
401
+ s = s.replace({'perimenopausal':'peri','peri-menopausal':'peri','peri':'peri'})
402
+ s = s.replace({'postmenopausal':'post','post-menopausal':'post','post':'post'})
403
+ return s.map({'pre':'pre','peri':'peri','post':'post'})
404
+ # numeric mapping heuristic: map min->pre, median->peri, max->post
405
+ num = pd.to_numeric(s, errors='coerce')
406
+ num_unique = sorted(num.dropna().unique().tolist())
407
+ if len(num_unique) >= 3:
408
+ mapping = {num_unique[0]:'pre', num_unique[len(num_unique)//2]:'peri', num_unique[-1]:'post'}
409
+ return num.map(mapping)
410
+ # 2-level mapping (assume 1->pre,2->post) or fallback
411
+ if len(num_unique) == 2:
412
+ return num.map({num_unique[0]:'pre', num_unique[1]:'post'})
413
+ # If not mappable, return NaN series
414
+ return pd.Series([np.nan]*len(s), index=s.index)
415
+
416
+ mapped_stage = map_stage_to_labels(data[stage_col])
417
+ # If mapping failed (too many NaNs), attempt a simple bleed-based heuristic (last menstrual period)
418
+ if mapped_stage.isna().mean() > 0.9:
419
+ bleed_candidates = [c for c in data.columns if re.search(r'LMP|BLEED|PERIOD|MENSTR', c, flags=re.I)]
420
+ if len(bleed_candidates) > 0:
421
+ lcol = bleed_candidates[0]
422
+ lnum = pd.to_numeric(data[lcol], errors='coerce')
423
+ mapped_stage = pd.Series(index=data.index, dtype=object)
424
+ mapped_stage[lnum.isna()] = 'post'
425
+ mapped_stage[lnum.notna()] = 'pre'
426
+ else:
427
+ raise RuntimeError("Failed to map stage variable to pre/peri/post and no bleed/LMP variable found.")
428
+
429
+ data['_menopause_stage'] = mapped_stage
430
+ print("Mapped stage counts (after heuristic mapping):")
431
+ print(data['_menopause_stage'].value_counts(dropna=False))
432
+
433
+ # Drop rows with no mapped stage
434
+ data = data[~data['_menopause_stage'].isna()].copy()
435
+ print("Rows available for modeling:", data.shape[0])
436
+
437
+ # --------------------------
438
+ # 5. Feature selection for modeling
439
+ # Keep only self-report fields with enough non-missing values and >1 unique value
440
+ # --------------------------
441
+ feature_candidates = [c for c in use_cols if c != stage_col]
442
+ selected_features = []
443
+ for c in feature_candidates:
444
+ non_null = data[c].notna().sum()
445
+ # require at least 2% nonmissing or minimum 50 observations
446
+ if non_null < max(50, len(data) * 0.02):
447
+ continue
448
+ if data[c].nunique(dropna=True) <= 1:
449
+ continue
450
+ selected_features.append(c)
451
+
452
+ print("Number of features selected for modeling:", len(selected_features))
453
+ print("First 40 features (if many):", selected_features[:40])
454
+
455
+ # --------------------------
456
+ # 6. Preprocessing pipeline
457
+ # Numeric features: impute mean
458
+ # Categorical features: impute most frequent + one-hot encode
459
+ # Normalization: only added for logistic regression pipeline (tree-based RF doesn't need scaling)
460
+ # --------------------------
461
+ numeric_feats = [c for c in selected_features if pd.api.types.is_numeric_dtype(data[c])]
462
+ cat_feats = [c for c in selected_features if c not in numeric_feats]
463
+
464
+ from sklearn.pipeline import Pipeline
465
+ from sklearn.compose import ColumnTransformer
466
+
467
+ numeric_transformer = Pipeline(steps=[
468
+ ('imputer', SimpleImputer(strategy='mean'))
469
+ ])
470
+
471
+ # Construct OneHotEncoder in a sklearn-version compatible way
472
+ try:
473
+ ohe = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
474
+ except TypeError:
475
+ # older sklearn versions use `sparse` kwarg
476
+ ohe = OneHotEncoder(handle_unknown='ignore', sparse=False)
477
+
478
+ categorical_transformer = Pipeline(steps=[
479
+ ('imputer', SimpleImputer(strategy='most_frequent')),
480
+ ('onehot', ohe)
481
+ ])
482
+
483
+ preprocessor = ColumnTransformer(transformers=[
484
+ ('num', numeric_transformer, numeric_feats),
485
+ ('cat', categorical_transformer, cat_feats)
486
+ ], remainder='drop')
487
+
488
+ # Two pipelines: RandomForest (no scaling) and LogisticRegression (scaling)
489
+ rf_pipeline = Pipeline(steps=[
490
+ ('pre', preprocessor),
491
+ ('rf', RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1))
492
+ ])
493
+
494
+ lr_pipeline = Pipeline(steps=[
495
+ ('pre', preprocessor),
496
+ ('scaler', StandardScaler()),
497
+ ('lr', LogisticRegression(solver='lbfgs', max_iter=1000))
498
+ ])
499
+
500
+ # --------------------------
501
+ # 7. Prepare data, train/test split
502
+ # --------------------------
503
+ X = data[selected_features].copy()
504
+ y = data['_menopause_stage'].copy().astype(str) # values: 'pre','peri','post' (hopefully)
505
+
506
+ print("Target class distribution:")
507
+ print(y.value_counts())
508
+
509
+ # Stratified split
510
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)
511
+ print("Train / test sizes:", X_train.shape[0], X_test.shape[0])
512
+
513
+ # --------------------------
514
+ # 8. Train models
515
+ # --------------------------
516
+ print("Training RandomForest...")
517
+ rf_pipeline.fit(X_train, y_train)
518
+ print("RandomForest trained.")
519
+
520
+ print("Training LogisticRegression (multinomial)...")
521
+ lr_pipeline.fit(X_train, y_train)
522
+ print("LogisticRegression trained.")
523
+
524
+ # --------------------------
525
+ # 9. Predictions and assessment
526
+ # --------------------------
527
+ def evaluate_model(pipeline, X_test, y_test, model_name, output_dir=OUTPUT_DIR):
528
+ y_pred = pipeline.predict(X_test)
529
+ report = classification_report(y_test, y_pred)
530
+ print(f"\n=== {model_name} Classification Report ===\n{report}")
531
+ # confusion matrix
532
+ labels = sorted(y_test.unique())
533
+ cm = confusion_matrix(y_test, y_pred, labels=labels)
534
+ print(f"{model_name} Confusion Matrix (rows=true, cols=pred):\nLabels: {labels}\n{cm}")
535
+ # Save classification report
536
+ with open(os.path.join(output_dir, f"classification_report_{model_name.replace(' ','_')}.txt"), "w") as f:
537
+ f.write(report)
538
+ # Plot confusion matrix with matplotlib
539
+ fig, ax = plt.subplots(figsize=(5,4))
540
+ im = ax.imshow(cm, interpolation='nearest')
541
+ ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=45)
542
+ ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels)
543
+ ax.set_title(f"{model_name} Confusion Matrix")
544
+ for i in range(cm.shape[0]):
545
+ for j in range(cm.shape[1]):
546
+ ax.text(j, i, format(cm[i, j], 'd'), ha="center", va="center")
547
+ plt.tight_layout()
548
+ plt.savefig(os.path.join(output_dir, f"{model_name.replace(' ','_')}_confusion_matrix.png"))
549
+ # Show plots only when requested; otherwise close to free resources (non-interactive default)
550
+ if SHOW_PLOTS:
551
+ plt.show()
552
+ else:
553
+ plt.close('all')
554
+ return y_pred, cm
555
+
556
+ rf_pred, rf_cm = evaluate_model(rf_pipeline, X_test, y_test, "RandomForest")
557
+ lr_pred, lr_cm = evaluate_model(lr_pipeline, X_test, y_test, "LogisticRegression")
558
+
559
+ # 10. Feature importance
560
+ # Extract feature names after preprocessing (numerics stay same; categorical one-hot create names)
561
+ pre = rf_pipeline.named_steps['pre']
562
+ # Get numeric feature names
563
+ feature_names = []
564
+ if len(numeric_feats) > 0:
565
+ feature_names.extend(numeric_feats)
566
+ if len(cat_feats) > 0:
567
+ # Get onehot output names
568
+ ohe = pre.named_transformers_['cat'].named_steps['onehot']
569
+ try:
570
+ cat_onehot_names = ohe.get_feature_names_out(cat_feats)
571
+ except Exception:
572
+ # fallback
573
+ cat_onehot_names = []
574
+ feature_names.extend(cat_onehot_names.tolist() if hasattr(cat_onehot_names, 'tolist') else list(cat_onehot_names))
575
+ # Feature importances from RandomForest
576
+ rf_model = rf_pipeline.named_steps['rf']
577
+ importances = rf_model.feature_importances_
578
+ imp_df = pd.DataFrame({'feature': feature_names, 'importance': importances}).sort_values('importance', ascending=False)
579
+ imp_df.to_csv(os.path.join(OUTPUT_DIR, "rf_feature_importances.csv"), index=False)
580
+ print("\nTop 20 RF feature importances:")
581
+ print(imp_df.head(20).to_string(index=False))
582
+
583
+ # Permutation importance (robust)
584
+ print("Computing permutation importance (this can take some time)...")
585
+ perm = permutation_importance(rf_pipeline, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1)
586
+ perm_idx = perm.importances_mean.argsort()[::-1]
587
+ perm_df = pd.DataFrame({
588
+ 'feature': np.array(feature_names)[perm_idx],
589
+ 'importance_mean': perm.importances_mean[perm_idx],
590
+ 'importance_std': perm.importances_std[perm_idx]
591
+ })
592
+ perm_df.to_csv(os.path.join(OUTPUT_DIR, "rf_permutation_importances.csv"), index=False)
593
+ print("Top 20 permutation importances:")
594
+ print(perm_df.head(20).to_string(index=False))
595
+
596
+ # Plot RF top features
597
+ topn = min(20, imp_df.shape[0])
598
+ fig, ax = plt.subplots(figsize=(8,6))
599
+ ax.barh(imp_df['feature'].head(topn)[::-1], imp_df['importance'].head(topn)[::-1])
600
+ ax.set_title("RandomForest: Top feature importances")
601
+ ax.set_xlabel("Importance")
602
+ plt.tight_layout()
603
+ plt.savefig(os.path.join(OUTPUT_DIR, "rf_top_feature_importances.png"))
604
+ if SHOW_PLOTS:
605
+ plt.show()
606
+ else:
607
+ plt.close('all')
608
+
609
+ # 11. ROC curves (one-vs-rest) if predict_proba available
610
+ def plot_multiclass_roc(pipeline, X_test, y_test, model_name):
611
+ if not hasattr(pipeline, "predict_proba"):
612
+ print(f"{model_name} has no predict_proba; skipping ROC plot.")
613
+ return
614
+ # Must use same class order as pipeline's final estimator
615
+ final_est = pipeline.named_steps[list(pipeline.named_steps.keys())[-1]]
616
+ classes = final_est.classes_
617
+ y_test_bin = label_binarize(y_test, classes=classes)
618
+ y_score = pipeline.predict_proba(X_test)
619
+ for i, cls in enumerate(classes):
620
+ fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_score[:, i])
621
+ roc_auc = auc(fpr, tpr)
622
+ plt.figure()
623
+ plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
624
+ plt.plot([0,1],[0,1], linestyle='--')
625
+ plt.title(f"{model_name} ROC for class {cls}")
626
+ plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
627
+ plt.legend(loc='lower right')
628
+ plt.savefig(os.path.join(OUTPUT_DIR, f"{model_name.replace(' ','_')}_ROC_{cls}.png"))
629
+ if SHOW_PLOTS:
630
+ plt.show()
631
+ else:
632
+ plt.close('all')
633
+
634
+ print("Plotting ROC curves for RandomForest and LogisticRegression (if available)...")
635
+ if __name__ == '__main__' and 'rf_pipeline' in globals():
636
+ plot_multiclass_roc(rf_pipeline, X_test, y_test, "RandomForest")
637
+ plot_multiclass_roc(lr_pipeline, X_test, y_test, "LogisticRegression")
638
+
639
+ # ==========================================================================================
640
+ # 12. FORECASTING MODULE: Predict menopausal stage for new individuals
641
+ # ==========================================================================================
642
+ class MenopauseForecast:
643
+ """
644
+ Forecasting module for predicting menopausal stage (pre/peri/post) given self-reported features.
645
+
646
+ This class encapsulates the trained models and preprocessing pipeline to make predictions
647
+ on new data with the same features used during training.
648
+ """
649
+
650
+ def __init__(self, rf_pipeline, lr_pipeline, feature_names, stage_classes):
651
+ """
652
+ Initialize the forecaster with trained pipelines.
653
+
654
+ Parameters:
655
+ -----------
656
+ rf_pipeline : sklearn Pipeline
657
+ Trained RandomForest pipeline
658
+ lr_pipeline : sklearn Pipeline
659
+ Trained LogisticRegression pipeline
660
+ feature_names : list
661
+ List of feature column names used for training
662
+ stage_classes : list
663
+ List of possible menopause stage classes (e.g., ['pre', 'peri', 'post'])
664
+ """
665
+ self.rf_pipeline = rf_pipeline
666
+ self.lr_pipeline = lr_pipeline
667
+ self.feature_names = feature_names
668
+ self.stage_classes = stage_classes
669
+ self.models = {
670
+ 'RandomForest': rf_pipeline,
671
+ 'LogisticRegression': lr_pipeline
672
+ }
673
+
674
+ def predict_single(self, feature_dict, model='RandomForest', return_proba=True):
675
+ """
676
+ Predict menopausal stage for a single individual.
677
+
678
+ Parameters:
679
+ -----------
680
+ feature_dict : dict
681
+ Dictionary with feature names as keys and values for prediction.
682
+ Example: {'HOT7': 1, 'SLEEP7': 2, 'CESD': 10, ...}
683
+ model : str
684
+ Which model to use for prediction: 'RandomForest' or 'LogisticRegression'
685
+ return_proba : bool
686
+ If True, return prediction probabilities; otherwise just the class label
687
+
688
+ Returns:
689
+ --------
690
+ dict : Contains 'stage', 'confidence', and optionally 'probabilities'
691
+ """
692
+ if model not in self.models:
693
+ raise ValueError(f"Model '{model}' not found. Available: {list(self.models.keys())}")
694
+
695
+ # Create DataFrame with single row, reindex to match training features
696
+ X = pd.DataFrame([feature_dict]).reindex(columns=self.feature_names, fill_value=np.nan)
697
+
698
+ pipeline = self.models[model]
699
+ prediction = pipeline.predict(X)[0]
700
+
701
+ result = {
702
+ 'stage': prediction,
703
+ 'model': model,
704
+ 'confidence': None,
705
+ 'probabilities': None
706
+ }
707
+
708
+ if return_proba:
709
+ try:
710
+ proba = pipeline.predict_proba(X)[0]
711
+ result['confidence'] = float(np.max(proba))
712
+ result['probabilities'] = {
713
+ cls: float(prob)
714
+ for cls, prob in zip(pipeline.named_steps[list(pipeline.named_steps.keys())[-1]].classes_, proba)
715
+ }
716
+ except Exception as e:
717
+ print(f"Warning: Could not compute probabilities: {e}")
718
+
719
+ return result
720
+
721
+ def predict_batch(self, df, model='RandomForest', return_proba=True):
722
+ """
723
+ Predict menopausal stage for multiple individuals (batch prediction).
724
+
725
+ Parameters:
726
+ -----------
727
+ df : pd.DataFrame
728
+ DataFrame with feature columns matching training features.
729
+ Missing values will be handled by the preprocessing pipeline.
730
+ model : str
731
+ Which model to use: 'RandomForest' or 'LogisticRegression'
732
+ return_proba : bool
733
+ If True, return prediction probabilities
734
+
735
+ Returns:
736
+ --------
737
+ pd.DataFrame : Contains 'predicted_stage', 'confidence', and probability columns
738
+ """
739
+ if model not in self.models:
740
+ raise ValueError(f"Model '{model}' not found. Available: {list(self.models.keys())}")
741
+
742
+ # Reindex to match training features
743
+ X = df.reindex(columns=self.feature_names, fill_value=np.nan)
744
+
745
+ pipeline = self.models[model]
746
+ predictions = pipeline.predict(X)
747
+
748
+ result_df = pd.DataFrame({
749
+ 'predicted_stage': predictions,
750
+ 'model': model
751
+ })
752
+
753
+ if return_proba:
754
+ try:
755
+ proba = pipeline.predict_proba(X)
756
+ final_est = pipeline.named_steps[list(pipeline.named_steps.keys())[-1]]
757
+ result_df['confidence'] = np.max(proba, axis=1)
758
+
759
+ # Add probability column for each class
760
+ for i, cls in enumerate(final_est.classes_):
761
+ result_df[f'prob_{cls}'] = proba[:, i]
762
+ except Exception as e:
763
+ print(f"Warning: Could not compute probabilities: {e}")
764
+
765
+ return result_df
766
+
767
+ def compare_models(self, feature_dict):
768
+ """
769
+ Compare predictions from both RandomForest and LogisticRegression models.
770
+
771
+ Parameters:
772
+ -----------
773
+ feature_dict : dict
774
+ Feature values for the individual
775
+
776
+ Returns:
777
+ --------
778
+ dict : Predictions and probabilities from both models
779
+ """
780
+ rf_result = self.predict_single(feature_dict, model='RandomForest', return_proba=True)
781
+ lr_result = self.predict_single(feature_dict, model='LogisticRegression', return_proba=True)
782
+
783
+ return {
784
+ 'RandomForest': rf_result,
785
+ 'LogisticRegression': lr_result
786
+ }
787
+
788
+ def get_feature_info(self):
789
+ """Return information about required features."""
790
+ return {
791
+ 'num_features': len(self.feature_names),
792
+ 'feature_names': self.feature_names,
793
+ 'stage_classes': self.stage_classes
794
+ }
795
+
796
+
797
+ def create_forecast_example():
798
+ """
799
+ Create an example forecast instance and demonstrate usage.
800
+
801
+ This function is robust: if the training artifacts (`rf_pipeline`, `lr_pipeline`,
802
+ `selected_features`, `X_train`, `X_test`) are not available in memory (e.g., when
803
+ the module is imported in another process), it attempts to load saved pipelines
804
+ from `OUTPUT_DIR` via `load_forecast_model()` and uses placeholder inputs.
805
+ """
806
+ print("\n" + "="*80)
807
+ print("FORECASTING MODULE EXAMPLE: Predicting Menopausal Stage")
808
+ print("="*80)
809
+
810
+ # Determine pipelines and feature metadata (use in-memory if available, else load from disk)
811
+ try:
812
+ _rf = rf_pipeline
813
+ _lr = lr_pipeline
814
+ _features = selected_features
815
+ _stage_classes = sorted(y.unique().tolist())
816
+ has_training = True
817
+ except NameError:
818
+ print("Training artifacts not present in memory; attempting to load from disk...")
819
+ try:
820
+ _loaded = load_forecast_model(OUTPUT_DIR)
821
+ _rf = _loaded.rf_pipeline
822
+ _lr = _loaded.lr_pipeline
823
+ _features = _loaded.feature_names
824
+ _stage_classes = _loaded.stage_classes
825
+ has_training = False
826
+ except Exception as e:
827
+ raise RuntimeError(f"Failed to initialize forecaster from disk: {e}")
828
+
829
+ forecast = MenopauseForecast(
830
+ rf_pipeline=_rf,
831
+ lr_pipeline=_lr,
832
+ feature_names=_features,
833
+ stage_classes=_stage_classes
834
+ )
835
+
836
+ print(f"\nForecaster initialized with {len(_features)} features")
837
+ print(f"Predicting stages: {_stage_classes}")
838
+
839
+ # Example 1: Single individual prediction
840
+ print("\n--- Example 1: Predict for a single individual ---")
841
+ example_individual = {}
842
+ n_example_feats = min(10, len(_features))
843
+
844
+ if has_training:
845
+ for feat in _features[:n_example_feats]:
846
+ try:
847
+ example_individual[feat] = float(pd.to_numeric(X_train[feat], errors='coerce').median())
848
+ except Exception:
849
+ # Fallback to mode or NaN
850
+ try:
851
+ example_individual[feat] = X_train[feat].mode().iloc[0]
852
+ except Exception:
853
+ example_individual[feat] = np.nan
854
+ else:
855
+ # No training DF available; provide NaN placeholders to let pipeline impute
856
+ for feat in _features[:n_example_feats]:
857
+ example_individual[feat] = np.nan
858
+
859
+ result = forecast.predict_single(example_individual, model='RandomForest', return_proba=True)
860
+ print(f"Predicted stage: {result.get('stage')}")
861
+ print(f"Confidence: {result.get('confidence'):.3f}" if result.get('confidence') is not None else "Confidence: None")
862
+ if result.get('probabilities'):
863
+ print("Stage probabilities:")
864
+ for stage, prob in sorted(result['probabilities'].items()):
865
+ print(f" {stage}: {prob:.3f}")
866
+
867
+ # Example 2: Compare models
868
+ print("\n--- Example 2: Compare RandomForest vs LogisticRegression ---")
869
+ comparison = forecast.compare_models(example_individual)
870
+ for model_name, cres in comparison.items():
871
+ print(f"\n{model_name}:")
872
+ print(f" Predicted stage: {cres.get('stage')}")
873
+ print(f" Confidence: {cres.get('confidence'):.3f}" if cres.get('confidence') is not None else " Confidence: None")
874
+
875
+ # Example 3: Batch prediction on a small sample (either X_test if available or placeholder rows)
876
+ print("\n--- Example 3: Batch prediction (small sample) ---")
877
+ if has_training:
878
+ try:
879
+ test_sample = X_test.iloc[:5].copy()
880
+ batch_results = forecast.predict_batch(test_sample, model='RandomForest', return_proba=True)
881
+ print(batch_results.to_string())
882
+ except Exception as e:
883
+ print(f"Batch prediction failed on training sample: {e}")
884
+ else:
885
+ # Create a small placeholder DataFrame with feature columns filled with NaN
886
+ placeholder = pd.DataFrame([{f: np.nan for f in _features[:n_example_feats]}])
887
+ batch_results = forecast.predict_batch(placeholder, model='RandomForest', return_proba=True)
888
+ print(batch_results.to_string())
889
+
890
+ return forecast
891
+
892
+
893
+ def save_forecast_model(forecast_instance, output_dir=OUTPUT_DIR):
894
+ """
895
+ Save the forecast model instance for later use (optional: can use joblib for production).
896
+
897
+ For now, saves metadata about features and classes that can be used to reinitialize
898
+ the forecaster.
899
+
900
+ Parameters:
901
+ -----------
902
+ forecast_instance : MenopauseForecast
903
+ The forecaster to save
904
+ output_dir : str
905
+ Directory to save metadata
906
+ """
907
+ import json
908
+ import joblib
909
+
910
+ metadata = {
911
+ 'feature_names': forecast_instance.feature_names,
912
+ 'stage_classes': forecast_instance.stage_classes,
913
+ 'num_features': len(forecast_instance.feature_names)
914
+ }
915
+
916
+ # Save metadata as JSON
917
+ with open(os.path.join(output_dir, 'forecast_metadata.json'), 'w') as f:
918
+ json.dump(metadata, f, indent=2)
919
+
920
+ # Save trained pipelines using joblib (allows full reuse)
921
+ joblib.dump(forecast_instance.rf_pipeline, os.path.join(output_dir, 'rf_pipeline.pkl'))
922
+ joblib.dump(forecast_instance.lr_pipeline, os.path.join(output_dir, 'lr_pipeline.pkl'))
923
+
924
+ print(f"Forecast model saved to {output_dir}")
925
+ print(f" - forecast_metadata.json")
926
+ print(f" - rf_pipeline.pkl")
927
+ print(f" - lr_pipeline.pkl")
928
+
929
+
930
+ def load_forecast_model(output_dir=OUTPUT_DIR):
931
+ """
932
+ Load a previously saved forecast model.
933
+
934
+ Parameters:
935
+ -----------
936
+ output_dir : str
937
+ Directory containing saved models
938
+
939
+ Returns:
940
+ --------
941
+ MenopauseForecast : The loaded forecaster
942
+ """
943
+ import json
944
+ import joblib
945
+
946
+ # Load metadata
947
+ with open(os.path.join(output_dir, 'forecast_metadata.json'), 'r') as f:
948
+ metadata = json.load(f)
949
+
950
+ # Load pipelines
951
+ rf_pipeline_loaded = joblib.load(os.path.join(output_dir, 'rf_pipeline.pkl'))
952
+ lr_pipeline_loaded = joblib.load(os.path.join(output_dir, 'lr_pipeline.pkl'))
953
+
954
+ # Recreate forecaster
955
+ forecast = MenopauseForecast(
956
+ rf_pipeline=rf_pipeline_loaded,
957
+ lr_pipeline=lr_pipeline_loaded,
958
+ feature_names=metadata['feature_names'],
959
+ stage_classes=metadata['stage_classes']
960
+ )
961
+
962
+ print(f"Forecast model loaded from {output_dir}")
963
+ return forecast
964
+
965
+
966
+ # Initialize and demonstrate the forecasting module
967
+
968
+ # Symptom cycle forecasting (defined earlier near CLI args)
969
+ class SymptomCycleForecaster:
970
+ """
971
+ Predicts the probability of hot flashes and mood changes within a menstrual cycle
972
+ based on last menstrual period (LMP) date and target date.
973
+ """
974
+ def __init__(self, cycle_length=28, hot_mu=14, hot_sigma=5, mood_mu=26, mood_sigma=4,
975
+ base_hot=0.1, amp_hot=0.4, base_mood=0.1, amp_mood=0.45, threshold=0.5):
976
+ self.cycle_length = cycle_length
977
+ self.hot_mu = hot_mu
978
+ self.hot_sigma = hot_sigma
979
+ self.mood_mu = mood_mu
980
+ self.mood_sigma = mood_sigma
981
+ self.base_hot = base_hot
982
+ self.amp_hot = amp_hot
983
+ self.base_mood = base_mood
984
+ self.amp_mood = amp_mood
985
+ self.threshold = threshold
986
+
987
+ def _parse_lmp(self, lmp, reference_date=None):
988
+ """Parse LMP input which may be a full date string or an integer day-of-month."""
989
+ if pd.isna(lmp):
990
+ return None
991
+ # If numeric day (int-like), construct a date in the same month as reference_date
992
+ try:
993
+ lmp_int = int(lmp)
994
+ if reference_date is None:
995
+ ref = pd.Timestamp(datetime.today()).to_pydatetime()
996
+ else:
997
+ ref = pd.to_datetime(reference_date, errors='coerce')
998
+ if pd.isna(ref):
999
+ ref = pd.Timestamp(datetime.today()).to_pydatetime()
1000
+ else:
1001
+ ref = ref.to_pydatetime()
1002
+ # Clamp day to valid range
1003
+ day = max(1, min(lmp_int, 28))
1004
+ return datetime(ref.year, ref.month, day)
1005
+ except Exception:
1006
+ # Try parse as full date string
1007
+ try:
1008
+ return pd.to_datetime(lmp, errors='coerce').to_pydatetime()
1009
+ except Exception:
1010
+ return None
1011
+
1012
+ def compute_cycle_day(self, lmp, target_date=None):
1013
+ """Return 1-based cycle day (1..cycle_length) or None if cannot compute."""
1014
+ if target_date is None:
1015
+ tdate = datetime.today()
1016
+ else:
1017
+ tdate = pd.to_datetime(target_date, errors='coerce')
1018
+ if pd.isna(tdate):
1019
+ tdate = datetime.today()
1020
+ else:
1021
+ tdate = tdate.to_pydatetime()
1022
+ lmp_date = self._parse_lmp(lmp, reference_date=tdate)
1023
+ if lmp_date is None:
1024
+ return None
1025
+ delta = (tdate - lmp_date).days
1026
+ if delta < 0:
1027
+ # If LMP is in the future, assume it refers to previous cycle (subtract one month)
1028
+ lmp_date = lmp_date - timedelta(days=self.cycle_length)
1029
+ delta = (tdate - lmp_date).days
1030
+ cycle_day = (delta % self.cycle_length) + 1
1031
+ return int(cycle_day)
1032
+
1033
+ def _gauss_prob(self, day, mu, sigma, base, amp):
1034
+ if day is None:
1035
+ return np.nan
1036
+ val = base + amp * np.exp(-0.5 * ((day - mu) / float(sigma)) ** 2)
1037
+ return float(min(max(val, 0.0), 1.0))
1038
+
1039
+ def predict_single(self, lmp, target_date=None):
1040
+ day = self.compute_cycle_day(lmp, target_date=target_date)
1041
+ hot_p = self._gauss_prob(day, self.hot_mu, self.hot_sigma, self.base_hot, self.amp_hot)
1042
+ mood_p = self._gauss_prob(day, self.mood_mu, self.mood_sigma, self.base_mood, self.amp_mood)
1043
+ return {
1044
+ 'cycle_day': day,
1045
+ 'hotflash_prob': hot_p,
1046
+ 'hotflash_pred': hot_p >= self.threshold if not np.isnan(hot_p) else None,
1047
+ 'mood_prob': mood_p,
1048
+ 'mood_pred': mood_p >= self.threshold if not np.isnan(mood_p) else None
1049
+ }
1050
+
1051
+ def predict_df(self, df, lmp_col='LMP', date_col=None, menopause_stage_col=None):
1052
+ df = df.copy()
1053
+ results = df.apply(
1054
+ lambda row: pd.Series(self.predict_single(
1055
+ lmp=row.get(lmp_col),
1056
+ target_date=(row.get(date_col) if date_col is not None else None)
1057
+ )), axis=1
1058
+ )
1059
+ out = pd.concat([df.reset_index(drop=True), results.reset_index(drop=True)], axis=1)
1060
+ return out
1061
+
1062
+
1063
+ def predict_symptoms_from_csv(input_csv, output_csv, lmp_col='LMP', date_col=None,
1064
+ menopause_stage_col=None, cycle_length=28, **kwargs):
1065
+ """Read input CSV, predict hot flashes/mood by cycle day, and write output CSV."""
1066
+ df = pd.read_csv(input_csv)
1067
+ fore = SymptomCycleForecaster(cycle_length=cycle_length)
1068
+ out_df = fore.predict_df(df, lmp_col=lmp_col, date_col=date_col, menopause_stage_col=menopause_stage_col)
1069
+ out_df.to_csv(output_csv, index=False)
1070
+ # Print a brief summary
1071
+ print(f"Wrote symptom predictions for {out_df.shape[0]} rows to {output_csv}")
1072
+ print("Sample predictions (first 5 rows):")
1073
+ print(out_df[[lmp_col] + ['cycle_day','hotflash_prob','hotflash_pred','mood_prob','mood_pred']].head().to_string())
1074
+
1075
+ # CLI integration: run symptom prediction if requested
1076
+ if __name__ == '__main__':
1077
+ # If symptom prediction requested via CLI, run fast-path and exit
1078
+ if args.predict_symptoms:
1079
+ if not args.symptoms_input or not args.symptoms_output:
1080
+ print("Error: --symptoms-input and --symptoms-output are required when --predict-symptoms is set")
1081
+ sys.exit(1)
1082
+ else:
1083
+ predict_symptoms_from_csv(
1084
+ input_csv=args.symptoms_input,
1085
+ output_csv=args.symptoms_output,
1086
+ lmp_col=args.lmp_col,
1087
+ date_col=args.date_col,
1088
+ cycle_length=args.cycle_length
1089
+ )
1090
+ sys.exit(0)
1091
+
1092
+ # Dual predictions are handled in the early fast-path above to avoid training.
1093
+
1094
+ # Default behavior: create demo forecaster, save trained models and show summary
1095
+ forecast_model = create_forecast_example()
1096
+ save_forecast_model(forecast_model)
1097
+
1098
+ print("\n" + "="*80)
1099
+ print("FORECASTING MODULE SUMMARY")
1100
+ print("="*80)
1101
+ print("""
1102
+ The MenopauseForecast class provides three main methods for predictions:
1103
+
1104
+ 1. predict_single(feature_dict, model='RandomForest', return_proba=True)
1105
+ - Predict stage for one individual given feature values
1106
+ - Returns predicted stage and confidence scores
1107
+
1108
+ 2. predict_batch(df, model='RandomForest', return_proba=True)
1109
+ - Predict stages for multiple individuals
1110
+ - Returns DataFrame with predictions and probabilities for each stage
1111
+
1112
+ 3. compare_models(feature_dict)
1113
+ - Compare predictions from both RandomForest and LogisticRegression
1114
+ - Useful for validating model agreement
1115
+
1116
+ Usage in your own code:
1117
+ from menopause import load_forecast_model
1118
+
1119
+ # Load the trained forecaster
1120
+ forecast = load_forecast_model('swan_ml_output')
1121
+
1122
+ # Predict for an individual
1123
+ features = {'HOT7': 1, 'SLEEP7': 2, 'CESD': 10, ...}
1124
+ result = forecast.predict_single(features, model='RandomForest')
1125
+
1126
+ # Predict for multiple individuals
1127
+ results_df = forecast.predict_batch(your_dataframe, model='RandomForest')
1128
+ """)
1129
+
1130
+
1131
+ # ==========================================================================================
1132
+ # 13. CSV INPUT/OUTPUT FUNCTIONALITY: Batch prediction from CSV files
1133
+ # ==========================================================================================
1134
+
1135
+ def predict_from_csv(input_csv, forecast_instance, output_csv=None, model='RandomForest', output_dir=OUTPUT_DIR):
1136
+ """
1137
+ Read individual data from CSV, make predictions, and save results.
1138
+
1139
+ Parameters:
1140
+ -----------
1141
+ input_csv : str
1142
+ Path to input CSV file with feature columns for individuals
1143
+ CSV should have columns matching training features (or subset)
1144
+ forecast_instance : MenopauseForecast
1145
+ The trained forecaster instance
1146
+ output_csv : str
1147
+ Path to output CSV file (default: input_csv with '_predictions' appended)
1148
+ model : str
1149
+ Which model to use ('RandomForest' or 'LogisticRegression')
1150
+ output_dir : str
1151
+ Directory to save results (for metadata)
1152
+
1153
+ Returns:
1154
+ --------
1155
+ pd.DataFrame : Results with predictions and confidence scores
1156
+
1157
+ Example:
1158
+ --------
1159
+ forecast = load_forecast_model('swan_ml_output')
1160
+ results = predict_from_csv('individuals.csv', forecast)
1161
+ # Results saved to 'individuals_predictions.csv'
1162
+ """
1163
+ import os
1164
+
1165
+ # Read input CSV
1166
+ print(f"Reading input data from: {input_csv}")
1167
+ try:
1168
+ data = pd.read_csv(input_csv)
1169
+ except FileNotFoundError:
1170
+ print(f"ERROR: File not found: {input_csv}")
1171
+ return None
1172
+
1173
+ n_samples = len(data)
1174
+ print(f"Loaded {n_samples} individuals")
1175
+
1176
+ # Identify feature columns (exclude ID columns)
1177
+ id_cols = ['ID', 'id', 'SWANID', 'individual', 'Individual', 'subject', 'Subject']
1178
+ feature_cols = [c for c in data.columns if c not in id_cols]
1179
+
1180
+ # Separate ID columns from features
1181
+ id_data = data[[c for c in id_cols if c in data.columns]] if any(c in data.columns for c in id_cols) else None
1182
+
1183
+ # Make predictions
1184
+ print(f"Making predictions using {model}...")
1185
+ predictions = forecast_instance.predict_batch(
1186
+ data[feature_cols],
1187
+ model=model,
1188
+ return_proba=True
1189
+ )
1190
+
1191
+ # Combine with original data
1192
+ if id_data is not None:
1193
+ results = pd.concat([id_data.reset_index(drop=True), predictions.reset_index(drop=True)], axis=1)
1194
+ else:
1195
+ results = predictions.reset_index(drop=True)
1196
+
1197
+ # Add individual index if no ID column
1198
+ if id_data is None:
1199
+ results.insert(0, 'individual', range(1, n_samples + 1))
1200
+
1201
+ # Set output file path
1202
+ if output_csv is None:
1203
+ base, ext = os.path.splitext(input_csv)
1204
+ output_csv = f"{base}_predictions{ext}"
1205
+
1206
+ # Save results
1207
+ print(f"Saving predictions to: {output_csv}")
1208
+ results.to_csv(output_csv, index=False)
1209
+ return results
1210
+
1211
+
1212
+ def predict_dual_from_csv(stage_input_csv, stage_output_csv, symptoms_input_csv, symptoms_output_csv,
1213
+ forecast_dir=OUTPUT_DIR, model='RandomForest', lmp_col='LMP',
1214
+ date_col=None, cycle_length=28):
1215
+ """Run menopause stage prediction and symptom-cycle prediction using separate
1216
+ input and output files for each model.
1217
+
1218
+ Returns:
1219
+ --------
1220
+ dict : {'stage': stage_results_df, 'symptoms': symptom_results_df}
1221
+ """
1222
+ print(f"Reading stage input data from: {stage_input_csv}")
1223
+ try:
1224
+ stage_data = pd.read_csv(stage_input_csv)
1225
+ except FileNotFoundError:
1226
+ print(f"ERROR: File not found: {stage_input_csv}")
1227
+ return None
1228
+
1229
+ # Load forecast model
1230
+ try:
1231
+ forecast = load_forecast_model(output_dir=forecast_dir)
1232
+ except Exception as e:
1233
+ print(f"ERROR: Could not load forecast model from '{forecast_dir}': {e}")
1234
+ return None
1235
+
1236
+ # Identify id and feature columns
1237
+ id_cols = ['ID', 'id', 'SWANID', 'individual', 'Individual', 'subject', 'Subject']
1238
+ feature_cols = [c for c in stage_data.columns if c not in id_cols]
1239
+
1240
+ # Make stage predictions
1241
+ print(f"Making menopause stage predictions using {model}...")
1242
+ stage_preds = forecast.predict_batch(stage_data[feature_cols], model=model, return_proba=True)
1243
+
1244
+ id_data = stage_data[[c for c in id_cols if c in stage_data.columns]] if any(c in stage_data.columns for c in id_cols) else None
1245
+ if id_data is not None:
1246
+ stage_results = pd.concat([id_data.reset_index(drop=True), stage_preds.reset_index(drop=True)], axis=1)
1247
+ else:
1248
+ stage_results = stage_preds.reset_index(drop=True)
1249
+ stage_results.insert(0, 'individual', range(1, len(stage_results) + 1))
1250
+
1251
+ # Default stage output path if not provided
1252
+ if stage_output_csv is None:
1253
+ base, ext = os.path.splitext(stage_input_csv)
1254
+ stage_output_csv = f"{base}_stage_predictions{ext}"
1255
+
1256
+ print(f"Saving stage predictions to: {stage_output_csv}")
1257
+ stage_results.to_csv(stage_output_csv, index=False)
1258
+
1259
+ # Symptom predictions (independent)
1260
+ print(f"Reading symptom input data from: {symptoms_input_csv}")
1261
+ try:
1262
+ symptom_data = pd.read_csv(symptoms_input_csv)
1263
+ except FileNotFoundError:
1264
+ print(f"ERROR: File not found: {symptoms_input_csv}")
1265
+ return None
1266
+
1267
+ if date_col is None and 'date' in symptom_data.columns:
1268
+ date_col = 'date'
1269
+
1270
+ fore = SymptomCycleForecaster(cycle_length=cycle_length)
1271
+ symptom_results = fore.predict_df(symptom_data, lmp_col=lmp_col, date_col=date_col)
1272
+
1273
+ # Default symptom output path if not provided
1274
+ if symptoms_output_csv is None:
1275
+ base, ext = os.path.splitext(symptoms_input_csv)
1276
+ symptoms_output_csv = f"{base}_symptom_predictions{ext}"
1277
+
1278
+ print(f"Saving symptom predictions to: {symptoms_output_csv}")
1279
+ symptom_results.to_csv(symptoms_output_csv, index=False)
1280
+
1281
+ return {'stage': stage_results, 'symptoms': symptom_results}
1282
+
1283
+
1284
+ def predict_combined_from_csv(*args, **kwargs):
1285
+ """Deprecated: combined predictions are removed in favor of separate input/output files."""
1286
+ raise ValueError(
1287
+ "Combined predictions are deprecated. Use predict_dual_from_csv() with separate stage and symptom input/output files."
1288
+ )
1289
+
1290
+
1291
+ def create_demo_csv(forecast_instance, num_individuals=5, output_file='demo_individuals.csv', output_dir=OUTPUT_DIR):
1292
+ """
1293
+ Create a demo CSV file with sample individuals for testing predictions.
1294
+ Uses statistics from the training data to generate realistic feature values.
1295
+
1296
+ Parameters:
1297
+ -----------
1298
+ forecast_instance : MenopauseForecast
1299
+ The trained forecaster (used to get feature names)
1300
+ num_individuals : int
1301
+ Number of demo individuals to generate
1302
+ output_file : str
1303
+ Path to output CSV file
1304
+ output_dir : str
1305
+ Directory to save demo file
1306
+
1307
+ Returns:
1308
+ --------
1309
+ str : Path to created CSV file
1310
+ """
1311
+
1312
+ # Get feature names from forecaster
1313
+ feature_names = forecast_instance.feature_names
1314
+
1315
+ # Create demo data with random realistic values
1316
+ np.random.seed(42)
1317
+ demo_data = {}
1318
+
1319
+ # Add individual ID
1320
+ demo_data['individual'] = [f"Individual_{i+1}" for i in range(num_individuals)]
1321
+
1322
+ # Generate random feature values (using ranges typical for SWAN data)
1323
+ for feat in feature_names:
1324
+ # Random values between 1 and 5 (typical Likert scale for SWAN)
1325
+ demo_data[feat] = np.random.randint(1, 6, size=num_individuals)
1326
+
1327
+ # Create DataFrame
1328
+ demo_df = pd.DataFrame(demo_data)
1329
+
1330
+ # Create full path
1331
+ full_path = os.path.join(output_dir, output_file)
1332
+
1333
+ # Ensure output directory exists
1334
+ os.makedirs(output_dir, exist_ok=True)
1335
+
1336
+ # Save demo file
1337
+ demo_df.to_csv(full_path, index=False)
1338
+
1339
+ print(f"✅ Demo CSV created: {full_path}")
1340
+ print(f" Individuals: {num_individuals}")
1341
+ print(f" Features: {len(feature_names)}")
1342
+ print(f" File shape: {demo_df.shape}")
1343
+
1344
+ return full_path
1345
+
1346
+
1347
+ def add_performance_metrics_to_csv(results_df, y_test=None, model_name='RandomForest'):
1348
+ """
1349
+ Add performance metrics to predictions CSV.
1350
+ If true labels available, computes accuracy, precision, recall, F1-score.
1351
+
1352
+ Parameters:
1353
+ -----------
1354
+ results_df : pd.DataFrame
1355
+ Results dataframe with predictions
1356
+ y_test : array-like
1357
+ True labels (optional)
1358
+ model_name : str
1359
+ Name of model used
1360
+
1361
+ Returns:
1362
+ --------
1363
+ pd.DataFrame : Results with metrics appended
1364
+ """
1365
+
1366
+ if y_test is not None:
1367
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
1368
+
1369
+ acc = accuracy_score(y_test, results_df['predicted_stage'])
1370
+ prec = precision_score(y_test, results_df['predicted_stage'], average='weighted', zero_division=0)
1371
+ recall = recall_score(y_test, results_df['predicted_stage'], average='weighted', zero_division=0)
1372
+ f1 = f1_score(y_test, results_df['predicted_stage'], average='weighted', zero_division=0)
1373
+
1374
+ # Add as metadata comment at bottom
1375
+ metrics_text = f"\n# Performance Metrics ({model_name})\n"
1376
+ metrics_text += f"# Accuracy: {acc:.3f}\n"
1377
+ metrics_text += f"# Precision (weighted): {prec:.3f}\n"
1378
+ metrics_text += f"# Recall (weighted): {recall:.3f}\n"
1379
+ metrics_text += f"# F1-Score (weighted): {f1:.3f}\n"
1380
+
1381
+ return results_df, metrics_text
1382
+
1383
+ return results_df, None
predict_csv.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ CSV Prediction Script for SWAN Menopause Stage Forecasting
4
+
5
+ This script demonstrates how to use the trained forecasting module to make predictions
6
+ on a batch of individuals from a CSV file and save results with confidence scores
7
+ and performance metrics.
8
+
9
+ Usage:
10
+ python predict_csv.py --input demo_individuals.csv --model RandomForest
11
+ python predict_csv.py --input individuals.csv --output results.csv --model LogisticRegression
12
+
13
+ The script will:
14
+ 1. Read input CSV with individual feature values
15
+ 2. Make predictions using trained model
16
+ 3. Save results with predicted stage, confidence, and probabilities
17
+ 4. Display summary statistics
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import argparse
23
+ import pandas as pd
24
+ import numpy as np
25
+ from pathlib import Path
26
+
27
+
28
+ def main():
29
+ """Main function to handle CSV prediction."""
30
+
31
+ parser = argparse.ArgumentParser(
32
+ description='Make menopause stage predictions from CSV file'
33
+ )
34
+ parser.add_argument(
35
+ '--input', '-i',
36
+ required=True,
37
+ help='Path to input CSV file with individual feature values'
38
+ )
39
+ parser.add_argument(
40
+ '--output', '-o',
41
+ default=None,
42
+ help='Path to output CSV file (default: input_predictions.csv)'
43
+ )
44
+ parser.add_argument(
45
+ '--model', '-m',
46
+ choices=['RandomForest', 'LogisticRegression'],
47
+ default='RandomForest',
48
+ help='Which model to use for predictions'
49
+ )
50
+ parser.add_argument(
51
+ '--forecast-dir',
52
+ default='swan_ml_output',
53
+ help='Directory containing trained forecast models'
54
+ )
55
+
56
+ args = parser.parse_args()
57
+
58
+ # Import after parsing args
59
+ try:
60
+ from menopause import load_forecast_model, predict_from_csv
61
+ except ImportError:
62
+ print("ERROR: Could not import menopause module.")
63
+ print("Make sure you're in the correct directory and menopause.py is available.")
64
+ sys.exit(1)
65
+
66
+ # Check if input file exists
67
+ if not os.path.exists(args.input):
68
+ print(f"ERROR: Input file not found: {args.input}")
69
+ sys.exit(1)
70
+
71
+ # Check if forecast models exist
72
+ forecast_dir = args.forecast_dir
73
+ if not os.path.exists(os.path.join(forecast_dir, 'rf_pipeline.pkl')):
74
+ print(f"ERROR: Forecast models not found in {forecast_dir}")
75
+ print("Please run 'python menopause.py' first to train models.")
76
+ sys.exit(1)
77
+
78
+ print("="*80)
79
+ print("MENOPAUSE STAGE PREDICTION FROM CSV")
80
+ print("="*80)
81
+
82
+ # Load forecaster
83
+ print(f"\nLoading forecaster from {forecast_dir}...")
84
+ forecast = load_forecast_model(forecast_dir)
85
+
86
+ # Make predictions
87
+ print(f"\nUsing model: {args.model}")
88
+ results = predict_from_csv(
89
+ args.input,
90
+ forecast,
91
+ output_csv=args.output,
92
+ model=args.model,
93
+ output_dir='.'
94
+ )
95
+
96
+ if results is not None:
97
+ print("\n" + "="*80)
98
+ print("PREDICTION RESULTS")
99
+ print("="*80)
100
+
101
+ # Display results table
102
+ print("\nDetailed Results:")
103
+ print(results.to_string(index=False))
104
+
105
+ # Display performance metrics
106
+ print("\n" + "="*80)
107
+ print("PERFORMANCE SUMMARY")
108
+ print("="*80)
109
+
110
+ print(f"\nTotal Individuals: {len(results)}")
111
+ print(f"\nStage Distribution:")
112
+ for stage, count in results['predicted_stage'].value_counts().items():
113
+ pct = count / len(results) * 100
114
+ print(f" {stage}: {count} ({pct:.1f}%)")
115
+
116
+ print(f"\nConfidence Scores:")
117
+ print(f" Mean: {results['confidence'].mean():.3f}")
118
+ print(f" Min: {results['confidence'].min():.3f}")
119
+ print(f" Max: {results['confidence'].max():.3f}")
120
+ print(f" Std Dev: {results['confidence'].std():.3f}")
121
+
122
+ # Confidence distribution
123
+ high_conf = (results['confidence'] > 0.8).sum()
124
+ med_conf = ((results['confidence'] > 0.6) & (results['confidence'] <= 0.8)).sum()
125
+ low_conf = (results['confidence'] <= 0.6).sum()
126
+
127
+ print(f"\nConfidence Distribution:")
128
+ print(f" High (>0.80): {high_conf}/{len(results)} ({high_conf/len(results)*100:.1f}%)")
129
+ print(f" Medium (0.60-0.80): {med_conf}/{len(results)} ({med_conf/len(results)*100:.1f}%)")
130
+ print(f" Low (≤0.60): {low_conf}/{len(results)} ({low_conf/len(results)*100:.1f}%)")
131
+
132
+ # Output file confirmation
133
+ output_path = args.output if args.output else f"{Path(args.input).stem}_predictions.csv"
134
+ print(f"\n✅ Results saved to: {output_path}")
135
+ else:
136
+ print("ERROR: Prediction failed.")
137
+ sys.exit(1)
138
+
139
+ print("\n" + "="*80)
140
+
141
+
142
+ if __name__ == '__main__':
143
+ main()
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWAN Menopause Prediction - Gradio App
2
+ # Python 3.10+ recommended
3
+
4
+ # ── UI Framework ─────────────────────────────────────────────────────────────
5
+ gradio>=4.0.0
6
+
7
+ # ── Data & ML ────────────────────────────────────────────────────────────────
8
+ pandas>=1.3.0
9
+ numpy>=1.20.0
10
+ scikit-learn==1.7.2 # Must match version used to train saved .pkl artifacts
11
+ joblib>=1.0.0
12
+ python-dateutil>=2.8.0
13
+
14
+ # ── Visualization ────────────────────────────────────────────────────────────
15
+ matplotlib>=3.3.0
16
+ seaborn>=0.11.0
17
+
18
+ # ── Notes ─────────────────────────────────────────────────────────────────────
19
+ # scikit-learn version is pinned because the .pkl pipelines (rf_pipeline.pkl,
20
+ # lr_pipeline.pkl) were serialized with scikit-learn 1.7.2. Using a different
21
+ # version may cause pickle incompatibility errors.
22
+ #
23
+ # To install locally:
24
+ # python -m venv .venv
25
+ # source .venv/bin/activate # or .venv\Scripts\activate on Windows
26
+ # pip install -r requirements_app.txt
swan_ml_output/forecast_metadata.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feature_names": [
3
+ "PAIN17",
4
+ "PAINTW17",
5
+ "PAIN27",
6
+ "PAINTW27",
7
+ "SLEEP17",
8
+ "SLEEP27",
9
+ "BCOHOTH7",
10
+ "EXERCIS7",
11
+ "EXERHAR7",
12
+ "EXEROST7",
13
+ "EXERMEN7",
14
+ "EXERLOO7",
15
+ "EXERMEM7",
16
+ "EXERPER7",
17
+ "EXERGEN7",
18
+ "EXERWGH7",
19
+ "EXERADV7",
20
+ "EXEROTH7",
21
+ "EXERSPE7",
22
+ "ABBLEED7",
23
+ "BLEEDNG7",
24
+ "LMPDAY7",
25
+ "DEPRESS7",
26
+ "SEX17",
27
+ "SEX27",
28
+ "SEX37",
29
+ "SEX47",
30
+ "SEX57",
31
+ "SEX67",
32
+ "SEX77",
33
+ "SEX87",
34
+ "SEX97",
35
+ "SEX107",
36
+ "SEX117",
37
+ "SEX127",
38
+ "SMOKERE7",
39
+ "HOTFLAS7",
40
+ "NUMHOTF7",
41
+ "BOTHOTF7",
42
+ "IRRITAB7",
43
+ "VAGINDR7",
44
+ "MOODCHG7",
45
+ "SLEEPQL7",
46
+ "PHYSILL7",
47
+ "HOTHEAD7",
48
+ "EXER12H7",
49
+ "ALCO24H7",
50
+ "AGE7",
51
+ "RACE",
52
+ "LANGINT7"
53
+ ],
54
+ "stage_classes": [
55
+ "peri",
56
+ "post",
57
+ "pre"
58
+ ],
59
+ "num_features": 50
60
+ }
swan_ml_output/lr_pipeline.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d9a1f99d0fc278ba57c7d21f0de5a0d4f2d88e7a79e4647c5d6f9b0cb925f9e
3
+ size 61178
swan_ml_output/rf_pipeline.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e8a2e356ca8e17972073da38902d5cefe824693bba6b3206316956dafbd64a7
3
+ size 4274787