Spaces:
Running
Running
| """ | |
| core_agent.py | |
| ============= | |
| DataMind Agent β Multi-LLM Core Logic | |
| Supports: Google Gemini, OpenAI GPT, Anthropic Claude, xAI Grok, | |
| Mistral AI, Meta Llama (via Together AI), Alibaba Qwen (via Together AI) | |
| File formats: CSV, Excel (.xlsx, .xls), JSON | |
| """ | |
| import os | |
| import io | |
| import json | |
| import warnings | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| import seaborn as sns | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from dotenv import load_dotenv | |
| warnings.filterwarnings("ignore") | |
| load_dotenv() | |
| # βββ Palette ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PALETTE = ["#6C63FF", "#FF6584", "#43E97B", "#F7971E", "#4FC3F7", "#CE93D8"] | |
| DARK_BG = "#0F0F1A" | |
| CARD_BG = "#1A1A2E" | |
| # βββ Provider registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PROVIDERS = { | |
| "gemini": { | |
| "name": "Google Gemini", | |
| "models": [ | |
| # Gemini 3 Series | |
| "gemini-3.1-pro-preview", | |
| "gemini-3-flash-preview", | |
| "gemini-3.1-flash-lite-preview", | |
| # Gemini 2.5 Series | |
| "gemini-2.5-pro", | |
| "gemini-2.5-flash", | |
| "gemini-2.5-flash-lite", | |
| # Gemini 2.0 Series (Legacy - closing June 2026) | |
| "gemini-2.0-flash", | |
| "gemini-2.0-flash-lite", | |
| # Gemini 1.5 Series (Legacy) | |
| "gemini-1.5-pro-002", | |
| "gemini-1.5-flash-002", | |
| ], | |
| "default": "gemini-2.5-pro", | |
| "key_hint": "AIza...", | |
| "color": "#4285f4", | |
| "key_url": "https://aistudio.google.com/app/apikey", | |
| }, | |
| "openai": { | |
| "name": "OpenAI GPT", | |
| "models": [ | |
| # GPT-5 Series | |
| "gpt-5.4-pro", | |
| "gpt-5.4", | |
| "gpt-5.3-instant", | |
| "gpt-5-mini", | |
| "gpt-5-nano", | |
| # Reasoning (o-series) | |
| "o3-deep-research", | |
| "o3", | |
| "o4-mini", | |
| # GPT-4 Series (Legacy) | |
| "gpt-4.1", | |
| "gpt-4o", | |
| "gpt-4o-mini", | |
| "gpt-4-turbo", | |
| "gpt-4-0613", | |
| # GPT-3.5 Series (Legacy) | |
| "gpt-3.5-turbo-0125", | |
| ], | |
| "default": "gpt-5.4-pro", | |
| "key_hint": "sk-...", | |
| "color": "#10a37f", | |
| "key_url": "https://platform.openai.com/api-keys", | |
| }, | |
| "claude": { | |
| "name": "Anthropic Claude", | |
| "models": [ | |
| # Claude 4 Series | |
| "claude-opus-4-6", | |
| "claude-sonnet-4-6", | |
| "claude-haiku-4-5-20251001", | |
| # Claude 3.5 Series (Legacy) | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-haiku-20241022", | |
| # Claude 3 Series (Legacy) | |
| "claude-3-opus-20240229", | |
| "claude-3-sonnet-20240229", | |
| "claude-3-haiku-20240307", | |
| # Claude 2 Series (Legacy) | |
| "claude-2.1", | |
| "claude-2.0", | |
| ], | |
| "default": "claude-sonnet-4-6", | |
| "key_hint": "sk-ant-...", | |
| "color": "#d97706", | |
| "key_url": "https://console.anthropic.com/", | |
| }, | |
| "grok": { | |
| "name": "xAI Grok", | |
| "models": [ | |
| # Grok 4 Series | |
| "grok-4.2", | |
| "grok-4.1-fast", | |
| "grok-code-fast-1", | |
| # Grok 2 Series (Legacy) | |
| "grok-2-1212", | |
| "grok-2-mini", | |
| ], | |
| "default": "grok-4.2", | |
| "key_hint": "xai-...", | |
| "color": "#9b9b9b", | |
| "key_url": "https://console.x.ai/", | |
| }, | |
| "mistral": { | |
| "name": "Mistral AI", | |
| "models": [ | |
| # Mistral Frontier | |
| "mistral-large-2411", | |
| "mistral-medium-2508", | |
| "magistral-medium-1.2", | |
| # Mistral Legacy | |
| "mistral-large-2407", | |
| "mistral-small-2409", | |
| "open-mixtral-8x22b", | |
| "open-mixtral-8x7b", | |
| ], | |
| "default": "mistral-large-2411", | |
| "key_hint": "...", | |
| "color": "#ff6b35", | |
| "key_url": "https://console.mistral.ai/", | |
| }, | |
| "llama": { | |
| "name": "Meta Llama", | |
| "models": [ | |
| # Llama 4 Series | |
| "meta-llama/llama-4-behemoth", | |
| "meta-llama/llama-4-maverick", | |
| "meta-llama/llama-4-scout", | |
| # Llama 3 Series (Legacy) | |
| "meta-llama/llama-3.3-70b-instruct", | |
| "meta-llama/llama-3.2-90b-vision", | |
| "meta-llama/llama-3.2-3b", | |
| "meta-llama/llama-3.1-405b", | |
| "meta-llama/llama-3.1-70b", | |
| "meta-llama/llama-3.1-8b", | |
| # Llama 2 Series (Legacy) | |
| "meta-llama/llama-2-70b-chat", | |
| "meta-llama/llama-2-13b-chat", | |
| "meta-llama/llama-2-7b-chat", | |
| ], | |
| "default": "meta-llama/llama-4-maverick", | |
| "key_hint": "Together AI key...", | |
| "color": "#0668E1", | |
| "key_url": "https://api.together.ai/", | |
| "note": "Requires a Together AI API key (api.together.ai)", | |
| }, | |
| "qwen": { | |
| "name": "Alibaba Qwen", | |
| "models": [ | |
| # Qwen 3 Series | |
| "Qwen/qwen-3.5-plus", | |
| "Qwen/qwen-3.5-flash", | |
| "Qwen/qwen-3-max-thinking", | |
| "Qwen/qwen3-coder-480b", | |
| # Qwen 2 Series (Legacy) | |
| "Qwen/qwen2.5-72b-instruct", | |
| "Qwen/qwen2.5-coder-32b", | |
| "Qwen/qwen2-72b-instruct", | |
| # Qwen 1.5 Series (Legacy) | |
| "Qwen/qwen1.5-110b", | |
| "Qwen/qwen1.5-72b-chat", | |
| ], | |
| "default": "Qwen/qwen-3.5-plus", | |
| "key_hint": "Together AI key...", | |
| "color": "#6547d4", | |
| "key_url": "https://api.together.ai/", | |
| "note": "Requires a Together AI API key (api.together.ai)", | |
| }, | |
| } | |
| # βββ LLM factory ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_llm(provider: str, api_key: str, model: str = None): | |
| """ | |
| Returns a LangChain chat model for any supported provider. | |
| All returned objects expose an .invoke(messages) -> response.content interface. | |
| """ | |
| pinfo = PROVIDERS[provider] | |
| model = model or pinfo["default"] | |
| if provider == "gemini": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI( | |
| model=model, | |
| google_api_key=api_key, | |
| temperature=0.3, | |
| convert_system_message_to_human=True, | |
| ) | |
| elif provider == "openai": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| api_key=api_key, | |
| temperature=0.3, | |
| ) | |
| elif provider == "claude": | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic( | |
| model=model, | |
| api_key=api_key, | |
| temperature=0.3, | |
| ) | |
| elif provider == "grok": | |
| # xAI Grok uses an OpenAI-compatible endpoint | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| api_key=api_key, | |
| base_url="https://api.x.ai/v1", | |
| temperature=0.3, | |
| ) | |
| elif provider == "mistral": | |
| from langchain_mistralai import ChatMistralAI | |
| return ChatMistralAI( | |
| model=model, | |
| api_key=api_key, | |
| temperature=0.3, | |
| ) | |
| elif provider == "llama": | |
| # Meta Llama via Together AI (OpenAI-compatible) | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| api_key=api_key, | |
| base_url="https://api.together.xyz/v1", | |
| temperature=0.3, | |
| ) | |
| elif provider == "qwen": | |
| # Alibaba Qwen via Together AI (OpenAI-compatible) | |
| # Also works with DashScope: base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| api_key=api_key, | |
| base_url="https://api.together.xyz/v1", | |
| temperature=0.3, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def validate_llm(provider: str, api_key: str, model: str = None) -> tuple: | |
| """ | |
| Instantiate and ping the LLM with a lightweight call. | |
| Returns (llm, success_message) or raises on failure. | |
| """ | |
| from langchain_core.messages import HumanMessage | |
| llm = get_llm(provider, api_key, model) | |
| llm.invoke([HumanMessage(content="Say OK")]) | |
| return llm, f"β Connected to {PROVIDERS[provider]['name']}!" | |
| # βββ File loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_file(file) -> tuple: | |
| """Load an uploaded file into a DataFrame. Returns (df, file_type).""" | |
| name = file.name.lower() | |
| if name.endswith(".csv"): | |
| df = pd.read_csv(file) | |
| return df, "CSV" | |
| elif name.endswith((".xlsx", ".xls")): | |
| df = pd.read_excel(file) | |
| return df, "Excel" | |
| elif name.endswith(".json"): | |
| content = json.load(file) | |
| if isinstance(content, list): | |
| df = pd.DataFrame(content) | |
| elif isinstance(content, dict): | |
| df = pd.DataFrame([content]) if not any( | |
| isinstance(v, list) for v in content.values() | |
| ) else pd.DataFrame(content) | |
| return df, "JSON" | |
| else: | |
| raise ValueError(f"Unsupported file type: {name}") | |
| # βββ Data profiling βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def profile_dataframe(df: pd.DataFrame) -> dict: | |
| """Generate a rich statistical profile of the dataframe.""" | |
| numeric_cols = df.select_dtypes(include="number").columns.tolist() | |
| category_cols = df.select_dtypes(include=["object", "category"]).columns.tolist() | |
| datetime_cols = df.select_dtypes(include=["datetime"]).columns.tolist() | |
| profile = { | |
| "shape": df.shape, | |
| "columns": df.columns.tolist(), | |
| "dtypes": df.dtypes.astype(str).to_dict(), | |
| "numeric_columns": numeric_cols, | |
| "categorical_columns": category_cols, | |
| "datetime_columns": datetime_cols, | |
| "null_counts": df.isnull().sum().to_dict(), | |
| "null_pct": (df.isnull().mean() * 100).round(2).to_dict(), | |
| "duplicates": int(df.duplicated().sum()), | |
| } | |
| if numeric_cols: | |
| desc = df[numeric_cols].describe().round(3) | |
| profile["numeric_stats"] = desc.to_dict() | |
| if category_cols: | |
| profile["top_categories"] = { | |
| col: df[col].value_counts().head(5).to_dict() | |
| for col in category_cols | |
| } | |
| return profile | |
| def profile_to_text(profile: dict, df: pd.DataFrame) -> str: | |
| """Convert profile dict to LLM-readable text summary.""" | |
| rows, cols = profile["shape"] | |
| lines = [ | |
| f"Dataset: {rows} rows x {cols} columns", | |
| f"Numeric columns : {', '.join(profile['numeric_columns']) or 'None'}", | |
| f"Categorical cols : {', '.join(profile['categorical_columns']) or 'None'}", | |
| f"Datetime cols : {', '.join(profile['datetime_columns']) or 'None'}", | |
| f"Missing values : {sum(profile['null_counts'].values())} total", | |
| f"Duplicate rows : {profile['duplicates']}", | |
| "", | |
| "--- Sample Data (first 5 rows) ---", | |
| df.head(5).to_string(index=False), | |
| ] | |
| if profile.get("numeric_stats"): | |
| lines += ["", "--- Numeric Stats ---"] | |
| for col, stats in profile["numeric_stats"].items(): | |
| lines.append( | |
| f" {col}: mean={stats.get('mean','?')}, std={stats.get('std','?')}, " | |
| f"min={stats.get('min','?')}, max={stats.get('max','?')}" | |
| ) | |
| return "\n".join(lines) | |
| # βββ AI Q&A βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ask_agent(question: str, df: pd.DataFrame, profile: dict, llm) -> str: | |
| """Send a question + data context to the active LLM and return the answer.""" | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| data_context = profile_to_text(profile, df) | |
| system = ( | |
| "You are DataMind, an expert data analyst AI. You receive a dataset summary " | |
| "and answer questions about it. Be precise, insightful, and helpful. " | |
| "When relevant, suggest what visualizations would best illustrate the answer. " | |
| "Format your response clearly. Use bullet points for lists. " | |
| "Use numbers and percentages when quoting statistics." | |
| ) | |
| user_msg = ( | |
| f"Here is the dataset context:\n\n{data_context}\n\n" | |
| f"User question: {question}\n\n" | |
| "Provide a thorough, accurate analysis. " | |
| "If you perform calculations, show the logic briefly." | |
| ) | |
| messages = [ | |
| SystemMessage(content=system), | |
| HumanMessage(content=user_msg), | |
| ] | |
| response = llm.invoke(messages) | |
| return response.content | |
| # βββ Visualization engine βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def auto_suggest_charts(profile: dict) -> list: | |
| """Suggest relevant chart types based on data profile.""" | |
| suggestions = [] | |
| if len(profile["numeric_columns"]) >= 2: | |
| suggestions.append("correlation_heatmap") | |
| suggestions.append("scatter_matrix") | |
| if profile["numeric_columns"]: | |
| suggestions.append("distribution_plots") | |
| suggestions.append("box_plots") | |
| if profile["categorical_columns"] and profile["numeric_columns"]: | |
| suggestions.append("bar_chart") | |
| suggestions.append("pie_chart") | |
| if profile["datetime_columns"] and profile["numeric_columns"]: | |
| suggestions.append("time_series") | |
| return suggestions | |
| def make_plotly_chart( | |
| chart_type: str, | |
| df: pd.DataFrame, | |
| profile: dict, | |
| x_col: str = None, | |
| y_col: str = None, | |
| color_col: str = None, | |
| ): | |
| """Generate a Plotly figure for the given chart type.""" | |
| num_cols = profile["numeric_columns"] | |
| cat_cols = profile["categorical_columns"] | |
| template = "plotly_dark" | |
| if chart_type == "correlation_heatmap" and len(num_cols) >= 2: | |
| corr = df[num_cols].corr().round(2) | |
| fig = px.imshow( | |
| corr, text_auto=True, color_continuous_scale="RdBu_r", | |
| title="Correlation Heatmap", template=template, | |
| color_continuous_midpoint=0, | |
| ) | |
| elif chart_type == "distribution_plots" and num_cols: | |
| col = y_col or num_cols[0] | |
| fig = px.histogram( | |
| df, x=col, nbins=30, marginal="box", | |
| title=f"Distribution of {col}", | |
| color_discrete_sequence=PALETTE, template=template, | |
| ) | |
| elif chart_type == "box_plots" and num_cols: | |
| cols = num_cols[:6] | |
| fig = go.Figure() | |
| for i, col in enumerate(cols): | |
| fig.add_trace(go.Box(y=df[col], name=col, marker_color=PALETTE[i % len(PALETTE)])) | |
| fig.update_layout(title="Box Plots β Numeric Columns", template=template) | |
| elif chart_type == "bar_chart" and cat_cols and num_cols: | |
| xc = x_col or cat_cols[0] | |
| yc = y_col or num_cols[0] | |
| agg = ( | |
| df.groupby(xc)[yc].mean() | |
| .reset_index() | |
| .sort_values(yc, ascending=False) | |
| .head(15) | |
| ) | |
| fig = px.bar( | |
| agg, x=xc, y=yc, color=yc, | |
| color_continuous_scale="Viridis", | |
| title=f"Average {yc} by {xc}", template=template, | |
| ) | |
| elif chart_type == "pie_chart" and cat_cols: | |
| col = x_col or cat_cols[0] | |
| counts = df[col].value_counts().head(8) | |
| fig = px.pie( | |
| values=counts.values, names=counts.index, | |
| title=f"Distribution of {col}", | |
| color_discrete_sequence=PALETTE, template=template, | |
| ) | |
| elif chart_type == "scatter_matrix" and len(num_cols) >= 2: | |
| cols = num_cols[:4] | |
| fig = px.scatter_matrix( | |
| df, dimensions=cols, | |
| color=cat_cols[0] if cat_cols else None, | |
| color_discrete_sequence=PALETTE, | |
| title="Scatter Matrix", template=template, | |
| ) | |
| fig.update_traces(diagonal_visible=False, showupperhalf=False) | |
| elif chart_type == "time_series" and profile["datetime_columns"] and num_cols: | |
| dt_col = profile["datetime_columns"][0] | |
| yc = y_col or num_cols[0] | |
| fig = px.line( | |
| df.sort_values(dt_col), x=dt_col, y=yc, | |
| title=f"{yc} over Time", | |
| color_discrete_sequence=PALETTE, template=template, | |
| ) | |
| elif chart_type == "scatter" and len(num_cols) >= 2: | |
| xc = x_col or num_cols[0] | |
| yc = y_col or num_cols[1] | |
| fig = px.scatter( | |
| df, x=xc, y=yc, | |
| color=color_col or (cat_cols[0] if cat_cols else None), | |
| color_discrete_sequence=PALETTE, | |
| title=f"{xc} vs {yc}", trendline="ols", template=template, | |
| ) | |
| elif chart_type == "line" and num_cols: | |
| xc = x_col or ( | |
| profile["datetime_columns"][0] if profile["datetime_columns"] else num_cols[0] | |
| ) | |
| yc = y_col or num_cols[0] | |
| fig = px.line( | |
| df, x=xc, y=yc, | |
| color_discrete_sequence=PALETTE, | |
| title=f"{yc} trend", template=template, | |
| ) | |
| else: | |
| if num_cols: | |
| means = df[num_cols[:8]].mean() | |
| fig = px.bar( | |
| x=means.index, y=means.values, | |
| labels={"x": "Column", "y": "Mean Value"}, | |
| color=means.values, color_continuous_scale="Viridis", | |
| title="Column Means Overview", template=template, | |
| ) | |
| else: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No numeric data available for this chart type.", | |
| showarrow=False, font=dict(size=14), | |
| ) | |
| fig.update_layout(template=template, title="Chart Unavailable") | |
| fig.update_layout( | |
| paper_bgcolor=DARK_BG, | |
| plot_bgcolor=CARD_BG, | |
| font=dict(family="DM Sans, sans-serif", color="#E0E0FF"), | |
| margin=dict(l=40, r=40, t=60, b=40), | |
| ) | |
| return fig | |
| # βββ AI chart recommendation ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ai_recommend_chart(question: str, profile: dict, llm) -> dict: | |
| """Ask the active LLM which chart best answers the user's question.""" | |
| from langchain_core.messages import HumanMessage | |
| num_cols = profile["numeric_columns"] | |
| cat_cols = profile["categorical_columns"] | |
| dt_cols = profile["datetime_columns"] | |
| prompt = ( | |
| f"Given this dataset profile:\n" | |
| f"- Numeric columns: {num_cols}\n" | |
| f"- Categorical columns: {cat_cols}\n" | |
| f"- Datetime columns: {dt_cols}\n\n" | |
| f'The user asked: "{question}"\n\n' | |
| "Recommend ONE chart type from this list that best answers their question:\n" | |
| "[correlation_heatmap, distribution_plots, box_plots, bar_chart, pie_chart, " | |
| "scatter, line, time_series, scatter_matrix]\n\n" | |
| "Also suggest the best x_col and y_col from the available columns.\n\n" | |
| "Respond ONLY in valid JSON like:\n" | |
| '{"chart_type": "bar_chart", "x_col": "category_col", ' | |
| '"y_col": "numeric_col", "reason": "short explanation"}' | |
| ) | |
| try: | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| text = response.content.strip() | |
| if "```" in text: | |
| text = text.split("```")[1] | |
| if text.startswith("json"): | |
| text = text[4:] | |
| return json.loads(text.strip()) | |
| except Exception: | |
| return { | |
| "chart_type": "distribution_plots", | |
| "x_col": None, | |
| "y_col": None, | |
| "reason": "Default chart", | |
| } |