acul / app.py
acul3's picture
Deploy Pocket Studio model delivery proxy
42841d5 verified
Raw
History Blame Contribute Delete
7.59 kB
import copy
import json
import os
from pathlib import Path
from typing import Any
from urllib.parse import quote
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/")
HF_BUCKET_ID = os.getenv("HF_BUCKET_ID", "acul3/acul")
HF_TOKEN = os.getenv("HF_TOKEN")
CACHE_SECONDS = int(os.getenv("CACHE_SECONDS", "3600"))
PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/")
MODEL_ALLOWED_PREFIXES = tuple(
prefix.strip().strip("/")
for prefix in os.getenv("MODEL_ALLOWED_PREFIXES", "MusicModels,DreamLiteModels").split(",")
if prefix.strip()
)
MODEL_MANIFEST_PATH = Path(os.getenv("MODEL_MANIFEST_PATH", "/app/model_manifest.json"))
app = FastAPI(title="Pocket Studio Model Delivery")
def _load_manifest() -> dict[str, Any] | None:
candidates = [
MODEL_MANIFEST_PATH,
Path(__file__).with_name("model_manifest.json"),
]
resolved = Path(__file__).resolve()
if len(resolved.parents) > 2:
candidates.append(resolved.parents[2] / "Config" / "suare-models.json")
for candidate in candidates:
if candidate.exists():
with candidate.open("r", encoding="utf-8") as handle:
return json.load(handle)
return None
MANIFEST = _load_manifest()
def _manifest_paths() -> set[str]:
if not MANIFEST:
return set()
paths: set[str] = set()
for entry in MANIFEST.get("files", {}).values():
store = str(entry.get("store", "")).strip("/")
relative_path = str(entry.get("relativePath", "")).strip("/")
if store and relative_path:
paths.add(f"{store}/{relative_path}")
return paths
ALLOWED_PATHS = _manifest_paths()
def _normalize_model_path(store: str, path: str) -> str:
store = store.strip("/")
path = path.strip("/")
if not store or not path:
raise HTTPException(status_code=404, detail="Model file not found.")
parts = f"{store}/{path}".split("/")
if any(part in ("", ".", "..") for part in parts):
raise HTTPException(status_code=404, detail="Model file not found.")
return "/".join(parts)
def _is_allowed(remote_path: str) -> bool:
if ALLOWED_PATHS:
return remote_path in ALLOWED_PATHS
return any(
remote_path == prefix or remote_path.startswith(prefix + "/")
for prefix in MODEL_ALLOWED_PREFIXES
)
def _bucket_url(remote_path: str) -> str:
encoded_path = quote(remote_path, safe="")
return f"{HF_ENDPOINT}/buckets/{HF_BUCKET_ID}/resolve/{encoded_path}"
def _request_headers(request: Request) -> dict[str, str]:
if not HF_TOKEN:
raise HTTPException(status_code=503, detail="Model delivery is missing its HF_TOKEN secret.")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
if range_header := request.headers.get("range"):
headers["Range"] = range_header
return headers
def _response_headers(upstream: httpx.Response, remote_path: str) -> dict[str, str]:
forwarded = {}
for key in (
"accept-ranges",
"content-length",
"content-range",
"content-type",
"etag",
"last-modified",
):
value = upstream.headers.get(key)
if value:
forwarded[key] = value
forwarded.setdefault("content-type", "application/octet-stream")
forwarded["cache-control"] = f"public, max-age={CACHE_SECONDS}"
forwarded["content-disposition"] = f"attachment; filename=\"{Path(remote_path).name}\""
return forwarded
def _public_base_url(request: Request) -> str:
if PUBLIC_BASE_URL:
return PUBLIC_BASE_URL
host = request.headers.get("x-forwarded-host") or request.headers.get("host")
proto = request.headers.get("x-forwarded-proto") or "https"
if host:
return f"{proto}://{host}".rstrip("/")
return str(request.base_url).rstrip("/").replace("http://", "https://", 1)
async def _open_upstream(request: Request, remote_path: str, method: str) -> httpx.Response:
if not _is_allowed(remote_path):
raise HTTPException(status_code=404, detail="Model file not found.")
client = httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=None), follow_redirects=True)
upstream_request = client.build_request(
method,
_bucket_url(remote_path),
headers=_request_headers(request),
)
try:
response = await client.send(upstream_request, stream=True)
except Exception:
await client.aclose()
raise
response.extensions["delivery_client"] = client
return response
async def _close_upstream(response: httpx.Response) -> None:
await response.aclose()
client = response.extensions.get("delivery_client")
if isinstance(client, httpx.AsyncClient):
await client.aclose()
def _raise_for_upstream(response: httpx.Response, remote_path: str) -> None:
if response.status_code in (401, 403):
raise HTTPException(status_code=503, detail="Private model origin rejected the delivery token.")
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Model file not found.")
if response.status_code not in (200, 206):
raise HTTPException(
status_code=502,
detail=f"Model origin failed for {remote_path} (HTTP {response.status_code}).",
)
@app.get("/health")
async def health() -> dict[str, Any]:
return {
"ok": True,
"bucket": HF_BUCKET_ID,
"manifestLoaded": MANIFEST is not None,
"allowlistSize": len(ALLOWED_PATHS),
"allowedPrefixes": list(MODEL_ALLOWED_PREFIXES),
}
@app.get("/manifest")
async def manifest(request: Request) -> JSONResponse:
if not MANIFEST:
raise HTTPException(status_code=404, detail="No manifest is deployed with this Space.")
base_url = _public_base_url(request)
payload = copy.deepcopy(MANIFEST)
for entry in payload.get("files", {}).values():
store = str(entry.get("store", "")).strip("/")
relative_path = str(entry.get("relativePath", "")).strip("/")
if store and relative_path:
entry["url"] = f"{base_url}/models/{store}/{quote(relative_path)}"
return JSONResponse(payload)
@app.head("/models/{store}/{path:path}")
async def head_model(request: Request, store: str, path: str) -> Response:
remote_path = _normalize_model_path(store, path)
response = await _open_upstream(request, remote_path, "HEAD")
try:
_raise_for_upstream(response, remote_path)
return Response(status_code=response.status_code, headers=_response_headers(response, remote_path))
finally:
await _close_upstream(response)
@app.get("/models/{store}/{path:path}")
async def get_model(request: Request, store: str, path: str) -> StreamingResponse:
remote_path = _normalize_model_path(store, path)
response = await _open_upstream(request, remote_path, "GET")
try:
_raise_for_upstream(response, remote_path)
except Exception:
await _close_upstream(response)
raise
async def body():
try:
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
yield chunk
finally:
await _close_upstream(response)
return StreamingResponse(
body(),
status_code=response.status_code,
headers=_response_headers(response, remote_path),
media_type=response.headers.get("content-type", "application/octet-stream"),
)