""" app.py ====== DataMind Agent — Multi-LLM Streamlit UI Supports: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok, Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI) Run: streamlit run app.py """ import os import io import streamlit as st import pandas as pd import plotly.express as px from core_agent import ( PROVIDERS, get_llm, validate_llm, load_file, profile_dataframe, profile_to_text, ask_agent, auto_suggest_charts, make_plotly_chart, ai_recommend_chart, ) # ─── Page config ────────────────────────────────────────────────────────────── st.set_page_config( page_title="DataMind Agent", page_icon="🧠", layout="wide", initial_sidebar_state="expanded", ) # ─── CSS ────────────────────────────────────────────────────────────────────── st.markdown(""" """, unsafe_allow_html=True) # ─── Session state ──────────────────────────────────────────────────────────── for key, default in { "df": None, "profile": None, "file_type": None, "chat_history": [], "llm": None, "active_provider": None, "active_model": None, "api_key_set": False, }.items(): if key not in st.session_state: st.session_state[key] = default # ─── Sidebar ────────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("### 🧠 DataMind Agent") st.markdown("---") # ── Provider selector ───────────────────────────────────────────────────── st.markdown("**🤖 Choose AI Provider**") provider_labels = {k: v["name"] for k, v in PROVIDERS.items()} selected_provider = st.selectbox( "Provider", options=list(provider_labels.keys()), format_func=lambda k: provider_labels[k], label_visibility="collapsed", key="provider_select", ) pinfo = PROVIDERS[selected_provider] # Color dot for the selected provider st.markdown( f'' f'● {pinfo["name"]}', unsafe_allow_html=True, ) # Model selector selected_model = st.selectbox( "Model", options=pinfo["models"], index=0, key=f"model_{selected_provider}", ) # Show a note for providers that need a third-party key (e.g. Together AI) if pinfo.get("note"): st.caption(f"ℹ️ {pinfo['note']}") # API key st.markdown(f"**🔑 {pinfo['name']} API Key**") api_key = st.text_input( "API Key", type="password", placeholder=pinfo["key_hint"], label_visibility="collapsed", key=f"apikey_{selected_provider}", ) connect_btn = st.button("🔌 Connect", key="connect_btn") if connect_btn and api_key: with st.spinner(f"Connecting to {pinfo['name']}..."): try: llm, msg = validate_llm(selected_provider, api_key, selected_model) st.session_state.llm = llm st.session_state.api_key_set = True st.session_state.active_provider = selected_provider st.session_state.active_model = selected_model st.session_state.chat_history = [] st.success(msg) except Exception as e: st.session_state.api_key_set = False st.error(f"❌ Connection failed: {e}") elif connect_btn and not api_key: st.warning("⚠️ Please enter your API key first.") # Show active connection status if st.session_state.api_key_set and st.session_state.active_provider: ap = st.session_state.active_provider am = st.session_state.active_model ac = PROVIDERS[ap]["color"] st.markdown( f'
' f' {PROVIDERS[ap]["name"]}
' f'{am}
', unsafe_allow_html=True, ) st.markdown("---") # ── File upload ─────────────────────────────────────────────────────────── st.markdown("**📁 Upload Data File**") uploaded = st.file_uploader( "Upload", type=["csv", "xlsx", "xls", "json"], label_visibility="collapsed", ) if uploaded and st.session_state.api_key_set: with st.spinner("📊 Analyzing your data..."): try: df, ftype = load_file(uploaded) st.session_state.df = df st.session_state.file_type = ftype st.session_state.profile = profile_dataframe(df) st.session_state.chat_history = [] st.success(f"✅ Loaded {ftype} file!") except Exception as e: st.error(f"❌ Error: {e}") elif uploaded and not st.session_state.api_key_set: st.warning("⚠️ Connect to an AI provider first.") st.markdown("---") st.markdown(""" **How to use:** 1. Choose an AI provider 2. Select a model 3. Paste your API key → click Connect 4. Upload CSV, Excel, or JSON 5. Explore · Chat · Visualize --- **Get API keys:** - [Gemini](https://aistudio.google.com/app/apikey) - [OpenAI](https://platform.openai.com/api-keys) - [Claude](https://console.anthropic.com/) - [Grok](https://console.x.ai/) - [Mistral](https://console.mistral.ai/) - [Llama / Qwen → Together AI](https://api.together.ai/) """) # ─── Main content ───────────────────────────────────────────────────────────── st.markdown('
🧠 DataMind Agent
', unsafe_allow_html=True) # Dynamic subtitle showing active provider if st.session_state.api_key_set and st.session_state.active_provider: ap = st.session_state.active_provider am = st.session_state.active_model ac = PROVIDERS[ap]["color"] sub = ( f'AI-powered data analysis · Powered by ' f'' f'{PROVIDERS[ap]["name"]} / {am}' ) else: sub = "AI-powered data analysis · Connect a provider and upload data to begin" st.markdown(f'
{sub}
', unsafe_allow_html=True) # ─── Landing state ──────────────────────────────────────────────────────────── if st.session_state.df is None: col1, col2, col3 = st.columns(3) cards = [ ("🤖", "7 AI Providers", "Gemini, GPT, Claude, Grok, Mistral, Llama, Qwen — latest models"), ("📂", "CSV · Excel · JSON", "Upload any tabular data file — we handle parsing automatically"), ("📊", "Smart Visualizations", "AI picks the right chart for your question automatically"), ] for col, (icon, title, desc) in zip([col1, col2, col3], cards): with col: st.markdown( f'
{icon}
' f'
{title}

' f'

{desc}

', unsafe_allow_html=True, ) st.markdown("
", unsafe_allow_html=True) if not st.session_state.api_key_set: st.info("👈 Choose a provider, enter your API key and click **Connect** in the sidebar.") else: st.info("👈 Upload a data file (CSV, Excel, or JSON) in the sidebar to get started!") else: df = st.session_state.df profile = st.session_state.profile llm = st.session_state.llm tab1, tab2, tab3, tab4 = st.tabs(["📊 Dashboard", "💬 Chat", "🎨 Charts", "🔍 Raw Data"]) # ══════════════════════════════════════════════════════════════════════════ # TAB 1 — Dashboard # ══════════════════════════════════════════════════════════════════════════ with tab1: rows, cols = profile["shape"] nulls = sum(profile["null_counts"].values()) num_c = len(profile["numeric_columns"]) cat_c = len(profile["categorical_columns"]) c1, c2, c3, c4 = st.columns(4) for col_obj, num, label in [ (c1, f"{rows:,}", "Rows"), (c2, str(cols), "Columns"), (c3, str(num_c), "Numeric Cols"), (c4, str(nulls), "Missing Values"), ]: col_obj.markdown( f'
{num}
' f'
{label}
', unsafe_allow_html=True, ) st.markdown("
", unsafe_allow_html=True) st.markdown("#### 📋 Column Overview") col_info = pd.DataFrame({ "Column": df.columns, "Type": df.dtypes.astype(str).values, "Non-Null": df.notnull().sum().values, "Null %": (df.isnull().mean() * 100).round(1).values, "Unique": df.nunique().values, }) st.dataframe(col_info, use_container_width=True, hide_index=True) st.markdown("#### 🤖 Auto-Generated Insights") suggested = auto_suggest_charts(profile)[:3] chart_cols = st.columns(min(len(suggested), 2)) for i, ctype in enumerate(suggested[:2]): with chart_cols[i]: try: fig = make_plotly_chart(ctype, df, profile) st.plotly_chart(fig, use_container_width=True) except Exception as e: st.warning(f"Could not render {ctype}: {e}") if len(suggested) > 2: try: fig = make_plotly_chart(suggested[2], df, profile) st.plotly_chart(fig, use_container_width=True) except Exception: pass st.markdown("#### 🧠 AI Dataset Summary") ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI") if st.button(f"✨ Generate Summary with {ap_name}"): with st.spinner(f"{ap_name} is analyzing your dataset..."): summary = ask_agent( "Give me a concise executive summary of this dataset. " "Highlight key patterns, anomalies, and 3 actionable insights.", df, profile, llm, ) st.markdown(f'
{summary}
', unsafe_allow_html=True) # ══════════════════════════════════════════════════════════════════════════ # TAB 2 — Chat # ══════════════════════════════════════════════════════════════════════════ with tab2: ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI") st.markdown(f"#### 💬 Ask Anything About Your Data") st.markdown(f"*Powered by **{ap_name} / {st.session_state.active_model}***") st.markdown("**Quick questions to try:**") suggestions = [ "What are the top 5 most important patterns in this data?", "Are there any outliers or anomalies I should know about?", "What correlations exist between the numeric columns?", ] q_cols = st.columns(3) for i, s in enumerate(suggestions): with q_cols[i]: if st.button(s, key=f"sug_{i}"): st.session_state["prefill_q"] = s for turn in st.session_state.chat_history: st.markdown(f'
👤 {turn["user"]}
', unsafe_allow_html=True) st.markdown(f'
🧠 {turn["agent"]}
', unsafe_allow_html=True) prefill = st.session_state.pop("prefill_q", "") question = st.text_input( "Ask a question...", value=prefill, placeholder="e.g. What's the average sales by region?", label_visibility="collapsed", ) col_send, col_clear = st.columns([1, 5]) with col_send: send = st.button("Send 🚀") with col_clear: if st.button("Clear Chat"): st.session_state.chat_history = [] st.rerun() if send and question.strip(): with st.spinner(f"🧠 {ap_name} is thinking..."): answer = ask_agent(question, df, profile, llm) chart_rec = ai_recommend_chart(question, profile, llm) st.session_state.chat_history.append({ "user": question, "agent": answer, "chart_rec": chart_rec, }) st.markdown(f'
👤 {question}
', unsafe_allow_html=True) st.markdown(f'
🧠 {answer}
', unsafe_allow_html=True) if chart_rec: st.markdown( f"*📊 Suggested chart: **{chart_rec['chart_type']}** — " f"{chart_rec.get('reason','')}*" ) try: fig = make_plotly_chart( chart_rec["chart_type"], df, profile, x_col=chart_rec.get("x_col"), y_col=chart_rec.get("y_col"), ) st.plotly_chart(fig, use_container_width=True) except Exception: pass # ══════════════════════════════════════════════════════════════════════════ # TAB 3 — Charts # ══════════════════════════════════════════════════════════════════════════ with tab3: st.markdown("#### 🎨 Custom Chart Builder") chart_options = { "Correlation Heatmap": "correlation_heatmap", "Distribution Plot": "distribution_plots", "Box Plots": "box_plots", "Bar Chart": "bar_chart", "Pie Chart": "pie_chart", "Scatter Plot": "scatter", "Line Chart": "line", "Scatter Matrix": "scatter_matrix", } if profile["datetime_columns"]: chart_options["Time Series"] = "time_series" c1, c2, c3 = st.columns(3) with c1: chart_label = st.selectbox("Chart Type", list(chart_options.keys())) with c2: all_cols = ["(auto)"] + df.columns.tolist() x_col = st.selectbox("X Column", all_cols) with c3: y_col = st.selectbox("Y Column", all_cols) x_val = None if x_col == "(auto)" else x_col y_val = None if y_col == "(auto)" else y_col if st.button("🎨 Generate Chart"): with st.spinner("Rendering..."): try: fig = make_plotly_chart( chart_options[chart_label], df, profile, x_col=x_val, y_col=y_val, ) st.plotly_chart(fig, use_container_width=True) except Exception as e: st.error(f"Chart error: {e}") st.markdown("---") st.markdown("#### 📊 All Auto-Suggested Charts") suggested_all = auto_suggest_charts(profile) for i in range(0, len(suggested_all), 2): cols = st.columns(2) for j, ctype in enumerate(suggested_all[i:i+2]): with cols[j]: try: fig = make_plotly_chart(ctype, df, profile) st.plotly_chart(fig, use_container_width=True) except Exception: st.warning(f"Could not render {ctype}") # ══════════════════════════════════════════════════════════════════════════ # TAB 4 — Raw Data # ══════════════════════════════════════════════════════════════════════════ with tab4: st.markdown("#### 🔍 Raw Data Explorer") search = st.text_input("🔎 Filter rows containing...", placeholder="Type to filter...") if search: mask = df.astype(str).apply( lambda row: row.str.contains(search, case=False, na=False) ).any(axis=1) display_df = df[mask] st.info(f"Showing {len(display_df):,} of {len(df):,} rows matching '{search}'") else: display_df = df st.dataframe(display_df, use_container_width=True, height=500) csv_buf = io.StringIO() df.to_csv(csv_buf, index=False) st.download_button( "⬇️ Download as CSV", data=csv_buf.getvalue(), file_name="analyzed_data.csv", mime="text/csv", )