""" Chat Engine Module - InsightGenAI ================================= Natural language interface for data querying. Converts natural language questions to pandas queries. Includes fallback to LLM API for complex queries. Author: InsightGenAI Team Version: 1.0.0 """ import pandas as pd import numpy as np import re from typing import Dict, List, Tuple, Optional, Any, Callable import streamlit as st import json import os class ChatEngine: """ Natural language chat interface for data analysis. Supports: - Pattern-based query parsing - Pandas code generation - LLM API fallback for complex queries """ # Query patterns for common data operations QUERY_PATTERNS = { # Summary queries 'show_head': { 'patterns': [ r'show (?:me )?(?:the )?(?:first )?(\d+ )?rows?', r'display (?:the )?(?:first )?(\d+ )?rows?', r'head (?:of )?(?:the )?data', r'show (?:me )?the (?:beginning|start)' ], 'handler': '_handle_show_head' }, 'show_tail': { 'patterns': [ r'show (?:me )?(?:the )?last (\d+ )?rows?', r'display (?:the )?last (\d+ )?rows?', r'tail (?:of )?(?:the )?data', r'show (?:me )?the end' ], 'handler': '_handle_show_tail' }, 'show_shape': { 'patterns': [ r'how many rows', r'how many columns', r'what is the shape', r'size of (?:the )?data', r'dimensions? of (?:the )?data' ], 'handler': '_handle_show_shape' }, 'show_info': { 'patterns': [ r'show (?:me )?info', r'data types?', r'column types?', r'what columns', r'list columns' ], 'handler': '_handle_show_info' }, 'show_describe': { 'patterns': [ r'describe (?:the )?data', r'summary statistics?', r'statistical summary', r'basic statistics?' ], 'handler': '_handle_show_describe' }, # Column-specific queries 'column_stats': { 'patterns': [ r'stats (?:for |of )?(?:column )?([\w\s]+)', r'statistics (?:for |of )?(?:column )?([\w\s]+)', r'describe (?:column )?([\w\s]+)', r'info (?:about |on )?(?:column )?([\w\s]+)' ], 'handler': '_handle_column_stats' }, 'column_mean': { 'patterns': [ r'(?:what is |calculate )?(?:the )?mean (?:of |for )?(?:column )?([\w\s]+)', r'(?:what is |calculate )?(?:the )?average (?:of |for )?(?:column )?([\w\s]+)', r'average (?:of |for )?([\w\s]+)' ], 'handler': '_handle_column_mean' }, 'column_sum': { 'patterns': [ r'(?:what is |calculate )?(?:the )?sum (?:of |for )?(?:column )?([\w\s]+)', r'total (?:of |for )?([\w\s]+)', r'sum (?:of |for )?([\w\s]+)' ], 'handler': '_handle_column_sum' }, 'column_max': { 'patterns': [ r'(?:what is |find )?(?:the )?max(?:imum)? (?:of |for )?(?:column )?([\w\s]+)', r'highest (?:value (?:in |of )?)?([\w\s]+)', r'max (?:of |for )?([\w\s]+)' ], 'handler': '_handle_column_max' }, 'column_min': { 'patterns': [ r'(?:what is |find )?(?:the )?min(?:imum)? (?:of |for )?(?:column )?([\w\s]+)', r'lowest (?:value (?:in |of )?)?([\w\s]+)', r'min (?:of |for )?([\w\s]+)' ], 'handler': '_handle_column_min' }, 'value_counts': { 'patterns': [ r'value counts? (?:for |of )?(?:column )?([\w\s]+)', r'unique values? (?:in |of )?([\w\s]+)', r'how many unique (?:values )?(?:in )?([\w\s]+)', r'frequency (?:of |for )?([\w\s]+)' ], 'handler': '_handle_value_counts' }, # Filtering queries 'filter_greater': { 'patterns': [ r'show (?:rows? )?where ([\w\s]+) (?:is )?greater than (\d+\.?\d*)', r'show (?:rows? )?where ([\w\s]+) (?:is )?more than (\d+\.?\d*)', r'show (?:rows? )?where ([\w\s]+) > (\d+\.?\d*)', r'filter ([\w\s]+) > (\d+\.?\d*)' ], 'handler': '_handle_filter_greater' }, 'filter_less': { 'patterns': [ r'show (?:rows? )?where ([\w\s]+) (?:is )?less than (\d+\.?\d*)', r'show (?:rows? )?where ([\w\s]+) (?:is )?fewer than (\d+\.?\d*)', r'show (?:rows? )?where ([\w\s]+) < (\d+\.?\d*)', r'filter ([\w\s]+) < (\d+\.?\d*)' ], 'handler': '_handle_filter_less' }, 'filter_equal': { 'patterns': [ r'show (?:rows? )?where ([\w\s]+) (?:is |equals? )?([\w\s]+)', r'show (?:rows? )?where ([\w\s]+) = ([\w\s]+)', r'filter ([\w\s]+) = ([\w\s]+)' ], 'handler': '_handle_filter_equal' }, 'top_n': { 'patterns': [ r'top (\d+) (?:by |sorted by )?([\w\s]+)', r'show (?:me )?top (\d+)', r'highest (\d+) (?:by )?([\w\s]+)' ], 'handler': '_handle_top_n' }, # Grouping queries 'group_by': { 'patterns': [ r'group (?:by )?([\w\s]+) (?:and )?(?:calculate )?(?:the )?(mean|sum|count|avg|average|max|min)?', r'aggregate (?:by )?([\w\s]+)', r'([\w\s]+) (?:grouped |aggregation )by ([\w\s]+)' ], 'handler': '_handle_group_by' }, # Correlation queries 'correlation': { 'patterns': [ r'correlation (?:between )?([\w\s]+) (?:and )?([\w\s]+)', r'correlate ([\w\s]+) (?:with |and )?([\w\s]+)', r'how (?:are |is )?([\w\s]+) (?:and )?([\w\s]+) related' ], 'handler': '_handle_correlation' }, # Missing values 'missing_values': { 'patterns': [ r'missing values?', r'null values?', r'how many missing', r'na values?' ], 'handler': '_handle_missing_values' }, # Duplicates 'duplicates': { 'patterns': [ r'duplicate rows?', r'how many duplicates', r'are there duplicates' ], 'handler': '_handle_duplicates' } } def __init__(self, df: pd.DataFrame, column_types: Optional[Dict[str, str]] = None): """ Initialize the Chat Engine. Args: df: Dataset to query column_types: Dictionary of column types """ self.df = df.copy() self.column_types = column_types or {} self.chat_history: List[Dict[str, str]] = [] self.llm_api_key: Optional[str] = None self.llm_provider: str = 'openai' # or 'huggingface' def set_llm_config(self, api_key: str, provider: str = 'openai') -> None: """ Configure LLM API for fallback queries. Args: api_key: API key for the LLM service provider: LLM provider ('openai' or 'huggingface') """ self.llm_api_key = api_key self.llm_provider = provider def process_query(self, query: str) -> Dict[str, Any]: """ Process a natural language query. Args: query: Natural language query string Returns: Dict with response data """ query_lower = query.lower().strip() # Try pattern matching first for query_type, config in self.QUERY_PATTERNS.items(): for pattern in config['patterns']: match = re.search(pattern, query_lower) if match: handler = getattr(self, config['handler']) result = handler(match) # Add to chat history self.chat_history.append({ 'query': query, 'response_type': 'pattern', 'result': result }) return { 'success': True, 'type': query_type, 'result': result, 'method': 'pattern' } # Fallback to LLM if configured if self.llm_api_key: return self._query_llm(query) # No match found return { 'success': False, 'error': "I couldn't understand that query. Try rephrasing or use simpler terms.", 'suggestions': self._get_suggestions() } def _get_suggestions(self) -> List[str]: """Get query suggestions for the user.""" return [ "Show me the first 10 rows", "What is the average of [column_name]?", "Show rows where [column] > 100", "Group by [column] and calculate mean", "What is the correlation between [col1] and [col2]?", "Show missing values" ] # Pattern handlers def _handle_show_head(self, match) -> Dict: """Handle show head query.""" n = int(match.group(1)) if match.group(1) else 5 return { 'data': self.df.head(n), 'message': f"Showing first {min(n, len(self.df))} rows" } def _handle_show_tail(self, match) -> Dict: """Handle show tail query.""" n = int(match.group(1)) if match.group(1) else 5 return { 'data': self.df.tail(n), 'message': f"Showing last {min(n, len(self.df))} rows" } def _handle_show_shape(self, match) -> Dict: """Handle shape query.""" rows, cols = self.df.shape return { 'message': f"The dataset has {rows:,} rows and {cols} columns", 'shape': (rows, cols) } def _handle_show_info(self, match) -> Dict: """Handle info query.""" info_df = pd.DataFrame({ 'Column': self.df.columns, 'Type': self.df.dtypes.values, 'Non-Null Count': self.df.count().values, 'Null Count': self.df.isnull().sum().values }) return { 'data': info_df, 'message': f"Dataset has {len(self.df.columns)} columns" } def _handle_show_describe(self, match) -> Dict: """Handle describe query.""" return { 'data': self.df.describe(), 'message': "Statistical summary of numeric columns" } def _handle_column_stats(self, match) -> Dict: """Handle column stats query.""" col = match.group(1).strip() # Find closest column name col = self._find_column(col) if col and col in self.df.columns: stats = self.df[col].describe() return { 'data': stats, 'message': f"Statistics for column '{col}'" } return {'error': f"Column '{col}' not found"} def _handle_column_mean(self, match) -> Dict: """Handle column mean query.""" col = match.group(1).strip() col = self._find_column(col) if col and col in self.df.columns: mean_val = self.df[col].mean() return { 'message': f"Mean of '{col}': {mean_val:.4f}", 'value': mean_val } return {'error': f"Column '{col}' not found"} def _handle_column_sum(self, match) -> Dict: """Handle column sum query.""" col = match.group(1).strip() col = self._find_column(col) if col and col in self.df.columns: sum_val = self.df[col].sum() return { 'message': f"Sum of '{col}': {sum_val:,.2f}", 'value': sum_val } return {'error': f"Column '{col}' not found"} def _handle_column_max(self, match) -> Dict: """Handle column max query.""" col = match.group(1).strip() col = self._find_column(col) if col and col in self.df.columns: max_val = self.df[col].max() return { 'message': f"Maximum of '{col}': {max_val}", 'value': max_val } return {'error': f"Column '{col}' not found"} def _handle_column_min(self, match) -> Dict: """Handle column min query.""" col = match.group(1).strip() col = self._find_column(col) if col and col in self.df.columns: min_val = self.df[col].min() return { 'message': f"Minimum of '{col}': {min_val}", 'value': min_val } return {'error': f"Column '{col}' not found"} def _handle_value_counts(self, match) -> Dict: """Handle value counts query.""" col = match.group(1).strip() col = self._find_column(col) if col and col in self.df.columns: counts = self.df[col].value_counts().head(10) return { 'data': counts, 'message': f"Top 10 values in '{col}'" } return {'error': f"Column '{col}' not found"} def _handle_filter_greater(self, match) -> Dict: """Handle filter greater than query.""" col = match.group(1).strip() value = float(match.group(2)) col = self._find_column(col) if col and col in self.df.columns: filtered = self.df[self.df[col] > value] return { 'data': filtered.head(20), 'message': f"Found {len(filtered)} rows where '{col}' > {value}" } return {'error': f"Column '{col}' not found"} def _handle_filter_less(self, match) -> Dict: """Handle filter less than query.""" col = match.group(1).strip() value = float(match.group(2)) col = self._find_column(col) if col and col in self.df.columns: filtered = self.df[self.df[col] < value] return { 'data': filtered.head(20), 'message': f"Found {len(filtered)} rows where '{col}' < {value}" } return {'error': f"Column '{col}' not found"} def _handle_filter_equal(self, match) -> Dict: """Handle filter equal query.""" col = match.group(1).strip() value = match.group(2).strip() col = self._find_column(col) if col and col in self.df.columns: # Try to convert value to appropriate type try: value = float(value) except: pass filtered = self.df[self.df[col] == value] return { 'data': filtered.head(20), 'message': f"Found {len(filtered)} rows where '{col}' = '{value}'" } return {'error': f"Column '{col}' not found"} def _handle_top_n(self, match) -> Dict: """Handle top N query.""" n = int(match.group(1)) col = match.group(2).strip() if match.group(2) else self.df.columns[0] col = self._find_column(col) if col and col in self.df.columns: top_n = self.df.nlargest(n, col) return { 'data': top_n, 'message': f"Top {n} rows by '{col}'" } return {'error': f"Column '{col}' not found"} def _handle_group_by(self, match) -> Dict: """Handle group by query.""" col = match.group(1).strip() agg_func = match.group(2) if match.group(2) else 'mean' col = self._find_column(col) if col and col in self.df.columns: agg_map = { 'mean': 'mean', 'avg': 'mean', 'average': 'mean', 'sum': 'sum', 'count': 'count', 'max': 'max', 'min': 'min' } func = agg_map.get(agg_func, 'mean') numeric_cols = self.df.select_dtypes(include=[np.number]).columns grouped = self.df.groupby(col)[numeric_cols].agg(func) return { 'data': grouped.head(20), 'message': f"Grouped by '{col}' with {func} aggregation" } return {'error': f"Column '{col}' not found"} def _handle_correlation(self, match) -> Dict: """Handle correlation query.""" col1 = match.group(1).strip() col2 = match.group(2).strip() col1 = self._find_column(col1) col2 = self._find_column(col2) if col1 in self.df.columns and col2 in self.df.columns: corr = self.df[col1].corr(self.df[col2]) return { 'message': f"Correlation between '{col1}' and '{col2}': {corr:.4f}", 'value': corr } return {'error': f"One or both columns not found"} def _handle_missing_values(self, match) -> Dict: """Handle missing values query.""" missing = self.df.isnull().sum() missing = missing[missing > 0] if len(missing) > 0: return { 'data': missing, 'message': f"Found missing values in {len(missing)} columns" } return {'message': "No missing values found! 🎉"} def _handle_duplicates(self, match) -> Dict: """Handle duplicates query.""" n_duplicates = self.df.duplicated().sum() return { 'message': f"Found {n_duplicates} duplicate rows", 'count': n_duplicates } def _find_column(self, col_name: str) -> Optional[str]: """ Find the closest matching column name. Args: col_name: Column name to find Returns: Actual column name or None """ col_name = col_name.lower().strip() # Exact match for col in self.df.columns: if col.lower() == col_name: return col # Substring match for col in self.df.columns: if col_name in col.lower() or col.lower() in col_name: return col return None def _query_llm(self, query: str) -> Dict[str, Any]: """ Query LLM API for complex questions. Args: query: Natural language query Returns: Dict with LLM response """ if self.llm_provider == 'openai': return self._query_openai(query) else: return self._query_huggingface(query) def _query_openai(self, query: str) -> Dict[str, Any]: """Query OpenAI API.""" try: import openai openai.api_key = self.llm_api_key # Create context about the dataset columns_info = "\n".join([ f"- {col} ({self.df[col].dtype})" for col in self.df.columns[:20] # Limit to first 20 columns ]) prompt = f"""You are a data analysis assistant. Answer the following question about a dataset. Dataset Information: - Shape: {self.df.shape} - Columns: {columns_info} User Question: {query} Provide a clear, concise answer based on the dataset structure.""" response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful data analysis assistant."}, {"role": "user", "content": prompt} ], max_tokens=500 ) answer = response.choices[0].message.content return { 'success': True, 'type': 'llm_response', 'result': {'message': answer}, 'method': 'llm' } except Exception as e: return { 'success': False, 'error': f"LLM query failed: {str(e)}" } def _query_huggingface(self, query: str) -> Dict[str, Any]: """Query HuggingFace Inference API.""" try: import requests API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-large" headers = {"Authorization": f"Bearer {self.llm_api_key}"} payload = { "inputs": f"Answer this data question: {query}", "parameters": {"max_length": 200} } response = requests.post(API_URL, headers=headers, json=payload) result = response.json() if isinstance(result, list) and len(result) > 0: answer = result[0].get('generated_text', 'No response') else: answer = str(result) return { 'success': True, 'type': 'llm_response', 'result': {'message': answer}, 'method': 'llm' } except Exception as e: return { 'success': False, 'error': f"HuggingFace query failed: {str(e)}" } def get_chat_history(self) -> List[Dict[str, str]]: """Get the chat history.""" return self.chat_history def clear_history(self) -> None: """Clear the chat history.""" self.chat_history = [] # Streamlit display functions def display_chat_interface(df: pd.DataFrame, column_types: Optional[Dict[str, str]] = None): """Display chat interface in Streamlit.""" st.subheader("💬 Chat With Your Data") # Initialize chat engine if 'chat_engine' not in st.session_state: st.session_state.chat_engine = ChatEngine(df, column_types) chat_engine = st.session_state.chat_engine # LLM configuration with st.expander("⚙️ LLM Configuration (Optional)"): col1, col2 = st.columns(2) with col1: provider = st.selectbox( "LLM Provider", options=['None', 'openai', 'huggingface'], help="Select LLM provider for complex queries" ) with col2: if provider != 'None': api_key = st.text_input( "API Key", type="password", help=f"Enter your {provider} API key" ) if api_key: chat_engine.set_llm_config(api_key, provider) # Chat input query = st.text_input( "Ask a question about your data", placeholder="e.g., 'What is the average age?' or 'Show rows where salary > 50000'" ) if st.button("Ask", type="primary") and query: with st.spinner("Processing..."): response = chat_engine.process_query(query) if response['success']: result = response['result'] # Display message if 'message' in result: st.info(result['message']) # Display data if 'data' in result: st.dataframe(result['data'], use_container_width=True) # Display single value if 'value' in result: st.metric("Result", f"{result['value']:.4f}" if isinstance(result['value'], float) else result['value']) else: st.error(response.get('error', 'Unknown error')) if 'suggestions' in response: st.write("Try these queries:") for suggestion in response['suggestions']: st.code(suggestion) # Example queries with st.expander("📖 Example Queries"): st.markdown(""" **Basic Queries:** - `show me the first 10 rows` - `how many rows and columns?` - `describe the data` **Column Queries:** - `what is the average of [column]?` - `what is the maximum of [column]?` - `show value counts for [column]` **Filtering:** - `show rows where [column] > 100` - `show rows where [column] = value` - `top 10 by [column]` **Aggregation:** - `group by [column] and calculate mean` - `correlation between [col1] and [col2]` **Data Quality:** - `show missing values` - `how many duplicates?` """)