pravah / app /main.py
triflix's picture
Upload 17 files
ba3dc51 verified
import os
import time
from datetime import datetime, timezone
from html import escape
from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from .adapter import build_visualization_from_intent
from .config import settings
from .db import DatabaseHandler
from .knowledge import KB_CONFIG, KnowledgeManager
from .models import AiQueryRequest, AiQueryResponse
from .security import require_api_key
from .ai_engine import AIEngine
app = FastAPI(title="Pravah AI Backend", version="1.0.0")
origins = [o.strip() for o in settings.allowed_origins.split(",") if o.strip()]
if origins == ["*"]:
allow_origins = ["*"]
allow_credentials = False
else:
allow_origins = origins
allow_credentials = True
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=allow_credentials,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"]
if not settings.backend_api_key
else ["Content-Type", "Authorization", "X-API-Key"],
)
db = DatabaseHandler()
kb = KnowledgeManager(KB_CONFIG)
ai = AIEngine(db, kb)
_LOGS: list[dict] = []
_MAX_LOGS = 500
def _append_log(entry: dict) -> None:
_LOGS.append(entry)
if len(_LOGS) > _MAX_LOGS:
del _LOGS[0 : max(0, len(_LOGS) - _MAX_LOGS)]
@app.get("/health")
def health():
return {
"ok": True,
"build": os.environ.get("SPACE_REVISION")
or os.environ.get("HF_SPACE_REVISION")
or os.environ.get("GIT_SHA")
or "local",
"features": {
"autoInsights": True,
"sqlIdentifierNormalization": True,
"rag": True,
},
"dbConfigured": db.is_configured(),
"db": db.healthcheck(),
"groqConfigured": ai.is_configured(),
}
@app.get("/admin/logs", response_class=HTMLResponse, dependencies=[Depends(require_api_key)])
def admin_logs():
rows = list(reversed(_LOGS))
def td(v) -> str:
if v is None:
return "<td class='muted'>—</td>"
return f"<td>{escape(str(v))}</td>"
html = """
<html>
<head>
<meta charset='utf-8' />
<meta name='viewport' content='width=device-width, initial-scale=1' />
<title>Pravah Backend Logs</title>
<style>
body{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial; background:#0b1220; color:#e5e7eb; padding:16px;}
.card{background:#0f172a; border:1px solid #1f2937; border-radius:12px; padding:14px;}
h1{font-size:18px; margin:0 0 10px 0;}
.sub{color:#9ca3af; font-size:12px; margin-bottom:12px;}
table{width:100%; border-collapse:collapse; font-size:12px;}
th,td{border-bottom:1px solid #1f2937; padding:8px; vertical-align:top;}
th{color:#cbd5e1; text-align:left; font-weight:700; position:sticky; top:0; background:#0f172a;}
.ok{color:#34d399; font-weight:700;}
.err{color:#fb7185; font-weight:700;}
.muted{color:#94a3b8;}
code{color:#a5b4fc;}
.pill{display:inline-block; padding:2px 8px; border-radius:999px; background:#111827; border:1px solid #1f2937; font-size:11px; color:#cbd5e1;}
</style>
</head>
<body>
<div class='card'>
<h1>Pravah Backend Logs</h1>
<div class='sub'>Showing last <span class='pill'>""" + str(len(rows)) + """</span> requests (max """ + str(_MAX_LOGS) + """ in memory). Use <code>/admin/logs.json</code> for raw data.</div>
<table>
<thead>
<tr>
<th>Time (UTC)</th>
<th>Status</th>
<th>Intent</th>
<th>Model</th>
<th>Latency (s)</th>
<th>Tokens</th>
<th>Prompt</th>
<th>SQL</th>
<th>Error</th>
</tr>
</thead>
<tbody>
"""
for r in rows:
status = r.get("status")
status_html = "<span class='ok'>200</span>" if status == 200 else f"<span class='err'>{escape(str(status))}</span>"
html += "<tr>"
html += td(r.get("ts"))
html += f"<td>{status_html}</td>"
html += td(r.get("intent"))
html += td(r.get("model"))
html += td(r.get("latency"))
tok = r.get("tokens")
if tok is None:
html += "<td class='muted'>—</td>"
else:
html += f"<td><span class='pill'>{escape(str(tok))}</span></td>"
html += td(r.get("prompt"))
sql = r.get("sql")
if sql:
html += f"<td><code>{escape(str(sql))}</code></td>"
else:
html += "<td class='muted'>—</td>"
err = r.get("error")
if err:
html += f"<td class='err'>{escape(str(err))}</td>"
else:
html += "<td class='muted'>—</td>"
html += "</tr>"
html += """
</tbody>
</table>
</div>
</body>
</html>
"""
return HTMLResponse(content=html)
@app.get("/admin/logs.json", dependencies=[Depends(require_api_key)])
def admin_logs_json():
return {"items": _LOGS[-_MAX_LOGS:], "count": len(_LOGS)}
@app.post("/api/ai/query", response_model=AiQueryResponse, dependencies=[Depends(require_api_key)])
def ai_query(payload: AiQueryRequest):
if not ai.is_configured():
raise HTTPException(status_code=503, detail="Groq is not configured. Set GROQ_API_KEY.")
started = time.time()
sql = None
intent = None
model = None
tokens = None
error = None
status = 200
try:
plan, metrics = ai.process_query(payload.prompt, history=[])
model = metrics.get("model")
tokens = metrics.get("tokens")
intent = plan.get("intent", "chat")
reply = plan.get("message", "")
sql = plan.get("sql")
chart_config = plan.get("chart_config")
df = None
if intent in ["sql", "analytics"] and isinstance(sql, str) and sql.strip():
if not db.is_configured():
return AiQueryResponse(
answer=f"{reply}\n\nDatabase not configured. Set POSTGRES_* settings.",
visualization=None,
)
df, err = db.execute_query(sql)
if err:
fixed_sql = ai.repair_sql(payload.prompt, sql, err)
if fixed_sql and fixed_sql.strip() and fixed_sql.strip() != sql.strip():
sql = fixed_sql
df, err2 = db.execute_query(sql)
if err2:
return AiQueryResponse(
answer=f"{reply}\n\nSQL Error: {err2}",
visualization=None,
)
else:
return AiQueryResponse(answer=f"{reply}\n\nSQL Error: {err}", visualization=None)
answer, visualization = build_visualization_from_intent(intent, reply, df, sql, chart_config)
return AiQueryResponse(answer=answer, visualization=visualization)
except HTTPException:
raise
except Exception as e:
status = 500
error = str(e)
raise
finally:
latency = round(time.time() - started, 3)
_append_log(
{
"ts": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"status": status,
"intent": intent,
"model": model,
"latency": latency,
"tokens": tokens,
"prompt": (payload.prompt or "")[:2000],
"sql": (sql or "")[:4000] if isinstance(sql, str) else None,
"error": error,
}
)