Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import pandas as pd | |
| import os | |
| import json | |
| import tempfile | |
| import shutil | |
| from typing import Optional | |
| from pydantic import BaseModel | |
| from google import genai | |
| from google.genai import types | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Data Analysis API", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with your frontend domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Response models | |
| class AnalysisResponse(BaseModel): | |
| summary: dict | |
| chart_data: dict | |
| metadata: dict | |
| class ErrorResponse(BaseModel): | |
| error: str | |
| details: Optional[str] = None | |
| # Ensure tmp directory exists | |
| os.makedirs("/tmp", exist_ok=True) | |
| def load_file_from_upload(file_path: str, original_filename: str): | |
| """Load file from uploaded temporary file""" | |
| try: | |
| ext = os.path.splitext(original_filename)[-1].lower() | |
| if ext == ".csv": | |
| df = pd.read_csv(file_path) | |
| elif ext in [".xls", ".xlsx"]: | |
| # For Excel files, we'll take the first sheet by default | |
| # In a production app, you might want to let users choose | |
| df = pd.read_excel(file_path, sheet_name=0) | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}") | |
| return df.copy() | |
| except Exception as e: | |
| logger.error(f"Error loading file: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"Error loading file: {str(e)}") | |
| def preprocess(df, drop_thresh=0.5): | |
| """Preprocess the dataframe""" | |
| try: | |
| df = df.copy() | |
| df.columns = [str(c).strip().lower().replace(" ", "_") for c in df.columns] | |
| df = df.loc[:, df.isnull().mean() < drop_thresh] | |
| for col in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| df.loc[:, col] = df[col].fillna(df[col].median()) | |
| elif pd.api.types.is_datetime64_any_dtype(df[col]): | |
| df.loc[:, col] = df[col].fillna(pd.Timestamp('1970-01-01')) | |
| else: | |
| df.loc[:, col] = df[col].fillna("Unknown") | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| try: | |
| df.loc[:, col] = pd.to_numeric(df[col]) | |
| except: | |
| pass | |
| df = df.drop_duplicates() | |
| return df | |
| except Exception as e: | |
| logger.error(f"Error preprocessing data: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error preprocessing data: {str(e)}") | |
| def get_metadata(df): | |
| """Get dataframe metadata""" | |
| return { | |
| "rows": df.shape[0], | |
| "columns": df.shape[1], | |
| "column_names": list(df.columns), | |
| "column_types": df.dtypes.astype(str).to_dict(), | |
| "unique_values": {col: df[col].nunique() for col in df.columns} | |
| } | |
| def generate_summary(meta, fiverow): | |
| """Generate AI summary using Google Gemini""" | |
| try: | |
| # Get API key from environment variable | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="GEMINI_API_KEY environment variable not set") | |
| client = genai.Client(api_key=api_key) | |
| model = "gemini-2.5-flash-lite" | |
| system_prompt = """ | |
| You are a strict JSON generator. | |
| Input contains: | |
| - meta: dataframe metadata | |
| - fiverow: first 5 records of dataframe | |
| You must output JSON with the following structure: | |
| { | |
| "summary": "<short natural language overview of dataset>", | |
| "recommended_charts": [ | |
| { | |
| "type": "<one of: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap>", | |
| "title": "<short title for chart>", | |
| "columns": ["<col1>", "<col2>", "..."], | |
| "python_code": "<full runnable Python code using seaborn/matplotlib that produces the chart>" | |
| }, | |
| ... | |
| ] | |
| } | |
| Mandatory rules: | |
| - Always produce syntactically valid JSON ONLY. No text outside the JSON object. | |
| - Provide at least these chart types somewhere in recommended_charts: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap. | |
| - Use only column names that appear in meta['column_names']. | |
| - The python_code string must be self-contained and runnable assuming a variable `df` exists containing the full cleaned DataFrame. Start the code with imports: | |
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| and include any necessary preprocessing steps (e.g., parsing dates). | |
| - For timeseries charts ensure the datetime column is parsed (`pd.to_datetime`) before plotting. | |
| - For multiple_columns provide a pairplot or facetgrid example that uses up to 4 numeric columns or sensible categorical splits. | |
| - For stacked_bar, show aggregation code (groupby + unstack) and plotting with df.plot(kind='bar', stacked=True). | |
| - For heatmap, compute correlation matrix and plot sns.heatmap with annotations. | |
| - For pie charts, ensure grouping/aggregation when there are >20 unique categories (group small categories into 'Other'). | |
| - For histogram and scatter include axis labels and tight_layout; include plt.show() at the end. | |
| - Keep code minimal but complete so a user can copy-paste and run (assume seaborn, matplotlib, pandas installed). | |
| - For each chart add a sensible "columns" list showing which columns the code uses. | |
| - Do not include examples using columns not present in meta. | |
| - Do not include more than 10 recommended_charts. | |
| - Ensure strings inside the JSON are escaped properly so the JSON parses. | |
| Produce concise natural-language one-line summary in "summary". Ensure JSON is parseable by json.loads in Python. | |
| """ | |
| user_prompt = { | |
| "meta": meta, | |
| "fiverow": fiverow | |
| } | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[types.Part.from_text(text=str(user_prompt))], | |
| ), | |
| ] | |
| generate_content_config = types.GenerateContentConfig( | |
| thinking_config=types.ThinkingConfig(thinking_budget=0), | |
| response_mime_type="application/json", | |
| system_instruction=[types.Part.from_text(text=system_prompt)], | |
| ) | |
| response = "" | |
| for chunk in client.models.generate_content_stream( | |
| model=model, | |
| contents=contents, | |
| config=generate_content_config, | |
| ): | |
| if chunk.text: | |
| response += chunk.text | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating summary: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating AI summary: {str(e)}") | |
| def flatten_columns(df): | |
| """Flatten MultiIndex columns""" | |
| if isinstance(df.columns, pd.MultiIndex): | |
| df.columns = ['_'.join(map(str, col)).strip() for col in df.columns.values] | |
| return df | |
| def extract_chart_data_json_by_type(summary_json: str, df): | |
| """Extract chart data grouped by type""" | |
| try: | |
| data = json.loads(summary_json) | |
| result = {} | |
| for chart in data.get("recommended_charts", []): | |
| chart_type = chart.get("type") | |
| columns = chart.get("columns", []) | |
| title = chart.get("title", "unnamed_chart") | |
| if chart_type not in result: | |
| result[chart_type] = [] | |
| try: | |
| if chart_type == "bar": | |
| df_agg = df[columns].groupby(columns[0]).sum(numeric_only=True).reset_index() | |
| chart_data = df_agg.to_dict(orient="records") | |
| elif chart_type == "stacked_bar": | |
| df_agg = df.groupby(columns).sum(numeric_only=True).unstack() | |
| df_agg = flatten_columns(df_agg) | |
| chart_data = df_agg.fillna(0).to_dict(orient="records") | |
| elif chart_type == "pie": | |
| col = columns[0] | |
| counts = df[col].value_counts() | |
| if len(counts) > 20: | |
| top = counts.nlargest(19) | |
| others = counts.iloc[19:].sum() | |
| counts = pd.concat([top, pd.Series({'Other': others})]) | |
| chart_data = counts.reset_index().rename(columns={'index': col, col: 'value'}).to_dict(orient="records") | |
| elif chart_type == "histogram": | |
| chart_data = df[columns[0]].dropna().tolist() | |
| elif chart_type == "scatter": | |
| chart_data = df[columns].to_dict(orient="records") | |
| elif chart_type == "timeseries": | |
| df_copy = df[columns].copy() | |
| for c in columns: | |
| df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce') | |
| chart_data = df_copy.astype(str).to_dict(orient="records") | |
| elif chart_type == "multiple_columns": | |
| chart_data = df[columns].to_dict(orient="records") | |
| elif chart_type == "heatmap": | |
| corr_df = df[columns].corr().fillna(0) | |
| chart_data = flatten_columns(corr_df).to_dict() | |
| else: | |
| chart_data = [] | |
| except Exception as e: | |
| chart_data = {"error": str(e)} | |
| result[chart_type].append({"title": title, "data": chart_data}) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error extracting chart data: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error extracting chart data: {str(e)}") | |
| async def root(): | |
| return {"message": "Data Analysis API is running"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def analyze_data(file: UploadFile = File(...)): | |
| """ | |
| Analyze uploaded CSV/Excel file and return AI-generated summary with chart recommendations | |
| """ | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="No file provided") | |
| # Check file type | |
| allowed_extensions = ['.csv', '.xls', '.xlsx'] | |
| file_ext = os.path.splitext(file.filename)[-1].lower() | |
| if file_ext not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}" | |
| ) | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: | |
| try: | |
| # Save uploaded file to temporary location | |
| shutil.copyfileobj(file.file, tmp_file) | |
| tmp_file_path = tmp_file.name | |
| # Process the file | |
| df = load_file_from_upload(tmp_file_path, file.filename) | |
| df_clean = preprocess(df) | |
| # Generate metadata | |
| meta = get_metadata(df_clean) | |
| fiverow = df_clean.head(5).to_dict(orient="records") | |
| # Generate AI summary | |
| summary_json = generate_summary(meta, fiverow) | |
| summary_data = json.loads(summary_json) | |
| # Extract chart data by type | |
| chart_data = extract_chart_data_json_by_type(summary_json, df_clean) | |
| return AnalysisResponse( | |
| summary=summary_data, | |
| chart_data=chart_data, | |
| metadata=meta | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| os.unlink(tmp_file_path) | |
| except: | |
| pass | |
| async def http_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail} | |
| ) | |
| async def general_exception_handler(request, exc): | |
| logger.error(f"Unhandled exception: {str(exc)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error", "details": str(exc)} | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |