import streamlit as st
import pandas as pd
from utils.consts import DB_PATH
import sqlite3
import re
import os
from agents.sql_agent.agent import SQLAgent
import time
from agents.tools import PlotSQLTool
from agents.dataframe_agent import get_dataframe_agent
from datetime import datetime
db_name = os.path.basename(DB_PATH)
st.set_page_config(page_title="🔍 TalkToData", layout="wide", initial_sidebar_state="collapsed")
# Loại bỏ title markdown để tránh hiển thị lặp lại
# Sidebar for settings
with st.sidebar:
st.header("ℹ️ About", anchor=None)
st.markdown("""
**TalkToData** v0.1.0
Your personal AI Data Analyst.
""", unsafe_allow_html=True)
# Initialize chat history
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Initialize SQL agent
# agent = get_sql_agent()
agent = SQLAgent()
state = {
"question": None,
"db_info": {
"tables": [],
"columns": {},
"schema": None
},
"sql_query": None,
"sql_result": None,
"error": None,
"step": None,
"answer": None
}
# --- Upload Screen State ---
if 'files_uploaded' not in st.session_state:
st.session_state['files_uploaded'] = False
# TEMP: Bypass landing page
st.session_state['files_uploaded'] = True
if not st.session_state['files_uploaded']:
# CSS to center and enlarge only the welcome start button
st.markdown("""
""", unsafe_allow_html=True)
# Wrap welcome content to scope styling
st.markdown("
", unsafe_allow_html=True)
# Title and subtitle
st.markdown("""
🔍 TalkToData
Your Personal AI Data Analyst that instantly answers your data questions with clear insights and elegant visualizations.
""", unsafe_allow_html=True)
# Standalone welcome start button
if st.button("🚀 Explore now", key="start"):
st.session_state['files_uploaded'] = True
st.experimental_rerun()
# Close welcome wrapper
st.markdown("", unsafe_allow_html=True)
st.divider()
# SaaS-style Features section
st.markdown("## Features")
feat_cols = st.columns(3)
feat_cols[0].markdown("### 🗣 Natural-Language Queries\nAsk your data without SQL knowledge.")
feat_cols[1].markdown("### 📊 Instant Visualizations\nGet charts from one command.")
feat_cols[2].markdown("### 🔒 Secure & Local\nYour data stays on your machine.")
st.divider()
# How It Works section
st.markdown("## How It Works")
step_cols = st.columns(3)
step_cols[0].markdown("#### 1️⃣ Upload\nUpload .db or CSV files.")
step_cols[1].markdown("#### 2️⃣ Chat\nInteract in natural language.")
step_cols[2].markdown("#### 3️⃣ Visualize\nSee results as tables or charts.")
st.divider()
# Use Cases
st.markdown("## Use Cases")
st.markdown("- \"Show me top 5 products by sales\" → Chart")
st.markdown("- \"List customers from 2020\" → Table")
st.divider()
# Testimonials
st.markdown("## Testimonials")
testi_cols = st.columns(2)
testi_cols[0].markdown("> \"TalkToData transformed our data workflow!\" \n— Jane Doe, Data Analyst")
testi_cols[1].markdown("> \"The AI assistant is incredibly smart and fast.\" \n— John Smith, Product Manager")
st.divider()
# Footer
st.markdown("2025 TalkToData. All rights reserved.")
st.markdown("TalkToData v0.1.0 - Copyright 2025 by Khanh Pham
", unsafe_allow_html=True)
st.html(
"Oops!
"
)
st.divider()
else:
# App title and return button
# st.title("🔍 TalkToData")
st.markdown("### TalkToData")
# TEMP: Commented out back-to-home
# if st.button('⬅️ Back to Home', key='back_to_upload'):
# st.session_state['files_uploaded'] = False
# # Xóa dữ liệu cũ
# if 'uploaded_csvs' in st.session_state:
# st.session_state['uploaded_csvs'] = []
# st.experimental_rerun()
# Layout: Data source selector, main content, and chat
data_col, left_col, right_col = st.columns([1.5, 3, 2])
# Data source selection
with data_col:
# st.subheader("Data Sources")
# Upload data
with st.expander("**Upload Data**", expanded=True):
st.file_uploader('Select SQLite (.db), CSV or Excel (.xlsx) files',
type=['db', 'csv', 'xlsx'],
accept_multiple_files=True,
key='upload_any_col',
label_visibility="collapsed")
gsheet_url = st.text_input('Enter Google Sheets URL (optional)', '', key='gsheet_url')
upload_status = []
has_db = False
has_csv = False
# Retrieve uploaded files list safely
uploaded_files = st.session_state.get('upload_any_col', [])
# Process Google Sheets if URL provided
url = st.session_state.get('gsheet_url', '').strip()
if url:
try:
csv_url = url.replace('/edit#gid=', '/export?format=csv&gid=')
df_gs = pd.read_csv(csv_url)
if 'uploaded_csvs' not in st.session_state:
st.session_state['uploaded_csvs'] = []
st.session_state['uploaded_csvs'].append({'name': 'GoogleSheets', 'df': df_gs})
upload_status.append('✅ Google Sheets loaded')
has_csv = True
except Exception as e:
upload_status.append(f'❌ Google Sheets error: {e}')
# Process files
for f in uploaded_files:
if f.name.lower().endswith('.db'):
try:
with open(DB_PATH, "wb") as dbf:
dbf.write(f.read())
upload_status.append(f"✅ Database: {f.name}")
has_db = True
except Exception as e:
upload_status.append(f"❌ Database error: {e}")
# Process CSV and Excel
name = f.name.lower()
if name.endswith('.csv') or name.endswith('.xlsx'):
try:
if name.endswith('.xlsx'):
# Process each sheet in Excel
f.seek(0)
xls = pd.ExcelFile(f)
sheets = st.multiselect(f"Select sheet(s) from {f.name}", xls.sheet_names, default=xls.sheet_names)
for sheet in sheets:
# Read raw to detect header rows
raw = xls.parse(sheet, header=None)
nn = raw.notnull().sum(axis=1)
hdr = [i for i, cnt in enumerate(nn) if cnt > 1]
if len(hdr) >= 2:
header = hdr[:2]
elif len(hdr) == 1:
header = [hdr[0]]
else:
header = [0]
df_sheet = xls.parse(sheet, header=header)
# Flatten MultiIndex if needed
if isinstance(df_sheet.columns, pd.MultiIndex):
df_sheet.columns = [" ".join([str(x) for x in col if pd.notna(x)]).strip() for col in df_sheet.columns]
# Store with sheet label
sheet_key = f"{f.name}:{sheet}"
if 'uploaded_csvs' not in st.session_state:
st.session_state['uploaded_csvs'] = []
st.session_state['uploaded_csvs'].append({'name': sheet_key, 'df': df_sheet})
upload_status.append(f"✅ Excel: {sheet_key}")
else:
temp_df = pd.read_csv(f)
if 'uploaded_csvs' not in st.session_state:
st.session_state['uploaded_csvs'] = []
# Check existing and update
csv_exists = False
for i, csv in enumerate(st.session_state['uploaded_csvs']):
if csv['name'] == f.name:
st.session_state['uploaded_csvs'][i]['df'] = temp_df
csv_exists = True
break
if not csv_exists:
st.session_state['uploaded_csvs'].append({'name': f.name, 'df': temp_df})
upload_status.append(f"✅ CSV/Excel: {f.name}")
has_csv = True
except Exception as e:
upload_status.append(f"❌ CSV/Excel error: {e}")
# Hiển thị trạng thái upload
if upload_status:
for status in upload_status:
st.write(status)
# After upload, select data sources
ds = []
if os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0:
ds.append(db_name)
if 'uploaded_csvs' in st.session_state:
ds += [csv['name'] for csv in st.session_state['uploaded_csvs']]
if ds:
# Initialize selected_sources session state to default to db_name
if 'selected_sources' not in st.session_state:
st.session_state['selected_sources'] = [db_name] if db_name in ds else []
selected_sources = st.multiselect(
"**Select sources**", options=ds,
key='selected_sources'
)
else:
st.info("Upload a database or CSV/Excel file to select a data source.")
with left_col:
# Data Preview: filter sources by user selection
selected = st.session_state.get('selected_sources', [])
preview_db = os.path.exists(DB_PATH) and db_name in selected
# Filter CSV/Excel previews
preview_csvs = [csv for csv in st.session_state.get('uploaded_csvs', []) if csv['name'] in selected]
if preview_db or preview_csvs:
# Display previews
with st.container(height=415):
st.markdown("**Data Preview**")
# Build tab labels
tab_labels = []
if preview_db:
tab_labels.append(db_name)
for c in preview_csvs:
tab_labels.append(c['name'])
tabs = st.tabs(tab_labels)
idx = 0
# Database preview
if preview_db:
with tabs[idx]:
conn = sqlite3.connect(DB_PATH)
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
if tables:
t_tabs = st.tabs([t[0] for t in tables])
for t, tab in zip(tables, t_tabs):
with tab:
st.table(pd.read_sql_query(f"SELECT * FROM {t[0]}", conn))
else:
st.info("No tables found.")
conn.close()
idx += 1
# CSV/Excel previews
for c in preview_csvs:
with tabs[idx]:
st.table(c['df'])
idx += 1
# --- Data Exploration Section (Always Visible) ---
with st.container(height=225):
# Data Exploration: only support Database source
selected = st.session_state.get('selected_sources', [])
if db_name not in selected:
st.warning(f"⚠️ Data Exploration only supports SQL queries on database .db files. Please select at least a database to continue.")
else:
# st.subheader("Data Exploration")
sql_explore = st.text_area(
"Enter SQL query to explore:",
value=st.session_state.get('explore_sql', ''),
height=100,
key='explore_sql'
)
if st.button("Run Query", key="explore_run"):
try:
df_explore = pd.read_sql_query(sql_explore, sqlite3.connect(DB_PATH))
st.session_state['explore_result'] = df_explore
# Record exploration history
if 'explore_history' not in st.session_state:
st.session_state['explore_history'] = []
# User query
st.session_state['explore_history'].append({
'source': 'explore', 'role': 'user', 'content': sql_explore, 'timestamp': datetime.now()
})
# Assistant result as CSV
res_str = df_explore.to_csv(index=False)
st.session_state['explore_history'].append({
'source': 'explore', 'role': 'assistant', 'content': res_str, 'timestamp': datetime.now()
})
except Exception as e:
st.error(f"Error: {e}")
# Wrap tabs in scrollable container
with st.container(height=300):
# st.markdown("", unsafe_allow_html=True)
tabs = st.tabs(["Results", "History"])
# Results tab: show explore_result only
with tabs[0]:
if 'explore_result' in st.session_state:
# st.subheader("Results")
st.table(st.session_state['explore_result'])
else:
st.write("No results yet.")
# History tab: Query history
with tabs[1]:
# st.subheader("History")
# Build paired history entries
combined = []
# Exploration history pairs
explore_hist = st.session_state.get('explore_history', [])
for i in range(0, len(explore_hist), 2):
u = explore_hist[i] if i < len(explore_hist) else {}
a = explore_hist[i+1] if i+1 < len(explore_hist) else {}
combined.append({
'source': db_name,
'query_type': 'sql',
'query': u.get('content'),
'result': a.get('content'),
'timestamp': u.get('timestamp')
})
# Chat history pairs for all sources
for source, chat_hist in st.session_state.get('chat_histories', {}).items():
for idx in range(len(chat_hist)):
if chat_hist[idx].get('role') == 'user':
q = chat_hist[idx].get('content')
r = chat_hist[idx+1].get('content') if idx+1 < len(chat_hist) else None
combined.append({
'source': source,
'query_type': 'chat',
'query': q,
'result': r,
'timestamp': chat_hist[idx].get('timestamp')
})
if combined:
df_history = pd.DataFrame(combined)
# ensure timestamp column is datetime
if not pd.api.types.is_datetime64_any_dtype(df_history['timestamp']):
df_history['timestamp'] = pd.to_datetime(df_history['timestamp'])
# sort latest first
df_history = df_history.sort_values('timestamp', ascending=False)
st.table(df_history)
else:
st.write("No history yet.")
st.markdown("
", unsafe_allow_html=True)
with right_col:
# Use selected_sources from left data selector
data_sources = st.session_state.get('selected_sources', [])
csv_files = st.session_state.get('uploaded_csvs', [])
selected_source = data_sources[0] if data_sources else None
# Chat history per source (only if a source is selected)
if 'chat_histories' not in st.session_state:
st.session_state['chat_histories'] = {}
# Initialize past conversations container
if 'all_conversations' not in st.session_state:
st.session_state['all_conversations'] = {}
# Only proceed with chat if a data source is selected
if selected_source is not None:
if selected_source not in st.session_state['chat_histories']:
st.session_state['chat_histories'][selected_source] = []
if selected_source not in st.session_state['all_conversations']:
st.session_state['all_conversations'][selected_source] = []
chat_history = st.session_state['chat_histories'][selected_source]
# Only show chat interface if a data source is selected
if selected_source is not None:
container = st.container(height=700, border=True)
# Align New Conversation button top-right
with container:
cols = st.columns([2, 1])
with cols[0]:
st.markdown("**Ask TalkToData**")
if cols[1].button("New Chat", key=f"new_conv_{selected_source}"):
if chat_history:
conv = chat_history.copy()
ts = conv[0].get('timestamp', datetime.now())
st.session_state['all_conversations'][selected_source].append({'messages':conv, 'timestamp':ts})
st.session_state['chat_histories'][selected_source] = []
st.experimental_rerun()
# Display chat messages
chat_history = st.session_state['chat_histories'][selected_source]
# Welcome message for new chat
if not chat_history:
container.chat_message("assistant").write("👋 Hello! Welcome to TalkToData. Ask any question about your data to get started.")
for turn in chat_history:
role = turn.get('role', '')
content = turn.get('content', '')
if role == 'user':
container.chat_message("user").write(content)
else:
container.chat_message("assistant").write(content)
# Chat input
user_input = st.chat_input(f"Ask a question about {selected_source}...")
else:
# Placeholder to maintain layout
st.container(height=700, border=True)
user_input = None
if user_input:
chat_history.append({"role": "user", "content": user_input, "timestamp": datetime.now()})
with container.chat_message("user"):
st.write(user_input)
# Answer logic
with container.chat_message("assistant"):
with st.spinner("Thinking..."):
if selected_source == db_name:
# Handle /sql and /plot commands
if user_input.strip().lower().startswith('/sql'):
sql = user_input[len('/sql'):].strip()
try:
df = pd.read_sql_query(sql, sqlite3.connect(DB_PATH))
st.write(f"```sql\n{sql}\n```")
st.table(df)
chat_history.append({"role": "assistant", "content": f"```sql\n{sql}\n```", "timestamp": datetime.now()})
except Exception as e:
err = f"SQL Error: {e}"
st.error(err)
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
elif user_input.strip().lower().startswith('/plot'):
sql = user_input[len('/plot'):].strip()
try:
tool = PlotSQLTool()
md = tool._run(sql)
st.markdown(md)
m = re.search(r'!\[.*\]\((.*?)\)', md)
if m:
st.image(m.group(1))
chat_history.append({"role": "assistant", "content": md, "timestamp": datetime.now()})
except Exception as e:
err = f"Plot Error: {e}"
st.error(err)
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
else:
# Use SQL agent as before
state['question'] = user_input
try:
for step in agent.graph.stream(state, stream_mode="updates"):
step_name, step_details = next(iter(step.items()))
if step_name == 'generate_sql':
with st.expander("SQL Generated", expanded=False):
st.markdown(f"```sql\n{step_details.get('sql_query', '')}\n```")
elif step_name == 'execute_sql':
with st.expander("SQL Result", expanded=False):
st.table(step_details.get('sql_result', pd.DataFrame()))
elif step_name == 'generate_answer':
st.write(step_details.get('answer', ''))
chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()})
elif step_name == 'render_visualization':
try:
visualization_output = step_details.get('visualization_output')
if visualization_output and os.path.exists(visualization_output):
st.image(visualization_output)
else:
print("No visualization was generated for this query.")
except Exception as e:
print(f"Could not display visualization: {str(e)}")
except Exception as e:
err = f"SQL Agent Error: {e}"
print(err)
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
else:
# Use DataFrame agent for selected CSV
csv_file = next((csv for csv in csv_files if csv['name'] == selected_source), None)
if csv_file:
if 'csv_agents' not in st.session_state:
st.session_state['csv_agents'] = {}
if selected_source not in st.session_state['csv_agents']:
st.session_state['csv_agents'][selected_source] = get_dataframe_agent(csv_file['df'])
agent = st.session_state['csv_agents'][selected_source]
try:
response = agent.invoke(user_input)
answer = response["output"] if isinstance(response, dict) and "output" in response else str(response)
except Exception as e:
answer = f"CSV Agent Error: {e}"
st.write(answer)
chat_history.append({"role": "assistant", "content": answer, "timestamp": datetime.now()})
# Refresh to update History immediately
# st.experimental_rerun()
# Past Conversations Panel
with st.container(height=200):
st.markdown("**Recent Conversations**")
# Flatten and sort conversations by most recent first
entries = []
for source, convs in st.session_state.get('all_conversations', {}).items():
for conv in convs:
entries.append((source, conv))
entries = sorted(entries, key=lambda x: x[1]['timestamp'], reverse=True)
for source, conv in entries:
label = conv['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
with st.expander(f"{source} - {label}", expanded=False):
for msg in conv['messages']:
if msg.get('role') == 'user':
st.chat_message('user').write(msg.get('content'))
else:
st.chat_message('assistant').write(msg.get('content'))