Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| from scipy import stats | |
| import warnings | |
| import google.generativeai as genai | |
| import os | |
| from dotenv import load_dotenv | |
| import logging | |
| from datetime import datetime | |
| import tempfile | |
| import json | |
| warnings.filterwarnings('ignore') | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| # Load environment variables | |
| #load_dotenv() | |
| # Gemini API configuration | |
| # Set your API key as environment variable: GEMINI_API_KEY | |
| #genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
| def analyze_dataset_overview(file_obj, api_key) -> tuple: | |
| """ | |
| Analyzes dataset using Gemini AI and provides storytelling overview. | |
| Args: | |
| file_obj: Gradio file object | |
| api_key: Gemini API key from user input | |
| Returns: | |
| story_text (str): AI-generated data story | |
| basic_info_text (str): Dataset basic information | |
| data_quality_score (float): Data quality percentage | |
| """ | |
| if file_obj is None: | |
| return "❌ Please upload a CSV file first.", "", 0 | |
| if not api_key or api_key.strip() == "": | |
| return "❌ Please enter your Gemini API key first.", "", 0 | |
| try: | |
| df = pd.read_csv(file_obj.name) | |
| # Extract dataset metadata | |
| metadata = extract_dataset_metadata(df) | |
| # Create prompt for Gemini | |
| gemini_prompt = create_insights_prompt(metadata) | |
| # Generate story with Gemini | |
| story = generate_insights_with_gemini(gemini_prompt, api_key) | |
| # Create basic info summary | |
| basic_info = create_basic_info_summary(metadata) | |
| # Calculate data quality score | |
| quality_score = metadata['data_quality'] | |
| return story, basic_info, quality_score | |
| except Exception as e: | |
| return f"❌ Error loading data: {str(e)}", "", 0 | |
| def extract_dataset_metadata(df: pd.DataFrame) -> dict: | |
| """ | |
| Extracts metadata from dataset. | |
| Args: | |
| df (pd.DataFrame): DataFrame to analyze | |
| Returns: | |
| dict: Dataset metadata | |
| """ | |
| rows, cols = df.shape | |
| columns = df.columns.tolist() | |
| # Data types | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| categorical_cols = df.select_dtypes(include=['object']).columns.tolist() | |
| datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist() | |
| # Missing values | |
| missing_data = df.isnull().sum() | |
| missing_percentage = (missing_data / len(df) * 100).round(2) | |
| # Basic statistics | |
| numeric_stats = {} | |
| if numeric_cols: | |
| numeric_stats = df[numeric_cols].describe().to_dict() | |
| # Categorical variable information | |
| categorical_info = {} | |
| for col in categorical_cols[:5]: # First 5 categorical columns | |
| unique_count = df[col].nunique() | |
| top_values = df[col].value_counts().head(3).to_dict() | |
| categorical_info[col] = { | |
| 'unique_count': unique_count, | |
| 'top_values': top_values | |
| } | |
| # Potential relationships | |
| correlations = {} | |
| if len(numeric_cols) > 1: | |
| corr_matrix = df[numeric_cols].corr() | |
| # Find highest correlations | |
| high_corr = [] | |
| for i in range(len(corr_matrix.columns)): | |
| for j in range(i+1, len(corr_matrix.columns)): | |
| corr_val = abs(corr_matrix.iloc[i, j]) | |
| if corr_val > 0.7: | |
| high_corr.append({ | |
| 'var1': corr_matrix.columns[i], | |
| 'var2': corr_matrix.columns[j], | |
| 'correlation': round(corr_val, 3) | |
| }) | |
| correlations = high_corr[:5] # Top 5 correlations | |
| return { | |
| 'shape': (rows, cols), | |
| 'columns': columns, | |
| 'numeric_cols': numeric_cols, | |
| 'categorical_cols': categorical_cols, | |
| 'datetime_cols': datetime_cols, | |
| 'missing_data': missing_data.to_dict(), | |
| 'missing_percentage': missing_percentage.to_dict(), | |
| 'numeric_stats': numeric_stats, | |
| 'categorical_info': categorical_info, | |
| 'correlations': correlations, | |
| 'data_quality': round((df.notna().sum().sum() / (rows * cols)) * 100, 1) | |
| } | |
| def create_insights_prompt(metadata: dict) -> str: | |
| """ | |
| Creates data insights prompt for Gemini. | |
| Args: | |
| metadata (dict): Dataset metadata | |
| Returns: | |
| str: Gemini prompt | |
| """ | |
| prompt = f""" | |
| You are an expert data analyst and storyteller. Using the following dataset information, | |
| predict what this dataset is about and tell a story about it. | |
| DATASET INFORMATION: | |
| - Size: {metadata['shape'][0]:,} rows, {metadata['shape'][1]} columns | |
| - Columns: {', '.join(metadata['columns'])} | |
| - Numeric columns: {', '.join(metadata['numeric_cols'])} | |
| - Categorical columns: {', '.join(metadata['categorical_cols'])} | |
| - Data quality: {metadata['data_quality']}% | |
| CATEGORICAL VARIABLE DETAILS: | |
| {metadata['categorical_info']} | |
| HIGH CORRELATIONS: | |
| {metadata['correlations']} | |
| Please create a story in the following format: | |
| # Dataset Overview | |
| ## What is this dataset about? | |
| [Your prediction about the dataset] | |
| ## Which sector/domain does it belong to? | |
| [Your sector analysis] | |
| ## Potential Use Cases | |
| - [Use case 1] | |
| - [Use case 2] | |
| - [Use case 3] | |
| ## Interesting Findings | |
| - [Finding 1] | |
| - [Finding 2] | |
| - [Finding 3] | |
| ## What Can We Do With This Data? | |
| - [Potential analysis 1] | |
| - [Potential analysis 2] | |
| - [Potential analysis 3] | |
| Make your story visual and engaging using emojis! | |
| Keep it in English and make it professional yet accessible. | |
| Use proper markdown formatting for headers and lists. | |
| """ | |
| return prompt | |
| def generate_insights_with_gemini(prompt: str, api_key: str) -> str: | |
| """ | |
| Generates data insights using Gemini AI. | |
| Args: | |
| prompt (str): Prepared prompt for Gemini | |
| api_key (str): Gemini API key | |
| Returns: | |
| str: Story generated by Gemini | |
| """ | |
| try: | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel('gemini-1.5-flash') | |
| response = model.generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| # Fallback story if Gemini API fails | |
| return f""" | |
| 🔍 **DATA DISCOVERY STORY** | |
| ⚠️ Gemini API Error: {str(e)} | |
| 📊 **Fallback Analysis**: | |
| This dataset appears to be a fascinating collection of information! | |
| 🎯 **Prediction**: Based on the structure, this could be business, e-commerce, or customer behavior data. | |
| 🏢 **Sector**: Likely used in retail, digital marketing, or analytics domain. | |
| ✨ **Potential Stories**: | |
| • 🛒 Customer journey analysis | |
| • 📈 Seasonal trends and patterns | |
| • 👥 Customer segmentation | |
| • 💡 Recommendation systems | |
| • 🎯 Marketing campaign optimization | |
| 🔮 **What We Can Do**: | |
| • Customer lifetime value prediction | |
| • Churn prediction modeling | |
| • Pricing strategy optimization | |
| • Market basket analysis | |
| • A/B testing insights | |
| 📊 The data quality looks promising for analysis! | |
| """ | |
| def create_basic_info_summary(metadata: dict) -> str: | |
| """Creates basic information summary text""" | |
| summary = f""" | |
| 📋 **Dataset Overview** | |
| 📊 **Size**: {metadata['shape'][0]:,} rows × {metadata['shape'][1]} columns | |
| 🔢 **Data Types**: | |
| • Numeric variables: {len(metadata['numeric_cols'])} | |
| • Categorical variables: {len(metadata['categorical_cols'])} | |
| • DateTime variables: {len(metadata['datetime_cols'])} | |
| 🎯 **Data Quality**: {metadata['data_quality']}% | |
| 📈 **Missing Data**: {sum(metadata['missing_data'].values())} total missing values | |
| 🔗 **High Correlations Found**: {len(metadata['correlations'])} pairs | |
| """ | |
| return summary | |
| def generate_data_profiling(file_obj) -> tuple: | |
| """ | |
| Generates detailed data profiling report. | |
| Args: | |
| file_obj: Gradio file object | |
| Returns: | |
| missing_data_df (DataFrame): Missing data analysis | |
| numeric_stats_df (DataFrame): Numeric statistics | |
| categorical_stats_df (DataFrame): Categorical statistics | |
| """ | |
| if file_obj is None: | |
| return None, None, None | |
| try: | |
| df = pd.read_csv(file_obj.name) | |
| # Missing data analysis | |
| missing_data = df.isnull().sum() | |
| missing_pct = (missing_data / len(df) * 100).round(2) | |
| missing_df = pd.DataFrame({ | |
| 'Column': missing_data.index, | |
| 'Missing Count': missing_data.values, | |
| 'Missing Percentage': missing_pct.values | |
| }).sort_values('Missing Count', ascending=False) | |
| # Numeric statistics | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| numeric_stats_df = None | |
| if len(numeric_cols) > 0: | |
| numeric_stats_df = df[numeric_cols].describe().round(3).reset_index() | |
| # Categorical statistics | |
| cat_cols = df.select_dtypes(include=['object']).columns | |
| categorical_stats = [] | |
| for col in cat_cols: | |
| categorical_stats.append({ | |
| 'Column': col, | |
| 'Unique Values': df[col].nunique(), | |
| 'Most Frequent': df[col].mode().iloc[0] if len(df[col].mode()) > 0 else 'N/A', | |
| 'Frequency': df[col].value_counts().iloc[0] if len(df[col].value_counts()) > 0 else 0 | |
| }) | |
| categorical_stats_df = pd.DataFrame(categorical_stats) if categorical_stats else None | |
| return missing_df, numeric_stats_df, categorical_stats_df | |
| except Exception as e: | |
| error_df = pd.DataFrame({'Error': [f"Error in profiling: {str(e)}"]}) | |
| return error_df, None, None | |
| def create_smart_visualizations(file_obj) -> tuple: | |
| """ | |
| Creates smart visualizations. | |
| Args: | |
| file_obj: Gradio file object | |
| Returns: | |
| dtype_fig (Plot): Data type distribution chart | |
| missing_fig (Plot): Missing data bar chart | |
| correlation_fig (Plot): Correlation heatmap | |
| distribution_fig (Plot): Variable distributions | |
| """ | |
| if file_obj is None: | |
| return None, None, None, None | |
| try: | |
| df = pd.read_csv(file_obj.name) | |
| # 1. Data type distribution | |
| dtype_counts = df.dtypes.value_counts() | |
| dtype_fig = px.pie( | |
| values=dtype_counts.values, | |
| names=[str(dtype) for dtype in dtype_counts.index], # Convert dtype objects to strings | |
| title="🔍 Data Type Distribution" | |
| ) | |
| dtype_fig.update_traces(textposition='inside', textinfo='percent+label') | |
| # 2. Missing data heatmap | |
| missing_data = df.isnull().sum() | |
| missing_fig = px.bar( | |
| x=missing_data.index, | |
| y=missing_data.values, | |
| title="🔴 Missing Data by Column", | |
| labels={'x': 'Columns', 'y': 'Missing Count'} | |
| ) | |
| missing_fig.update_xaxes(tickangle=45) | |
| # 3. Correlation heatmap | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| correlation_fig = None | |
| if len(numeric_cols) > 1: | |
| corr_matrix = df[numeric_cols].corr() | |
| correlation_fig = px.imshow( | |
| corr_matrix, | |
| text_auto=True, | |
| aspect="auto", | |
| title="🔗 Correlation Matrix", | |
| color_continuous_scale='RdBu' | |
| ) | |
| # 4. Distribution plots for numeric variables | |
| distribution_fig = None | |
| if len(numeric_cols) > 0: | |
| # Select first 4 numeric columns for distribution | |
| cols_to_plot = numeric_cols[:4] | |
| if len(cols_to_plot) == 1: | |
| distribution_fig = px.histogram( | |
| df, x=cols_to_plot[0], | |
| title=f"📊 Distribution of {cols_to_plot[0]}" | |
| ) | |
| else: | |
| # Create subplots for multiple columns | |
| fig = make_subplots( | |
| rows=2, cols=2, | |
| subplot_titles=[f"{col} Distribution" for col in cols_to_plot] | |
| ) | |
| for i, col in enumerate(cols_to_plot): | |
| row = (i // 2) + 1 | |
| col_pos = (i % 2) + 1 | |
| fig.add_trace( | |
| go.Histogram(x=df[col].values, name=str(col), showlegend=False), # Convert to numpy array and string | |
| row=row, col=col_pos | |
| ) | |
| fig.update_layout(title="📊 Numeric Variable Distributions") | |
| distribution_fig = fig | |
| return dtype_fig, missing_fig, correlation_fig, distribution_fig | |
| except Exception as e: | |
| # Return error plot | |
| error_fig = px.scatter(title=f"❌ Visualization Error: {str(e)}") | |
| return error_fig, None, None, None | |
| # Create Gradio interface | |
| def create_gradio_interface(): | |
| """Creates main Gradio interface""" | |
| with gr.Blocks(title="🚀 AI Data Explorer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🚀 AutoEDA") | |
| gr.Markdown("Upload your CSV file and get AI-powered analysis reports!") | |
| with gr.Row(): | |
| file_input = gr.File( | |
| label="📁 Upload CSV File", | |
| file_types=[".csv"] | |
| ) | |
| with gr.Tabs(): | |
| # Overview tab | |
| with gr.Tab("🔍 Overview"): | |
| gr.Markdown("### AI-Powered Data Insights") | |
| with gr.Row(): | |
| api_key_input = gr.Textbox( | |
| label="🔑 Gemini API Key", | |
| placeholder="Enter your Gemini API key here...", | |
| type="password" | |
| ) | |
| with gr.Row(): | |
| overview_btn = gr.Button("🎯 Generate Story", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| story_output = gr.Markdown( | |
| label="📖 Data Insights", | |
| value="" | |
| ) | |
| with gr.Column(): | |
| basic_info_output = gr.Markdown( | |
| label="📋 Basic Information", | |
| value="" | |
| ) | |
| with gr.Row(): | |
| quality_score = gr.Number( | |
| label="🎯 Data Quality Score (%)", | |
| precision=1 | |
| ) | |
| overview_btn.click( | |
| fn=analyze_dataset_overview, | |
| inputs=[file_input, api_key_input], | |
| outputs=[story_output, basic_info_output, quality_score] | |
| ) | |
| # Profiling tab | |
| with gr.Tab("📊 Data Profiling"): | |
| gr.Markdown("### Automated Data Profiling") | |
| with gr.Row(): | |
| profiling_btn = gr.Button("🔍 Generate Profiling", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| missing_data_table = gr.Dataframe( | |
| label="🔴 Missing Data Analysis", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| numeric_stats_table = gr.Dataframe( | |
| label="🔢 Numeric Statistics", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| categorical_stats_table = gr.Dataframe( | |
| label="📝 Categorical Statistics", | |
| interactive=False | |
| ) | |
| profiling_btn.click( | |
| fn=generate_data_profiling, | |
| inputs=[file_input], | |
| outputs=[missing_data_table, numeric_stats_table, categorical_stats_table] | |
| ) | |
| # Visualization tab | |
| with gr.Tab("📈 Smart Visualizations"): | |
| gr.Markdown("### Automated Data Visualizations") | |
| with gr.Row(): | |
| viz_btn = gr.Button("🎨 Create Visualizations", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| dtype_plot = gr.Plot(label="🔍 Data Types") | |
| missing_plot = gr.Plot(label="🔴 Missing Data") | |
| with gr.Column(): | |
| correlation_plot = gr.Plot(label="🔗 Correlations") | |
| distribution_plot = gr.Plot(label="📊 Distributions") | |
| viz_btn.click( | |
| fn=create_smart_visualizations, | |
| inputs=[file_input], | |
| outputs=[dtype_plot, missing_plot, correlation_plot, distribution_plot] | |
| ) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown("💡 **Tip**: Get your free Gemini API key from [Google AI Studio](https://aistudio.google.com/)") | |
| return demo | |
| # Main application | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| mcp_server=True | |
| ) |