Spaces:
Build error
Build error
| import os | |
| import shutil | |
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import pandas as pd | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import base64 | |
| # Define constants | |
| MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization | |
| FIGURES_DIR = "./figures" | |
| EXAMPLE_DIR = "./example" | |
| EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv") | |
| # Ensure the figures and example directories exist | |
| os.makedirs(FIGURES_DIR, exist_ok=True) | |
| os.makedirs(EXAMPLE_DIR, exist_ok=True) | |
| # Download the Titanic dataset if it doesn't exist | |
| if not os.path.isfile(EXAMPLE_FILE): | |
| print("Downloading the Titanic dataset for examples...") | |
| try: | |
| # Using seaborn's built-in Titanic dataset | |
| titanic = sns.load_dataset('titanic') | |
| titanic.to_csv(EXAMPLE_FILE, index=False) | |
| print(f"Example dataset saved to {EXAMPLE_FILE}.") | |
| except Exception as e: | |
| print(f"Failed to download the Titanic dataset: {e}") | |
| print("Please ensure the 'example/titanic.csv' file exists.") | |
| # Optionally, exit or continue without examples | |
| # exit(1) | |
| # Initialize tokenizer and model | |
| print("Loading model and tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| model.to('cpu') # Ensure the model runs on CPU | |
| print("Model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| exit(1) | |
| # Define the base prompt for the model | |
| base_prompt = """You are an expert data analyst. | |
| Based on the following data description, determine an appropriate target feature. | |
| List 3 insightful questions regarding the data. | |
| Provide detailed answers to each question with relevant statistics. | |
| Summarize the findings with real-world insights. | |
| Data Description: | |
| {data_description} | |
| Additional Notes: | |
| {additional_notes} | |
| Please provide your response in a structured and detailed manner. | |
| """ | |
| example_notes = """This data is about the Titanic wreck in 1912. | |
| The target figure is the survival of passengers, noted by 'Survived'. | |
| pclass: A proxy for socio-economic status (SES) | |
| 1st = Upper | |
| 2nd = Middle | |
| 3rd = Lower | |
| age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5 | |
| sibsp: Number of siblings/spouses aboard | |
| parch: Number of parents/children aboard""" | |
| def get_images_in_directory(directory): | |
| """Retrieve all image file paths from the specified directory.""" | |
| image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'} | |
| image_files = [] | |
| for root, dirs, files in os.walk(directory): | |
| for file in files: | |
| if os.path.splitext(file)[1].lower() in image_extensions: | |
| image_files.append(os.path.join(root, file)) | |
| return image_files | |
| def generate_summary(prompt): | |
| """Generate a summary from the language model based on the prompt.""" | |
| inputs = tokenizer.encode(prompt, return_tensors="pt") | |
| inputs = inputs.to('cpu') # Ensure the model runs on CPU | |
| # Generate response | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| inputs, | |
| max_length=500, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| def analyze_data(data_file_path): | |
| """Perform data analysis on the uploaded CSV file.""" | |
| try: | |
| data = pd.read_csv(data_file_path) | |
| except Exception as e: | |
| return None, f"Error loading CSV file: {e}", None | |
| # Generate data description | |
| data_description = f"- **Data Summary (.describe()):**\n{data.describe().to_markdown()}\n\n" | |
| data_description += f"- **Data Types:**\n{data.dtypes.to_frame().to_markdown()}\n" | |
| # Determine target variable (for demonstration, assume 'Survived' or first numeric column) | |
| if 'Survived' in data.columns: | |
| target = 'Survived' | |
| else: | |
| numeric_cols = data.select_dtypes(include='number').columns | |
| target = numeric_cols[0] if len(numeric_cols) > 0 else data.columns[0] | |
| # Generate visualizations | |
| visualization_paths = [] | |
| # Correlation heatmap | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(data.corr(), annot=True, fmt=".2f", cmap='coolwarm') | |
| plt.title("Correlation Heatmap") | |
| heatmap_path = os.path.join(FIGURES_DIR, "correlation_heatmap.png") | |
| plt.savefig(heatmap_path) | |
| plt.clf() | |
| visualization_paths.append(heatmap_path) | |
| # Distribution of target variable | |
| plt.figure(figsize=(8, 6)) | |
| sns.countplot(x=target, data=data) | |
| plt.title(f"Distribution of {target}") | |
| distribution_path = os.path.join(FIGURES_DIR, f"{target}_distribution.png") | |
| plt.savefig(distribution_path) | |
| plt.clf() | |
| visualization_paths.append(distribution_path) | |
| # Pairplot (limited to first 5 numeric columns for performance) | |
| numeric_cols = data.select_dtypes(include='number').columns[:5] | |
| if len(numeric_cols) >= 2: | |
| sns.pairplot(data[numeric_cols].dropna()) | |
| pairplot_path = os.path.join(FIGURES_DIR, "pairplot.png") | |
| plt.savefig(pairplot_path) | |
| plt.clf() | |
| visualization_paths.append(pairplot_path) | |
| return data_description, visualization_paths, target | |
| def interact_with_agent(file_input, additional_notes): | |
| """Process the uploaded file and interact with the language model to analyze data.""" | |
| # Clear and recreate the figures directory | |
| if os.path.exists(FIGURES_DIR): | |
| shutil.rmtree(FIGURES_DIR) | |
| os.makedirs(FIGURES_DIR, exist_ok=True) | |
| if file_input is None: | |
| return [{"role": "assistant", "content": "❌ No file uploaded. Please upload a CSV file to proceed."}] | |
| # Analyze the data | |
| data_description, visualization_paths, target = analyze_data(file_input.name) | |
| if data_description is None: | |
| return [{"role": "assistant", "content": data_description}] # data_description contains the error message | |
| # Construct the prompt for the model | |
| prompt = base_prompt.format( | |
| data_description=data_description, | |
| additional_notes=additional_notes if additional_notes else "None." | |
| ) | |
| # Generate summary from the model | |
| summary = generate_summary(prompt) | |
| # Prepare chat messages in 'messages' format | |
| messages = [ | |
| {"role": "user", "content": "I have uploaded a CSV file for analysis."}, | |
| {"role": "assistant", "content": "⏳ _Analyzing the data..._"} | |
| ] | |
| # Append the summary | |
| messages.append({"role": "assistant", "content": summary}) | |
| # Append images by converting them to Base64 | |
| for image_path in visualization_paths: | |
| # Ensure the image path is valid before attempting to display | |
| if os.path.isfile(image_path): | |
| with open(image_path, "rb") as img_file: | |
| img_bytes = img_file.read() | |
| encoded_img = base64.b64encode(img_bytes).decode() | |
| img_md = f"" | |
| messages.append({"role": "assistant", "content": img_md}) | |
| else: | |
| messages.append({"role": "assistant", "content": f"⚠️ Unable to find image: {image_path}"}) | |
| return messages | |
| # Define the Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue=gr.themes.colors.blue, | |
| secondary_hue=gr.themes.colors.orange, | |
| ) | |
| ) as demo: | |
| gr.Markdown("""# 📊 Data Analyst Assistant | |
| Upload a `.csv` file, add any additional notes, and **the assistant will analyze the data and generate visualizations and insights for you!** | |
| **Example:** [Titanic Dataset](./example/titanic.csv) | |
| """) | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
| text_input = gr.Textbox( | |
| label="Additional Notes", | |
| placeholder="Enter any additional notes or leave blank..." | |
| ) | |
| submit = gr.Button("Run Analysis", variant="primary") | |
| chatbot = gr.Chatbot(label="Data Analyst Agent", type='messages', height=500) | |
| # Handle examples only if the example file exists | |
| if os.path.isfile(EXAMPLE_FILE): | |
| gr.Examples( | |
| examples=[[EXAMPLE_FILE, example_notes]], | |
| inputs=[file_input, text_input], | |
| label="Examples", | |
| cache_examples=False | |
| ) | |
| else: | |
| gr.Markdown("**No example files available.** Please upload your own CSV files.") | |
| # Connect the submit button to the interact_with_agent function | |
| submit.click( | |
| interact_with_agent, | |
| inputs=[file_input, text_input], | |
| outputs=[chatbot], | |
| api_name="run_analysis" | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |