tiffany101 commited on
Commit
e471533
Β·
verified Β·
1 Parent(s): e66f927

Update app/app_chat.py

Browse files
Files changed (1) hide show
  1. app/app_chat.py +94 -69
app/app_chat.py CHANGED
@@ -8,42 +8,66 @@ import os
8
  import boto3
9
  from botocore.config import Config
10
 
11
- AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
12
-
13
- bedrock = boto3.client(
14
- "bedrock-runtime",
15
- region_name=AWS_REGION,
16
- config=Config(read_timeout=60, connect_timeout=60, retries={"max_attempts": 3}),
17
- )
18
 
19
- # Choose a Nova model ID available in your Bedrock account.
20
- # Common examples include Nova Lite / Pro in Bedrock.
21
  NOVA_MODEL_ID = os.getenv("NOVA_MODEL_ID", "us.amazon.nova-lite-v1:0")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
24
 
25
  st.set_page_config(page_title="EchoML", page_icon="πŸ’¬", layout="wide")
26
- st.title("πŸ’¬ Chat with Your Model(IRIS Edition)")
27
 
28
  # Sidebar configuration
29
  with st.sidebar:
30
  st.header("Settings")
31
- api_url = st.text_input("FastAPI endpoint", value="https://query-your-model-api-784882848382.us-central1.run.app/explain")
32
- model_path = st.text_input("Model path", value="Query_Your_Model/model_data/model.pkl")
 
 
 
33
  feat_names_str = st.text_input(
34
  "Feature names (comma-separated)",
35
- value="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)"
 
 
 
36
  )
37
- namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global")
38
  alpha = st.slider("Alpha (retrieval weight)", 0.0, 1.0, 0.7, 0.05)
39
  k = st.slider("Top-K similar to retrieve", 1, 10, 5)
40
 
 
 
 
 
41
 
42
- feat_names = [s.strip() for s in feat_names_str.split(",")]
43
-
44
-
45
 
 
46
  # Helpers
 
 
47
  def label_from_pred(y_pred):
48
  try:
49
  num = int(round(float(y_pred)))
@@ -66,10 +90,10 @@ def summarize_prediction(res):
66
 
67
  def show_similar_cases(res, n_display, feat_names):
68
  sims = safe_similar_cases(res)
69
- if not sims:
70
  return "No similar cases were retrieved."
71
  n = min(n_display, len(sims))
72
- lines = [f"It found **{len(sims)}** similar past cases (showing **{n}**):"]
73
  for case in sims[:n]:
74
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
75
  lines.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
@@ -82,7 +106,7 @@ def plot_shap_bar(topk):
82
  feats = [f["feature"] for f in topk]
83
  shap_vals = [f["shap"] for f in topk]
84
  fig, ax = plt.subplots()
85
- ax.barh(feats, shap_vals) # default colors per instructions
86
  ax.set_xlabel("SHAP value (impact on prediction)")
87
  ax.set_title("Feature importance for this prediction")
88
  st.pyplot(fig)
@@ -119,7 +143,29 @@ def explain_in_words(res, n_display, feat_names):
119
  st.markdown("\n".join(msg))
120
  plot_shap_bar(topk)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def llm_explain(res, feat_names, extra_context=None):
 
 
 
123
  try:
124
  pred = label_from_pred(res["prediction"]["y_pred"])
125
  proba = res["prediction"]["proba"]
@@ -141,20 +187,14 @@ def llm_explain(res, feat_names, extra_context=None):
141
  "- Why the model made the prediction\n"
142
  "- Which features mattered\n"
143
  "- Why those features mattered\n"
144
- "- Experiments the user could perform: tell them which feature values to increase/decrease"
145
- )
146
-
147
- response = bedrock.converse(
148
- modelId=NOVA_MODEL_ID,
149
- messages=[{"role": "user", "content": [{"text": prompt}]}],
150
  )
151
 
152
- return response["output"]["message"]["content"][0]["text"]
153
 
154
  except Exception as e:
155
  return f"LLM explanation failed: {e}"
156
 
157
-
158
  def interpret_question(user_q):
159
  q = (user_q or "").lower()
160
  if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
@@ -173,9 +213,13 @@ def perform_api_call(features):
173
  "namespace": namespace,
174
  "retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
175
  }
176
- return requests.post(api_url, json=payload).json()
 
 
177
 
 
178
  # What-if parsing
 
179
 
180
  FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
181
 
@@ -204,24 +248,13 @@ def default_delta(curr):
204
  def apply_what_if(user_q, feat_names, current_features):
205
  """
206
  Returns (new_features, change_text) or (None, error_msg)
207
-
208
- Handles:
209
- - "increase sepal length to 5.8"
210
- - "decrease petal width by 0.2"
211
- - "increase sepal width by 10%"
212
- - "make petal length higher"
213
- - "slightly increase sepal width"
214
- - "increase sepal length to 5.8, decrease petal width to 0.2 and reduce sepal width a bit"
215
  """
216
-
217
  q = user_q.lower()
218
  new = current_features.copy()
219
  changes = []
220
 
221
- # Split query by commas or 'and'
222
  parts = re.split(r",| and ", q)
223
 
224
- # Define modifiers with scaling factors (relative to default 10%)
225
  modifier_scale = {
226
  "slightly": 0.5,
227
  "a bit": 0.5,
@@ -238,7 +271,6 @@ def apply_what_if(user_q, feat_names, current_features):
238
  if not part:
239
  continue
240
 
241
- # Detect intensity modifier
242
  scale = 1.0
243
  for mod, factor in modifier_scale.items():
244
  if mod in part:
@@ -246,7 +278,7 @@ def apply_what_if(user_q, feat_names, current_features):
246
  part = part.replace(mod, "")
247
  break
248
 
249
- # 1) Direct set: "= X" or "to X"
250
  m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
251
  if m:
252
  feat_frag, val_str = m.group(1), m.group(2)
@@ -263,7 +295,7 @@ def apply_what_if(user_q, feat_names, current_features):
263
  changes.append(f"Set **{fname}** to **{val:.2f}**.")
264
  continue
265
 
266
- # 2) increase/decrease by absolute value
267
  m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
268
  if m:
269
  op, feat_frag, val_str = m.groups()
@@ -276,16 +308,13 @@ def apply_what_if(user_q, feat_names, current_features):
276
  changes.append(f"Couldn't parse a number from: '{val_str}'.")
277
  continue
278
  delta *= scale
279
- if op in ["decrease", "reduce"]:
280
- delta = -abs(delta)
281
- else:
282
- delta = abs(delta)
283
  idx = feat_names.index(fname)
284
  new[idx] = new[idx] + delta
285
  changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β†’ **{new[idx]:.2f}**.")
286
  continue
287
 
288
- # 3) increase/decrease by percent
289
  m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
290
  if m:
291
  op, feat_frag, perc_str = m.groups()
@@ -299,12 +328,12 @@ def apply_what_if(user_q, feat_names, current_features):
299
  continue
300
  perc *= scale
301
  idx = feat_names.index(fname)
302
- factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0)
303
  new[idx] = new[idx] * factor
304
  changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β†’ **{new[idx]:.2f}**.")
305
  continue
306
 
307
- # 4) make X higher/lower (no number) β†’ Β±10% default * modifier scale
308
  m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
309
  if m:
310
  _, feat_frag, direction = m.groups()
@@ -327,9 +356,11 @@ def apply_what_if(user_q, feat_names, current_features):
327
 
328
  return new, "\n".join(changes)
329
 
330
-
331
  # Step 1: Enter features & predict
332
- st.subheader("Step 1 – Enter features to generate a prediction in the order: [sepal length,sepal width,petal length and petal width]")
 
 
333
  user_features = st.text_input("Enter feature values (comma-separated)", "")
334
  predict_btn = st.button("πŸ” Predict and Explain")
335
 
@@ -345,19 +376,21 @@ if predict_btn:
345
  st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
346
  else:
347
  st.session_state["input_features"] = features
348
- # Show entered features
349
  st.markdown("### ✨ Entered Features")
350
  st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
351
- # Call API
352
  res = perform_api_call(features)
353
  st.session_state["prediction_result"] = res
354
  st.session_state["messages"] = []
355
  st.success(summarize_prediction(res))
356
- st.info("Scroll down to explore similar cases or chat ")
357
  except Exception as e:
358
- st.error(f"Error contacting API: {e}")
 
 
 
 
359
 
360
- # Step 2: Similar cases
361
  if st.session_state["prediction_result"]:
362
  st.divider()
363
  st.subheader("Step 2 – Explore similar cases")
@@ -376,26 +409,22 @@ if st.session_state["prediction_result"]:
376
  else:
377
  st.write("No similar cases retrieved.")
378
 
379
- # Step 3: Explanation Mode + Chat
380
  st.divider()
381
  st.subheader("Step 3 – Chat with the model about this prediction")
382
 
383
- # Choose explanation mode BEFORE asking questions
384
  if "chat_mode" not in st.session_state:
385
  st.session_state["chat_mode"] = "System"
386
  st.session_state["chat_mode"] = st.radio(
387
  "How should explanations be generated?",
388
  ["System", "LLM (Natural language)"],
389
- index=0 if "System" in st.session_state["chat_mode"] else 1,
390
  horizontal=True,
391
  )
392
 
393
- # Show previous messages
394
  for role, content in st.session_state["messages"]:
395
  with st.chat_message(role):
396
  st.markdown(content)
397
 
398
- # Chat input
399
  if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
400
  st.session_state["messages"].append(("user", user_q))
401
  with st.chat_message("user"):
@@ -403,7 +432,6 @@ if st.session_state["prediction_result"]:
403
 
404
  intent = interpret_question(user_q)
405
 
406
- # Current base result
407
  base_res = st.session_state["prediction_result"]
408
  base_pred = base_res["prediction"]["y_pred"]
409
  base_proba = base_res["prediction"]["proba"]
@@ -411,7 +439,7 @@ if st.session_state["prediction_result"]:
411
 
412
  if intent == "explain":
413
  if "LLM" in st.session_state["chat_mode"]:
414
- with st.spinner(" Generating LLM explanation..."):
415
  answer = llm_explain(base_res, feat_names)
416
  st.session_state["messages"].append(("assistant", answer))
417
  with st.chat_message("assistant"):
@@ -428,7 +456,6 @@ if st.session_state["prediction_result"]:
428
  st.markdown(text)
429
 
430
  elif intent == "what_if":
431
- # Parse & apply change, recompute, compare
432
  if st.session_state["input_features"] is None:
433
  msg = "Please run a prediction first (Step 1) so I know your starting feature values."
434
  st.session_state["messages"].append(("assistant", msg))
@@ -452,15 +479,14 @@ if st.session_state["prediction_result"]:
452
  ctx = {
453
  "change_applied": status,
454
  "before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
455
- "after": {"features": new_feats, "label": new_label, "proba": new_proba}
456
  }
457
- with st.spinner(" Summarizing the effect with LLM..."):
458
  answer = llm_explain(new_res, feat_names, extra_context=ctx)
459
  st.session_state["messages"].append(("assistant", answer))
460
  with st.chat_message("assistant"):
461
  st.markdown(answer)
462
  else:
463
- # System comparison summary + new SHAP chart
464
  lines = [
465
  f"**Change applied:** {status}",
466
  f"**Before:** {base_label} (class `{base_pred}`) β€” confidence **{base_proba:.2f}**",
@@ -473,7 +499,6 @@ if st.session_state["prediction_result"]:
473
  st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
474
 
475
  else:
476
- # Summary fallback
477
  summary = summarize_prediction(base_res)
478
  st.session_state["messages"].append(("assistant", summary))
479
  with st.chat_message("assistant"):
 
8
  import boto3
9
  from botocore.config import Config
10
 
11
+ # ============================================================
12
+ # Bedrock (Amazon Nova) setup
13
+ # ============================================================
 
 
 
 
14
 
15
+ AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
 
16
  NOVA_MODEL_ID = os.getenv("NOVA_MODEL_ID", "us.amazon.nova-lite-v1:0")
17
 
18
+ def make_bedrock_client():
19
+ """
20
+ Creates a Bedrock Runtime client.
21
+ Auth is provided via Hugging Face Secrets / env vars:
22
+ - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
23
+ - AWS_REGION
24
+ """
25
+ # If key not set, we'll still create client, but calls will fail with a clear error later.
26
+ return boto3.client(
27
+ "bedrock-runtime",
28
+ region_name=AWS_REGION,
29
+ config=Config(read_timeout=60, connect_timeout=60, retries={"max_attempts": 3}),
30
+ )
31
+
32
+ bedrock = make_bedrock_client()
33
 
34
+ # ============================================================
35
+ # Streamlit App
36
+ # ============================================================
37
 
38
  st.set_page_config(page_title="EchoML", page_icon="πŸ’¬", layout="wide")
39
+ st.title("πŸ’¬ Chat with Your Model (IRIS Edition)")
40
 
41
  # Sidebar configuration
42
  with st.sidebar:
43
  st.header("Settings")
44
+ api_url = st.text_input(
45
+ "FastAPI endpoint",
46
+ value=os.getenv("ECHO_API_URL", "https://query-your-model-api-784882848382.us-central1.run.app/explain"),
47
+ )
48
+ model_path = st.text_input("Model path", value=os.getenv("MODEL_PATH", "Query_Your_Model/model_data/model.pkl"))
49
  feat_names_str = st.text_input(
50
  "Feature names (comma-separated)",
51
+ value=os.getenv(
52
+ "FEATURE_NAMES",
53
+ "sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)"
54
+ ),
55
  )
56
+ namespace = st.text_input("Namespace", value=os.getenv("NAMESPACE", "Query_Your_Model/data/base_indices/iris_global"))
57
  alpha = st.slider("Alpha (retrieval weight)", 0.0, 1.0, 0.7, 0.05)
58
  k = st.slider("Top-K similar to retrieve", 1, 10, 5)
59
 
60
+ st.divider()
61
+ st.caption("LLM provider: Amazon Bedrock (Nova)")
62
+ st.caption(f"AWS_REGION: `{AWS_REGION}`")
63
+ st.caption(f"NOVA_MODEL_ID: `{NOVA_MODEL_ID}`")
64
 
65
+ feat_names = [s.strip() for s in feat_names_str.split(",") if s.strip()]
 
 
66
 
67
+ # ============================================================
68
  # Helpers
69
+ # ============================================================
70
+
71
  def label_from_pred(y_pred):
72
  try:
73
  num = int(round(float(y_pred)))
 
90
 
91
  def show_similar_cases(res, n_display, feat_names):
92
  sims = safe_similar_cases(res)
93
+ if not sims:
94
  return "No similar cases were retrieved."
95
  n = min(n_display, len(sims))
96
+ lines = [f"It found **{len(sims)}** similar cases (showing **{n}**):"]
97
  for case in sims[:n]:
98
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
99
  lines.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
 
106
  feats = [f["feature"] for f in topk]
107
  shap_vals = [f["shap"] for f in topk]
108
  fig, ax = plt.subplots()
109
+ ax.barh(feats, shap_vals) # default colors
110
  ax.set_xlabel("SHAP value (impact on prediction)")
111
  ax.set_title("Feature importance for this prediction")
112
  st.pyplot(fig)
 
143
  st.markdown("\n".join(msg))
144
  plot_shap_bar(topk)
145
 
146
+ def bedrock_llm(prompt: str) -> str:
147
+ """
148
+ Calls Amazon Nova via Bedrock Converse API.
149
+ Requires Hugging Face Secret: AWS_BEARER_TOKEN_BEDROCK
150
+ """
151
+ if not os.getenv("AWS_BEARER_TOKEN_BEDROCK"):
152
+ return (
153
+ "LLM explanation is not available because `AWS_BEARER_TOKEN_BEDROCK` is not set.\n\n"
154
+ "In Hugging Face Spaces β†’ Settings β†’ Variables and secrets β†’ Secrets, add:\n"
155
+ "- Name: AWS_BEARER_TOKEN_BEDROCK\n"
156
+ "- Value: (your Bedrock API key)\n"
157
+ )
158
+
159
+ resp = bedrock.converse(
160
+ modelId=NOVA_MODEL_ID,
161
+ messages=[{"role": "user", "content": [{"text": prompt}]}],
162
+ )
163
+ return resp["output"]["message"]["content"][0]["text"]
164
+
165
  def llm_explain(res, feat_names, extra_context=None):
166
+ """
167
+ LLM explanation: can handle 'why' and 'what-if' using the provided context (old/new).
168
+ """
169
  try:
170
  pred = label_from_pred(res["prediction"]["y_pred"])
171
  proba = res["prediction"]["proba"]
 
187
  "- Why the model made the prediction\n"
188
  "- Which features mattered\n"
189
  "- Why those features mattered\n"
190
+ "- 2-3 experiments the user could perform: tell them which feature values to increase/decrease"
 
 
 
 
 
191
  )
192
 
193
+ return bedrock_llm(prompt)
194
 
195
  except Exception as e:
196
  return f"LLM explanation failed: {e}"
197
 
 
198
  def interpret_question(user_q):
199
  q = (user_q or "").lower()
200
  if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
 
213
  "namespace": namespace,
214
  "retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
215
  }
216
+ r = requests.post(api_url, json=payload, timeout=60)
217
+ r.raise_for_status()
218
+ return r.json()
219
 
220
+ # ============================================================
221
  # What-if parsing
222
+ # ============================================================
223
 
224
  FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
225
 
 
248
  def apply_what_if(user_q, feat_names, current_features):
249
  """
250
  Returns (new_features, change_text) or (None, error_msg)
 
 
 
 
 
 
 
 
251
  """
 
252
  q = user_q.lower()
253
  new = current_features.copy()
254
  changes = []
255
 
 
256
  parts = re.split(r",| and ", q)
257
 
 
258
  modifier_scale = {
259
  "slightly": 0.5,
260
  "a bit": 0.5,
 
271
  if not part:
272
  continue
273
 
 
274
  scale = 1.0
275
  for mod, factor in modifier_scale.items():
276
  if mod in part:
 
278
  part = part.replace(mod, "")
279
  break
280
 
281
+ # 1) Set to value
282
  m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
283
  if m:
284
  feat_frag, val_str = m.group(1), m.group(2)
 
295
  changes.append(f"Set **{fname}** to **{val:.2f}**.")
296
  continue
297
 
298
+ # 2) Increase/decrease by absolute value
299
  m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
300
  if m:
301
  op, feat_frag, val_str = m.groups()
 
308
  changes.append(f"Couldn't parse a number from: '{val_str}'.")
309
  continue
310
  delta *= scale
311
+ delta = -abs(delta) if op in ["decrease", "reduce"] else abs(delta)
 
 
 
312
  idx = feat_names.index(fname)
313
  new[idx] = new[idx] + delta
314
  changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β†’ **{new[idx]:.2f}**.")
315
  continue
316
 
317
+ # 3) Increase/decrease by percent
318
  m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
319
  if m:
320
  op, feat_frag, perc_str = m.groups()
 
328
  continue
329
  perc *= scale
330
  idx = feat_names.index(fname)
331
+ factor = 1.0 + (abs(perc) / 100.0 if op in ["increase", "raise"] else -abs(perc) / 100.0)
332
  new[idx] = new[idx] * factor
333
  changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β†’ **{new[idx]:.2f}**.")
334
  continue
335
 
336
+ # 4) Make higher/lower (no number)
337
  m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
338
  if m:
339
  _, feat_frag, direction = m.groups()
 
356
 
357
  return new, "\n".join(changes)
358
 
359
+ # ============================================================
360
  # Step 1: Enter features & predict
361
+ # ============================================================
362
+
363
+ st.subheader("Step 1 – Enter features to generate a prediction in the order: [sepal length, sepal width, petal length, petal width]")
364
  user_features = st.text_input("Enter feature values (comma-separated)", "")
365
  predict_btn = st.button("πŸ” Predict and Explain")
366
 
 
376
  st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
377
  else:
378
  st.session_state["input_features"] = features
 
379
  st.markdown("### ✨ Entered Features")
380
  st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
381
+
382
  res = perform_api_call(features)
383
  st.session_state["prediction_result"] = res
384
  st.session_state["messages"] = []
385
  st.success(summarize_prediction(res))
386
+ st.info("Scroll down to explore similar cases or chat.")
387
  except Exception as e:
388
+ st.error(f"Error generating prediction: {e}")
389
+
390
+ # ============================================================
391
+ # Step 2 + 3
392
+ # ============================================================
393
 
 
394
  if st.session_state["prediction_result"]:
395
  st.divider()
396
  st.subheader("Step 2 – Explore similar cases")
 
409
  else:
410
  st.write("No similar cases retrieved.")
411
 
 
412
  st.divider()
413
  st.subheader("Step 3 – Chat with the model about this prediction")
414
 
 
415
  if "chat_mode" not in st.session_state:
416
  st.session_state["chat_mode"] = "System"
417
  st.session_state["chat_mode"] = st.radio(
418
  "How should explanations be generated?",
419
  ["System", "LLM (Natural language)"],
420
+ index=0 if st.session_state["chat_mode"] == "System" else 1,
421
  horizontal=True,
422
  )
423
 
 
424
  for role, content in st.session_state["messages"]:
425
  with st.chat_message(role):
426
  st.markdown(content)
427
 
 
428
  if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
429
  st.session_state["messages"].append(("user", user_q))
430
  with st.chat_message("user"):
 
432
 
433
  intent = interpret_question(user_q)
434
 
 
435
  base_res = st.session_state["prediction_result"]
436
  base_pred = base_res["prediction"]["y_pred"]
437
  base_proba = base_res["prediction"]["proba"]
 
439
 
440
  if intent == "explain":
441
  if "LLM" in st.session_state["chat_mode"]:
442
+ with st.spinner("Generating LLM explanation (Nova)..."):
443
  answer = llm_explain(base_res, feat_names)
444
  st.session_state["messages"].append(("assistant", answer))
445
  with st.chat_message("assistant"):
 
456
  st.markdown(text)
457
 
458
  elif intent == "what_if":
 
459
  if st.session_state["input_features"] is None:
460
  msg = "Please run a prediction first (Step 1) so I know your starting feature values."
461
  st.session_state["messages"].append(("assistant", msg))
 
479
  ctx = {
480
  "change_applied": status,
481
  "before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
482
+ "after": {"features": new_feats, "label": new_label, "proba": new_proba},
483
  }
484
+ with st.spinner("Summarizing the effect (Nova)..."):
485
  answer = llm_explain(new_res, feat_names, extra_context=ctx)
486
  st.session_state["messages"].append(("assistant", answer))
487
  with st.chat_message("assistant"):
488
  st.markdown(answer)
489
  else:
 
490
  lines = [
491
  f"**Change applied:** {status}",
492
  f"**Before:** {base_label} (class `{base_pred}`) β€” confidence **{base_proba:.2f}**",
 
499
  st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
500
 
501
  else:
 
502
  summary = summarize_prediction(base_res)
503
  st.session_state["messages"].append(("assistant", summary))
504
  with st.chat_message("assistant"):