csv_bot_2 / app.py
Chamin09's picture
Create app.py
a878557 verified
import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline
import re
import datetime
import warnings
warnings.filterwarnings('ignore')
# Initialize the text-to-text model for question answering
qa_model = pipeline(
"text2text-generation",
model="google/flan-t5-small", # Using a smaller model for free tier compatibility
device=-1 # Use CPU
)
# Global variable to store the dataframe
df = None
df_text_representation = None
def preprocess_dataframe(df):
"""Process dataframe to handle common data types and create a text representation."""
# Convert date-like columns to datetime
for col in df.columns:
if df[col].dtype == 'object':
try:
# Try to convert to datetime
df[col] = pd.to_datetime(df[col], errors='ignore')
except:
pass
# Create a text description of the dataframe
column_descriptions = []
for col in df.columns:
dtype = df[col].dtype
if pd.api.types.is_numeric_dtype(dtype):
if pd.api.types.is_integer_dtype(dtype):
type_desc = "integer"
min_val = df[col].min() if not df[col].isna().all() else "unknown"
max_val = df[col].max() if not df[col].isna().all() else "unknown"
column_descriptions.append(f"Column '{col}' contains {type_desc} values ranging from {min_val} to {max_val}")
else:
type_desc = "decimal"
min_val = round(df[col].min(), 2) if not df[col].isna().all() else "unknown"
max_val = round(df[col].max(), 2) if not df[col].isna().all() else "unknown"
column_descriptions.append(f"Column '{col}' contains {type_desc} values ranging from {min_val} to {max_val}")
elif pd.api.types.is_datetime64_dtype(dtype):
earliest = df[col].min() if not df[col].isna().all() else "unknown"
latest = df[col].max() if not df[col].isna().all() else "unknown"
column_descriptions.append(f"Column '{col}' contains date/time values from {earliest} to {latest}")
else:
unique_count = df[col].nunique()
if unique_count <= 10: # For columns with few unique values, list them
unique_vals = ", ".join([f"'{str(val)}'" for val in df[col].dropna().unique()[:10]])
column_descriptions.append(f"Column '{col}' contains text values including: {unique_vals}")
else:
column_descriptions.append(f"Column '{col}' contains text values with {unique_count} unique entries")
# Create a general dataframe description
df_desc = [
f"This is a dataset with {df.shape[0]} rows and {df.shape[1]} columns.",
f"The columns are: {', '.join([f'{col}' for col in df.columns])}.",
"Column details:"
]
return df, "\n".join(df_desc + column_descriptions)
def extract_numeric_values(text):
"""Extract numeric values from a string."""
# Try to find decimal numbers first
decimal_matches = re.findall(r'-?\d+\.\d+', text)
if decimal_matches:
return [float(m) for m in decimal_matches]
# Then look for integers
int_matches = re.findall(r'-?\d+', text)
if int_matches:
return [int(m) for m in int_matches]
return []
def extract_dates(text):
"""Extract potential date references from text."""
date_patterns = [
r'\d{4}-\d{1,2}-\d{1,2}', # YYYY-MM-DD
r'\d{1,2}/\d{1,2}/\d{2,4}', # MM/DD/YYYY or DD/MM/YYYY
r'\d{1,2}-\d{1,2}-\d{2,4}', # MM-DD-YYYY or DD-MM-YYYY
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]* \d{1,2},? \d{4}\b' # Month DD, YYYY
]
dates = []
for pattern in date_patterns:
matches = re.findall(pattern, text)
dates.extend(matches)
return dates
def interpret_query(query, df_text):
"""Use the language model to interpret the query about the dataframe."""
prompt = f"""
You are an AI assistant that helps analyze data. Given the following dataset description and a user question,
provide a clear and concise answer based only on the information in the dataset.
Dataset information:
{df_text}
User question: {query}
Your answer should be informative and direct. If the question cannot be answered with the available information,
explain why. If the question requires finding specific data, provide the relevant values or data points.
"""
result = qa_model(prompt, max_length=512, do_sample=False)
return result[0]['generated_text']
def get_data_insight(query, df):
"""Try to derive insights from the dataframe based on the query."""
response = ""
# Extract potential column references from the query
column_matches = []
for col in df.columns:
if col.lower() in query.lower():
column_matches.append(col)
# Extract values that might be used for filtering
numeric_values = extract_numeric_values(query)
date_values = extract_dates(query)
# Check if we need to compute statistics
compute_stats = any(term in query.lower() for term in ['average', 'mean', 'max', 'min', 'sum', 'count', 'total'])
# Check if this is a filtering query
filtering_query = any(term in query.lower() for term in ['where', 'greater than', 'less than', 'equal to', 'between', 'filter', 'show'])
try:
# If columns were mentioned and we need stats
if column_matches and compute_stats:
for col in column_matches:
if pd.api.types.is_numeric_dtype(df[col].dtype):
response += f"\nStatistics for {col}:\n"
response += f"Mean: {df[col].mean():.2f}\n"
response += f"Min: {df[col].min()}\n"
response += f"Max: {df[col].max()}\n"
response += f"Sum: {df[col].sum()}\n"
response += f"Count: {df[col].count()}\n"
# If this seems like a filtering query
elif filtering_query and column_matches:
filtered_df = df.copy()
filter_applied = False
for col in column_matches:
# Handle numeric filtering
if pd.api.types.is_numeric_dtype(df[col].dtype) and numeric_values:
if "greater than" in query.lower() or "more than" in query.lower() or "above" in query.lower():
filtered_df = filtered_df[filtered_df[col] > numeric_values[0]]
filter_applied = True
elif "less than" in query.lower() or "smaller than" in query.lower() or "below" in query.lower():
filtered_df = filtered_df[filtered_df[col] < numeric_values[0]]
filter_applied = True
elif "equal" in query.lower():
filtered_df = filtered_df[filtered_df[col] == numeric_values[0]]
filter_applied = True
elif "between" in query.lower() and len(numeric_values) >= 2:
filtered_df = filtered_df[(filtered_df[col] >= min(numeric_values)) & (filtered_df[col] <= max(numeric_values))]
filter_applied = True
# Handle date filtering
elif pd.api.types.is_datetime64_dtype(df[col].dtype) and date_values:
try:
date_obj = pd.to_datetime(date_values[0])
if "after" in query.lower() or "later than" in query.lower():
filtered_df = filtered_df[filtered_df[col] > date_obj]
filter_applied = True
elif "before" in query.lower() or "earlier than" in query.lower():
filtered_df = filtered_df[filtered_df[col] < date_obj]
filter_applied = True
elif "on" in query.lower() or "equal" in query.lower():
filtered_df = filtered_df[filtered_df[col].dt.date == date_obj.date()]
filter_applied = True
except:
pass
# Handle text filtering
elif df[col].dtype == 'object':
for word in query.lower().split():
if len(word) > 3 and word not in ["show", "where", "with", "contains", "containing", "that", "have"]:
filtered_df = filtered_df[filtered_df[col].str.lower().str.contains(word, na=False)]
filter_applied = True
if filter_applied:
if len(filtered_df) > 0:
if len(filtered_df) <= 10: # Show full results if 10 or fewer rows
response += f"\nFound {len(filtered_df)} matching rows:\n{filtered_df.to_string()}"
else:
response += f"\nFound {len(filtered_df)} matching rows. Here are the first 5:\n{filtered_df.head().to_string()}"
else:
response += "\nNo data found matching your criteria."
except Exception as e:
response += f"\nError processing query: {str(e)}"
return response
def process_file_and_query(file, query):
global df, df_text_representation
if file is None:
return "Please upload a CSV file first."
if query.strip() == "":
return "Please enter a question about the data."
try:
# Read the CSV file
df = pd.read_csv(file.name)
# Preprocess the dataframe
df, df_text_representation = preprocess_dataframe(df)
# Get LLM interpretation of the query
llm_response = interpret_query(query, df_text_representation)
# Try to get direct data insights
data_insights = get_data_insight(query, df)
# Combine responses
final_response = f"Response: {llm_response}"
if data_insights:
final_response += f"\n\nAdditional Data Analysis:{data_insights}"
return final_response
except Exception as e:
return f"Error processing file or query: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="CSV Question Answering Bot") as demo:
gr.Markdown("# CSV Question Answering Bot")
gr.Markdown("Upload a CSV file and ask questions about it in natural language.")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload CSV File")
query_input = gr.Textbox(label="Ask a question about your data", placeholder="Example: What's the average sales value?")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Answer", lines=10)
submit_btn.click(fn=process_file_and_query, inputs=[file_input, query_input], outputs=output)
gr.Examples(
examples=[
["What columns are in this dataset?"],
["What's the average of the numeric columns?"],
["Show me the first 5 rows"],
["How many rows have values greater than 100?"],
["What's the date range in this dataset?"]
],
inputs=query_input
)
gr.Markdown("""
## How to use this app
1. Upload a CSV file of any kind (numerical data, dates, text, etc.)
2. Type a natural language question about the data
3. Click 'Submit' to get your answer
Examples of questions you can ask:
- "What are the columns in this dataset?"
- "What's the average salary?"
- "Show me rows where age is greater than 30"
- "What's the minimum and maximum price?"
- "How many entries were recorded after January 2023?"
- "Find all products in the Electronics category"
""")
# Launch the app
if __name__ == "__main__":
demo.launch()