insightgenai / modules /chat_engine.py
mohsinbhatti's picture
Initial commit - InsightGenAI files
e478478
"""
Chat Engine Module - InsightGenAI
=================================
Natural language interface for data querying.
Converts natural language questions to pandas queries.
Includes fallback to LLM API for complex queries.
Author: InsightGenAI Team
Version: 1.0.0
"""
import pandas as pd
import numpy as np
import re
from typing import Dict, List, Tuple, Optional, Any, Callable
import streamlit as st
import json
import os
class ChatEngine:
"""
Natural language chat interface for data analysis.
Supports:
- Pattern-based query parsing
- Pandas code generation
- LLM API fallback for complex queries
"""
# Query patterns for common data operations
QUERY_PATTERNS = {
# Summary queries
'show_head': {
'patterns': [
r'show (?:me )?(?:the )?(?:first )?(\d+ )?rows?',
r'display (?:the )?(?:first )?(\d+ )?rows?',
r'head (?:of )?(?:the )?data',
r'show (?:me )?the (?:beginning|start)'
],
'handler': '_handle_show_head'
},
'show_tail': {
'patterns': [
r'show (?:me )?(?:the )?last (\d+ )?rows?',
r'display (?:the )?last (\d+ )?rows?',
r'tail (?:of )?(?:the )?data',
r'show (?:me )?the end'
],
'handler': '_handle_show_tail'
},
'show_shape': {
'patterns': [
r'how many rows',
r'how many columns',
r'what is the shape',
r'size of (?:the )?data',
r'dimensions? of (?:the )?data'
],
'handler': '_handle_show_shape'
},
'show_info': {
'patterns': [
r'show (?:me )?info',
r'data types?',
r'column types?',
r'what columns',
r'list columns'
],
'handler': '_handle_show_info'
},
'show_describe': {
'patterns': [
r'describe (?:the )?data',
r'summary statistics?',
r'statistical summary',
r'basic statistics?'
],
'handler': '_handle_show_describe'
},
# Column-specific queries
'column_stats': {
'patterns': [
r'stats (?:for |of )?(?:column )?([\w\s]+)',
r'statistics (?:for |of )?(?:column )?([\w\s]+)',
r'describe (?:column )?([\w\s]+)',
r'info (?:about |on )?(?:column )?([\w\s]+)'
],
'handler': '_handle_column_stats'
},
'column_mean': {
'patterns': [
r'(?:what is |calculate )?(?:the )?mean (?:of |for )?(?:column )?([\w\s]+)',
r'(?:what is |calculate )?(?:the )?average (?:of |for )?(?:column )?([\w\s]+)',
r'average (?:of |for )?([\w\s]+)'
],
'handler': '_handle_column_mean'
},
'column_sum': {
'patterns': [
r'(?:what is |calculate )?(?:the )?sum (?:of |for )?(?:column )?([\w\s]+)',
r'total (?:of |for )?([\w\s]+)',
r'sum (?:of |for )?([\w\s]+)'
],
'handler': '_handle_column_sum'
},
'column_max': {
'patterns': [
r'(?:what is |find )?(?:the )?max(?:imum)? (?:of |for )?(?:column )?([\w\s]+)',
r'highest (?:value (?:in |of )?)?([\w\s]+)',
r'max (?:of |for )?([\w\s]+)'
],
'handler': '_handle_column_max'
},
'column_min': {
'patterns': [
r'(?:what is |find )?(?:the )?min(?:imum)? (?:of |for )?(?:column )?([\w\s]+)',
r'lowest (?:value (?:in |of )?)?([\w\s]+)',
r'min (?:of |for )?([\w\s]+)'
],
'handler': '_handle_column_min'
},
'value_counts': {
'patterns': [
r'value counts? (?:for |of )?(?:column )?([\w\s]+)',
r'unique values? (?:in |of )?([\w\s]+)',
r'how many unique (?:values )?(?:in )?([\w\s]+)',
r'frequency (?:of |for )?([\w\s]+)'
],
'handler': '_handle_value_counts'
},
# Filtering queries
'filter_greater': {
'patterns': [
r'show (?:rows? )?where ([\w\s]+) (?:is )?greater than (\d+\.?\d*)',
r'show (?:rows? )?where ([\w\s]+) (?:is )?more than (\d+\.?\d*)',
r'show (?:rows? )?where ([\w\s]+) > (\d+\.?\d*)',
r'filter ([\w\s]+) > (\d+\.?\d*)'
],
'handler': '_handle_filter_greater'
},
'filter_less': {
'patterns': [
r'show (?:rows? )?where ([\w\s]+) (?:is )?less than (\d+\.?\d*)',
r'show (?:rows? )?where ([\w\s]+) (?:is )?fewer than (\d+\.?\d*)',
r'show (?:rows? )?where ([\w\s]+) < (\d+\.?\d*)',
r'filter ([\w\s]+) < (\d+\.?\d*)'
],
'handler': '_handle_filter_less'
},
'filter_equal': {
'patterns': [
r'show (?:rows? )?where ([\w\s]+) (?:is |equals? )?([\w\s]+)',
r'show (?:rows? )?where ([\w\s]+) = ([\w\s]+)',
r'filter ([\w\s]+) = ([\w\s]+)'
],
'handler': '_handle_filter_equal'
},
'top_n': {
'patterns': [
r'top (\d+) (?:by |sorted by )?([\w\s]+)',
r'show (?:me )?top (\d+)',
r'highest (\d+) (?:by )?([\w\s]+)'
],
'handler': '_handle_top_n'
},
# Grouping queries
'group_by': {
'patterns': [
r'group (?:by )?([\w\s]+) (?:and )?(?:calculate )?(?:the )?(mean|sum|count|avg|average|max|min)?',
r'aggregate (?:by )?([\w\s]+)',
r'([\w\s]+) (?:grouped |aggregation )by ([\w\s]+)'
],
'handler': '_handle_group_by'
},
# Correlation queries
'correlation': {
'patterns': [
r'correlation (?:between )?([\w\s]+) (?:and )?([\w\s]+)',
r'correlate ([\w\s]+) (?:with |and )?([\w\s]+)',
r'how (?:are |is )?([\w\s]+) (?:and )?([\w\s]+) related'
],
'handler': '_handle_correlation'
},
# Missing values
'missing_values': {
'patterns': [
r'missing values?',
r'null values?',
r'how many missing',
r'na values?'
],
'handler': '_handle_missing_values'
},
# Duplicates
'duplicates': {
'patterns': [
r'duplicate rows?',
r'how many duplicates',
r'are there duplicates'
],
'handler': '_handle_duplicates'
}
}
def __init__(self, df: pd.DataFrame, column_types: Optional[Dict[str, str]] = None):
"""
Initialize the Chat Engine.
Args:
df: Dataset to query
column_types: Dictionary of column types
"""
self.df = df.copy()
self.column_types = column_types or {}
self.chat_history: List[Dict[str, str]] = []
self.llm_api_key: Optional[str] = None
self.llm_provider: str = 'openai' # or 'huggingface'
def set_llm_config(self, api_key: str, provider: str = 'openai') -> None:
"""
Configure LLM API for fallback queries.
Args:
api_key: API key for the LLM service
provider: LLM provider ('openai' or 'huggingface')
"""
self.llm_api_key = api_key
self.llm_provider = provider
def process_query(self, query: str) -> Dict[str, Any]:
"""
Process a natural language query.
Args:
query: Natural language query string
Returns:
Dict with response data
"""
query_lower = query.lower().strip()
# Try pattern matching first
for query_type, config in self.QUERY_PATTERNS.items():
for pattern in config['patterns']:
match = re.search(pattern, query_lower)
if match:
handler = getattr(self, config['handler'])
result = handler(match)
# Add to chat history
self.chat_history.append({
'query': query,
'response_type': 'pattern',
'result': result
})
return {
'success': True,
'type': query_type,
'result': result,
'method': 'pattern'
}
# Fallback to LLM if configured
if self.llm_api_key:
return self._query_llm(query)
# No match found
return {
'success': False,
'error': "I couldn't understand that query. Try rephrasing or use simpler terms.",
'suggestions': self._get_suggestions()
}
def _get_suggestions(self) -> List[str]:
"""Get query suggestions for the user."""
return [
"Show me the first 10 rows",
"What is the average of [column_name]?",
"Show rows where [column] > 100",
"Group by [column] and calculate mean",
"What is the correlation between [col1] and [col2]?",
"Show missing values"
]
# Pattern handlers
def _handle_show_head(self, match) -> Dict:
"""Handle show head query."""
n = int(match.group(1)) if match.group(1) else 5
return {
'data': self.df.head(n),
'message': f"Showing first {min(n, len(self.df))} rows"
}
def _handle_show_tail(self, match) -> Dict:
"""Handle show tail query."""
n = int(match.group(1)) if match.group(1) else 5
return {
'data': self.df.tail(n),
'message': f"Showing last {min(n, len(self.df))} rows"
}
def _handle_show_shape(self, match) -> Dict:
"""Handle shape query."""
rows, cols = self.df.shape
return {
'message': f"The dataset has {rows:,} rows and {cols} columns",
'shape': (rows, cols)
}
def _handle_show_info(self, match) -> Dict:
"""Handle info query."""
info_df = pd.DataFrame({
'Column': self.df.columns,
'Type': self.df.dtypes.values,
'Non-Null Count': self.df.count().values,
'Null Count': self.df.isnull().sum().values
})
return {
'data': info_df,
'message': f"Dataset has {len(self.df.columns)} columns"
}
def _handle_show_describe(self, match) -> Dict:
"""Handle describe query."""
return {
'data': self.df.describe(),
'message': "Statistical summary of numeric columns"
}
def _handle_column_stats(self, match) -> Dict:
"""Handle column stats query."""
col = match.group(1).strip()
# Find closest column name
col = self._find_column(col)
if col and col in self.df.columns:
stats = self.df[col].describe()
return {
'data': stats,
'message': f"Statistics for column '{col}'"
}
return {'error': f"Column '{col}' not found"}
def _handle_column_mean(self, match) -> Dict:
"""Handle column mean query."""
col = match.group(1).strip()
col = self._find_column(col)
if col and col in self.df.columns:
mean_val = self.df[col].mean()
return {
'message': f"Mean of '{col}': {mean_val:.4f}",
'value': mean_val
}
return {'error': f"Column '{col}' not found"}
def _handle_column_sum(self, match) -> Dict:
"""Handle column sum query."""
col = match.group(1).strip()
col = self._find_column(col)
if col and col in self.df.columns:
sum_val = self.df[col].sum()
return {
'message': f"Sum of '{col}': {sum_val:,.2f}",
'value': sum_val
}
return {'error': f"Column '{col}' not found"}
def _handle_column_max(self, match) -> Dict:
"""Handle column max query."""
col = match.group(1).strip()
col = self._find_column(col)
if col and col in self.df.columns:
max_val = self.df[col].max()
return {
'message': f"Maximum of '{col}': {max_val}",
'value': max_val
}
return {'error': f"Column '{col}' not found"}
def _handle_column_min(self, match) -> Dict:
"""Handle column min query."""
col = match.group(1).strip()
col = self._find_column(col)
if col and col in self.df.columns:
min_val = self.df[col].min()
return {
'message': f"Minimum of '{col}': {min_val}",
'value': min_val
}
return {'error': f"Column '{col}' not found"}
def _handle_value_counts(self, match) -> Dict:
"""Handle value counts query."""
col = match.group(1).strip()
col = self._find_column(col)
if col and col in self.df.columns:
counts = self.df[col].value_counts().head(10)
return {
'data': counts,
'message': f"Top 10 values in '{col}'"
}
return {'error': f"Column '{col}' not found"}
def _handle_filter_greater(self, match) -> Dict:
"""Handle filter greater than query."""
col = match.group(1).strip()
value = float(match.group(2))
col = self._find_column(col)
if col and col in self.df.columns:
filtered = self.df[self.df[col] > value]
return {
'data': filtered.head(20),
'message': f"Found {len(filtered)} rows where '{col}' > {value}"
}
return {'error': f"Column '{col}' not found"}
def _handle_filter_less(self, match) -> Dict:
"""Handle filter less than query."""
col = match.group(1).strip()
value = float(match.group(2))
col = self._find_column(col)
if col and col in self.df.columns:
filtered = self.df[self.df[col] < value]
return {
'data': filtered.head(20),
'message': f"Found {len(filtered)} rows where '{col}' < {value}"
}
return {'error': f"Column '{col}' not found"}
def _handle_filter_equal(self, match) -> Dict:
"""Handle filter equal query."""
col = match.group(1).strip()
value = match.group(2).strip()
col = self._find_column(col)
if col and col in self.df.columns:
# Try to convert value to appropriate type
try:
value = float(value)
except:
pass
filtered = self.df[self.df[col] == value]
return {
'data': filtered.head(20),
'message': f"Found {len(filtered)} rows where '{col}' = '{value}'"
}
return {'error': f"Column '{col}' not found"}
def _handle_top_n(self, match) -> Dict:
"""Handle top N query."""
n = int(match.group(1))
col = match.group(2).strip() if match.group(2) else self.df.columns[0]
col = self._find_column(col)
if col and col in self.df.columns:
top_n = self.df.nlargest(n, col)
return {
'data': top_n,
'message': f"Top {n} rows by '{col}'"
}
return {'error': f"Column '{col}' not found"}
def _handle_group_by(self, match) -> Dict:
"""Handle group by query."""
col = match.group(1).strip()
agg_func = match.group(2) if match.group(2) else 'mean'
col = self._find_column(col)
if col and col in self.df.columns:
agg_map = {
'mean': 'mean', 'avg': 'mean', 'average': 'mean',
'sum': 'sum', 'count': 'count',
'max': 'max', 'min': 'min'
}
func = agg_map.get(agg_func, 'mean')
numeric_cols = self.df.select_dtypes(include=[np.number]).columns
grouped = self.df.groupby(col)[numeric_cols].agg(func)
return {
'data': grouped.head(20),
'message': f"Grouped by '{col}' with {func} aggregation"
}
return {'error': f"Column '{col}' not found"}
def _handle_correlation(self, match) -> Dict:
"""Handle correlation query."""
col1 = match.group(1).strip()
col2 = match.group(2).strip()
col1 = self._find_column(col1)
col2 = self._find_column(col2)
if col1 in self.df.columns and col2 in self.df.columns:
corr = self.df[col1].corr(self.df[col2])
return {
'message': f"Correlation between '{col1}' and '{col2}': {corr:.4f}",
'value': corr
}
return {'error': f"One or both columns not found"}
def _handle_missing_values(self, match) -> Dict:
"""Handle missing values query."""
missing = self.df.isnull().sum()
missing = missing[missing > 0]
if len(missing) > 0:
return {
'data': missing,
'message': f"Found missing values in {len(missing)} columns"
}
return {'message': "No missing values found! 🎉"}
def _handle_duplicates(self, match) -> Dict:
"""Handle duplicates query."""
n_duplicates = self.df.duplicated().sum()
return {
'message': f"Found {n_duplicates} duplicate rows",
'count': n_duplicates
}
def _find_column(self, col_name: str) -> Optional[str]:
"""
Find the closest matching column name.
Args:
col_name: Column name to find
Returns:
Actual column name or None
"""
col_name = col_name.lower().strip()
# Exact match
for col in self.df.columns:
if col.lower() == col_name:
return col
# Substring match
for col in self.df.columns:
if col_name in col.lower() or col.lower() in col_name:
return col
return None
def _query_llm(self, query: str) -> Dict[str, Any]:
"""
Query LLM API for complex questions.
Args:
query: Natural language query
Returns:
Dict with LLM response
"""
if self.llm_provider == 'openai':
return self._query_openai(query)
else:
return self._query_huggingface(query)
def _query_openai(self, query: str) -> Dict[str, Any]:
"""Query OpenAI API."""
try:
import openai
openai.api_key = self.llm_api_key
# Create context about the dataset
columns_info = "\n".join([
f"- {col} ({self.df[col].dtype})"
for col in self.df.columns[:20] # Limit to first 20 columns
])
prompt = f"""You are a data analysis assistant. Answer the following question about a dataset.
Dataset Information:
- Shape: {self.df.shape}
- Columns:
{columns_info}
User Question: {query}
Provide a clear, concise answer based on the dataset structure."""
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful data analysis assistant."},
{"role": "user", "content": prompt}
],
max_tokens=500
)
answer = response.choices[0].message.content
return {
'success': True,
'type': 'llm_response',
'result': {'message': answer},
'method': 'llm'
}
except Exception as e:
return {
'success': False,
'error': f"LLM query failed: {str(e)}"
}
def _query_huggingface(self, query: str) -> Dict[str, Any]:
"""Query HuggingFace Inference API."""
try:
import requests
API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-large"
headers = {"Authorization": f"Bearer {self.llm_api_key}"}
payload = {
"inputs": f"Answer this data question: {query}",
"parameters": {"max_length": 200}
}
response = requests.post(API_URL, headers=headers, json=payload)
result = response.json()
if isinstance(result, list) and len(result) > 0:
answer = result[0].get('generated_text', 'No response')
else:
answer = str(result)
return {
'success': True,
'type': 'llm_response',
'result': {'message': answer},
'method': 'llm'
}
except Exception as e:
return {
'success': False,
'error': f"HuggingFace query failed: {str(e)}"
}
def get_chat_history(self) -> List[Dict[str, str]]:
"""Get the chat history."""
return self.chat_history
def clear_history(self) -> None:
"""Clear the chat history."""
self.chat_history = []
# Streamlit display functions
def display_chat_interface(df: pd.DataFrame, column_types: Optional[Dict[str, str]] = None):
"""Display chat interface in Streamlit."""
st.subheader("💬 Chat With Your Data")
# Initialize chat engine
if 'chat_engine' not in st.session_state:
st.session_state.chat_engine = ChatEngine(df, column_types)
chat_engine = st.session_state.chat_engine
# LLM configuration
with st.expander("⚙️ LLM Configuration (Optional)"):
col1, col2 = st.columns(2)
with col1:
provider = st.selectbox(
"LLM Provider",
options=['None', 'openai', 'huggingface'],
help="Select LLM provider for complex queries"
)
with col2:
if provider != 'None':
api_key = st.text_input(
"API Key",
type="password",
help=f"Enter your {provider} API key"
)
if api_key:
chat_engine.set_llm_config(api_key, provider)
# Chat input
query = st.text_input(
"Ask a question about your data",
placeholder="e.g., 'What is the average age?' or 'Show rows where salary > 50000'"
)
if st.button("Ask", type="primary") and query:
with st.spinner("Processing..."):
response = chat_engine.process_query(query)
if response['success']:
result = response['result']
# Display message
if 'message' in result:
st.info(result['message'])
# Display data
if 'data' in result:
st.dataframe(result['data'], use_container_width=True)
# Display single value
if 'value' in result:
st.metric("Result", f"{result['value']:.4f}" if isinstance(result['value'], float) else result['value'])
else:
st.error(response.get('error', 'Unknown error'))
if 'suggestions' in response:
st.write("Try these queries:")
for suggestion in response['suggestions']:
st.code(suggestion)
# Example queries
with st.expander("📖 Example Queries"):
st.markdown("""
**Basic Queries:**
- `show me the first 10 rows`
- `how many rows and columns?`
- `describe the data`
**Column Queries:**
- `what is the average of [column]?`
- `what is the maximum of [column]?`
- `show value counts for [column]`
**Filtering:**
- `show rows where [column] > 100`
- `show rows where [column] = value`
- `top 10 by [column]`
**Aggregation:**
- `group by [column] and calculate mean`
- `correlation between [col1] and [col2]`
**Data Quality:**
- `show missing values`
- `how many duplicates?`
""")