"""FastAPI application factory.""" from __future__ import annotations import os import time from pathlib import Path import psutil import structlog from fastapi import FastAPI from agent_bench.agents.orchestrator import Orchestrator from agent_bench.core.config import AppConfig, load_config, load_task_config from agent_bench.core.provider import create_provider from agent_bench.rag.embedder import Embedder from agent_bench.rag.retriever import Retriever from agent_bench.rag.store import HybridStore from agent_bench.serving.middleware import MetricsCollector, RateLimitMiddleware, RequestMiddleware from agent_bench.serving.routes import router from agent_bench.tools.calculator import CalculatorTool from agent_bench.tools.registry import ToolRegistry from agent_bench.tools.search import SearchTool def create_app(config: AppConfig | None = None) -> FastAPI: """Create and configure the FastAPI application. Initializes all singletons and attaches them to app.state. """ if config is None: config = load_config() app = FastAPI(title="agent-bench", version="0.1.0") log = structlog.get_logger() # Load task config for system prompt task = load_task_config("tech_docs") # Providers — create all available, keyed by name provider = create_provider(config) providers: dict = {config.provider.default: provider} _alt_providers = {"openai", "anthropic"} - {config.provider.default} for alt in _alt_providers: try: from agent_bench.core.provider import ( AnthropicProvider, OpenAIProvider, ) if alt == "openai" and os.environ.get("OPENAI_API_KEY"): providers["openai"] = OpenAIProvider(config) elif alt == "anthropic" and os.environ.get( "ANTHROPIC_API_KEY", ): providers["anthropic"] = AnthropicProvider(config) except Exception: pass # missing dependency or key — skip # --- Shared RAG components (corpus-independent) --- embedder = Embedder( model_name=config.embedding.model, cache_dir=config.embedding.cache_dir, ) reranker = None if config.rag.reranker.enabled: from agent_bench.rag.reranker import CrossEncoderReranker reranker = CrossEncoderReranker(model_name=config.rag.reranker.model_name) # --- Security components (constructed before tools so PII redactor # can be injected into per-corpus SearchTools) --- from agent_bench.security.audit_logger import AuditLogger from agent_bench.security.injection_detector import InjectionDetector from agent_bench.security.output_validator import OutputValidator from agent_bench.security.pii_redactor import PIIRedactor sec = config.security injection_detector = InjectionDetector( tiers=sec.injection.tiers, classifier_url=sec.injection.classifier_url, enabled=sec.injection.enabled, ) pii_redactor = PIIRedactor( redact_patterns=sec.pii.redact_patterns, mode=sec.pii.mode, use_ner=sec.pii.use_ner, ) output_validator = OutputValidator( pii_check=sec.output.pii_check, url_check=sec.output.url_check, secret_check=sec.output.secret_check, blocklist=sec.output.blocklist, ) audit_logger = AuditLogger( path=sec.audit.path, max_size_bytes=sec.audit.max_size_mb * 1024 * 1024, rotate=sec.audit.rotate, ) # --- Mode-dependent construction: multi-corpus vs legacy single-corpus --- corpus_map: dict[str, dict[str, Orchestrator]] = {} orchestrators: dict[str, Orchestrator] = {} store: HybridStore if config.corpora: # Multi-corpus mode. Skip the legacy single-store path entirely — # each corpus gets its own store / retriever / registry, and the # per-corpus inner dict holds one Orchestrator per available provider. _proc = psutil.Process() _baseline_rss = _proc.memory_info().rss / 1024**2 _default_store: HybridStore | None = None for corpus_name, corpus_cfg in config.corpora.items(): # Skip corpora marked unavailable. They stay in config.corpora # for schema visibility but are not wired into corpus_map, # so routes return 400 via _resolve_orchestrator and the # dashboard can render the toggle as disabled. if not corpus_cfg.available: log.warning( "corpus_skipped_unavailable", name=corpus_name, label=corpus_cfg.label, reason="CorpusConfig.available=False", hint="set available=true once the store is built", ) continue c_store_path = Path(corpus_cfg.store_path) if c_store_path.exists() and (c_store_path / "index.faiss").exists(): c_store = HybridStore.load( str(c_store_path), rrf_k=config.rag.retrieval.rrf_k, ) else: c_store = HybridStore( dimension=384, rrf_k=config.rag.retrieval.rrf_k, ) c_retriever = Retriever( embedder=embedder, store=c_store, default_strategy=config.rag.retrieval.strategy, # type: ignore[arg-type] candidates_per_system=config.rag.retrieval.candidates_per_system, reranker=reranker, reranker_top_k=config.rag.reranker.top_k, ) c_registry = ToolRegistry() c_registry.register( SearchTool( retriever=c_retriever, default_top_k=corpus_cfg.top_k, default_strategy=config.rag.retrieval.strategy, # type: ignore[arg-type] refusal_threshold=corpus_cfg.refusal_threshold, pii_redactor=pii_redactor if sec.pii.enabled else None, ) ) c_registry.register(CalculatorTool()) inner: dict[str, Orchestrator] = {} for p_name, p_prov in providers.items(): inner[p_name] = Orchestrator( provider=p_prov, registry=c_registry, max_iterations=corpus_cfg.max_iterations, temperature=config.agent.temperature, ) corpus_map[corpus_name] = inner if corpus_name == config.default_corpus: _default_store = c_store _rss_mb = _proc.memory_info().rss / 1024**2 log.info( "corpus_loaded", name=corpus_name, label=corpus_cfg.label, store_path=str(c_store_path), providers=list(inner.keys()), rss_mb=round(_rss_mb, 1), rss_delta_mb=round(_rss_mb - _baseline_rss, 1), ) log.info( "multi_corpus_mode", corpora=list(corpus_map.keys()), default=config.default_corpus, providers=list(providers.keys()), ) # Legacy rag.refusal_threshold is ignored in multi-corpus mode; # per-corpus refusal_threshold is authoritative. Only warn when the # legacy value is non-default AND differs from the default corpus's # threshold — that is the actual drift case. A legacy value that # matches the default corpus is benign (someone kept both in sync). legacy_thresh = config.rag.refusal_threshold default_thresh = config.corpora[config.default_corpus].refusal_threshold if legacy_thresh != 0.0 and legacy_thresh != default_thresh: log.warning( "rag_refusal_threshold_drift_in_multi_corpus_mode", legacy_value=legacy_thresh, default_corpus=config.default_corpus, default_corpus_value=default_thresh, hint="rag.refusal_threshold is ignored; " "update corpora..refusal_threshold instead", ) # AppConfig._validate_default_corpus guarantees default_corpus is in # corpora when corpora is non-empty, so _default_store is always set. assert _default_store is not None store = _default_store # orchestrators (flat, per-provider) is the default-corpus inner dict # — keeps /ask's existing provider-switching code path working for # the default corpus. Per-request corpus routing in Task 3 will # consult corpus_map[corpus][provider] directly. orchestrators = dict(corpus_map[config.default_corpus]) orchestrator = orchestrators[config.provider.default] else: # Legacy single-corpus mode. log.info("single_corpus_mode_legacy") store_path = Path(config.rag.store_path) if store_path.exists() and (store_path / "index.faiss").exists(): store = HybridStore.load(str(store_path), rrf_k=config.rag.retrieval.rrf_k) else: store = HybridStore(dimension=384, rrf_k=config.rag.retrieval.rrf_k) retriever = Retriever( embedder=embedder, store=store, default_strategy=config.rag.retrieval.strategy, # type: ignore[arg-type] candidates_per_system=config.rag.retrieval.candidates_per_system, reranker=reranker, reranker_top_k=config.rag.reranker.top_k, ) registry = ToolRegistry() registry.register( SearchTool( retriever=retriever, default_top_k=config.rag.retrieval.top_k, default_strategy=config.rag.retrieval.strategy, # type: ignore[arg-type] refusal_threshold=config.rag.refusal_threshold, pii_redactor=pii_redactor if sec.pii.enabled else None, ) ) registry.register(CalculatorTool()) for name, prov in providers.items(): orchestrators[name] = Orchestrator( provider=prov, registry=registry, max_iterations=config.agent.max_iterations, temperature=config.agent.temperature, ) orchestrator = orchestrators[config.provider.default] # Metrics metrics = MetricsCollector() # Conversation memory (optional, SQLite-backed) conversation_store = None if config.memory.enabled: from agent_bench.memory.store import ConversationStore conversation_store = ConversationStore(db_path=config.memory.db_path) # Attach to app state app.state.orchestrator = orchestrator app.state.orchestrators = orchestrators app.state.corpus_map = corpus_map app.state.store = store app.state.conversation_store = conversation_store app.state.config = config app.state.system_prompt = task.system_prompt app.state.start_time = time.time() app.state.metrics = metrics app.state.injection_detector = injection_detector app.state.pii_redactor = pii_redactor app.state.output_validator = output_validator app.state.audit_logger = audit_logger # Middleware and routes (order matters: rate limit checked first) app.add_middleware(RequestMiddleware) app.add_middleware(RateLimitMiddleware, requests_per_minute=config.serving.rate_limit_rpm) app.include_router(router) # Startup warmup: eager-load models to reduce cold start latency @app.on_event("startup") async def warmup() -> None: log.info("warmup_start") _ = embedder.embed("warmup") if reranker is not None: _ = reranker.model # noqa: F841 log.info("warmup_complete") return app