Spaces:
Build error
Build error
Update app/app_chat.py
Browse files- app/app_chat.py +94 -69
app/app_chat.py
CHANGED
|
@@ -8,42 +8,66 @@ import os
|
|
| 8 |
import boto3
|
| 9 |
from botocore.config import Config
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
"bedrock-runtime",
|
| 15 |
-
region_name=AWS_REGION,
|
| 16 |
-
config=Config(read_timeout=60, connect_timeout=60, retries={"max_attempts": 3}),
|
| 17 |
-
)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
# Common examples include Nova Lite / Pro in Bedrock.
|
| 21 |
NOVA_MODEL_ID = os.getenv("NOVA_MODEL_ID", "us.amazon.nova-lite-v1:0")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
st.set_page_config(page_title="EchoML", page_icon="π¬", layout="wide")
|
| 26 |
-
st.title("π¬ Chat with Your Model(IRIS Edition)")
|
| 27 |
|
| 28 |
# Sidebar configuration
|
| 29 |
with st.sidebar:
|
| 30 |
st.header("Settings")
|
| 31 |
-
api_url = st.text_input(
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
feat_names_str = st.text_input(
|
| 34 |
"Feature names (comma-separated)",
|
| 35 |
-
value=
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
-
namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global")
|
| 38 |
alpha = st.slider("Alpha (retrieval weight)", 0.0, 1.0, 0.7, 0.05)
|
| 39 |
k = st.slider("Top-K similar to retrieve", 1, 10, 5)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
feat_names = [s.strip() for s in feat_names_str.split(",")]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
|
|
|
| 46 |
# Helpers
|
|
|
|
|
|
|
| 47 |
def label_from_pred(y_pred):
|
| 48 |
try:
|
| 49 |
num = int(round(float(y_pred)))
|
|
@@ -66,10 +90,10 @@ def summarize_prediction(res):
|
|
| 66 |
|
| 67 |
def show_similar_cases(res, n_display, feat_names):
|
| 68 |
sims = safe_similar_cases(res)
|
| 69 |
-
if not sims:
|
| 70 |
return "No similar cases were retrieved."
|
| 71 |
n = min(n_display, len(sims))
|
| 72 |
-
lines = [f"It found **{len(sims)}** similar
|
| 73 |
for case in sims[:n]:
|
| 74 |
features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
|
| 75 |
lines.append(f"- **{case['case_id']}** β {features_named}, predicted as **{case['y_pred']}**.")
|
|
@@ -82,7 +106,7 @@ def plot_shap_bar(topk):
|
|
| 82 |
feats = [f["feature"] for f in topk]
|
| 83 |
shap_vals = [f["shap"] for f in topk]
|
| 84 |
fig, ax = plt.subplots()
|
| 85 |
-
ax.barh(feats, shap_vals) # default colors
|
| 86 |
ax.set_xlabel("SHAP value (impact on prediction)")
|
| 87 |
ax.set_title("Feature importance for this prediction")
|
| 88 |
st.pyplot(fig)
|
|
@@ -119,7 +143,29 @@ def explain_in_words(res, n_display, feat_names):
|
|
| 119 |
st.markdown("\n".join(msg))
|
| 120 |
plot_shap_bar(topk)
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def llm_explain(res, feat_names, extra_context=None):
|
|
|
|
|
|
|
|
|
|
| 123 |
try:
|
| 124 |
pred = label_from_pred(res["prediction"]["y_pred"])
|
| 125 |
proba = res["prediction"]["proba"]
|
|
@@ -141,20 +187,14 @@ def llm_explain(res, feat_names, extra_context=None):
|
|
| 141 |
"- Why the model made the prediction\n"
|
| 142 |
"- Which features mattered\n"
|
| 143 |
"- Why those features mattered\n"
|
| 144 |
-
"-
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
response = bedrock.converse(
|
| 148 |
-
modelId=NOVA_MODEL_ID,
|
| 149 |
-
messages=[{"role": "user", "content": [{"text": prompt}]}],
|
| 150 |
)
|
| 151 |
|
| 152 |
-
return
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
return f"LLM explanation failed: {e}"
|
| 156 |
|
| 157 |
-
|
| 158 |
def interpret_question(user_q):
|
| 159 |
q = (user_q or "").lower()
|
| 160 |
if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
|
|
@@ -173,9 +213,13 @@ def perform_api_call(features):
|
|
| 173 |
"namespace": namespace,
|
| 174 |
"retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
|
| 175 |
}
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
|
|
|
| 178 |
# What-if parsing
|
|
|
|
| 179 |
|
| 180 |
FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
|
| 181 |
|
|
@@ -204,24 +248,13 @@ def default_delta(curr):
|
|
| 204 |
def apply_what_if(user_q, feat_names, current_features):
|
| 205 |
"""
|
| 206 |
Returns (new_features, change_text) or (None, error_msg)
|
| 207 |
-
|
| 208 |
-
Handles:
|
| 209 |
-
- "increase sepal length to 5.8"
|
| 210 |
-
- "decrease petal width by 0.2"
|
| 211 |
-
- "increase sepal width by 10%"
|
| 212 |
-
- "make petal length higher"
|
| 213 |
-
- "slightly increase sepal width"
|
| 214 |
-
- "increase sepal length to 5.8, decrease petal width to 0.2 and reduce sepal width a bit"
|
| 215 |
"""
|
| 216 |
-
|
| 217 |
q = user_q.lower()
|
| 218 |
new = current_features.copy()
|
| 219 |
changes = []
|
| 220 |
|
| 221 |
-
# Split query by commas or 'and'
|
| 222 |
parts = re.split(r",| and ", q)
|
| 223 |
|
| 224 |
-
# Define modifiers with scaling factors (relative to default 10%)
|
| 225 |
modifier_scale = {
|
| 226 |
"slightly": 0.5,
|
| 227 |
"a bit": 0.5,
|
|
@@ -238,7 +271,6 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 238 |
if not part:
|
| 239 |
continue
|
| 240 |
|
| 241 |
-
# Detect intensity modifier
|
| 242 |
scale = 1.0
|
| 243 |
for mod, factor in modifier_scale.items():
|
| 244 |
if mod in part:
|
|
@@ -246,7 +278,7 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 246 |
part = part.replace(mod, "")
|
| 247 |
break
|
| 248 |
|
| 249 |
-
# 1)
|
| 250 |
m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
|
| 251 |
if m:
|
| 252 |
feat_frag, val_str = m.group(1), m.group(2)
|
|
@@ -263,7 +295,7 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 263 |
changes.append(f"Set **{fname}** to **{val:.2f}**.")
|
| 264 |
continue
|
| 265 |
|
| 266 |
-
# 2)
|
| 267 |
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
|
| 268 |
if m:
|
| 269 |
op, feat_frag, val_str = m.groups()
|
|
@@ -276,16 +308,13 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 276 |
changes.append(f"Couldn't parse a number from: '{val_str}'.")
|
| 277 |
continue
|
| 278 |
delta *= scale
|
| 279 |
-
if op in ["decrease", "reduce"]
|
| 280 |
-
delta = -abs(delta)
|
| 281 |
-
else:
|
| 282 |
-
delta = abs(delta)
|
| 283 |
idx = feat_names.index(fname)
|
| 284 |
new[idx] = new[idx] + delta
|
| 285 |
changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β **{new[idx]:.2f}**.")
|
| 286 |
continue
|
| 287 |
|
| 288 |
-
# 3)
|
| 289 |
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
|
| 290 |
if m:
|
| 291 |
op, feat_frag, perc_str = m.groups()
|
|
@@ -299,12 +328,12 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 299 |
continue
|
| 300 |
perc *= scale
|
| 301 |
idx = feat_names.index(fname)
|
| 302 |
-
factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0)
|
| 303 |
new[idx] = new[idx] * factor
|
| 304 |
changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β **{new[idx]:.2f}**.")
|
| 305 |
continue
|
| 306 |
|
| 307 |
-
# 4)
|
| 308 |
m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
|
| 309 |
if m:
|
| 310 |
_, feat_frag, direction = m.groups()
|
|
@@ -327,9 +356,11 @@ def apply_what_if(user_q, feat_names, current_features):
|
|
| 327 |
|
| 328 |
return new, "\n".join(changes)
|
| 329 |
|
| 330 |
-
|
| 331 |
# Step 1: Enter features & predict
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
user_features = st.text_input("Enter feature values (comma-separated)", "")
|
| 334 |
predict_btn = st.button("π Predict and Explain")
|
| 335 |
|
|
@@ -345,19 +376,21 @@ if predict_btn:
|
|
| 345 |
st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
|
| 346 |
else:
|
| 347 |
st.session_state["input_features"] = features
|
| 348 |
-
# Show entered features
|
| 349 |
st.markdown("### β¨ Entered Features")
|
| 350 |
st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
|
| 351 |
-
|
| 352 |
res = perform_api_call(features)
|
| 353 |
st.session_state["prediction_result"] = res
|
| 354 |
st.session_state["messages"] = []
|
| 355 |
st.success(summarize_prediction(res))
|
| 356 |
-
st.info("Scroll down to explore similar cases or chat
|
| 357 |
except Exception as e:
|
| 358 |
-
st.error(f"Error
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
-
# Step 2: Similar cases
|
| 361 |
if st.session_state["prediction_result"]:
|
| 362 |
st.divider()
|
| 363 |
st.subheader("Step 2 β Explore similar cases")
|
|
@@ -376,26 +409,22 @@ if st.session_state["prediction_result"]:
|
|
| 376 |
else:
|
| 377 |
st.write("No similar cases retrieved.")
|
| 378 |
|
| 379 |
-
# Step 3: Explanation Mode + Chat
|
| 380 |
st.divider()
|
| 381 |
st.subheader("Step 3 β Chat with the model about this prediction")
|
| 382 |
|
| 383 |
-
# Choose explanation mode BEFORE asking questions
|
| 384 |
if "chat_mode" not in st.session_state:
|
| 385 |
st.session_state["chat_mode"] = "System"
|
| 386 |
st.session_state["chat_mode"] = st.radio(
|
| 387 |
"How should explanations be generated?",
|
| 388 |
["System", "LLM (Natural language)"],
|
| 389 |
-
index=0 if
|
| 390 |
horizontal=True,
|
| 391 |
)
|
| 392 |
|
| 393 |
-
# Show previous messages
|
| 394 |
for role, content in st.session_state["messages"]:
|
| 395 |
with st.chat_message(role):
|
| 396 |
st.markdown(content)
|
| 397 |
|
| 398 |
-
# Chat input
|
| 399 |
if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
|
| 400 |
st.session_state["messages"].append(("user", user_q))
|
| 401 |
with st.chat_message("user"):
|
|
@@ -403,7 +432,6 @@ if st.session_state["prediction_result"]:
|
|
| 403 |
|
| 404 |
intent = interpret_question(user_q)
|
| 405 |
|
| 406 |
-
# Current base result
|
| 407 |
base_res = st.session_state["prediction_result"]
|
| 408 |
base_pred = base_res["prediction"]["y_pred"]
|
| 409 |
base_proba = base_res["prediction"]["proba"]
|
|
@@ -411,7 +439,7 @@ if st.session_state["prediction_result"]:
|
|
| 411 |
|
| 412 |
if intent == "explain":
|
| 413 |
if "LLM" in st.session_state["chat_mode"]:
|
| 414 |
-
with st.spinner("
|
| 415 |
answer = llm_explain(base_res, feat_names)
|
| 416 |
st.session_state["messages"].append(("assistant", answer))
|
| 417 |
with st.chat_message("assistant"):
|
|
@@ -428,7 +456,6 @@ if st.session_state["prediction_result"]:
|
|
| 428 |
st.markdown(text)
|
| 429 |
|
| 430 |
elif intent == "what_if":
|
| 431 |
-
# Parse & apply change, recompute, compare
|
| 432 |
if st.session_state["input_features"] is None:
|
| 433 |
msg = "Please run a prediction first (Step 1) so I know your starting feature values."
|
| 434 |
st.session_state["messages"].append(("assistant", msg))
|
|
@@ -452,15 +479,14 @@ if st.session_state["prediction_result"]:
|
|
| 452 |
ctx = {
|
| 453 |
"change_applied": status,
|
| 454 |
"before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
|
| 455 |
-
"after": {"features": new_feats, "label": new_label, "proba": new_proba}
|
| 456 |
}
|
| 457 |
-
with st.spinner("
|
| 458 |
answer = llm_explain(new_res, feat_names, extra_context=ctx)
|
| 459 |
st.session_state["messages"].append(("assistant", answer))
|
| 460 |
with st.chat_message("assistant"):
|
| 461 |
st.markdown(answer)
|
| 462 |
else:
|
| 463 |
-
# System comparison summary + new SHAP chart
|
| 464 |
lines = [
|
| 465 |
f"**Change applied:** {status}",
|
| 466 |
f"**Before:** {base_label} (class `{base_pred}`) β confidence **{base_proba:.2f}**",
|
|
@@ -473,7 +499,6 @@ if st.session_state["prediction_result"]:
|
|
| 473 |
st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
|
| 474 |
|
| 475 |
else:
|
| 476 |
-
# Summary fallback
|
| 477 |
summary = summarize_prediction(base_res)
|
| 478 |
st.session_state["messages"].append(("assistant", summary))
|
| 479 |
with st.chat_message("assistant"):
|
|
|
|
| 8 |
import boto3
|
| 9 |
from botocore.config import Config
|
| 10 |
|
| 11 |
+
# ============================================================
|
| 12 |
+
# Bedrock (Amazon Nova) setup
|
| 13 |
+
# ============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
|
|
|
|
| 16 |
NOVA_MODEL_ID = os.getenv("NOVA_MODEL_ID", "us.amazon.nova-lite-v1:0")
|
| 17 |
|
| 18 |
+
def make_bedrock_client():
|
| 19 |
+
"""
|
| 20 |
+
Creates a Bedrock Runtime client.
|
| 21 |
+
Auth is provided via Hugging Face Secrets / env vars:
|
| 22 |
+
- AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
|
| 23 |
+
- AWS_REGION
|
| 24 |
+
"""
|
| 25 |
+
# If key not set, we'll still create client, but calls will fail with a clear error later.
|
| 26 |
+
return boto3.client(
|
| 27 |
+
"bedrock-runtime",
|
| 28 |
+
region_name=AWS_REGION,
|
| 29 |
+
config=Config(read_timeout=60, connect_timeout=60, retries={"max_attempts": 3}),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
bedrock = make_bedrock_client()
|
| 33 |
|
| 34 |
+
# ============================================================
|
| 35 |
+
# Streamlit App
|
| 36 |
+
# ============================================================
|
| 37 |
|
| 38 |
st.set_page_config(page_title="EchoML", page_icon="π¬", layout="wide")
|
| 39 |
+
st.title("π¬ Chat with Your Model (IRIS Edition)")
|
| 40 |
|
| 41 |
# Sidebar configuration
|
| 42 |
with st.sidebar:
|
| 43 |
st.header("Settings")
|
| 44 |
+
api_url = st.text_input(
|
| 45 |
+
"FastAPI endpoint",
|
| 46 |
+
value=os.getenv("ECHO_API_URL", "https://query-your-model-api-784882848382.us-central1.run.app/explain"),
|
| 47 |
+
)
|
| 48 |
+
model_path = st.text_input("Model path", value=os.getenv("MODEL_PATH", "Query_Your_Model/model_data/model.pkl"))
|
| 49 |
feat_names_str = st.text_input(
|
| 50 |
"Feature names (comma-separated)",
|
| 51 |
+
value=os.getenv(
|
| 52 |
+
"FEATURE_NAMES",
|
| 53 |
+
"sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)"
|
| 54 |
+
),
|
| 55 |
)
|
| 56 |
+
namespace = st.text_input("Namespace", value=os.getenv("NAMESPACE", "Query_Your_Model/data/base_indices/iris_global"))
|
| 57 |
alpha = st.slider("Alpha (retrieval weight)", 0.0, 1.0, 0.7, 0.05)
|
| 58 |
k = st.slider("Top-K similar to retrieve", 1, 10, 5)
|
| 59 |
|
| 60 |
+
st.divider()
|
| 61 |
+
st.caption("LLM provider: Amazon Bedrock (Nova)")
|
| 62 |
+
st.caption(f"AWS_REGION: `{AWS_REGION}`")
|
| 63 |
+
st.caption(f"NOVA_MODEL_ID: `{NOVA_MODEL_ID}`")
|
| 64 |
|
| 65 |
+
feat_names = [s.strip() for s in feat_names_str.split(",") if s.strip()]
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
# ============================================================
|
| 68 |
# Helpers
|
| 69 |
+
# ============================================================
|
| 70 |
+
|
| 71 |
def label_from_pred(y_pred):
|
| 72 |
try:
|
| 73 |
num = int(round(float(y_pred)))
|
|
|
|
| 90 |
|
| 91 |
def show_similar_cases(res, n_display, feat_names):
|
| 92 |
sims = safe_similar_cases(res)
|
| 93 |
+
if not sims:
|
| 94 |
return "No similar cases were retrieved."
|
| 95 |
n = min(n_display, len(sims))
|
| 96 |
+
lines = [f"It found **{len(sims)}** similar cases (showing **{n}**):"]
|
| 97 |
for case in sims[:n]:
|
| 98 |
features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
|
| 99 |
lines.append(f"- **{case['case_id']}** β {features_named}, predicted as **{case['y_pred']}**.")
|
|
|
|
| 106 |
feats = [f["feature"] for f in topk]
|
| 107 |
shap_vals = [f["shap"] for f in topk]
|
| 108 |
fig, ax = plt.subplots()
|
| 109 |
+
ax.barh(feats, shap_vals) # default colors
|
| 110 |
ax.set_xlabel("SHAP value (impact on prediction)")
|
| 111 |
ax.set_title("Feature importance for this prediction")
|
| 112 |
st.pyplot(fig)
|
|
|
|
| 143 |
st.markdown("\n".join(msg))
|
| 144 |
plot_shap_bar(topk)
|
| 145 |
|
| 146 |
+
def bedrock_llm(prompt: str) -> str:
|
| 147 |
+
"""
|
| 148 |
+
Calls Amazon Nova via Bedrock Converse API.
|
| 149 |
+
Requires Hugging Face Secret: AWS_BEARER_TOKEN_BEDROCK
|
| 150 |
+
"""
|
| 151 |
+
if not os.getenv("AWS_BEARER_TOKEN_BEDROCK"):
|
| 152 |
+
return (
|
| 153 |
+
"LLM explanation is not available because `AWS_BEARER_TOKEN_BEDROCK` is not set.\n\n"
|
| 154 |
+
"In Hugging Face Spaces β Settings β Variables and secrets β Secrets, add:\n"
|
| 155 |
+
"- Name: AWS_BEARER_TOKEN_BEDROCK\n"
|
| 156 |
+
"- Value: (your Bedrock API key)\n"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
resp = bedrock.converse(
|
| 160 |
+
modelId=NOVA_MODEL_ID,
|
| 161 |
+
messages=[{"role": "user", "content": [{"text": prompt}]}],
|
| 162 |
+
)
|
| 163 |
+
return resp["output"]["message"]["content"][0]["text"]
|
| 164 |
+
|
| 165 |
def llm_explain(res, feat_names, extra_context=None):
|
| 166 |
+
"""
|
| 167 |
+
LLM explanation: can handle 'why' and 'what-if' using the provided context (old/new).
|
| 168 |
+
"""
|
| 169 |
try:
|
| 170 |
pred = label_from_pred(res["prediction"]["y_pred"])
|
| 171 |
proba = res["prediction"]["proba"]
|
|
|
|
| 187 |
"- Why the model made the prediction\n"
|
| 188 |
"- Which features mattered\n"
|
| 189 |
"- Why those features mattered\n"
|
| 190 |
+
"- 2-3 experiments the user could perform: tell them which feature values to increase/decrease"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
+
return bedrock_llm(prompt)
|
| 194 |
|
| 195 |
except Exception as e:
|
| 196 |
return f"LLM explanation failed: {e}"
|
| 197 |
|
|
|
|
| 198 |
def interpret_question(user_q):
|
| 199 |
q = (user_q or "").lower()
|
| 200 |
if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
|
|
|
|
| 213 |
"namespace": namespace,
|
| 214 |
"retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
|
| 215 |
}
|
| 216 |
+
r = requests.post(api_url, json=payload, timeout=60)
|
| 217 |
+
r.raise_for_status()
|
| 218 |
+
return r.json()
|
| 219 |
|
| 220 |
+
# ============================================================
|
| 221 |
# What-if parsing
|
| 222 |
+
# ============================================================
|
| 223 |
|
| 224 |
FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
|
| 225 |
|
|
|
|
| 248 |
def apply_what_if(user_q, feat_names, current_features):
|
| 249 |
"""
|
| 250 |
Returns (new_features, change_text) or (None, error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
"""
|
|
|
|
| 252 |
q = user_q.lower()
|
| 253 |
new = current_features.copy()
|
| 254 |
changes = []
|
| 255 |
|
|
|
|
| 256 |
parts = re.split(r",| and ", q)
|
| 257 |
|
|
|
|
| 258 |
modifier_scale = {
|
| 259 |
"slightly": 0.5,
|
| 260 |
"a bit": 0.5,
|
|
|
|
| 271 |
if not part:
|
| 272 |
continue
|
| 273 |
|
|
|
|
| 274 |
scale = 1.0
|
| 275 |
for mod, factor in modifier_scale.items():
|
| 276 |
if mod in part:
|
|
|
|
| 278 |
part = part.replace(mod, "")
|
| 279 |
break
|
| 280 |
|
| 281 |
+
# 1) Set to value
|
| 282 |
m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
|
| 283 |
if m:
|
| 284 |
feat_frag, val_str = m.group(1), m.group(2)
|
|
|
|
| 295 |
changes.append(f"Set **{fname}** to **{val:.2f}**.")
|
| 296 |
continue
|
| 297 |
|
| 298 |
+
# 2) Increase/decrease by absolute value
|
| 299 |
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
|
| 300 |
if m:
|
| 301 |
op, feat_frag, val_str = m.groups()
|
|
|
|
| 308 |
changes.append(f"Couldn't parse a number from: '{val_str}'.")
|
| 309 |
continue
|
| 310 |
delta *= scale
|
| 311 |
+
delta = -abs(delta) if op in ["decrease", "reduce"] else abs(delta)
|
|
|
|
|
|
|
|
|
|
| 312 |
idx = feat_names.index(fname)
|
| 313 |
new[idx] = new[idx] + delta
|
| 314 |
changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β **{new[idx]:.2f}**.")
|
| 315 |
continue
|
| 316 |
|
| 317 |
+
# 3) Increase/decrease by percent
|
| 318 |
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
|
| 319 |
if m:
|
| 320 |
op, feat_frag, perc_str = m.groups()
|
|
|
|
| 328 |
continue
|
| 329 |
perc *= scale
|
| 330 |
idx = feat_names.index(fname)
|
| 331 |
+
factor = 1.0 + (abs(perc) / 100.0 if op in ["increase", "raise"] else -abs(perc) / 100.0)
|
| 332 |
new[idx] = new[idx] * factor
|
| 333 |
changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β **{new[idx]:.2f}**.")
|
| 334 |
continue
|
| 335 |
|
| 336 |
+
# 4) Make higher/lower (no number)
|
| 337 |
m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
|
| 338 |
if m:
|
| 339 |
_, feat_frag, direction = m.groups()
|
|
|
|
| 356 |
|
| 357 |
return new, "\n".join(changes)
|
| 358 |
|
| 359 |
+
# ============================================================
|
| 360 |
# Step 1: Enter features & predict
|
| 361 |
+
# ============================================================
|
| 362 |
+
|
| 363 |
+
st.subheader("Step 1 β Enter features to generate a prediction in the order: [sepal length, sepal width, petal length, petal width]")
|
| 364 |
user_features = st.text_input("Enter feature values (comma-separated)", "")
|
| 365 |
predict_btn = st.button("π Predict and Explain")
|
| 366 |
|
|
|
|
| 376 |
st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
|
| 377 |
else:
|
| 378 |
st.session_state["input_features"] = features
|
|
|
|
| 379 |
st.markdown("### β¨ Entered Features")
|
| 380 |
st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
|
| 381 |
+
|
| 382 |
res = perform_api_call(features)
|
| 383 |
st.session_state["prediction_result"] = res
|
| 384 |
st.session_state["messages"] = []
|
| 385 |
st.success(summarize_prediction(res))
|
| 386 |
+
st.info("Scroll down to explore similar cases or chat.")
|
| 387 |
except Exception as e:
|
| 388 |
+
st.error(f"Error generating prediction: {e}")
|
| 389 |
+
|
| 390 |
+
# ============================================================
|
| 391 |
+
# Step 2 + 3
|
| 392 |
+
# ============================================================
|
| 393 |
|
|
|
|
| 394 |
if st.session_state["prediction_result"]:
|
| 395 |
st.divider()
|
| 396 |
st.subheader("Step 2 β Explore similar cases")
|
|
|
|
| 409 |
else:
|
| 410 |
st.write("No similar cases retrieved.")
|
| 411 |
|
|
|
|
| 412 |
st.divider()
|
| 413 |
st.subheader("Step 3 β Chat with the model about this prediction")
|
| 414 |
|
|
|
|
| 415 |
if "chat_mode" not in st.session_state:
|
| 416 |
st.session_state["chat_mode"] = "System"
|
| 417 |
st.session_state["chat_mode"] = st.radio(
|
| 418 |
"How should explanations be generated?",
|
| 419 |
["System", "LLM (Natural language)"],
|
| 420 |
+
index=0 if st.session_state["chat_mode"] == "System" else 1,
|
| 421 |
horizontal=True,
|
| 422 |
)
|
| 423 |
|
|
|
|
| 424 |
for role, content in st.session_state["messages"]:
|
| 425 |
with st.chat_message(role):
|
| 426 |
st.markdown(content)
|
| 427 |
|
|
|
|
| 428 |
if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
|
| 429 |
st.session_state["messages"].append(("user", user_q))
|
| 430 |
with st.chat_message("user"):
|
|
|
|
| 432 |
|
| 433 |
intent = interpret_question(user_q)
|
| 434 |
|
|
|
|
| 435 |
base_res = st.session_state["prediction_result"]
|
| 436 |
base_pred = base_res["prediction"]["y_pred"]
|
| 437 |
base_proba = base_res["prediction"]["proba"]
|
|
|
|
| 439 |
|
| 440 |
if intent == "explain":
|
| 441 |
if "LLM" in st.session_state["chat_mode"]:
|
| 442 |
+
with st.spinner("Generating LLM explanation (Nova)..."):
|
| 443 |
answer = llm_explain(base_res, feat_names)
|
| 444 |
st.session_state["messages"].append(("assistant", answer))
|
| 445 |
with st.chat_message("assistant"):
|
|
|
|
| 456 |
st.markdown(text)
|
| 457 |
|
| 458 |
elif intent == "what_if":
|
|
|
|
| 459 |
if st.session_state["input_features"] is None:
|
| 460 |
msg = "Please run a prediction first (Step 1) so I know your starting feature values."
|
| 461 |
st.session_state["messages"].append(("assistant", msg))
|
|
|
|
| 479 |
ctx = {
|
| 480 |
"change_applied": status,
|
| 481 |
"before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
|
| 482 |
+
"after": {"features": new_feats, "label": new_label, "proba": new_proba},
|
| 483 |
}
|
| 484 |
+
with st.spinner("Summarizing the effect (Nova)..."):
|
| 485 |
answer = llm_explain(new_res, feat_names, extra_context=ctx)
|
| 486 |
st.session_state["messages"].append(("assistant", answer))
|
| 487 |
with st.chat_message("assistant"):
|
| 488 |
st.markdown(answer)
|
| 489 |
else:
|
|
|
|
| 490 |
lines = [
|
| 491 |
f"**Change applied:** {status}",
|
| 492 |
f"**Before:** {base_label} (class `{base_pred}`) β confidence **{base_proba:.2f}**",
|
|
|
|
| 499 |
st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
|
| 500 |
|
| 501 |
else:
|
|
|
|
| 502 |
summary = summarize_prediction(base_res)
|
| 503 |
st.session_state["messages"].append(("assistant", summary))
|
| 504 |
with st.chat_message("assistant"):
|