tiffany101 commited on
Commit
bf4c1bc
Β·
verified Β·
1 Parent(s): 2f8408d

Update app/app_chat.py

Browse files
Files changed (1) hide show
  1. app/app_chat.py +275 -254
app/app_chat.py CHANGED
@@ -1,430 +1,451 @@
1
  import re
2
  import json
3
- import math
4
  import matplotlib.pyplot as plt
5
  import streamlit as st
6
  import requests
7
- from openai import OpenAI
 
 
 
8
 
9
- #titles
10
- st.set_page_config(page_title="Chat with Your Model", page_icon="πŸ’¬", layout="wide")
11
- st.title("πŸ’¬ Chat with Your Model")
12
 
 
 
13
 
14
- # ----------------------------------
15
  # Sidebar configuration
16
- # ----------------------------------
17
  with st.sidebar:
18
  st.header("Settings")
19
- api_url = st.text_input("FastAPI endpoint", value="http://127.0.0.1:8000/explain")
 
 
 
 
20
  model_path = st.text_input("Model path", value="Query_Your_Model/model_data/model.pkl")
 
21
  feat_names_str = st.text_input(
22
  "Feature names (comma-separated)",
23
- value="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)"
24
  )
 
25
  namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global")
26
- alpha = st.slider("Alpha (retrieval weight)", 0.0, 1.0, 0.7, 0.05)
 
 
 
27
  k = st.slider("Top-K similar to retrieve", 1, 10, 5)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  feat_names = [s.strip() for s in feat_names_str.split(",")]
30
 
31
- # OpenAI client (set OPENAI_API_KEY in .streamlit/secrets.toml)
32
- client = OpenAI(api_key="")
 
33
 
34
- # ------------------------------------------------
35
  # Helpers
36
- # ------------------------------------------------
37
  def label_from_pred(y_pred):
38
  try:
39
  num = int(round(float(y_pred)))
40
- mapping = {0: "setosa", 1: "versicolor", 2: "virginica"}
41
- return mapping.get(num, str(y_pred))
42
- except Exception:
43
- return str(y_pred)
44
-
45
- def safe_topk_list(res):
46
- return res.get("explanation", {}).get("topk", []) or []
47
-
48
- def safe_similar_cases(res):
49
- return res.get("similar_cases", []) or []
50
-
51
- def summarize_prediction(res):
52
- pred = res["prediction"]["y_pred"]
53
- proba = res["prediction"]["proba"]
54
- label = label_from_pred(pred)
55
- return f"🌸 The model predicts **{label}** (class `{pred}`) with confidence **{proba:.2f}**."
56
 
57
  def show_similar_cases(res, n_display, feat_names):
58
  sims = safe_similar_cases(res)
59
- if not sims:
60
  return "No similar cases were retrieved."
61
  n = min(n_display, len(sims))
62
- lines = [f"It found **{len(sims)}** similar past cases (showing **{n}**):"]
63
  for case in sims[:n]:
64
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
65
  lines.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
66
- #The above code was generated by Chatgpt on 11/23/2025 at 3pm.
67
  return "\n".join(lines)
68
 
69
  def plot_shap_bar(topk):
70
- if not topk:
71
- st.write("No SHAP details available for this prediction.")
72
- return
73
  feats = [f["feature"] for f in topk]
74
  shap_vals = [f["shap"] for f in topk]
75
  fig, ax = plt.subplots()
76
- ax.barh(feats, shap_vals) # default colors per instructions
77
  ax.set_xlabel("SHAP value (impact on prediction)")
78
  ax.set_title("Feature importance for this prediction")
79
  st.pyplot(fig)
80
 
81
- def explain_in_words(res, n_display, feat_names):
82
- pred = res["prediction"]["y_pred"]
83
- proba = res["prediction"]["proba"]
84
- label = label_from_pred(pred)
85
- topk = safe_topk_list(res)
86
-
87
  msg = [
88
  f"🌸 Based on these features, the model thinks it's **{label}** (class `{pred}`) with confidence **{proba:.2f}**.\n",
89
- "### Key reasons (SHAP):"
90
  ]
91
 
92
  if topk:
93
- for f in topk:
94
- effect = "increased" if f["shap"] > 0 else "decreased"
95
- msg.append(
96
- f"- **{f['feature']} = {f['value']:.2f}** β†’ {effect} the prediction "
97
- f"(impact **{abs(f['shap']):.2f}**)."
98
- )
99
- else:
100
- msg.append("- No SHAP details available.")
101
-
102
- sims = safe_similar_cases(res)
103
- if sims:
104
- n = min(n_display, len(sims))
105
- msg.append(f"\n### Similar cases (showing {n} of {len(sims)})")
106
  for case in sims[:n]:
107
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
108
  msg.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
109
- #The above code was generated by Chatgpt on 11/23/2025 at 4pm.
110
 
111
  st.markdown("\n".join(msg))
112
  plot_shap_bar(topk)
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def llm_explain(res, feat_names, extra_context=None):
115
- """LLM explanation: can handle 'why' and 'what-if' using the provided context (old/new)."""
116
  try:
117
  pred = label_from_pred(res["prediction"]["y_pred"])
118
  proba = res["prediction"]["proba"]
119
- topk = safe_topk_list(res)
120
- sims = safe_similar_cases(res)
121
-
122
- base_prompt = {
123
- "prediction": pred,
124
  "probability": round(proba, 3),
125
  "topk": topk,
126
  "similar_examples_sample": sims[:3],
127
- "extra_context": extra_context or {}
 
128
  }
129
 
130
  prompt = (
131
  "You are an explainability copilot. Explain to a non-technical user.\n\n"
132
  f"DATA:\n{json.dumps(base_prompt, indent=2)}\n\n"
133
  "Write a short, clear answer that covers:\n"
134
- "- Why the model made the prediction\n"
135
- "- Which features mattered\n"
 
 
136
  )
137
 
138
- response = client.chat.completions.create(
139
- model="gpt-4o-mini",
140
- messages=[{"role": "user", "content": prompt}],
141
- )
142
- return response.choices[0].message.content
143
  except Exception as e:
144
  return f"LLM explanation failed: {e}"
145
 
146
- def interpret_question(user_q):
147
- q = (user_q or "").lower()
148
- if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
149
- return "what_if"
150
- if any(w in q for w in ["why", "explain", "reason"]):
151
- return "explain"
152
- if "similar" in q:
153
- return "similar"
154
- return "summary"
 
 
 
 
 
 
 
 
 
 
155
 
156
- def perform_api_call(features):
157
- payload = {
158
- "model_path": model_path,
159
- "feature_names": feat_names,
160
- "features": features,
161
- "namespace": namespace,
162
- "retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
163
- }
164
- return requests.post(api_url, json=payload).json()
165
 
166
- # --- What-if parsing ---
167
 
168
  FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
169
 
170
  def match_feature_name(fragment, feat_names):
171
- """Best-effort fuzzy match: choose feature with max token overlap (case-insensitive)."""
172
  frag = fragment.strip().lower()
173
  best, best_score = None, -1
174
  for name in feat_names:
175
- n = name.lower()
176
- score = sum(tok in n for tok in re.findall(r"[a-z0-9]+", frag))
177
- if score > best_score:
178
- best, best_score = name, score
179
- return best if best_score > 0 else None
180
-
181
- def parse_numeric(val_str):
182
- try:
183
- return float(val_str)
184
- except:
185
  return None
186
 
187
  def default_delta(curr):
188
- # 10% of current value (min 0.1 abs) as a sensible default tweak
189
  base = abs(curr) * 0.10
190
  return base if base >= 0.1 else (0.1 if curr >= 0 else -0.1)
191
 
192
  def apply_what_if(user_q, feat_names, current_features):
193
- """
194
- Returns (new_features, change_text) or (None, error_msg)
195
- Supports:
196
- - "what if petal length (cm) = 2.5"
197
- - "set sepal width to 3.8"
198
- - "increase petal width by 0.2"
199
- - "decrease sepal length by 10%"
200
- - "make petal length higher" (uses default +10%)
201
- - "make petal width lower" (uses default -10%)
202
- """
203
  q = user_q.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- # 1) Direct set: "= X" or "to X"
206
- m = re.search(r"(?:set|what if|change)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", q)
207
- if m:
208
- feat_frag, val_str = m.group(1), m.group(2)
209
- fname = match_feature_name(feat_frag, feat_names)
210
- if fname is None:
211
- return None, f"Couldn't identify which feature to set from: '{feat_frag}'."
212
- val = parse_numeric(val_str)
213
- if val is None:
214
- return None, f"Couldn't parse a number from: '{val_str}'."
215
- new = current_features.copy()
216
- idx = feat_names.index(fname)
217
- new[idx] = val
218
- return new, f"Set **{fname}** to **{val:.2f}**."
219
- #The above code was generated by Chatgpt on 11/23/2025 at 4:30pm.
220
-
221
- # 2) increase/decrease by absolute amount: "increase X by 0.2"
222
- m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", q)
223
- if m:
224
- op, feat_frag, val_str = m.groups()
225
- fname = match_feature_name(feat_frag, feat_names)
226
- if fname is None:
227
- return None, f"Couldn't identify which feature to adjust from: '{feat_frag}'."
228
- delta = parse_numeric(val_str)
229
- if delta is None:
230
- return None, f"Couldn't parse a number from: '{val_str}'."
231
- if op in ["decrease", "reduce"]:
232
- delta = -abs(delta)
233
- else:
234
- delta = abs(delta)
235
- new = current_features.copy()
236
- idx = feat_names.index(fname)
237
- new_val = new[idx] + delta
238
- new[idx] = new_val
239
- return new, f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β†’ **{new_val:.2f}**."
240
-
241
- # 3) increase/decrease by percent: "decrease X by 10%"
242
- m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", q)
243
- if m:
244
- op, feat_frag, perc_str = m.groups()
245
- fname = match_feature_name(feat_frag, feat_names)
246
- if fname is None:
247
- return None, f"Couldn't identify which feature to adjust from: '{feat_frag}'."
248
- perc = parse_numeric(perc_str)
249
- if perc is None:
250
- return None, f"Couldn't parse a percentage from: '{perc_str}'."
251
- new = current_features.copy()
252
- idx = feat_names.index(fname)
253
- factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0)
254
- new_val = new[idx] * factor
255
- new[idx] = new_val
256
- return new, f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β†’ **{new_val:.2f}**."
257
-
258
- # 4) make X higher/lower (no amount) β†’ default Β±10%
259
- m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", q)
260
- if m:
261
- _, feat_frag, direction = m.groups()
262
- fname = match_feature_name(feat_frag, feat_names)
263
- if fname is None:
264
- return None, f"Couldn't identify which feature to adjust from: '{feat_frag}'."
265
- idx = feat_names.index(fname)
266
- base_delta = default_delta(current_features[idx])
267
- delta = base_delta if direction in ["higher", "increase", "raise"] else -abs(base_delta)
268
- new = current_features.copy()
269
- new_val = new[idx] + delta
270
- new[idx] = new_val
271
- verb = "Increased" if delta>0 else "Decreased"
272
- return new, f"{verb} **{fname}** by **{abs(delta):.2f}** (default) β†’ **{new_val:.2f}**."
273
-
274
- return None, "I couldn’t parse the change. Try: `what if petal length = 2.5`, `increase sepal width by 0.2`, or `decrease petal width by 10%`."
275
-
276
- # ------------------------------------------------
277
- # Step 1: Enter features & predict
278
- # ------------------------------------------------
279
- st.subheader("Step 1 – Enter features to generate a prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  user_features = st.text_input("Enter feature values (comma-separated)", "")
281
  predict_btn = st.button("πŸ” Predict and Explain")
282
 
283
- if "prediction_result" not in st.session_state:
284
- st.session_state["prediction_result"] = None
285
  st.session_state["messages"] = []
286
  st.session_state["input_features"] = None
287
 
 
 
 
288
  if predict_btn:
289
  try:
290
  features = [float(x.strip()) for x in user_features.split(",") if x.strip()]
291
- if len(features) != len(feat_names):
292
  st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
293
  else:
294
  st.session_state["input_features"] = features
295
- # Show entered features
296
  st.markdown("### ✨ Entered Features")
297
  st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
298
- # Call API
299
  res = perform_api_call(features)
300
  st.session_state["prediction_result"] = res
301
  st.session_state["messages"] = []
302
  st.success(summarize_prediction(res))
303
- st.info("Scroll down to explore similar cases or chat πŸ‘‡")
304
  except Exception as e:
305
  st.error(f"Error contacting API: {e}")
306
 
307
- # ------------------------------------------------
308
- # Step 2: Similar cases
309
- # ------------------------------------------------
310
  if st.session_state["prediction_result"]:
311
  st.divider()
312
  st.subheader("Step 2 – Explore similar cases")
313
- res = st.session_state["prediction_result"]
314
- sims = safe_similar_cases(res)
315
- total_cases = len(sims)
316
-
317
- if total_cases > 0:
318
- options = [str(i) for i in range(1, total_cases + 1)] + ["All"]
319
- chosen = st.selectbox(f"The model found {total_cases} similar cases. How many to view?", options)
320
- n_display = total_cases if chosen == "All" else int(chosen)
321
- st.markdown(f"### Showing {n_display} of {total_cases} similar cases:")
322
- for i, case in enumerate(sims[:n_display], start=1):
323
- features_named = ", ".join([f"{n} = {v:.2f}" for n, v in zip(feat_names, case["features"])])
324
- st.markdown(f"**Case {i} β€” {case['case_id']}** \n{features_named} \n**Predicted as:** {case['y_pred']}")
325
  else:
326
  st.write("No similar cases retrieved.")
327
 
328
- # ------------------------------------------------
329
- # Step 3: Explanation Mode + Chat
330
- # ------------------------------------------------
331
  st.divider()
332
  st.subheader("Step 3 – Chat with the model about this prediction")
333
 
334
- # Choose explanation mode BEFORE asking questions
335
  if "chat_mode" not in st.session_state:
336
- st.session_state["chat_mode"] = "System (SHAP-based)"
 
337
  st.session_state["chat_mode"] = st.radio(
338
  "How should explanations be generated?",
339
- ["System (SHAP-based)", "LLM (Natural language)"],
340
- index=0 if "System" in st.session_state["chat_mode"] else 1,
341
  horizontal=True,
342
  )
343
 
344
- # Show previous messages
345
  for role, content in st.session_state["messages"]:
346
  with st.chat_message(role):
347
  st.markdown(content)
348
 
349
- # Chat input
350
  if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
351
  st.session_state["messages"].append(("user", user_q))
352
  with st.chat_message("user"):
353
- st.markdown(user_q)
354
 
355
  intent = interpret_question(user_q)
356
 
357
- # Current base result
358
  base_res = st.session_state["prediction_result"]
359
  base_pred = base_res["prediction"]["y_pred"]
360
  base_proba = base_res["prediction"]["proba"]
361
- base_label = label_from_pred(base_pred)
362
 
363
  if intent == "explain":
364
  if "LLM" in st.session_state["chat_mode"]:
365
- with st.spinner("πŸ’‘ Generating LLM explanation..."):
366
  answer = llm_explain(base_res, feat_names)
367
  st.session_state["messages"].append(("assistant", answer))
368
  with st.chat_message("assistant"):
369
- st.markdown(answer)
370
- else:
371
- with st.chat_message("assistant"):
372
- explain_in_words(base_res, total_cases, feat_names)
373
- st.session_state["messages"].append(("assistant", "System explanation shown above."))
374
-
375
- elif intent == "similar":
376
- text = show_similar_cases(base_res, total_cases, feat_names)
377
- st.session_state["messages"].append(("assistant", text))
378
- with st.chat_message("assistant"):
379
  st.markdown(text)
380
 
381
  elif intent == "what_if":
382
- # Parse & apply change, recompute, compare
383
  if st.session_state["input_features"] is None:
384
  msg = "Please run a prediction first (Step 1) so I know your starting feature values."
385
  st.session_state["messages"].append(("assistant", msg))
386
- with st.chat_message("assistant"):
387
- st.markdown(msg)
388
- else:
389
- new_feats, status = apply_what_if(user_q, feat_names, st.session_state["input_features"])
390
- if new_feats is None:
391
- st.session_state["messages"].append(("assistant", status))
392
  with st.chat_message("assistant"):
393
  st.markdown(status)
394
  else:
395
- with st.spinner("πŸ” Recomputing with your change..."):
396
  new_res = perform_api_call(new_feats)
397
 
398
  new_pred = new_res["prediction"]["y_pred"]
399
- new_proba = new_res["prediction"]["proba"]
400
- new_label = label_from_pred(new_pred)
401
-
402
- if "LLM" in st.session_state["chat_mode"]:
403
  ctx = {
404
  "change_applied": status,
405
  "before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
406
- "after": {"features": new_feats, "label": new_label, "proba": new_proba}
407
  }
408
- with st.spinner("πŸ’‘ Summarizing the effect with LLM..."):
409
  answer = llm_explain(new_res, feat_names, extra_context=ctx)
410
  st.session_state["messages"].append(("assistant", answer))
411
  with st.chat_message("assistant"):
412
  st.markdown(answer)
413
  else:
414
- # System comparison summary + new SHAP chart
415
  lines = [
416
  f"**Change applied:** {status}",
417
  f"**Before:** {base_label} (class `{base_pred}`) β€” confidence **{base_proba:.2f}**",
418
- f"**After:** {new_label} (class `{new_pred}`) β€” confidence **{new_proba:.2f}**",
419
- ]
420
- with st.chat_message("assistant"):
421
- st.markdown("\n\n".join(lines))
422
- st.markdown("**New explanation (SHAP) for the changed input:**")
423
- explain_in_words(new_res, len(safe_similar_cases(new_res)), feat_names)
424
  st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
425
 
426
  else:
427
- # Summary fallback
428
  summary = summarize_prediction(base_res)
429
  st.session_state["messages"].append(("assistant", summary))
430
  with st.chat_message("assistant"):
 
1
  import re
2
  import json
3
+
4
  import matplotlib.pyplot as plt
5
  import streamlit as st
6
  import requests
7
+ import os
8
+
9
+ from litellm import completion # pip install litellm
10
+
11
 
 
 
 
12
 
13
+ st.set_page_config(page_title="EchoML", page_icon="πŸ’¬", layout="wide")
14
+ st.title("πŸ’¬ Chat with Your Model (IRIS Edition)")
15
 
16
+ # -----------------------------
17
  # Sidebar configuration
18
+ # -----------------------------
19
  with st.sidebar:
20
  st.header("Settings")
21
+
22
+ api_url = st.text_input(
23
+ "FastAPI endpoint",
24
+ value="https://query-your-model-api-784882848382.us-central1.run.app/explain",
25
+ )
26
  model_path = st.text_input("Model path", value="Query_Your_Model/model_data/model.pkl")
27
+
28
  feat_names_str = st.text_input(
29
  "Feature names (comma-separated)",
30
+ value="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",
31
  )
32
+
33
  namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global")
34
+
35
+ # IMPORTANT: your retrieval.py shows:
36
+ # similarity = alpha * cos(SHAP) + (1 - alpha) * cos(features)
37
+ alpha = st.slider("Alpha (retrieval weight: SHAP vs features)", 0.0, 1.0, 0.7, 0.05)
38
  k = st.slider("Top-K similar to retrieve", 1, 10, 5)
39
 
40
+ st.divider()
41
+ st.subheader("Nova (LLM) Settings")
42
+
43
+ # Model IDs based on your Nova console screenshot
44
+ nova_model_id = st.selectbox(
45
+ "Nova model",
46
+ options=["nova-micro-v1", "nova-lite-v1", "nova-pro-v1", "nova-premier-v1"],
47
+ index=1,
48
+ help="These are the Nova model IDs from the Nova developer console.",
49
+ )
50
+
51
+ temperature = st.slider("LLM temperature", 0.0, 1.0, 0.2, 0.05)
52
+ max_tokens = st.slider("LLM max tokens", 64, 1024, 350, 32)
53
+
54
  feat_names = [s.strip() for s in feat_names_str.split(",")]
55
 
56
+ # -----------------------------
57
+
58
+
59
 
 
60
  # Helpers
61
+ # -----------------------------
62
  def label_from_pred(y_pred):
63
  try:
64
  num = int(round(float(y_pred)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def show_similar_cases(res, n_display, feat_names):
67
  sims = safe_similar_cases(res)
68
+ if not sims:
69
  return "No similar cases were retrieved."
70
  n = min(n_display, len(sims))
71
+ lines = [f"It found **{len(sims)}** similar reference cases (showing **{n}**):"]
72
  for case in sims[:n]:
73
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
74
  lines.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
75
+
76
  return "\n".join(lines)
77
 
78
  def plot_shap_bar(topk):
 
 
 
79
  feats = [f["feature"] for f in topk]
80
  shap_vals = [f["shap"] for f in topk]
81
  fig, ax = plt.subplots()
82
+ ax.barh(feats, shap_vals) # default colors
83
  ax.set_xlabel("SHAP value (impact on prediction)")
84
  ax.set_title("Feature importance for this prediction")
85
  st.pyplot(fig)
86
 
 
 
 
 
 
 
87
  msg = [
88
  f"🌸 Based on these features, the model thinks it's **{label}** (class `{pred}`) with confidence **{proba:.2f}**.\n",
89
+ "### Key reasons (SHAP):",
90
  ]
91
 
92
  if topk:
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  for case in sims[:n]:
94
  features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
95
  msg.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
96
+
97
 
98
  st.markdown("\n".join(msg))
99
  plot_shap_bar(topk)
100
 
101
+ def interpret_question(user_q):
102
+ q = (user_q or "").lower()
103
+ if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
104
+ return "what_if"
105
+ if any(w in q for w in ["why", "explain", "reason"]):
106
+ return "explain"
107
+ if "similar" in q:
108
+ return "similar"
109
+ return "summary"
110
+
111
+ def perform_api_call(features):
112
+ payload = {
113
+ "model_path": model_path,
114
+ "feature_names": feat_names,
115
+ "features": features,
116
+ "namespace": namespace,
117
+ "retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
118
+ }
119
+ r = requests.post(api_url, json=payload, timeout=60)
120
+ r.raise_for_status()
121
+ return r.json()
122
+
123
+ # -----------------------------
124
+ # Nova LLM (via LiteLLM)
125
+ # -----------------------------
126
+ def nova_llm_text(prompt: str) -> str:
127
+ """
128
+ Uses Amazon Nova API key from env: AMAZON_NOVA_API_KEY
129
+ Model route: amazon_nova/<model_id>
130
+ """
131
+ api_key = os.getenv("AMAZON_NOVA_API_KEY")
132
+ if not api_key:
133
+ return (
134
+ "Nova API key not found. In Hugging Face β†’ Settings β†’ Variables and secrets, "
135
+ "add a *Secret* named `AMAZON_NOVA_API_KEY`."
136
+ )
137
+
138
+ # LiteLLM expects the provider key in env or passed; env is simplest for Spaces
139
+ os.environ["AMAZON_NOVA_API_KEY"] = api_key
140
+
141
+ try:
142
+ resp = completion(
143
+ model=f"amazon_nova/{nova_model_id}",
144
+ messages=[{"role": "user", "content": prompt}],
145
+ temperature=temperature,
146
+ max_tokens=max_tokens,
147
+ )
148
+ return resp.choices[0].message.content
149
+ except Exception as e:
150
+ return f"LLM explanation failed: {e}"
151
+
152
  def llm_explain(res, feat_names, extra_context=None):
153
+
154
  try:
155
  pred = label_from_pred(res["prediction"]["y_pred"])
156
  proba = res["prediction"]["proba"]
 
 
 
 
 
157
  "probability": round(proba, 3),
158
  "topk": topk,
159
  "similar_examples_sample": sims[:3],
160
+ "extra_context": extra_context or {},
161
+ "retrieval_note": "alpha=1.0 prioritizes SHAP similarity; alpha=0.0 prioritizes feature similarity.",
162
  }
163
 
164
  prompt = (
165
  "You are an explainability copilot. Explain to a non-technical user.\n\n"
166
  f"DATA:\n{json.dumps(base_prompt, indent=2)}\n\n"
167
  "Write a short, clear answer that covers:\n"
168
+ "- Why the model made the prediction (grounded in SHAP)\n"
169
+ "- Which features mattered most\n"
170
+ "- Why those features mattered\n"
171
+ "- 2-3 concrete experiments: tell them which feature values to increase/decrease and what to watch for"
172
  )
173
 
174
+ return nova_llm_text(prompt)
175
+
176
+
177
+
178
+
179
  except Exception as e:
180
  return f"LLM explanation failed: {e}"
181
 
182
+ # -----------------------------
183
+ # What-if parsing
184
+ # -----------------------------
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
 
 
 
 
 
 
 
 
 
 
202
 
 
203
 
204
  FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
205
 
206
  def match_feature_name(fragment, feat_names):
207
+
208
  frag = fragment.strip().lower()
209
  best, best_score = None, -1
210
  for name in feat_names:
 
 
 
 
 
 
 
 
 
 
211
  return None
212
 
213
  def default_delta(curr):
214
+
215
  base = abs(curr) * 0.10
216
  return base if base >= 0.1 else (0.1 if curr >= 0 else -0.1)
217
 
218
  def apply_what_if(user_q, feat_names, current_features):
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
  q = user_q.lower()
230
+ new = current_features.copy()
231
+ changes = []
232
+
233
+ parts = re.split(r",| and ", q)
234
+
235
+ modifier_scale = {
236
+ "slightly": 0.5,
237
+ "a bit": 0.5,
238
+ "a little": 0.5,
239
+ "somewhat": 0.7,
240
+ "moderately": 1.0,
241
+ "significantly": 1.5,
242
+ "greatly": 2.0,
243
+ "a lot": 2.0,
244
+ }
245
 
246
+ for part in parts:
247
+ part = part.strip()
248
+ if not part:
249
+ continue
250
+
251
+ scale = 1.0
252
+ for mod, factor in modifier_scale.items():
253
+ if mod in part:
254
+ scale = factor
255
+ part = part.replace(mod, "")
256
+ break
257
+
258
+ # set to value
259
+ m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
260
+ if m:
261
+ feat_frag, val_str = m.group(1), m.group(2)
262
+ fname = match_feature_name(feat_frag, feat_names)
263
+ if fname is None:
264
+ changes.append(f"Couldn't identify which feature to set from: '{feat_frag}'.")
265
+ continue
266
+ val = parse_numeric(val_str)
267
+ if val is None:
268
+ changes.append(f"Couldn't parse a number from: '{val_str}'.")
269
+ continue
270
+ idx = feat_names.index(fname)
271
+ new[idx] = val
272
+ changes.append(f"Set **{fname}** to **{val:.2f}**.")
273
+ continue
274
+
275
+ # +/- absolute
276
+ m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
277
+ if m:
278
+ op, feat_frag, val_str = m.groups()
279
+ fname = match_feature_name(feat_frag, feat_names)
280
+ if fname is None:
281
+ changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
282
+ continue
283
+ delta = parse_numeric(val_str)
284
+ if delta is None:
285
+ changes.append(f"Couldn't parse a number from: '{val_str}'.")
286
+ continue
287
+ delta *= scale
288
+ delta = -abs(delta) if op in ["decrease", "reduce"] else abs(delta)
289
+ idx = feat_names.index(fname)
290
+ new[idx] = new[idx] + delta
291
+ changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β†’ **{new[idx]:.2f}**.")
292
+ continue
293
+
294
+ # +/- percent
295
+ m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
296
+ if m:
297
+ op, feat_frag, perc_str = m.groups()
298
+ fname = match_feature_name(feat_frag, feat_names)
299
+ if fname is None:
300
+ changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
301
+ continue
302
+ perc = parse_numeric(perc_str)
303
+ if perc is None:
304
+ changes.append(f"Couldn't parse a percentage from: '{perc_str}'.")
305
+ continue
306
+ perc *= scale
307
+ idx = feat_names.index(fname)
308
+ factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0)
309
+ new[idx] = new[idx] * factor
310
+ changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β†’ **{new[idx]:.2f}**.")
311
+ continue
312
+
313
+ # make higher/lower
314
+ m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
315
+ if m:
316
+ _, feat_frag, direction = m.groups()
317
+ fname = match_feature_name(feat_frag, feat_names)
318
+ if fname is None:
319
+ changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
320
+ continue
321
+ idx = feat_names.index(fname)
322
+ base_delta = default_delta(current_features[idx]) * scale
323
+ delta = base_delta if direction in ["higher", "increase", "raise"] else -abs(base_delta)
324
+ new[idx] = new[idx] + delta
325
+ verb = "Increased" if delta > 0 else "Decreased"
326
+ changes.append(f"{verb} **{fname}** by **{abs(delta):.2f}** (scaled {scale:.1f}Γ—) β†’ **{new[idx]:.2f}**.")
327
+ continue
328
+
329
+ changes.append(f"Couldn't parse instruction: '{part}'.")
330
+
331
+ if not changes:
332
+ return None, "No valid feature changes detected."
333
+
334
+ return new, "\n".join(changes)
335
+
336
+ # -----------------------------
337
+ # App state init
338
+ # -----------------------------
339
+ st.subheader("Step 1 – Enter features to generate a prediction in the order: [sepal length, sepal width, petal length, petal width]")
340
  user_features = st.text_input("Enter feature values (comma-separated)", "")
341
  predict_btn = st.button("πŸ” Predict and Explain")
342
 
 
 
343
  st.session_state["messages"] = []
344
  st.session_state["input_features"] = None
345
 
346
+ # -----------------------------
347
+ # Step 1: Predict
348
+ # -----------------------------
349
  if predict_btn:
350
  try:
351
  features = [float(x.strip()) for x in user_features.split(",") if x.strip()]
 
352
  st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
353
  else:
354
  st.session_state["input_features"] = features
355
+
356
  st.markdown("### ✨ Entered Features")
357
  st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
358
+
359
  res = perform_api_call(features)
360
  st.session_state["prediction_result"] = res
361
  st.session_state["messages"] = []
362
  st.success(summarize_prediction(res))
363
+ st.info("Scroll down to explore similar cases or chat.")
364
  except Exception as e:
365
  st.error(f"Error contacting API: {e}")
366
 
367
+ # -----------------------------
368
+ # Step 2 + Step 3
369
+ # -----------------------------
370
  if st.session_state["prediction_result"]:
371
  st.divider()
372
  st.subheader("Step 2 – Explore similar cases")
 
 
 
 
 
 
 
 
 
 
 
 
373
  else:
374
  st.write("No similar cases retrieved.")
375
 
376
+
377
+
378
+
379
  st.divider()
380
  st.subheader("Step 3 – Chat with the model about this prediction")
381
 
382
+
383
  if "chat_mode" not in st.session_state:
384
+ st.session_state["chat_mode"] = "System"
385
+
386
  st.session_state["chat_mode"] = st.radio(
387
  "How should explanations be generated?",
388
+ ["System", "LLM (Natural language)"],
389
+ index=0 if st.session_state["chat_mode"] == "System" else 1,
390
  horizontal=True,
391
  )
392
 
393
+
394
  for role, content in st.session_state["messages"]:
395
  with st.chat_message(role):
396
  st.markdown(content)
397
 
398
+
399
  if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
400
  st.session_state["messages"].append(("user", user_q))
401
  with st.chat_message("user"):
 
402
 
403
  intent = interpret_question(user_q)
404
 
405
+
406
  base_res = st.session_state["prediction_result"]
407
  base_pred = base_res["prediction"]["y_pred"]
408
  base_proba = base_res["prediction"]["proba"]
 
409
 
410
  if intent == "explain":
411
  if "LLM" in st.session_state["chat_mode"]:
412
+ with st.spinner("Generating Nova explanation..."):
413
  answer = llm_explain(base_res, feat_names)
414
  st.session_state["messages"].append(("assistant", answer))
415
  with st.chat_message("assistant"):
 
 
 
 
 
 
 
 
 
 
416
  st.markdown(text)
417
 
418
  elif intent == "what_if":
419
+
420
  if st.session_state["input_features"] is None:
421
  msg = "Please run a prediction first (Step 1) so I know your starting feature values."
422
  st.session_state["messages"].append(("assistant", msg))
 
 
 
 
 
 
423
  with st.chat_message("assistant"):
424
  st.markdown(status)
425
  else:
426
+ with st.spinner("Recomputing with your change..."):
427
  new_res = perform_api_call(new_feats)
428
 
429
  new_pred = new_res["prediction"]["y_pred"]
 
 
 
 
430
  ctx = {
431
  "change_applied": status,
432
  "before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
433
+ "after": {"features": new_feats, "label": new_label, "proba": new_proba},
434
  }
435
+ with st.spinner("Summarizing the effect with Nova..."):
436
  answer = llm_explain(new_res, feat_names, extra_context=ctx)
437
  st.session_state["messages"].append(("assistant", answer))
438
  with st.chat_message("assistant"):
439
  st.markdown(answer)
440
  else:
441
+
442
  lines = [
443
  f"**Change applied:** {status}",
444
  f"**Before:** {base_label} (class `{base_pred}`) β€” confidence **{base_proba:.2f}**",
 
 
 
 
 
 
445
  st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
446
 
447
  else:
448
+
449
  summary = summarize_prediction(base_res)
450
  st.session_state["messages"].append(("assistant", summary))
451
  with st.chat_message("assistant"):