Spaces:
Build error
Build error
| import re | |
| import json | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| import requests | |
| import os | |
| from litellm import completion # pip install litellm | |
| st.set_page_config(page_title="EchoML", page_icon="π¬", layout="wide") | |
| st.title("π¬ Chat with Your Model (IRIS Edition)") | |
| # ----------------------------- | |
| # Sidebar configuration | |
| # ----------------------------- | |
| with st.sidebar: | |
| st.header("Settings") | |
| api_url = st.text_input( | |
| "FastAPI endpoint", | |
| value="https://query-your-model-api-784882848382.us-central1.run.app/explain", | |
| ) | |
| model_path = st.text_input("Model path", value="Query_Your_Model/model_data/model.pkl") | |
| feat_names_str = st.text_input( | |
| "Feature names (comma-separated)", | |
| value="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)", | |
| ) | |
| namespace = st.text_input("Namespace", value="Query_Your_Model/data/base_indices/iris_global") | |
| # IMPORTANT: your retrieval.py shows: | |
| # similarity = alpha * cos(SHAP) + (1 - alpha) * cos(features) | |
| alpha = st.slider("Alpha (retrieval weight: SHAP vs features)", 0.0, 1.0, 0.7, 0.05) | |
| k = st.slider("Top-K similar to retrieve", 1, 10, 5) | |
| st.divider() | |
| st.subheader("Nova (LLM) Settings") | |
| # Model IDs based on your Nova console screenshot | |
| nova_model_id = st.selectbox( | |
| "Nova model", | |
| options=["nova-micro-v1", "nova-lite-v1", "nova-pro-v1", "nova-premier-v1"], | |
| index=1, | |
| help="These are the Nova model IDs from the Nova developer console.", | |
| ) | |
| temperature = st.slider("LLM temperature", 0.0, 1.0, 0.2, 0.05) | |
| max_tokens = st.slider("LLM max tokens", 64, 1024, 350, 32) | |
| feat_names = [s.strip() for s in feat_names_str.split(",")] | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def label_from_pred(y_pred): | |
| try: | |
| num = int(round(float(y_pred))) | |
| def show_similar_cases(res, n_display, feat_names): | |
| sims = safe_similar_cases(res) | |
| if not sims: | |
| return "No similar cases were retrieved." | |
| n = min(n_display, len(sims)) | |
| lines = [f"It found **{len(sims)}** similar reference cases (showing **{n}**):"] | |
| for case in sims[:n]: | |
| features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])]) | |
| lines.append(f"- **{case['case_id']}** β {features_named}, predicted as **{case['y_pred']}**.") | |
| return "\n".join(lines) | |
| def plot_shap_bar(topk): | |
| feats = [f["feature"] for f in topk] | |
| shap_vals = [f["shap"] for f in topk] | |
| fig, ax = plt.subplots() | |
| ax.barh(feats, shap_vals) # default colors | |
| ax.set_xlabel("SHAP value (impact on prediction)") | |
| ax.set_title("Feature importance for this prediction") | |
| st.pyplot(fig) | |
| msg = [ | |
| f"πΈ Based on these features, the model thinks it's **{label}** (class `{pred}`) with confidence **{proba:.2f}**.\n", | |
| "### Key reasons (SHAP):", | |
| ] | |
| if topk: | |
| for case in sims[:n]: | |
| features_named = ", ".join([f"{name} = {val:.2f}" for name, val in zip(feat_names, case["features"])]) | |
| msg.append(f"- **{case['case_id']}** β {features_named}, predicted as **{case['y_pred']}**.") | |
| st.markdown("\n".join(msg)) | |
| plot_shap_bar(topk) | |
| def interpret_question(user_q): | |
| q = (user_q or "").lower() | |
| if any(w in q for w in ["what if", "increase", "decrease", "set ", "make ", "higher", "lower", "raise", "reduce", "change"]): | |
| return "what_if" | |
| if any(w in q for w in ["why", "explain", "reason"]): | |
| return "explain" | |
| if "similar" in q: | |
| return "similar" | |
| return "summary" | |
| def perform_api_call(features): | |
| payload = { | |
| "model_path": model_path, | |
| "feature_names": feat_names, | |
| "features": features, | |
| "namespace": namespace, | |
| "retrieval": {"alpha": alpha, "k": k, "use_retrieval": True, "namespace": namespace}, | |
| } | |
| r = requests.post(api_url, json=payload, timeout=60) | |
| r.raise_for_status() | |
| return r.json() | |
| # ----------------------------- | |
| # Nova LLM (via LiteLLM) | |
| # ----------------------------- | |
| def nova_llm_text(prompt: str) -> str: | |
| """ | |
| Uses Amazon Nova API key from env: AMAZON_NOVA_API_KEY | |
| Model route: amazon_nova/<model_id> | |
| """ | |
| api_key = os.getenv("AMAZON_NOVA_API_KEY") | |
| if not api_key: | |
| return ( | |
| "Nova API key not found. In Hugging Face β Settings β Variables and secrets, " | |
| "add a *Secret* named `AMAZON_NOVA_API_KEY`." | |
| ) | |
| # LiteLLM expects the provider key in env or passed; env is simplest for Spaces | |
| os.environ["AMAZON_NOVA_API_KEY"] = api_key | |
| try: | |
| resp = completion( | |
| model=f"amazon_nova/{nova_model_id}", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return resp.choices[0].message.content | |
| except Exception as e: | |
| return f"LLM explanation failed: {e}" | |
| def llm_explain(res, feat_names, extra_context=None): | |
| try: | |
| pred = label_from_pred(res["prediction"]["y_pred"]) | |
| proba = res["prediction"]["proba"] | |
| "probability": round(proba, 3), | |
| "topk": topk, | |
| "similar_examples_sample": sims[:3], | |
| "extra_context": extra_context or {}, | |
| "retrieval_note": "alpha=1.0 prioritizes SHAP similarity; alpha=0.0 prioritizes feature similarity.", | |
| } | |
| prompt = ( | |
| "You are an explainability copilot. Explain to a non-technical user.\n\n" | |
| f"DATA:\n{json.dumps(base_prompt, indent=2)}\n\n" | |
| "Write a short, clear answer that covers:\n" | |
| "- Why the model made the prediction (grounded in SHAP)\n" | |
| "- Which features mattered most\n" | |
| "- Why those features mattered\n" | |
| "- 2-3 concrete experiments: tell them which feature values to increase/decrease and what to watch for" | |
| ) | |
| return nova_llm_text(prompt) | |
| except Exception as e: | |
| return f"LLM explanation failed: {e}" | |
| # ----------------------------- | |
| # What-if parsing | |
| # ----------------------------- | |
| FEATURE_NAME_PAT = re.compile(r"([a-zA-Z][a-zA-Z0-9 _\-\(\)]*)") | |
| def match_feature_name(fragment, feat_names): | |
| frag = fragment.strip().lower() | |
| best, best_score = None, -1 | |
| for name in feat_names: | |
| return None | |
| def default_delta(curr): | |
| base = abs(curr) * 0.10 | |
| return base if base >= 0.1 else (0.1 if curr >= 0 else -0.1) | |
| def apply_what_if(user_q, feat_names, current_features): | |
| q = user_q.lower() | |
| new = current_features.copy() | |
| changes = [] | |
| parts = re.split(r",| and ", q) | |
| modifier_scale = { | |
| "slightly": 0.5, | |
| "a bit": 0.5, | |
| "a little": 0.5, | |
| "somewhat": 0.7, | |
| "moderately": 1.0, | |
| "significantly": 1.5, | |
| "greatly": 2.0, | |
| "a lot": 2.0, | |
| } | |
| for part in parts: | |
| part = part.strip() | |
| if not part: | |
| continue | |
| scale = 1.0 | |
| for mod, factor in modifier_scale.items(): | |
| if mod in part: | |
| scale = factor | |
| part = part.replace(mod, "") | |
| break | |
| # set to value | |
| m = re.search(r"(?:set|what if|change|increase|decrease|raise|reduce)\s+(.*?)\s*(?:=|to)\s*([-+]?\d*\.?\d+)", part) | |
| if m: | |
| feat_frag, val_str = m.group(1), m.group(2) | |
| fname = match_feature_name(feat_frag, feat_names) | |
| if fname is None: | |
| changes.append(f"Couldn't identify which feature to set from: '{feat_frag}'.") | |
| continue | |
| val = parse_numeric(val_str) | |
| if val is None: | |
| changes.append(f"Couldn't parse a number from: '{val_str}'.") | |
| continue | |
| idx = feat_names.index(fname) | |
| new[idx] = val | |
| changes.append(f"Set **{fname}** to **{val:.2f}**.") | |
| continue | |
| # +/- absolute | |
| m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\b(?!%)", part) | |
| if m: | |
| op, feat_frag, val_str = m.groups() | |
| fname = match_feature_name(feat_frag, feat_names) | |
| if fname is None: | |
| changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.") | |
| continue | |
| delta = parse_numeric(val_str) | |
| if delta is None: | |
| changes.append(f"Couldn't parse a number from: '{val_str}'.") | |
| continue | |
| delta *= scale | |
| delta = -abs(delta) if op in ["decrease", "reduce"] else abs(delta) | |
| idx = feat_names.index(fname) | |
| new[idx] = new[idx] + delta | |
| changes.append(f"{'Increased' if delta>0 else 'Decreased'} **{fname}** by **{abs(delta):.2f}** β **{new[idx]:.2f}**.") | |
| continue | |
| # +/- percent | |
| m = re.search(r"(increase|decrease|raise|reduce)\s+(.*?)\s+by\s+([-+]?\d*\.?\d+)\s*%", part) | |
| if m: | |
| op, feat_frag, perc_str = m.groups() | |
| fname = match_feature_name(feat_frag, feat_names) | |
| if fname is None: | |
| changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.") | |
| continue | |
| perc = parse_numeric(perc_str) | |
| if perc is None: | |
| changes.append(f"Couldn't parse a percentage from: '{perc_str}'.") | |
| continue | |
| perc *= scale | |
| idx = feat_names.index(fname) | |
| factor = 1.0 + (abs(perc)/100.0 if op in ["increase","raise"] else -abs(perc)/100.0) | |
| new[idx] = new[idx] * factor | |
| changes.append(f"{op.title()}d **{fname}** by **{abs(perc):.0f}%** β **{new[idx]:.2f}**.") | |
| continue | |
| # make higher/lower | |
| m = re.search(r"(make|set)?\s*(.*?)\s*(higher|lower|increase|decrease|raise|reduce)", part) | |
| if m: | |
| _, feat_frag, direction = m.groups() | |
| fname = match_feature_name(feat_frag, feat_names) | |
| if fname is None: | |
| changes.append(f"Couldn't identify which feature to adjust from: '{feat_frag}'.") | |
| continue | |
| idx = feat_names.index(fname) | |
| base_delta = default_delta(current_features[idx]) * scale | |
| delta = base_delta if direction in ["higher", "increase", "raise"] else -abs(base_delta) | |
| new[idx] = new[idx] + delta | |
| verb = "Increased" if delta > 0 else "Decreased" | |
| changes.append(f"{verb} **{fname}** by **{abs(delta):.2f}** (scaled {scale:.1f}Γ) β **{new[idx]:.2f}**.") | |
| continue | |
| changes.append(f"Couldn't parse instruction: '{part}'.") | |
| if not changes: | |
| return None, "No valid feature changes detected." | |
| return new, "\n".join(changes) | |
| # ----------------------------- | |
| # App state init | |
| # ----------------------------- | |
| st.subheader("Step 1 β Enter features to generate a prediction in the order: [sepal length, sepal width, petal length, petal width]") | |
| user_features = st.text_input("Enter feature values (comma-separated)", "") | |
| predict_btn = st.button("π Predict and Explain") | |
| st.session_state["messages"] = [] | |
| st.session_state["input_features"] = None | |
| # ----------------------------- | |
| # Step 1: Predict | |
| # ----------------------------- | |
| if predict_btn: | |
| try: | |
| features = [float(x.strip()) for x in user_features.split(",") if x.strip()] | |
| st.warning(f"Expected {len(feat_names)} values ({', '.join(feat_names)}), but got {len(features)}.") | |
| else: | |
| st.session_state["input_features"] = features | |
| st.markdown("### β¨ Entered Features") | |
| st.markdown("\n".join([f"- **{n}** = {v:.2f}" for n, v in zip(feat_names, features)])) | |
| res = perform_api_call(features) | |
| st.session_state["prediction_result"] = res | |
| st.session_state["messages"] = [] | |
| st.success(summarize_prediction(res)) | |
| st.info("Scroll down to explore similar cases or chat.") | |
| except Exception as e: | |
| st.error(f"Error contacting API: {e}") | |
| # ----------------------------- | |
| # Step 2 + Step 3 | |
| # ----------------------------- | |
| if st.session_state["prediction_result"]: | |
| st.divider() | |
| st.subheader("Step 2 β Explore similar cases") | |
| else: | |
| st.write("No similar cases retrieved.") | |
| st.divider() | |
| st.subheader("Step 3 β Chat with the model about this prediction") | |
| if "chat_mode" not in st.session_state: | |
| st.session_state["chat_mode"] = "System" | |
| st.session_state["chat_mode"] = st.radio( | |
| "How should explanations be generated?", | |
| ["System", "LLM (Natural language)"], | |
| index=0 if st.session_state["chat_mode"] == "System" else 1, | |
| horizontal=True, | |
| ) | |
| for role, content in st.session_state["messages"]: | |
| with st.chat_message(role): | |
| st.markdown(content) | |
| if user_q := st.chat_input("Ask e.g. 'Why this prediction?' or 'Increase petal length by 0.3' or 'set sepal width to 3.8'"): | |
| st.session_state["messages"].append(("user", user_q)) | |
| with st.chat_message("user"): | |
| intent = interpret_question(user_q) | |
| base_res = st.session_state["prediction_result"] | |
| base_pred = base_res["prediction"]["y_pred"] | |
| base_proba = base_res["prediction"]["proba"] | |
| if intent == "explain": | |
| if "LLM" in st.session_state["chat_mode"]: | |
| with st.spinner("Generating Nova explanation..."): | |
| answer = llm_explain(base_res, feat_names) | |
| st.session_state["messages"].append(("assistant", answer)) | |
| with st.chat_message("assistant"): | |
| st.markdown(text) | |
| elif intent == "what_if": | |
| if st.session_state["input_features"] is None: | |
| msg = "Please run a prediction first (Step 1) so I know your starting feature values." | |
| st.session_state["messages"].append(("assistant", msg)) | |
| with st.chat_message("assistant"): | |
| st.markdown(status) | |
| else: | |
| with st.spinner("Recomputing with your change..."): | |
| new_res = perform_api_call(new_feats) | |
| new_pred = new_res["prediction"]["y_pred"] | |
| ctx = { | |
| "change_applied": status, | |
| "before": {"features": st.session_state["input_features"], "label": base_label, "proba": base_proba}, | |
| "after": {"features": new_feats, "label": new_label, "proba": new_proba}, | |
| } | |
| with st.spinner("Summarizing the effect with Nova..."): | |
| answer = llm_explain(new_res, feat_names, extra_context=ctx) | |
| st.session_state["messages"].append(("assistant", answer)) | |
| with st.chat_message("assistant"): | |
| st.markdown(answer) | |
| else: | |
| lines = [ | |
| f"**Change applied:** {status}", | |
| f"**Before:** {base_label} (class `{base_pred}`) β confidence **{base_proba:.2f}**", | |
| st.session_state["messages"].append(("assistant", "What-if comparison + SHAP shown above.")) | |
| else: | |
| summary = summarize_prediction(base_res) | |
| st.session_state["messages"].append(("assistant", summary)) | |
| with st.chat_message("assistant"): | |
| st.markdown(summary) |