ishaq101's picture
feat/Analysis State & Report Rework (#4)
0e02a0f
Raw
History Blame
6.56 kB
"""Analysis session API — create a new analysis (the per-session workspace).
An analysis IS the chat session: the `analysis_states` row and the chat `rooms`
row share one id (`analysis_id == room_id`), so the existing `room_id` on the chat
request doubles as the `analysis_id`. Creating an analysis enforces the data-first
gate (>=1 bound source) and seeds the state with a title + an optional problem
statement (validated later by the Problem Statement skill).
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db.postgres.connection import get_db
from src.db.postgres.models import AnalysisDataSourceRow, AnalysisStateRow, Room
from src.middlewares.logging import get_logger, log_execution
logger = get_logger("analysis_api")
router = APIRouter(prefix="/api/v1", tags=["Analysis"])
def _serialize_state(row: AnalysisStateRow, data_source_ids: list[str]) -> dict:
"""The full analysis payload: the 8 state fields + the bound source ids."""
return {
"id": row.id,
"analysis_title": row.analysis_title,
"problem_statement": row.problem_statement,
"problem_validated": row.problem_validated,
"user_id": row.user_id,
"report_id": row.report_id,
"data_source_ids": data_source_ids,
"created_at": row.created_at.isoformat() if row.created_at else None,
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
}
async def _bound_source_ids(db: AsyncSession, analysis_id: str) -> list[str]:
result = await db.execute(
select(AnalysisDataSourceRow.reference_id).where(
AnalysisDataSourceRow.analysis_id == analysis_id
)
)
return list(result.scalars().all())
async def _sources_by_id(user_id: str) -> dict:
"""Catalog sources keyed by source_id, to resolve `type`/`name` on binding.
Never-throw: missing catalog / read error → empty map, and binding rows fall back
to type='unknown' / name=reference_id.
"""
try:
from src.catalog.store import CatalogStore
catalog = await CatalogStore().get(user_id)
except Exception as e: # noqa: BLE001 — binding must not fail on catalog read
logger.warning("analysis: catalog read failed for binding", user_id=user_id, error=str(e))
return {}
return {s.source_id: s for s in catalog.sources} if catalog else {}
class CreateAnalysisRequest(BaseModel):
user_id: str
analysis_title: str = "New analysis"
problem_statement: str = ""
data_source_ids: list[str] = Field(default_factory=list)
@router.post("/analysis/create")
@log_execution(logger)
async def create_analysis(
request: CreateAnalysisRequest,
db: AsyncSession = Depends(get_db),
):
"""Create a new analysis session: one shared id for its state + chat room.
Data-first gate (decision #2): an analysis requires >=1 bound data source.
The bound sources are persisted as dedorch `data_sources` rows (#10) in the same
transaction as the state + room, so the analysis is scoped to exactly the sources
the user picked. `structured_flow` and the report read this binding back.
"""
if not request.data_source_ids:
raise HTTPException(
status_code=400,
detail="An analysis requires at least one bound data source.",
)
analysis_id = str(uuid.uuid4())
# The analysis IS the session: state row + chat room + source bindings share one
# id, created atomically in one transaction.
state_row = AnalysisStateRow(
id=analysis_id,
user_id=request.user_id,
analysis_title=request.analysis_title,
problem_statement=request.problem_statement,
problem_validated=False,
)
db.add(Room(id=analysis_id, user_id=request.user_id, title=request.analysis_title))
db.add(state_row)
# dict.fromkeys dedupes while preserving order. Each binding row snapshots the
# source's type + name from the catalog (reference_id = catalog source id);
# bound_at/created_at default to now() in dedorch.
bound_ids = list(dict.fromkeys(request.data_source_ids))
src_by_id = await _sources_by_id(request.user_id)
for source_id in bound_ids:
src = src_by_id.get(source_id)
db.add(
AnalysisDataSourceRow(
id=str(uuid.uuid4()),
analysis_id=analysis_id,
type=src.source_type if src else "unknown",
name=src.name if src else source_id,
reference_id=source_id,
bound_by=request.user_id,
)
)
await db.commit()
await db.refresh(state_row)
logger.info(
"analysis created",
analysis_id=analysis_id,
user_id=request.user_id,
sources=len(bound_ids),
)
return {
"status": "success",
"message": "Analysis created successfully",
"data": _serialize_state(state_row, bound_ids),
}
@router.get("/analysis")
@log_execution(logger)
async def list_analyses(user_id: str, db: AsyncSession = Depends(get_db)):
"""List a user's analyses, most-recently-updated first (Analysis sidebar).
Summary fields only (no per-row source bindings — fetch those via the detail
endpoint) to keep the list a single query.
"""
result = await db.execute(
select(AnalysisStateRow)
.where(AnalysisStateRow.user_id == user_id)
.order_by(AnalysisStateRow.updated_at.desc())
)
rows = result.scalars().all()
return {
"status": "success",
"data": [
{
"id": r.id,
"analysis_title": r.analysis_title,
"problem_validated": r.problem_validated,
"report_id": r.report_id,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
}
for r in rows
],
}
@router.get("/analysis/{analysis_id}")
@log_execution(logger)
async def get_analysis(analysis_id: str, db: AsyncSession = Depends(get_db)):
"""Read one analysis's state + bound data sources (the FE workspace render)."""
row = await db.get(AnalysisStateRow, analysis_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Analysis {analysis_id!r} not found.")
data_source_ids = await _bound_source_ids(db, analysis_id)
return {"status": "success", "data": _serialize_state(row, data_source_ids)}