thangvip's picture
feat: deploy complete Compliment Forest app
9dad6a7 verified
from __future__ import annotations
from pathlib import Path
from typing import Any
import gradio as gr
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, ConfigDict, Field
from starlette.responses import StreamingResponse
from .backends.image import DemoImageBackend, FluxImageBackend
from .backends.text import DemoTextBackend, LlamaCppTextBackend
from .config import AppConfig
from .orchestrator import ForestOrchestrator
from .trace import TraceRecorder
class ForestRequest(BaseModel):
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
name: str = Field(min_length=1, max_length=80)
situation: str = Field(min_length=1, max_length=1200)
seed: int | None = Field(default=None, ge=0, le=2_147_483_647)
def build_orchestrator(config: AppConfig) -> ForestOrchestrator:
if config.text_backend == "llama_cpp":
text_backend = LlamaCppTextBackend(
base_url=config.llama_base_url,
model=config.llama_model,
)
else:
text_backend = DemoTextBackend()
if config.image_backend == "flux":
image_backend = FluxImageBackend(
model_id=config.flux_model_id,
lora_id=config.flux_lora_id,
local_files_only=config.local_files_only,
)
else:
image_backend = DemoImageBackend()
trace_recorder = TraceRecorder(config.trace_path) if config.trace_path else None
return ForestOrchestrator(
text_backend=text_backend,
image_backend=image_backend,
trace_recorder=trace_recorder,
)
def create_app(
*,
config: AppConfig | None = None,
orchestrator: Any | None = None,
frontend_dir: str | Path | None = None,
) -> gr.Server:
runtime = config or AppConfig.from_env()
forest = orchestrator or build_orchestrator(runtime)
frontend = (
Path(frontend_dir)
if frontend_dir is not None
else Path(__file__).resolve().parents[2] / "frontend"
)
app = gr.Server(
title="The Compliment Forest",
description="A progressive path of grounded encouragement.",
docs_url=None,
redoc_url=None,
)
@app.get("/")
def index() -> FileResponse:
return FileResponse(frontend / "index.html")
@app.get("/styles.css")
def styles() -> FileResponse:
return FileResponse(frontend / "styles.css", media_type="text/css")
@app.get("/app.js")
def javascript() -> FileResponse:
return FileResponse(frontend / "app.js", media_type="text/javascript")
assets = frontend / "assets"
if assets.exists():
app.mount("/assets", StaticFiles(directory=assets), name="assets")
@app.get("/health")
def health() -> dict[str, object]:
return {
"status": "ok",
"text_backend": runtime.text_backend,
"image_backend": runtime.image_backend,
"off_grid": True,
"model_parameter_budget_billions": 18,
}
@app.post("/api/forest")
def generate_forest(request: ForestRequest) -> StreamingResponse:
def stream():
seed = request.seed if request.seed is not None else runtime.default_seed
for event in forest.generate(request.name, request.situation, seed):
yield event.model_dump_json() + "\n"
return StreamingResponse(stream(), media_type="application/x-ndjson")
return app