data-mind-ultra / app.py
sanjaystarc's picture
Upload 3 files
7ae981c verified
"""
app.py
======
DataMind Agent β€” Multi-LLM Streamlit UI
Supports: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok,
Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI)
Run: streamlit run app.py
"""
import os
import io
import streamlit as st
import pandas as pd
import plotly.express as px
from core_agent import (
PROVIDERS,
get_llm, validate_llm,
load_file, profile_dataframe, profile_to_text,
ask_agent, auto_suggest_charts, make_plotly_chart, ai_recommend_chart,
)
# ─── Page config ──────────────────────────────────────────────────────────────
st.set_page_config(
page_title="DataMind Agent",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded",
)
# ─── CSS ──────────────────────────────────────────────────────────────────────
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;700;800&family=DM+Sans:wght@300;400;500&display=swap');
html, body, [class*="css"] {
font-family: 'DM Sans', sans-serif;
background-color: #0a0a12;
color: #e8e8ff;
}
.main { background-color: #0a0a12; }
.hero-title {
font-family: 'Syne', sans-serif;
font-size: 2.8rem; font-weight: 800;
background: linear-gradient(135deg, #e8e8ff 0%, #6C63FF 50%, #43E97B 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
background-clip: text; margin-bottom: 0.2rem;
}
.hero-sub { color: #6a6a9a; font-size: 1rem; margin-bottom: 2rem; }
.stat-card {
background: #1a1a2e; border: 1px solid #2a2a45;
border-radius: 16px; padding: 1.2rem 1.5rem; text-align: center;
}
.stat-num { font-family: 'Syne', sans-serif; font-size: 2rem; font-weight: 800; color: #6C63FF; }
.stat-label { color: #6a6a9a; font-size: 0.8rem; text-transform: uppercase; letter-spacing: 0.1em; }
/* Provider badge */
.provider-badge {
display: inline-block;
padding: 3px 10px; border-radius: 20px;
font-size: 0.72rem; font-weight: 700;
letter-spacing: 0.05em;
margin-bottom: 0.5rem;
}
.user-bubble {
background: rgba(108,99,255,0.15); border: 1px solid rgba(108,99,255,0.3);
border-radius: 18px 18px 4px 18px; padding: 0.9rem 1.2rem;
margin: 0.5rem 0; font-size: 0.95rem;
}
.agent-bubble {
background: #1a1a2e; border: 1px solid #2a2a45;
border-radius: 18px 18px 18px 4px; padding: 0.9rem 1.2rem;
margin: 0.5rem 0; font-size: 0.95rem; line-height: 1.6;
}
section[data-testid="stSidebar"] {
background: #10101e; border-right: 1px solid #2a2a45;
}
.stButton > button {
background: linear-gradient(135deg, #6C63FF, #43E97B);
color: white; border: none; border-radius: 12px;
font-family: 'Syne', sans-serif; font-weight: 700;
padding: 0.6rem 1.5rem; transition: opacity 0.2s;
}
.stButton > button:hover { opacity: 0.85; color: white; }
.stTextInput > div > div > input {
background: #1a1a2e; border: 1px solid #2a2a45;
border-radius: 12px; color: #e8e8ff;
}
.stSelectbox > div > div {
background: #1a1a2e; border: 1px solid #2a2a45; border-radius: 12px;
}
.stTabs [data-baseweb="tab-list"] {
background: #10101e; border-radius: 12px; gap: 0.3rem;
}
.stTabs [data-baseweb="tab"] {
background: transparent; color: #6a6a9a;
border-radius: 10px; font-family: 'Syne', sans-serif;
}
.stTabs [aria-selected="true"] {
background: rgba(108,99,255,0.2) !important; color: #6C63FF !important;
}
</style>
""", unsafe_allow_html=True)
# ─── Session state ────────────────────────────────────────────────────────────
for key, default in {
"df": None,
"profile": None,
"file_type": None,
"chat_history": [],
"llm": None,
"active_provider": None,
"active_model": None,
"api_key_set": False,
}.items():
if key not in st.session_state:
st.session_state[key] = default
# ─── Sidebar ──────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("### 🧠 DataMind Agent")
st.markdown("---")
# ── Provider selector ─────────────────────────────────────────────────────
st.markdown("**πŸ€– Choose AI Provider**")
provider_labels = {k: v["name"] for k, v in PROVIDERS.items()}
selected_provider = st.selectbox(
"Provider",
options=list(provider_labels.keys()),
format_func=lambda k: provider_labels[k],
label_visibility="collapsed",
key="provider_select",
)
pinfo = PROVIDERS[selected_provider]
# Color dot for the selected provider
st.markdown(
f'<span class="provider-badge" style="background:{pinfo["color"]}22;'
f'color:{pinfo["color"]};border:1px solid {pinfo["color"]}55;">'
f'● {pinfo["name"]}</span>',
unsafe_allow_html=True,
)
# Model selector
selected_model = st.selectbox(
"Model",
options=pinfo["models"],
index=0,
key=f"model_{selected_provider}",
)
# Show a note for providers that need a third-party key (e.g. Together AI)
if pinfo.get("note"):
st.caption(f"ℹ️ {pinfo['note']}")
# API key
st.markdown(f"**πŸ”‘ {pinfo['name']} API Key**")
api_key = st.text_input(
"API Key",
type="password",
placeholder=pinfo["key_hint"],
label_visibility="collapsed",
key=f"apikey_{selected_provider}",
)
connect_btn = st.button("πŸ”Œ Connect", key="connect_btn")
if connect_btn and api_key:
with st.spinner(f"Connecting to {pinfo['name']}..."):
try:
llm, msg = validate_llm(selected_provider, api_key, selected_model)
st.session_state.llm = llm
st.session_state.api_key_set = True
st.session_state.active_provider = selected_provider
st.session_state.active_model = selected_model
st.session_state.chat_history = []
st.success(msg)
except Exception as e:
st.session_state.api_key_set = False
st.error(f"❌ Connection failed: {e}")
elif connect_btn and not api_key:
st.warning("⚠️ Please enter your API key first.")
# Show active connection status
if st.session_state.api_key_set and st.session_state.active_provider:
ap = st.session_state.active_provider
am = st.session_state.active_model
ac = PROVIDERS[ap]["color"]
st.markdown(
f'<div style="margin-top:8px;padding:8px 12px;border-radius:10px;'
f'background:{ac}15;border:1px solid {ac}40;font-size:0.78rem;">'
f'<span style="color:{ac}">●</span> <b>{PROVIDERS[ap]["name"]}</b><br/>'
f'<span style="color:#6a6a9a">{am}</span></div>',
unsafe_allow_html=True,
)
st.markdown("---")
# ── File upload ───────────────────────────────────────────────────────────
st.markdown("**πŸ“ Upload Data File**")
uploaded = st.file_uploader(
"Upload",
type=["csv", "xlsx", "xls", "json"],
label_visibility="collapsed",
)
if uploaded and st.session_state.api_key_set:
with st.spinner("πŸ“Š Analyzing your data..."):
try:
df, ftype = load_file(uploaded)
st.session_state.df = df
st.session_state.file_type = ftype
st.session_state.profile = profile_dataframe(df)
st.session_state.chat_history = []
st.success(f"βœ… Loaded {ftype} file!")
except Exception as e:
st.error(f"❌ Error: {e}")
elif uploaded and not st.session_state.api_key_set:
st.warning("⚠️ Connect to an AI provider first.")
st.markdown("---")
st.markdown("""
**How to use:**
1. Choose an AI provider
2. Select a model
3. Paste your API key β†’ click Connect
4. Upload CSV, Excel, or JSON
5. Explore Β· Chat Β· Visualize
---
**Get API keys:**
- [Gemini](https://aistudio.google.com/app/apikey)
- [OpenAI](https://platform.openai.com/api-keys)
- [Claude](https://console.anthropic.com/)
- [Grok](https://console.x.ai/)
- [Mistral](https://console.mistral.ai/)
- [Llama / Qwen β†’ Together AI](https://api.together.ai/)
""")
# ─── Main content ─────────────────────────────────────────────────────────────
st.markdown('<div class="hero-title">🧠 DataMind Agent</div>', unsafe_allow_html=True)
# Dynamic subtitle showing active provider
if st.session_state.api_key_set and st.session_state.active_provider:
ap = st.session_state.active_provider
am = st.session_state.active_model
ac = PROVIDERS[ap]["color"]
sub = (
f'AI-powered data analysis Β· Powered by '
f'<span style="color:{ac};font-weight:600">'
f'{PROVIDERS[ap]["name"]} / {am}</span>'
)
else:
sub = "AI-powered data analysis Β· Connect a provider and upload data to begin"
st.markdown(f'<div class="hero-sub">{sub}</div>', unsafe_allow_html=True)
# ─── Landing state ────────────────────────────────────────────────────────────
if st.session_state.df is None:
col1, col2, col3 = st.columns(3)
cards = [
("πŸ€–", "7 AI Providers", "Gemini, GPT, Claude, Grok, Mistral, Llama, Qwen β€” latest models"),
("πŸ“‚", "CSV Β· Excel Β· JSON", "Upload any tabular data file β€” we handle parsing automatically"),
("πŸ“Š", "Smart Visualizations", "AI picks the right chart for your question automatically"),
]
for col, (icon, title, desc) in zip([col1, col2, col3], cards):
with col:
st.markdown(
f'<div class="stat-card"><div class="stat-num">{icon}</div>'
f'<div class="stat-label">{title}</div><br>'
f'<p style="color:#6a6a9a;font-size:0.85rem">{desc}</p></div>',
unsafe_allow_html=True,
)
st.markdown("<br>", unsafe_allow_html=True)
if not st.session_state.api_key_set:
st.info("πŸ‘ˆ Choose a provider, enter your API key and click **Connect** in the sidebar.")
else:
st.info("πŸ‘ˆ Upload a data file (CSV, Excel, or JSON) in the sidebar to get started!")
else:
df = st.session_state.df
profile = st.session_state.profile
llm = st.session_state.llm
tab1, tab2, tab3, tab4 = st.tabs(["πŸ“Š Dashboard", "πŸ’¬ Chat", "🎨 Charts", "πŸ” Raw Data"])
# ══════════════════════════════════════════════════════════════════════════
# TAB 1 β€” Dashboard
# ══════════════════════════════════════════════════════════════════════════
with tab1:
rows, cols = profile["shape"]
nulls = sum(profile["null_counts"].values())
num_c = len(profile["numeric_columns"])
cat_c = len(profile["categorical_columns"])
c1, c2, c3, c4 = st.columns(4)
for col_obj, num, label in [
(c1, f"{rows:,}", "Rows"),
(c2, str(cols), "Columns"),
(c3, str(num_c), "Numeric Cols"),
(c4, str(nulls), "Missing Values"),
]:
col_obj.markdown(
f'<div class="stat-card"><div class="stat-num">{num}</div>'
f'<div class="stat-label">{label}</div></div>',
unsafe_allow_html=True,
)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("#### πŸ“‹ Column Overview")
col_info = pd.DataFrame({
"Column": df.columns,
"Type": df.dtypes.astype(str).values,
"Non-Null": df.notnull().sum().values,
"Null %": (df.isnull().mean() * 100).round(1).values,
"Unique": df.nunique().values,
})
st.dataframe(col_info, use_container_width=True, hide_index=True)
st.markdown("#### πŸ€– Auto-Generated Insights")
suggested = auto_suggest_charts(profile)[:3]
chart_cols = st.columns(min(len(suggested), 2))
for i, ctype in enumerate(suggested[:2]):
with chart_cols[i]:
try:
fig = make_plotly_chart(ctype, df, profile)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.warning(f"Could not render {ctype}: {e}")
if len(suggested) > 2:
try:
fig = make_plotly_chart(suggested[2], df, profile)
st.plotly_chart(fig, use_container_width=True)
except Exception:
pass
st.markdown("#### 🧠 AI Dataset Summary")
ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI")
if st.button(f"✨ Generate Summary with {ap_name}"):
with st.spinner(f"{ap_name} is analyzing your dataset..."):
summary = ask_agent(
"Give me a concise executive summary of this dataset. "
"Highlight key patterns, anomalies, and 3 actionable insights.",
df, profile, llm,
)
st.markdown(f'<div class="agent-bubble">{summary}</div>', unsafe_allow_html=True)
# ══════════════════════════════════════════════════════════════════════════
# TAB 2 β€” Chat
# ══════════════════════════════════════════════════════════════════════════
with tab2:
ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI")
st.markdown(f"#### πŸ’¬ Ask Anything About Your Data")
st.markdown(f"*Powered by **{ap_name} / {st.session_state.active_model}***")
st.markdown("**Quick questions to try:**")
suggestions = [
"What are the top 5 most important patterns in this data?",
"Are there any outliers or anomalies I should know about?",
"What correlations exist between the numeric columns?",
]
q_cols = st.columns(3)
for i, s in enumerate(suggestions):
with q_cols[i]:
if st.button(s, key=f"sug_{i}"):
st.session_state["prefill_q"] = s
for turn in st.session_state.chat_history:
st.markdown(f'<div class="user-bubble">πŸ‘€ {turn["user"]}</div>', unsafe_allow_html=True)
st.markdown(f'<div class="agent-bubble">🧠 {turn["agent"]}</div>', unsafe_allow_html=True)
prefill = st.session_state.pop("prefill_q", "")
question = st.text_input(
"Ask a question...",
value=prefill,
placeholder="e.g. What's the average sales by region?",
label_visibility="collapsed",
)
col_send, col_clear = st.columns([1, 5])
with col_send:
send = st.button("Send πŸš€")
with col_clear:
if st.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
if send and question.strip():
with st.spinner(f"🧠 {ap_name} is thinking..."):
answer = ask_agent(question, df, profile, llm)
chart_rec = ai_recommend_chart(question, profile, llm)
st.session_state.chat_history.append({
"user": question, "agent": answer, "chart_rec": chart_rec,
})
st.markdown(f'<div class="user-bubble">πŸ‘€ {question}</div>', unsafe_allow_html=True)
st.markdown(f'<div class="agent-bubble">🧠 {answer}</div>', unsafe_allow_html=True)
if chart_rec:
st.markdown(
f"*πŸ“Š Suggested chart: **{chart_rec['chart_type']}** β€” "
f"{chart_rec.get('reason','')}*"
)
try:
fig = make_plotly_chart(
chart_rec["chart_type"], df, profile,
x_col=chart_rec.get("x_col"),
y_col=chart_rec.get("y_col"),
)
st.plotly_chart(fig, use_container_width=True)
except Exception:
pass
# ══════════════════════════════════════════════════════════════════════════
# TAB 3 β€” Charts
# ══════════════════════════════════════════════════════════════════════════
with tab3:
st.markdown("#### 🎨 Custom Chart Builder")
chart_options = {
"Correlation Heatmap": "correlation_heatmap",
"Distribution Plot": "distribution_plots",
"Box Plots": "box_plots",
"Bar Chart": "bar_chart",
"Pie Chart": "pie_chart",
"Scatter Plot": "scatter",
"Line Chart": "line",
"Scatter Matrix": "scatter_matrix",
}
if profile["datetime_columns"]:
chart_options["Time Series"] = "time_series"
c1, c2, c3 = st.columns(3)
with c1:
chart_label = st.selectbox("Chart Type", list(chart_options.keys()))
with c2:
all_cols = ["(auto)"] + df.columns.tolist()
x_col = st.selectbox("X Column", all_cols)
with c3:
y_col = st.selectbox("Y Column", all_cols)
x_val = None if x_col == "(auto)" else x_col
y_val = None if y_col == "(auto)" else y_col
if st.button("🎨 Generate Chart"):
with st.spinner("Rendering..."):
try:
fig = make_plotly_chart(
chart_options[chart_label], df, profile,
x_col=x_val, y_col=y_val,
)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.error(f"Chart error: {e}")
st.markdown("---")
st.markdown("#### πŸ“Š All Auto-Suggested Charts")
suggested_all = auto_suggest_charts(profile)
for i in range(0, len(suggested_all), 2):
cols = st.columns(2)
for j, ctype in enumerate(suggested_all[i:i+2]):
with cols[j]:
try:
fig = make_plotly_chart(ctype, df, profile)
st.plotly_chart(fig, use_container_width=True)
except Exception:
st.warning(f"Could not render {ctype}")
# ══════════════════════════════════════════════════════════════════════════
# TAB 4 β€” Raw Data
# ══════════════════════════════════════════════════════════════════════════
with tab4:
st.markdown("#### πŸ” Raw Data Explorer")
search = st.text_input("πŸ”Ž Filter rows containing...", placeholder="Type to filter...")
if search:
mask = df.astype(str).apply(
lambda row: row.str.contains(search, case=False, na=False)
).any(axis=1)
display_df = df[mask]
st.info(f"Showing {len(display_df):,} of {len(df):,} rows matching '{search}'")
else:
display_df = df
st.dataframe(display_df, use_container_width=True, height=500)
csv_buf = io.StringIO()
df.to_csv(csv_buf, index=False)
st.download_button(
"⬇️ Download as CSV",
data=csv_buf.getvalue(),
file_name="analyzed_data.csv",
mime="text/csv",
)