| import copy |
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import quote |
|
|
| import httpx |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import JSONResponse, Response, StreamingResponse |
|
|
|
|
| HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") |
| HF_BUCKET_ID = os.getenv("HF_BUCKET_ID", "acul3/acul") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| CACHE_SECONDS = int(os.getenv("CACHE_SECONDS", "3600")) |
| PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/") |
| MODEL_ALLOWED_PREFIXES = tuple( |
| prefix.strip().strip("/") |
| for prefix in os.getenv("MODEL_ALLOWED_PREFIXES", "MusicModels,DreamLiteModels").split(",") |
| if prefix.strip() |
| ) |
| MODEL_MANIFEST_PATH = Path(os.getenv("MODEL_MANIFEST_PATH", "/app/model_manifest.json")) |
|
|
| app = FastAPI(title="Pocket Studio Model Delivery") |
|
|
|
|
| def _load_manifest() -> dict[str, Any] | None: |
| candidates = [ |
| MODEL_MANIFEST_PATH, |
| Path(__file__).with_name("model_manifest.json"), |
| ] |
| resolved = Path(__file__).resolve() |
| if len(resolved.parents) > 2: |
| candidates.append(resolved.parents[2] / "Config" / "suare-models.json") |
| for candidate in candidates: |
| if candidate.exists(): |
| with candidate.open("r", encoding="utf-8") as handle: |
| return json.load(handle) |
| return None |
|
|
|
|
| MANIFEST = _load_manifest() |
|
|
|
|
| def _manifest_paths() -> set[str]: |
| if not MANIFEST: |
| return set() |
| paths: set[str] = set() |
| for entry in MANIFEST.get("files", {}).values(): |
| store = str(entry.get("store", "")).strip("/") |
| relative_path = str(entry.get("relativePath", "")).strip("/") |
| if store and relative_path: |
| paths.add(f"{store}/{relative_path}") |
| return paths |
|
|
|
|
| ALLOWED_PATHS = _manifest_paths() |
|
|
|
|
| def _normalize_model_path(store: str, path: str) -> str: |
| store = store.strip("/") |
| path = path.strip("/") |
| if not store or not path: |
| raise HTTPException(status_code=404, detail="Model file not found.") |
| parts = f"{store}/{path}".split("/") |
| if any(part in ("", ".", "..") for part in parts): |
| raise HTTPException(status_code=404, detail="Model file not found.") |
| return "/".join(parts) |
|
|
|
|
| def _is_allowed(remote_path: str) -> bool: |
| if ALLOWED_PATHS: |
| return remote_path in ALLOWED_PATHS |
| return any( |
| remote_path == prefix or remote_path.startswith(prefix + "/") |
| for prefix in MODEL_ALLOWED_PREFIXES |
| ) |
|
|
|
|
| def _bucket_url(remote_path: str) -> str: |
| encoded_path = quote(remote_path, safe="") |
| return f"{HF_ENDPOINT}/buckets/{HF_BUCKET_ID}/resolve/{encoded_path}" |
|
|
|
|
| def _request_headers(request: Request) -> dict[str, str]: |
| if not HF_TOKEN: |
| raise HTTPException(status_code=503, detail="Model delivery is missing its HF_TOKEN secret.") |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
| if range_header := request.headers.get("range"): |
| headers["Range"] = range_header |
| return headers |
|
|
|
|
| def _response_headers(upstream: httpx.Response, remote_path: str) -> dict[str, str]: |
| forwarded = {} |
| for key in ( |
| "accept-ranges", |
| "content-length", |
| "content-range", |
| "content-type", |
| "etag", |
| "last-modified", |
| ): |
| value = upstream.headers.get(key) |
| if value: |
| forwarded[key] = value |
| forwarded.setdefault("content-type", "application/octet-stream") |
| forwarded["cache-control"] = f"public, max-age={CACHE_SECONDS}" |
| forwarded["content-disposition"] = f"attachment; filename=\"{Path(remote_path).name}\"" |
| return forwarded |
|
|
|
|
| def _public_base_url(request: Request) -> str: |
| if PUBLIC_BASE_URL: |
| return PUBLIC_BASE_URL |
| host = request.headers.get("x-forwarded-host") or request.headers.get("host") |
| proto = request.headers.get("x-forwarded-proto") or "https" |
| if host: |
| return f"{proto}://{host}".rstrip("/") |
| return str(request.base_url).rstrip("/").replace("http://", "https://", 1) |
|
|
|
|
| async def _open_upstream(request: Request, remote_path: str, method: str) -> httpx.Response: |
| if not _is_allowed(remote_path): |
| raise HTTPException(status_code=404, detail="Model file not found.") |
| client = httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=None), follow_redirects=True) |
| upstream_request = client.build_request( |
| method, |
| _bucket_url(remote_path), |
| headers=_request_headers(request), |
| ) |
| try: |
| response = await client.send(upstream_request, stream=True) |
| except Exception: |
| await client.aclose() |
| raise |
| response.extensions["delivery_client"] = client |
| return response |
|
|
|
|
| async def _close_upstream(response: httpx.Response) -> None: |
| await response.aclose() |
| client = response.extensions.get("delivery_client") |
| if isinstance(client, httpx.AsyncClient): |
| await client.aclose() |
|
|
|
|
| def _raise_for_upstream(response: httpx.Response, remote_path: str) -> None: |
| if response.status_code in (401, 403): |
| raise HTTPException(status_code=503, detail="Private model origin rejected the delivery token.") |
| if response.status_code == 404: |
| raise HTTPException(status_code=404, detail="Model file not found.") |
| if response.status_code not in (200, 206): |
| raise HTTPException( |
| status_code=502, |
| detail=f"Model origin failed for {remote_path} (HTTP {response.status_code}).", |
| ) |
|
|
|
|
| @app.get("/health") |
| async def health() -> dict[str, Any]: |
| return { |
| "ok": True, |
| "bucket": HF_BUCKET_ID, |
| "manifestLoaded": MANIFEST is not None, |
| "allowlistSize": len(ALLOWED_PATHS), |
| "allowedPrefixes": list(MODEL_ALLOWED_PREFIXES), |
| } |
|
|
|
|
| @app.get("/manifest") |
| async def manifest(request: Request) -> JSONResponse: |
| if not MANIFEST: |
| raise HTTPException(status_code=404, detail="No manifest is deployed with this Space.") |
| base_url = _public_base_url(request) |
| payload = copy.deepcopy(MANIFEST) |
| for entry in payload.get("files", {}).values(): |
| store = str(entry.get("store", "")).strip("/") |
| relative_path = str(entry.get("relativePath", "")).strip("/") |
| if store and relative_path: |
| entry["url"] = f"{base_url}/models/{store}/{quote(relative_path)}" |
| return JSONResponse(payload) |
|
|
|
|
| @app.head("/models/{store}/{path:path}") |
| async def head_model(request: Request, store: str, path: str) -> Response: |
| remote_path = _normalize_model_path(store, path) |
| response = await _open_upstream(request, remote_path, "HEAD") |
| try: |
| _raise_for_upstream(response, remote_path) |
| return Response(status_code=response.status_code, headers=_response_headers(response, remote_path)) |
| finally: |
| await _close_upstream(response) |
|
|
|
|
| @app.get("/models/{store}/{path:path}") |
| async def get_model(request: Request, store: str, path: str) -> StreamingResponse: |
| remote_path = _normalize_model_path(store, path) |
| response = await _open_upstream(request, remote_path, "GET") |
| try: |
| _raise_for_upstream(response, remote_path) |
| except Exception: |
| await _close_upstream(response) |
| raise |
|
|
| async def body(): |
| try: |
| async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): |
| yield chunk |
| finally: |
| await _close_upstream(response) |
|
|
| return StreamingResponse( |
| body(), |
| status_code=response.status_code, |
| headers=_response_headers(response, remote_path), |
| media_type=response.headers.get("content-type", "application/octet-stream"), |
| ) |
|
|