|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Query |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
import pandas as pd |
|
|
import os |
|
|
import json |
|
|
import tempfile |
|
|
import hashlib |
|
|
import uuid |
|
|
from datetime import datetime, timezone |
|
|
from typing import Optional, List |
|
|
from pydantic import BaseModel |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
import logging |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from functools import partial |
|
|
import re |
|
|
import traceback |
|
|
import motor.motor_asyncio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MONGO_URI = os.getenv("MONGO_URI", "mongodb+srv://curseofwitcher:curseofwitcher@aianalyticsdata.btxby1j.mongodb.net/?retryWrites=true&w=majority&appName=aiAnalyticsData") |
|
|
DB_NAME = os.getenv("DB_NAME", "data_analysis") |
|
|
SNAPSHOT_BUCKET = os.getenv("SNAPSHOT_DIR", "/tmp/snapshots") |
|
|
os.makedirs(SNAPSHOT_BUCKET, exist_ok=True) |
|
|
MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE_BYTES", 200 * 1024 * 1024)) |
|
|
METADATA_ONLY_FALLBACK = os.getenv("METADATA_ONLY_FALLBACK", "true").lower() == "true" |
|
|
TTL_DAYS = int(os.getenv("SNAPSHOT_TTL_DAYS", "0")) |
|
|
EXECUTOR_WORKERS = int(os.getenv("EXECUTOR_WORKERS", "2")) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("app") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Data Analysis API", version="1.0.0") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(MONGO_URI) |
|
|
db = mongo_client[DB_NAME] |
|
|
snapshots = db.snapshots |
|
|
|
|
|
|
|
|
EXECUTOR = ThreadPoolExecutor(max_workers=EXECUTOR_WORKERS) |
|
|
|
|
|
|
|
|
os.makedirs("/tmp", exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
class AnalysisResponse(BaseModel): |
|
|
summary: dict |
|
|
chart_data: dict |
|
|
metadata: dict |
|
|
|
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
error: str |
|
|
details: Optional[str] = None |
|
|
|
|
|
|
|
|
class DrillRequest(BaseModel): |
|
|
snapshot_id: str |
|
|
filter_column: str |
|
|
filter_value: str |
|
|
limit: Optional[int] = 100 |
|
|
offset: Optional[int] = 0 |
|
|
highlight_columns: Optional[List[str]] = None |
|
|
|
|
|
|
|
|
|
|
|
def sha256_bytes(data: bytes) -> str: |
|
|
h = hashlib.sha256() |
|
|
h.update(data) |
|
|
return h.hexdigest() |
|
|
|
|
|
|
|
|
def sha256_text(text: str) -> str: |
|
|
return sha256_bytes(text.encode("utf-8")) |
|
|
|
|
|
|
|
|
def sha256_obj(obj) -> str: |
|
|
text = json.dumps(obj, sort_keys=True, default=str) |
|
|
return sha256_text(text) |
|
|
|
|
|
|
|
|
def canonical_types(df: pd.DataFrame) -> dict: |
|
|
def map_type(dtype): |
|
|
if pd.api.types.is_integer_dtype(dtype) or pd.api.types.is_float_dtype(dtype): |
|
|
return "numeric" |
|
|
if pd.api.types.is_datetime64_any_dtype(dtype): |
|
|
return "datetime" |
|
|
return "object" |
|
|
return {col: map_type(dtype) for col, dtype in df.dtypes.items()} |
|
|
|
|
|
|
|
|
def _find_balanced_json(s: str): |
|
|
first = s.find('{') |
|
|
if first == -1: |
|
|
return None |
|
|
stack = [] |
|
|
for i in range(first, len(s)): |
|
|
ch = s[i] |
|
|
if ch == '{': |
|
|
stack.append('{') |
|
|
elif ch == '}': |
|
|
if not stack: |
|
|
return None |
|
|
stack.pop() |
|
|
if not stack: |
|
|
return s[first:i+1] |
|
|
return None |
|
|
|
|
|
|
|
|
def _escape_problematic_backslashes(s: str) -> str: |
|
|
return re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', s) |
|
|
|
|
|
|
|
|
def safe_json_loads(raw_text: str): |
|
|
try: |
|
|
return json.loads(raw_text) |
|
|
except Exception as e1: |
|
|
err1 = str(e1) |
|
|
|
|
|
subset = _find_balanced_json(raw_text) |
|
|
if subset: |
|
|
try: |
|
|
return json.loads(subset) |
|
|
except Exception as e2: |
|
|
err2 = str(e2) |
|
|
else: |
|
|
err2 = "no balanced braces found" |
|
|
|
|
|
try: |
|
|
fixed = _escape_problematic_backslashes(subset or raw_text) |
|
|
return json.loads(fixed) |
|
|
except Exception as e3: |
|
|
err3 = str(e3) |
|
|
|
|
|
diagnostic = { |
|
|
"direct_error": err1, |
|
|
"subset_error": err2, |
|
|
"escaped_error": err3, |
|
|
"raw_snippet": (raw_text[:4000] + '...') if len(raw_text) > 4000 else raw_text |
|
|
} |
|
|
raise ValueError("Unable to parse JSON from model output. Diagnostic: " + json.dumps(diagnostic)) |
|
|
|
|
|
|
|
|
def data_fingerprint(df: pd.DataFrame, n_sample_rows: int = 100) -> str: |
|
|
df2 = df.copy() |
|
|
df2 = df2.reindex(sorted(df2.columns), axis=1) |
|
|
head = df2.head(n_sample_rows).to_json(orient="split", date_format="iso", force_ascii=False) |
|
|
tail = df2.tail(n_sample_rows).to_json(orient="split", date_format="iso", force_ascii=False) |
|
|
col_aggs = {c: {"nunique": int(df2[c].nunique()), "nulls": int(df2[c].isnull().sum())} for c in df2.columns} |
|
|
text = head + tail + json.dumps(col_aggs, sort_keys=True, default=str) |
|
|
return hashlib.sha256(text.encode("utf-8")).hexdigest() |
|
|
|
|
|
|
|
|
def stream_save_and_hash(upload_file: UploadFile, tmp_path: str, size_limit: Optional[int] = None) -> str: |
|
|
h = hashlib.sha256() |
|
|
total = 0 |
|
|
upload_file.file.seek(0) |
|
|
with open(tmp_path, "wb") as f: |
|
|
while True: |
|
|
chunk = upload_file.file.read(8192) |
|
|
if not chunk: |
|
|
break |
|
|
f.write(chunk) |
|
|
h.update(chunk) |
|
|
total += len(chunk) |
|
|
if size_limit and total > size_limit: |
|
|
raise HTTPException(status_code=413, detail="Uploaded file exceeds maximum allowed size") |
|
|
return h.hexdigest() |
|
|
|
|
|
|
|
|
async def save_preprocessed_df(df: pd.DataFrame, snapshot_id: str) -> str: |
|
|
path = os.path.join(SNAPSHOT_BUCKET, f"{snapshot_id}.csv") |
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
await loop.run_in_executor(EXECUTOR, partial(df.to_csv, path, index=False)) |
|
|
return path |
|
|
|
|
|
|
|
|
def load_file_from_path(file_path: str, original_filename: str) -> pd.DataFrame: |
|
|
ext = os.path.splitext(original_filename)[-1].lower() |
|
|
if ext == ".csv": |
|
|
return pd.read_csv(file_path) |
|
|
elif ext in [".xls", ".xlsx"]: |
|
|
return pd.read_excel(file_path, sheet_name=0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported file type: {ext}") |
|
|
|
|
|
|
|
|
def preprocess(df: pd.DataFrame, drop_thresh=0.5) -> pd.DataFrame: |
|
|
df = df.copy() |
|
|
df.columns = [str(c).strip().lower().replace(" ", "_") for c in df.columns] |
|
|
df = df.loc[:, df.isnull().mean() < drop_thresh] |
|
|
|
|
|
for col in df.columns: |
|
|
if pd.api.types.is_numeric_dtype(df[col]): |
|
|
df.loc[:, col] = df[col].fillna(df[col].median()) |
|
|
elif pd.api.types.is_datetime64_any_dtype(df[col]): |
|
|
df.loc[:, col] = df[col].fillna(pd.Timestamp('1970-01-01')) |
|
|
else: |
|
|
df.loc[:, col] = df[col].fillna("Unknown") |
|
|
|
|
|
for col in df.columns: |
|
|
if df[col].dtype == 'object': |
|
|
try: |
|
|
df.loc[:, col] = pd.to_numeric(df[col]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
df = df.drop_duplicates() |
|
|
return df |
|
|
|
|
|
|
|
|
def get_metadata(df: pd.DataFrame) -> dict: |
|
|
return { |
|
|
"rows": int(df.shape[0]), |
|
|
"columns": int(df.shape[1]), |
|
|
"column_names": list(df.columns), |
|
|
"column_types": {col: str(dtype) for col, dtype in df.dtypes.items()}, |
|
|
"unique_values": {col: int(df[col].nunique()) for col in df.columns} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def generate_summary_blocking(meta, fiverow, system_prompt_override: Optional[str] = None): |
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") or "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs" |
|
|
if not api_key: |
|
|
raise RuntimeError("GEMINI_API_KEY environment variable not set") |
|
|
client = genai.Client(api_key=api_key) |
|
|
model = "gemini-2.5-flash-lite" |
|
|
|
|
|
system_prompt = system_prompt_override or """ |
|
|
You are a strict JSON generator. |
|
|
Input contains: |
|
|
- meta: dataframe metadata |
|
|
- fiverow: first 5 records of dataframe |
|
|
You must output JSON with the following structure: |
|
|
{ |
|
|
"summary": "<short natural language overview of dataset>", |
|
|
"recommended_charts": [ |
|
|
{ |
|
|
"type": "<one of: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap>", |
|
|
"title": "<short title for chart>", |
|
|
"columns": ["<col1>", "<col2>", "..."], |
|
|
"python_code": "<full runnable Python code using seaborn/matplotlib that produces the chart>" |
|
|
}, |
|
|
... |
|
|
] |
|
|
} |
|
|
Mandatory rules: |
|
|
- Always produce syntactically valid JSON ONLY. No text outside the JSON object. |
|
|
- Provide at least these chart types somewhere in recommended_charts: bar, pie, timeseries, histogram, scatter, multiple_columns, stacked_bar, heatmap. |
|
|
- Use only column names that appear in meta['column_names']. |
|
|
- The python_code string must be self-contained and runnable assuming a variable `df` exists containing the full cleaned DataFrame. Start the code with imports: |
|
|
import pandas as pd |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
and include any necessary preprocessing steps (e.g., parsing dates). |
|
|
- For timeseries charts ensure the datetime column is parsed (`pd.to_datetime`) before plotting. |
|
|
- For multiple_columns provide a pairplot or facetgrid example that uses up to 4 numeric columns or sensible categorical splits. |
|
|
- For stacked_bar, show aggregation code (groupby + unstack) and plotting with df.plot(kind='bar', stacked=True). |
|
|
- For heatmap, compute correlation matrix and plot sns.heatmap with annotations. |
|
|
- For pie charts, ensure grouping/aggregation when there are >20 unique categories (group small categories into 'Other'). |
|
|
- For histogram and scatter include axis labels and tight_layout; include plt.show() at the end. |
|
|
- Keep code minimal but complete so a user can copy-paste and run (assume seaborn, matplotlib, pandas installed). |
|
|
- For each chart add a sensible "columns" list showing which columns the code uses. |
|
|
- Do not include examples using columns not present in meta. |
|
|
- Do not include more than 10 recommended_charts. |
|
|
- Ensure strings inside the JSON are escaped properly so the JSON parses. |
|
|
Produce concise natural-language one-line summary in "summary". Ensure JSON is parseable by json.loads in Python. |
|
|
""" |
|
|
|
|
|
user_prompt = {"meta": meta, "fiverow": fiverow} |
|
|
contents = [ |
|
|
types.Content( |
|
|
role="user", |
|
|
parts=[types.Part.from_text(text=str(user_prompt))], |
|
|
), |
|
|
] |
|
|
generate_content_config = types.GenerateContentConfig( |
|
|
thinking_config=types.ThinkingConfig(thinking_budget=0), |
|
|
response_mime_type="application/json", |
|
|
system_instruction=[types.Part.from_text(text=system_prompt)], |
|
|
) |
|
|
|
|
|
raw = "" |
|
|
try: |
|
|
for chunk in client.models.generate_content_stream( |
|
|
model=model, |
|
|
contents=contents, |
|
|
config=generate_content_config, |
|
|
): |
|
|
if chunk.text: |
|
|
raw += chunk.text |
|
|
except Exception as e: |
|
|
logger.error("AI generation stream error: %s\n%s", str(e), traceback.format_exc()) |
|
|
raise RuntimeError("AI generation failed: " + str(e)) |
|
|
|
|
|
logger.debug("AI raw output (trimmed): %s", raw[:2000]) |
|
|
|
|
|
try: |
|
|
parsed = safe_json_loads(raw) |
|
|
except Exception as e: |
|
|
logger.error("Failed to parse AI JSON. Raw (trimmed): %s", raw[:2000]) |
|
|
raise RuntimeError(f"AI JSON parse error: {e}") |
|
|
|
|
|
if not isinstance(parsed, dict) or "summary" not in parsed or "recommended_charts" not in parsed: |
|
|
logger.error("AI output missing required keys. Parsed keys: %s", list(parsed.keys()) if isinstance(parsed, dict) else type(parsed)) |
|
|
raise RuntimeError("AI output missing required keys: 'summary' and 'recommended_charts' required") |
|
|
|
|
|
return parsed |
|
|
|
|
|
|
|
|
async def generate_summary_async(meta, fiverow, system_prompt_override: Optional[str] = None): |
|
|
loop = asyncio.get_running_loop() |
|
|
return await loop.run_in_executor(EXECUTOR, generate_summary_blocking, meta, fiverow, system_prompt_override) |
|
|
|
|
|
|
|
|
|
|
|
def flatten_columns(df): |
|
|
if isinstance(df.columns, pd.MultiIndex): |
|
|
df.columns = ['_'.join(map(str, col)).strip() for col in df.columns.values] |
|
|
return df |
|
|
|
|
|
|
|
|
def extract_chart_data_json_by_type(parsed_summary: dict, df: pd.DataFrame): |
|
|
try: |
|
|
result = {} |
|
|
for chart in parsed_summary.get("recommended_charts", []): |
|
|
chart_type = chart.get("type") |
|
|
columns = chart.get("columns", []) or [] |
|
|
title = chart.get("title", "unnamed_chart") |
|
|
if chart_type not in result: |
|
|
result[chart_type] = [] |
|
|
try: |
|
|
if chart_type == "bar": |
|
|
df_agg = df[columns].groupby(columns[0]).sum(numeric_only=True).reset_index() |
|
|
chart_data = df_agg.to_dict(orient="records") |
|
|
elif chart_type == "stacked_bar": |
|
|
df_agg = df.groupby(columns).sum(numeric_only=True).unstack() |
|
|
df_agg = flatten_columns(df_agg) |
|
|
chart_data = df_agg.fillna(0).to_dict(orient="records") |
|
|
elif chart_type == "pie": |
|
|
col = columns[0] |
|
|
counts = df[col].value_counts() |
|
|
if len(counts) > 20: |
|
|
top = counts.nlargest(19) |
|
|
others = counts.iloc[19:].sum() |
|
|
counts = pd.concat([top, pd.Series({'Other': others})]) |
|
|
chart_data = counts.reset_index().rename(columns={'index': col, col: 'value'}).to_dict(orient="records") |
|
|
elif chart_type == "histogram": |
|
|
chart_data = df[columns[0]].dropna().tolist() |
|
|
elif chart_type == "scatter": |
|
|
chart_data = df[columns].to_dict(orient="records") |
|
|
elif chart_type == "timeseries": |
|
|
df_copy = df[columns].copy() |
|
|
for c in columns: |
|
|
if not pd.api.types.is_datetime64_any_dtype(df_copy[c]): |
|
|
df_copy[c] = pd.to_datetime(df_copy[c], errors='coerce') |
|
|
chart_data = df_copy.astype(str).to_dict(orient="records") |
|
|
elif chart_type == "multiple_columns": |
|
|
chart_data = df[columns].to_dict(orient="records") |
|
|
elif chart_type == "heatmap": |
|
|
corr_df = df[columns].corr().fillna(0) |
|
|
chart_data = flatten_columns(corr_df).to_dict() |
|
|
else: |
|
|
chart_data = [] |
|
|
except Exception as e: |
|
|
chart_data = {"error": str(e)} |
|
|
result[chart_type].append({"title": title, "data": chart_data}) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error("Error extracting chart data: %s\n%s", str(e), traceback.format_exc()) |
|
|
raise RuntimeError(f"Error extracting chart data: {e}") |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def create_indexes(): |
|
|
try: |
|
|
await snapshots.create_index("file_hash") |
|
|
await snapshots.create_index("data_hash") |
|
|
await snapshots.create_index("meta_hash") |
|
|
await snapshots.create_index("snapshot_id", unique=True) |
|
|
if TTL_DAYS > 0: |
|
|
await snapshots.create_index("created_at_dt", expireAfterSeconds=TTL_DAYS * 24 * 3600) |
|
|
logger.info("Mongo indexes ensured.") |
|
|
except Exception: |
|
|
logger.exception("Error creating indexes") |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Data Analysis API is running"} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
@app.post("/analyze", response_model=AnalysisResponse) |
|
|
async def analyze_data(file: UploadFile = File(...)): |
|
|
if not file.filename: |
|
|
raise HTTPException(status_code=400, detail="No file provided") |
|
|
allowed_extensions = ['.csv', '.xls', '.xlsx'] |
|
|
file_ext = os.path.splitext(file.filename)[-1].lower() |
|
|
if file_ext not in allowed_extensions: |
|
|
raise HTTPException(status_code=400, detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: |
|
|
tmp_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
file_hash = stream_save_and_hash(file, tmp_path, size_limit=MAX_UPLOAD_SIZE) |
|
|
except HTTPException: |
|
|
try: |
|
|
os.unlink(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
raise |
|
|
except Exception as e: |
|
|
try: |
|
|
os.unlink(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
logger.exception("Error saving uploaded file") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
try: |
|
|
df = load_file_from_path(tmp_path, file.filename) |
|
|
except Exception as e: |
|
|
logger.exception("Error loading file") |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
try: |
|
|
df_clean = preprocess(df) |
|
|
except Exception as e: |
|
|
logger.exception("Error preprocessing file") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
meta = get_metadata(df_clean) |
|
|
fiverow = df_clean.head(5).to_dict(orient="records") |
|
|
|
|
|
|
|
|
data_hash = data_fingerprint(df_clean) |
|
|
meta_hash = sha256_obj({ |
|
|
"rows": meta["rows"], |
|
|
"columns": meta["columns"], |
|
|
"column_names": meta["column_names"], |
|
|
"column_types": canonical_types(df_clean), |
|
|
}) |
|
|
|
|
|
|
|
|
existing = await snapshots.find_one({"file_hash": file_hash}) |
|
|
cache_hit = None |
|
|
if not existing: |
|
|
existing = await snapshots.find_one({"data_hash": data_hash}) |
|
|
if existing: |
|
|
cache_hit = "data" |
|
|
if not existing and METADATA_ONLY_FALLBACK: |
|
|
existing = await snapshots.find_one({"meta_hash": meta_hash}) |
|
|
if existing: |
|
|
cache_hit = "meta" |
|
|
|
|
|
if existing: |
|
|
snapshot_id_return = existing.get("snapshot_id") or str(existing.get("_id")) |
|
|
summary = existing.get("summary") or {} |
|
|
chart_data = existing.get("chart_data") or {} |
|
|
metadata = existing.get("metadata") or meta |
|
|
return AnalysisResponse(summary=summary, chart_data=chart_data, metadata=metadata) |
|
|
|
|
|
|
|
|
snapshot_id = uuid.uuid4().hex |
|
|
created_at_iso = datetime.now(timezone.utc).isoformat() |
|
|
created_at_dt = datetime.now(timezone.utc) |
|
|
doc = { |
|
|
"snapshot_id": snapshot_id, |
|
|
"filename": file.filename, |
|
|
"file_hash": file_hash, |
|
|
"data_hash": data_hash, |
|
|
"meta_hash": meta_hash, |
|
|
"metadata": meta, |
|
|
"summary": None, |
|
|
"chart_data": None, |
|
|
"preprocessed_path": None, |
|
|
"status": "processing", |
|
|
"created_at": created_at_iso, |
|
|
"created_at_dt": created_at_dt, |
|
|
} |
|
|
await snapshots.insert_one(doc) |
|
|
|
|
|
|
|
|
try: |
|
|
summary_obj = await generate_summary_async(meta, fiverow) |
|
|
except Exception as e: |
|
|
await snapshots.update_one({"snapshot_id": snapshot_id}, {"$set": {"status": "failed", "error": str(e)}}) |
|
|
logger.exception("AI generation failed") |
|
|
raise HTTPException(status_code=500, detail=f"AI generation failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
chart_data = extract_chart_data_json_by_type(summary_obj, df_clean) |
|
|
except Exception as e: |
|
|
await snapshots.update_one({"snapshot_id": snapshot_id}, {"$set": {"status": "failed", "error": str(e)}}) |
|
|
logger.exception("Chart extraction failed") |
|
|
raise HTTPException(status_code=500, detail=f"Chart extraction failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
preprocessed_path = await save_preprocessed_df(df_clean, snapshot_id) |
|
|
except Exception as e: |
|
|
await snapshots.update_one({"snapshot_id": snapshot_id}, {"$set": {"status": "failed", "error": str(e)}}) |
|
|
logger.exception("Saving preprocessed failed") |
|
|
raise HTTPException(status_code=500, detail=f"Saving preprocessed failed: {e}") |
|
|
|
|
|
|
|
|
await snapshots.update_one( |
|
|
{"snapshot_id": snapshot_id}, |
|
|
{"$set": { |
|
|
"summary": summary_obj, |
|
|
"chart_data": chart_data, |
|
|
"preprocessed_path": preprocessed_path, |
|
|
"status": "done", |
|
|
"completed_at": datetime.now(timezone.utc).isoformat() |
|
|
}} |
|
|
) |
|
|
|
|
|
return AnalysisResponse(summary=summary_obj, chart_data=chart_data, metadata=meta) |
|
|
|
|
|
finally: |
|
|
try: |
|
|
os.unlink(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
@app.get("/snapshots") |
|
|
async def list_snapshots(limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0)): |
|
|
cursor = snapshots.find({}, {"preprocessed_path": 0, "summary": 0, "chart_data": 0}).sort("created_at_dt", -1).skip(offset).limit(limit) |
|
|
items = [] |
|
|
async for doc in cursor: |
|
|
items.append({ |
|
|
"id": doc.get("snapshot_id") or str(doc.get("_id")), |
|
|
"filename": doc.get("filename"), |
|
|
"metadata": doc.get("metadata"), |
|
|
"status": doc.get("status"), |
|
|
"created_at": doc.get("created_at"), |
|
|
}) |
|
|
return {"count": len(items), "items": items} |
|
|
|
|
|
|
|
|
@app.get("/snapshot/{snapshot_id}") |
|
|
async def get_snapshot(snapshot_id: str): |
|
|
doc = await snapshots.find_one({"snapshot_id": snapshot_id}) |
|
|
if not doc: |
|
|
raise HTTPException(status_code=404, detail="Snapshot not found") |
|
|
return { |
|
|
"id": doc["snapshot_id"], |
|
|
"filename": doc.get("filename"), |
|
|
"metadata": doc.get("metadata"), |
|
|
"summary": doc.get("summary"), |
|
|
"chart_data": doc.get("chart_data"), |
|
|
"status": doc.get("status"), |
|
|
"created_at": doc.get("created_at"), |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/preprocessed/{snapshot_id}") |
|
|
async def get_preprocessed(snapshot_id: str, limit: int = 100, offset: int = 0): |
|
|
doc = await snapshots.find_one({"snapshot_id": snapshot_id}) |
|
|
if not doc: |
|
|
raise HTTPException(status_code=404, detail="Snapshot not found") |
|
|
path = doc.get("preprocessed_path") |
|
|
if not path or not os.path.exists(path): |
|
|
raise HTTPException(status_code=404, detail="Preprocessed data not available") |
|
|
df = pd.read_csv(path) |
|
|
total = len(df) |
|
|
rows = df.iloc[offset: offset + limit].to_dict(orient="records") |
|
|
return {"total": total, "offset": offset, "limit": limit, "rows": rows} |
|
|
|
|
|
|
|
|
@app.post("/drill") |
|
|
async def drill(req: DrillRequest): |
|
|
doc = await snapshots.find_one({"snapshot_id": req.snapshot_id}) |
|
|
if not doc: |
|
|
raise HTTPException(status_code=404, detail="Snapshot not found") |
|
|
path = doc.get("preprocessed_path") |
|
|
if not path or not os.path.exists(path): |
|
|
raise HTTPException(status_code=404, detail="Preprocessed data not available") |
|
|
df = pd.read_csv(path) |
|
|
if req.filter_column not in df.columns: |
|
|
raise HTTPException(status_code=400, detail=f"Column {req.filter_column} not found in preprocessed data") |
|
|
try: |
|
|
filtered = df[df[req.filter_column] == req.filter_value] |
|
|
if filtered.empty: |
|
|
filtered = df[df[req.filter_column].astype(str) == str(req.filter_value)] |
|
|
except Exception: |
|
|
filtered = df[df[req.filter_column].astype(str) == str(req.filter_value)] |
|
|
total = len(filtered) |
|
|
rows = filtered.iloc[req.offset: req.offset + req.limit].to_dict(orient="records") |
|
|
highlights = req.highlight_columns or [req.filter_column] |
|
|
highlights = [c for c in highlights if c in df.columns] |
|
|
return { |
|
|
"snapshot_id": req.snapshot_id, |
|
|
"filter_column": req.filter_column, |
|
|
"filter_value": req.filter_value, |
|
|
"total_matches": total, |
|
|
"offset": req.offset, |
|
|
"limit": req.limit, |
|
|
"rows": rows, |
|
|
"highlight_columns": highlights, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request, exc): |
|
|
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) |
|
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def general_exception_handler(request, exc): |
|
|
logger.exception("Unhandled exception") |
|
|
return JSONResponse(status_code=500, content={"error": "Internal server error", "details": str(exc)}) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860"))) |
|
|
|