data-analysis-agent / streamlit_app.py
arthikrangan's picture
Upload 2 files
aa96015 verified
import os
import sys
import duckdb
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Tuple
import streamlit as st
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
except Exception:
pass
# ---------- Basic page setup ----------
st.set_page_config(page_title="Excel → Dataset", page_icon="📊", layout="wide")
PRIMARY_DIR = Path(__file__).parent.resolve()
UPLOAD_DIR = PRIMARY_DIR / "uploads"
DB_DIR = PRIMARY_DIR / "dbs"
SCRIPT_PATH = PRIMARY_DIR / "source_to_duckdb.py" # must be colocated
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
DB_DIR.mkdir(parents=True, exist_ok=True)
st.markdown(
"""
<style>
.logbox { border: 1px solid #e5e7eb; background:#fafafa; padding:10px; border-radius:12px; }
.logbox code { white-space: pre-wrap; font-size: 0.85rem; }
</style>
""",
unsafe_allow_html=True,
)
st.title("Data Analysis Agent")
# --------- Session state helpers ---------
if "processing" not in st.session_state:
st.session_state.processing = False
if "processed_key" not in st.session_state:
st.session_state.processed_key = None
if "last_overview_md" not in st.session_state:
st.session_state.last_overview_md = None
if "last_preview_items" not in st.session_state:
st.session_state.last_preview_items = [] # list of dicts: {'table_ref': str, 'label': str}
def _file_key(uploaded) -> str:
# unique-ish key per upload (name + size)
try:
size = len(uploaded.getbuffer())
except Exception:
size = 0
return f"{uploaded.name}:{size}"
# ---------- DuckDB helpers ----------
def list_user_tables(con: duckdb.DuckDBPyConnection) -> List[str]:
"""
Robust discovery:
1) information_schema.tables for BASE TABLE/VIEW in any schema, excluding __*
2) duckdb_tables() as a secondary path
3) __excel_tables mapping + existence check as last resort
"""
# 1) information_schema (most portable)
try:
q = (
"SELECT table_schema, table_name "
"FROM information_schema.tables "
"WHERE table_type IN ('BASE TABLE','VIEW') "
"AND table_name NOT LIKE '__%%' "
"ORDER BY table_schema, table_name"
)
rows = con.execute(q).fetchall()
names = []
for schema, name in rows:
if (schema or '').lower() == 'main':
names.append(name)
else:
names.append(f'{schema}."{name}"')
if names:
return names
except Exception:
try:
rows = con.execute(
"SELECT sheet_name, table_name, inferred_title, original_title_text, block_index, start_row "
"FROM __excel_tables ORDER BY block_index, start_row"
).fetchall()
for sheet_name, table_name, inferred_title, original_title_text, block_index, start_row in rows:
if table_name not in want_names:
continue
title = inferred_title or original_title_text or 'untitled'
mapping[table_name] = {'sheet_name': sheet_name, 'title': title}
except Exception:
pass
# 2) duckdb_tables()
try:
q2 = (
"SELECT schema_name, table_name "
"FROM duckdb_tables() "
"WHERE table_type = 'BASE TABLE' "
"AND table_name NOT LIKE '__%%' "
"ORDER BY schema_name, table_name"
)
rows = con.execute(q2).fetchall()
names = []
for schema, name in rows:
if (schema or '').lower() == 'main':
names.append(name)
else:
names.append(f'{schema}."{name}"')
if names:
return names
except Exception:
try:
rows = con.execute(
"SELECT sheet_name, table_name, inferred_title, original_title_text, block_index, start_row "
"FROM __excel_tables ORDER BY block_index, start_row"
).fetchall()
for sheet_name, table_name, inferred_title, original_title_text, block_index, start_row in rows:
if table_name not in want_names:
continue
title = inferred_title or original_title_text or 'untitled'
mapping[table_name] = {'sheet_name': sheet_name, 'title': title}
except Exception:
pass
# 3) Fallback to metadata table
try:
meta = con.execute("SELECT DISTINCT table_name FROM __file_tables").fetchall()
names = []
for (t,) in meta:
try:
con.execute(f'SELECT 1 FROM "{t}" LIMIT 1').fetchone()
names.append(t)
except Exception:
continue
return names
except Exception:
# Fallback to legacy excel metadata table if unified not present
try:
meta = con.execute("SELECT DISTINCT table_name FROM __excel_tables").fetchall()
names = []
for (t,) in meta:
try:
con.execute(f'SELECT 1 FROM "{t}" LIMIT 1').fetchone()
names.append(t)
except Exception:
continue
return names
except Exception:
return []
def get_columns(con: duckdb.DuckDBPyConnection, table: str) -> List[Tuple[str,str]]:
# Normalize table name for information_schema lookup
if table.lower().startswith("main."):
tname = table.split('.', 1)[1].strip('"')
schema_filter = 'main'
elif '.' in table:
schema_filter, tname_raw = table.split('.', 1)
tname = tname_raw.strip('"')
else:
schema_filter = 'main'
tname = table.strip('"')
q = (
"SELECT column_name, data_type "
"FROM information_schema.columns "
"WHERE table_schema=? AND table_name=? "
"ORDER BY ordinal_position"
)
return con.execute(q, [schema_filter, tname]).fetchall()
def detect_year_column(con, table: str, col: str) -> bool:
try:
sql = (
f'SELECT AVG(CASE WHEN TRY_CAST("{col}" AS INTEGER) BETWEEN 1900 AND 2100 '
f'THEN 1.0 ELSE 0.0 END) FROM {table}'
)
v = con.execute(sql).fetchone()[0]
return (v or 0) > 0.7
except Exception:
return False
def role_of_column(con, table: str, col: str, dtype: str) -> str:
d = (dtype or '').upper()
if any(tok in d for tok in ["DATE", "TIMESTAMP"]):
return "date"
if any(tok in d for tok in ["INT", "BIGINT", "DOUBLE", "DECIMAL", "FLOAT", "HUGEINT", "REAL"]):
if detect_year_column(con, table, col):
return "year"
return "numeric"
if any(tok in d for tok in ["CHAR", "STRING", "TEXT", "VARCHAR"]):
try:
sql = f'SELECT COUNT(*), COUNT(DISTINCT "{col}") FROM {table}'
n, nd = con.execute(sql).fetchone()
if n and nd is not None:
ratio = (nd / n) if n else 0
if ratio > 0.95:
return "id_like"
if 0.01 <= ratio <= 0.35:
return "category"
if ratio < 0.01:
return "binary_flag"
except Exception:
pass
return "text"
return "other"
def quick_table_profile(con: duckdb.DuckDBPyConnection, table: str) -> Dict:
rows = con.execute(f'SELECT COUNT(*) FROM {table}').fetchone()[0]
cols = get_columns(con, table)
roles = {"category": [], "numeric": [], "date": [], "year": [], "id_like": [], "text": [], "binary": [], "other": []}
for c, d in cols:
r = role_of_column(con, table, c, d)
if r == "binary_flag":
roles["binary"].append(c)
else:
roles.setdefault(r, []).append(c)
return {
"rows": int(rows or 0),
"n_cols": len(cols),
"n_cat": len(roles["category"]),
"n_num": len(roles["numeric"]),
"n_time": len(roles["year"]) + len(roles["date"]),
}
def table_mapping(con: duckdb.DuckDBPyConnection, user_tables: List[str]) -> Dict[str, Dict]:
"""
Map db_table (normalized) -> {sheet_name, title} using __excel_tables if present.
"""
normalize = lambda t: t.split('.', 1)[1].strip('"') if '.' in t else t.strip('"')
want_names = {normalize(t) for t in user_tables}
mapping: Dict[str, Dict] = {}
try:
rows = con.execute(
"SELECT sheet_name, table_name, inferred_title, original_title_text, block_index, start_row "
"FROM __file_tables ORDER BY block_index, start_row"
).fetchall()
for sheet_name, table_name, inferred_title, original_title_text, block_index, start_row in rows:
if table_name not in want_names:
continue
title = inferred_title or original_title_text or 'untitled'
mapping[table_name] = {'sheet_name': sheet_name, 'title': title}
except Exception:
try:
rows = con.execute(
"SELECT sheet_name, table_name, inferred_title, original_title_text, block_index, start_row "
"FROM __excel_tables ORDER BY block_index, start_row"
).fetchall()
for sheet_name, table_name, inferred_title, original_title_text, block_index, start_row in rows:
if table_name not in want_names:
continue
title = inferred_title or original_title_text or 'untitled'
mapping[table_name] = {'sheet_name': sheet_name, 'title': title}
except Exception:
pass
return mapping
def excel_schema_samples(con: duckdb.DuckDBPyConnection, mapping: Dict[str, Dict], max_cols: int = 8) -> Dict[str, List[str]]:
""" Return up to max_cols original column names per table_name (normalized) for LLM hints. """
samples: Dict[str, List[str]] = {}
try:
rows = con.execute("SELECT sheet_name, table_name, column_ordinal, original_name FROM __file_schema ORDER BY sheet_name, table_name, column_ordinal").fetchall()
for sheet_name, table_name, ordn, orig in rows:
if table_name not in mapping:
continue
lst = samples.setdefault(table_name, [])
if orig and len(lst) < max_cols:
lst.append(str(orig))
except Exception:
try:
rows = con.execute(
"SELECT sheet_name, table_name, inferred_title, original_title_text, block_index, start_row "
"FROM __excel_tables ORDER BY block_index, start_row"
).fetchall()
for sheet_name, table_name, inferred_title, original_title_text, block_index, start_row in rows:
if table_name not in want_names:
continue
title = inferred_title or original_title_text or 'untitled'
mapping[table_name] = {'sheet_name': sheet_name, 'title': title}
except Exception:
pass
return samples
# ---------- OpenAI ----------
def ai_overview_from_context(context: Dict) -> str:
api_key = os.environ.get("OPENAI_API_KEY") or st.secrets.get("OPENAI_API_KEY", None)
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set. Please add it to .env or Streamlit secrets.")
try:
from openai import OpenAI
client = OpenAI(api_key=api_key)
model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
except Exception as e:
raise RuntimeError("OpenAI client not available. Install 'openai' >= 1.0 and try again.") from e
prompt = f'''
Start directly (no greeting). Write a concise, conversational overview (max two short paragraphs) of the dataset created from the uploaded Excel.
Requirements:
- Do NOT mention database engines, schemas, or technical column/table names.
- For each segment, reference it as: Sheet "<sheet_name>" — Table "<title>" (use "untitled" if missing).
- Use any provided original Excel header hints ONLY to infer friendlier human concepts; do not quote them verbatim.
- After the overview, list 6–8 simple questions a user could ask in natural language.
- Output Markdown with headings: "Overview" and "Try These Questions".
Context (JSON):
{json.dumps(context, ensure_ascii=False, indent=2)}
'''
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.4,
)
return resp.choices[0].message.content.strip()
# ---------- Orchestration ----------
def run_ingestion_pipeline(file_path: Path, db_path: Path, log_placeholder):
# Combined log function
log_lines: List[str] = []
def _append(line: str):
log_lines.append(line)
log_placeholder.markdown(
f"<div class='logbox'><code>{'</code><br/><code>'.join(map(str, log_lines[-400:]))}</code></div>",
unsafe_allow_html=True,
)
# 1) Save (already saved by caller, but we log here for a single place)
_append("[app] Saving file…")
_append("[app] Saved.")
if not SCRIPT_PATH.exists():
_append("[app] ERROR: ingestion component not found next to the app.")
raise FileNotFoundError("Required ingestion component not found.")
# 2) Ingest
_append("[app] Ingesting…")
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
cmd = [sys.executable, str(SCRIPT_PATH), "--file", str(file_path), "--duckdb", str(db_path)]
try:
proc = subprocess.Popen(
cmd, cwd=str(PRIMARY_DIR),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1, universal_newlines=True, env=env
)
except Exception as e:
_append(f"[app] ERROR: failed to start ingestion: {e}")
raise
for line in iter(proc.stdout.readline, ""):
_append(line.rstrip("\n"))
proc.wait()
if proc.returncode != 0:
_append("[app] ERROR: ingestion reported a non-zero exit code.")
raise RuntimeError("Ingestion failed. See logs.")
_append("[app] Ingestion complete.")
# 3) Open dataset
_append("[app] Opening dataset…")
con = duckdb.connect(str(db_path))
_append("[app] Dataset open.")
return con, _append
def analyze_and_summarize(con: duckdb.DuckDBPyConnection):
user_tables = list_user_tables(con)
preview_items = [] # list of {'table_ref': t, 'label': label} for UI
if not user_tables:
# Try to provide metadata if no tables are found
try:
meta_df = con.execute("SELECT * FROM __excel_tables").fetchdf()
st.warning("No user tables were discovered. Showing ingestion metadata for reference.")
st.dataframe(meta_df, use_container_width=True, hide_index=True)
except Exception:
st.error("No user tables were discovered and no metadata table is available.")
return "", []
# Build mapping + schema hints
mapping = table_mapping(con, user_tables) # normalized table_name -> {sheet_name, title}
schema_hints = excel_schema_samples(con, mapping, max_cols=8)
# Build compact profiling context for LLM; avoid raw db table names
per_table = []
for idx, t in enumerate(user_tables, start=1):
prof = quick_table_profile(con, t)
norm = t.split('.', 1)[1].strip('"') if '.' in t else t.strip('"')
m = mapping.get(norm, {})
sheet = m.get('sheet_name')
title = m.get('title')
per_table.append({
"idx": idx,
"sheet_name": sheet,
"title": title or "untitled",
"rows": prof["rows"],
"n_cols": prof["n_cols"],
"category_fields": prof["n_cat"],
"numeric_measures": prof["n_num"],
"time_fields": prof["n_time"],
"example_original_headers": schema_hints.get(norm, [])
})
label = f'Sheet "{sheet}" — Table "{title or "untitled"}"' if sheet else f'Table "{title or "untitled"}"'
preview_items.append({'table_ref': t, 'label': label})
context = {
"segments": per_table
}
# Generate overview (LLM only)
overview_md = ai_overview_from_context(context)
return overview_md, preview_items
# ---------- UI flow ----------
file = st.file_uploader("Upload an Excel or CSV file", type=["xlsx", "csv"])
if file is None and not st.session_state.last_overview_md:
st.info("Upload a .xlsx or .csv file to begin.")
# Only show logs AFTER there is an upload or some result to show
logs_placeholder = None
if file is not None or st.session_state.processing or st.session_state.last_overview_md:
logs_exp = st.expander("Processing logs", expanded=False)
logs_placeholder = logs_exp.empty()
if file is not None:
key = _file_key(file)
stem = Path(file.name).stem
saved_file = UPLOAD_DIR / file.name
db_path = DB_DIR / f"{stem}.duckdb"
# --- CLEAR state immediately on new upload ---
if st.session_state.get("processed_key") != key:
st.session_state["last_overview_md"] = None
st.session_state["last_preview_items"] = []
st.session_state["chat_history"] = []
st.session_state["schema_text"] = None
st.session_state["db_path"] = None
# Optional: clear any previous logs shown in UI on rerun
# (no explicit log buffer stored; the log expander will refresh)
# Auto-start ingestion exactly once per unique upload
if (st.session_state.processed_key != key) and (not st.session_state.processing):
st.session_state.processing = True
if logs_placeholder is None:
logs_exp = st.expander("Ingestion logs", expanded=False)
logs_placeholder = logs_exp.empty()
# Save uploaded file
with open(saved_file, "wb") as f:
f.write(file.getbuffer())
try:
con, app_log = run_ingestion_pipeline(saved_file, db_path, logs_placeholder)
# Analyze + overview
app_log("[app] Analyzing data…")
overview_md, preview_items = analyze_and_summarize(con)
app_log("[app] Overview complete.")
con.close()
st.session_state.last_overview_md = overview_md
st.session_state.last_preview_items = preview_items
st.session_state.processed_key = key
st.session_state.processing = False
except Exception as e:
st.session_state.processing = False
st.error(f"Ingestion failed. See logs for details. Error: {e}")
# Display results if available (and avoid re-triggering ingestion)
if st.session_state.last_overview_md:
#st.subheader("Overview")
st.markdown(st.session_state.last_overview_md)
with st.expander("Quick preview (verification only)", expanded=False):
try:
# Reconnect to current dataset path (if present)
if file is not None:
stem = Path(file.name).stem
db_path = DB_DIR / f"{stem}.duckdb"
con = duckdb.connect(str(db_path))
for item in st.session_state.last_preview_items:
t = item['table_ref']
label = item['label']
df = con.execute(f"SELECT * FROM {t} LIMIT 50").df()
st.caption(f"Preview — {label}")
st.dataframe(df, use_container_width=True, hide_index=True)
con.close()
except Exception as e:
st.warning(f"Could not preview tables: {e}")
# =====================
# Chat with your dataset
# (Appends after overview & preview; leaves earlier logic untouched)
# =====================
if st.session_state.get("last_overview_md"):
st.divider()
st.subheader("Chat with your dataset")
# Lazy imports so nothing changes before preview completes
def _lazy_imports():
from duckdb_react_agent import get_schema_summary, make_llm, answer_question # noqa: F401
return get_schema_summary, make_llm, answer_question
# Initialize chat memory
st.session_state.setdefault("chat_history", []) # [{role, content, sql?, plot_path?}]
# --- 1) Take input first so the user's question appears immediately ---
user_q = st.chat_input("Ask a question about the dataset…")
if user_q:
st.session_state.chat_history.append({"role": "user", "content": user_q})
# --- 2) Render history in strict User → Assistant order ---
for msg in st.session_state.chat_history:
with st.chat_message("user" if msg["role"] == "user" else "assistant"):
# Always: text first
st.markdown(msg["content"])
# Then: optional plot (below the answer)
plot_path_hist = msg.get("plot_path")
if plot_path_hist:
if not os.path.isabs(plot_path_hist):
plot_path_hist = str((PRIMARY_DIR / plot_path_hist).resolve())
if os.path.exists(plot_path_hist):
st.image(plot_path_hist, caption="Chart", width=520)
# Finally: optional SQL expander
if msg.get("sql"):
with st.expander("View generated SQL", expanded=False):
st.markdown(f"<div class='sqlbox'>{msg['sql']}</div>", unsafe_allow_html=True)
# --- 3) If a new question arrived, stream the assistant answer now ---
if user_q:
with st.chat_message("assistant"):
# Placeholders in sequence: text → plot → SQL
stream_placeholder = st.empty()
plot_placeholder = st.empty()
sql_placeholder = st.empty()
# Show pending immediately
stream_placeholder.markdown("_Answer pending…_")
partial_chunks = []
def on_token(t: str):
partial_chunks.append(t)
stream_placeholder.markdown("".join(partial_chunks))
# Resolve DB path
if 'db_path' in locals():
_db_path = db_path # from the preview scope if defined
else:
if 'file' in locals() and file is not None:
_stem = Path(file.name).stem
_db_path = DB_DIR / f"{_stem}.duckdb"
else:
_candidates = sorted(DB_DIR.glob("*.duckdb"), key=lambda p: p.stat().st_mtime, reverse=True)
_db_path = _candidates[0] if _candidates else None
if not _db_path or not Path(_db_path).exists():
stream_placeholder.error("No dataset found. Please re-upload the file in this session.")
else:
# Call agent lazily
get_schema_summary, make_llm, answer_question = _lazy_imports()
try:
try:
con2 = duckdb.connect(str(_db_path), read_only=True)
except Exception:
con2 = duckdb.connect(str(_db_path))
schema_text = get_schema_summary(con2, allowed_schemas=["main"])
llm = make_llm(model=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"), temperature=0.0)
import inspect as _inspect
_sig = None
try:
_sig = _inspect.signature(answer_question)
except Exception:
_sig = None
def _call_answer():
try:
if _sig and "history" in _sig.parameters and "token_callback" in _sig.parameters and "stream" in _sig.parameters:
return answer_question(con2, llm, schema_text, user_q, stream=True, token_callback=on_token, history=st.session_state.chat_history)
elif _sig and "token_callback" in _sig.parameters and "stream" in _sig.parameters:
return answer_question(con2, llm, schema_text, user_q, stream=True, token_callback=on_token)
else:
return answer_question(con2, llm, schema_text, user_q)
except TypeError:
try:
return answer_question(con2, llm, schema_text, user_q, stream=True, token_callback=on_token)
except TypeError:
return answer_question(con2, llm, schema_text, user_q)
result = _call_answer()
con2.close()
# Finalize text
answer_text = result.get("answer") or "".join(partial_chunks) or "*No answer produced.*"
stream_placeholder.markdown(answer_text)
# Show plot next (slightly larger, below the text)
plot_path = result.get("plot_path")
if plot_path:
if not os.path.isabs(plot_path):
plot_path = str((PRIMARY_DIR / plot_path).resolve())
if os.path.exists(plot_path):
plot_placeholder.image(plot_path, caption="Chart", width=560)
# Finally show SQL
gen_sql = (result.get("sql") or "").strip()
if gen_sql:
with st.expander("View generated SQL", expanded=False):
st.markdown(f"<div class='sqlbox'>{gen_sql}</div>", unsafe_allow_html=True)
# Persist assistant message
st.session_state.chat_history.append({
"role": "assistant",
"content": answer_text,
"sql": gen_sql,
"plot_path": result.get("plot_path")
})
finally:
try:
con2.close()
except Exception:
pass