Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Query
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from fastapi.responses import JSONResponse
|
|
@@ -16,6 +16,7 @@ from google.genai import types
|
|
| 16 |
import logging
|
| 17 |
import asyncio
|
| 18 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 19 |
import re
|
| 20 |
import traceback
|
| 21 |
import motor.motor_asyncio
|
|
@@ -116,7 +117,6 @@ def _find_balanced_json(s: str):
|
|
| 116 |
stack.append('{')
|
| 117 |
elif ch == '}':
|
| 118 |
if not stack:
|
| 119 |
-
# malformed, but continue
|
| 120 |
return None
|
| 121 |
stack.pop()
|
| 122 |
if not stack:
|
|
@@ -125,7 +125,6 @@ def _find_balanced_json(s: str):
|
|
| 125 |
|
| 126 |
|
| 127 |
def _escape_problematic_backslashes(s: str) -> str:
|
| 128 |
-
# Escape backslashes that are not followed by valid JSON escape char
|
| 129 |
return re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', s)
|
| 130 |
|
| 131 |
|
|
@@ -189,8 +188,8 @@ def stream_save_and_hash(upload_file: UploadFile, tmp_path: str, size_limit: Opt
|
|
| 189 |
async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str:
|
| 190 |
path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv")
|
| 191 |
loop = asyncio.get_running_loop()
|
| 192 |
-
#
|
| 193 |
-
await loop.run_in_executor(EXECUTOR, df.to_csv, path, False
|
| 194 |
return path
|
| 195 |
|
| 196 |
|
|
@@ -240,7 +239,8 @@ def get_metadata(df: pd.DataFrame) -> dict:
|
|
| 240 |
|
| 241 |
# ---------- AI generation (blocking) ----------
|
| 242 |
def generate_summary_blocking(meta, fiverow, system_prompt_override: Optional[str] = None):
|
| 243 |
-
|
|
|
|
| 244 |
if not api_key:
|
| 245 |
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
| 246 |
client = genai.Client(api_key=api_key)
|
|
@@ -254,9 +254,37 @@ Input contains:
|
|
| 254 |
You must output JSON with the following structure:
|
| 255 |
{
|
| 256 |
"summary": "<short natural language overview of dataset>",
|
| 257 |
-
"recommended_charts": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
}
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
"""
|
| 261 |
|
| 262 |
user_prompt = {"meta": meta, "fiverow": fiverow}
|
|
@@ -344,7 +372,6 @@ def extract_chart_data_json_by_type(parsed_summary: dict, df: pd.DataFrame):
|
|
| 344 |
elif chart_type == "timeseries":
|
| 345 |
df_copy = df[columns].copy()
|
| 346 |
for c in columns:
|
| 347 |
-
# attempt parse only if not already datetime dtype
|
| 348 |
if not pd.api.types.is_datetime64_any_dtype(df_copy[c]):
|
| 349 |
df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce')
|
| 350 |
chart_data = df_copy.astype(str).to_dict(orient="records")
|
|
@@ -460,7 +487,6 @@ async def analyze_data(file: UploadFile = File(...)):
|
|
| 460 |
|
| 461 |
if existing:
|
| 462 |
snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id"))
|
| 463 |
-
# ensure structure matches response_model
|
| 464 |
summary = existing.get("summary") or {}
|
| 465 |
chart_data = existing.get("chart_data") or {}
|
| 466 |
metadata = existing.get("metadata") or meta
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Query
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from fastapi.responses import JSONResponse
|
|
|
|
| 16 |
import logging
|
| 17 |
import asyncio
|
| 18 |
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
from functools import partial
|
| 20 |
import re
|
| 21 |
import traceback
|
| 22 |
import motor.motor_asyncio
|
|
|
|
| 117 |
stack.append('{')
|
| 118 |
elif ch == '}':
|
| 119 |
if not stack:
|
|
|
|
| 120 |
return None
|
| 121 |
stack.pop()
|
| 122 |
if not stack:
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def _escape_problematic_backslashes(s: str) -> str:
|
|
|
|
| 128 |
return re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', s)
|
| 129 |
|
| 130 |
|
|
|
|
| 188 |
async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str:
|
| 189 |
path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv")
|
| 190 |
loop = asyncio.get_running_loop()
|
| 191 |
+
# Use functools.partial to pass keyword args to to_csv (handles pandas 3.0+ keyword-only changes)
|
| 192 |
+
await loop.run_in_executor(EXECUTOR, partial(df.to_csv, path, index=False))
|
| 193 |
return path
|
| 194 |
|
| 195 |
|
|
|
|
| 239 |
|
| 240 |
# ---------- AI generation (blocking) ----------
|
| 241 |
def generate_summary_blocking(meta, fiverow, system_prompt_override: Optional[str] = None):
|
| 242 |
+
# using provided API key (kept as-is per deployment)
|
| 243 |
+
api_key = os.getenv("GEMINI_API_KEY") or "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs"
|
| 244 |
if not api_key:
|
| 245 |
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
| 246 |
client = genai.Client(api_key=api_key)
|
|
|
|
| 254 |
You must output JSON with the following structure:
|
| 255 |
{
|
| 256 |
"summary": "<short natural language overview of dataset>",
|
| 257 |
+
"recommended_charts": [
|
| 258 |
+
{
|
| 259 |
+
"type": "<one of: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap>",
|
| 260 |
+
"title": "<short title for chart>",
|
| 261 |
+
"columns": ["<col1>", "<col2>", "..."],
|
| 262 |
+
"python_code": "<full runnable Python code using seaborn/matplotlib that produces the chart>"
|
| 263 |
+
},
|
| 264 |
+
...
|
| 265 |
+
]
|
| 266 |
}
|
| 267 |
+
Mandatory rules:
|
| 268 |
+
- Always produce syntactically valid JSON ONLY. No text outside the JSON object.
|
| 269 |
+
- Provide at least these chart types somewhere in recommended_charts: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap.
|
| 270 |
+
- Use only column names that appear in meta['column_names'].
|
| 271 |
+
- The python_code string must be self-contained and runnable assuming a variable `df` exists containing the full cleaned DataFrame. Start the code with imports:
|
| 272 |
+
import pandas as pd
|
| 273 |
+
import seaborn as sns
|
| 274 |
+
import matplotlib.pyplot as plt
|
| 275 |
+
and include any necessary preprocessing steps (e.g., parsing dates).
|
| 276 |
+
- For timeseries charts ensure the datetime column is parsed (`pd.to_datetime`) before plotting.
|
| 277 |
+
- For multiple_columns provide a pairplot or facetgrid example that uses up to 4 numeric columns or sensible categorical splits.
|
| 278 |
+
- For stacked_bar, show aggregation code (groupby + unstack) and plotting with df.plot(kind='bar', stacked=True).
|
| 279 |
+
- For heatmap, compute correlation matrix and plot sns.heatmap with annotations.
|
| 280 |
+
- For pie charts, ensure grouping/aggregation when there are >20 unique categories (group small categories into 'Other').
|
| 281 |
+
- For histogram and scatter include axis labels and tight_layout; include plt.show() at the end.
|
| 282 |
+
- Keep code minimal but complete so a user can copy-paste and run (assume seaborn, matplotlib, pandas installed).
|
| 283 |
+
- For each chart add a sensible "columns" list showing which columns the code uses.
|
| 284 |
+
- Do not include examples using columns not present in meta.
|
| 285 |
+
- Do not include more than 10 recommended_charts.
|
| 286 |
+
- Ensure strings inside the JSON are escaped properly so the JSON parses.
|
| 287 |
+
Produce concise natural-language one-line summary in "summary". Ensure JSON is parseable by json.loads in Python.
|
| 288 |
"""
|
| 289 |
|
| 290 |
user_prompt = {"meta": meta, "fiverow": fiverow}
|
|
|
|
| 372 |
elif chart_type == "timeseries":
|
| 373 |
df_copy = df[columns].copy()
|
| 374 |
for c in columns:
|
|
|
|
| 375 |
if not pd.api.types.is_datetime64_any_dtype(df_copy[c]):
|
| 376 |
df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce')
|
| 377 |
chart_data = df_copy.astype(str).to_dict(orient="records")
|
|
|
|
| 487 |
|
| 488 |
if existing:
|
| 489 |
snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id"))
|
|
|
|
| 490 |
summary = existing.get("summary") or {}
|
| 491 |
chart_data = existing.get("chart_data") or {}
|
| 492 |
metadata = existing.get("metadata") or meta
|