| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import gradio as gr |
| from fastapi.responses import FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, ConfigDict, Field |
| from starlette.responses import StreamingResponse |
|
|
| from .backends.image import DemoImageBackend, FluxImageBackend |
| from .backends.text import DemoTextBackend, LlamaCppTextBackend |
| from .config import AppConfig |
| from .orchestrator import ForestOrchestrator |
| from .trace import TraceRecorder |
|
|
|
|
| class ForestRequest(BaseModel): |
| model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) |
|
|
| name: str = Field(min_length=1, max_length=80) |
| situation: str = Field(min_length=1, max_length=1200) |
| seed: int | None = Field(default=None, ge=0, le=2_147_483_647) |
|
|
|
|
| def build_orchestrator(config: AppConfig) -> ForestOrchestrator: |
| if config.text_backend == "llama_cpp": |
| text_backend = LlamaCppTextBackend( |
| base_url=config.llama_base_url, |
| model=config.llama_model, |
| ) |
| else: |
| text_backend = DemoTextBackend() |
|
|
| if config.image_backend == "flux": |
| image_backend = FluxImageBackend( |
| model_id=config.flux_model_id, |
| lora_id=config.flux_lora_id, |
| local_files_only=config.local_files_only, |
| ) |
| else: |
| image_backend = DemoImageBackend() |
| trace_recorder = TraceRecorder(config.trace_path) if config.trace_path else None |
| return ForestOrchestrator( |
| text_backend=text_backend, |
| image_backend=image_backend, |
| trace_recorder=trace_recorder, |
| ) |
|
|
|
|
| def create_app( |
| *, |
| config: AppConfig | None = None, |
| orchestrator: Any | None = None, |
| frontend_dir: str | Path | None = None, |
| ) -> gr.Server: |
| runtime = config or AppConfig.from_env() |
| forest = orchestrator or build_orchestrator(runtime) |
| frontend = ( |
| Path(frontend_dir) |
| if frontend_dir is not None |
| else Path(__file__).resolve().parents[2] / "frontend" |
| ) |
| app = gr.Server( |
| title="The Compliment Forest", |
| description="A progressive path of grounded encouragement.", |
| docs_url=None, |
| redoc_url=None, |
| ) |
|
|
| @app.get("/") |
| def index() -> FileResponse: |
| return FileResponse(frontend / "index.html") |
|
|
| @app.get("/styles.css") |
| def styles() -> FileResponse: |
| return FileResponse(frontend / "styles.css", media_type="text/css") |
|
|
| @app.get("/app.js") |
| def javascript() -> FileResponse: |
| return FileResponse(frontend / "app.js", media_type="text/javascript") |
|
|
| assets = frontend / "assets" |
| if assets.exists(): |
| app.mount("/assets", StaticFiles(directory=assets), name="assets") |
|
|
| @app.get("/health") |
| def health() -> dict[str, object]: |
| return { |
| "status": "ok", |
| "text_backend": runtime.text_backend, |
| "image_backend": runtime.image_backend, |
| "off_grid": True, |
| "model_parameter_budget_billions": 18, |
| } |
|
|
| @app.post("/api/forest") |
| def generate_forest(request: ForestRequest) -> StreamingResponse: |
| def stream(): |
| seed = request.seed if request.seed is not None else runtime.default_seed |
| for event in forest.generate(request.name, request.situation, seed): |
| yield event.model_dump_json() + "\n" |
|
|
| return StreamingResponse(stream(), media_type="application/x-ndjson") |
|
|
| return app |
|
|