triflix commited on
Commit
1befb1d
·
verified ·
1 Parent(s): eecfd06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Query
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
@@ -16,6 +16,7 @@ from google.genai import types
16
  import logging
17
  import asyncio
18
  from concurrent.futures import ThreadPoolExecutor
 
19
  import re
20
  import traceback
21
  import motor.motor_asyncio
@@ -116,7 +117,6 @@ def _find_balanced_json(s: str):
116
  stack.append('{')
117
  elif ch == '}':
118
  if not stack:
119
- # malformed, but continue
120
  return None
121
  stack.pop()
122
  if not stack:
@@ -125,7 +125,6 @@ def _find_balanced_json(s: str):
125
 
126
 
127
  def _escape_problematic_backslashes(s: str) -> str:
128
- # Escape backslashes that are not followed by valid JSON escape char
129
  return re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', s)
130
 
131
 
@@ -189,8 +188,8 @@ def stream_save_and_hash(upload_file: UploadFile, tmp_path: str, size_limit: Opt
189
  async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str:
190
  path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv")
191
  loop = asyncio.get_running_loop()
192
- # df.to_csv(path, index=False) is blocking -> run in executor
193
- await loop.run_in_executor(EXECUTOR, df.to_csv, path, False, None)
194
  return path
195
 
196
 
@@ -240,7 +239,8 @@ def get_metadata(df: pd.DataFrame) -> dict:
240
 
241
  # ---------- AI generation (blocking) ----------
242
  def generate_summary_blocking(meta, fiverow, system_prompt_override: Optional[str] = None):
243
- api_key = "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs"
 
244
  if not api_key:
245
  raise RuntimeError("GEMINI_API_KEY environment variable not set")
246
  client = genai.Client(api_key=api_key)
@@ -254,9 +254,37 @@ Input contains:
254
  You must output JSON with the following structure:
255
  {
256
  "summary": "<short natural language overview of dataset>",
257
- "recommended_charts": [ ... ]
 
 
 
 
 
 
 
 
258
  }
259
- Always produce syntactically valid JSON ONLY.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  """
261
 
262
  user_prompt = {"meta": meta, "fiverow": fiverow}
@@ -344,7 +372,6 @@ def extract_chart_data_json_by_type(parsed_summary: dict, df: pd.DataFrame):
344
  elif chart_type == "timeseries":
345
  df_copy = df[columns].copy()
346
  for c in columns:
347
- # attempt parse only if not already datetime dtype
348
  if not pd.api.types.is_datetime64_any_dtype(df_copy[c]):
349
  df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce')
350
  chart_data = df_copy.astype(str).to_dict(orient="records")
@@ -460,7 +487,6 @@ async def analyze_data(file: UploadFile = File(...)):
460
 
461
  if existing:
462
  snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id"))
463
- # ensure structure matches response_model
464
  summary = existing.get("summary") or {}
465
  chart_data = existing.get("chart_data") or {}
466
  metadata = existing.get("metadata") or meta
 
1
+ # app.py
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Query
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
 
16
  import logging
17
  import asyncio
18
  from concurrent.futures import ThreadPoolExecutor
19
+ from functools import partial
20
  import re
21
  import traceback
22
  import motor.motor_asyncio
 
117
  stack.append('{')
118
  elif ch == '}':
119
  if not stack:
 
120
  return None
121
  stack.pop()
122
  if not stack:
 
125
 
126
 
127
  def _escape_problematic_backslashes(s: str) -> str:
 
128
  return re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', s)
129
 
130
 
 
188
  async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str:
189
  path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv")
190
  loop = asyncio.get_running_loop()
191
+ # Use functools.partial to pass keyword args to to_csv (handles pandas 3.0+ keyword-only changes)
192
+ await loop.run_in_executor(EXECUTOR, partial(df.to_csv, path, index=False))
193
  return path
194
 
195
 
 
239
 
240
  # ---------- AI generation (blocking) ----------
241
  def generate_summary_blocking(meta, fiverow, system_prompt_override: Optional[str] = None):
242
+ # using provided API key (kept as-is per deployment)
243
+ api_key = os.getenv("GEMINI_API_KEY") or "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs"
244
  if not api_key:
245
  raise RuntimeError("GEMINI_API_KEY environment variable not set")
246
  client = genai.Client(api_key=api_key)
 
254
  You must output JSON with the following structure:
255
  {
256
  "summary": "<short natural language overview of dataset>",
257
+ "recommended_charts": [
258
+ {
259
+ "type": "<one of: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap>",
260
+ "title": "<short title for chart>",
261
+ "columns": ["<col1>", "<col2>", "..."],
262
+ "python_code": "<full runnable Python code using seaborn/matplotlib that produces the chart>"
263
+ },
264
+ ...
265
+ ]
266
  }
267
+ Mandatory rules:
268
+ - Always produce syntactically valid JSON ONLY. No text outside the JSON object.
269
+ - Provide at least these chart types somewhere in recommended_charts: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap.
270
+ - Use only column names that appear in meta['column_names'].
271
+ - The python_code string must be self-contained and runnable assuming a variable `df` exists containing the full cleaned DataFrame. Start the code with imports:
272
+ import pandas as pd
273
+ import seaborn as sns
274
+ import matplotlib.pyplot as plt
275
+ and include any necessary preprocessing steps (e.g., parsing dates).
276
+ - For timeseries charts ensure the datetime column is parsed (`pd.to_datetime`) before plotting.
277
+ - For multiple_columns provide a pairplot or facetgrid example that uses up to 4 numeric columns or sensible categorical splits.
278
+ - For stacked_bar, show aggregation code (groupby + unstack) and plotting with df.plot(kind='bar', stacked=True).
279
+ - For heatmap, compute correlation matrix and plot sns.heatmap with annotations.
280
+ - For pie charts, ensure grouping/aggregation when there are >20 unique categories (group small categories into 'Other').
281
+ - For histogram and scatter include axis labels and tight_layout; include plt.show() at the end.
282
+ - Keep code minimal but complete so a user can copy-paste and run (assume seaborn, matplotlib, pandas installed).
283
+ - For each chart add a sensible "columns" list showing which columns the code uses.
284
+ - Do not include examples using columns not present in meta.
285
+ - Do not include more than 10 recommended_charts.
286
+ - Ensure strings inside the JSON are escaped properly so the JSON parses.
287
+ Produce concise natural-language one-line summary in "summary". Ensure JSON is parseable by json.loads in Python.
288
  """
289
 
290
  user_prompt = {"meta": meta, "fiverow": fiverow}
 
372
  elif chart_type == "timeseries":
373
  df_copy = df[columns].copy()
374
  for c in columns:
 
375
  if not pd.api.types.is_datetime64_any_dtype(df_copy[c]):
376
  df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce')
377
  chart_data = df_copy.astype(str).to_dict(orient="records")
 
487
 
488
  if existing:
489
  snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id"))
 
490
  summary = existing.get("summary") or {}
491
  chart_data = existing.get("chart_data") or {}
492
  metadata = existing.get("metadata") or meta