File size: 7,712 Bytes
458c8e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
from typing import Any, Dict, List, Optional
import pandas as pd
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from agent.agent_graph import build_app
from pipeline.utils_cool import df_to_payload, parse_user_choice
from .runtime_ctx import get_df_summary
from dotenv import load_dotenv
load_dotenv()
class ChatbotHandler:
def __init__(self):
self.ctx: Dict[str, Any] = {
"graph_app": None, # LangGraph app
"state": { # mirrors your prior STATE
"df_payload": None,
"results": [],
"steps_taken": 0,
"confirmed_step": None,
"confirmed_params": {},
"last_task": None,
"plan": None,
"messages": [],
"max_steps": 8,
},
}
# keep chat UI history in the component; we only need to return a reply string to it
# but Gradio Chatbot expects (history, ""), so we'll append our reply to history.
self._boot_text: Optional[str] = None # first reply after upload
# LLM for the graph
self.llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite",
temperature=0,
api_key=os.getenv("GOOGLE_API_KEY"),
)
def _format_summary(self, s: Dict[str, Any]) -> str:
cols = s.get("columns") or []
dtypes = s.get("dtypes") or {}
shape = s.get("shape") or (None, None)
label_guess = s.get("label_guess") or "None"
task_guess = s.get("task_guess") or "Unknown"
issues = s.get("issues") or []
# keep it concise but helpful
dt_pairs = [f"{k}: {v}" for k, v in list(dtypes.items())[:8]]
if len(dtypes) > 8:
dt_pairs.append("…")
lines = [
"### Dataset summary",
f"- Shape: {shape[0]} rows × {shape[1]} columns",
f"- Columns: {', '.join(map(str, cols[:10]))}{'…' if len(cols) > 10 else ''}",
f"- Dtypes: {', '.join(dt_pairs)}",
f"- Label guess: {label_guess}",
f"- Task guess: {task_guess}",
]
if issues:
lines.append(f"- Potential issues: {('; '.join(issues[:3]))}{'…' if len(issues) > 3 else ''}")
return "\n".join(lines)
# ---------------- Boot on upload: run inspect + SOTA + plan ----------------
def update_context(self, file_path: Optional[str], data_type: Optional[str], df: Optional["pd.DataFrame"]):
if df is None:
return ""
# (Re)build graph + seed state (unchanged)
self.ctx["graph_app"] = build_app(self.llm)
df_payload = df_to_payload(df)
st = self.ctx["state"]
st.update({
"df_payload": df_payload,
"results": [],
"steps_taken": 0,
"confirmed_step": None,
"confirmed_params": {},
"last_task": None,
"plan": None,
"messages": [HumanMessage(content="A new dataset was uploaded. Start the workflow.")],
"max_steps": 8,
})
final = self.ctx["graph_app"].invoke(st)
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
st[k] = final.get(k, st.get(k))
# Build the boot text from the stored summary (authoritative + consistent)
s = get_df_summary() or {}
summary_text = self._format_summary(s)
# Decide whether to ask for confirmation
task_guess = (s.get("task_guess") or "").lower()
label_guess = s.get("label_guess")
needs_task = task_guess not in {"classification", "regression", "unsupervised"}
needs_label = (task_guess in {"classification", "regression"}) and (not label_guess)
if needs_task or needs_label:
ask = "\n\nPlease confirm the task" + (" and label column" if needs_label else "") + \
". For example: `task=classification label=noisy_letter_grade`."
else:
ask = f"\n\nIf that looks right, say `confirm task={task_guess}" + \
(f" label={label_guess}`" if label_guess else "`") + \
" and I’ll fetch SOTA and propose a plan."
self._boot_text = summary_text + ask
return self._boot_text
# ---------------- One chat turn → graph turn ----------------
def respond(self, message: str, history: List):
if history is None:
history = []
msg = (message or "").strip()
if not msg:
return history, ""
# If we have a prepared boot reply (from upload) and the chat is empty,
# show it before processing the user's first message.
if self._boot_text and len(history) == 0:
history.append(("[system]", self._boot_text))
self._boot_text = None
# Require a booted graph
if self.ctx.get("graph_app") is None:
history.append((msg, "Please upload a dataset first."))
return history, ""
st = self.ctx["state"]
# Allow quick “run X a=b” parsing before we call the graph (same as your old handle_chat)
step, params = parse_user_choice(msg)
if step:
st["confirmed_step"] = step
st["confirmed_params"] = {**(st.get("confirmed_params") or {}), **params}
# Build this turn’s input state
messages = (st.get("messages") or []) + [HumanMessage(content=msg)]
turn_state = {
"messages": messages,
"df_payload": st.get("df_payload"),
"results": st.get("results", []),
"steps_taken": st.get("steps_taken", 0),
"max_steps": max(8, st.get("steps_taken", 0) + 4),
"confirmed_step": st.get("confirmed_step"),
"confirmed_params": st.get("confirmed_params", {}),
"last_task": st.get("last_task"),
"plan": st.get("plan"),
}
# Invoke graph for this turn
final = self.ctx["graph_app"].invoke(turn_state)
# Persist state back
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
st[k] = final.get(k, turn_state.get(k, st.get(k)))
# Extract assistant text
reply = self._extract_ai_text(final.get("messages", [])) or "Done."
history.append((msg, reply))
return history, ""
# ---------------- helper: extract last AI string ----------------
def _extract_ai_text(self, messages: List[Any]) -> str:
def coerce_text(content: Any) -> str:
if content is None: return ""
if isinstance(content, str): return content
if isinstance(content, list):
parts = []
for c in content:
if isinstance(c, dict):
parts.append(str(c.get("text") or c.get("content") or c.get("data") or ""))
else:
parts.append(str(c))
return " ".join(p for p in parts if p)
return str(content)
for m in reversed(messages or []):
role = getattr(m, "type", None) or getattr(m, "role", None)
if role in ("ai", "assistant", "aimessage"):
return coerce_text(getattr(m, "content", None))
if isinstance(m, dict):
r = (m.get("role") or m.get("type") or "").lower()
if r in ("assistant", "ai", "aimessage"):
return coerce_text(m.get("content"))
return ""
|