QueryMind / app.py
RAVENOCC's picture
Upload app.py
c45a9b8 verified
import os
import streamlit as st
import pandas as pd
import pickle
import base64
from io import BytesIO, StringIO
import sys
import operator
from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
import tempfile
import shutil
import plotly.io as pio
import io
import re
import json
import openai
# from fpdf import FPDF
import base64
from datetime import datetime
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from PIL import Image as PILImage
# Import LangChain and LangGraph components
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_experimental.utilities import PythonREPL
from langgraph.prebuilt import ToolInvocation, ToolExecutor
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState
from langgraph.graph import StateGraph, END
from reportlab.platypus import PageBreak
from PIL import Image as PILImage
# Initialize session state for AI provider settings
if 'ai_provider' not in st.session_state:
st.session_state.ai_provider = "openai"
if 'api_key' not in st.session_state:
st.session_state.api_key = ""
if 'selected_model' not in st.session_state:
st.session_state.selected_model = "gpt-4"
# Define model options for each provider
OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"]
GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"]
# Create temporary directory for file storage
if 'temp_dir' not in st.session_state:
st.session_state.temp_dir = tempfile.mkdtemp()
st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
os.makedirs(st.session_state.images_dir, exist_ok=True)
print(f"Created temporary directory: {st.session_state.temp_dir}")
print(f"Created images directory: {st.session_state.images_dir}")
# Define the system prompt
SYSTEM_PROMPT = """## Role
You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
## Capabilities
1. **Execute python code** using the `complete_python_task` tool.
## Goals
1. Understand the user's objectives clearly.
2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
3. Investigate if the goal is achievable by running Python code via the `python_code` field.
4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
## Code Guidelines
- **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
- **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
- **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
- **ONLY USE THE FOLLOWING LIBRARIES**:
- `pandas`
- `sklearn` (including all major ML models)
- `plotly`
- `numpy`
All these libraries are already imported for you.
## Machine Learning Guidelines
- For regression tasks:
- Linear Regression: `LinearRegression`
- Logistic Regression: `LogisticRegression`
- Ridge Regression: `Ridge`
- Lasso Regression: `Lasso`
- Random Forest Regression: `RandomForestRegressor`
- For classification tasks:
- Logistic Regression: `LogisticRegression`
- Decision Trees: `DecisionTreeClassifier`
- Random Forests: `RandomForestClassifier`
- Support Vector Machines: `SVC`
- K-Nearest Neighbors: `KNeighborsClassifier`
- Naive Bayes: `GaussianNB`
- For clustering:
- K-Means: `KMeans`
- DBSCAN: `DBSCAN`
- For dimensionality reduction:
- PCA: `PCA`
- Always preprocess data appropriately:
- Scale numerical features with `StandardScaler` or `MinMaxScaler`
- Encode categorical variables with `OneHotEncoder` when needed
- Handle missing values with `SimpleImputer`
- Always split data into training and testing sets using `train_test_split`
- Evaluate models using appropriate metrics:
- For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score`
- For classification: `accuracy_score`, `confusion_matrix`, `classification_report`
- For clustering: `silhouette_score`
- Consider using `cross_val_score` for more robust evaluation
- Visualize ML results with plotly when possible
## Plotting Guidelines
- Always use the `plotly` library for plotting.
- Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
- Do not try and show the plots inline with `fig.show()`.
"""
# Define the State class
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
input_data: Annotated[List[Dict], operator.add]
intermediate_outputs: Annotated[List[dict], operator.add]
current_variables: dict
output_image_paths: Annotated[List[str], operator.add]
# Initialize session state variables
if 'in_memory_datasets' not in st.session_state:
st.session_state.in_memory_datasets = {}
if 'persistent_vars' not in st.session_state:
st.session_state.persistent_vars = {}
if 'dataset_metadata_list' not in st.session_state:
st.session_state.dataset_metadata_list = []
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'dashboard_plots' not in st.session_state:
st.session_state.dashboard_plots = [None, None, None, None]
if 'columns' not in st.session_state:
st.session_state.columns = ["No columns available"]
if 'custom_plots_to_save' not in st.session_state:
st.session_state.custom_plots_to_save = {}
# Set up the tools
repl = PythonREPL()
plotly_saving_code = """import pickle
import uuid
import os
for figure in plotly_figures:
pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
with open(pickle_filename, 'wb') as f:
pickle.dump(figure, f)
"""
@tool
def complete_python_task(
graph_state: Annotated[dict, InjectedState],
thought: str,
python_code: str
) -> Tuple[str, dict]:
"""Execute Python code for data analysis and visualization."""
current_variables = graph_state.get("current_variables", {})
# Load datasets from in-memory storage
for input_dataset in graph_state.get("input_data", []):
var_name = input_dataset.get("variable_name")
if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
print(f"Loading {var_name} from in-memory storage")
current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
current_image_pickle_files = os.listdir(st.session_state.images_dir)
try:
# Capture stdout
old_stdout = sys.stdout
sys.stdout = StringIO()
# Execute the code and capture the result
exec_globals = globals().copy()
exec_globals.update(st.session_state.persistent_vars)
exec_globals.update(current_variables)
# Add scikit-learn modules to execution environment
import sklearn
import numpy as np
# Import scikit-learn components
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso # type: ignore
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier # type: ignore
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.svm import SVC, SVR
from sklearn.naive_bayes import GaussianNB
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.metrics import (
accuracy_score, confusion_matrix, classification_report,
mean_squared_error, r2_score, mean_absolute_error, silhouette_score
)
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
# Update execution globals with all ML components
exec_globals.update({
"plotly_figures": [],
"images_dir": st.session_state.images_dir,
"np": np,
# Linear models
"LinearRegression": LinearRegression,
"LogisticRegression": LogisticRegression,
"Ridge": Ridge,
"Lasso": Lasso,
# Tree-based models
"DecisionTreeClassifier": DecisionTreeClassifier,
"DecisionTreeRegressor": DecisionTreeRegressor,
"RandomForestClassifier": RandomForestClassifier,
"RandomForestRegressor": RandomForestRegressor,
"GradientBoostingClassifier": GradientBoostingClassifier,
# SVM models
"SVC": SVC,
"SVR": SVR,
# Other models
"GaussianNB": GaussianNB,
"PCA": PCA,
"KNeighborsClassifier": KNeighborsClassifier,
"KNeighborsRegressor": KNeighborsRegressor,
"KMeans": KMeans,
"DBSCAN": DBSCAN,
# Preprocessing
"StandardScaler": StandardScaler,
"MinMaxScaler": MinMaxScaler,
"OneHotEncoder": OneHotEncoder,
"SimpleImputer": SimpleImputer,
# Model selection and evaluation
"train_test_split": train_test_split,
"cross_val_score": cross_val_score,
"GridSearchCV": GridSearchCV,
"accuracy_score": accuracy_score,
"confusion_matrix": confusion_matrix,
"classification_report": classification_report,
"mean_squared_error": mean_squared_error,
"r2_score": r2_score,
"mean_absolute_error": mean_absolute_error,
"silhouette_score": silhouette_score,
# Pipeline
"Pipeline": Pipeline
})
exec(python_code, exec_globals)
st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
# Get the captured stdout
output = sys.stdout.getvalue()
# Restore stdout
sys.stdout = old_stdout
updated_state = {
"intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
"current_variables": st.session_state.persistent_vars
}
if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
exec(plotly_saving_code, exec_globals)
# Check if any images were created
new_image_folder_contents = os.listdir(st.session_state.images_dir)
new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
if new_image_files:
updated_state["output_image_paths"] = new_image_files
st.session_state.persistent_vars["plotly_figures"] = []
return output, updated_state
except Exception as e:
sys.stdout = old_stdout # Restore stdout in case of error
print(f"Error in complete_python_task: {str(e)}")
return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
# Function to initialize the LLM based on selected provider and model
def initialize_llm():
api_key = st.session_state.api_key
model = st.session_state.selected_model
if not api_key:
return None
try:
if st.session_state.ai_provider == "openai":
os.environ["OPENAI_API_KEY"] = api_key
return ChatOpenAI(model=model, temperature=0)
elif st.session_state.ai_provider == "groq":
os.environ["GROQ_API_KEY"] = api_key
# For Groq, set the base URL and use the model
from langchain_groq import ChatGroq
return ChatGroq(model=model, temperature=0)
except Exception as e:
print(f"Error initializing LLM: {str(e)}")
return None
# Set up the tools
tools = [complete_python_task]
tool_executor = ToolExecutor(tools)
# Load the prompt template
chat_template = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("placeholder", "{messages}"),
])
def create_data_summary(state: AgentState) -> str:
summary = ""
variables = []
# Add sample data for each dataset
for d in state.get("input_data", []):
var_name = d.get("variable_name")
if var_name:
variables.append(var_name)
summary += f"\n\nVariable: {var_name}\n"
summary += f"Description: {d.get('data_description', 'No description')}\n"
# Add sample data if available
if var_name in st.session_state.in_memory_datasets:
df = st.session_state.in_memory_datasets[var_name]
summary += "\nSample Data (first 5 rows):\n"
summary += df.head(5).to_string()
if "current_variables" in state:
remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
for v in remaining_variables:
var_value = state["current_variables"].get(v)
if isinstance(var_value, pd.DataFrame):
summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
else:
summary += f"\n\nVariable: {v}"
return summary
def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
"""Determine if we should route to tools or end the chain"""
if messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
def call_model(state: AgentState):
"""Call the LLM to get a response"""
current_data_template = """The following data is available:\n{data_summary}"""
current_data_message = HumanMessage(
content=current_data_template.format(data_summary=create_data_summary(state))
)
messages = [current_data_message] + state["messages"]
# Get the initialized LLM
llm = initialize_llm()
if llm is None:
return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]}
# Create the model with bound tools
model = llm.bind_tools(tools)
model = chat_template | model
llm_outputs = model.invoke({"messages": messages})
return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
def call_tools(state: AgentState):
"""Execute tools called by the LLM"""
last_message = state["messages"][-1]
tool_invocations = []
if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
tool_invocations = [
ToolInvocation(
tool=tool_call["name"],
tool_input={**tool_call["args"], "graph_state": state}
) for tool_call in last_message.tool_calls
]
responses = tool_executor.batch(tool_invocations, return_exceptions=True)
tool_messages = []
state_updates = {}
for tc, response in zip(last_message.tool_calls, responses):
if isinstance(response, Exception):
print(f"Exception in tool execution: {str(response)}")
tool_messages.append(ToolMessage(
content=f"Error: {str(response)}",
name=tc["name"],
tool_call_id=tc["id"]
))
continue
message, updates = response
tool_messages.append(ToolMessage(
content=str(message),
name=tc["name"],
tool_call_id=tc["id"]
))
# Merge updates instead of overwriting
for key, value in updates.items():
if key in state_updates:
if isinstance(value, list) and isinstance(state_updates[key], list):
state_updates[key].extend(value)
elif isinstance(value, dict) and isinstance(state_updates[key], dict):
state_updates[key].update(value)
else:
state_updates[key] = value
else:
state_updates[key] = value
if 'messages' not in state_updates:
state_updates["messages"] = []
state_updates["messages"] = tool_messages
return state_updates
# Set up the graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", call_tools)
workflow.add_conditional_edges(
"agent",
route_to_tools,
{
"tools": "tools",
"__end__": END
}
)
workflow.add_edge("tools", "agent")
workflow.set_entry_point("agent")
chain = workflow.compile()
def process_file_upload(files):
"""Process uploaded files and return dataframe previews and column names"""
st.session_state.in_memory_datasets = {} # Clear previous datasets
st.session_state.dataset_metadata_list = [] # Clear previous metadata
st.session_state.persistent_vars.clear() # Clear persistent variables for new session
if not files:
return "No files uploaded.", [], ["No columns available"]
results = []
all_columns = [] # Track all columns from all datasets
for file in files:
try:
# Use file object directly
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith(('.xls', '.xlsx')):
df = pd.read_excel(file)
else:
results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
continue
var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
st.session_state.in_memory_datasets[var_name] = df
# Collect all columns
all_columns.extend(df.columns.tolist())
# Create dataset metadata
dataset_metadata = {
"variable_name": var_name,
"data_path": "in_memory",
"data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
"original_filename": file.name
}
st.session_state.dataset_metadata_list.append(dataset_metadata)
# Return preview of the dataset
preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
preview += df.head(10).to_markdown()
results.append(preview)
print(f"Successfully processed {file.name}")
except Exception as e:
print(f"Error processing {file.name}: {str(e)}")
results.append(f"Error processing {file.name}: {str(e)}")
# Get unique columns
unique_columns = []
seen = set()
for col in all_columns:
if col not in seen:
seen.add(col)
unique_columns.append(col)
if not unique_columns:
unique_columns = ["No columns available"]
print(f"Found {len(unique_columns)} unique columns across datasets")
return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
def get_columns():
"""Directly gets columns from in-memory datasets"""
all_columns = []
for var_name, df in st.session_state.in_memory_datasets.items():
if isinstance(df, pd.DataFrame):
all_columns.extend(df.columns.tolist())
# Remove duplicates while preserving order
unique_columns = []
seen = set()
for col in all_columns:
if col not in seen:
seen.add(col)
unique_columns.append(col)
if not unique_columns:
unique_columns = ["No columns available"]
print(f"Populating dropdowns with {len(unique_columns)} columns")
return unique_columns
# === FUNCTIONS ===
import openai
import pandas as pd
import json
import re
def standard_clean(df):
df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
df.drop_duplicates(inplace=True)
df.dropna(axis=1, how='all', inplace=True)
df.dropna(axis=0, how='all', inplace=True)
for col in df.select_dtypes(include='object').columns:
df[col] = df[col].astype(str).str.strip()
return df
def query_openai(prompt):
try:
# Use the configured API key and model from session state
api_key = st.session_state.api_key
model = st.session_state.selected_model
if st.session_state.ai_provider == "openai":
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
return response.choices[0].message.content
elif st.session_state.ai_provider == "groq":
from groq import Groq
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
return response.choices[0].message.content
except Exception as e:
print(f"API Error: {e}")
return "{}"
def llm_suggest_cleaning(df):
sample = df.head(10).to_csv(index=False)
prompt = f"""
You are a professional data wrangler. Below is a sample of a messy dataset.
Return a Python dictionary with the following keys:
1. rename_columns – fix unclear or inconsistent column names
2. convert_types – correct datatypes: int, float, str, or date
3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0
4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes)
Do not drop any rows or columns. Your output must be a valid Python dict.
Example:
{{
"rename_columns": {{"dob": "date_of_birth"}},
"convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}},
"fill_missing": {{"gender": "mode", "salary": -1}},
"value_map": {{
"gender": {{"M": "Male", "F": "Female"}},
"subscribed": {{"Y": "Yes", "N": "No"}}
}}
}}
Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning.
Sample data:
{sample}
"""
raw_response = query_openai(prompt)
try:
suggestions = eval(raw_response)
return suggestions
except:
print("Could not parse suggestions.")
return {
"rename_columns": {},
"convert_types": {},
"fill_missing": {},
"value_map": {}
}
def apply_suggestions(df, suggestions):
df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
for col, dtype in suggestions.get("convert_types", {}).items():
if col not in df.columns:
continue
try:
if dtype == "int":
df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
elif dtype == "float":
df[col] = pd.to_numeric(df[col], errors='coerce')
elif dtype == "str":
df[col] = df[col].astype(str)
elif dtype == "date":
df[col] = pd.to_datetime(df[col], errors='coerce')
except:
print(f"Failed to convert {col} to {dtype}")
for col, method in suggestions.get("fill_missing", {}).items():
if col not in df.columns:
continue
try:
if method == "mean":
df[col].fillna(df[col].mean(), inplace=True)
elif method == "median":
df[col].fillna(df[col].median(), inplace=True)
elif method == "mode":
df[col].fillna(df[col].mode().iloc[0], inplace=True)
elif isinstance(method, str):
df[col].fillna(method, inplace=True)
except:
print(f"Could not fill missing values for {col}")
for col, mapping in suggestions.get("value_map", {}).items():
if col in df.columns:
try:
df[col] = df[col].replace(mapping)
except:
print(f"Could not map values in {col}")
return df
def capture_dashboard_screenshot():
"""Capture the entire dashboard as a single image"""
try:
# Create a figure that combines all dashboard plots
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Create a 2x2 subplot
fig = make_subplots(rows=2, cols=2,
subplot_titles=["Visualization 1", "Visualization 2",
"Visualization 3", "Visualization 4"])
# Add each plot from the dashboard to the combined figure
for i, plot in enumerate(st.session_state.dashboard_plots):
if plot is not None:
row = (i // 2) + 1
col = (i % 2) + 1
# Extract traces from the original figure and add to our subplot
for trace in plot.data:
fig.add_trace(trace, row=row, col=col)
# Copy layout properties for each subplot
for axis_type in ['xaxis', 'yaxis']:
axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
subplot_name = f"{axis_type}{row}{col}"
# Copy axis properties if they exist
if hasattr(plot.layout, axis_name):
axis_props = getattr(plot.layout, axis_name)
fig.update_layout({subplot_name: axis_props})
# Update layout for better appearance
fig.update_layout(
height=800,
width=1000,
title_text="Dashboard Overview",
showlegend=False,
)
# Save to a temporary file
dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
return dashboard_path
except Exception as e:
import traceback
print(f"Error capturing dashboard: {str(e)}")
print(traceback.format_exc())
return None
def generate_enhanced_pdf_report():
"""Generate an enhanced PDF report with proper handling of base64 image data"""
try:
# Create a buffer for the PDF
buffer = io.BytesIO()
# Create the PDF document
doc = SimpleDocTemplate(buffer, pagesize=letter,
leftMargin=36, rightMargin=36,
topMargin=36, bottomMargin=36)
# Create custom styles with better formatting
styles = getSampleStyleSheet()
# Add custom styles with improved formatting
styles.add(ParagraphStyle(
name='ReportTitle',
parent=styles['Heading1'],
fontSize=24,
alignment=1, # Centered
spaceAfter=20,
textColor='#2C3E50' # Dark blue color
))
styles.add(ParagraphStyle(
name='SectionHeader',
parent=styles['Heading2'],
fontSize=16,
spaceBefore=15,
spaceAfter=10,
textColor='#2C3E50',
borderWidth=1,
borderColor='#95A5A6',
borderPadding=5,
borderRadius=5
))
styles.add(ParagraphStyle(
name='SubHeader',
parent=styles['Heading3'],
fontSize=14,
spaceBefore=10,
spaceAfter=8,
textColor='#34495E',
fontWeight='bold'
))
styles.add(ParagraphStyle(
name='UserMessage',
parent=styles['Normal'],
fontSize=11,
leftIndent=10,
spaceBefore=8,
spaceAfter=4
))
styles.add(ParagraphStyle(
name='AssistantMessage',
parent=styles['Normal'],
fontSize=11,
leftIndent=10,
spaceBefore=4,
spaceAfter=12,
textColor='#2980B9'
))
styles.add(ParagraphStyle(
name='Timestamp',
parent=styles['Italic'],
fontSize=10,
textColor='#7F8C8D',
alignment=2 # Right aligned
))
# Create the document content
elements = []
# Add title
elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
# Add timestamp
elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
styles['Timestamp']))
elements.append(Spacer(1, 0.5*inch))
# Add conversation history with better formatting
elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader']))
if st.session_state.chat_history:
for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
# Format user message with proper styling
elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader']))
user_msg_formatted = user_msg.replace('\n', '<br/>')
elements.append(Paragraph(user_msg_formatted, styles['UserMessage']))
# Process assistant message to handle visualization
# Look for markdown image syntax with base64 data
base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)'
# Check if the message contains visualizations
if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg):
# Split the message at the Visualizations header if it exists
if '### Visualizations' in assistant_msg:
parts = assistant_msg.split('### Visualizations', 1)
text_part = parts[0]
viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else ""
else:
# If no header but still has visualization
match = re.search(base64_pattern, assistant_msg)
text_part = assistant_msg[:match.start()]
viz_part = assistant_msg[match.start():]
# Format the text part
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
text_part = text_part.replace('\n', '<br/>')
elements.append(Paragraph(text_part, styles['AssistantMessage']))
# Process visualizations
matches = re.findall(base64_pattern, viz_part)
for j, base64_data in enumerate(matches):
try:
# Decode the base64 image
image_data = base64.b64decode(base64_data)
# Create a temporary file for the image
temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png"
with open(temp_img_path, 'wb') as f:
f.write(image_data)
# Add the image to the PDF
elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader']))
elements.append(Spacer(1, 0.1*inch))
img = Image(temp_img_path, width=6*inch, height=4*inch)
elements.append(img)
elements.append(Spacer(1, 0.2*inch))
except Exception as e:
print(f"Error processing base64 image: {str(e)}")
elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]",
styles['Normal']))
else:
# No visualizations, just format the text
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
assistant_msg_formatted = assistant_msg.replace('\n', '<br/>')
if len(assistant_msg_formatted) > 1500:
assistant_msg_formatted = assistant_msg_formatted[:1500] + '...'
elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage']))
elements.append(Spacer(1, 0.2*inch))
else:
elements.append(Paragraph('No conversation history available.', styles['Normal']))
# Force a page break before the dashboard
elements.append(PageBreak())
# Add dashboard section header
elements.append(Paragraph('Dashboard Overview', styles['SectionHeader']))
elements.append(Spacer(1, 0.2*inch))
# Capture the dashboard as a single image
dashboard_img_path = capture_dashboard_screenshot()
if dashboard_img_path:
# Calculate available width
available_width = doc.width
# Create PIL image to get dimensions
pil_img = PILImage.open(dashboard_img_path)
img_width, img_height = pil_img.size
# Calculate scaling factor to fit within page width
scale_factor = available_width / img_width
# Calculate new height based on aspect ratio
new_height = img_height * scale_factor
# Add the image with scaled dimensions
img = Image(dashboard_img_path, width=available_width, height=new_height)
elements.append(img)
else:
# Fallback: Add individual plots if combined dashboard fails
plot_count = 0
for i, plot in enumerate(st.session_state.dashboard_plots):
if plot is not None:
plot_count += 1
# Convert plotly figure to image
img_bytes = io.BytesIO()
plot.write_image(img_bytes, format='png', width=500, height=300)
img_bytes.seek(0)
# Create a temporary file for the image
temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
with open(temp_img_path, 'wb') as f:
f.write(img_bytes.getvalue())
# Add to PDF with appropriate caption and formatting
elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader']))
elements.append(Spacer(1, 0.1*inch))
# Add the image with proper scaling
img = Image(temp_img_path, width=6.5*inch, height=4*inch)
elements.append(img)
elements.append(Spacer(1, 0.3*inch))
if plot_count == 0:
elements.append(Paragraph('No visualizations have been added to the dashboard.',
styles['Normal']))
# Build the PDF with improved formatting
doc.build(elements)
# Get the value of the buffer
pdf_value = buffer.getvalue()
buffer.close()
return pdf_value
except Exception as e:
import traceback
print(f"Error generating enhanced PDF report: {str(e)}")
print(traceback.format_exc())
return None
def chat_with_workflow(message, history, dataset_info):
"""Send user query to the workflow and get response"""
if not dataset_info:
return "Please upload at least one dataset before asking questions."
# Check if we have a valid API key and model
if not st.session_state.api_key:
return "Please set up your API key and model in the Settings tab before chatting."
print(f"Chat with workflow called with {len(dataset_info)} datasets")
try:
# Extract chat history for context (last 3 exchanges)
max_history = 3
previous_messages = []
if history:
start_idx = max(0, len(history) - max_history)
recent_history = history[start_idx:]
for exchange in recent_history:
if exchange[0]: # User message
previous_messages.append(HumanMessage(content=exchange[0]))
if exchange[1]: # AI response
previous_messages.append(AIMessage(content=exchange[1]))
# Initialize the workflow state
state = AgentState(
messages=previous_messages + [HumanMessage(content=message)],
input_data=dataset_info,
intermediate_outputs=[],
current_variables=st.session_state.persistent_vars,
output_image_paths=[]
)
# Execute the workflow
print("Executing workflow...")
result = chain.invoke(state)
print("Workflow execution completed")
# Extract messages from the result
messages = result["messages"]
# Format the response - only get the latest response
response = ""
if messages:
latest_message = messages[-1] # Get only the last message
if hasattr(latest_message, "content"):
content = latest_message.content
# Clean up the response
# Remove any instances where the user's message is repeated
if message in content:
content = content.split(message)[-1].strip()
# Remove any chat history markers
content_lines = content.split('\n')
filtered_lines = [line for line in content_lines
if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))]
content = '\n'.join(filtered_lines)
response = content.strip() + "\n\n"
# Handle visualizations
if "output_image_paths" in result and result["output_image_paths"]:
response += "### Visualizations\n\n"
for img_path in result["output_image_paths"]:
try:
full_path = os.path.join(st.session_state.images_dir, img_path)
with open(full_path, 'rb') as f:
fig = pickle.load(f)
# Convert plotly figure to image
img_bytes = BytesIO()
fig.update_layout(width=800, height=500)
pio.write_image(fig, img_bytes, format='png')
img_bytes.seek(0)
# Convert to base64 for markdown image
b64_img = base64.b64encode(img_bytes.read()).decode()
response += f"![Visualization](data:image/png;base64,{b64_img})\n\n"
except Exception as e:
response += f"Error loading visualization: {str(e)}\n\n"
return response
except Exception as e:
import traceback
print(f"Error in chat_with_workflow: {str(e)}")
print(traceback.format_exc())
return f"Error executing workflow: {str(e)}"
def auto_generate_dashboard(dataset_info):
"""Generate an automatic dashboard with four plots"""
if not dataset_info:
return "Please upload a dataset first.", [None, None, None, None]
prompt = """
You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
Use plotly and store the plots in a list named plotly_figures.
Include multivariate plots using color/size/facets when helpful.
"""
state = AgentState(
messages=[HumanMessage(content=prompt)],
input_data=dataset_info,
intermediate_outputs=[],
current_variables=st.session_state.persistent_vars,
output_image_paths=[]
)
result = chain.invoke(state)
figures = []
if "output_image_paths" in result:
for img_path in result["output_image_paths"][:4]:
try:
full_path = os.path.join(st.session_state.images_dir, img_path)
with open(full_path, 'rb') as f:
fig = pickle.load(f)
figures.append(fig)
except Exception as e:
print(f"Error loading figure: {e}")
while len(figures) < 4:
figures.append(None)
st.session_state.dashboard_plots = figures
return "Dashboard generated!", figures
def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
"""Generate custom plots based on user-selected columns"""
if not dataset_info or not x_col or not y_col:
return [None, None, None]
prompt = f"""
You are a data visualization expert.
Create 3 insightful visualizations using Plotly based on:
- X-axis: {x_col}
- Y-axis: {y_col}
- Facet (optional): {facet_col if facet_col != 'None' else 'None'}
Try to find interesting relationships, trends, or clusters using appropriate chart types.
Use `plotly_figures` list and avoid using fig.show().
"""
state = AgentState(
messages=[HumanMessage(content=prompt)],
input_data=dataset_info,
intermediate_outputs=[],
current_variables=st.session_state.persistent_vars,
output_image_paths=[]
)
result = chain.invoke(state)
figures = []
if "output_image_paths" in result:
for img_path in result["output_image_paths"][:3]:
try:
full_path = os.path.join(st.session_state.images_dir, img_path)
with open(full_path, 'rb') as f:
fig = pickle.load(f)
figures.append(fig)
except Exception as e:
print(f"Error loading figure: {e}")
while len(figures) < 3:
figures.append(None)
return figures
def remove_plot(index):
"""Remove a plot from the dashboard"""
if 0 <= index < len(st.session_state.dashboard_plots):
st.session_state.dashboard_plots[index] = None
def respond(message):
"""Handle chat message response"""
if not st.session_state.dataset_metadata_list:
bot_message = "Please upload at least one dataset before asking questions."
else:
bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
st.session_state.chat_history.append((message, bot_message))
st.rerun()
def save_plot_to_dashboard(plot_index):
"""Callback for the Add Plot button"""
for i in range(len(st.session_state.dashboard_plots)):
if st.session_state.dashboard_plots[i] is None:
# Found an empty slot
st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
return
# Streamlit UI
st.set_page_config(page_title="QueryMind 🧠", layout="wide")
st.title("QueryMind 🧠 - Data Assistant")
st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
# Create tabs
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report", "Settings"])
with tab1:
st.header("Upload Datasets")
uploaded_files = st.file_uploader("Upload CSV or Excel Files",
accept_multiple_files=True,
type=['csv', 'xlsx', 'xls'])
if uploaded_files and st.button("Process Uploaded Files"):
with st.spinner("Processing files..."):
preview, metadata_list, columns = process_file_upload(uploaded_files)
st.session_state.columns = columns
# Display basic information about processed files
st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
# Show detailed preview for each dataset
st.subheader("Dataset Previews")
for dataset_name, df in st.session_state.in_memory_datasets.items():
with st.expander(f"Preview: {dataset_name}"):
# Display dataset info
st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
# Display column information
col_info = pd.DataFrame({
'Column Name': df.columns,
'Data Type': df.dtypes.astype(str),
'Non-Null Count': df.count().values,
'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
})
# Show column information in a compact table
st.write("**Column Information:**")
st.dataframe(col_info, use_container_width=True)
# Show actual data preview
st.write("**Data Preview (First 10 rows):**")
st.dataframe(df.head(10), use_container_width=True)
# Provide hint for the next steps
st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
with tab2:
st.header("Data Cleaning")
if 'cleaning_done' not in st.session_state:
st.session_state.cleaning_done = False
if 'cleaned_datasets' not in st.session_state:
st.session_state.cleaned_datasets = {}
if 'cleaning_summaries' not in st.session_state:
st.session_state.cleaning_summaries = {}
if st.session_state.get("in_memory_datasets"):
if not st.session_state.cleaning_done:
if st.button("Run Data Cleaning"):
with st.spinner("Running LLM-assisted cleaning..."):
for name, df in st.session_state.in_memory_datasets.items():
raw_df = df.copy()
df_std = standard_clean(raw_df.copy())
suggestions = llm_suggest_cleaning(df_std.copy())
df_clean = apply_suggestions(df_std.copy(), suggestions)
st.session_state.cleaned_datasets[name] = df_clean
st.session_state.cleaning_summaries[name] = suggestions
st.session_state.cleaning_done = True
st.rerun()
else:
st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
else:
for name, df_clean in st.session_state.cleaned_datasets.items():
raw_df = st.session_state.in_memory_datasets[name]
st.subheader(f"Dataset: {name}")
col1, col2 = st.columns(2)
with col1:
st.markdown("Original Data (First 5 Rows)")
st.dataframe(raw_df.head())
with col2:
st.markdown("Cleaned Data (First 5 Rows)")
st.dataframe(df_clean.head())
st.markdown("Summary of Cleaning Actions")
suggestions = st.session_state.cleaning_summaries[name]
summary_text = ""
if suggestions:
for key, value in suggestions.items():
summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
st.markdown(summary_text)
st.markdown("Refine the Cleaning (Natural Language Instructions)")
user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
key=f"user_input_{name}")
if f'corrections_{name}' not in st.session_state:
st.session_state[f'corrections_{name}'] = []
if st.button("Apply Correction", key=f'apply_correction_{name}'):
if user_input.strip():
correction_prompt = f"""
You are a data cleaning expert. Below is a previously cleaned dataset with these actions:
{summary_text}
The user now wants the following additional instruction:
\"{user_input.strip()}\"
Write only the Python code that modifies the pandas DataFrame `df` accordingly.
Do not include explanations or markdown.
"""
correction_code = query_openai(correction_prompt)
try:
df = st.session_state.cleaned_datasets[name].copy()
local_vars = {"df": df}
exec(correction_code, {}, local_vars)
df_updated = local_vars["df"]
st.session_state.cleaned_datasets[name] = df_updated
st.session_state[f'corrections_{name}'].append((user_input, correction_code))
st.success("Correction applied.")
st.rerun()
except Exception as e:
st.error(f"Failed to apply correction: {str(e)}")
if st.session_state[f'corrections_{name}']:
st.markdown("Applied Corrections")
for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']):
st.markdown(f"**Instruction:** {msg}")
st.code(code, language='python')
col1, col2 = st.columns([1, 2])
with col1:
if st.button("Reset Cleaning and Re-run"):
st.session_state.cleaning_done = False
st.rerun()
with col2:
if st.button("Finalize and Proceed to Visualizations"):
st.session_state.cleaning_finalized = True
st.rerun()
else:
st.info("Please upload and process datasets first.")
with tab3:
st.header("Chat with AI Assistant")
# Show API warning if not set
if not st.session_state.api_key:
st.warning("⚠️ Please set up your API key and model in the Settings tab before using the chat.")
st.markdown("""
## Example Questions
- "What analysis can you perform on this dataset?"
- "Show me basic statistics for all columns"
- "Create a correlation heatmap"
- "Plot the distribution of a specific column"
- "What is the relationship between two columns?"
""")
# Display chat history
for exchange in st.session_state.chat_history:
with st.chat_message("user"):
st.write(exchange[0])
with st.chat_message("assistant"):
st.write(exchange[1])
# Chat input
if prompt := st.chat_input("Your question"):
with st.spinner("Thinking..."):
respond(prompt)
with tab4:
st.header("Auto Dashboard Generator")
# Dashboard controls
dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
col1, col2 = st.columns(2)
with col1:
if st.button("Generate Suggested Dashboard (Auto)"):
if not st.session_state.api_key:
st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
else:
with st.spinner("Generating dashboard..."):
message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
st.success(message)
with col2:
if st.button("Refresh Column Options"):
st.session_state.columns = get_columns()
st.rerun()
# Dashboard display
st.subheader("Dashboard")
# Row 1
col1, col2 = st.columns(2)
with col1:
if st.session_state.dashboard_plots[0]:
st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
if st.button("Remove Plot 1"):
remove_plot(0)
st.rerun()
with col2:
if st.session_state.dashboard_plots[1]:
st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
if st.button("Remove Plot 2"):
remove_plot(1)
st.rerun()
# Row 2
col3, col4 = st.columns(2)
with col3:
if st.session_state.dashboard_plots[2]:
st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
if st.button("Remove Plot 3"):
remove_plot(2)
st.rerun()
with col4:
if st.session_state.dashboard_plots[3]:
st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
if st.button("Remove Plot 4"):
remove_plot(3)
st.rerun()
# Custom plot generator
st.subheader("Custom Plot Generator")
# Column selection
col1, col2, col3 = st.columns(3)
with col1:
x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
with col2:
y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
with col3:
facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
if st.button("Generate Custom Visualizations"):
if not st.session_state.api_key:
st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
else:
with st.spinner("Generating custom visualizations..."):
custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
# Store plots in session state
for i, plot in enumerate(custom_plots):
if plot:
st.session_state.custom_plots_to_save[i] = plot
# Display custom plots with add buttons
for i, plot in enumerate(custom_plots):
if plot:
st.plotly_chart(plot, use_container_width=True)
st.button(
f"Add Plot {i+1} to Dashboard",
key=f"add_plot_{i}",
on_click=save_plot_to_dashboard,
args=(i,)
)
with tab5:
st.header("Generate Analysis Report")
st.markdown("""
Generate a PDF report containing:
- Dashboard visualizations
- Chat conversation history
""")
report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
if st.button("Generate PDF Report"):
if not st.session_state.api_key:
st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
else:
with st.spinner("Generating report..."):
pdf_data = generate_enhanced_pdf_report()
if pdf_data:
# Create download button for PDF
b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
# Create download link
pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
st.markdown("### Your report is ready!")
st.markdown(pdf_download_link, unsafe_allow_html=True)
# Preview option (simplified)
with st.expander("Preview Report"):
st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
else:
st.error("Failed to generate the report. Please try again.")
with tab6:
st.header("AI Provider Settings")
# AI Provider selection
provider = st.radio("Select AI Provider",
options=["OpenAI", "Groq"],
index=0 if st.session_state.ai_provider == "openai" else 1,
horizontal=True)
# Update session state based on selection
st.session_state.ai_provider = provider.lower()
# API Key input
api_key = st.text_input("Enter API Key",
value=st.session_state.api_key,
type="password",
help="Your API key for the selected provider")
# Display different model options based on provider
if st.session_state.ai_provider == "openai":
model_options = OPENAI_MODELS
model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable."
else: # groq
model_options = GROQ_MODELS
model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest."
# Model selection
selected_model = st.selectbox("Select Model",
options=model_options,
index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0,
help=model_help)
# Save button
if st.button("Save Settings"):
st.session_state.api_key = api_key
st.session_state.selected_model = selected_model
# Test the API key and model
try:
# Initialize LLM using the provided settings
test_llm = initialize_llm()
if test_llm:
st.success(f"✅ Successfully configured {provider} with model: {selected_model}")
else:
st.error("Failed to initialize the AI provider. Please check your API key and model selection.")
except Exception as e:
st.error(f"Error testing settings: {str(e)}")
# Display current settings
st.subheader("Current Settings")
settings_info = f"""
- **Provider**: {st.session_state.ai_provider.upper()}
- **Model**: {st.session_state.selected_model}
- **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'}
"""
st.markdown(settings_info)
# Provider-specific information
if st.session_state.ai_provider == "openai":
st.info("""
**OpenAI Models Information:**
- **GPT-4**: Most powerful model, best for complex analysis and detailed explanations
- **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities
- **GPT-4-Mini**: Economical option with good performance for standard tasks
- **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization
""")
else:
st.info("""
**Groq Models Information:**
- **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis
- **gemma2-9b-it**: Good balance of speed and capabilities
- **llama-3-8b-8192**: Fastest option for basic analysis tasks
""")
# Integration instructions
with st.expander("How to get API Keys"):
if st.session_state.ai_provider == "openai":
st.markdown("""
### Getting an OpenAI API Key
1. Go to [OpenAI's platform](https://platform.openai.com)
2. Sign up or log in to your account
3. Navigate to the API section
4. Create a new API key
5. Copy the key and paste it above
Note: OpenAI API usage incurs charges based on tokens used.
""")
else:
st.markdown("""
### Getting a Groq API Key
1. Go to [Groq's website](https://console.groq.com/keys)
2. Sign up or log in to your account
3. Navigate to API Keys section
4. Create a new API key
5. Copy the key and paste it above
Note: Check Groq's pricing page for current rates.
""")
# Cleanup on app exit
def cleanup():
try:
shutil.rmtree(st.session_state.temp_dir)
print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
except Exception as e:
print(f"Error cleaning up: {e}")
import atexit
atexit.register(cleanup)