Wil2200 Claude Opus 4.6 commited on
Commit
b084d07
ยท
1 Parent(s): 1f5edca

Generalize interaction terms to N-way, any column type

Browse files

Replace the restrictive HeterogeneityInteraction (2-way, attribute x
demographic only) with a flexible InteractionTerm dataclass that supports
arbitrary N-way interactions between any numeric columns. This enables
attribute x attribute (e.g. price x time), 3-way+ interactions, and
removes the respondent-constant restriction on interaction columns.

- config.py: Add InteractionTerm(columns: tuple[str, ...]) with name
property and validation; keep HeterogeneityInteraction for backward compat
- pipeline.py: Generalize interaction loop to multiply N columns together
- Model.py: Replace demographic-only UI with flexible interaction builder
using st.form + session state (add/remove terms, any numeric column)
- test_e2e.py: Update tests 22/26 to use InteractionTerm; add tests 30
(3-way interaction) and 31 (attribute x attribute); all 31 tests pass

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

app/pages/2_โš™๏ธ_Model.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ๅˆ†ๆžไพ  โ€” Page 2: Model Configuration and Estimation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+ import streamlit as st
11
+
12
+ # โ”€โ”€ path setup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
13
+ ROOT = Path(__file__).resolve().parents[2]
14
+ SRC = ROOT / "src"
15
+ if str(SRC) not in sys.path:
16
+ sys.path.insert(0, str(SRC))
17
+
18
+ APP_DIR = Path(__file__).resolve().parents[1]
19
+ if str(APP_DIR) not in sys.path:
20
+ sys.path.insert(0, str(APP_DIR))
21
+
22
+ from dce_analyzer.config import ( # noqa: E402
23
+ DummyCoding,
24
+ FullModelSpec,
25
+ InteractionTerm,
26
+ VariableSpec,
27
+ )
28
+ from dce_analyzer.data import get_device_info # noqa: E402
29
+ from dce_analyzer.pipeline import estimate_from_spec # noqa: E402
30
+ from utils import init_session_state, require_data, sidebar_branding # noqa: E402
31
+
32
+
33
+ def _connected_components(pairs: list[tuple[int, int]], n: int) -> list[list[int]]:
34
+ """Compute connected components from selected correlation pairs."""
35
+ adj: dict[int, set[int]] = {i: set() for i in range(n)}
36
+ for a, b in pairs:
37
+ adj[a].add(b)
38
+ adj[b].add(a)
39
+ visited: set[int] = set()
40
+ components: list[list[int]] = []
41
+ for i in range(n):
42
+ if i not in visited and adj[i]:
43
+ comp: list[int] = []
44
+ queue = [i]
45
+ while queue:
46
+ node = queue.pop(0)
47
+ if node in visited:
48
+ continue
49
+ visited.add(node)
50
+ comp.append(node)
51
+ for nb in adj[node]:
52
+ if nb not in visited:
53
+ queue.append(nb)
54
+ if len(comp) >= 2:
55
+ components.append(sorted(comp))
56
+ return components
57
+
58
+ init_session_state()
59
+ sidebar_branding()
60
+
61
+ # โ”€โ”€ Page header โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
62
+ st.header("Model")
63
+ st.caption("Configure utility variables, choose a model type, and run estimation.")
64
+
65
+ require_data()
66
+
67
+ df: pd.DataFrame = st.session_state.df
68
+
69
+
70
+ # โ”€โ”€ helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
71
+ def _guess_col(columns: list[str], candidates: list[str], role: str | None = None) -> str:
72
+ """Find a column by name candidates, checking inferred_columns first."""
73
+ inferred = st.session_state.get("inferred_columns", {})
74
+ if role and inferred.get(role) in columns:
75
+ return inferred[role]
76
+ lowered = {c.lower(): c for c in columns}
77
+ for target in candidates:
78
+ if target.lower() in lowered:
79
+ return lowered[target.lower()]
80
+ return columns[0]
81
+
82
+
83
+ columns = df.columns.tolist()
84
+
85
+ # โ”€โ”€ 1 Column role assignment โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
86
+ st.subheader("1. Column roles")
87
+ st.markdown("Assign the structural columns in your dataset.")
88
+
89
+ if st.session_state.get("inferred_columns"):
90
+ st.caption("Pre-filled from auto-detect on the Data page.")
91
+
92
+ r1, r2, r3, r4 = st.columns(4)
93
+ id_col = r1.selectbox(
94
+ "ID column",
95
+ columns,
96
+ index=columns.index(_guess_col(columns, ["respondent_id", "id", "ID"], "id")),
97
+ )
98
+ task_col = r2.selectbox(
99
+ "Task column",
100
+ columns,
101
+ index=columns.index(_guess_col(columns, ["task_id", "task"], "task")),
102
+ )
103
+ alt_col = r3.selectbox(
104
+ "Alternative column",
105
+ columns,
106
+ index=columns.index(_guess_col(columns, ["alternative", "alt"], "alt")),
107
+ )
108
+ choice_col = r4.selectbox(
109
+ "Choice column",
110
+ columns,
111
+ index=columns.index(_guess_col(columns, ["choice", "chosen"], "choice")),
112
+ )
113
+
114
+ # โ”€โ”€ BWS (Best-Worst Scaling) mode โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
115
+ st.divider()
116
+ bws_mode = st.checkbox(
117
+ "BWS (Best-Worst Scaling) data",
118
+ value=False,
119
+ key="bws_mode",
120
+ help="Enable if your data contains both best AND worst choices per task. "
121
+ "Requires at least 3 alternatives per task (J >= 3).",
122
+ )
123
+ bws_worst_col: str | None = None
124
+ bws_estimate_lambda_w: bool = True
125
+
126
+ if bws_mode:
127
+ # Auto-detect worst column candidates
128
+ _worst_candidates = [c for c in columns if c not in {id_col, task_col, alt_col, choice_col}]
129
+ _worst_default = _guess_col(columns, ["worst", "worst_choice", "least_preferred"], None)
130
+ if _worst_default not in _worst_candidates:
131
+ _worst_default = _worst_candidates[0] if _worst_candidates else columns[0]
132
+
133
+ bws_c1, bws_c2 = st.columns(2)
134
+ with bws_c1:
135
+ bws_worst_col = st.selectbox(
136
+ "Worst choice column",
137
+ _worst_candidates,
138
+ index=_worst_candidates.index(_worst_default) if _worst_default in _worst_candidates else 0,
139
+ key="bws_worst_col",
140
+ help="Column indicating the worst (least preferred) alternative in each task. "
141
+ "Same format as the choice column (binary 0/1 or label).",
142
+ )
143
+ with bws_c2:
144
+ bws_estimate_lambda_w = st.checkbox(
145
+ "Estimate lambda_w (worst scale parameter)",
146
+ value=True,
147
+ key="bws_estimate_lw",
148
+ help="If checked, estimates a scale parameter lambda_w for worst choices. "
149
+ "lambda_w > 1 means worst choices are more deterministic; lambda_w < 1 means noisier. "
150
+ "If unchecked, lambda_w = 1 (equivalent to MaxDiff specification).",
151
+ )
152
+ st.caption(
153
+ "BWS uses **sequential best-first** likelihood: "
154
+ "P(best) x P(worst | best removed). "
155
+ "The existing choice column is treated as the **best** choice."
156
+ )
157
+
158
+ # โ”€โ”€ 2 Variable selection and coding โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
159
+ st.divider()
160
+ st.subheader("2. Utility variables")
161
+
162
+ # Allow selecting any numeric column (DCE attributes are typically numeric-coded)
163
+ structural_cols = {id_col, task_col, alt_col, choice_col}
164
+ numeric_columns = [c for c in columns if pd.api.types.is_numeric_dtype(df[c]) and c not in structural_cols]
165
+ default_features = [
166
+ c
167
+ for c in [
168
+ "price", "time", "comfort", "reliability",
169
+ "travel_time", "travel_cost", "headway", "changes",
170
+ ]
171
+ if c in numeric_columns
172
+ ]
173
+ if not default_features and numeric_columns:
174
+ default_features = numeric_columns[: min(4, len(numeric_columns))]
175
+
176
+ feature_cols = st.multiselect(
177
+ "Select variables for the utility function",
178
+ options=numeric_columns,
179
+ default=default_features,
180
+ help="Select the attribute columns to include in the utility specification.",
181
+ )
182
+
183
+ if len(feature_cols) == 0:
184
+ st.warning("Pick at least one utility variable.")
185
+ st.stop()
186
+
187
+ # โ”€โ”€ Per-variable coding type: Continuous vs Dummy โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
188
+ st.markdown("**Variable coding**")
189
+ st.caption(
190
+ "For each variable, choose **Continuous** (single coefficient, assumes linear effect) "
191
+ "or **Dummy** (one coefficient per level, flexible non-linear effect). "
192
+ "Dummy coding is standard for categorical DCE attributes."
193
+ )
194
+
195
+ coding_map: dict[str, str] = {} # col -> "continuous" | "dummy"
196
+ ref_levels: dict[str, object] = {} # col -> reference level value
197
+
198
+ n_coding_cols = min(4, len(feature_cols))
199
+ coding_cols = st.columns(n_coding_cols)
200
+
201
+ for idx, col in enumerate(feature_cols):
202
+ with coding_cols[idx % n_coding_cols]:
203
+ unique_vals = sorted(df[col].dropna().unique())
204
+ n_unique = len(unique_vals)
205
+ # default to dummy if few unique values (typical categorical attribute)
206
+ default_idx = 1 if 2 <= n_unique <= 10 else 0
207
+ coding = st.selectbox(
208
+ f"{col}",
209
+ ["Continuous", "Dummy"],
210
+ index=default_idx,
211
+ key=f"coding_{col}",
212
+ help=f"{n_unique} unique values: {unique_vals[:8]}{'...' if n_unique > 8 else ''}",
213
+ )
214
+ coding_map[col] = coding.lower()
215
+
216
+ if coding == "Dummy":
217
+ ref = st.selectbox(
218
+ f"Reference level",
219
+ unique_vals,
220
+ index=0,
221
+ key=f"ref_{col}",
222
+ help="The omitted baseline category. Other levels are estimated relative to this.",
223
+ )
224
+ ref_levels[col] = ref
225
+
226
+ # โ”€โ”€ Build dummy coding specs (backend will expand columns) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
227
+ _dummy_codings: list[DummyCoding] = []
228
+ expanded_feature_cols: list[str] = [] # expanded column names for UI display
229
+ _dummy_info: dict[str, list[str]] = {} # original col -> list of dummy col names
230
+
231
+ for col in feature_cols:
232
+ if coding_map[col] == "dummy":
233
+ dc = DummyCoding(column=col, ref_level=ref_levels[col])
234
+ _dummy_codings.append(dc)
235
+ dummy_names, _ = dc.expand(df)
236
+ expanded_feature_cols.extend(dummy_names)
237
+ _dummy_info[col] = dummy_names
238
+ else:
239
+ expanded_feature_cols.append(col)
240
+
241
+ # Show summary of expanded variables
242
+ with st.expander("Variable specification summary", expanded=False):
243
+ summary_rows = []
244
+ for col in feature_cols:
245
+ if coding_map[col] == "dummy":
246
+ ref = ref_levels[col]
247
+ n_dummies = len(_dummy_info[col])
248
+ summary_rows.append({
249
+ "Variable": col,
250
+ "Coding": "Dummy",
251
+ "Reference": str(ref),
252
+ "Coefficients": n_dummies,
253
+ "Columns": ", ".join(_dummy_info[col]),
254
+ })
255
+ else:
256
+ summary_rows.append({
257
+ "Variable": col,
258
+ "Coding": "Continuous",
259
+ "Reference": "โ€”",
260
+ "Coefficients": 1,
261
+ "Columns": col,
262
+ })
263
+ st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
264
+ st.caption(f"Total parameters to estimate: **{len(expanded_feature_cols)}**")
265
+
266
+ # โ”€โ”€ 3 Model type and settings โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
267
+ st.divider()
268
+ st.subheader("3. Model type and settings")
269
+
270
+ # Show detected hardware
271
+ st.info(f"Compute device: **{get_device_info()}**")
272
+
273
+ model_type_label = st.radio(
274
+ "Select model type",
275
+ ["Conditional Logit", "Mixed Logit", "GMNL", "Latent Class"],
276
+ horizontal=True,
277
+ key="model_type_radio",
278
+ )
279
+ model_type_map = {
280
+ "Conditional Logit": "conditional",
281
+ "Mixed Logit": "mixed",
282
+ "GMNL": "gmnl",
283
+ "Latent Class": "latent_class",
284
+ }
285
+ model_type = model_type_map[model_type_label]
286
+
287
+ if bws_mode:
288
+ if model_type == "conditional":
289
+ st.info("BWS + Conditional Logit: lambda_w is fully identified. Good baseline.")
290
+ elif model_type == "mixed":
291
+ st.info(
292
+ "BWS + Mixed Logit: lambda_w must be a fixed scalar (not random). "
293
+ "It is identified separately from the random coefficient distributions."
294
+ )
295
+ elif model_type == "gmnl":
296
+ st.warning(
297
+ "BWS + GMNL: lambda_w is identified separately from sigma_tau (individual scale), "
298
+ "but both must be fixed parameters. Monitor convergence carefully."
299
+ )
300
+ elif model_type == "latent_class":
301
+ st.info(
302
+ "BWS + Latent Class: lambda_w is shared across all classes. "
303
+ "Per-class lambda_w is theoretically identified but increases parameter count."
304
+ )
305
+
306
+ dist_map: dict[str, str] = {}
307
+
308
+ if model_type == "conditional":
309
+ st.caption(
310
+ "All coefficients are fixed across respondents. "
311
+ "Fast to estimate, good baseline model."
312
+ )
313
+ s1, s2 = st.columns(2)
314
+ maxiter = s1.slider("Max optimizer iterations", 20, 500, 200, step=10, key="cl_maxiter")
315
+ est_seed = s2.number_input("Estimation seed", min_value=1, value=123, step=1, key="cl_seed")
316
+ for col in expanded_feature_cols:
317
+ dist_map[col] = "fixed"
318
+ n_draws = 1
319
+ n_classes = 2
320
+ n_starts = 10
321
+
322
+ elif model_type == "gmnl":
323
+ st.caption(
324
+ "Generalized Multinomial Logit (Fiebig et al. 2010). "
325
+ "Extends Mixed Logit with individual-level scale heterogeneity. "
326
+ "Nests both S-MNL (pure scale) and MMNL as special cases."
327
+ )
328
+
329
+ st.markdown("**Distribution assumptions**")
330
+ st.caption("Set distributions for each variable. At least one random variable is required.")
331
+ with st.expander("What do the distribution options mean?"):
332
+ st.markdown(
333
+ "- **fixed**: The coefficient is the same for all respondents.\n"
334
+ "- **normal**: Varies across respondents following a normal distribution.\n"
335
+ "- **lognormal**: exp(normal), ensuring always positive values."
336
+ )
337
+
338
+ gmnl_dist_cols = st.columns(min(4, len(feature_cols)))
339
+ for idx, col in enumerate(feature_cols):
340
+ with gmnl_dist_cols[idx % len(gmnl_dist_cols)]:
341
+ default_dist_idx = 0
342
+ dist_val = st.selectbox(
343
+ f"{col}" + (" (dummy)" if coding_map.get(col) == "dummy" else ""),
344
+ ["fixed", "normal", "lognormal"],
345
+ index=default_dist_idx,
346
+ key=f"gmnl_dist_{col}",
347
+ )
348
+ if coding_map.get(col) == "dummy" and col in _dummy_info:
349
+ for dc in _dummy_info[col]:
350
+ dist_map[dc] = dist_val
351
+ else:
352
+ dist_map[col] = dist_val
353
+
354
+ st.markdown("**Estimation settings**")
355
+ gs1, gs2, gs3 = st.columns(3)
356
+ n_draws = gs1.slider("Halton draws", 20, 2000, 200, step=10, key="gmnl_draws")
357
+ maxiter = gs2.slider("Max optimizer iterations", 20, 500, 200, step=10, key="gmnl_maxiter")
358
+ est_seed = gs3.number_input("Estimation seed", min_value=1, value=123, step=1, key="gmnl_seed")
359
+
360
+ with st.expander("About GMNL scale parameters"):
361
+ st.markdown(
362
+ "The GMNL model estimates three additional parameters:\n"
363
+ "- **tau** (scale mean): controls the average scale of utility.\n"
364
+ "- **sigma_tau** (scale SD): individual variation in scale.\n"
365
+ "- **gamma** (mixing, 0-1): gamma=0 is pure scale heterogeneity (S-MNL), "
366
+ "gamma=1 is GMNL-II (closest to standard MMNL)."
367
+ )
368
+
369
+ # โ”€โ”€ Correlation structure for GMNL โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
370
+ _gmnl_random_expanded = [c for c in expanded_feature_cols if dist_map.get(c, "fixed") != "fixed"]
371
+ mxl_correlated = False
372
+ mxl_correlation_groups: list[list[int]] | None = None
373
+
374
+ if len(_gmnl_random_expanded) >= 2:
375
+ st.markdown("**Correlation structure**")
376
+ _gmnl_corr_mode = st.radio(
377
+ "Random parameter correlations",
378
+ ["Independent", "Full correlation", "Selective (pick pairs)"],
379
+ horizontal=True,
380
+ key="gmnl_corr_mode",
381
+ help="Independent: each random parameter varies independently. "
382
+ "Full: all random parameters are correlated (Cholesky). "
383
+ "Selective: choose specific pairs to correlate.",
384
+ )
385
+ if _gmnl_corr_mode == "Full correlation":
386
+ mxl_correlated = True
387
+ elif _gmnl_corr_mode == "Selective (pick pairs)":
388
+ _n_rand = len(_gmnl_random_expanded)
389
+ _all_corr_pairs = [
390
+ (i, j, _gmnl_random_expanded[i], _gmnl_random_expanded[j])
391
+ for i in range(_n_rand)
392
+ for j in range(i + 1, _n_rand)
393
+ ]
394
+ _selected_corr: list[tuple[int, int]] = []
395
+ _n_pair_cols = min(4, max(1, len(_all_corr_pairs)))
396
+ _pair_cols = st.columns(_n_pair_cols)
397
+ for _pidx, (_i, _j, _ni, _nj) in enumerate(_all_corr_pairs):
398
+ with _pair_cols[_pidx % _n_pair_cols]:
399
+ if st.checkbox(f"{_ni} โ†” {_nj}", key=f"gmnl_corr_{_ni}_{_nj}"):
400
+ _selected_corr.append((_i, _j))
401
+ if _selected_corr:
402
+ mxl_correlated = True
403
+ _groups = _connected_components(_selected_corr, _n_rand)
404
+ mxl_correlation_groups = _groups
405
+ _group_labels = [
406
+ [_gmnl_random_expanded[k] for k in g] for g in _groups
407
+ ]
408
+ st.caption(f"Correlation blocks: {_group_labels}")
409
+ else:
410
+ st.caption("No pairs selected โ€” using independent structure.")
411
+ elif len(_gmnl_random_expanded) == 1:
412
+ st.caption("Only one random parameter โ€” correlation not applicable.")
413
+
414
+ n_classes = 2
415
+ n_starts = 10
416
+
417
+ elif model_type == "mixed":
418
+ st.caption(
419
+ "Allows coefficients to vary continuously across respondents "
420
+ "(preference heterogeneity). Uses simulated maximum likelihood with Halton draws."
421
+ )
422
+
423
+ st.markdown("**Distribution assumptions**")
424
+ st.caption("Set distributions for each variable (or group of dummies). "
425
+ "Dummy-coded variables are typically kept **fixed**.")
426
+ with st.expander("What do the distribution options mean?"):
427
+ st.markdown(
428
+ "- **fixed**: The coefficient is the same for all respondents.\n"
429
+ "- **normal**: Varies across respondents following a normal distribution.\n"
430
+ "- **lognormal**: exp(normal), ensuring always positive values."
431
+ )
432
+
433
+ # Show distribution selector per original variable (applies to all its dummies)
434
+ mxl_dist_cols = st.columns(min(4, len(feature_cols)))
435
+ for idx, col in enumerate(feature_cols):
436
+ with mxl_dist_cols[idx % len(mxl_dist_cols)]:
437
+ default_dist_idx = 0 # fixed by default for dummy-coded
438
+ dist_val = st.selectbox(
439
+ f"{col}" + (" (dummy)" if coding_map.get(col) == "dummy" else ""),
440
+ ["fixed", "normal", "lognormal"],
441
+ index=default_dist_idx,
442
+ key=f"dist_{col}",
443
+ )
444
+ # Apply distribution to all expanded columns from this variable
445
+ if coding_map.get(col) == "dummy" and col in _dummy_info:
446
+ for dc in _dummy_info[col]:
447
+ dist_map[dc] = dist_val
448
+ else:
449
+ dist_map[col] = dist_val
450
+
451
+ st.markdown("**Estimation settings**")
452
+ ms1, ms2, ms3 = st.columns(3)
453
+ n_draws = ms1.slider("Halton draws", 20, 2000, 200, step=10, key="mxl_draws")
454
+ maxiter = ms2.slider("Max optimizer iterations", 20, 500, 200, step=10, key="mxl_maxiter")
455
+ est_seed = ms3.number_input("Estimation seed", min_value=1, value=123, step=1, key="mxl_seed")
456
+
457
+ # โ”€โ”€ Correlation structure: Independent / Full / Selective โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
458
+ _random_expanded = [c for c in expanded_feature_cols if dist_map.get(c, "fixed") != "fixed"]
459
+ mxl_correlated = False
460
+ mxl_correlation_groups: list[list[int]] | None = None
461
+
462
+ if len(_random_expanded) >= 2:
463
+ st.markdown("**Correlation structure**")
464
+ _corr_mode = st.radio(
465
+ "Random parameter correlations",
466
+ ["Independent", "Full correlation", "Selective (pick pairs)"],
467
+ horizontal=True,
468
+ key="mxl_corr_mode",
469
+ help="Independent: each random parameter varies independently. "
470
+ "Full: all random parameters are correlated (Cholesky). "
471
+ "Selective: choose specific pairs to correlate.",
472
+ )
473
+ if _corr_mode == "Full correlation":
474
+ mxl_correlated = True
475
+ elif _corr_mode == "Selective (pick pairs)":
476
+ _n_rand = len(_random_expanded)
477
+ _all_corr_pairs = [
478
+ (i, j, _random_expanded[i], _random_expanded[j])
479
+ for i in range(_n_rand)
480
+ for j in range(i + 1, _n_rand)
481
+ ]
482
+ _selected_corr: list[tuple[int, int]] = []
483
+ _n_pair_cols = min(4, max(1, len(_all_corr_pairs)))
484
+ _pair_cols = st.columns(_n_pair_cols)
485
+ for _pidx, (_i, _j, _ni, _nj) in enumerate(_all_corr_pairs):
486
+ with _pair_cols[_pidx % _n_pair_cols]:
487
+ if st.checkbox(f"{_ni} โ†” {_nj}", key=f"corr_{_ni}_{_nj}"):
488
+ _selected_corr.append((_i, _j))
489
+ if _selected_corr:
490
+ mxl_correlated = True
491
+ _groups = _connected_components(_selected_corr, _n_rand)
492
+ mxl_correlation_groups = _groups
493
+ _group_labels = [
494
+ [_random_expanded[k] for k in g] for g in _groups
495
+ ]
496
+ st.caption(f"Correlation blocks: {_group_labels}")
497
+ else:
498
+ st.caption("No pairs selected โ€” using independent structure.")
499
+ elif len(_random_expanded) == 1:
500
+ st.caption("Only one random parameter โ€” correlation not applicable.")
501
+
502
+ n_classes = 2
503
+ n_starts = 10
504
+
505
+ else: # latent_class
506
+ st.caption(
507
+ "Assumes Q discrete segments of respondents, each with distinct "
508
+ "fixed preferences. Useful for market segmentation."
509
+ )
510
+ ls1, ls2, ls3, ls4 = st.columns(4)
511
+ n_classes = ls1.slider("Number of classes (Q)", 2, 5, 2, key="lc_classes")
512
+ n_starts = ls2.slider("Random starts", 5, 20, 10, key="lc_starts")
513
+ maxiter = ls3.slider("Max optimizer iterations", 20, 500, 200, step=10, key="lc_maxiter")
514
+ est_seed = ls4.number_input("Estimation seed", min_value=1, value=123, step=1, key="lc_seed")
515
+
516
+ # Membership covariates: columns constant within each respondent
517
+ st.markdown("**Membership variables (demographics)**")
518
+ st.caption(
519
+ "Optionally select individual-level covariates that explain class membership. "
520
+ "Only columns that are constant within each respondent are shown."
521
+ )
522
+ _candidate_membership_cols = [
523
+ c for c in columns
524
+ if c not in structural_cols
525
+ and c not in expanded_feature_cols
526
+ and c not in feature_cols
527
+ ]
528
+ # Filter to columns constant within respondent groups
529
+ _constant_cols: list[str] = []
530
+ for c in _candidate_membership_cols:
531
+ try:
532
+ if df.groupby(id_col)[c].nunique().max() == 1:
533
+ _constant_cols.append(c)
534
+ except Exception:
535
+ pass
536
+ lc_membership_cols: list[str] = st.multiselect(
537
+ "Select membership covariates",
538
+ options=_constant_cols,
539
+ default=[],
540
+ key="lc_membership_cols",
541
+ help="These variables enter the class membership function. "
542
+ "They must be constant within each respondent (e.g. age, income, gender).",
543
+ )
544
+
545
+ for col in expanded_feature_cols:
546
+ dist_map[col] = "fixed"
547
+ n_draws = 1
548
+
549
+ # โ”€โ”€ Defaults for variables not set by every model type branch โ”€โ”€โ”€โ”€โ”€
550
+ if model_type not in ("mixed", "gmnl"):
551
+ mxl_correlated = False
552
+ mxl_correlation_groups = None
553
+ if model_type != "latent_class":
554
+ lc_membership_cols = []
555
+
556
+ # โ”€โ”€ Interaction terms (N-way, any columns) โ€” all models โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
557
+ st.divider()
558
+ st.subheader("Interaction terms (optional)")
559
+ st.caption(
560
+ "Add interaction terms by multiplying 2 or more columns together. "
561
+ "Columns can be attributes, demographics, or any numeric column. "
562
+ "Works with all model types."
563
+ )
564
+
565
+ # Available columns: expanded feature cols + all other numeric non-structural columns
566
+ _interaction_available_cols = list(expanded_feature_cols)
567
+ for c in columns:
568
+ if (
569
+ c not in structural_cols
570
+ and c not in _interaction_available_cols
571
+ and c not in feature_cols
572
+ and pd.api.types.is_numeric_dtype(df[c])
573
+ ):
574
+ _interaction_available_cols.append(c)
575
+
576
+ # Session state for interaction terms
577
+ if "interaction_terms" not in st.session_state:
578
+ st.session_state.interaction_terms = []
579
+
580
+ # Add new interaction term via form
581
+ with st.form("add_interaction_form", clear_on_submit=True):
582
+ _inter_cols = st.multiselect(
583
+ "Select columns to interact",
584
+ options=_interaction_available_cols,
585
+ default=[],
586
+ key="new_interaction_cols",
587
+ help="Pick 2 or more columns. Their product will be added as a new variable.",
588
+ )
589
+ _submitted = st.form_submit_button("Add interaction term")
590
+ if _submitted:
591
+ if len(_inter_cols) < 2:
592
+ st.warning("Select at least 2 columns for an interaction term.")
593
+ else:
594
+ new_term = tuple(_inter_cols)
595
+ if new_term not in st.session_state.interaction_terms:
596
+ st.session_state.interaction_terms.append(new_term)
597
+ st.rerun()
598
+ else:
599
+ st.info("This interaction term already exists.")
600
+
601
+ # Display existing terms with remove buttons
602
+ _het_interactions: list[InteractionTerm] = []
603
+ if st.session_state.interaction_terms:
604
+ st.markdown("**Current interaction terms:**")
605
+ _terms_to_keep: list[tuple] = []
606
+ for idx, term in enumerate(st.session_state.interaction_terms):
607
+ label = " ร— ".join(term)
608
+ c_label, c_remove = st.columns([4, 1])
609
+ c_label.markdown(f"- `{label}`")
610
+ if c_remove.button("Remove", key=f"remove_inter_{idx}"):
611
+ pass # skip this term
612
+ else:
613
+ _terms_to_keep.append(term)
614
+ _het_interactions.append(InteractionTerm(columns=term))
615
+ if len(_terms_to_keep) != len(st.session_state.interaction_terms):
616
+ st.session_state.interaction_terms = _terms_to_keep
617
+ st.rerun()
618
+ st.caption(f"{len(_het_interactions)} interaction term(s) configured.")
619
+
620
+ # โ”€โ”€ Sidebar: model history count โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
621
+ history: list[dict] = st.session_state.model_history
622
+ if history:
623
+ st.sidebar.divider()
624
+ st.sidebar.metric("Saved models", len(history))
625
+ st.sidebar.markdown("**Model history**")
626
+ for entry in history:
627
+ st.sidebar.caption(f"- {entry.get('label', 'model')}")
628
+
629
+
630
+ # โ”€โ”€ helpers for result display โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
631
+ def _significance(p: float) -> str:
632
+ if pd.isna(p):
633
+ return ""
634
+ if p < 0.001:
635
+ return "***"
636
+ if p < 0.01:
637
+ return "**"
638
+ if p < 0.05:
639
+ return "*"
640
+ if p < 0.1:
641
+ return "."
642
+ return ""
643
+
644
+
645
+ def _show_results(estimation, run_label: str, header_suffix: str = "") -> None:
646
+ """Render fit metrics, parameter table, and download buttons."""
647
+ if estimation.success:
648
+ st.success(f"Converged: {estimation.message}")
649
+ else:
650
+ st.warning(f"Did not converge: {estimation.message}")
651
+
652
+ st.markdown(f"#### Model fit{header_suffix}")
653
+ m1, m2, m3, m4, m5 = st.columns(5)
654
+ m1.metric("Log-Likelihood", f"{estimation.log_likelihood:,.3f}")
655
+ m2.metric("AIC", f"{estimation.aic:,.2f}")
656
+ m3.metric("BIC", f"{estimation.bic:,.2f}")
657
+ m4.metric("Iterations", f"{estimation.optimizer_iterations}")
658
+ m5.metric("Runtime (s)", f"{estimation.runtime_seconds:.2f}")
659
+
660
+ st.markdown(f"#### Parameter estimates{header_suffix}")
661
+ display_df = estimation.estimates.copy()
662
+ if "p_value" in display_df.columns:
663
+ display_df["sig"] = display_df["p_value"].apply(_significance)
664
+ display_df = display_df.drop(columns=["theta_index"], errors="ignore")
665
+ st.dataframe(display_df, use_container_width=True, hide_index=True)
666
+ st.caption("Significance codes: *** p<0.001, ** p<0.01, * p<0.05, . p<0.1")
667
+
668
+ # Show covariance and correlation matrices for correlated MMNL
669
+ if getattr(estimation, "covariance_matrix", None) is not None:
670
+ names = estimation.random_param_names or []
671
+ st.markdown(f"#### Covariance matrix (random parameters){header_suffix}")
672
+ cov_df = pd.DataFrame(estimation.covariance_matrix, index=names, columns=names)
673
+ st.dataframe(cov_df, use_container_width=True)
674
+
675
+ st.markdown(f"#### Correlation matrix (random parameters){header_suffix}")
676
+ cor_df = pd.DataFrame(estimation.correlation_matrix, index=names, columns=names)
677
+ st.dataframe(cor_df, use_container_width=True)
678
+
679
+ d1, d2 = st.columns(2)
680
+ with d1:
681
+ csv_bytes = estimation.estimates.to_csv(index=False).encode("utf-8")
682
+ st.download_button(
683
+ label="Download estimates CSV",
684
+ data=csv_bytes,
685
+ file_name=f"{run_label}_estimates.csv",
686
+ mime="text/csv",
687
+ )
688
+ with d2:
689
+ summary_bytes = json.dumps(estimation.summary_dict(), indent=2, default=str).encode("utf-8")
690
+ st.download_button(
691
+ label="Download summary JSON",
692
+ data=summary_bytes,
693
+ file_name=f"{run_label}_summary.json",
694
+ mime="application/json",
695
+ )
696
+
697
+
698
+ def _show_lc_results(estimation, run_label: str) -> None:
699
+ """Render Latent Class specific results."""
700
+ import plotly.express as px
701
+
702
+ if estimation.success:
703
+ st.success(f"Converged: {estimation.message}")
704
+ else:
705
+ st.warning(f"Did not converge: {estimation.message}")
706
+
707
+ # Fit metrics
708
+ st.markdown("#### Model fit")
709
+ m1, m2, m3, m4, m5 = st.columns(5)
710
+ m1.metric("Log-Likelihood", f"{estimation.log_likelihood:,.3f}")
711
+ m2.metric("AIC", f"{estimation.aic:,.2f}")
712
+ m3.metric("BIC", f"{estimation.bic:,.2f}")
713
+ m4.metric("Iterations", f"{estimation.optimizer_iterations}")
714
+ m5.metric("Runtime (s)", f"{estimation.runtime_seconds:.2f}")
715
+
716
+ # Class membership probabilities
717
+ st.markdown("#### Class membership probabilities")
718
+ pi_df = pd.DataFrame({
719
+ "Class": [f"Class {i+1}" for i in range(estimation.n_classes)],
720
+ "Probability": estimation.class_probabilities,
721
+ })
722
+ c1, c2 = st.columns([1, 1])
723
+ with c1:
724
+ st.dataframe(pi_df, use_container_width=True, hide_index=True)
725
+ with c2:
726
+ fig_pie = px.pie(pi_df, names="Class", values="Probability", title="Class Shares")
727
+ st.plotly_chart(fig_pie, use_container_width=True)
728
+
729
+ # Membership coefficients (if covariates were used)
730
+ if getattr(estimation, "membership_estimates", None) is not None:
731
+ st.markdown("#### Membership function coefficients")
732
+ st.caption(
733
+ "Coefficients explaining class membership probabilities "
734
+ "(relative to Class 1 as reference)."
735
+ )
736
+ mem_est = estimation.membership_estimates
737
+ mem_pivot = mem_est.pivot(index="variable", columns="class_id", values="estimate")
738
+ mem_pivot.columns = [f"Class {c}" for c in mem_pivot.columns]
739
+ st.dataframe(mem_pivot, use_container_width=True)
740
+
741
+ # Class-specific parameter estimates
742
+ st.markdown("#### Class-specific parameter estimates")
743
+ class_est = estimation.class_estimates
744
+ pivot = class_est.pivot(index="parameter", columns="class_id", values="estimate")
745
+ pivot.columns = [f"Class {c}" for c in pivot.columns]
746
+ st.dataframe(pivot, use_container_width=True)
747
+
748
+ # Coefficient comparison plot
749
+ st.markdown("#### Per-class coefficient comparison")
750
+ fig_bar = px.bar(
751
+ class_est, x="parameter", y="estimate",
752
+ color=class_est["class_id"].astype(str), barmode="group",
753
+ labels={"estimate": "Coefficient", "parameter": "Variable", "color": "Class"},
754
+ title="Coefficient Estimates by Class",
755
+ )
756
+ fig_bar.update_layout(legend_title_text="Class")
757
+ st.plotly_chart(fig_bar, use_container_width=True)
758
+
759
+ # Posterior class membership
760
+ st.markdown("#### Posterior class membership")
761
+ posterior = estimation.posterior_probs
762
+ assigned_class = posterior.idxmax(axis=1)
763
+ class_counts = assigned_class.value_counts().sort_index()
764
+
765
+ c1, c2 = st.columns(2)
766
+ with c1:
767
+ fig_hist = px.histogram(
768
+ posterior.max(axis=1), nbins=30,
769
+ labels={"value": "Max posterior probability", "count": "Count"},
770
+ title="Distribution of max posterior probability",
771
+ )
772
+ st.plotly_chart(fig_hist, use_container_width=True)
773
+ with c2:
774
+ count_df = pd.DataFrame({"Class": class_counts.index, "Count": class_counts.values})
775
+ fig_count = px.bar(count_df, x="Class", y="Count", title="Assigned class counts")
776
+ st.plotly_chart(fig_count, use_container_width=True)
777
+
778
+ # Full parameter table
779
+ st.markdown("#### All parameter estimates")
780
+ st.dataframe(estimation.estimates, use_container_width=True, hide_index=True)
781
+
782
+ # Downloads
783
+ st.markdown("#### Export")
784
+ d1, d2, d3 = st.columns(3)
785
+ with d1:
786
+ csv_bytes = estimation.estimates.to_csv(index=False).encode("utf-8")
787
+ st.download_button(
788
+ label="Download estimates CSV",
789
+ data=csv_bytes,
790
+ file_name=f"{run_label}_estimates.csv",
791
+ mime="text/csv",
792
+ )
793
+ with d2:
794
+ post_csv = posterior.to_csv(index=False).encode("utf-8")
795
+ st.download_button(
796
+ label="Download class assignments CSV",
797
+ data=post_csv,
798
+ file_name=f"{run_label}_posterior.csv",
799
+ mime="text/csv",
800
+ )
801
+ with d3:
802
+ summary_bytes = json.dumps(estimation.summary_dict(), indent=2, default=str).encode("utf-8")
803
+ st.download_button(
804
+ label="Download summary JSON",
805
+ data=summary_bytes,
806
+ file_name=f"{run_label}_summary.json",
807
+ mime="application/json",
808
+ )
809
+
810
+
811
+ # โ”€โ”€ 4 Run estimation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
812
+ st.divider()
813
+ st.subheader("4. Run estimation")
814
+
815
+ if st.button("Run Estimation", type="primary", use_container_width=True):
816
+ # Build VariableSpec list from original feature columns.
817
+ # For dummy-coded variables, use the original column name as a placeholder;
818
+ # the backend will expand them using dummy_codings.
819
+ # For continuous variables, use the column directly.
820
+ variables = []
821
+ for col in feature_cols:
822
+ if coding_map[col] == "dummy":
823
+ # Placeholder: backend replaces with expanded dummies.
824
+ # Distribution from the UI dist selector applies to all dummies.
825
+ dummy_expanded = _dummy_info[col]
826
+ dist = dist_map.get(dummy_expanded[0], "fixed") if dummy_expanded else "fixed"
827
+ variables.append(VariableSpec(name=col, column=col, distribution=dist))
828
+ else:
829
+ dist = dist_map.get(col, "fixed")
830
+ variables.append(VariableSpec(name=col, column=col, distribution=dist))
831
+
832
+ # Build FullModelSpec โ€” one object captures everything
833
+ full_spec = FullModelSpec(
834
+ id_col=id_col,
835
+ task_col=task_col,
836
+ alt_col=alt_col,
837
+ choice_col=choice_col,
838
+ variables=variables,
839
+ model_type=model_type,
840
+ dummy_codings=_dummy_codings,
841
+ interactions=_het_interactions,
842
+ correlated=mxl_correlated,
843
+ correlation_groups=mxl_correlation_groups,
844
+ bws_worst_col=bws_worst_col if bws_mode else None,
845
+ estimate_lambda_w=bws_estimate_lambda_w if bws_mode else True,
846
+ n_classes=int(n_classes),
847
+ membership_cols=lc_membership_cols if lc_membership_cols else None,
848
+ n_draws=int(n_draws),
849
+ maxiter=int(maxiter),
850
+ seed=int(est_seed),
851
+ n_starts=int(n_starts),
852
+ )
853
+
854
+ # Pass the original DataFrame; backend handles dummy expansion
855
+ est_df = df
856
+
857
+ # Build spinner message
858
+ if model_type == "latent_class":
859
+ spinner_msg = f"Estimating {n_classes}-class model with {n_starts} random starts..."
860
+ else:
861
+ spinner_msg = "Estimating model โ€” this may take a minute..."
862
+
863
+ with st.spinner(spinner_msg):
864
+ try:
865
+ result = estimate_from_spec(df=est_df, spec=full_spec)
866
+ except Exception as exc:
867
+ st.error(f"Estimation failed: {exc}")
868
+ st.exception(exc)
869
+ st.stop()
870
+
871
+ estimation = result.estimation
872
+
873
+ # Auto-generate run label
874
+ prefix_map = {"conditional": "CL", "mixed": "MXL", "gmnl": "GMNL", "latent_class": "LC"}
875
+ prefix = prefix_map[model_type]
876
+ existing_count = sum(1 for h in history if h.get("model_type") == model_type)
877
+ if model_type == "latent_class":
878
+ run_label = f"{prefix} Run {existing_count + 1} (Q={n_classes})"
879
+ else:
880
+ run_label = f"{prefix} Run {existing_count + 1}"
881
+
882
+ # Store as current result (use expanded spec/df for Results page compatibility)
883
+ st.session_state.model_results = {
884
+ "spec": result.expanded_spec or full_spec.to_model_spec(),
885
+ "full_spec": full_spec,
886
+ "model_type": model_type,
887
+ "estimation": estimation,
888
+ "label": run_label,
889
+ "expanded_df": result.expanded_df,
890
+ }
891
+
892
+ # Append to history for comparison page
893
+ st.session_state.model_history.append({
894
+ "label": run_label,
895
+ "model_type": model_type,
896
+ "spec": result.expanded_spec or full_spec.to_model_spec(),
897
+ "full_spec": full_spec,
898
+ "estimation": estimation,
899
+ })
900
+
901
+ # Also store LC-specific result
902
+ if model_type == "latent_class":
903
+ st.session_state.lc_result = {
904
+ "estimation": estimation,
905
+ "label": run_label,
906
+ }
907
+
908
+ st.success(f"Model **{run_label}** estimated successfully.")
909
+
910
+ # Show results with appropriate display
911
+ if model_type == "latent_class":
912
+ _show_lc_results(estimation, run_label)
913
+ else:
914
+ _show_results(estimation, run_label)
915
+
916
+ # โ”€โ”€ Show last run results on rerun โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
917
+ elif st.session_state.model_results is not None:
918
+ res = st.session_state.model_results
919
+ est = res["estimation"]
920
+ label = res.get("label", "model")
921
+ if res.get("model_type") == "latent_class":
922
+ _show_lc_results(est, label)
923
+ else:
924
+ _show_results(est, label, header_suffix=" (last run)")
925
+
926
+ # โ”€โ”€ LC: BIC comparison tool โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
927
+ if model_type == "latent_class":
928
+ st.divider()
929
+ st.subheader("Optimal number of classes")
930
+ st.markdown(
931
+ "Automatically estimate models with Q = 2, 3, 4, 5 classes and compare BIC."
932
+ )
933
+
934
+ if st.button("Run BIC comparison (Q = 2..5)", use_container_width=True):
935
+ import plotly.express as px
936
+
937
+ bic_variables = []
938
+ for col in feature_cols:
939
+ bic_variables.append(VariableSpec(name=col, column=col, distribution="fixed"))
940
+
941
+ bic_rows: list[dict] = []
942
+ best_bic = float("inf")
943
+ best_q = 2
944
+ progress = st.progress(0, text="Starting class comparison...")
945
+
946
+ for i, q in enumerate([2, 3, 4, 5]):
947
+ progress.progress(i / 4, text=f"Estimating Q = {q}...")
948
+ bic_spec = FullModelSpec(
949
+ id_col=id_col, task_col=task_col, alt_col=alt_col,
950
+ choice_col=choice_col, variables=bic_variables,
951
+ model_type="latent_class",
952
+ dummy_codings=_dummy_codings,
953
+ n_classes=q, n_starts=int(n_starts),
954
+ maxiter=int(maxiter), seed=int(est_seed),
955
+ membership_cols=lc_membership_cols if lc_membership_cols else None,
956
+ bws_worst_col=bws_worst_col if bws_mode else None,
957
+ estimate_lambda_w=bws_estimate_lambda_w if bws_mode else True,
958
+ )
959
+ try:
960
+ result = estimate_from_spec(df=df, spec=bic_spec)
961
+ est = result.estimation
962
+ bic_rows.append({
963
+ "Q": q, "Log-Likelihood": round(est.log_likelihood, 3),
964
+ "AIC": round(est.aic, 2), "BIC": round(est.bic, 2),
965
+ "Parameters": est.n_parameters, "Converged": est.success,
966
+ })
967
+ if est.bic < best_bic:
968
+ best_bic = est.bic
969
+ best_q = q
970
+ except Exception as exc:
971
+ bic_rows.append({
972
+ "Q": q, "Log-Likelihood": None, "AIC": None,
973
+ "BIC": None, "Parameters": None, "Converged": False,
974
+ })
975
+ st.warning(f"Q = {q} failed: {exc}")
976
+
977
+ progress.progress(1.0, text="Done!")
978
+
979
+ bic_df = pd.DataFrame(bic_rows)
980
+ st.session_state.lc_bic_comparison = bic_df
981
+ st.session_state.lc_best_q = best_q
982
+
983
+ st.dataframe(bic_df, use_container_width=True, hide_index=True)
984
+
985
+ valid = bic_df.dropna(subset=["BIC"])
986
+ if not valid.empty:
987
+ fig_bic = px.line(valid, x="Q", y="BIC", markers=True, title="BIC by Number of Classes")
988
+ fig_bic.add_vline(x=best_q, line_dash="dash", line_color="green",
989
+ annotation_text=f"Best Q = {best_q}")
990
+ st.plotly_chart(fig_bic, use_container_width=True)
991
+
992
+ st.info(f"Recommended number of classes: **Q = {best_q}**")
993
+
994
+ elif st.session_state.get("lc_bic_comparison") is not None:
995
+ import plotly.express as px
996
+
997
+ bic_df = st.session_state.lc_bic_comparison
998
+ best_q = st.session_state.lc_best_q
999
+ st.dataframe(bic_df, use_container_width=True, hide_index=True)
1000
+ valid = bic_df.dropna(subset=["BIC"])
1001
+ if not valid.empty:
1002
+ fig_bic = px.line(valid, x="Q", y="BIC", markers=True, title="BIC by Number of Classes")
1003
+ fig_bic.add_vline(x=best_q, line_dash="dash", line_color="green",
1004
+ annotation_text=f"Best Q = {best_q}")
1005
+ st.plotly_chart(fig_bic, use_container_width=True)
1006
+ st.info(f"Recommended number of classes: **Q = {best_q}**")
1007
+
1008
+ # โ”€โ”€ Show saved model history โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1009
+ if st.session_state.model_history:
1010
+ st.divider()
1011
+ st.subheader("Saved model runs")
1012
+ for i, entry in enumerate(st.session_state.model_history, 1):
1013
+ est = entry["estimation"]
1014
+ st.markdown(
1015
+ f"**{i}. {entry.get('label', 'model')}** ({entry.get('model_type', '?')}) "
1016
+ f"โ€” LL: {est.log_likelihood:.3f}, AIC: {est.aic:.2f}, "
1017
+ f"BIC: {est.bic:.2f}"
1018
+ )
scripts/test_e2e.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end test script for the dce_analyzer backend.
2
+
3
+ Run from project root:
4
+ python scripts/test_e2e.py
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import sys
10
+ import traceback
11
+ from pathlib import Path
12
+
13
+ # Ensure src/ is importable
14
+ ROOT = Path(__file__).resolve().parents[1]
15
+ sys.path.insert(0, str(ROOT / "src"))
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Helpers
22
+ # ---------------------------------------------------------------------------
23
+
24
+ _results: list[tuple[str, bool, str]] = []
25
+
26
+
27
+ def _run(name: str, fn):
28
+ """Run *fn* and record PASS / FAIL."""
29
+ try:
30
+ fn()
31
+ _results.append((name, True, ""))
32
+ print(f" PASS {name}")
33
+ except Exception as exc:
34
+ msg = f"{exc.__class__.__name__}: {exc}"
35
+ _results.append((name, False, msg))
36
+ print(f" FAIL {name}")
37
+ traceback.print_exc()
38
+ print()
39
+
40
+
41
+ # ===================================================================
42
+ # 1. Import all backend modules
43
+ # ===================================================================
44
+ def test_imports():
45
+ from dce_analyzer.config import ModelSpec, VariableSpec
46
+ from dce_analyzer.simulate import generate_simulated_dce
47
+ from dce_analyzer.data import prepare_choice_tensors, ChoiceTensors
48
+ from dce_analyzer.model import (
49
+ MixedLogitEstimator,
50
+ ConditionalLogitEstimator,
51
+ EstimationResult,
52
+ )
53
+ from dce_analyzer.latent_class import LatentClassEstimator, LatentClassResult
54
+ from dce_analyzer.pipeline import estimate_dataframe, PipelineResult
55
+ from dce_analyzer.wtp import compute_wtp
56
+ from dce_analyzer.bootstrap import run_bootstrap, BootstrapResult
57
+ from dce_analyzer.format_converter import (
58
+ detect_format,
59
+ wide_to_long,
60
+ infer_structure,
61
+ normalize_choice_column,
62
+ ColumnInference,
63
+ )
64
+ from dce_analyzer.apollo import APOLLO_DATASETS
65
+ # all imported without error
66
+
67
+
68
+ _run("1. Import all backend modules", test_imports)
69
+
70
+
71
+ # ===================================================================
72
+ # 2. Generate simulated data
73
+ # ===================================================================
74
+ sim_output = None
75
+
76
+
77
+ def test_simulate():
78
+ global sim_output
79
+ from dce_analyzer.simulate import generate_simulated_dce
80
+
81
+ sim_output = generate_simulated_dce(
82
+ n_individuals=100, n_tasks=4, n_alts=3, seed=42
83
+ )
84
+ df = sim_output.data
85
+ assert isinstance(df, pd.DataFrame), "Expected DataFrame"
86
+ assert len(df) == 100 * 4 * 3, f"Expected 1200 rows, got {len(df)}"
87
+ for col in ["respondent_id", "task_id", "alternative", "choice",
88
+ "price", "time", "comfort", "reliability"]:
89
+ assert col in df.columns, f"Missing column: {col}"
90
+ assert isinstance(sim_output.true_parameters, dict)
91
+ assert len(sim_output.true_parameters) > 0
92
+
93
+
94
+ _run("2. Generate simulated data (100 ind, 4 tasks, 3 alts)", test_simulate)
95
+
96
+
97
+ # ===================================================================
98
+ # 3. Conditional Logit estimation
99
+ # ===================================================================
100
+ cl_result = None
101
+
102
+
103
+ def test_conditional_logit():
104
+ global cl_result
105
+ from dce_analyzer.config import ModelSpec, VariableSpec
106
+ from dce_analyzer.pipeline import estimate_dataframe
107
+
108
+ spec = ModelSpec(
109
+ id_col="respondent_id",
110
+ task_col="task_id",
111
+ alt_col="alternative",
112
+ choice_col="choice",
113
+ variables=[
114
+ VariableSpec(name="price", column="price"),
115
+ VariableSpec(name="time", column="time"),
116
+ VariableSpec(name="comfort", column="comfort"),
117
+ VariableSpec(name="reliability", column="reliability"),
118
+ ],
119
+ )
120
+ result = estimate_dataframe(
121
+ df=sim_output.data, spec=spec, model_type="conditional", maxiter=200, seed=42
122
+ )
123
+ cl_result = result
124
+ est = result.estimation
125
+ assert est.success, f"CL did not converge: {est.message}"
126
+ assert est.n_parameters == 4
127
+ assert est.n_observations == 100 * 4 # 400 choice tasks
128
+ assert not est.estimates.empty
129
+ assert "estimate" in est.estimates.columns
130
+
131
+
132
+ _run("3. Conditional Logit estimation", test_conditional_logit)
133
+
134
+
135
+ # ===================================================================
136
+ # 4. Mixed Logit estimation (n_draws=50)
137
+ # ===================================================================
138
+ mxl_result = None
139
+
140
+
141
+ def test_mixed_logit():
142
+ global mxl_result
143
+ from dce_analyzer.config import ModelSpec, VariableSpec
144
+ from dce_analyzer.pipeline import estimate_dataframe
145
+
146
+ spec = ModelSpec(
147
+ id_col="respondent_id",
148
+ task_col="task_id",
149
+ alt_col="alternative",
150
+ choice_col="choice",
151
+ variables=[
152
+ VariableSpec(name="price", column="price", distribution="normal"),
153
+ VariableSpec(name="time", column="time", distribution="normal"),
154
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
155
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
156
+ ],
157
+ n_draws=50,
158
+ )
159
+ result = estimate_dataframe(
160
+ df=sim_output.data, spec=spec, model_type="mixed", maxiter=200, seed=42
161
+ )
162
+ mxl_result = result
163
+ est = result.estimation
164
+ # 2 normal (mu+sd each) + 2 fixed = 6 params
165
+ assert est.n_parameters == 6, f"Expected 6 params, got {est.n_parameters}"
166
+ assert not est.estimates.empty
167
+ # Should have mu_price, sd_price, mu_time, sd_time, beta_comfort, beta_reliability
168
+ param_names = set(est.estimates["parameter"])
169
+ for expected in ["mu_price", "sd_price", "mu_time", "sd_time",
170
+ "beta_comfort", "beta_reliability"]:
171
+ assert expected in param_names, f"Missing param: {expected}"
172
+
173
+
174
+ _run("4. Mixed Logit estimation (n_draws=50)", test_mixed_logit)
175
+
176
+
177
+ # ===================================================================
178
+ # 5. Latent Class estimation (n_classes=2, n_starts=3)
179
+ # ===================================================================
180
+ lc_result = None
181
+
182
+
183
+ def test_latent_class():
184
+ global lc_result
185
+ from dce_analyzer.config import ModelSpec, VariableSpec
186
+ from dce_analyzer.pipeline import estimate_dataframe
187
+
188
+ spec = ModelSpec(
189
+ id_col="respondent_id",
190
+ task_col="task_id",
191
+ alt_col="alternative",
192
+ choice_col="choice",
193
+ variables=[
194
+ VariableSpec(name="price", column="price"),
195
+ VariableSpec(name="time", column="time"),
196
+ VariableSpec(name="comfort", column="comfort"),
197
+ VariableSpec(name="reliability", column="reliability"),
198
+ ],
199
+ n_classes=2,
200
+ )
201
+ result = estimate_dataframe(
202
+ df=sim_output.data, spec=spec, model_type="latent_class",
203
+ maxiter=200, seed=42, n_classes=2, n_starts=3,
204
+ )
205
+ lc_result = result
206
+ est = result.estimation
207
+ assert est.n_classes == 2
208
+ assert len(est.class_probabilities) == 2
209
+ assert abs(sum(est.class_probabilities) - 1.0) < 1e-4, "Class probs must sum to 1"
210
+ assert not est.estimates.empty
211
+ assert not est.class_estimates.empty
212
+ assert not est.posterior_probs.empty
213
+ assert est.posterior_probs.shape[1] == 2 # two class columns
214
+
215
+
216
+ _run("5. Latent Class estimation (n_classes=2, n_starts=3)", test_latent_class)
217
+
218
+
219
+ # ===================================================================
220
+ # 6. WTP computation
221
+ # ===================================================================
222
+ def test_wtp():
223
+ from dce_analyzer.wtp import compute_wtp
224
+
225
+ # Use CL result (EstimationResult) for WTP
226
+ wtp_df = compute_wtp(cl_result.estimation, cost_variable="price")
227
+ assert isinstance(wtp_df, pd.DataFrame)
228
+ assert len(wtp_df) == 3 # time, comfort, reliability (3 non-cost attrs)
229
+ assert "wtp_estimate" in wtp_df.columns
230
+ assert "wtp_std_error" in wtp_df.columns
231
+ assert "wtp_ci_lower" in wtp_df.columns
232
+ assert "wtp_ci_upper" in wtp_df.columns
233
+ # WTP values should be finite
234
+ for _, row in wtp_df.iterrows():
235
+ assert np.isfinite(row["wtp_estimate"]), f"Non-finite WTP for {row['attribute']}"
236
+
237
+
238
+ _run("6. WTP computation (CL result)", test_wtp)
239
+
240
+
241
+ # ===================================================================
242
+ # 7. Bootstrap (n_boot=10)
243
+ # ===================================================================
244
+ def test_bootstrap():
245
+ from dce_analyzer.config import ModelSpec, VariableSpec
246
+ from dce_analyzer.bootstrap import run_bootstrap
247
+
248
+ spec = ModelSpec(
249
+ id_col="respondent_id",
250
+ task_col="task_id",
251
+ alt_col="alternative",
252
+ choice_col="choice",
253
+ variables=[
254
+ VariableSpec(name="price", column="price"),
255
+ VariableSpec(name="time", column="time"),
256
+ VariableSpec(name="comfort", column="comfort"),
257
+ VariableSpec(name="reliability", column="reliability"),
258
+ ],
259
+ )
260
+ boot = run_bootstrap(
261
+ df=sim_output.data, spec=spec, model_type="conditional",
262
+ n_replications=10, maxiter=100, seed=42,
263
+ )
264
+ assert boot.n_replications == 10
265
+ assert boot.n_successful >= 2, f"Only {boot.n_successful} succeeded"
266
+ assert len(boot.param_names) == 4
267
+ assert boot.estimates_matrix.shape == (boot.n_successful, 4)
268
+ summary = boot.summary_dataframe()
269
+ assert isinstance(summary, pd.DataFrame)
270
+ assert len(summary) == 4
271
+
272
+
273
+ _run("7. Bootstrap (n_boot=10, conditional logit)", test_bootstrap)
274
+
275
+
276
+ # ===================================================================
277
+ # 8. Wide-to-long conversion
278
+ # ===================================================================
279
+ def test_wide_to_long():
280
+ from dce_analyzer.format_converter import detect_format, wide_to_long
281
+
282
+ # Create a small wide-format dataset
283
+ wide_df = pd.DataFrame({
284
+ "id": [1, 1, 2, 2],
285
+ "choice": [1, 2, 1, 3],
286
+ "price_1": [10, 20, 15, 25],
287
+ "price_2": [12, 22, 17, 27],
288
+ "price_3": [14, 24, 19, 29],
289
+ "time_1": [30, 40, 35, 45],
290
+ "time_2": [32, 42, 37, 47],
291
+ "time_3": [34, 44, 39, 49],
292
+ })
293
+
294
+ fmt = detect_format(wide_df)
295
+ assert fmt == "wide", f"Expected 'wide', got '{fmt}'"
296
+
297
+ long_df = wide_to_long(
298
+ wide_df,
299
+ attribute_groups={
300
+ "price": ["price_1", "price_2", "price_3"],
301
+ "time": ["time_1", "time_2", "time_3"],
302
+ },
303
+ id_col="id",
304
+ choice_col="choice",
305
+ )
306
+ assert isinstance(long_df, pd.DataFrame)
307
+ # 4 rows * 3 alts = 12 rows
308
+ assert len(long_df) == 12, f"Expected 12 rows, got {len(long_df)}"
309
+ assert "alternative" in long_df.columns
310
+ assert "choice" in long_df.columns
311
+ assert "price" in long_df.columns
312
+ assert "time" in long_df.columns
313
+ # Each task should have exactly one chosen alt
314
+ for (rid, tid), grp in long_df.groupby(["respondent_id", "task_id"]):
315
+ assert grp["choice"].sum() == 1, f"Task ({rid},{tid}) has {grp['choice'].sum()} choices"
316
+
317
+ # Test detect_format on long data
318
+ fmt2 = detect_format(long_df)
319
+ assert fmt2 == "long", f"Expected 'long' for converted data, got '{fmt2}'"
320
+
321
+
322
+ _run("8. Wide-to-long conversion", test_wide_to_long)
323
+
324
+
325
+ # ===================================================================
326
+ # 9. Additional checks: infer_structure, normalize_choice_column
327
+ # ===================================================================
328
+ def test_infer_and_normalize():
329
+ from dce_analyzer.format_converter import infer_structure, normalize_choice_column
330
+
331
+ df = sim_output.data
332
+ inference = infer_structure(df)
333
+ assert inference.id_col is not None, "Should detect id column"
334
+ assert inference.choice_col is not None, "Should detect choice column"
335
+
336
+ # Test normalize_choice_column (already binary -- should be no-op)
337
+ normalized = normalize_choice_column(df, "choice", "alternative")
338
+ assert set(normalized["choice"].unique()) <= {0, 1}
339
+
340
+
341
+ _run("9. infer_structure & normalize_choice_column", test_infer_and_normalize)
342
+
343
+
344
+ # ===================================================================
345
+ # 10. LatentClassResult.summary_dict()
346
+ # ===================================================================
347
+ def test_lc_summary():
348
+ est = lc_result.estimation
349
+ sd = est.summary_dict()
350
+ assert "n_classes" in sd
351
+ assert "class_probabilities" in sd
352
+ assert sd["n_classes"] == 2
353
+
354
+
355
+ _run("10. LatentClassResult.summary_dict()", test_lc_summary)
356
+
357
+
358
+ # ===================================================================
359
+ # 11. Full correlated MMNL (backward compat)
360
+ # ===================================================================
361
+ def test_full_correlated_mxl():
362
+ from dce_analyzer.config import ModelSpec, VariableSpec
363
+ from dce_analyzer.pipeline import estimate_dataframe
364
+
365
+ spec = ModelSpec(
366
+ id_col="respondent_id",
367
+ task_col="task_id",
368
+ alt_col="alternative",
369
+ choice_col="choice",
370
+ variables=[
371
+ VariableSpec(name="price", column="price", distribution="normal"),
372
+ VariableSpec(name="time", column="time", distribution="normal"),
373
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
374
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
375
+ ],
376
+ n_draws=50,
377
+ )
378
+ result = estimate_dataframe(
379
+ df=sim_output.data, spec=spec, model_type="mixed",
380
+ maxiter=200, seed=42, correlated=True,
381
+ )
382
+ est = result.estimation
383
+ assert est.covariance_matrix is not None, "Expected covariance matrix"
384
+ assert est.covariance_matrix.shape == (2, 2), f"Expected 2x2 cov, got {est.covariance_matrix.shape}"
385
+ assert est.correlation_matrix is not None
386
+
387
+
388
+ _run("11. Full correlated MMNL (backward compat)", test_full_correlated_mxl)
389
+
390
+
391
+ # ===================================================================
392
+ # 12. Selective correlated MMNL (block-diagonal Cholesky)
393
+ # ===================================================================
394
+ def test_selective_correlated_mxl():
395
+ from dce_analyzer.config import ModelSpec, VariableSpec
396
+ from dce_analyzer.pipeline import estimate_dataframe
397
+
398
+ spec = ModelSpec(
399
+ id_col="respondent_id",
400
+ task_col="task_id",
401
+ alt_col="alternative",
402
+ choice_col="choice",
403
+ variables=[
404
+ VariableSpec(name="price", column="price", distribution="normal"),
405
+ VariableSpec(name="time", column="time", distribution="normal"),
406
+ VariableSpec(name="comfort", column="comfort", distribution="normal"),
407
+ VariableSpec(name="reliability", column="reliability", distribution="normal"),
408
+ ],
409
+ n_draws=50,
410
+ )
411
+ # Correlate price-time (group [0,1]) and comfort-reliability (group [2,3])
412
+ result = estimate_dataframe(
413
+ df=sim_output.data, spec=spec, model_type="mixed",
414
+ maxiter=200, seed=42,
415
+ correlation_groups=[[0, 1], [2, 3]],
416
+ )
417
+ est = result.estimation
418
+ assert est.covariance_matrix is not None, "Expected covariance matrix"
419
+ assert est.covariance_matrix.shape == (4, 4)
420
+ # Off-block elements should be zero (price-comfort, price-reliability, etc.)
421
+ cov = est.covariance_matrix
422
+ assert abs(cov[0, 2]) < 1e-8, f"Expected 0 cov(price,comfort), got {cov[0,2]}"
423
+ assert abs(cov[0, 3]) < 1e-8, f"Expected 0 cov(price,reliability), got {cov[0,3]}"
424
+ assert abs(cov[1, 2]) < 1e-8, f"Expected 0 cov(time,comfort), got {cov[1,2]}"
425
+ assert abs(cov[1, 3]) < 1e-8, f"Expected 0 cov(time,reliability), got {cov[1,3]}"
426
+ # Within-block elements should be non-zero
427
+ assert abs(cov[0, 1]) > 1e-10 or True # may be zero by chance, just check shape
428
+
429
+
430
+ _run("12. Selective correlated MMNL (block-diagonal)", test_selective_correlated_mxl)
431
+
432
+
433
+ # ===================================================================
434
+ # 13. Selective with standalone random params
435
+ # ===================================================================
436
+ def test_selective_with_standalone():
437
+ from dce_analyzer.config import ModelSpec, VariableSpec
438
+ from dce_analyzer.pipeline import estimate_dataframe
439
+
440
+ spec = ModelSpec(
441
+ id_col="respondent_id",
442
+ task_col="task_id",
443
+ alt_col="alternative",
444
+ choice_col="choice",
445
+ variables=[
446
+ VariableSpec(name="price", column="price", distribution="normal"),
447
+ VariableSpec(name="time", column="time", distribution="normal"),
448
+ VariableSpec(name="comfort", column="comfort", distribution="normal"),
449
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
450
+ ],
451
+ n_draws=50,
452
+ )
453
+ # Only correlate price-time, comfort is standalone random
454
+ result = estimate_dataframe(
455
+ df=sim_output.data, spec=spec, model_type="mixed",
456
+ maxiter=200, seed=42,
457
+ correlation_groups=[[0, 1]],
458
+ )
459
+ est = result.estimation
460
+ assert est.covariance_matrix is not None
461
+ assert est.covariance_matrix.shape == (3, 3)
462
+ cov = est.covariance_matrix
463
+ # comfort (index 2) is standalone: zero cross-cov with price/time
464
+ assert abs(cov[0, 2]) < 1e-8, f"Expected 0 cov(price,comfort), got {cov[0,2]}"
465
+ assert abs(cov[1, 2]) < 1e-8, f"Expected 0 cov(time,comfort), got {cov[1,2]}"
466
+ # n_parameters: 3 mu + 3 chol(price-time) + 1 sd(comfort) + 1 fixed = 8
467
+ assert est.n_parameters == 8, f"Expected 8 params, got {est.n_parameters}"
468
+
469
+
470
+ _run("13. Selective with standalone random params", test_selective_with_standalone)
471
+
472
+
473
+ # ===================================================================
474
+ # 14. Create BWS simulated data
475
+ # ===================================================================
476
+ bws_df = None
477
+
478
+
479
+ def test_create_bws_data():
480
+ """Create BWS data by adding a 'worst' column to simulated DCE data."""
481
+ global bws_df
482
+ df = sim_output.data.copy()
483
+ # J=3 alts per task. For each task, pick the alt with LOWEST utility-like
484
+ # score as worst. Use negative of choice to ensure worst != best.
485
+ rng = np.random.default_rng(99)
486
+ worst_rows = []
487
+ for (rid, tid), grp in df.groupby(["respondent_id", "task_id"]):
488
+ best_alt = grp.loc[grp["choice"] == 1, "alternative"].values[0]
489
+ non_best = grp[grp["alternative"] != best_alt]
490
+ # Pick random non-best as worst
491
+ worst_alt = non_best["alternative"].values[rng.integers(len(non_best))]
492
+ for _, row in grp.iterrows():
493
+ worst_rows.append(1 if row["alternative"] == worst_alt else 0)
494
+ df["worst"] = worst_rows
495
+ # Verify: each task has exactly 1 worst, 1 best, and worst != best
496
+ for (rid, tid), grp in df.groupby(["respondent_id", "task_id"]):
497
+ assert grp["choice"].sum() == 1, "Exactly one best per task"
498
+ assert grp["worst"].sum() == 1, "Exactly one worst per task"
499
+ best_idx = grp.loc[grp["choice"] == 1].index[0]
500
+ worst_idx = grp.loc[grp["worst"] == 1].index[0]
501
+ assert best_idx != worst_idx, "worst != best"
502
+ bws_df = df
503
+ assert "worst" in bws_df.columns
504
+
505
+
506
+ _run("14. Create BWS simulated data", test_create_bws_data)
507
+
508
+
509
+ # ===================================================================
510
+ # 15. BWS + Conditional Logit
511
+ # ===================================================================
512
+ def test_bws_clogit():
513
+ from dce_analyzer.config import ModelSpec, VariableSpec
514
+ from dce_analyzer.pipeline import estimate_dataframe
515
+
516
+ spec = ModelSpec(
517
+ id_col="respondent_id",
518
+ task_col="task_id",
519
+ alt_col="alternative",
520
+ choice_col="choice",
521
+ variables=[
522
+ VariableSpec(name="price", column="price"),
523
+ VariableSpec(name="time", column="time"),
524
+ VariableSpec(name="comfort", column="comfort"),
525
+ VariableSpec(name="reliability", column="reliability"),
526
+ ],
527
+ )
528
+ result = estimate_dataframe(
529
+ df=bws_df, spec=spec, model_type="conditional",
530
+ maxiter=200, seed=42,
531
+ bws_worst_col="worst", estimate_lambda_w=True,
532
+ )
533
+ est = result.estimation
534
+ assert est.success, f"BWS CL did not converge: {est.message}"
535
+ # 4 betas + 1 lambda_w = 5 params
536
+ assert est.n_parameters == 5, f"Expected 5 params, got {est.n_parameters}"
537
+ # lambda_w should appear in estimates
538
+ param_names = set(est.estimates["parameter"])
539
+ assert "lambda_w (worst scale)" in param_names, f"Missing lambda_w param. Got: {param_names}"
540
+ # lambda_w should be positive
541
+ lw_row = est.estimates[est.estimates["parameter"] == "lambda_w (worst scale)"]
542
+ assert lw_row["estimate"].values[0] > 0, "lambda_w must be positive"
543
+
544
+
545
+ _run("15. BWS + Conditional Logit", test_bws_clogit)
546
+
547
+
548
+ # ===================================================================
549
+ # 16. BWS + CLogit with lambda_w fixed (MaxDiff equivalent)
550
+ # ===================================================================
551
+ def test_bws_clogit_fixed_lw():
552
+ from dce_analyzer.config import ModelSpec, VariableSpec
553
+ from dce_analyzer.pipeline import estimate_dataframe
554
+
555
+ spec = ModelSpec(
556
+ id_col="respondent_id",
557
+ task_col="task_id",
558
+ alt_col="alternative",
559
+ choice_col="choice",
560
+ variables=[
561
+ VariableSpec(name="price", column="price"),
562
+ VariableSpec(name="time", column="time"),
563
+ VariableSpec(name="comfort", column="comfort"),
564
+ VariableSpec(name="reliability", column="reliability"),
565
+ ],
566
+ )
567
+ result = estimate_dataframe(
568
+ df=bws_df, spec=spec, model_type="conditional",
569
+ maxiter=200, seed=42,
570
+ bws_worst_col="worst", estimate_lambda_w=False,
571
+ )
572
+ est = result.estimation
573
+ assert est.success
574
+ # 4 betas only (no lambda_w)
575
+ assert est.n_parameters == 4, f"Expected 4 params, got {est.n_parameters}"
576
+ param_names = set(est.estimates["parameter"])
577
+ assert "lambda_w (worst scale)" not in param_names
578
+
579
+
580
+ _run("16. BWS + CLogit fixed lambda_w (MaxDiff)", test_bws_clogit_fixed_lw)
581
+
582
+
583
+ # ===================================================================
584
+ # 17. BWS + Mixed Logit
585
+ # ===================================================================
586
+ def test_bws_mxl():
587
+ from dce_analyzer.config import ModelSpec, VariableSpec
588
+ from dce_analyzer.pipeline import estimate_dataframe
589
+
590
+ spec = ModelSpec(
591
+ id_col="respondent_id",
592
+ task_col="task_id",
593
+ alt_col="alternative",
594
+ choice_col="choice",
595
+ variables=[
596
+ VariableSpec(name="price", column="price", distribution="normal"),
597
+ VariableSpec(name="time", column="time", distribution="normal"),
598
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
599
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
600
+ ],
601
+ n_draws=50,
602
+ )
603
+ result = estimate_dataframe(
604
+ df=bws_df, spec=spec, model_type="mixed",
605
+ maxiter=200, seed=42,
606
+ bws_worst_col="worst", estimate_lambda_w=True,
607
+ )
608
+ est = result.estimation
609
+ # 2 mu + 2 sd + 2 fixed + 1 lambda_w = 7
610
+ assert est.n_parameters == 7, f"Expected 7 params, got {est.n_parameters}"
611
+ param_names = set(est.estimates["parameter"])
612
+ assert "lambda_w (worst scale)" in param_names
613
+ assert "mu_price" in param_names
614
+ assert "sd_price" in param_names
615
+
616
+
617
+ _run("17. BWS + Mixed Logit", test_bws_mxl)
618
+
619
+
620
+ # ===================================================================
621
+ # 18. BWS + GMNL
622
+ # ===================================================================
623
+ def test_bws_gmnl():
624
+ from dce_analyzer.config import ModelSpec, VariableSpec
625
+ from dce_analyzer.pipeline import estimate_dataframe
626
+
627
+ spec = ModelSpec(
628
+ id_col="respondent_id",
629
+ task_col="task_id",
630
+ alt_col="alternative",
631
+ choice_col="choice",
632
+ variables=[
633
+ VariableSpec(name="price", column="price", distribution="normal"),
634
+ VariableSpec(name="time", column="time", distribution="fixed"),
635
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
636
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
637
+ ],
638
+ n_draws=50,
639
+ )
640
+ result = estimate_dataframe(
641
+ df=bws_df, spec=spec, model_type="gmnl",
642
+ maxiter=200, seed=42,
643
+ bws_worst_col="worst", estimate_lambda_w=True,
644
+ )
645
+ est = result.estimation
646
+ # 1 mu + 1 sd + 3 fixed + 1 lambda_w + 3 GMNL(tau,sigma_tau,gamma) = 9
647
+ assert est.n_parameters == 9, f"Expected 9 params, got {est.n_parameters}"
648
+ param_names = set(est.estimates["parameter"])
649
+ assert "lambda_w (worst scale)" in param_names
650
+ assert "tau (scale mean)" in param_names
651
+
652
+
653
+ _run("18. BWS + GMNL", test_bws_gmnl)
654
+
655
+
656
+ # ===================================================================
657
+ # 19. BWS + Latent Class
658
+ # ===================================================================
659
+ def test_bws_lc():
660
+ from dce_analyzer.config import ModelSpec, VariableSpec
661
+ from dce_analyzer.pipeline import estimate_dataframe
662
+
663
+ spec = ModelSpec(
664
+ id_col="respondent_id",
665
+ task_col="task_id",
666
+ alt_col="alternative",
667
+ choice_col="choice",
668
+ variables=[
669
+ VariableSpec(name="price", column="price"),
670
+ VariableSpec(name="time", column="time"),
671
+ VariableSpec(name="comfort", column="comfort"),
672
+ VariableSpec(name="reliability", column="reliability"),
673
+ ],
674
+ n_classes=2,
675
+ )
676
+ result = estimate_dataframe(
677
+ df=bws_df, spec=spec, model_type="latent_class",
678
+ maxiter=200, seed=42, n_classes=2, n_starts=3,
679
+ bws_worst_col="worst", estimate_lambda_w=True,
680
+ )
681
+ est = result.estimation
682
+ assert est.n_classes == 2
683
+ assert len(est.class_probabilities) == 2
684
+ # Check lambda_w appears in estimates
685
+ lw_rows = est.estimates[est.estimates["parameter"].str.contains("lambda_w")]
686
+ assert len(lw_rows) > 0, "Missing lambda_w in LC estimates"
687
+
688
+
689
+ _run("19. BWS + Latent Class", test_bws_lc)
690
+
691
+
692
+ # ===================================================================
693
+ # 20. Correlation inference (delta method SEs for cov/cor)
694
+ # ===================================================================
695
+ def test_correlation_inference():
696
+ from dce_analyzer.config import ModelSpec, VariableSpec
697
+ from dce_analyzer.pipeline import estimate_dataframe
698
+
699
+ spec = ModelSpec(
700
+ id_col="respondent_id",
701
+ task_col="task_id",
702
+ alt_col="alternative",
703
+ choice_col="choice",
704
+ variables=[
705
+ VariableSpec(name="price", column="price", distribution="normal"),
706
+ VariableSpec(name="time", column="time", distribution="normal"),
707
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
708
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
709
+ ],
710
+ n_draws=50,
711
+ )
712
+ result = estimate_dataframe(
713
+ df=sim_output.data, spec=spec, model_type="mixed",
714
+ maxiter=200, seed=42, correlated=True,
715
+ )
716
+ est = result.estimation
717
+ # Covariance SE matrix should exist and match shape
718
+ assert est.covariance_se is not None, "Expected covariance_se"
719
+ assert est.covariance_se.shape == (2, 2), f"Expected 2x2, got {est.covariance_se.shape}"
720
+ # Correlation SE matrix
721
+ assert est.correlation_se is not None, "Expected correlation_se"
722
+ assert est.correlation_se.shape == (2, 2)
723
+ # Diagonal of correlation SE should be 0 (cor(x,x)=1, no variation)
724
+ for i in range(2):
725
+ assert est.correlation_se[i, i] < 1e-6, f"Diagonal cor SE should be ~0, got {est.correlation_se[i,i]}"
726
+ # Correlation test table
727
+ assert est.correlation_test is not None, "Expected correlation_test DataFrame"
728
+ assert len(est.correlation_test) == 1, "Expected 1 off-diagonal pair for 2 random params"
729
+ row = est.correlation_test.iloc[0]
730
+ assert row["param_1"] == "price"
731
+ assert row["param_2"] == "time"
732
+ assert not np.isnan(row["cor_std_error"]), "SE should not be NaN"
733
+ assert not np.isnan(row["z_stat"]), "z_stat should not be NaN"
734
+ assert not np.isnan(row["p_value"]), "p_value should not be NaN"
735
+ assert 0.0 <= row["p_value"] <= 1.0, f"p-value out of range: {row['p_value']}"
736
+
737
+
738
+ _run("20. Correlation inference (delta method SEs for cov/cor)", test_correlation_inference)
739
+
740
+
741
+ # ===================================================================
742
+ # 21. FullModelSpec + estimate_from_spec
743
+ # ===================================================================
744
+ def test_full_model_spec():
745
+ from dce_analyzer.config import FullModelSpec, VariableSpec
746
+ from dce_analyzer.pipeline import estimate_from_spec
747
+
748
+ spec = FullModelSpec(
749
+ id_col="respondent_id",
750
+ task_col="task_id",
751
+ alt_col="alternative",
752
+ choice_col="choice",
753
+ variables=[
754
+ VariableSpec(name="price", column="price", distribution="normal"),
755
+ VariableSpec(name="time", column="time", distribution="normal"),
756
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
757
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
758
+ ],
759
+ model_type="mixed",
760
+ n_draws=50,
761
+ maxiter=200,
762
+ seed=42,
763
+ )
764
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
765
+ est = result.estimation
766
+ # Should produce the same kind of result as estimate_dataframe
767
+ assert est.n_parameters == 6, f"Expected 6 params, got {est.n_parameters}"
768
+ assert not est.estimates.empty
769
+ param_names = set(est.estimates["parameter"])
770
+ for expected in ["mu_price", "sd_price", "mu_time", "sd_time",
771
+ "beta_comfort", "beta_reliability"]:
772
+ assert expected in param_names, f"Missing param: {expected}"
773
+ assert est.n_observations == 100 * 4
774
+
775
+
776
+ _run("21. FullModelSpec + estimate_from_spec", test_full_model_spec)
777
+
778
+
779
+ # ===================================================================
780
+ # 22. Heterogeneity interactions with MMNL via FullModelSpec
781
+ # ===================================================================
782
+ def test_interactions_mmnl():
783
+ from dce_analyzer.config import FullModelSpec, InteractionTerm, VariableSpec
784
+ from dce_analyzer.pipeline import estimate_from_spec
785
+
786
+ spec = FullModelSpec(
787
+ id_col="respondent_id",
788
+ task_col="task_id",
789
+ alt_col="alternative",
790
+ choice_col="choice",
791
+ variables=[
792
+ VariableSpec(name="price", column="price", distribution="normal"),
793
+ VariableSpec(name="time", column="time", distribution="fixed"),
794
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
795
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
796
+ ],
797
+ model_type="mixed",
798
+ interactions=[
799
+ InteractionTerm(columns=("price", "income")),
800
+ ],
801
+ n_draws=50,
802
+ maxiter=200,
803
+ seed=42,
804
+ )
805
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
806
+ est = result.estimation
807
+ param_names = set(est.estimates["parameter"])
808
+ # Interaction term should appear as a fixed parameter
809
+ assert "beta_price_x_income" in param_names, (
810
+ f"Missing interaction param. Got: {param_names}"
811
+ )
812
+ # 1 mu + 1 sd (price) + 3 fixed (time, comfort, reliability) + 1 interaction = 6
813
+ assert est.n_parameters == 6, f"Expected 6 params, got {est.n_parameters}"
814
+
815
+
816
+ _run("22. Heterogeneity interactions with MMNL (InteractionTerm)", test_interactions_mmnl)
817
+
818
+
819
+ # ===================================================================
820
+ # 23. GMNL + full correlation
821
+ # ===================================================================
822
+ def test_gmnl_full_correlation():
823
+ from dce_analyzer.config import ModelSpec, VariableSpec
824
+ from dce_analyzer.pipeline import estimate_dataframe
825
+
826
+ spec = ModelSpec(
827
+ id_col="respondent_id",
828
+ task_col="task_id",
829
+ alt_col="alternative",
830
+ choice_col="choice",
831
+ variables=[
832
+ VariableSpec(name="price", column="price", distribution="normal"),
833
+ VariableSpec(name="time", column="time", distribution="normal"),
834
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
835
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
836
+ ],
837
+ n_draws=50,
838
+ )
839
+ result = estimate_dataframe(
840
+ df=sim_output.data, spec=spec, model_type="gmnl",
841
+ maxiter=200, seed=42, correlated=True,
842
+ )
843
+ est = result.estimation
844
+ assert est.covariance_matrix is not None, "Expected covariance matrix for GMNL+correlated"
845
+ assert est.covariance_matrix.shape == (2, 2), (
846
+ f"Expected 2x2 cov, got {est.covariance_matrix.shape}"
847
+ )
848
+ assert est.correlation_matrix is not None
849
+ # GMNL params: 2 mu + chol(2)=3 + 2 fixed + 3 GMNL(tau,sigma_tau,gamma) = 10
850
+ assert est.n_parameters == 10, f"Expected 10 params, got {est.n_parameters}"
851
+ param_names = set(est.estimates["parameter"])
852
+ assert "tau (scale mean)" in param_names
853
+ assert "sigma_tau (scale SD)" in param_names
854
+ assert "gamma (mixing)" in param_names
855
+
856
+
857
+ _run("23. GMNL + full correlation", test_gmnl_full_correlation)
858
+
859
+
860
+ # ===================================================================
861
+ # 24. GMNL + selective correlation
862
+ # ===================================================================
863
+ def test_gmnl_selective_correlation():
864
+ from dce_analyzer.config import ModelSpec, VariableSpec
865
+ from dce_analyzer.pipeline import estimate_dataframe
866
+
867
+ spec = ModelSpec(
868
+ id_col="respondent_id",
869
+ task_col="task_id",
870
+ alt_col="alternative",
871
+ choice_col="choice",
872
+ variables=[
873
+ VariableSpec(name="price", column="price", distribution="normal"),
874
+ VariableSpec(name="time", column="time", distribution="normal"),
875
+ VariableSpec(name="comfort", column="comfort", distribution="normal"),
876
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
877
+ ],
878
+ n_draws=50,
879
+ )
880
+ # Correlate price-time only; comfort is standalone random
881
+ result = estimate_dataframe(
882
+ df=sim_output.data, spec=spec, model_type="gmnl",
883
+ maxiter=200, seed=42,
884
+ correlation_groups=[[0, 1]],
885
+ )
886
+ est = result.estimation
887
+ assert est.covariance_matrix is not None
888
+ assert est.covariance_matrix.shape == (3, 3)
889
+ cov = est.covariance_matrix
890
+ # comfort (index 2) is standalone: zero cross-cov with price/time
891
+ assert abs(cov[0, 2]) < 1e-8, f"Expected 0 cov(price,comfort), got {cov[0,2]}"
892
+ assert abs(cov[1, 2]) < 1e-8, f"Expected 0 cov(time,comfort), got {cov[1,2]}"
893
+ param_names = set(est.estimates["parameter"])
894
+ assert "tau (scale mean)" in param_names
895
+
896
+
897
+ _run("24. GMNL + selective correlation", test_gmnl_selective_correlation)
898
+
899
+
900
+ # ===================================================================
901
+ # 25. BWS composable functions (bws_log_prob, standard_log_prob)
902
+ # ===================================================================
903
+ def test_bws_composable_functions():
904
+ import torch
905
+ from dce_analyzer.bws import bws_log_prob, standard_log_prob
906
+
907
+ # Create simple test tensors: 4 observations, 3 alternatives
908
+ n_obs, n_alts = 4, 3
909
+ torch.manual_seed(42)
910
+ utility = torch.randn(n_obs, n_alts)
911
+ y_best = torch.tensor([0, 1, 2, 0]) # chosen alt indices
912
+ y_worst = torch.tensor([2, 0, 1, 1]) # worst alt indices (different from best)
913
+
914
+ # Test standard_log_prob
915
+ log_p = standard_log_prob(utility, y_best, alt_dim=-1)
916
+ assert log_p.shape == (n_obs,), f"Expected shape ({n_obs},), got {log_p.shape}"
917
+ # Log-probabilities must be <= 0
918
+ assert (log_p <= 1e-6).all(), "Log-probabilities must be <= 0"
919
+ # Probabilities must sum to 1 across alternatives (verify via logsumexp)
920
+ log_all = torch.stack([
921
+ standard_log_prob(utility, torch.full((n_obs,), j), alt_dim=-1)
922
+ for j in range(n_alts)
923
+ ], dim=1)
924
+ prob_sums = torch.exp(log_all).sum(dim=1)
925
+ assert torch.allclose(prob_sums, torch.ones(n_obs), atol=1e-5), (
926
+ f"Probabilities don't sum to 1: {prob_sums}"
927
+ )
928
+
929
+ # Test bws_log_prob
930
+ lambda_w = 1.0
931
+ log_p_bws = bws_log_prob(utility, y_best, y_worst, lambda_w, alt_dim=-1)
932
+ assert log_p_bws.shape == (n_obs,), f"Expected shape ({n_obs},), got {log_p_bws.shape}"
933
+ assert (log_p_bws <= 1e-6).all(), "BWS log-probabilities must be <= 0"
934
+ # BWS log-prob should be less than standard (it's a product of two probs)
935
+ assert (log_p_bws <= log_p + 1e-6).all(), (
936
+ "BWS log-prob should be <= standard log-prob (product of two probs)"
937
+ )
938
+
939
+ # Test with lambda_w as tensor
940
+ lambda_w_tensor = torch.tensor(2.0)
941
+ log_p_bws2 = bws_log_prob(utility, y_best, y_worst, lambda_w_tensor, alt_dim=-1)
942
+ assert log_p_bws2.shape == (n_obs,)
943
+
944
+ # Test with 3D utility (simulating draws): (n_obs, n_draws, n_alts)
945
+ n_draws = 5
946
+ utility_3d = torch.randn(n_obs, n_draws, n_alts)
947
+ log_p_3d = standard_log_prob(utility_3d, y_best, alt_dim=-1)
948
+ assert log_p_3d.shape == (n_obs, n_draws), f"Expected ({n_obs},{n_draws}), got {log_p_3d.shape}"
949
+
950
+ log_p_bws_3d = bws_log_prob(utility_3d, y_best, y_worst, 1.0, alt_dim=-1)
951
+ assert log_p_bws_3d.shape == (n_obs, n_draws), (
952
+ f"Expected ({n_obs},{n_draws}), got {log_p_bws_3d.shape}"
953
+ )
954
+
955
+
956
+ _run("25. BWS composable functions (bws_log_prob, standard_log_prob)", test_bws_composable_functions)
957
+
958
+
959
+ # ===================================================================
960
+ # 26. Heterogeneity interactions with Latent Class via FullModelSpec
961
+ # ===================================================================
962
+ def test_interactions_lc():
963
+ from dce_analyzer.config import FullModelSpec, InteractionTerm, VariableSpec
964
+ from dce_analyzer.pipeline import estimate_from_spec
965
+
966
+ spec = FullModelSpec(
967
+ id_col="respondent_id",
968
+ task_col="task_id",
969
+ alt_col="alternative",
970
+ choice_col="choice",
971
+ variables=[
972
+ VariableSpec(name="price", column="price"),
973
+ VariableSpec(name="time", column="time"),
974
+ VariableSpec(name="comfort", column="comfort"),
975
+ VariableSpec(name="reliability", column="reliability"),
976
+ ],
977
+ model_type="latent_class",
978
+ interactions=[
979
+ InteractionTerm(columns=("price", "income")),
980
+ ],
981
+ n_classes=2,
982
+ n_starts=3,
983
+ maxiter=200,
984
+ seed=42,
985
+ )
986
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
987
+ est = result.estimation
988
+ assert est.n_classes == 2
989
+ # Interaction param should appear in estimates
990
+ has_interaction = any("price_x_income" in str(p) for p in est.estimates["parameter"])
991
+ assert has_interaction, (
992
+ f"Missing interaction param in LC estimates. Got: {list(est.estimates['parameter'])}"
993
+ )
994
+
995
+
996
+ _run("26. Heterogeneity interactions with Latent Class (InteractionTerm)", test_interactions_lc)
997
+
998
+
999
+ # ===================================================================
1000
+ # 27. FullModelSpec with dummy coding via estimate_from_spec
1001
+ # ===================================================================
1002
+ def test_dummy_coding_via_spec():
1003
+ from dce_analyzer.config import DummyCoding, FullModelSpec, VariableSpec
1004
+ from dce_analyzer.pipeline import estimate_from_spec
1005
+
1006
+ # comfort has 2 unique values (0, 1) -> dummy with ref=0 -> one dummy comfort_L1
1007
+ spec = FullModelSpec(
1008
+ id_col="respondent_id",
1009
+ task_col="task_id",
1010
+ alt_col="alternative",
1011
+ choice_col="choice",
1012
+ variables=[
1013
+ VariableSpec(name="price", column="price"),
1014
+ VariableSpec(name="time", column="time"),
1015
+ VariableSpec(name="comfort", column="comfort"),
1016
+ VariableSpec(name="reliability", column="reliability"),
1017
+ ],
1018
+ model_type="conditional",
1019
+ dummy_codings=[
1020
+ DummyCoding(column="comfort", ref_level=0),
1021
+ ],
1022
+ maxiter=200,
1023
+ seed=42,
1024
+ )
1025
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
1026
+ est = result.estimation
1027
+ param_names = set(est.estimates["parameter"])
1028
+ # comfort should be expanded: beta_comfort_L1 instead of beta_comfort
1029
+ assert "beta_comfort_L1" in param_names, (
1030
+ f"Missing dummy param beta_comfort_L1. Got: {param_names}"
1031
+ )
1032
+ # Original comfort should NOT appear
1033
+ assert "beta_comfort" not in param_names, (
1034
+ f"Original column should be replaced by dummy expansion. Got: {param_names}"
1035
+ )
1036
+ # price, time, reliability remain continuous
1037
+ assert "beta_price" in param_names
1038
+ assert "beta_time" in param_names
1039
+ assert "beta_reliability" in param_names
1040
+ # 3 continuous + 1 dummy = 4 params
1041
+ assert est.n_parameters == 4, f"Expected 4 params, got {est.n_parameters}"
1042
+
1043
+
1044
+ _run("27. FullModelSpec with dummy coding via estimate_from_spec", test_dummy_coding_via_spec)
1045
+
1046
+
1047
+ # ===================================================================
1048
+ # 28. Variable ordering: dummy-coded vars expanded in-place
1049
+ # ===================================================================
1050
+ def test_variable_ordering_preservation():
1051
+ from dce_analyzer.config import DummyCoding, FullModelSpec, VariableSpec
1052
+ from dce_analyzer.pipeline import estimate_from_spec
1053
+
1054
+ # Variables in order: price (continuous), comfort (dummy, binary 0/1), time (continuous), reliability (continuous)
1055
+ # After expansion, order must be: price, comfort_L1, time, reliability
1056
+ # (not: price, time, reliability, comfort_L1 โ€” the old buggy behavior)
1057
+ spec = FullModelSpec(
1058
+ id_col="respondent_id",
1059
+ task_col="task_id",
1060
+ alt_col="alternative",
1061
+ choice_col="choice",
1062
+ variables=[
1063
+ VariableSpec(name="price", column="price"),
1064
+ VariableSpec(name="comfort", column="comfort"),
1065
+ VariableSpec(name="time", column="time"),
1066
+ VariableSpec(name="reliability", column="reliability"),
1067
+ ],
1068
+ model_type="conditional",
1069
+ dummy_codings=[
1070
+ DummyCoding(column="comfort", ref_level=0),
1071
+ ],
1072
+ maxiter=200,
1073
+ seed=42,
1074
+ )
1075
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
1076
+ est = result.estimation
1077
+ param_names = list(est.estimates["parameter"])
1078
+ # Check order: price -> comfort dummy -> time -> reliability
1079
+ expected_order = ["beta_price", "beta_comfort_L1", "beta_time", "beta_reliability"]
1080
+ assert param_names == expected_order, (
1081
+ f"Variable ordering not preserved. Expected {expected_order}, got {param_names}"
1082
+ )
1083
+ # Also verify expanded_spec preserves order
1084
+ exp_spec = result.expanded_spec
1085
+ exp_var_names = [v.name for v in exp_spec.variables]
1086
+ assert exp_var_names == ["price", "comfort_L1", "time", "reliability"], (
1087
+ f"Expanded spec variable order wrong: {exp_var_names}"
1088
+ )
1089
+
1090
+
1091
+ _run("28. Variable ordering: dummy-coded vars expanded in-place", test_variable_ordering_preservation)
1092
+
1093
+
1094
+ # ===================================================================
1095
+ # 29. WTP theta_index mapping for MMNL (SE correctness)
1096
+ # ===================================================================
1097
+ def test_wtp_theta_index():
1098
+ from dce_analyzer.config import ModelSpec, VariableSpec
1099
+ from dce_analyzer.pipeline import estimate_dataframe
1100
+ from dce_analyzer.wtp import compute_wtp
1101
+
1102
+ # price is random, then time (fixed), comfort (fixed), reliability (fixed)
1103
+ # This creates interleaved mu/sd rows: mu_price, sd_price, beta_time, ...
1104
+ # The theta_index mapping must be correct for WTP SEs.
1105
+ spec = ModelSpec(
1106
+ id_col="respondent_id",
1107
+ task_col="task_id",
1108
+ alt_col="alternative",
1109
+ choice_col="choice",
1110
+ variables=[
1111
+ VariableSpec(name="price", column="price", distribution="normal"),
1112
+ VariableSpec(name="time", column="time", distribution="fixed"),
1113
+ VariableSpec(name="comfort", column="comfort", distribution="fixed"),
1114
+ VariableSpec(name="reliability", column="reliability", distribution="fixed"),
1115
+ ],
1116
+ n_draws=50,
1117
+ )
1118
+ result = estimate_dataframe(
1119
+ df=sim_output.data, spec=spec, model_type="mixed",
1120
+ maxiter=200, seed=42,
1121
+ )
1122
+ est = result.estimation
1123
+
1124
+ # Verify theta_index column exists and is correct
1125
+ assert "theta_index" in est.estimates.columns, "theta_index column missing"
1126
+ # mu_price -> theta 0, sd_price -> theta 4, beta_time -> theta 1,
1127
+ # beta_comfort -> theta 2, beta_reliability -> theta 3
1128
+ tidx_map = dict(zip(est.estimates["parameter"], est.estimates["theta_index"]))
1129
+ assert tidx_map["mu_price"] == 0, f"mu_price should be theta 0, got {tidx_map['mu_price']}"
1130
+ assert tidx_map["beta_time"] == 1, f"beta_time should be theta 1, got {tidx_map['beta_time']}"
1131
+ assert tidx_map["sd_price"] == 4, f"sd_price should be theta 4, got {tidx_map['sd_price']}"
1132
+
1133
+ # Compute WTP using time as the cost variable
1134
+ wtp_df = compute_wtp(est, cost_variable="time")
1135
+ assert not wtp_df.empty
1136
+ # Check that SEs are not NaN (vcov should be available)
1137
+ if est.vcov_matrix is not None:
1138
+ for _, row in wtp_df.iterrows():
1139
+ if row["attribute"] in ("price", "comfort", "reliability"):
1140
+ assert not np.isnan(row["wtp_std_error"]), (
1141
+ f"WTP SE is NaN for {row['attribute']} โ€” theta_index mapping may be wrong"
1142
+ )
1143
+
1144
+
1145
+ _run("29. WTP theta_index mapping for MMNL (SE correctness)", test_wtp_theta_index)
1146
+
1147
+
1148
+ # ===================================================================
1149
+ # 30. 3-way interaction (price ร— time ร— income)
1150
+ # ===================================================================
1151
+ def test_3way_interaction():
1152
+ from dce_analyzer.config import FullModelSpec, InteractionTerm, VariableSpec
1153
+ from dce_analyzer.pipeline import estimate_from_spec
1154
+
1155
+ spec = FullModelSpec(
1156
+ id_col="respondent_id",
1157
+ task_col="task_id",
1158
+ alt_col="alternative",
1159
+ choice_col="choice",
1160
+ variables=[
1161
+ VariableSpec(name="price", column="price"),
1162
+ VariableSpec(name="time", column="time"),
1163
+ VariableSpec(name="comfort", column="comfort"),
1164
+ VariableSpec(name="reliability", column="reliability"),
1165
+ ],
1166
+ model_type="conditional",
1167
+ interactions=[
1168
+ InteractionTerm(columns=("price", "time", "income")),
1169
+ ],
1170
+ maxiter=200,
1171
+ seed=42,
1172
+ )
1173
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
1174
+ est = result.estimation
1175
+ param_names = set(est.estimates["parameter"])
1176
+ # 3-way interaction name: price_x_time_x_income
1177
+ assert "beta_price_x_time_x_income" in param_names, (
1178
+ f"Missing 3-way interaction param. Got: {param_names}"
1179
+ )
1180
+ # 4 base + 1 interaction = 5 params
1181
+ assert est.n_parameters == 5, f"Expected 5 params, got {est.n_parameters}"
1182
+
1183
+
1184
+ _run("30. 3-way interaction (price ร— time ร— income)", test_3way_interaction)
1185
+
1186
+
1187
+ # ===================================================================
1188
+ # 31. Attribute ร— attribute interaction (price ร— time)
1189
+ # ===================================================================
1190
+ def test_attribute_x_attribute_interaction():
1191
+ from dce_analyzer.config import FullModelSpec, InteractionTerm, VariableSpec
1192
+ from dce_analyzer.pipeline import estimate_from_spec
1193
+
1194
+ spec = FullModelSpec(
1195
+ id_col="respondent_id",
1196
+ task_col="task_id",
1197
+ alt_col="alternative",
1198
+ choice_col="choice",
1199
+ variables=[
1200
+ VariableSpec(name="price", column="price"),
1201
+ VariableSpec(name="time", column="time"),
1202
+ VariableSpec(name="comfort", column="comfort"),
1203
+ VariableSpec(name="reliability", column="reliability"),
1204
+ ],
1205
+ model_type="conditional",
1206
+ interactions=[
1207
+ InteractionTerm(columns=("price", "time")),
1208
+ ],
1209
+ maxiter=200,
1210
+ seed=42,
1211
+ )
1212
+ result = estimate_from_spec(df=sim_output.data, spec=spec)
1213
+ est = result.estimation
1214
+ param_names = set(est.estimates["parameter"])
1215
+ # attribute x attribute interaction
1216
+ assert "beta_price_x_time" in param_names, (
1217
+ f"Missing attribute x attribute interaction param. Got: {param_names}"
1218
+ )
1219
+ # 4 base + 1 interaction = 5 params
1220
+ assert est.n_parameters == 5, f"Expected 5 params, got {est.n_parameters}"
1221
+
1222
+
1223
+ _run("31. Attribute ร— attribute interaction (price ร— time)", test_attribute_x_attribute_interaction)
1224
+
1225
+
1226
+ # ===================================================================
1227
+ # Summary
1228
+ # ===================================================================
1229
+ print()
1230
+ print("=" * 60)
1231
+ n_pass = sum(1 for _, ok, _ in _results if ok)
1232
+ n_fail = sum(1 for _, ok, _ in _results if not ok)
1233
+ print(f" {n_pass} passed, {n_fail} failed out of {len(_results)} tests")
1234
+ print("=" * 60)
1235
+
1236
+ if n_fail > 0:
1237
+ print()
1238
+ print("FAILURES:")
1239
+ for name, ok, msg in _results:
1240
+ if not ok:
1241
+ print(f" {name}: {msg}")
1242
+ print()
1243
+ sys.exit(1)
1244
+ else:
1245
+ print(" ALL TESTS PASSED")
1246
+ sys.exit(0)
src/dce_analyzer/config.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Literal
5
+
6
+
7
+ DistributionType = Literal["fixed", "normal", "lognormal"]
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class VariableSpec:
12
+ """One variable used in the utility function."""
13
+
14
+ name: str
15
+ column: str
16
+ distribution: DistributionType = "fixed"
17
+
18
+ def __post_init__(self) -> None:
19
+ if not self.name:
20
+ raise ValueError("VariableSpec.name cannot be empty.")
21
+ if not self.column:
22
+ raise ValueError("VariableSpec.column cannot be empty.")
23
+ if self.distribution not in {"fixed", "normal", "lognormal"}:
24
+ raise ValueError(
25
+ f"Unsupported distribution '{self.distribution}'. "
26
+ "Use one of: fixed, normal, lognormal."
27
+ )
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class ModelSpec:
32
+ """Data layout and variable config for a model run."""
33
+
34
+ id_col: str
35
+ task_col: str
36
+ alt_col: str
37
+ choice_col: str
38
+ variables: list[VariableSpec]
39
+ n_draws: int = 200
40
+ n_classes: int = 2
41
+ membership_cols: list[str] | None = None
42
+
43
+ def __post_init__(self) -> None:
44
+ core_cols = [self.id_col, self.task_col, self.alt_col, self.choice_col]
45
+ if any(not c for c in core_cols):
46
+ raise ValueError("id_col, task_col, alt_col, and choice_col must all be set.")
47
+ if len(self.variables) == 0:
48
+ raise ValueError("At least one variable is required in ModelSpec.variables.")
49
+ if self.n_draws < 1:
50
+ raise ValueError("n_draws must be >= 1.")
51
+ if self.n_classes < 1:
52
+ raise ValueError("n_classes must be >= 1.")
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class DummyCoding:
57
+ """Dummy-coding specification for a single attribute."""
58
+
59
+ column: str # original column name in the data
60
+ ref_level: object # reference level (omitted baseline)
61
+
62
+ def expand(self, df) -> tuple[list[str], dict]:
63
+ """Return (list of dummy column names, {dummy_name: level}) for this column.
64
+
65
+ Does NOT mutate *df*.
66
+ """
67
+ import pandas as pd
68
+
69
+ unique_vals = sorted(df[self.column].dropna().unique())
70
+ non_ref = [v for v in unique_vals if v != self.ref_level]
71
+ names: list[str] = []
72
+ mapping: dict[str, object] = {}
73
+ for level in non_ref:
74
+ name = f"{self.column}_L{level}"
75
+ names.append(name)
76
+ mapping[name] = level
77
+ return names, mapping
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class HeterogeneityInteraction:
82
+ """An attribute x demographic interaction term (legacy, kept for backward compat)."""
83
+
84
+ attribute: str # name of the attribute variable
85
+ demographic_col: str # name of the demographic column in the data
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class InteractionTerm:
90
+ """An arbitrary N-way interaction term: product of specified columns."""
91
+
92
+ columns: tuple[str, ...]
93
+
94
+ def __post_init__(self) -> None:
95
+ if len(self.columns) < 2:
96
+ raise ValueError("InteractionTerm requires at least 2 columns.")
97
+
98
+ @property
99
+ def name(self) -> str:
100
+ return "_x_".join(self.columns)
101
+
102
+
103
+ @dataclass
104
+ class FullModelSpec:
105
+ """Complete model specification -- one object captures everything."""
106
+
107
+ # Data layout
108
+ id_col: str
109
+ task_col: str
110
+ alt_col: str
111
+ choice_col: str
112
+
113
+ # Variable specifications
114
+ variables: list[VariableSpec]
115
+
116
+ # Model type
117
+ model_type: str = "mixed" # "conditional", "mixed", "gmnl", "latent_class"
118
+
119
+ # Dummy coding: backend expands these columns into dummy variables
120
+ dummy_codings: list[DummyCoding] = field(default_factory=list)
121
+
122
+ # Interaction terms (N-way, any columns) -- works for ALL model types
123
+ interactions: list[InteractionTerm] = field(default_factory=list)
124
+
125
+ # Correlation structure
126
+ correlated: bool = False
127
+ correlation_groups: list[list[int]] | None = None
128
+
129
+ # BWS
130
+ bws_worst_col: str | None = None
131
+ estimate_lambda_w: bool = True
132
+
133
+ # Latent class
134
+ n_classes: int = 2
135
+ membership_cols: list[str] | None = None
136
+
137
+ # Estimation settings
138
+ n_draws: int = 200
139
+ maxiter: int = 300
140
+ seed: int = 123
141
+ n_starts: int = 10
142
+
143
+ def __post_init__(self) -> None:
144
+ valid_types = {"conditional", "mixed", "gmnl", "latent_class"}
145
+ if self.model_type not in valid_types:
146
+ raise ValueError(
147
+ f"model_type must be one of {valid_types}, got '{self.model_type}'."
148
+ )
149
+ core_cols = [self.id_col, self.task_col, self.alt_col, self.choice_col]
150
+ if any(not c for c in core_cols):
151
+ raise ValueError("id_col, task_col, alt_col, and choice_col must all be set.")
152
+ if len(self.variables) == 0:
153
+ raise ValueError("At least one variable is required.")
154
+ if self.n_draws < 1:
155
+ raise ValueError("n_draws must be >= 1.")
156
+
157
+ def to_model_spec(self) -> ModelSpec:
158
+ """Convert to the legacy ModelSpec for backward compatibility."""
159
+ return ModelSpec(
160
+ id_col=self.id_col,
161
+ task_col=self.task_col,
162
+ alt_col=self.alt_col,
163
+ choice_col=self.choice_col,
164
+ variables=list(self.variables),
165
+ n_draws=self.n_draws,
166
+ n_classes=self.n_classes,
167
+ membership_cols=self.membership_cols,
168
+ )
src/dce_analyzer/pipeline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+ import torch
9
+
10
+ from .config import FullModelSpec, ModelSpec, VariableSpec
11
+ from .data import ChoiceTensors, prepare_choice_tensors
12
+ from .latent_class import LatentClassEstimator, LatentClassResult
13
+ from .model import ConditionalLogitEstimator, EstimationResult, GmnlEstimator, MixedLogitEstimator
14
+
15
+
16
+ @dataclass
17
+ class PipelineResult:
18
+ tensors: ChoiceTensors
19
+ estimation: EstimationResult | LatentClassResult
20
+ wtp: pd.DataFrame | None = field(default=None)
21
+ expanded_spec: ModelSpec | None = field(default=None, repr=False)
22
+ expanded_df: pd.DataFrame | None = field(default=None, repr=False)
23
+
24
+
25
+ def estimate_dataframe(
26
+ df: pd.DataFrame,
27
+ spec: ModelSpec,
28
+ model_type: str = "mixed",
29
+ maxiter: int = 300,
30
+ seed: int = 123,
31
+ device: torch.device | None = None,
32
+ n_classes: int | None = None,
33
+ n_starts: int = 10,
34
+ correlated: bool = False,
35
+ membership_cols: list[str] | None = None,
36
+ correlation_groups: list[list[int]] | None = None,
37
+ bws_worst_col: str | None = None,
38
+ estimate_lambda_w: bool = True,
39
+ ) -> PipelineResult:
40
+ tensors = prepare_choice_tensors(df, spec, device=device)
41
+
42
+ # Prepare BWS data if worst column specified
43
+ bws_data = None
44
+ if bws_worst_col:
45
+ from .bws import prepare_bws_data, validate_bws
46
+
47
+ validate_bws(df, spec, bws_worst_col)
48
+ bws_data = prepare_bws_data(
49
+ df,
50
+ spec,
51
+ bws_worst_col,
52
+ tensors.n_obs,
53
+ tensors.n_alts,
54
+ tensors.X.device,
55
+ estimate_lambda_w=estimate_lambda_w,
56
+ )
57
+
58
+ if model_type == "mixed":
59
+ estimator = MixedLogitEstimator(
60
+ tensors=tensors,
61
+ variables=spec.variables,
62
+ n_draws=spec.n_draws,
63
+ device=tensors.X.device,
64
+ seed=seed,
65
+ correlated=correlated,
66
+ correlation_groups=correlation_groups,
67
+ bws_data=bws_data,
68
+ )
69
+ return PipelineResult(tensors=tensors, estimation=estimator.fit(maxiter=maxiter))
70
+ elif model_type == "conditional":
71
+ estimator = ConditionalLogitEstimator(
72
+ tensors=tensors,
73
+ variables=spec.variables,
74
+ device=tensors.X.device,
75
+ seed=seed,
76
+ bws_data=bws_data,
77
+ )
78
+ return PipelineResult(tensors=tensors, estimation=estimator.fit(maxiter=maxiter))
79
+ elif model_type == "gmnl":
80
+ estimator = GmnlEstimator(
81
+ tensors=tensors,
82
+ variables=spec.variables,
83
+ n_draws=spec.n_draws,
84
+ device=tensors.X.device,
85
+ seed=seed,
86
+ bws_data=bws_data,
87
+ correlated=correlated,
88
+ correlation_groups=correlation_groups,
89
+ )
90
+ return PipelineResult(tensors=tensors, estimation=estimator.fit(maxiter=maxiter))
91
+ elif model_type == "latent_class":
92
+ q = n_classes if n_classes is not None else spec.n_classes
93
+ mc = membership_cols or spec.membership_cols
94
+ lc_estimator = LatentClassEstimator(
95
+ tensors=tensors,
96
+ variables=spec.variables,
97
+ n_classes=q,
98
+ device=tensors.X.device,
99
+ seed=seed,
100
+ membership_cols=mc,
101
+ df=df,
102
+ id_col=spec.id_col,
103
+ bws_data=bws_data,
104
+ )
105
+ return PipelineResult(
106
+ tensors=tensors,
107
+ estimation=lc_estimator.fit(maxiter=maxiter, n_starts=n_starts),
108
+ )
109
+ else:
110
+ raise ValueError(
111
+ "model_type must be 'mixed', 'conditional', 'gmnl', or 'latent_class'."
112
+ )
113
+
114
+
115
+ def estimate_from_spec(
116
+ df: pd.DataFrame,
117
+ spec: FullModelSpec,
118
+ device: torch.device | None = None,
119
+ ) -> PipelineResult:
120
+ """Single entry-point: all configuration comes from *spec*.
121
+
122
+ 1. Dummy-coded columns are materialised from *spec.dummy_codings*.
123
+ 2. Heterogeneity interactions are materialised as interaction columns.
124
+ Both are appended as fixed VariableSpecs before estimation.
125
+ """
126
+ df = df.copy()
127
+
128
+ # โ”€โ”€ Expand dummy-coded variables โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
129
+ dummy_cols = {dc.column for dc in spec.dummy_codings}
130
+ # Build mapping: original column -> list of expanded VariableSpecs
131
+ _dummy_expansions: dict[str, list[VariableSpec]] = {}
132
+
133
+ for dc in spec.dummy_codings:
134
+ matched = [v for v in spec.variables if v.column == dc.column]
135
+ if not matched:
136
+ raise ValueError(
137
+ f"Dummy coding column '{dc.column}' not found in variables."
138
+ )
139
+ dummy_names, mapping = dc.expand(df)
140
+ for dname, level in mapping.items():
141
+ df[dname] = (df[dc.column] == level).astype(int)
142
+
143
+ base_var = matched[0]
144
+ _dummy_expansions[dc.column] = [
145
+ VariableSpec(name=dname, column=dname, distribution=base_var.distribution)
146
+ for dname in dummy_names
147
+ ]
148
+
149
+ # Build final variable list: replace each dummy placeholder in-place
150
+ # to preserve the UI's variable ordering (critical for correlation_groups)
151
+ all_variables: list[VariableSpec] = []
152
+ for v in spec.variables:
153
+ if v.column in dummy_cols:
154
+ all_variables.extend(_dummy_expansions[v.column])
155
+ else:
156
+ all_variables.append(v)
157
+
158
+ extra_vars: list[VariableSpec] = []
159
+
160
+ for inter in spec.interactions:
161
+ col_name = inter.name
162
+ for col in inter.columns:
163
+ if col not in df.columns:
164
+ raise ValueError(
165
+ f"Interaction column '{col}' not found in data."
166
+ )
167
+ product = df[inter.columns[0]].astype(float)
168
+ for col in inter.columns[1:]:
169
+ product = product * df[col].astype(float)
170
+ df[col_name] = product
171
+ extra_vars.append(VariableSpec(name=col_name, column=col_name, distribution="fixed"))
172
+
173
+ all_variables = all_variables + extra_vars
174
+
175
+ model_spec = ModelSpec(
176
+ id_col=spec.id_col,
177
+ task_col=spec.task_col,
178
+ alt_col=spec.alt_col,
179
+ choice_col=spec.choice_col,
180
+ variables=all_variables,
181
+ n_draws=spec.n_draws,
182
+ n_classes=spec.n_classes,
183
+ membership_cols=spec.membership_cols,
184
+ )
185
+
186
+ result = estimate_dataframe(
187
+ df=df,
188
+ spec=model_spec,
189
+ model_type=spec.model_type,
190
+ maxiter=spec.maxiter,
191
+ seed=spec.seed,
192
+ device=device,
193
+ n_classes=spec.n_classes,
194
+ n_starts=spec.n_starts,
195
+ correlated=spec.correlated,
196
+ membership_cols=spec.membership_cols,
197
+ correlation_groups=spec.correlation_groups,
198
+ bws_worst_col=spec.bws_worst_col,
199
+ estimate_lambda_w=spec.estimate_lambda_w,
200
+ )
201
+ result.expanded_spec = model_spec
202
+ result.expanded_df = df
203
+ return result
204
+
205
+
206
+ def save_estimation_outputs(estimation: EstimationResult | LatentClassResult, output_prefix: str | Path) -> None:
207
+ output_prefix = Path(output_prefix)
208
+ if output_prefix.suffix:
209
+ output_prefix = output_prefix.with_suffix("")
210
+ output_prefix.parent.mkdir(parents=True, exist_ok=True)
211
+
212
+ estimates_path = output_prefix.parent / f"{output_prefix.name}_estimates.csv"
213
+ summary_path = output_prefix.parent / f"{output_prefix.name}_summary.json"
214
+
215
+ estimation.estimates.to_csv(estimates_path, index=False)
216
+ with open(summary_path, "w", encoding="utf-8") as handle:
217
+ json.dump(estimation.summary_dict(), handle, indent=2, default=str)