doc-ingestion / src /api /main.py
vampokala's picture
Added additional embedding models to the list
76de0b2
"""FastAPI application for citation-aware querying."""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
from collections import defaultdict, deque
from contextlib import asynccontextmanager, suppress
from datetime import datetime
from pathlib import Path
from typing import Any, Deque, Dict, Iterator, cast
from fastapi import BackgroundTasks, FastAPI, File, Form, Header, HTTPException, Request, Response, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from redis import Redis
from redis.exceptions import RedisError
from src.api.models import (
CitationModel,
HealthModel,
LLMConfigModel,
MetricsModel,
QueryRequestModel,
QueryResponseModel,
RetrievedChunkModel,
RuntimeConfigModel,
TruthfulnessModel,
)
from src.core.observability import get_observer
from src.core.rag_orchestrator import (
COLLECTION_NAME,
QueryRequest,
QueryResponse,
RAGOrchestrator,
StreamingQuerySession,
)
from src.monitoring.metrics import RequestMetrics, get_metrics_collector
from src.utils.config import load_config
from src.web import session_corpus
from src.web.ingestion_service import (
MAX_FILE_BYTES,
MAX_FILES_PER_SESSION,
MAX_SESSION_BYTES,
run_ingest,
save_uploaded_files,
)
_cfg = load_config("config.yaml")
_orchestrator = RAGOrchestrator(_cfg)
_metrics_collector = get_metrics_collector()
_rate_window: Dict[str, Deque[float]] = defaultdict(deque)
_redis_client: "Redis | None" = None
# Redis is a base dependency in this project; keep a feature flag for defensive gating.
_REDIS_AVAILABLE = True
_logger = logging.getLogger("api.audit")
def _frontend_origins() -> list[str]:
raw = os.getenv("DOC_FRONTEND_ORIGINS", "").strip()
if raw:
return [p.strip() for p in raw.split(",") if p.strip()]
if os.getenv("DOC_PROFILE", "").strip().lower() == "demo":
return ["*"]
return [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:3000",
"http://localhost:8000",
"http://127.0.0.1:8000",
]
def _demo_uploads_enabled() -> bool:
return (
os.getenv("DOC_PROFILE", "").strip().lower() == "demo"
and os.getenv("DOC_DEMO_UPLOADS", "0").strip() == "1"
)
@asynccontextmanager
async def _lifespan(_: FastAPI):
stop = asyncio.Event()
async def _janitor_loop() -> None:
while not stop.is_set():
try:
session_corpus.janitor_sweep()
except Exception:
pass
try:
await asyncio.wait_for(stop.wait(), timeout=60.0)
except asyncio.TimeoutError:
pass
task = asyncio.create_task(_janitor_loop())
try:
yield
finally:
stop.set()
task.cancel()
with suppress(Exception):
await task
app = FastAPI(title="Doc Ingestion Citation API", version="0.1.0", lifespan=_lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=_frontend_origins(),
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
def _resolve_client_key(request: Request, api_key: str | None) -> str:
if api_key:
return f"key:{api_key}"
if request.client and request.client.host:
return f"ip:{request.client.host}"
return "ip:unknown"
def _audit_log(event: str, request: Request, **fields: object) -> None:
payload = {
"event": event,
"path": request.url.path,
"method": request.method,
"client_ip": request.client.host if request.client else "unknown",
"ts": int(time.time()),
}
payload.update(fields)
_logger.info(json.dumps(payload, separators=(",", ":")))
def _get_redis() -> "Redis | None":
global _redis_client
if not _REDIS_AVAILABLE or not _cfg.api.redis_rate_limit_enabled:
return None
if _redis_client is not None:
return _redis_client
try:
_redis_client = Redis.from_url(_cfg.api.redis_url, decode_responses=True)
_redis_client.ping()
return _redis_client
except (RedisError, Exception):
return None
def _enforce_rate_limit(client_key: str) -> None:
limit = int(_cfg.api.rate_limit_per_minute)
redis_client = _get_redis()
if redis_client is not None:
now_bucket = int(time.time() // 60)
redis_key = f"ratelimit:{client_key}:{now_bucket}"
try:
current = int(cast(Any, redis_client.incr(redis_key)))
if current == 1:
redis_client.expire(redis_key, 120)
if current > limit:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
return
except RedisError:
# fallback to in-memory limiter when Redis has transient issues
pass
now = time.time()
q = _rate_window[client_key]
while q and (now - q[0]) > 60.0:
q.popleft()
if len(q) >= limit:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
q.append(now)
def _is_local_provider(provider: str | None) -> bool:
selected = (provider or _cfg.llm.default_provider or "").strip().lower()
if selected == "claude":
selected = "anthropic"
return selected == "ollama"
def _verify_auth(api_key: str | None, provider: str | None = None) -> None:
if not _cfg.api.auth_enabled:
return
if os.getenv("DOC_PROFILE", "").strip().lower() == "demo":
return
# Local Ollama queries do not require API auth by default.
if _is_local_provider(provider):
return
valid_keys = _cfg.api.resolved_api_keys()
if not valid_keys:
# secure-by-default for production paths
raise HTTPException(status_code=503, detail="API auth enabled but no keys configured")
if not api_key or api_key not in valid_keys:
raise HTTPException(status_code=401, detail="Invalid API key")
def _retrieved_chunks_json(final: QueryResponse) -> list[dict[str, Any]]:
"""Serialize retrieval results for SSE / JSON (same shape as non-streaming QueryResponseModel)."""
out: list[dict[str, Any]] = []
for item in final.retrieved:
legacy = item.to_legacy_dict()
out.append(
RetrievedChunkModel(
id=legacy["id"],
score=float(legacy.get("score") or 0.0),
source=str(legacy.get("source") or "hybrid"),
confidence=float(legacy.get("confidence") or 0.0),
metadata=dict(legacy.get("metadata") or {}),
preview=(legacy.get("text") or "")[:240],
).model_dump(),
)
return out
def _calculate_cost(provider: str, model: str, answer_tokens: int = 0, query_tokens: int = 0) -> float:
"""Calculate USD cost of request based on tokens and provider pricing.
Rough estimates — update as provider pricing changes.
"""
if provider == "openai":
return answer_tokens * 0.000002 + query_tokens * 0.000001
elif provider == "anthropic":
return answer_tokens * 0.0000024 + query_tokens * 0.0000008
elif provider == "gemini":
return answer_tokens * 0.000001 + query_tokens * 0.0000005
else:
return 0.0
def _build_request_metrics(request_id: str, out: QueryResponse) -> RequestMetrics:
"""Build metrics with explicit cache-hit step latency semantics."""
step_latencies = out.step_latencies or {}
is_cached = bool(out.cached)
return RequestMetrics(
request_id=request_id,
total_latency_ms=out.processing_time_ms,
retrieval_latency_ms=None if is_cached else step_latencies.get("retrieval", 0.0),
reranking_latency_ms=None if is_cached else step_latencies.get("reranking", 0.0),
generation_latency_ms=None if is_cached else step_latencies.get("generation", 0.0),
citation_latency_ms=None if is_cached else step_latencies.get("citation_verification", 0.0),
truthfulness_latency_ms=None if is_cached else step_latencies.get("truthfulness_scoring", 0.0),
cost_usd=_calculate_cost(out.provider, out.model),
citation_groundedness=out.truthfulness.citation_groundedness if out.truthfulness else 0.0,
nli_faithfulness=out.truthfulness.nli_faithfulness if out.truthfulness else 0.0,
uncited_claims=out.truthfulness.uncited_claims if out.truthfulness else 0,
timestamp=datetime.utcnow().isoformat(),
cached=is_cached,
)
@app.get("/health", response_model=HealthModel)
def health() -> HealthModel:
return HealthModel(status="ok", collection=COLLECTION_NAME)
@app.get("/config/llm", response_model=LLMConfigModel)
def llm_config() -> LLMConfigModel:
"""Allowed providers/models and defaults from server config (for UI dropdowns)."""
llm = _cfg.llm
provider_key_configured = {
provider: llm.provider_has_key(provider)
for provider in llm.allowed_models_by_provider.keys()
}
return LLMConfigModel(
default_provider=llm.default_provider,
default_model_by_provider=dict(llm.default_model_by_provider),
allowed_models_by_provider={k: list(v) for k, v in llm.allowed_models_by_provider.items()},
provider_key_configured=provider_key_configured,
demo_mode=os.getenv("DOC_PROFILE", "").strip().lower() == "demo",
)
@app.get("/config/runtime", response_model=RuntimeConfigModel)
def runtime_config() -> RuntimeConfigModel:
embedding_profiles = {
name: profile.model_dump()
for name, profile in _cfg.embeddings.profiles.items()
}
return RuntimeConfigModel(
chunking_default_strategy=_cfg.chunking.default_strategy,
chunking_allowed_strategies=list(_cfg.chunking.allowed_strategies),
embedding_default_profile=_cfg.embeddings.default_profile,
embedding_profiles=embedding_profiles,
)
@app.get("/metrics", response_model=MetricsModel)
def metrics(x_api_key: str | None = Header(default=None, alias="X-API-Key")) -> MetricsModel:
# no request object here; audit logs are captured in protected endpoints with request context
_verify_auth(x_api_key, provider="metrics")
return MetricsModel(
cache_ttl_seconds=int(_cfg.generation.cache_ttl),
available_providers=sorted(_cfg.llm.allowed_models_by_provider.keys()),
)
def _session_summary(session: session_corpus.SessionCorpus) -> dict[str, Any]:
files = []
for p in sorted(session.upload_dir.glob("*")):
if p.is_file():
files.append({"name": p.name, "size_bytes": p.stat().st_size})
touched_path = session.upload_dir.parent / ".touched"
mtime = touched_path.stat().st_mtime if touched_path.exists() else time.time()
ttl = session_corpus._session_ttl_seconds()
expires_at = int(mtime + ttl)
return {
"session_id": session.session_id,
"files": files,
"total_bytes": session_corpus.total_bytes(session),
"max_session_bytes": MAX_SESSION_BYTES,
"max_files": MAX_FILES_PER_SESSION,
"expires_at": expires_at,
}
def _raise_sessions_demo_disabled() -> None:
"""Avoid POST /sessions falling through to StaticFiles (which responds with 405)."""
raise HTTPException(
status_code=404,
detail=(
"Demo sessions are disabled on this server. "
"Set DOC_PROFILE=demo and DOC_DEMO_UPLOADS=1 on the API process."
),
)
if _demo_uploads_enabled():
@app.post("/sessions")
def create_session(response: Response) -> dict[str, Any]:
sid = session_corpus.new_session_id()
session = session_corpus.get_or_create(sid)
session_corpus.touch(sid)
expires_at = int(time.time() + session_corpus._session_ttl_seconds())
response.headers["X-Demo-Session-Id"] = sid
return {"session_id": session.session_id, "expires_at": expires_at}
@app.get("/sessions/{sid}")
def get_session(sid: str, response: Response) -> dict[str, Any]:
session = session_corpus.get_or_create(sid)
session_corpus.touch(sid)
response.headers["X-Demo-Session-Id"] = sid
return _session_summary(session)
@app.post("/sessions/{sid}/documents")
def upload_session_documents(
sid: str,
request: Request,
files: list[UploadFile] = File(default_factory=list),
chunk_strategy: str | None = Form(default=None),
embedding_profile: str | None = Form(default=None),
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> dict[str, Any]:
_verify_auth(x_api_key, provider="uploads")
client_key = _resolve_client_key(request, x_api_key)
_enforce_rate_limit(client_key)
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
session = session_corpus.get_or_create(sid)
existing = session_corpus.total_bytes(session)
class _UploadBlob:
def __init__(self, name: str, payload: bytes) -> None:
self.name = name
self._payload = payload
def getvalue(self) -> bytes:
return self._payload
staged_objs = []
for f in files:
payload = f.file.read()
staged_objs.append(_UploadBlob(f.filename or "unknown", payload))
staged = save_uploaded_files(
str(session.upload_dir),
staged_objs,
existing_bytes=existing,
max_files=MAX_FILES_PER_SESSION,
max_file_bytes=MAX_FILE_BYTES,
max_session_bytes=MAX_SESSION_BYTES,
)
try:
run_ingest(
str(session.upload_dir),
bm25_index_path=str(session.bm25_index_path),
collection_name=session.collection_name,
chroma_path=str(session.chroma_path),
chunk_strategy=chunk_strategy,
embedding_profile=embedding_profile,
)
except Exception as exc:
detail = str(exc)
if "Failed to connect to Ollama" in detail:
detail = (
"Embedding backend unavailable: Ollama is not reachable. "
"Select a non-Ollama embedding profile (for example 'st_minilm') "
"or configure Ollama on the server."
)
raise HTTPException(status_code=400, detail=detail) from exc
session_corpus.touch(sid)
return {"session_id": sid, "results": [r.__dict__ for r in staged], **_session_summary(session)}
@app.delete("/sessions/{sid}")
def delete_session(sid: str, response: Response) -> dict[str, Any]:
session_corpus.delete_session(sid)
new_sid = session_corpus.new_session_id()
session_corpus.get_or_create(new_sid)
session_corpus.touch(new_sid)
response.headers["X-Demo-Session-Id"] = new_sid
return {"deleted_session_id": sid, "session_id": new_sid}
else:
# Register these so session paths are not handled by the SPA/static mount below.
@app.post("/sessions")
def _sessions_off_create():
_raise_sessions_demo_disabled()
@app.get("/sessions/{sid}")
def _sessions_off_get(_sid: str):
_raise_sessions_demo_disabled()
@app.post("/sessions/{sid}/documents")
def _sessions_off_documents(_sid: str):
_raise_sessions_demo_disabled()
@app.delete("/sessions/{sid}")
def _sessions_off_delete(_sid: str):
_raise_sessions_demo_disabled()
@app.post("/query", response_model=QueryResponseModel)
def query(
req: QueryRequestModel,
request: Request,
background_tasks: BackgroundTasks,
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> QueryResponseModel:
request_id = str(uuid.uuid4())
try:
_verify_auth(x_api_key, req.provider)
except HTTPException as exc:
_audit_log("auth_failed", request, status=exc.status_code, reason=exc.detail)
raise
client_key = _resolve_client_key(request, x_api_key)
_enforce_rate_limit(client_key)
_audit_log("auth_success", request, client_key=client_key)
session_kwargs: dict[str, Any] = {}
knowledge_scope = (req.knowledge_scope or "global").strip().lower()
if req.session_id:
if not _demo_uploads_enabled():
raise HTTPException(status_code=400, detail="Session querying is only enabled in demo uploads mode")
session = session_corpus.get_or_create(req.session_id)
if not session.upload_dir.exists():
raise HTTPException(status_code=404, detail="Session not found")
session_corpus.touch(req.session_id)
has_uploads = any(p.is_file() for p in session.upload_dir.glob("*"))
if knowledge_scope == "session" and not has_uploads:
raise HTTPException(status_code=409, detail="No uploaded documents in this session yet. Upload first.")
if knowledge_scope == "both" and not has_uploads:
knowledge_scope = "global"
session_kwargs = {
"session_bm25_index_path": str(session.bm25_index_path),
"session_collection_name": session.collection_name,
"session_chroma_path": str(session.chroma_path),
"knowledge_scope": knowledge_scope,
}
try:
out = _orchestrator.run(
QueryRequest(
query_text=req.query,
top_k=req.top_k,
use_llm=req.use_llm,
use_rerank=req.use_rerank,
stream=req.stream,
provider=req.provider,
model=req.model,
provider_api_key=req.provider_api_key,
reranker_model=req.reranker_model,
include_citations=req.include_citations,
embedding_profile=req.embedding_profile,
**session_kwargs,
)
)
except Exception as exc: # keep API error shape stable
_audit_log("query_failed", request, client_key=client_key, error=str(exc))
_logger.exception("query_failed provider=%s model=%s client_key=%s", req.provider, req.model, client_key)
raise HTTPException(status_code=400, detail=str(exc)) from exc
_audit_log(
"query_success",
request,
client_key=client_key,
provider=out.provider,
model=out.model,
latency_ms=round(out.processing_time_ms, 2),
cached=out.cached,
)
retrieved = []
for item in out.retrieved:
legacy = item.to_legacy_dict()
retrieved.append(
RetrievedChunkModel(
id=legacy["id"],
score=float(legacy.get("score") or 0.0),
source=str(legacy.get("source") or "hybrid"),
confidence=float(legacy.get("confidence") or 0.0),
metadata=dict(legacy.get("metadata") or {}),
preview=(legacy.get("text") or "")[:240],
)
)
truthfulness_model = None
if out.truthfulness is not None:
t = out.truthfulness
truthfulness_model = TruthfulnessModel(
nli_faithfulness=t.nli_faithfulness,
citation_groundedness=t.citation_groundedness,
uncited_claims=t.uncited_claims,
score=t.score,
)
citation_models = [CitationModel.model_validate(c) for c in out.citations]
_metrics_collector.record_request(_build_request_metrics(request_id, out))
observer = get_observer()
background_tasks.add_task(observer.flush_async)
return QueryResponseModel(
query=out.query,
provider=out.provider,
model=out.model,
answer=out.answer,
processing_time_ms=out.processing_time_ms,
cached=out.cached,
validation_issues=out.validation_issues,
citations=citation_models,
retrieved=retrieved,
truthfulness=truthfulness_model,
embedding_profile=out.embedding_profile,
)
@app.post("/query/stream")
def query_stream(
req: QueryRequestModel,
request: Request,
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> StreamingResponse:
try:
_verify_auth(x_api_key, req.provider)
except HTTPException as exc:
_audit_log("auth_failed", request, status=exc.status_code, reason=exc.detail)
raise
client_key = _resolve_client_key(request, x_api_key)
_enforce_rate_limit(client_key)
_audit_log("stream_auth_success", request, client_key=client_key)
session_kwargs: dict[str, Any] = {}
knowledge_scope = (req.knowledge_scope or "global").strip().lower()
if req.session_id:
if not _demo_uploads_enabled():
raise HTTPException(status_code=400, detail="Session querying is only enabled in demo uploads mode")
session = session_corpus.get_or_create(req.session_id)
if not session.upload_dir.exists():
raise HTTPException(status_code=404, detail="Session not found")
session_corpus.touch(req.session_id)
has_uploads = any(p.is_file() for p in session.upload_dir.glob("*"))
if knowledge_scope == "session" and not has_uploads:
raise HTTPException(status_code=409, detail="No uploaded documents in this session yet. Upload first.")
if knowledge_scope == "both" and not has_uploads:
knowledge_scope = "global"
session_kwargs = {
"session_bm25_index_path": str(session.bm25_index_path),
"session_collection_name": session.collection_name,
"session_chroma_path": str(session.chroma_path),
"knowledge_scope": knowledge_scope,
}
def _gen() -> Iterator[str]:
request_id = str(uuid.uuid4())
try:
stream_req = QueryRequest(
query_text=req.query,
top_k=req.top_k,
use_llm=req.use_llm,
use_rerank=req.use_rerank,
stream=True,
provider=req.provider,
model=req.model,
provider_api_key=req.provider_api_key,
reranker_model=req.reranker_model,
include_citations=req.include_citations,
embedding_profile=req.embedding_profile,
**session_kwargs,
)
with StreamingQuerySession(_orchestrator, stream_req) as session:
for piece in session.iter_tokens():
yield f"data: {json.dumps({'type': 'token', 'text': piece})}\n\n"
final = session.finalize()
_audit_log(
"stream_success",
request,
client_key=client_key,
provider=final.provider,
model=final.model,
)
final_payload: dict[str, Any] = {
"type": "final",
"citations": final.citations,
"retrieved": _retrieved_chunks_json(final),
"truthfulness": final.truthfulness.to_dict() if final.truthfulness is not None else None,
"provider": final.provider,
"model": final.model,
"processing_time_ms": final.processing_time_ms,
"cached": final.cached,
"validation_issues": final.validation_issues,
"embedding_profile": final.embedding_profile,
}
_metrics_collector.record_request(_build_request_metrics(request_id, final))
yield f"data: {json.dumps(final_payload)}\n\n"
yield "data: [DONE]\n\n"
except Exception as exc:
_audit_log("stream_failed", request, client_key=client_key, error=str(exc))
_logger.exception("stream_failed provider=%s model=%s client_key=%s", req.provider, req.model, client_key)
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(_gen(), media_type="text/event-stream")
@app.get("/observability/dashboard")
def observability_dashboard(x_api_key: str | None = Header(default=None, alias="X-API-Key")):
"""Return real-time observability metrics for dashboarding."""
_verify_auth(x_api_key, provider="metrics")
return _metrics_collector.get_dashboard_metrics()
_ui_static = Path(__file__).resolve().parent.parent.parent / "static"
if _ui_static.is_dir() and (_ui_static / "index.html").is_file():
app.mount("/", StaticFiles(directory=str(_ui_static), html=True), name="ui")