|
|
import os |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import pandas as pd |
|
|
from langchain_core.messages import HumanMessage |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
|
|
from agent.agent_graph import build_app |
|
|
from pipeline.utils_cool import df_to_payload, parse_user_choice |
|
|
|
|
|
from .runtime_ctx import get_df_summary |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
class ChatbotHandler: |
|
|
def __init__(self): |
|
|
self.ctx: Dict[str, Any] = { |
|
|
"graph_app": None, |
|
|
"state": { |
|
|
"df_payload": None, |
|
|
"results": [], |
|
|
"steps_taken": 0, |
|
|
"confirmed_step": None, |
|
|
"confirmed_params": {}, |
|
|
"last_task": None, |
|
|
"plan": None, |
|
|
"messages": [], |
|
|
"max_steps": 8, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
self._boot_text: Optional[str] = None |
|
|
|
|
|
|
|
|
self.llm = ChatGoogleGenerativeAI( |
|
|
model="gemini-2.5-flash-lite", |
|
|
temperature=0, |
|
|
api_key=os.getenv("GOOGLE_API_KEY"), |
|
|
) |
|
|
|
|
|
def _format_summary(self, s: Dict[str, Any]) -> str: |
|
|
cols = s.get("columns") or [] |
|
|
dtypes = s.get("dtypes") or {} |
|
|
shape = s.get("shape") or (None, None) |
|
|
label_guess = s.get("label_guess") or "None" |
|
|
task_guess = s.get("task_guess") or "Unknown" |
|
|
issues = s.get("issues") or [] |
|
|
|
|
|
|
|
|
dt_pairs = [f"{k}: {v}" for k, v in list(dtypes.items())[:8]] |
|
|
if len(dtypes) > 8: |
|
|
dt_pairs.append("…") |
|
|
|
|
|
lines = [ |
|
|
"### Dataset summary", |
|
|
f"- Shape: {shape[0]} rows × {shape[1]} columns", |
|
|
f"- Columns: {', '.join(map(str, cols[:10]))}{'…' if len(cols) > 10 else ''}", |
|
|
f"- Dtypes: {', '.join(dt_pairs)}", |
|
|
f"- Label guess: {label_guess}", |
|
|
f"- Task guess: {task_guess}", |
|
|
] |
|
|
if issues: |
|
|
lines.append(f"- Potential issues: {('; '.join(issues[:3]))}{'…' if len(issues) > 3 else ''}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def update_context(self, file_path: Optional[str], data_type: Optional[str], df: Optional["pd.DataFrame"]): |
|
|
if df is None: |
|
|
return "" |
|
|
|
|
|
|
|
|
self.ctx["graph_app"] = build_app(self.llm) |
|
|
df_payload = df_to_payload(df) |
|
|
st = self.ctx["state"] |
|
|
st.update({ |
|
|
"df_payload": df_payload, |
|
|
"results": [], |
|
|
"steps_taken": 0, |
|
|
"confirmed_step": None, |
|
|
"confirmed_params": {}, |
|
|
"last_task": None, |
|
|
"plan": None, |
|
|
"messages": [HumanMessage(content="A new dataset was uploaded. Start the workflow.")], |
|
|
"max_steps": 8, |
|
|
}) |
|
|
|
|
|
final = self.ctx["graph_app"].invoke(st) |
|
|
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]: |
|
|
st[k] = final.get(k, st.get(k)) |
|
|
|
|
|
|
|
|
s = get_df_summary() or {} |
|
|
summary_text = self._format_summary(s) |
|
|
|
|
|
|
|
|
task_guess = (s.get("task_guess") or "").lower() |
|
|
label_guess = s.get("label_guess") |
|
|
needs_task = task_guess not in {"classification", "regression", "unsupervised"} |
|
|
needs_label = (task_guess in {"classification", "regression"}) and (not label_guess) |
|
|
|
|
|
if needs_task or needs_label: |
|
|
ask = "\n\nPlease confirm the task" + (" and label column" if needs_label else "") + \ |
|
|
". For example: `task=classification label=noisy_letter_grade`." |
|
|
else: |
|
|
ask = f"\n\nIf that looks right, say `confirm task={task_guess}" + \ |
|
|
(f" label={label_guess}`" if label_guess else "`") + \ |
|
|
" and I’ll fetch SOTA and propose a plan." |
|
|
|
|
|
self._boot_text = summary_text + ask |
|
|
return self._boot_text |
|
|
|
|
|
|
|
|
def respond(self, message: str, history: List): |
|
|
if history is None: |
|
|
history = [] |
|
|
msg = (message or "").strip() |
|
|
if not msg: |
|
|
return history, "" |
|
|
|
|
|
|
|
|
|
|
|
if self._boot_text and len(history) == 0: |
|
|
history.append(("[system]", self._boot_text)) |
|
|
self._boot_text = None |
|
|
|
|
|
|
|
|
if self.ctx.get("graph_app") is None: |
|
|
history.append((msg, "Please upload a dataset first.")) |
|
|
return history, "" |
|
|
|
|
|
st = self.ctx["state"] |
|
|
|
|
|
|
|
|
step, params = parse_user_choice(msg) |
|
|
if step: |
|
|
st["confirmed_step"] = step |
|
|
st["confirmed_params"] = {**(st.get("confirmed_params") or {}), **params} |
|
|
|
|
|
|
|
|
messages = (st.get("messages") or []) + [HumanMessage(content=msg)] |
|
|
turn_state = { |
|
|
"messages": messages, |
|
|
"df_payload": st.get("df_payload"), |
|
|
"results": st.get("results", []), |
|
|
"steps_taken": st.get("steps_taken", 0), |
|
|
"max_steps": max(8, st.get("steps_taken", 0) + 4), |
|
|
"confirmed_step": st.get("confirmed_step"), |
|
|
"confirmed_params": st.get("confirmed_params", {}), |
|
|
"last_task": st.get("last_task"), |
|
|
"plan": st.get("plan"), |
|
|
} |
|
|
|
|
|
|
|
|
final = self.ctx["graph_app"].invoke(turn_state) |
|
|
|
|
|
|
|
|
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]: |
|
|
st[k] = final.get(k, turn_state.get(k, st.get(k))) |
|
|
|
|
|
|
|
|
reply = self._extract_ai_text(final.get("messages", [])) or "Done." |
|
|
history.append((msg, reply)) |
|
|
return history, "" |
|
|
|
|
|
|
|
|
def _extract_ai_text(self, messages: List[Any]) -> str: |
|
|
def coerce_text(content: Any) -> str: |
|
|
if content is None: return "" |
|
|
if isinstance(content, str): return content |
|
|
if isinstance(content, list): |
|
|
parts = [] |
|
|
for c in content: |
|
|
if isinstance(c, dict): |
|
|
parts.append(str(c.get("text") or c.get("content") or c.get("data") or "")) |
|
|
else: |
|
|
parts.append(str(c)) |
|
|
return " ".join(p for p in parts if p) |
|
|
return str(content) |
|
|
|
|
|
for m in reversed(messages or []): |
|
|
role = getattr(m, "type", None) or getattr(m, "role", None) |
|
|
if role in ("ai", "assistant", "aimessage"): |
|
|
return coerce_text(getattr(m, "content", None)) |
|
|
if isinstance(m, dict): |
|
|
r = (m.get("role") or m.get("type") or "").lower() |
|
|
if r in ("assistant", "ai", "aimessage"): |
|
|
return coerce_text(m.get("content")) |
|
|
return "" |
|
|
|