import streamlit as st import pandas as pd from utils.consts import DB_PATH import sqlite3 import re import os from agents.sql_agent.agent import SQLAgent import time from agents.tools import PlotSQLTool from agents.dataframe_agent import get_dataframe_agent from datetime import datetime db_name = os.path.basename(DB_PATH) st.set_page_config(page_title="🔍 TalkToData", layout="wide", initial_sidebar_state="collapsed") # Loại bỏ title markdown để tránh hiển thị lặp lại # Sidebar for settings with st.sidebar: st.header("ℹ️ About", anchor=None) st.markdown(""" **TalkToData** v0.1.0 Your personal AI Data Analyst. """, unsafe_allow_html=True) # Initialize chat history if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Initialize SQL agent # agent = get_sql_agent() agent = SQLAgent() state = { "question": None, "db_info": { "tables": [], "columns": {}, "schema": None }, "sql_query": None, "sql_result": None, "error": None, "step": None, "answer": None } # --- Upload Screen State --- if 'files_uploaded' not in st.session_state: st.session_state['files_uploaded'] = False # TEMP: Bypass landing page st.session_state['files_uploaded'] = True if not st.session_state['files_uploaded']: # CSS to center and enlarge only the welcome start button st.markdown(""" """, unsafe_allow_html=True) # Wrap welcome content to scope styling st.markdown("
", unsafe_allow_html=True) # Title and subtitle st.markdown("""

🔍 TalkToData

Your Personal AI Data Analyst that instantly answers your data questions with clear insights and elegant visualizations.

""", unsafe_allow_html=True) # Standalone welcome start button if st.button("🚀 Explore now", key="start"): st.session_state['files_uploaded'] = True st.experimental_rerun() # Close welcome wrapper st.markdown("
", unsafe_allow_html=True) st.divider() # SaaS-style Features section st.markdown("## Features") feat_cols = st.columns(3) feat_cols[0].markdown("### 🗣 Natural-Language Queries\nAsk your data without SQL knowledge.") feat_cols[1].markdown("### 📊 Instant Visualizations\nGet charts from one command.") feat_cols[2].markdown("### 🔒 Secure & Local\nYour data stays on your machine.") st.divider() # How It Works section st.markdown("## How It Works") step_cols = st.columns(3) step_cols[0].markdown("#### 1️⃣ Upload\nUpload .db or CSV files.") step_cols[1].markdown("#### 2️⃣ Chat\nInteract in natural language.") step_cols[2].markdown("#### 3️⃣ Visualize\nSee results as tables or charts.") st.divider() # Use Cases st.markdown("## Use Cases") st.markdown("- \"Show me top 5 products by sales\" → Chart") st.markdown("- \"List customers from 2020\" → Table") st.divider() # Testimonials st.markdown("## Testimonials") testi_cols = st.columns(2) testi_cols[0].markdown("> \"TalkToData transformed our data workflow!\" \n— Jane Doe, Data Analyst") testi_cols[1].markdown("> \"The AI assistant is incredibly smart and fast.\" \n— John Smith, Product Manager") st.divider() # Footer st.markdown("2025 TalkToData. All rights reserved.") st.markdown("

TalkToData v0.1.0 - Copyright 2025 by Khanh Pham

", unsafe_allow_html=True) st.html( "

Oops!

" ) st.divider() else: # App title and return button # st.title("🔍 TalkToData") st.markdown("### TalkToData") # TEMP: Commented out back-to-home # if st.button('⬅️ Back to Home', key='back_to_upload'): # st.session_state['files_uploaded'] = False # # Xóa dữ liệu cũ # if 'uploaded_csvs' in st.session_state: # st.session_state['uploaded_csvs'] = [] # st.experimental_rerun() # Layout: Data source selector, main content, and chat data_col, left_col, right_col = st.columns([1.5, 3, 2]) # Data source selection with data_col: # st.subheader("Data Sources") # Upload data with st.expander("**Upload Data**", expanded=True): st.file_uploader('Select SQLite (.db), CSV or Excel (.xlsx) files', type=['db', 'csv', 'xlsx'], accept_multiple_files=True, key='upload_any_col', label_visibility="collapsed") gsheet_url = st.text_input('Enter Google Sheets URL (optional)', '', key='gsheet_url') upload_status = [] has_db = False has_csv = False # Retrieve uploaded files list safely uploaded_files = st.session_state.get('upload_any_col', []) # Process Google Sheets if URL provided url = st.session_state.get('gsheet_url', '').strip() if url: try: csv_url = url.replace('/edit#gid=', '/export?format=csv&gid=') df_gs = pd.read_csv(csv_url) if 'uploaded_csvs' not in st.session_state: st.session_state['uploaded_csvs'] = [] st.session_state['uploaded_csvs'].append({'name': 'GoogleSheets', 'df': df_gs}) upload_status.append('✅ Google Sheets loaded') has_csv = True except Exception as e: upload_status.append(f'❌ Google Sheets error: {e}') # Process files for f in uploaded_files: if f.name.lower().endswith('.db'): try: with open(DB_PATH, "wb") as dbf: dbf.write(f.read()) upload_status.append(f"✅ Database: {f.name}") has_db = True except Exception as e: upload_status.append(f"❌ Database error: {e}") # Process CSV and Excel name = f.name.lower() if name.endswith('.csv') or name.endswith('.xlsx'): try: if name.endswith('.xlsx'): # Process each sheet in Excel f.seek(0) xls = pd.ExcelFile(f) sheets = st.multiselect(f"Select sheet(s) from {f.name}", xls.sheet_names, default=xls.sheet_names) for sheet in sheets: # Read raw to detect header rows raw = xls.parse(sheet, header=None) nn = raw.notnull().sum(axis=1) hdr = [i for i, cnt in enumerate(nn) if cnt > 1] if len(hdr) >= 2: header = hdr[:2] elif len(hdr) == 1: header = [hdr[0]] else: header = [0] df_sheet = xls.parse(sheet, header=header) # Flatten MultiIndex if needed if isinstance(df_sheet.columns, pd.MultiIndex): df_sheet.columns = [" ".join([str(x) for x in col if pd.notna(x)]).strip() for col in df_sheet.columns] # Store with sheet label sheet_key = f"{f.name}:{sheet}" if 'uploaded_csvs' not in st.session_state: st.session_state['uploaded_csvs'] = [] st.session_state['uploaded_csvs'].append({'name': sheet_key, 'df': df_sheet}) upload_status.append(f"✅ Excel: {sheet_key}") else: temp_df = pd.read_csv(f) if 'uploaded_csvs' not in st.session_state: st.session_state['uploaded_csvs'] = [] # Check existing and update csv_exists = False for i, csv in enumerate(st.session_state['uploaded_csvs']): if csv['name'] == f.name: st.session_state['uploaded_csvs'][i]['df'] = temp_df csv_exists = True break if not csv_exists: st.session_state['uploaded_csvs'].append({'name': f.name, 'df': temp_df}) upload_status.append(f"✅ CSV/Excel: {f.name}") has_csv = True except Exception as e: upload_status.append(f"❌ CSV/Excel error: {e}") # Hiển thị trạng thái upload if upload_status: for status in upload_status: st.write(status) # After upload, select data sources ds = [] if os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0: ds.append(db_name) if 'uploaded_csvs' in st.session_state: ds += [csv['name'] for csv in st.session_state['uploaded_csvs']] if ds: # Initialize selected_sources session state to default to db_name if 'selected_sources' not in st.session_state: st.session_state['selected_sources'] = [db_name] if db_name in ds else [] selected_sources = st.multiselect( "**Select sources**", options=ds, key='selected_sources' ) else: st.info("Upload a database or CSV/Excel file to select a data source.") with left_col: # Data Preview: filter sources by user selection selected = st.session_state.get('selected_sources', []) preview_db = os.path.exists(DB_PATH) and db_name in selected # Filter CSV/Excel previews preview_csvs = [csv for csv in st.session_state.get('uploaded_csvs', []) if csv['name'] in selected] if preview_db or preview_csvs: # Display previews with st.container(height=415): st.markdown("**Data Preview**") # Build tab labels tab_labels = [] if preview_db: tab_labels.append(db_name) for c in preview_csvs: tab_labels.append(c['name']) tabs = st.tabs(tab_labels) idx = 0 # Database preview if preview_db: with tabs[idx]: conn = sqlite3.connect(DB_PATH) tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() if tables: t_tabs = st.tabs([t[0] for t in tables]) for t, tab in zip(tables, t_tabs): with tab: st.table(pd.read_sql_query(f"SELECT * FROM {t[0]}", conn)) else: st.info("No tables found.") conn.close() idx += 1 # CSV/Excel previews for c in preview_csvs: with tabs[idx]: st.table(c['df']) idx += 1 # --- Data Exploration Section (Always Visible) --- with st.container(height=225): # Data Exploration: only support Database source selected = st.session_state.get('selected_sources', []) if db_name not in selected: st.warning(f"⚠️ Data Exploration only supports SQL queries on database .db files. Please select at least a database to continue.") else: # st.subheader("Data Exploration") sql_explore = st.text_area( "Enter SQL query to explore:", value=st.session_state.get('explore_sql', ''), height=100, key='explore_sql' ) if st.button("Run Query", key="explore_run"): try: df_explore = pd.read_sql_query(sql_explore, sqlite3.connect(DB_PATH)) st.session_state['explore_result'] = df_explore # Record exploration history if 'explore_history' not in st.session_state: st.session_state['explore_history'] = [] # User query st.session_state['explore_history'].append({ 'source': 'explore', 'role': 'user', 'content': sql_explore, 'timestamp': datetime.now() }) # Assistant result as CSV res_str = df_explore.to_csv(index=False) st.session_state['explore_history'].append({ 'source': 'explore', 'role': 'assistant', 'content': res_str, 'timestamp': datetime.now() }) except Exception as e: st.error(f"Error: {e}") # Wrap tabs in scrollable container with st.container(height=300): # st.markdown("
", unsafe_allow_html=True) tabs = st.tabs(["Results", "History"]) # Results tab: show explore_result only with tabs[0]: if 'explore_result' in st.session_state: # st.subheader("Results") st.table(st.session_state['explore_result']) else: st.write("No results yet.") # History tab: Query history with tabs[1]: # st.subheader("History") # Build paired history entries combined = [] # Exploration history pairs explore_hist = st.session_state.get('explore_history', []) for i in range(0, len(explore_hist), 2): u = explore_hist[i] if i < len(explore_hist) else {} a = explore_hist[i+1] if i+1 < len(explore_hist) else {} combined.append({ 'source': db_name, 'query_type': 'sql', 'query': u.get('content'), 'result': a.get('content'), 'timestamp': u.get('timestamp') }) # Chat history pairs for all sources for source, chat_hist in st.session_state.get('chat_histories', {}).items(): for idx in range(len(chat_hist)): if chat_hist[idx].get('role') == 'user': q = chat_hist[idx].get('content') r = chat_hist[idx+1].get('content') if idx+1 < len(chat_hist) else None combined.append({ 'source': source, 'query_type': 'chat', 'query': q, 'result': r, 'timestamp': chat_hist[idx].get('timestamp') }) if combined: df_history = pd.DataFrame(combined) # ensure timestamp column is datetime if not pd.api.types.is_datetime64_any_dtype(df_history['timestamp']): df_history['timestamp'] = pd.to_datetime(df_history['timestamp']) # sort latest first df_history = df_history.sort_values('timestamp', ascending=False) st.table(df_history) else: st.write("No history yet.") st.markdown("
", unsafe_allow_html=True) with right_col: # Use selected_sources from left data selector data_sources = st.session_state.get('selected_sources', []) csv_files = st.session_state.get('uploaded_csvs', []) selected_source = data_sources[0] if data_sources else None # Chat history per source (only if a source is selected) if 'chat_histories' not in st.session_state: st.session_state['chat_histories'] = {} # Initialize past conversations container if 'all_conversations' not in st.session_state: st.session_state['all_conversations'] = {} # Only proceed with chat if a data source is selected if selected_source is not None: if selected_source not in st.session_state['chat_histories']: st.session_state['chat_histories'][selected_source] = [] if selected_source not in st.session_state['all_conversations']: st.session_state['all_conversations'][selected_source] = [] chat_history = st.session_state['chat_histories'][selected_source] # Only show chat interface if a data source is selected if selected_source is not None: container = st.container(height=700, border=True) # Align New Conversation button top-right with container: cols = st.columns([2, 1]) with cols[0]: st.markdown("**Ask TalkToData**") if cols[1].button("New Chat", key=f"new_conv_{selected_source}"): if chat_history: conv = chat_history.copy() ts = conv[0].get('timestamp', datetime.now()) st.session_state['all_conversations'][selected_source].append({'messages':conv, 'timestamp':ts}) st.session_state['chat_histories'][selected_source] = [] st.experimental_rerun() # Display chat messages chat_history = st.session_state['chat_histories'][selected_source] # Welcome message for new chat if not chat_history: container.chat_message("assistant").write("👋 Hello! Welcome to TalkToData. Ask any question about your data to get started.") for turn in chat_history: role = turn.get('role', '') content = turn.get('content', '') if role == 'user': container.chat_message("user").write(content) else: container.chat_message("assistant").write(content) # Chat input user_input = st.chat_input(f"Ask a question about {selected_source}...") else: # Placeholder to maintain layout st.container(height=700, border=True) user_input = None if user_input: chat_history.append({"role": "user", "content": user_input, "timestamp": datetime.now()}) with container.chat_message("user"): st.write(user_input) # Answer logic with container.chat_message("assistant"): with st.spinner("Thinking..."): if selected_source == db_name: # Handle /sql and /plot commands if user_input.strip().lower().startswith('/sql'): sql = user_input[len('/sql'):].strip() try: df = pd.read_sql_query(sql, sqlite3.connect(DB_PATH)) st.write(f"```sql\n{sql}\n```") st.table(df) chat_history.append({"role": "assistant", "content": f"```sql\n{sql}\n```", "timestamp": datetime.now()}) except Exception as e: err = f"SQL Error: {e}" st.error(err) chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) elif user_input.strip().lower().startswith('/plot'): sql = user_input[len('/plot'):].strip() try: tool = PlotSQLTool() md = tool._run(sql) st.markdown(md) m = re.search(r'!\[.*\]\((.*?)\)', md) if m: st.image(m.group(1)) chat_history.append({"role": "assistant", "content": md, "timestamp": datetime.now()}) except Exception as e: err = f"Plot Error: {e}" st.error(err) chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) else: # Use SQL agent as before state['question'] = user_input try: for step in agent.graph.stream(state, stream_mode="updates"): step_name, step_details = next(iter(step.items())) if step_name == 'generate_sql': with st.expander("SQL Generated", expanded=False): st.markdown(f"```sql\n{step_details.get('sql_query', '')}\n```") elif step_name == 'execute_sql': with st.expander("SQL Result", expanded=False): st.table(step_details.get('sql_result', pd.DataFrame())) elif step_name == 'generate_answer': st.write(step_details.get('answer', '')) chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()}) elif step_name == 'render_visualization': try: visualization_output = step_details.get('visualization_output') if visualization_output and os.path.exists(visualization_output): st.image(visualization_output) else: print("No visualization was generated for this query.") except Exception as e: print(f"Could not display visualization: {str(e)}") except Exception as e: err = f"SQL Agent Error: {e}" print(err) chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) else: # Use DataFrame agent for selected CSV csv_file = next((csv for csv in csv_files if csv['name'] == selected_source), None) if csv_file: if 'csv_agents' not in st.session_state: st.session_state['csv_agents'] = {} if selected_source not in st.session_state['csv_agents']: st.session_state['csv_agents'][selected_source] = get_dataframe_agent(csv_file['df']) agent = st.session_state['csv_agents'][selected_source] try: response = agent.invoke(user_input) answer = response["output"] if isinstance(response, dict) and "output" in response else str(response) except Exception as e: answer = f"CSV Agent Error: {e}" st.write(answer) chat_history.append({"role": "assistant", "content": answer, "timestamp": datetime.now()}) # Refresh to update History immediately # st.experimental_rerun() # Past Conversations Panel with st.container(height=200): st.markdown("**Recent Conversations**") # Flatten and sort conversations by most recent first entries = [] for source, convs in st.session_state.get('all_conversations', {}).items(): for conv in convs: entries.append((source, conv)) entries = sorted(entries, key=lambda x: x[1]['timestamp'], reverse=True) for source, conv in entries: label = conv['timestamp'].strftime("%Y-%m-%d %H:%M:%S") with st.expander(f"{source} - {label}", expanded=False): for msg in conv['messages']: if msg.get('role') == 'user': st.chat_message('user').write(msg.get('content')) else: st.chat_message('assistant').write(msg.get('content'))