Ai_analyst / app.py
AJAY KASU
Generalize Universal Prompt: Remove Wine mapping, enable generic inference
aae32bb
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import chromadb
from chromadb.config import Settings
from huggingface_hub import InferenceClient
from sklearn.ensemble import IsolationForest
import io
import contextlib
import re
import time
import os
import traceback
import matplotlib
matplotlib.use('Agg') # Prevent GUI errors
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Optional
# --- Configuration & Constants ---
st.set_page_config(page_title="Senior Data Analyst Agent", page_icon="๐Ÿง", layout="wide")
CHROMA_PATH = "chroma_db"
# --- Infrastructure Components ---
class AnalystMemory:
"""
Handles persistent memory using ChromaDB.
Stores context about datasets to recall domain-specific knowledge.
"""
def __init__(self, persistence_path: str):
self.client = chromadb.PersistentClient(path=persistence_path)
self.collection = self.client.get_or_create_collection(name="data_context")
def save_context(self, context_str: str, metadata: Dict[str, Any]):
try:
# We use a static ID for simplicity in this demo, effectively overwriting context
# In a real app, this would be hashed based on file signature
self.collection.upsert(
documents=[context_str],
metadatas=[metadata],
ids=["current_session_context"]
)
except Exception as e:
# Non-blocking error
print(f"Memory Save Error: {e}")
def retrieve_context(self) -> str:
try:
results = self.collection.get(ids=["current_session_context"])
if results["documents"]:
return results["documents"][0]
except Exception:
pass
return ""
class AnalystAgent:
def __init__(self, hf_token: str, model_id: str = "Qwen/Qwen2.5-72B-Instruct"):
self.client = InferenceClient(token=hf_token)
self.model_id = model_id
# Keep memory for now, although the user's snippet didn't explicitly include it in __init__
# We'll use memory in scan_context if needed, but the user's snippet focused on generate/execute
self.memory = AnalystMemory(CHROMA_PATH)
self.context: Dict[str, Any] = {}
def scan_context(self, df: pd.DataFrame) -> Dict[str, Any]:
"""
Input Phase: Proactively scans the data.
"""
# Reuse existing scan logic as it is useful for the UI
domain = self.detect_domain(df)
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
# Heuristic for important cols - casting to str for safety
msg_cols = [c for c in numeric_cols if any(x in str(c).lower() for x in ['sales', 'revenue', 'profit', 'churn', 'cost', 'amount'])]
target_vars = msg_cols if msg_cols else numeric_cols[:3]
null_counts = df.isnull().sum()
quality_issues = null_counts[null_counts > 0].to_dict()
duplicates = df.duplicated().sum()
self.context = {
"domain": domain,
"columns": list(df.columns),
"shape": df.shape,
"targets": target_vars,
"quality_issues": str(quality_issues),
"duplicates": int(duplicates),
"head": df.head(3).to_string()
}
# Save to persistent memory
self.memory.save_context(str(self.context), {"domain": domain})
return self.context
def detect_domain(self, df: pd.DataFrame) -> str:
"""Infers domain from column names."""
col_str = " ".join([str(c) for c in df.columns]).lower() # Handle integer columns
if any(x in col_str for x in ['sku', 'order', 'shipping', 'customer', 'qty']):
return "E-commerce/Retail"
elif any(x in col_str for x in ['ticker', 'volume', 'open', 'close', 'high']):
return "Financial Markets"
elif any(x in col_str for x in ['campaign', 'click', 'impression', 'cpc', 'ctr']):
return "Digital Marketing"
elif any(x in col_str for x in ['patient', 'diagnosis', 'treatment', 'dose']):
return "Healthcare"
elif any(x in col_str for x in ['salary', 'employee', 'department', 'hire']):
return "HR/People Analytics"
return "General Business Data"
def generate_plan(self, query: str, df: pd.DataFrame, system_prompt: Optional[str] = None) -> str:
"""
Generates a plan and code with full visibility of the data structure.
Uses "Elite" persona by default, or "Universal" when system_prompt is provided.
"""
# Dynamic Context: Feed Sample + Stats
df_head = df.head(3).to_string()
df_stats = df.describe().to_string()
if system_prompt is None:
# Default "Elite" Persona
system_prompt = f"""You are a Principal Data Scientist and Business Consultant.
### 1. DATA CONTEXT (STRICT)
You are analyzing a dataset with the following characteristics:
- Sample Data:
{df_head}
- Statistics:
{df_stats}
### 2. EXECUTION RULES
- **NO DUMMY DATA:** Do NOT create new dataframes (e.g., `pd.DataFrame({...})`). WORK ONLY WITH `df`.
- **NO `to_markdown()`:** Do not use `df.to_markdown()`.
- **RENAME FIRST:** If `df` has integer columns (0, 1, 2...), your FIRST step must be to rename them using the Domain Knowledge below.
- **CONTRACT:** You MUST define `result_text` (str) and `fig` (Plotly figure).
### 3. DOMAIN KNOWLEDGE: WINE DATASET
If you see 14 columns with integer names (0..13), you MUST apply this mapping:
df.columns = [
'Cultivar', 'Alcohol', 'Malic Acid', 'Ash', 'Alcalinity of Ash',
'Magnesium', 'Total Phenols', 'Flavanoids', 'Nonflavanoid Phenols',
'Proanthocyanins', 'Color Intensity', 'Hue', 'OD280', 'Proline'
]
### 4. REASONING & PLANNING PHASE
Before writing code, you must output a [PLAN] section. Break the task into:
- Data Cleaning: Rename columns to professional names.
- Correlation Analysis: Use Pearson's r for chemical relationships.
- Comparative Analysis: Group by Cultivar to find quality profiles.
### 5. OUTPUT STRUCTURE
[PLAN] -> Your multi-step strategy.
[CODE] -> The Python block.
[REPORT] -> A high-level summary (The user will see the 'result_text' from your code as the primary output).
"""
# If universal prompt is passed, it likely has formatted {df_sample} placeholders
# We need to format those manually if not handled outside
# But for simplicity, we assume the specific prompt passed in is formatted string or handles it.
# Actually, let's inject context if it's the Universal prompt pattern
if "{df_sample}" in system_prompt:
system_prompt = system_prompt.format(df_sample=df_head, df_stats=df_stats)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Objective: {query}"}
]
try:
response = self.client.chat_completion(
model=self.model_id,
messages=messages,
max_tokens=2000,
temperature=0.1
)
return response.choices[0].message.content
except Exception as e:
st.error(f"LLM Error: {e}")
return ""
def execute_step(self, query: str, df: pd.DataFrame, system_prompt: Optional[str] = None) -> Dict[str, Any]:
"""
The Self-Healing Loop: Executes code and retries with error feedback if it fails.
"""
# Pass system_prompt to generate_plan
plan_output = self.generate_plan(query, df, system_prompt)
code = self.extract_code(plan_output)
last_error = ""
for attempt in range(3):
try:
# Clean namespace for execution
# ADDED: plt and sns for Universal Dashboard
local_vars = {
"df": df, "pd": pd, "np": np, "px": px,
"plt": plt, "sns": sns
}
# Capture print statements for logs
with contextlib.redirect_stdout(io.StringIO()) as f:
# Matplotlib hook: Close previous figures to avoid clutter
plt.close('all')
exec(code, {}, local_vars)
return {
"success": True,
"result": local_vars.get("result_text", "No summary provided."),
"fig": local_vars.get("fig", None),
"code": code,
"log": f.getvalue(),
"plan": plan_output
}
except Exception as e:
last_error = traceback.format_exc()
# SELF-HEALING: Feed the error back to the LLM
retry_msg = f"Your code failed with error:\n{last_error}\n\nActual Columns: {df.columns.tolist()}. Please fix the code. Ensure you do not redefine 'df'."
try:
response = self.client.chat_completion(
model=self.model_id,
messages=[{"role": "user", "content": retry_msg}],
temperature=0.1
)
plan_output = response.choices[0].message.content
code = self.extract_code(plan_output)
except Exception as llm_e:
return {"success": False, "error": f"LLM Fix Failed: {llm_e}", "code": code}
return {"success": False, "error": f"Self-healing failed after 3 attempts. Last error: {last_error}", "code": code}
def extract_code(self, text: str) -> str:
# Match ```python ... ``` blocks, capturing everything between
match = re.search(r"```python(.*?)```", text, re.DOTALL)
if match:
return match.group(1).strip()
# Fallback: if no markdown fences, look for imports
if "import pandas" in text or "def " in text:
return text
return ""
# --- Streamlit UI Components ---
def render_insight_card(title: str, content: Any, icon: str = "๐Ÿ’ก"):
with st.container():
st.markdown(f"### {icon} {title}")
if isinstance(content, str):
st.info(content)
elif hasattr(content, "show"): # Plotly fig
st.plotly_chart(content, use_container_width=True)
else:
# Try Matplotlib, else write
try:
st.pyplot(content)
except:
st.write(content)
def main():
# Helper to get token
hf_token = os.environ.get("HF_TOKEN") or st.secrets.get("HF_TOKEN") or st.secrets.get("huggingface", {}).get("token")
# Sidebar Setup
with st.sidebar:
st.header("๐Ÿง  Configuration")
if not hf_token:
hf_token = st.text_input("Hugging Face Token", type="password")
if not hf_token:
st.warning("Token required to proceed.")
st.stop()
st.markdown("---")
uploaded_file = st.file_uploader("๐Ÿ“‚ Upload Dataset", type="csv")
st.markdown("---")
st.markdown("### ๐Ÿ› ๏ธ Capabilities")
st.caption("โœ… Domain Detection")
st.caption("โœ… Proactive Outlier Scan")
st.caption("โœ… Self-Healing Code")
# Initialize Agent
if "agent" not in st.session_state:
st.session_state.agent = AnalystAgent(hf_token)
# Main Canvas
if uploaded_file:
try:
# 1. Try standard read
df = pd.read_csv(uploaded_file)
# Heuristic: If first row looks like data (all numeric strings), reload with header=None
def is_numeric_header(df):
try:
# Check if all column names can be cast to floats
[float(c) for c in df.columns]
return True
except:
return False
if is_numeric_header(df):
uploaded_file.seek(0)
df = pd.read_csv(uploaded_file, header=None)
st.toast("Detected headerless CSV (e.g. Wine dataset). Reloaded with integer indices.", icon="๐Ÿง ")
except:
uploaded_file.seek(0)
try:
# 2. Try semicolon (common in EU data)
df = pd.read_csv(uploaded_file, sep=';')
except:
uploaded_file.seek(0)
try:
# 3. Last resort: Python engine
df = pd.read_csv(uploaded_file, sep=None, engine='python', encoding='latin1')
except Exception as e:
st.error(f"Could not read file. Error: {e}")
st.stop()
# 1. Context Scanning (Only on new file)
if st.session_state.get("last_file") != uploaded_file.name:
with st.status("๐Ÿ” Scanning Context...", expanded=True) as status:
context = st.session_state.agent.scan_context(df)
st.session_state.context = context
st.session_state.last_file = uploaded_file.name
st.session_state.initial_analysis_done = False # Reset Auto-EDA flag
st.session_state.initial_analysis_result = None # Clear previous results
status.write(f"Inferred Domain: **{context['domain']}**")
status.write(f"Identified Targets: `{context['targets']}`")
status.update(label="Context Loaded", state="complete", expanded=False)
st.toast("Deep scan complete!", icon="๐Ÿง ")
# 2. Context Display
c = st.session_state.context
col1, col2, col3, col4 = st.columns(4)
col1.metric("Domain", c['domain'])
col2.metric("Rows", f"{c['shape'][0]:,}")
col3.metric("Columns", c['shape'][1])
col4.metric("Duplicates", c['duplicates'])
# 1.5 Auto-EDA: Run Initial Analysis if not done
if not st.session_state.get("initial_analysis_done"):
with st.status("๐Ÿš€ Launching Universal Analyst Dashboard...", expanded=True) as status:
st.write("Generating 3-Plot Dashboard (Heatmap, Distribution, Split)...")
# UNIVERSAL ANALYST PROMPT
universal_prompt = """You are a Universal Principal Data Scientist. Your goal is to proactively discover the "Story" inside any uploaded dataset.
### PHASE 1: DATA INSPECTION
1. Examine the first few rows and the statistics: {df_sample}, {df_stats}.
2. Identify:
- Categorical Columns (for grouping).
- Numerical Columns (for correlation).
- The likely "Target" or most interesting variable.
### PHASE 2: AUTOMATIC MULTI-EDA DASHBOARD
You MUST generate a 3-plot dashboard using `plt.subplots(1, 3, figsize=(20, 6))`:
1. THE RELATIONSHIP MAP: A Seaborn Heatmap of all numerical correlations.
2. THE DISTRIBUTION: A Box plot or Histogram of the most high-variance variables.
3. THE CATEGORICAL SPLIT: A Scatter plot using a categorical column as the 'hue' to show clustering.
### PHASE 3: INTELLIGENT ERROR HANDLING
- **Headerless Data:** If columns are `0, 1, 2...`, YOU MUST RENAME THEM to `Feature_0, Feature_1...`.
- **Target Inference:** Look for a column with low cardinality (e.g., integers 1-3) and name it `Target_Variable`. Use this for the Categorical Split.
- Check for perfect 1.0 correlationsโ€”if found, it means you are correlating a column with itself. Correct this immediately.
- Use `sns.set_theme(style="whitegrid")`.
### PHASE 4: THE "EXCELLENT" OUTPUT CONTRACT
Your code MUST define:
- `result_text`: An Executive Summary using Markdown and LaTeX ($r$).
- `fig`: The Matplotlib figure object containing the 3-plot dashboard.
### OUTPUT STRUCTURE:
[PLAN] -> Your strategy based on the detected data types.
[CODE] -> The Python execution block.
[REPORT] -> A summary of "Anomalies," "Strongest Drivers," and "Actionable Insights."
"""
query = "Generate the comprehensive 3-plot dashboard."
# Execute with CUSTOM prompt
result = st.session_state.agent.execute_step(query, df, system_prompt=universal_prompt)
st.session_state.initial_analysis_result = result
st.session_state.initial_analysis_done = True
status.update(label="Initial Analysis Complete", state="complete", expanded=False)
# Display Auto-EDA Results
if st.session_state.get("initial_analysis_result"):
res = st.session_state.initial_analysis_result
if res["success"]:
st.markdown("### ๐Ÿ“Š Universal Analyst Dashboard")
# Check for Matplotlib Figure
if res.get("fig"):
# Check if it's a matplotlib figure (by checking for 'savefig' method or similar)
# or just try st.pyplot if it's not a plotly graph object
try:
# Try Streamlit Pyplot (Matplotlib)
st.pyplot(res["fig"])
except:
# Fallback to Plotly
st.plotly_chart(res["fig"], use_container_width=True)
if res.get("result"):
st.info(res["result"])
with st.expander("๐Ÿ”Ž View Analysis Code"):
st.code(res["code"], language="python")
st.divider()
# 3. Query Interface
st.markdown("### ๐Ÿ’ฌ Ask the Senior Analyst")
query = st.chat_input("Tell me about sales trends, outliers, or specific patterns...")
if query:
# Plan
with st.status(f"๐Ÿค” Analyzing: '{query}'", expanded=True) as status:
st.write("Drafting verification plan...")
# New API: execute_step handles planning and execution
result = st.session_state.agent.execute_step(query, df)
status.update(label="Analysis Complete", state="complete", expanded=False)
# Results
if result["success"]:
if result.get("result"):
render_insight_card("Key Finding", result["result"])
if result.get("fig"):
render_insight_card("Deep Dive", result["fig"], icon="๐Ÿ“Š")
with st.expander("๐Ÿ“ Audit Trail (Code & Logs)"):
st.code(result["code"], language="python")
if result.get("log"):
st.text(result["log"])
if result.get("plan"):
st.markdown("**Generated Plan:**")
st.text(result["plan"])
else:
st.error(f"Analysis Failed: {result.get('error')}")
with st.expander("Failed Code"):
st.code(result["code"], language="python")
else:
st.info("๐Ÿ‘‹ Welcome! Please upload a CSV file to begin automation.")
if __name__ == "__main__":
main()