from __future__ import annotations from pathlib import Path from typing import Any import gradio as gr from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, ConfigDict, Field from starlette.responses import StreamingResponse from .backends.image import DemoImageBackend, FluxImageBackend from .backends.text import DemoTextBackend, LlamaCppTextBackend from .config import AppConfig from .orchestrator import ForestOrchestrator from .trace import TraceRecorder class ForestRequest(BaseModel): model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) name: str = Field(min_length=1, max_length=80) situation: str = Field(min_length=1, max_length=1200) seed: int | None = Field(default=None, ge=0, le=2_147_483_647) def build_orchestrator(config: AppConfig) -> ForestOrchestrator: if config.text_backend == "llama_cpp": text_backend = LlamaCppTextBackend( base_url=config.llama_base_url, model=config.llama_model, ) else: text_backend = DemoTextBackend() if config.image_backend == "flux": image_backend = FluxImageBackend( model_id=config.flux_model_id, lora_id=config.flux_lora_id, local_files_only=config.local_files_only, ) else: image_backend = DemoImageBackend() trace_recorder = TraceRecorder(config.trace_path) if config.trace_path else None return ForestOrchestrator( text_backend=text_backend, image_backend=image_backend, trace_recorder=trace_recorder, ) def create_app( *, config: AppConfig | None = None, orchestrator: Any | None = None, frontend_dir: str | Path | None = None, ) -> gr.Server: runtime = config or AppConfig.from_env() forest = orchestrator or build_orchestrator(runtime) frontend = ( Path(frontend_dir) if frontend_dir is not None else Path(__file__).resolve().parents[2] / "frontend" ) app = gr.Server( title="The Compliment Forest", description="A progressive path of grounded encouragement.", docs_url=None, redoc_url=None, ) @app.get("/") def index() -> FileResponse: return FileResponse(frontend / "index.html") @app.get("/styles.css") def styles() -> FileResponse: return FileResponse(frontend / "styles.css", media_type="text/css") @app.get("/app.js") def javascript() -> FileResponse: return FileResponse(frontend / "app.js", media_type="text/javascript") assets = frontend / "assets" if assets.exists(): app.mount("/assets", StaticFiles(directory=assets), name="assets") @app.get("/health") def health() -> dict[str, object]: return { "status": "ok", "text_backend": runtime.text_backend, "image_backend": runtime.image_backend, "off_grid": True, "model_parameter_budget_billions": 18, } @app.post("/api/forest") def generate_forest(request: ForestRequest) -> StreamingResponse: def stream(): seed = request.seed if request.seed is not None else runtime.default_seed for event in forest.generate(request.name, request.situation, seed): yield event.model_dump_json() + "\n" return StreamingResponse(stream(), media_type="application/x-ndjson") return app