"""FastAPI server for the Adaptive Presentation Engine backend. Endpoints --------- - GET / : Vue SPA entry - GET /api/healthz : process liveness (no external calls) - GET /api/health : LLM connectivity check (use sparingly) - GET /api/state : bandit posteriors - POST /api/chat_plain : baseline chat (no strategy selection) - POST /api/chat : adaptive chat (non-streaming) - POST /api/chat_stream : adaptive chat (SSE streaming — Claude-style) - POST /api/rate : explicit feedback (reward update) - POST /api/preference : user preference warm-start - POST /api/reset : clear user session SSE stream events (`evt.type`) ------------------------------ - strategy : bandit selected strategy + posteriors (emitted first) - response_delta : token chunk inside - widget_start : model opened - widget_delta : raw token chunk inside (live, Claude-style) - done : final assembled payload (canonical response + polished widget) """ from __future__ import annotations import json import os import re import time import threading import uuid from typing import Iterable import numpy as np from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.security import HTTPBearer from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from . import config, llm from .combined_prompt import ( build_combined_system_prompt, build_combined_user_prompt, finalize_widget_schema_json, is_social_or_greeting_turn, parse_combined_output, strip_leaked_widget_json, widget_schema_json_is_valid, ) from .engine import engine, USERB_ID from .auth import ( authenticate_login, create_access_token, decode_access_token, get_user_by_id, get_user_by_email, register_user, seed_users_from_db, update_password, ) from . import db as persistence from .utils import ( detect_explore_trigger, enforce_response, fast_valence, negative_strength, ) from .widget_stream import parse_combined_stream # --------------------------------------------------------------------------- # Enforcement helpers # --------------------------------------------------------------------------- def _maybe_enforce_primitive(user_message: str, strategy: str, response: str) -> str: """Apply strict format except for short greeting/thanks-only turns.""" if not config.STRICT_PRIMITIVES or not response: return response if is_social_or_greeting_turn(user_message): return response return enforce_response(strategy, response) # --------------------------------------------------------------------------- # Auth plumbing # --------------------------------------------------------------------------- bearer_scheme = HTTPBearer(auto_error=False) _password_reset_lock = threading.Lock() _password_reset_tokens: dict[str, dict] = {} _PASSWORD_RESET_TTL_SECONDS = int(os.getenv("PASSWORD_RESET_TTL_SECONDS", "900")) def require_user_id(credentials=Depends(bearer_scheme)) -> str: if credentials is None or not getattr(credentials, "credentials", None): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") token = credentials.credentials try: return decode_access_token(token) except Exception as e: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {str(e)}") def _is_admin_user(user_id: str) -> bool: """ Admin check. Previous behavior silently promoted everyone to admin when ADMIN_* config was missing; that's a production footgun. We now require explicit admin configuration — no config means no admin. """ if not getattr(config, "ADMIN_CONFIGURED", False): return False if user_id in set(map(str, getattr(config, "ADMIN_USER_IDS", []) or [])): return True rec = get_user_by_id(user_id) if not rec: return False username_n = (rec.username or "").strip().lower() email_n = (rec.email or "").strip().lower() return username_n in {u.lower() for u in (config.ADMIN_USERNAMES or [])} or email_n in { e.lower() for e in (config.ADMIN_EMAILS or []) } def require_admin_user_id(user_id: str = Depends(require_user_id)) -> str: if not _is_admin_user(user_id): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin only") return user_id # --------------------------------------------------------------------------- # SSE helpers # --------------------------------------------------------------------------- def sse_pack(evt: dict) -> str: """Pack a JSON event into an SSE data frame.""" return f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" def _json_layout_is_only_numeric_index_arrays(schema_str: str) -> bool: """ Detect bogus JSON widgets where every block is text like '[0,1,2]' (e.g. tic-tac-toe win lines) instead of real UI. Those pass structural validation but are not widgets. """ import json as _json import re as _re try: o = _json.loads(schema_str) except _json.JSONDecodeError: return False layout = o.get("layout") if not isinstance(layout, list) or len(layout) < 2: return False pat = _re.compile(r"^\s*\[\s*\d+(\s*,\s*\d+)*\s*\]\s*$") for item in layout: if not isinstance(item, dict): return False if str(item.get("type", "")).lower() != "text": return False if not pat.match(str(item.get("content", ""))): return False return True def _dispatch_json_mode_widget(widget_payload_raw: str) -> tuple[str, str, int, str]: """ STRICT components-only: always resolve to a JSON schema rendered by the registry components. The HTML escape-hatch is DISABLED — model-generated HTML is never rendered. Returns (widget_schema, widget_html, widget_height, widget_debug_tag); widget_html is always "". """ raw = (widget_payload_raw or "").strip() if not raw: return "", "", 0, "" finalized = finalize_widget_schema_json(raw) if widget_schema_json_is_valid(finalized) and not _json_layout_is_only_numeric_index_arrays(finalized): return finalized, "", 0, "json_schema_ok" if _json_layout_is_only_numeric_index_arrays(finalized): return "", "", 0, "json_degenerate_layout" return finalized, "", 0, "json_schema_invalid" # NOTE: widget generation is NOT gated by keyword matching. The synthesizer LLM # decides per turn whether a widget helps AND whether it can populate it with real # values using the allowed block types (see the warrant rubric in combined_prompt). # A keyword list both over-fires ("bar exam") and misses ("how has revenue moved?"), # and it can't know whether the data exists to fill a chart — the model can. # --------------------------------------------------------------------------- # Request/response schemas # --------------------------------------------------------------------------- class ChatPlainReq(BaseModel): uid: str | None = None message: str class ChatReq(BaseModel): uid: str | None = None message: str class RateReq(BaseModel): uid: str | None = None strategy: str x_vec: list[float] reward: float class PreferenceReq(BaseModel): uid: str | None = None strategies: list[str] = [] lock: bool = False class ResetReq(BaseModel): uid: str | None = None class PrimitiveCreateReq(BaseModel): name: str instruction: str class PrimitiveUpdateReq(BaseModel): name: str instruction: str class StrategyCreateReq(BaseModel): id: str label: str instruction: str enabled: bool = True event_name: str = "" class StrategyUpdateReq(BaseModel): label: str instruction: str enabled: bool = True event_name: str = "" class AuthRegisterReq(BaseModel): username: str email: str password: str class AuthLoginReq(BaseModel): username_or_email: str password: str class ForgotPasswordReq(BaseModel): email: str class ResetPasswordReq(BaseModel): token: str new_password: str # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI(title="Adaptive Presentation Engine") app.add_middleware( CORSMiddleware, # allow_credentials=True is incompatible with "*" per CORS spec; # keep it False since we use bearer tokens, not cookies. allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000) @app.on_event("startup") def _on_startup(): persistence.init_db() # RAG: only when enabled — ingest the financial corpus into ChromaDB on first boot # (the binary store is not shipped). When USE_RAG is off, skip entirely (no model download). if getattr(config, "USE_RAG", False): try: from rag_finance.retriever import _load as _rag_load from rag_finance.ingest import ingest as _rag_ingest col, _ = _rag_load() if col is None or col.count() == 0: stats = _rag_ingest() print(f"[rag] ingested {stats.get('documents')} financial docs into ChromaDB") except Exception as e: print(f"[rag] ingest skipped: {e}") if getattr(config, "ADMIN_CONFIGURED", False) and config.ADMIN_USERNAME and config.ADMIN_PASSWORD and config.ADMIN_EMAIL: try: register_user(username=config.ADMIN_USERNAME, email=config.ADMIN_EMAIL, password=config.ADMIN_PASSWORD) except Exception: pass try: users = persistence.load_users_from_db() seed_users_from_db(users) except Exception: pass try: persistence.load_global_state(engine) persistence.load_user_states(engine) except Exception: pass # --------------------------------------------------------------------------- # Static / SPA # --------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) def index() -> HTMLResponse: html = config.INDEX_HTML.read_bytes() return HTMLResponse(content=html, media_type="text/html; charset=utf-8") @app.get("/api/healthz") def healthz(): """Process liveness — cheap, no external dependencies.""" return {"ok": True} @app.get("/api/health") def health(): if config.LLM_MODE == "openai_compat": h = llm.openai_health() return { "server": "ok", "mode": config.LLM_MODE, "openai_base_url": config.OPENAI_BASE_URL, "model": config.OPENAI_MODEL, **h, } if config.LLM_MODE == "anthropic": h = llm.anthropic_health() return {"server": "ok", "mode": config.LLM_MODE, **h} return { "server": "ok", "mode": config.LLM_MODE, "ok": False, "reachable": False, "error": "Unsupported LLM_MODE (expected openai_compat or anthropic)", } # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- @app.post("/api/auth/register") def auth_register(req: AuthRegisterReq): try: rec = register_user(username=req.username, email=req.email, password=req.password) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) token = create_access_token(user_id=rec.user_id) return { "access_token": token, "token_type": "bearer", "user": {"user_id": rec.user_id, "username": rec.username, "email": rec.email}, } @app.post("/api/auth/login") def auth_login(req: AuthLoginReq): rec = authenticate_login(username_or_email=req.username_or_email, password=req.password) if not rec: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") token = create_access_token(user_id=rec.user_id) return {"access_token": token, "token_type": "bearer"} @app.post("/api/auth/forgot-password") def auth_forgot_password(req: ForgotPasswordReq): """Dev/demo only. Do NOT return the reset token to the caller in production.""" email = (req.email or "").strip() user = get_user_by_email(email) if not user: return {"ok": True} token = uuid.uuid4().hex now = int(time.time()) expires = now + _PASSWORD_RESET_TTL_SECONDS with _password_reset_lock: _password_reset_tokens[token] = {"user_id": user.user_id, "expires": expires} # Expose reset token only in development to ease local testing. expose_token = (os.getenv("ENV", "development") or "development").lower() != "production" if expose_token: return {"ok": True, "reset_token": token} return {"ok": True} @app.post("/api/auth/reset-password") def auth_reset_password(req: ResetPasswordReq): token = (req.token or "").strip() new_password = (req.new_password or "").strip() if not token or not new_password: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="token and new_password required") now = int(time.time()) with _password_reset_lock: rec = _password_reset_tokens.get(token) if not rec: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired token") if int(rec.get("expires") or 0) < now: _password_reset_tokens.pop(token, None) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired token") user_id = str(rec.get("user_id") or "") _password_reset_tokens.pop(token, None) try: ok = update_password(user_id=user_id, new_password=new_password) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) if not ok: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") return {"ok": True} @app.get("/api/me") def me(user_id: str = Depends(require_user_id)): rec = get_user_by_id(user_id) if not rec: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") return { "user_id": rec.user_id, "username": rec.username, "email": rec.email, "is_admin": _is_admin_user(rec.user_id), } # --------------------------------------------------------------------------- # Bandit state + strategies # --------------------------------------------------------------------------- @app.get("/api/state") def state(user_id: str = Depends(require_user_id)): uid = user_id x = np.ones(config.D) * 0.5 ub = engine.get_user(USERB_ID) return { "posterior": engine.user_posterior(uid, x), "global": engine.global_posterior(x), "userb": engine.posterior_summary(ub["mu"], ub["sigma_inv"], x), "global_n": engine.global_n, "n_users": len(engine.users), "msg_count": engine.get_user(uid)["msg_count"], } @app.get("/api/strategies") def list_strategies(user_id: str = Depends(require_user_id)): admin = _is_admin_user(user_id) items = [] for sid, it in getattr(config, "STRATEGY_ITEMS", {}).items(): enabled = bool(it.get("enabled", True)) if admin or enabled: items.append( { "id": sid, "label": it.get("label") or sid, "instruction": it.get("instruction") or "", "enabled": enabled, "event_name": it.get("event_name") or "", } ) items.sort(key=lambda x: (not bool(x.get("enabled", True)), str(x.get("id") or ""))) return {"items": items} @app.options("/api/strategies") def strategies_preflight(): return JSONResponse({"ok": True}) @app.options("/api/strategies/") def strategies_preflight_slash(): return JSONResponse({"ok": True}) @app.get("/api/strategies/usage") def strategies_usage(user_id: str = Depends(require_admin_user_id)): return {"by_id": persistence.aggregate_strategy_usage()} @app.get("/api/strategies/usage/") def strategies_usage_slash(user_id: str = Depends(require_admin_user_id)): return strategies_usage(user_id=user_id) @app.get("/api/strategies/{sid}/analytics") def strategy_analytics( sid: str, days: int = 30, user_id: str = Depends(require_admin_user_id), ): sid = str(sid or "").strip() if not sid: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="sid is required") if sid not in getattr(config, "STRATEGY_ITEMS", {}): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Strategy not found") return persistence.get_strategy_analytics(sid, days=days) @app.post("/api/strategies") def create_strategy(req: StrategyCreateReq, user_id: str = Depends(require_admin_user_id)): sid = str(req.id or "").strip() if not sid: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="id is required") if sid in getattr(config, "STRATEGY_ITEMS", {}): raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="strategy id already exists") event_name = str(req.event_name or "").strip() items = list(getattr(config, "STRATEGY_ITEMS", {}).values()) items.append({ "id": sid, "label": req.label, "instruction": req.instruction, "enabled": bool(req.enabled), "event_name": event_name, }) config.persist_strategies(items) engine.reconcile_strategies() return { "item": { "id": sid, "label": req.label, "instruction": req.instruction, "enabled": bool(req.enabled), "event_name": event_name, } } @app.put("/api/strategies/{sid}") def update_strategy(sid: str, req: StrategyUpdateReq, user_id: str = Depends(require_admin_user_id)): sid = str(sid or "").strip() if not sid: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="sid is required") cur = getattr(config, "STRATEGY_ITEMS", {}).get(sid) if not cur: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Strategy not found") event_name = str(req.event_name or "").strip() or str(cur.get("event_name") or "").strip() items = [] for it in getattr(config, "STRATEGY_ITEMS", {}).values(): if it.get("id") == sid: items.append({ "id": sid, "label": req.label, "instruction": req.instruction, "enabled": bool(req.enabled), "event_name": event_name, }) else: items.append(it) config.persist_strategies(items) engine.reconcile_strategies() return { "item": { "id": sid, "label": req.label, "instruction": req.instruction, "enabled": bool(req.enabled), "event_name": event_name, } } @app.post("/api/strategies/{sid}/enable") def enable_strategy(sid: str, user_id: str = Depends(require_admin_user_id)): sid = str(sid or "").strip() cur = getattr(config, "STRATEGY_ITEMS", {}).get(sid) if not cur: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Strategy not found") items = [] for it in getattr(config, "STRATEGY_ITEMS", {}).values(): if it.get("id") == sid: items.append({"id": sid, "label": it.get("label") or sid, "instruction": it.get("instruction") or "", "enabled": True}) else: items.append(it) config.persist_strategies(items) engine.reconcile_strategies() return {"ok": True} @app.post("/api/strategies/{sid}/disable") def disable_strategy(sid: str, user_id: str = Depends(require_admin_user_id)): sid = str(sid or "").strip() cur = getattr(config, "STRATEGY_ITEMS", {}).get(sid) if not cur: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Strategy not found") items = [] for it in getattr(config, "STRATEGY_ITEMS", {}).values(): if it.get("id") == sid: items.append({"id": sid, "label": it.get("label") or sid, "instruction": it.get("instruction") or "", "enabled": False}) else: items.append(it) config.persist_strategies(items) engine.reconcile_strategies() return {"ok": True} @app.delete("/api/strategies/{sid}") def delete_strategy(sid: str, user_id: str = Depends(require_admin_user_id)): sid = str(sid or "").strip() cur = getattr(config, "STRATEGY_ITEMS", {}).get(sid) if not cur: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Strategy not found") items = [it for it in getattr(config, "STRATEGY_ITEMS", {}).values() if it.get("id") != sid] if not items: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot delete all strategies") config.persist_strategies(items) engine.reconcile_strategies() return {"ok": True} # --------------------------------------------------------------------------- # Primitives # --------------------------------------------------------------------------- @app.get("/api/primitives") def list_primitives(user_id: str = Depends(require_admin_user_id)): return {"items": persistence.list_user_primitives(user_id)} @app.post("/api/primitives") def create_primitive(req: PrimitiveCreateReq, user_id: str = Depends(require_admin_user_id)): name = (req.name or "").strip() inst = (req.instruction or "").strip() if not name or not inst: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name and instruction are required") row = persistence.create_user_primitive(user_id=user_id, name=name, instruction=inst) return {"item": row} @app.put("/api/primitives/{prim_id}") def update_primitive(prim_id: int, req: PrimitiveUpdateReq, user_id: str = Depends(require_admin_user_id)): name = (req.name or "").strip() inst = (req.instruction or "").strip() if not name or not inst: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name and instruction are required") row = persistence.update_user_primitive(user_id=user_id, prim_id=int(prim_id), name=name, instruction=inst) if row is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Primitive not found") return {"item": row} @app.delete("/api/primitives/{prim_id}") def delete_primitive(prim_id: int, user_id: str = Depends(require_admin_user_id)): ok = persistence.delete_user_primitive(user_id=user_id, prim_id=int(prim_id)) if not ok: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Primitive not found") return {"ok": True} # --------------------------------------------------------------------------- # Conversation history # --------------------------------------------------------------------------- @app.get("/api/conversation") def conversation(limit: int = 20, user_id: str = Depends(require_user_id)): uid = user_id lim = int(limit) lim = max(1, min(lim, 50)) def _pairs_from_msgs(msgs: list[dict]) -> list[dict]: out: list[dict] = [] last_user: str | None = None for m in msgs: if m.get("role") == "user": last_user = str(m.get("content") or "") elif m.get("role") == "assistant": if last_user is None: continue out.append( { "user": last_user, "assistant": str(m.get("content") or ""), "widget_html": str(m.get("widget_html") or ""), "widget_schema": str(m.get("widget_schema") or ""), "widget_height": int(m.get("widget_height") or 0), } ) last_user = None return out[-lim:] adaptive_msgs = persistence.get_recent_conversation_messages(user_id=uid, pane="adaptive", limit=lim * 4) baseline_msgs = persistence.get_recent_conversation_messages(user_id=uid, pane="baseline", limit=lim * 4) adaptive_pairs = _pairs_from_msgs(adaptive_msgs) baseline_pairs = _pairs_from_msgs(baseline_msgs) if not adaptive_pairs: adaptive_user = engine.get_user(uid) adaptive_pairs = (adaptive_user.get("history") or [])[-lim:] if not baseline_pairs: baseline_user = engine.get_user(uid + "_plain") baseline_pairs = (baseline_user.get("history") or [])[-lim:] return {"adaptive": {"history": adaptive_pairs}, "baseline": {"history": baseline_pairs}} # --------------------------------------------------------------------------- # Background persistence (off the hot path) # --------------------------------------------------------------------------- def _persist_after_response( *, user_id: str, uid: str, msg: str, response: str, strategy: str | None, elapsed: float | None, widget_present: bool, pane: str, widget_html: str = "", widget_schema: str = "", widget_height: int = 0, ) -> None: """Persist bandit state + conversation log after the response has been sent. Runs in FastAPI BackgroundTasks so HTTP latency is not gated by SQLite commits. """ try: persistence.persist_user_state(engine, uid) if pane == "adaptive": persistence.persist_global_state(engine) persistence.log_conversation_message( user_id=user_id, pane=pane, role="user", content=msg, strategy=strategy if pane == "adaptive" else None, ) persistence.log_conversation_message( user_id=user_id, pane=pane, role="assistant", content=response, strategy=strategy if pane == "adaptive" else None, elapsed=elapsed, widget=widget_present, widget_html=widget_html or "", widget_schema=widget_schema or "", widget_height=int(widget_height or 0), ) except Exception: # Persistence failures must never surface to the user. pass # --------------------------------------------------------------------------- # Baseline chat # --------------------------------------------------------------------------- # Baseline is the CONTRAST to the adaptive engine: instead of governed components, # the model writes a raw, self-contained interactive visualization as HTML/JS that # the frontend renders inside a sandboxed