triflix commited on
Commit
167ca93
·
verified ·
1 Parent(s): c68da81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +468 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fastapi_snapshot_app_improved.py
2
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ import pandas as pd
6
+ import os
7
+ import json
8
+ import tempfile
9
+ from typing import Optional, List
10
+ from pydantic import BaseModel
11
+ from google import genai
12
+ from google.genai import types
13
+ import logging
14
+ import hashlib
15
+ import uuid
16
+ from datetime import datetime, timezone
17
+ import motor.motor_asyncio
18
+ import asyncio
19
+ from concurrent.futures import ThreadPoolExecutor
20
+
21
+ # ----------------------------
22
+ # Configuration
23
+ # ----------------------------
24
+ MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017")
25
+ DB_NAME = os.getenv("DB_NAME", "data_analysis")
26
+ SNAPSHOT_BUCKET = os.getenv("SNAPSHOT_DIR", "/tmp/snapshots")
27
+ os.makedirs(SNAPSHOT_BUCKET, exist_ok=True)
28
+ MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE_BYTES", 200 * 1024 * 1024)) # 200MB default
29
+ METADATA_ONLY_FALLBACK = os.getenv("METADATA_ONLY_FALLBACK", "true").lower() == "true"
30
+ TTL_DAYS = int(os.getenv("SNAPSHOT_TTL_DAYS", "0")) # 0 = no TTL
31
+
32
+ # Setup logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # FastAPI app
37
+ app = FastAPI(title="Data Analysis API with Snapshotting", version="3.0.0")
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=["*"],
41
+ allow_credentials=True,
42
+ allow_methods=["*"],
43
+ allow_headers=["*"],
44
+ )
45
+
46
+ # Mongo client (async)
47
+ mongo_client = motor.motor_asyncio.AsyncIOMotorClient(MONGO_URI)
48
+ db = mongo_client[DB_NAME]
49
+ snapshots = db.snapshots
50
+
51
+ # Thread pool for blocking tasks (AI calls, heavy pandas ops)
52
+ EXECUTOR = ThreadPoolExecutor(max_workers=int(os.getenv("EXECUTOR_WORKERS", "2")))
53
+
54
+
55
+ # ---------- Helpers ----------
56
+ def sha256_bytes(data: bytes) -> str:
57
+ h = hashlib.sha256()
58
+ h.update(data)
59
+ return h.hexdigest()
60
+
61
+
62
+ def sha256_text(text: str) -> str:
63
+ return sha256_bytes(text.encode("utf-8"))
64
+
65
+
66
+ def sha256_obj(obj) -> str:
67
+ text = json.dumps(obj, sort_keys=True, default=str)
68
+ return sha256_text(text)
69
+
70
+
71
+ def canonical_types(df: pd.DataFrame) -> dict:
72
+ def map_type(dtype):
73
+ if pd.api.types.is_integer_dtype(dtype) or pd.api.types.is_float_dtype(dtype):
74
+ return "numeric"
75
+ if pd.api.types.is_datetime64_any_dtype(dtype):
76
+ return "datetime"
77
+ return "object"
78
+ return {col: map_type(dtype) for col, dtype in df.dtypes.items()}
79
+
80
+
81
+ async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str:
82
+ path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv")
83
+ # use pandas to_csv which is blocking; run in executor to avoid blocking event loop
84
+ loop = asyncio.get_running_loop()
85
+ await loop.run_in_executor(EXECUTOR, df.to_csv, path, False, False, None)
86
+ return path
87
+
88
+
89
+ def load_file_from_path(file_path: str, original_filename: str) -> pd.DataFrame:
90
+ ext = os.path.splitext(original_filename)[-1].lower()
91
+ if ext == ".csv":
92
+ # try common encodings; let pandas infer by default
93
+ return pd.read_csv(file_path)
94
+ elif ext in [".xls", ".xlsx"]:
95
+ return pd.read_excel(file_path, sheet_name=0)
96
+ else:
97
+ raise ValueError(f"Unsupported file type: {ext}")
98
+
99
+
100
+ def preprocess(df: pd.DataFrame, drop_thresh=0.5) -> pd.DataFrame:
101
+ df = df.copy()
102
+ df.columns = [str(c).strip().lower().replace(" ", "_") for c in df.columns]
103
+ df = df.loc[:, df.isnull().mean() < drop_thresh]
104
+
105
+ for col in df.columns:
106
+ if pd.api.types.is_numeric_dtype(df[col]):
107
+ df.loc[:, col] = df[col].fillna(df[col].median())
108
+ elif pd.api.types.is_datetime64_any_dtype(df[col]):
109
+ df.loc[:, col] = df[col].fillna(pd.Timestamp('1970-01-01'))
110
+ else:
111
+ df.loc[:, col] = df[col].fillna("Unknown")
112
+
113
+ for col in df.columns:
114
+ if df[col].dtype == 'object':
115
+ try:
116
+ df.loc[:, col] = pd.to_numeric(df[col])
117
+ except Exception:
118
+ pass
119
+
120
+ df = df.drop_duplicates()
121
+ return df
122
+
123
+
124
+ def get_metadata(df: pd.DataFrame) -> dict:
125
+ return {
126
+ "rows": int(df.shape[0]),
127
+ "columns": int(df.shape[1]),
128
+ "column_names": list(df.columns),
129
+ "column_types": {col: str(dtype) for col, dtype in df.dtypes.items()},
130
+ "unique_values": {col: int(df[col].nunique()) for col in df.columns}
131
+ }
132
+
133
+
134
+ def data_fingerprint(df: pd.DataFrame, n_sample_rows: int = 100) -> str:
135
+ # Deterministic fingerprint: canonical column order, sample head & tail JSON + aggregated stats
136
+ df2 = df.copy()
137
+ df2 = df2.reindex(sorted(df2.columns), axis=1)
138
+ head = df2.head(n_sample_rows).to_json(orient="split", date_format="iso", force_ascii=False)
139
+ tail = df2.tail(n_sample_rows).to_json(orient="split", date_format="iso", force_ascii=False)
140
+ col_aggs = {c: {"nunique": int(df2[c].nunique()), "nulls": int(df2[c].isnull().sum())} for c in df2.columns}
141
+ text = head + tail + json.dumps(col_aggs, sort_keys=True, default=str)
142
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
143
+
144
+
145
+ def stream_save_and_hash(upload_file: UploadFile, tmp_path: str, size_limit: Optional[int] = None) -> str:
146
+ h = hashlib.sha256()
147
+ total = 0
148
+ with open(tmp_path, "wb") as f:
149
+ while True:
150
+ chunk = upload_file.file.read(8192)
151
+ if not chunk:
152
+ break
153
+ f.write(chunk)
154
+ h.update(chunk)
155
+ total += len(chunk)
156
+ if size_limit and total > size_limit:
157
+ raise HTTPException(status_code=413, detail="Uploaded file exceeds maximum allowed size")
158
+ return h.hexdigest()
159
+
160
+
161
+ # ---------- AI interaction (blocking) ----------
162
+ def generate_summary_blocking(meta, fiverow) -> str:
163
+ api_key = os.getenv("GEMINI_API_KEY")
164
+ if not api_key:
165
+ raise RuntimeError("GEMINI_API_KEY not set")
166
+ client = genai.Client(api_key=api_key)
167
+ model = "gemini-2.5-flash-lite"
168
+ system_prompt = """
169
+ You are a strict JSON generator.
170
+ Input contains:
171
+ - meta: dataframe metadata
172
+ - fiverow: first 5 records of dataframe
173
+ You must output JSON with the following structure:
174
+ { "summary": "<short natural language overview>", "recommended_charts": [ ... ] }
175
+ Always produce syntactically valid JSON ONLY.
176
+ """
177
+ user_prompt = {"meta": meta, "fiverow": fiverow}
178
+ contents = [
179
+ types.Content(
180
+ role="user",
181
+ parts=[types.Part.from_text(text=str(user_prompt))],
182
+ ),
183
+ ]
184
+ generate_content_config = types.GenerateContentConfig(
185
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
186
+ response_mime_type="application/json",
187
+ system_instruction=[types.Part.from_text(text=system_prompt)],
188
+ )
189
+ response = ""
190
+ for chunk in client.models.generate_content_stream(
191
+ model=model,
192
+ contents=contents,
193
+ config=generate_content_config,
194
+ ):
195
+ if chunk.text:
196
+ response += chunk.text
197
+ try:
198
+ _ = json.loads(response)
199
+ except Exception as e:
200
+ logger.error("AI returned invalid JSON: %s", str(e))
201
+ raise RuntimeError("AI returned invalid JSON")
202
+ return response
203
+
204
+
205
+ async def generate_summary_async(meta, fiverow) -> str:
206
+ loop = asyncio.get_running_loop()
207
+ return await loop.run_in_executor(EXECUTOR, generate_summary_blocking, meta, fiverow)
208
+
209
+
210
+ # ---------- API Models ----------
211
+ class DrillRequest(BaseModel):
212
+ snapshot_id: str
213
+ filter_column: str
214
+ filter_value: str
215
+ limit: Optional[int] = 100
216
+ offset: Optional[int] = 0
217
+ highlight_columns: Optional[List[str]] = None
218
+
219
+
220
+ # ---------- Startup: indexes ----------
221
+ @app.on_event("startup")
222
+ async def create_indexes():
223
+ try:
224
+ await snapshots.create_index("file_hash")
225
+ await snapshots.create_index("data_hash")
226
+ await snapshots.create_index("meta_hash")
227
+ await snapshots.create_index("snapshot_id", unique=True)
228
+ if TTL_DAYS > 0:
229
+ await snapshots.create_index("created_at_dt", expireAfterSeconds=TTL_DAYS * 24 * 3600)
230
+ logger.info("Indexes ensured on snapshots collection")
231
+ except Exception:
232
+ logger.exception("Error creating indexes")
233
+
234
+
235
+ # ---------- Routes ----------
236
+ @app.get("/")
237
+ async def root():
238
+ return {"message": "Data Analysis API with snapshotting is running"}
239
+
240
+
241
+ @app.post("/analyze")
242
+ async def analyze(file: UploadFile = File(...)):
243
+ if not file.filename:
244
+ raise HTTPException(status_code=400, detail="No file provided")
245
+ allowed_extensions = ['.csv', '.xls', '.xlsx']
246
+ file_ext = os.path.splitext(file.filename)[-1].lower()
247
+ if file_ext not in allowed_extensions:
248
+ raise HTTPException(status_code=400, detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}")
249
+
250
+ # stream save + file hash (prevents OOM)
251
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
252
+ tmp_path = tmp_file.name
253
+ try:
254
+ file_hash = stream_save_and_hash(file, tmp_path, size_limit=MAX_UPLOAD_SIZE)
255
+ except HTTPException:
256
+ try:
257
+ os.unlink(tmp_path)
258
+ except Exception:
259
+ pass
260
+ raise
261
+ except Exception as e:
262
+ try:
263
+ os.unlink(tmp_path)
264
+ except Exception:
265
+ pass
266
+ logger.exception("Error saving uploaded file")
267
+ raise HTTPException(status_code=500, detail=str(e))
268
+
269
+ try:
270
+ # load and preprocess (blocking; small files ok). For very large files, consider streaming/parsing.
271
+ df = load_file_from_path(tmp_path, file.filename)
272
+ df_clean = preprocess(df)
273
+ meta = get_metadata(df_clean)
274
+ fiverow = df_clean.head(5).to_dict(orient="records")
275
+
276
+ # compute hashes: data_hash first, canonical meta_hash
277
+ data_hash = data_fingerprint(df_clean)
278
+ meta_hash = sha256_obj({
279
+ "rows": meta["rows"],
280
+ "columns": meta["columns"],
281
+ "column_names": meta["column_names"],
282
+ "column_types": canonical_types(df_clean),
283
+ })
284
+
285
+ # search order: exact file -> data_hash -> meta_hash (if allowed)
286
+ existing = await snapshots.find_one({"file_hash": file_hash})
287
+ cache_hit = None
288
+ if not existing:
289
+ existing = await snapshots.find_one({"data_hash": data_hash})
290
+ if existing:
291
+ cache_hit = "data"
292
+ if not existing and METADATA_ONLY_FALLBACK:
293
+ existing = await snapshots.find_one({"meta_hash": meta_hash})
294
+ if existing:
295
+ cache_hit = "meta"
296
+
297
+ if existing:
298
+ # return consistent snapshot_id
299
+ snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id"))
300
+ return {
301
+ "id": snapshot_id_return,
302
+ "summary": existing.get("summary"),
303
+ "chart_data": existing.get("chart_data"),
304
+ "metadata": existing.get("metadata"),
305
+ "created_at": existing.get("created_at"),
306
+ "cached": True,
307
+ "cache_hit": cache_hit or "file",
308
+ }
309
+
310
+ # Not found -> create processing snapshot doc with status
311
+ snapshot_id = uuid.uuid4().hex
312
+ created_at_iso = datetime.now(timezone.utc).isoformat()
313
+ created_at_dt = datetime.now(timezone.utc)
314
+
315
+ doc = {
316
+ "snapshot_id": snapshot_id,
317
+ "filename": file.filename,
318
+ "file_hash": file_hash,
319
+ "data_hash": data_hash,
320
+ "meta_hash": meta_hash,
321
+ "metadata": meta,
322
+ "summary": None,
323
+ "chart_data": None,
324
+ "preprocessed_path": None,
325
+ "status": "processing",
326
+ "created_at": created_at_iso,
327
+ "created_at_dt": created_at_dt,
328
+ }
329
+ await snapshots.insert_one(doc)
330
+
331
+ # Generate summary (blocking AI call offloaded to executor)
332
+ try:
333
+ summary_json = await generate_summary_async(meta, fiverow)
334
+ summary_obj = json.loads(summary_json)
335
+ chart_data = summary_obj.get("recommended_charts")
336
+ except Exception as e:
337
+ await snapshots.update_one({"snapshot_id": snapshot_id}, {"$set": {"status": "failed", "error": str(e)}})
338
+ raise
339
+
340
+ # save preprocessed csv for drilling and later retrieval (non-blocking via executor)
341
+ preprocessed_path = await save_preprocessed_df(df_clean, snapshot_id)
342
+
343
+ # finalize doc
344
+ await snapshots.update_one(
345
+ {"snapshot_id": snapshot_id},
346
+ {"$set": {
347
+ "summary": summary_obj,
348
+ "chart_data": chart_data,
349
+ "preprocessed_path": preprocessed_path,
350
+ "status": "done",
351
+ "completed_at": datetime.now(timezone.utc).isoformat()
352
+ }}
353
+ )
354
+
355
+ return {
356
+ "id": snapshot_id,
357
+ "summary": summary_obj,
358
+ "chart_data": chart_data,
359
+ "metadata": meta,
360
+ "created_at": created_at_iso,
361
+ "cached": False,
362
+ }
363
+
364
+ except HTTPException:
365
+ raise
366
+ except Exception as e:
367
+ logger.exception("Error processing file")
368
+ raise HTTPException(status_code=500, detail=str(e))
369
+ finally:
370
+ try:
371
+ os.unlink(tmp_path)
372
+ except Exception:
373
+ pass
374
+
375
+
376
+ @app.get("/snapshots")
377
+ async def list_snapshots(limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0)):
378
+ cursor = snapshots.find({}, {"preprocessed_path": 0, "summary": 0, "chart_data": 0}).sort("created_at_dt", -1).skip(offset).limit(limit)
379
+ items = []
380
+ async for doc in cursor:
381
+ items.append({
382
+ "id": doc.get("snapshot_id") or str(doc.get("_id")),
383
+ "filename": doc.get("filename"),
384
+ "metadata": doc.get("metadata"),
385
+ "status": doc.get("status"),
386
+ "created_at": doc.get("created_at"),
387
+ })
388
+ return {"count": len(items), "items": items}
389
+
390
+
391
+ @app.get("/snapshot/{snapshot_id}")
392
+ async def get_snapshot(snapshot_id: str):
393
+ doc = await snapshots.find_one({"snapshot_id": snapshot_id})
394
+ if not doc:
395
+ raise HTTPException(status_code=404, detail="Snapshot not found")
396
+ return {
397
+ "id": doc["snapshot_id"],
398
+ "filename": doc.get("filename"),
399
+ "metadata": doc.get("metadata"),
400
+ "summary": doc.get("summary"),
401
+ "chart_data": doc.get("chart_data"),
402
+ "status": doc.get("status"),
403
+ "created_at": doc.get("created_at"),
404
+ }
405
+
406
+
407
+ @app.get("/preprocessed/{snapshot_id}")
408
+ async def get_preprocessed(snapshot_id: str, limit: int = 100, offset: int = 0):
409
+ doc = await snapshots.find_one({"snapshot_id": snapshot_id})
410
+ if not doc:
411
+ raise HTTPException(status_code=404, detail="Snapshot not found")
412
+ path = doc.get("preprocessed_path")
413
+ if not path or not os.path.exists(path):
414
+ raise HTTPException(status_code=404, detail="Preprocessed data not available")
415
+ df = pd.read_csv(path)
416
+ total = len(df)
417
+ rows = df.iloc[offset: offset + limit].to_dict(orient="records")
418
+ return {"total": total, "offset": offset, "limit": limit, "rows": rows}
419
+
420
+
421
+ @app.post("/drill")
422
+ async def drill(req: DrillRequest):
423
+ doc = await snapshots.find_one({"snapshot_id": req.snapshot_id})
424
+ if not doc:
425
+ raise HTTPException(status_code=404, detail="Snapshot not found")
426
+ path = doc.get("preprocessed_path")
427
+ if not path or not os.path.exists(path):
428
+ raise HTTPException(status_code=404, detail="Preprocessed data not available")
429
+ df = pd.read_csv(path)
430
+ if req.filter_column not in df.columns:
431
+ raise HTTPException(status_code=400, detail=f"Column {req.filter_column} not found in preprocessed data")
432
+ try:
433
+ filtered = df[df[req.filter_column] == req.filter_value]
434
+ if filtered.empty:
435
+ filtered = df[df[req.filter_column].astype(str) == str(req.filter_value)]
436
+ except Exception:
437
+ filtered = df[df[req.filter_column].astype(str) == str(req.filter_value)]
438
+ total = len(filtered)
439
+ rows = filtered.iloc[req.offset: req.offset + req.limit].to_dict(orient="records")
440
+ highlights = req.highlight_columns or [req.filter_column]
441
+ highlights = [c for c in highlights if c in df.columns]
442
+ return {
443
+ "snapshot_id": req.snapshot_id,
444
+ "filter_column": req.filter_column,
445
+ "filter_value": req.filter_value,
446
+ "total_matches": total,
447
+ "offset": req.offset,
448
+ "limit": req.limit,
449
+ "rows": rows,
450
+ "highlight_columns": highlights,
451
+ }
452
+
453
+
454
+ # Global exception handlers
455
+ @app.exception_handler(HTTPException)
456
+ async def http_exception_handler(request, exc):
457
+ return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
458
+
459
+
460
+ @app.exception_handler(Exception)
461
+ async def general_exception_handler(request, exc):
462
+ logger.exception("Unhandled exception")
463
+ return JSONResponse(status_code=500, content={"error": "Internal server error", "details": str(exc)})
464
+
465
+
466
+ if __name__ == "__main__":
467
+ import uvicorn
468
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))