| | import gradio as gr |
| | import pandas as pd |
| | import numpy as np |
| | from transformers import pipeline |
| | import re |
| | import datetime |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| |
|
| | |
| | qa_model = pipeline( |
| | "text2text-generation", |
| | model="google/flan-t5-small", |
| | device=-1 |
| | ) |
| |
|
| | |
| | df = None |
| | df_text_representation = None |
| |
|
| | def preprocess_dataframe(df): |
| | """Process dataframe to handle common data types and create a text representation.""" |
| | |
| | for col in df.columns: |
| | if df[col].dtype == 'object': |
| | try: |
| | |
| | df[col] = pd.to_datetime(df[col], errors='ignore') |
| | except: |
| | pass |
| | |
| | |
| | column_descriptions = [] |
| | |
| | for col in df.columns: |
| | dtype = df[col].dtype |
| | if pd.api.types.is_numeric_dtype(dtype): |
| | if pd.api.types.is_integer_dtype(dtype): |
| | type_desc = "integer" |
| | min_val = df[col].min() if not df[col].isna().all() else "unknown" |
| | max_val = df[col].max() if not df[col].isna().all() else "unknown" |
| | column_descriptions.append(f"Column '{col}' contains {type_desc} values ranging from {min_val} to {max_val}") |
| | else: |
| | type_desc = "decimal" |
| | min_val = round(df[col].min(), 2) if not df[col].isna().all() else "unknown" |
| | max_val = round(df[col].max(), 2) if not df[col].isna().all() else "unknown" |
| | column_descriptions.append(f"Column '{col}' contains {type_desc} values ranging from {min_val} to {max_val}") |
| | elif pd.api.types.is_datetime64_dtype(dtype): |
| | earliest = df[col].min() if not df[col].isna().all() else "unknown" |
| | latest = df[col].max() if not df[col].isna().all() else "unknown" |
| | column_descriptions.append(f"Column '{col}' contains date/time values from {earliest} to {latest}") |
| | else: |
| | unique_count = df[col].nunique() |
| | if unique_count <= 10: |
| | unique_vals = ", ".join([f"'{str(val)}'" for val in df[col].dropna().unique()[:10]]) |
| | column_descriptions.append(f"Column '{col}' contains text values including: {unique_vals}") |
| | else: |
| | column_descriptions.append(f"Column '{col}' contains text values with {unique_count} unique entries") |
| | |
| | |
| | df_desc = [ |
| | f"This is a dataset with {df.shape[0]} rows and {df.shape[1]} columns.", |
| | f"The columns are: {', '.join([f'{col}' for col in df.columns])}.", |
| | "Column details:" |
| | ] |
| | |
| | return df, "\n".join(df_desc + column_descriptions) |
| |
|
| | def extract_numeric_values(text): |
| | """Extract numeric values from a string.""" |
| | |
| | decimal_matches = re.findall(r'-?\d+\.\d+', text) |
| | if decimal_matches: |
| | return [float(m) for m in decimal_matches] |
| | |
| | |
| | int_matches = re.findall(r'-?\d+', text) |
| | if int_matches: |
| | return [int(m) for m in int_matches] |
| | |
| | return [] |
| |
|
| | def extract_dates(text): |
| | """Extract potential date references from text.""" |
| | date_patterns = [ |
| | r'\d{4}-\d{1,2}-\d{1,2}', |
| | r'\d{1,2}/\d{1,2}/\d{2,4}', |
| | r'\d{1,2}-\d{1,2}-\d{2,4}', |
| | r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]* \d{1,2},? \d{4}\b' |
| | ] |
| | |
| | dates = [] |
| | for pattern in date_patterns: |
| | matches = re.findall(pattern, text) |
| | dates.extend(matches) |
| | |
| | return dates |
| |
|
| | def interpret_query(query, df_text): |
| | """Use the language model to interpret the query about the dataframe.""" |
| | prompt = f""" |
| | You are an AI assistant that helps analyze data. Given the following dataset description and a user question, |
| | provide a clear and concise answer based only on the information in the dataset. |
| | |
| | Dataset information: |
| | {df_text} |
| | |
| | User question: {query} |
| | |
| | Your answer should be informative and direct. If the question cannot be answered with the available information, |
| | explain why. If the question requires finding specific data, provide the relevant values or data points. |
| | """ |
| | |
| | result = qa_model(prompt, max_length=512, do_sample=False) |
| | return result[0]['generated_text'] |
| |
|
| | def get_data_insight(query, df): |
| | """Try to derive insights from the dataframe based on the query.""" |
| | response = "" |
| | |
| | |
| | column_matches = [] |
| | for col in df.columns: |
| | if col.lower() in query.lower(): |
| | column_matches.append(col) |
| | |
| | |
| | numeric_values = extract_numeric_values(query) |
| | date_values = extract_dates(query) |
| | |
| | |
| | compute_stats = any(term in query.lower() for term in ['average', 'mean', 'max', 'min', 'sum', 'count', 'total']) |
| | |
| | |
| | filtering_query = any(term in query.lower() for term in ['where', 'greater than', 'less than', 'equal to', 'between', 'filter', 'show']) |
| | |
| | try: |
| | |
| | if column_matches and compute_stats: |
| | for col in column_matches: |
| | if pd.api.types.is_numeric_dtype(df[col].dtype): |
| | response += f"\nStatistics for {col}:\n" |
| | response += f"Mean: {df[col].mean():.2f}\n" |
| | response += f"Min: {df[col].min()}\n" |
| | response += f"Max: {df[col].max()}\n" |
| | response += f"Sum: {df[col].sum()}\n" |
| | response += f"Count: {df[col].count()}\n" |
| | |
| | |
| | elif filtering_query and column_matches: |
| | filtered_df = df.copy() |
| | filter_applied = False |
| | |
| | for col in column_matches: |
| | |
| | if pd.api.types.is_numeric_dtype(df[col].dtype) and numeric_values: |
| | if "greater than" in query.lower() or "more than" in query.lower() or "above" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col] > numeric_values[0]] |
| | filter_applied = True |
| | elif "less than" in query.lower() or "smaller than" in query.lower() or "below" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col] < numeric_values[0]] |
| | filter_applied = True |
| | elif "equal" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col] == numeric_values[0]] |
| | filter_applied = True |
| | elif "between" in query.lower() and len(numeric_values) >= 2: |
| | filtered_df = filtered_df[(filtered_df[col] >= min(numeric_values)) & (filtered_df[col] <= max(numeric_values))] |
| | filter_applied = True |
| | |
| | |
| | elif pd.api.types.is_datetime64_dtype(df[col].dtype) and date_values: |
| | try: |
| | date_obj = pd.to_datetime(date_values[0]) |
| | if "after" in query.lower() or "later than" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col] > date_obj] |
| | filter_applied = True |
| | elif "before" in query.lower() or "earlier than" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col] < date_obj] |
| | filter_applied = True |
| | elif "on" in query.lower() or "equal" in query.lower(): |
| | filtered_df = filtered_df[filtered_df[col].dt.date == date_obj.date()] |
| | filter_applied = True |
| | except: |
| | pass |
| | |
| | |
| | elif df[col].dtype == 'object': |
| | for word in query.lower().split(): |
| | if len(word) > 3 and word not in ["show", "where", "with", "contains", "containing", "that", "have"]: |
| | filtered_df = filtered_df[filtered_df[col].str.lower().str.contains(word, na=False)] |
| | filter_applied = True |
| | |
| | if filter_applied: |
| | if len(filtered_df) > 0: |
| | if len(filtered_df) <= 10: |
| | response += f"\nFound {len(filtered_df)} matching rows:\n{filtered_df.to_string()}" |
| | else: |
| | response += f"\nFound {len(filtered_df)} matching rows. Here are the first 5:\n{filtered_df.head().to_string()}" |
| | else: |
| | response += "\nNo data found matching your criteria." |
| | except Exception as e: |
| | response += f"\nError processing query: {str(e)}" |
| | |
| | return response |
| |
|
| | def process_file_and_query(file, query): |
| | global df, df_text_representation |
| | |
| | if file is None: |
| | return "Please upload a CSV file first." |
| | |
| | if query.strip() == "": |
| | return "Please enter a question about the data." |
| | |
| | try: |
| | |
| | df = pd.read_csv(file.name) |
| | |
| | |
| | df, df_text_representation = preprocess_dataframe(df) |
| | |
| | |
| | llm_response = interpret_query(query, df_text_representation) |
| | |
| | |
| | data_insights = get_data_insight(query, df) |
| | |
| | |
| | final_response = f"Response: {llm_response}" |
| | |
| | if data_insights: |
| | final_response += f"\n\nAdditional Data Analysis:{data_insights}" |
| | |
| | return final_response |
| | |
| | except Exception as e: |
| | return f"Error processing file or query: {str(e)}" |
| |
|
| | |
| | with gr.Blocks(title="CSV Question Answering Bot") as demo: |
| | gr.Markdown("# CSV Question Answering Bot") |
| | gr.Markdown("Upload a CSV file and ask questions about it in natural language.") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | file_input = gr.File(label="Upload CSV File") |
| | query_input = gr.Textbox(label="Ask a question about your data", placeholder="Example: What's the average sales value?") |
| | submit_btn = gr.Button("Submit") |
| | |
| | with gr.Column(): |
| | output = gr.Textbox(label="Answer", lines=10) |
| | |
| | submit_btn.click(fn=process_file_and_query, inputs=[file_input, query_input], outputs=output) |
| |
|
| | gr.Examples( |
| | examples=[ |
| | ["What columns are in this dataset?"], |
| | ["What's the average of the numeric columns?"], |
| | ["Show me the first 5 rows"], |
| | ["How many rows have values greater than 100?"], |
| | ["What's the date range in this dataset?"] |
| | ], |
| | inputs=query_input |
| | ) |
| |
|
| | gr.Markdown(""" |
| | ## How to use this app |
| | 1. Upload a CSV file of any kind (numerical data, dates, text, etc.) |
| | 2. Type a natural language question about the data |
| | 3. Click 'Submit' to get your answer |
| | |
| | Examples of questions you can ask: |
| | - "What are the columns in this dataset?" |
| | - "What's the average salary?" |
| | - "Show me rows where age is greater than 30" |
| | - "What's the minimum and maximum price?" |
| | - "How many entries were recorded after January 2023?" |
| | - "Find all products in the Electronics category" |
| | """) |
| |
|
| | |
| | if __name__ == "__main__": |
| | demo.launch() |