Spaces:
Running
Running
| """ | |
| app.py | |
| ====== | |
| DataMind Agent β Multi-LLM Streamlit UI | |
| Supports: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok, | |
| Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI) | |
| Run: streamlit run app.py | |
| """ | |
| import os | |
| import io | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| from core_agent import ( | |
| PROVIDERS, | |
| get_llm, validate_llm, | |
| load_file, profile_dataframe, profile_to_text, | |
| ask_agent, auto_suggest_charts, make_plotly_chart, ai_recommend_chart, | |
| ) | |
| # βββ Page config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="DataMind Agent", | |
| page_icon="π§ ", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # βββ CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;700;800&family=DM+Sans:wght@300;400;500&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'DM Sans', sans-serif; | |
| background-color: #0a0a12; | |
| color: #e8e8ff; | |
| } | |
| .main { background-color: #0a0a12; } | |
| .hero-title { | |
| font-family: 'Syne', sans-serif; | |
| font-size: 2.8rem; font-weight: 800; | |
| background: linear-gradient(135deg, #e8e8ff 0%, #6C63FF 50%, #43E97B 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| background-clip: text; margin-bottom: 0.2rem; | |
| } | |
| .hero-sub { color: #6a6a9a; font-size: 1rem; margin-bottom: 2rem; } | |
| .stat-card { | |
| background: #1a1a2e; border: 1px solid #2a2a45; | |
| border-radius: 16px; padding: 1.2rem 1.5rem; text-align: center; | |
| } | |
| .stat-num { font-family: 'Syne', sans-serif; font-size: 2rem; font-weight: 800; color: #6C63FF; } | |
| .stat-label { color: #6a6a9a; font-size: 0.8rem; text-transform: uppercase; letter-spacing: 0.1em; } | |
| /* Provider badge */ | |
| .provider-badge { | |
| display: inline-block; | |
| padding: 3px 10px; border-radius: 20px; | |
| font-size: 0.72rem; font-weight: 700; | |
| letter-spacing: 0.05em; | |
| margin-bottom: 0.5rem; | |
| } | |
| .user-bubble { | |
| background: rgba(108,99,255,0.15); border: 1px solid rgba(108,99,255,0.3); | |
| border-radius: 18px 18px 4px 18px; padding: 0.9rem 1.2rem; | |
| margin: 0.5rem 0; font-size: 0.95rem; | |
| } | |
| .agent-bubble { | |
| background: #1a1a2e; border: 1px solid #2a2a45; | |
| border-radius: 18px 18px 18px 4px; padding: 0.9rem 1.2rem; | |
| margin: 0.5rem 0; font-size: 0.95rem; line-height: 1.6; | |
| } | |
| section[data-testid="stSidebar"] { | |
| background: #10101e; border-right: 1px solid #2a2a45; | |
| } | |
| .stButton > button { | |
| background: linear-gradient(135deg, #6C63FF, #43E97B); | |
| color: white; border: none; border-radius: 12px; | |
| font-family: 'Syne', sans-serif; font-weight: 700; | |
| padding: 0.6rem 1.5rem; transition: opacity 0.2s; | |
| } | |
| .stButton > button:hover { opacity: 0.85; color: white; } | |
| .stTextInput > div > div > input { | |
| background: #1a1a2e; border: 1px solid #2a2a45; | |
| border-radius: 12px; color: #e8e8ff; | |
| } | |
| .stSelectbox > div > div { | |
| background: #1a1a2e; border: 1px solid #2a2a45; border-radius: 12px; | |
| } | |
| .stTabs [data-baseweb="tab-list"] { | |
| background: #10101e; border-radius: 12px; gap: 0.3rem; | |
| } | |
| .stTabs [data-baseweb="tab"] { | |
| background: transparent; color: #6a6a9a; | |
| border-radius: 10px; font-family: 'Syne', sans-serif; | |
| } | |
| .stTabs [aria-selected="true"] { | |
| background: rgba(108,99,255,0.2) !important; color: #6C63FF !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # βββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for key, default in { | |
| "df": None, | |
| "profile": None, | |
| "file_type": None, | |
| "chat_history": [], | |
| "llm": None, | |
| "active_provider": None, | |
| "active_model": None, | |
| "api_key_set": False, | |
| }.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = default | |
| # βββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("### π§ DataMind Agent") | |
| st.markdown("---") | |
| # ββ Provider selector βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("**π€ Choose AI Provider**") | |
| provider_labels = {k: v["name"] for k, v in PROVIDERS.items()} | |
| selected_provider = st.selectbox( | |
| "Provider", | |
| options=list(provider_labels.keys()), | |
| format_func=lambda k: provider_labels[k], | |
| label_visibility="collapsed", | |
| key="provider_select", | |
| ) | |
| pinfo = PROVIDERS[selected_provider] | |
| # Color dot for the selected provider | |
| st.markdown( | |
| f'<span class="provider-badge" style="background:{pinfo["color"]}22;' | |
| f'color:{pinfo["color"]};border:1px solid {pinfo["color"]}55;">' | |
| f'β {pinfo["name"]}</span>', | |
| unsafe_allow_html=True, | |
| ) | |
| # Model selector | |
| selected_model = st.selectbox( | |
| "Model", | |
| options=pinfo["models"], | |
| index=0, | |
| key=f"model_{selected_provider}", | |
| ) | |
| # Show a note for providers that need a third-party key (e.g. Together AI) | |
| if pinfo.get("note"): | |
| st.caption(f"βΉοΈ {pinfo['note']}") | |
| # API key | |
| st.markdown(f"**π {pinfo['name']} API Key**") | |
| api_key = st.text_input( | |
| "API Key", | |
| type="password", | |
| placeholder=pinfo["key_hint"], | |
| label_visibility="collapsed", | |
| key=f"apikey_{selected_provider}", | |
| ) | |
| connect_btn = st.button("π Connect", key="connect_btn") | |
| if connect_btn and api_key: | |
| with st.spinner(f"Connecting to {pinfo['name']}..."): | |
| try: | |
| llm, msg = validate_llm(selected_provider, api_key, selected_model) | |
| st.session_state.llm = llm | |
| st.session_state.api_key_set = True | |
| st.session_state.active_provider = selected_provider | |
| st.session_state.active_model = selected_model | |
| st.session_state.chat_history = [] | |
| st.success(msg) | |
| except Exception as e: | |
| st.session_state.api_key_set = False | |
| st.error(f"β Connection failed: {e}") | |
| elif connect_btn and not api_key: | |
| st.warning("β οΈ Please enter your API key first.") | |
| # Show active connection status | |
| if st.session_state.api_key_set and st.session_state.active_provider: | |
| ap = st.session_state.active_provider | |
| am = st.session_state.active_model | |
| ac = PROVIDERS[ap]["color"] | |
| st.markdown( | |
| f'<div style="margin-top:8px;padding:8px 12px;border-radius:10px;' | |
| f'background:{ac}15;border:1px solid {ac}40;font-size:0.78rem;">' | |
| f'<span style="color:{ac}">β</span> <b>{PROVIDERS[ap]["name"]}</b><br/>' | |
| f'<span style="color:#6a6a9a">{am}</span></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("---") | |
| # ββ File upload βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("**π Upload Data File**") | |
| uploaded = st.file_uploader( | |
| "Upload", | |
| type=["csv", "xlsx", "xls", "json"], | |
| label_visibility="collapsed", | |
| ) | |
| if uploaded and st.session_state.api_key_set: | |
| with st.spinner("π Analyzing your data..."): | |
| try: | |
| df, ftype = load_file(uploaded) | |
| st.session_state.df = df | |
| st.session_state.file_type = ftype | |
| st.session_state.profile = profile_dataframe(df) | |
| st.session_state.chat_history = [] | |
| st.success(f"β Loaded {ftype} file!") | |
| except Exception as e: | |
| st.error(f"β Error: {e}") | |
| elif uploaded and not st.session_state.api_key_set: | |
| st.warning("β οΈ Connect to an AI provider first.") | |
| st.markdown("---") | |
| st.markdown(""" | |
| **How to use:** | |
| 1. Choose an AI provider | |
| 2. Select a model | |
| 3. Paste your API key β click Connect | |
| 4. Upload CSV, Excel, or JSON | |
| 5. Explore Β· Chat Β· Visualize | |
| --- | |
| **Get API keys:** | |
| - [Gemini](https://aistudio.google.com/app/apikey) | |
| - [OpenAI](https://platform.openai.com/api-keys) | |
| - [Claude](https://console.anthropic.com/) | |
| - [Grok](https://console.x.ai/) | |
| - [Mistral](https://console.mistral.ai/) | |
| - [Llama / Qwen β Together AI](https://api.together.ai/) | |
| """) | |
| # βββ Main content βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown('<div class="hero-title">π§ DataMind Agent</div>', unsafe_allow_html=True) | |
| # Dynamic subtitle showing active provider | |
| if st.session_state.api_key_set and st.session_state.active_provider: | |
| ap = st.session_state.active_provider | |
| am = st.session_state.active_model | |
| ac = PROVIDERS[ap]["color"] | |
| sub = ( | |
| f'AI-powered data analysis Β· Powered by ' | |
| f'<span style="color:{ac};font-weight:600">' | |
| f'{PROVIDERS[ap]["name"]} / {am}</span>' | |
| ) | |
| else: | |
| sub = "AI-powered data analysis Β· Connect a provider and upload data to begin" | |
| st.markdown(f'<div class="hero-sub">{sub}</div>', unsafe_allow_html=True) | |
| # βββ Landing state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if st.session_state.df is None: | |
| col1, col2, col3 = st.columns(3) | |
| cards = [ | |
| ("π€", "7 AI Providers", "Gemini, GPT, Claude, Grok, Mistral, Llama, Qwen β latest models"), | |
| ("π", "CSV Β· Excel Β· JSON", "Upload any tabular data file β we handle parsing automatically"), | |
| ("π", "Smart Visualizations", "AI picks the right chart for your question automatically"), | |
| ] | |
| for col, (icon, title, desc) in zip([col1, col2, col3], cards): | |
| with col: | |
| st.markdown( | |
| f'<div class="stat-card"><div class="stat-num">{icon}</div>' | |
| f'<div class="stat-label">{title}</div><br>' | |
| f'<p style="color:#6a6a9a;font-size:0.85rem">{desc}</p></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| if not st.session_state.api_key_set: | |
| st.info("π Choose a provider, enter your API key and click **Connect** in the sidebar.") | |
| else: | |
| st.info("π Upload a data file (CSV, Excel, or JSON) in the sidebar to get started!") | |
| else: | |
| df = st.session_state.df | |
| profile = st.session_state.profile | |
| llm = st.session_state.llm | |
| tab1, tab2, tab3, tab4 = st.tabs(["π Dashboard", "π¬ Chat", "π¨ Charts", "π Raw Data"]) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 1 β Dashboard | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab1: | |
| rows, cols = profile["shape"] | |
| nulls = sum(profile["null_counts"].values()) | |
| num_c = len(profile["numeric_columns"]) | |
| cat_c = len(profile["categorical_columns"]) | |
| c1, c2, c3, c4 = st.columns(4) | |
| for col_obj, num, label in [ | |
| (c1, f"{rows:,}", "Rows"), | |
| (c2, str(cols), "Columns"), | |
| (c3, str(num_c), "Numeric Cols"), | |
| (c4, str(nulls), "Missing Values"), | |
| ]: | |
| col_obj.markdown( | |
| f'<div class="stat-card"><div class="stat-num">{num}</div>' | |
| f'<div class="stat-label">{label}</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| st.markdown("#### π Column Overview") | |
| col_info = pd.DataFrame({ | |
| "Column": df.columns, | |
| "Type": df.dtypes.astype(str).values, | |
| "Non-Null": df.notnull().sum().values, | |
| "Null %": (df.isnull().mean() * 100).round(1).values, | |
| "Unique": df.nunique().values, | |
| }) | |
| st.dataframe(col_info, use_container_width=True, hide_index=True) | |
| st.markdown("#### π€ Auto-Generated Insights") | |
| suggested = auto_suggest_charts(profile)[:3] | |
| chart_cols = st.columns(min(len(suggested), 2)) | |
| for i, ctype in enumerate(suggested[:2]): | |
| with chart_cols[i]: | |
| try: | |
| fig = make_plotly_chart(ctype, df, profile) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.warning(f"Could not render {ctype}: {e}") | |
| if len(suggested) > 2: | |
| try: | |
| fig = make_plotly_chart(suggested[2], df, profile) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception: | |
| pass | |
| st.markdown("#### π§ AI Dataset Summary") | |
| ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI") | |
| if st.button(f"β¨ Generate Summary with {ap_name}"): | |
| with st.spinner(f"{ap_name} is analyzing your dataset..."): | |
| summary = ask_agent( | |
| "Give me a concise executive summary of this dataset. " | |
| "Highlight key patterns, anomalies, and 3 actionable insights.", | |
| df, profile, llm, | |
| ) | |
| st.markdown(f'<div class="agent-bubble">{summary}</div>', unsafe_allow_html=True) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 2 β Chat | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab2: | |
| ap_name = PROVIDERS.get(st.session_state.active_provider, {}).get("name", "AI") | |
| st.markdown(f"#### π¬ Ask Anything About Your Data") | |
| st.markdown(f"*Powered by **{ap_name} / {st.session_state.active_model}***") | |
| st.markdown("**Quick questions to try:**") | |
| suggestions = [ | |
| "What are the top 5 most important patterns in this data?", | |
| "Are there any outliers or anomalies I should know about?", | |
| "What correlations exist between the numeric columns?", | |
| ] | |
| q_cols = st.columns(3) | |
| for i, s in enumerate(suggestions): | |
| with q_cols[i]: | |
| if st.button(s, key=f"sug_{i}"): | |
| st.session_state["prefill_q"] = s | |
| for turn in st.session_state.chat_history: | |
| st.markdown(f'<div class="user-bubble">π€ {turn["user"]}</div>', unsafe_allow_html=True) | |
| st.markdown(f'<div class="agent-bubble">π§ {turn["agent"]}</div>', unsafe_allow_html=True) | |
| prefill = st.session_state.pop("prefill_q", "") | |
| question = st.text_input( | |
| "Ask a question...", | |
| value=prefill, | |
| placeholder="e.g. What's the average sales by region?", | |
| label_visibility="collapsed", | |
| ) | |
| col_send, col_clear = st.columns([1, 5]) | |
| with col_send: | |
| send = st.button("Send π") | |
| with col_clear: | |
| if st.button("Clear Chat"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| if send and question.strip(): | |
| with st.spinner(f"π§ {ap_name} is thinking..."): | |
| answer = ask_agent(question, df, profile, llm) | |
| chart_rec = ai_recommend_chart(question, profile, llm) | |
| st.session_state.chat_history.append({ | |
| "user": question, "agent": answer, "chart_rec": chart_rec, | |
| }) | |
| st.markdown(f'<div class="user-bubble">π€ {question}</div>', unsafe_allow_html=True) | |
| st.markdown(f'<div class="agent-bubble">π§ {answer}</div>', unsafe_allow_html=True) | |
| if chart_rec: | |
| st.markdown( | |
| f"*π Suggested chart: **{chart_rec['chart_type']}** β " | |
| f"{chart_rec.get('reason','')}*" | |
| ) | |
| try: | |
| fig = make_plotly_chart( | |
| chart_rec["chart_type"], df, profile, | |
| x_col=chart_rec.get("x_col"), | |
| y_col=chart_rec.get("y_col"), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception: | |
| pass | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 3 β Charts | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab3: | |
| st.markdown("#### π¨ Custom Chart Builder") | |
| chart_options = { | |
| "Correlation Heatmap": "correlation_heatmap", | |
| "Distribution Plot": "distribution_plots", | |
| "Box Plots": "box_plots", | |
| "Bar Chart": "bar_chart", | |
| "Pie Chart": "pie_chart", | |
| "Scatter Plot": "scatter", | |
| "Line Chart": "line", | |
| "Scatter Matrix": "scatter_matrix", | |
| } | |
| if profile["datetime_columns"]: | |
| chart_options["Time Series"] = "time_series" | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| chart_label = st.selectbox("Chart Type", list(chart_options.keys())) | |
| with c2: | |
| all_cols = ["(auto)"] + df.columns.tolist() | |
| x_col = st.selectbox("X Column", all_cols) | |
| with c3: | |
| y_col = st.selectbox("Y Column", all_cols) | |
| x_val = None if x_col == "(auto)" else x_col | |
| y_val = None if y_col == "(auto)" else y_col | |
| if st.button("π¨ Generate Chart"): | |
| with st.spinner("Rendering..."): | |
| try: | |
| fig = make_plotly_chart( | |
| chart_options[chart_label], df, profile, | |
| x_col=x_val, y_col=y_val, | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"Chart error: {e}") | |
| st.markdown("---") | |
| st.markdown("#### π All Auto-Suggested Charts") | |
| suggested_all = auto_suggest_charts(profile) | |
| for i in range(0, len(suggested_all), 2): | |
| cols = st.columns(2) | |
| for j, ctype in enumerate(suggested_all[i:i+2]): | |
| with cols[j]: | |
| try: | |
| fig = make_plotly_chart(ctype, df, profile) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception: | |
| st.warning(f"Could not render {ctype}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 4 β Raw Data | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab4: | |
| st.markdown("#### π Raw Data Explorer") | |
| search = st.text_input("π Filter rows containing...", placeholder="Type to filter...") | |
| if search: | |
| mask = df.astype(str).apply( | |
| lambda row: row.str.contains(search, case=False, na=False) | |
| ).any(axis=1) | |
| display_df = df[mask] | |
| st.info(f"Showing {len(display_df):,} of {len(df):,} rows matching '{search}'") | |
| else: | |
| display_df = df | |
| st.dataframe(display_df, use_container_width=True, height=500) | |
| csv_buf = io.StringIO() | |
| df.to_csv(csv_buf, index=False) | |
| st.download_button( | |
| "β¬οΈ Download as CSV", | |
| data=csv_buf.getvalue(), | |
| file_name="analyzed_data.csv", | |
| mime="text/csv", | |
| ) | |