EchoML / app /app_chat.py
tiffany101's picture
Update app/app_chat.py
bf4c1bc verified
import re
import json
import matplotlib.pyplot as plt
import streamlit as st
import requests
import os
from litellm import completion # pip install litellm
st.set_page_config(page_title="EchoML", page_icon="πŸ’¬", layout="wide")
st.title("πŸ’¬ Chat with Your Model (IRIS Edition)")
# -----------------------------
# Sidebar configuration
# -----------------------------
with st.sidebar:
st.header("Settings")
api_url = st.text_input(
"FastAPI endpoint",
value="https://query-your-model-api-784882848382.us-central1.run.app/explain",
)
model_path = st.text_input("Model path", value="Query_Your_Model/model_data/model.pkl")
feat_names_str = st.text_input(
"Feature names (comma-separated)",
value="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",
)
namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global")
# IMPORTANT: your retrieval.py shows:
# similarity = alpha * cos(SHAP) + (1 - alpha) * cos(features)
alpha = st.slider("Alpha (retrieval weight: SHAP vs features)", 0.0, 1.0, 0.7, 0.05)
k = st.slider("Top-K similar to retrieve", 1, 10, 5)
st.divider()
st.subheader("Nova (LLM) Settings")
# Model IDs based on your Nova console screenshot
nova_model_id = st.selectbox(
"Nova model",
options=["nova-micro-v1", "nova-lite-v1", "nova-pro-v1", "nova-premier-v1"],
index=1,
help="These are the Nova model IDs from the Nova developer console.",
)
temperature = st.slider("LLM temperature", 0.0, 1.0, 0.2, 0.05)
max_tokens = st.slider("LLM max tokens", 64, 1024, 350, 32)
feat_names = [s.strip() for s in feat_names_str.split(",")]
# -----------------------------
# Helpers
# -----------------------------
def label_from_pred(y_pred):
try:
num = int(round(float(y_pred)))
def show_similar_cases(res, n_display, feat_names):
sims = safe_similar_cases(res)
if not sims:
return "No similar cases were retrieved."
n = min(n_display, len(sims))
lines = [f"It found **{len(sims)}** similar reference cases (showing **{n}**):"]
for case in sims[:n]:
features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
lines.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
return "\n".join(lines)
def plot_shap_bar(topk):
feats = [f["feature"] for f in topk]
shap_vals = [f["shap"] for f in topk]
fig, ax = plt.subplots()
ax.barh(feats, shap_vals) # default colors
ax.set_xlabel("SHAP value (impact on prediction)")
ax.set_title("Feature importance for this prediction")
st.pyplot(fig)
msg = [
f"🌸 Based on these features, the model thinks it's **{label}** (class `{pred}`) with confidence **{proba:.2f}**.\n",
"### Key reasons (SHAP):",
]
if topk:
for case in sims[:n]:
features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])])
msg.append(f"- **{case['case_id']}** β†’ {features_named}, predicted as **{case['y_pred']}**.")
st.markdown("\n".join(msg))
plot_shap_bar(topk)
def interpret_question(user_q):
q = (user_q or "").lower()
if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]):
return "what_if"
if any(w in q for w in ["why", "explain", "reason"]):
return "explain"
if "similar" in q:
return "similar"
return "summary"
def perform_api_call(features):
payload = {
"model_path": model_path,
"feature_names": feat_names,
"features": features,
"namespace": namespace,
"retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace},
}
r = requests.post(api_url, json=payload, timeout=60)
r.raise_for_status()
return r.json()
# -----------------------------
# Nova LLM (via LiteLLM)
# -----------------------------
def nova_llm_text(prompt: str) -> str:
"""
Uses Amazon Nova API key from env: AMAZON_NOVA_API_KEY
Model route: amazon_nova/<model_id>
"""
api_key = os.getenv("AMAZON_NOVA_API_KEY")
if not api_key:
return (
"Nova API key not found. In Hugging Face β†’ Settings β†’ Variables and secrets, "
"add a *Secret* named `AMAZON_NOVA_API_KEY`."
)
# LiteLLM expects the provider key in env or passed; env is simplest for Spaces
os.environ["AMAZON_NOVA_API_KEY"] = api_key
try:
resp = completion(
model=f"amazon_nova/{nova_model_id}",
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content
except Exception as e:
return f"LLM explanation failed: {e}"
def llm_explain(res, feat_names, extra_context=None):
try:
pred = label_from_pred(res["prediction"]["y_pred"])
proba = res["prediction"]["proba"]
"probability": round(proba, 3),
"topk": topk,
"similar_examples_sample": sims[:3],
"extra_context": extra_context or {},
"retrieval_note": "alpha=1.0 prioritizes SHAP similarity; alpha=0.0 prioritizes feature similarity.",
}
prompt = (
"You are an explainability copilot. Explain to a non-technical user.\n\n"
f"DATA:\n{json.dumps(base_prompt, indent=2)}\n\n"
"Write a short, clear answer that covers:\n"
"- Why the model made the prediction (grounded in SHAP)\n"
"- Which features mattered most\n"
"- Why those features mattered\n"
"- 2-3 concrete experiments: tell them which feature values to increase/decrease and what to watch for"
)
return nova_llm_text(prompt)
except Exception as e:
return f"LLM explanation failed: {e}"
# -----------------------------
# What-if parsing
# -----------------------------
FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)")
def match_feature_name(fragment, feat_names):
frag = fragment.strip().lower()
best, best_score = None, -1
for name in feat_names:
return None
def default_delta(curr):
base = abs(curr) * 0.10
return base if base >= 0.1 else (0.1 if curr >= 0 else -0.1)
def apply_what_if(user_q, feat_names, current_features):
q = user_q.lower()
new = current_features.copy()
changes = []
parts = re.split(r",| and ", q)
modifier_scale = {
"slightly": 0.5,
"a bit": 0.5,
"a little": 0.5,
"somewhat": 0.7,
"moderately": 1.0,
"significantly": 1.5,
"greatly": 2.0,
"a lot": 2.0,
}
for part in parts:
part = part.strip()
if not part:
continue
scale = 1.0
for mod, factor in modifier_scale.items():
if mod in part:
scale = factor
part = part.replace(mod, "")
break
# set to value
m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part)
if m:
feat_frag, val_str = m.group(1), m.group(2)
fname = match_feature_name(feat_frag, feat_names)
if fname is None:
changes.append(f"Couldn't identify which feature to set from: '{feat_frag}'.")
continue
val = parse_numeric(val_str)
if val is None:
changes.append(f"Couldn't parse a number from: '{val_str}'.")
continue
idx = feat_names.index(fname)
new[idx] = val
changes.append(f"Set **{fname}** to **{val:.2f}**.")
continue
# +/- absolute
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part)
if m:
op, feat_frag, val_str = m.groups()
fname = match_feature_name(feat_frag, feat_names)
if fname is None:
changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
continue
delta = parse_numeric(val_str)
if delta is None:
changes.append(f"Couldn't parse a number from: '{val_str}'.")
continue
delta *= scale
delta = -abs(delta) if op in ["decrease", "reduce"] else abs(delta)
idx = feat_names.index(fname)
new[idx] = new[idx] + delta
changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β†’ **{new[idx]:.2f}**.")
continue
# +/- percent
m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part)
if m:
op, feat_frag, perc_str = m.groups()
fname = match_feature_name(feat_frag, feat_names)
if fname is None:
changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
continue
perc = parse_numeric(perc_str)
if perc is None:
changes.append(f"Couldn't parse a percentage from: '{perc_str}'.")
continue
perc *= scale
idx = feat_names.index(fname)
factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0)
new[idx] = new[idx] * factor
changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β†’ **{new[idx]:.2f}**.")
continue
# make higher/lower
m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part)
if m:
_, feat_frag, direction = m.groups()
fname = match_feature_name(feat_frag, feat_names)
if fname is None:
changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.")
continue
idx = feat_names.index(fname)
base_delta = default_delta(current_features[idx]) * scale
delta = base_delta if direction in ["higher", "increase", "raise"] else -abs(base_delta)
new[idx] = new[idx] + delta
verb = "Increased" if delta > 0 else "Decreased"
changes.append(f"{verb} **{fname}** by **{abs(delta):.2f}** (scaled {scale:.1f}Γ—) β†’ **{new[idx]:.2f}**.")
continue
changes.append(f"Couldn't parse instruction: '{part}'.")
if not changes:
return None, "No valid feature changes detected."
return new, "\n".join(changes)
# -----------------------------
# App state init
# -----------------------------
st.subheader("Step 1 – Enter features to generate a prediction in the order: [sepal length, sepal width, petal length, petal width]")
user_features = st.text_input("Enter feature values (comma-separated)", "")
predict_btn = st.button("πŸ” Predict and Explain")
st.session_state["messages"] = []
st.session_state["input_features"] = None
# -----------------------------
# Step 1: Predict
# -----------------------------
if predict_btn:
try:
features = [float(x.strip()) for x in user_features.split(",") if x.strip()]
st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.")
else:
st.session_state["input_features"] = features
st.markdown("### ✨ Entered Features")
st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)]))
res = perform_api_call(features)
st.session_state["prediction_result"] = res
st.session_state["messages"] = []
st.success(summarize_prediction(res))
st.info("Scroll down to explore similar cases or chat.")
except Exception as e:
st.error(f"Error contacting API: {e}")
# -----------------------------
# Step 2 + Step 3
# -----------------------------
if st.session_state["prediction_result"]:
st.divider()
st.subheader("Step 2 – Explore similar cases")
else:
st.write("No similar cases retrieved.")
st.divider()
st.subheader("Step 3 – Chat with the model about this prediction")
if "chat_mode" not in st.session_state:
st.session_state["chat_mode"] = "System"
st.session_state["chat_mode"] = st.radio(
"How should explanations be generated?",
["System", "LLM (Natural language)"],
index=0 if st.session_state["chat_mode"] == "System" else 1,
horizontal=True,
)
for role, content in st.session_state["messages"]:
with st.chat_message(role):
st.markdown(content)
if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"):
st.session_state["messages"].append(("user", user_q))
with st.chat_message("user"):
intent = interpret_question(user_q)
base_res = st.session_state["prediction_result"]
base_pred = base_res["prediction"]["y_pred"]
base_proba = base_res["prediction"]["proba"]
if intent == "explain":
if "LLM" in st.session_state["chat_mode"]:
with st.spinner("Generating Nova explanation..."):
answer = llm_explain(base_res, feat_names)
st.session_state["messages"].append(("assistant", answer))
with st.chat_message("assistant"):
st.markdown(text)
elif intent == "what_if":
if st.session_state["input_features"] is None:
msg = "Please run a prediction first (Step 1) so I know your starting feature values."
st.session_state["messages"].append(("assistant", msg))
with st.chat_message("assistant"):
st.markdown(status)
else:
with st.spinner("Recomputing with your change..."):
new_res = perform_api_call(new_feats)
new_pred = new_res["prediction"]["y_pred"]
ctx = {
"change_applied": status,
"before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba},
"after": {"features": new_feats, "label": new_label, "proba": new_proba},
}
with st.spinner("Summarizing the effect with Nova..."):
answer = llm_explain(new_res, feat_names, extra_context=ctx)
st.session_state["messages"].append(("assistant", answer))
with st.chat_message("assistant"):
st.markdown(answer)
else:
lines = [
f"**Change applied:** {status}",
f"**Before:** {base_label} (class `{base_pred}`) β€” confidence **{base_proba:.2f}**",
st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above."))
else:
summary = summarize_prediction(base_res)
st.session_state["messages"].append(("assistant", summary))
with st.chat_message("assistant"):
st.markdown(summary)