import copy import json import os from pathlib import Path from typing import Any from urllib.parse import quote import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") HF_BUCKET_ID = os.getenv("HF_BUCKET_ID", "acul3/acul") HF_TOKEN = os.getenv("HF_TOKEN") CACHE_SECONDS = int(os.getenv("CACHE_SECONDS", "3600")) PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").rstrip("/") MODEL_ALLOWED_PREFIXES = tuple( prefix.strip().strip("/") for prefix in os.getenv("MODEL_ALLOWED_PREFIXES", "MusicModels,DreamLiteModels").split(",") if prefix.strip() ) MODEL_MANIFEST_PATH = Path(os.getenv("MODEL_MANIFEST_PATH", "/app/model_manifest.json")) app = FastAPI(title="Pocket Studio Model Delivery") def _load_manifest() -> dict[str, Any] | None: candidates = [ MODEL_MANIFEST_PATH, Path(__file__).with_name("model_manifest.json"), ] resolved = Path(__file__).resolve() if len(resolved.parents) > 2: candidates.append(resolved.parents[2] / "Config" / "suare-models.json") for candidate in candidates: if candidate.exists(): with candidate.open("r", encoding="utf-8") as handle: return json.load(handle) return None MANIFEST = _load_manifest() def _manifest_paths() -> set[str]: if not MANIFEST: return set() paths: set[str] = set() for entry in MANIFEST.get("files", {}).values(): store = str(entry.get("store", "")).strip("/") relative_path = str(entry.get("relativePath", "")).strip("/") if store and relative_path: paths.add(f"{store}/{relative_path}") return paths ALLOWED_PATHS = _manifest_paths() def _normalize_model_path(store: str, path: str) -> str: store = store.strip("/") path = path.strip("/") if not store or not path: raise HTTPException(status_code=404, detail="Model file not found.") parts = f"{store}/{path}".split("/") if any(part in ("", ".", "..") for part in parts): raise HTTPException(status_code=404, detail="Model file not found.") return "/".join(parts) def _is_allowed(remote_path: str) -> bool: if ALLOWED_PATHS: return remote_path in ALLOWED_PATHS return any( remote_path == prefix or remote_path.startswith(prefix + "/") for prefix in MODEL_ALLOWED_PREFIXES ) def _bucket_url(remote_path: str) -> str: encoded_path = quote(remote_path, safe="") return f"{HF_ENDPOINT}/buckets/{HF_BUCKET_ID}/resolve/{encoded_path}" def _request_headers(request: Request) -> dict[str, str]: if not HF_TOKEN: raise HTTPException(status_code=503, detail="Model delivery is missing its HF_TOKEN secret.") headers = {"Authorization": f"Bearer {HF_TOKEN}"} if range_header := request.headers.get("range"): headers["Range"] = range_header return headers def _response_headers(upstream: httpx.Response, remote_path: str) -> dict[str, str]: forwarded = {} for key in ( "accept-ranges", "content-length", "content-range", "content-type", "etag", "last-modified", ): value = upstream.headers.get(key) if value: forwarded[key] = value forwarded.setdefault("content-type", "application/octet-stream") forwarded["cache-control"] = f"public, max-age={CACHE_SECONDS}" forwarded["content-disposition"] = f"attachment; filename=\"{Path(remote_path).name}\"" return forwarded def _public_base_url(request: Request) -> str: if PUBLIC_BASE_URL: return PUBLIC_BASE_URL host = request.headers.get("x-forwarded-host") or request.headers.get("host") proto = request.headers.get("x-forwarded-proto") or "https" if host: return f"{proto}://{host}".rstrip("/") return str(request.base_url).rstrip("/").replace("http://", "https://", 1) async def _open_upstream(request: Request, remote_path: str, method: str) -> httpx.Response: if not _is_allowed(remote_path): raise HTTPException(status_code=404, detail="Model file not found.") client = httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=None), follow_redirects=True) upstream_request = client.build_request( method, _bucket_url(remote_path), headers=_request_headers(request), ) try: response = await client.send(upstream_request, stream=True) except Exception: await client.aclose() raise response.extensions["delivery_client"] = client return response async def _close_upstream(response: httpx.Response) -> None: await response.aclose() client = response.extensions.get("delivery_client") if isinstance(client, httpx.AsyncClient): await client.aclose() def _raise_for_upstream(response: httpx.Response, remote_path: str) -> None: if response.status_code in (401, 403): raise HTTPException(status_code=503, detail="Private model origin rejected the delivery token.") if response.status_code == 404: raise HTTPException(status_code=404, detail="Model file not found.") if response.status_code not in (200, 206): raise HTTPException( status_code=502, detail=f"Model origin failed for {remote_path} (HTTP {response.status_code}).", ) @app.get("/health") async def health() -> dict[str, Any]: return { "ok": True, "bucket": HF_BUCKET_ID, "manifestLoaded": MANIFEST is not None, "allowlistSize": len(ALLOWED_PATHS), "allowedPrefixes": list(MODEL_ALLOWED_PREFIXES), } @app.get("/manifest") async def manifest(request: Request) -> JSONResponse: if not MANIFEST: raise HTTPException(status_code=404, detail="No manifest is deployed with this Space.") base_url = _public_base_url(request) payload = copy.deepcopy(MANIFEST) for entry in payload.get("files", {}).values(): store = str(entry.get("store", "")).strip("/") relative_path = str(entry.get("relativePath", "")).strip("/") if store and relative_path: entry["url"] = f"{base_url}/models/{store}/{quote(relative_path)}" return JSONResponse(payload) @app.head("/models/{store}/{path:path}") async def head_model(request: Request, store: str, path: str) -> Response: remote_path = _normalize_model_path(store, path) response = await _open_upstream(request, remote_path, "HEAD") try: _raise_for_upstream(response, remote_path) return Response(status_code=response.status_code, headers=_response_headers(response, remote_path)) finally: await _close_upstream(response) @app.get("/models/{store}/{path:path}") async def get_model(request: Request, store: str, path: str) -> StreamingResponse: remote_path = _normalize_model_path(store, path) response = await _open_upstream(request, remote_path, "GET") try: _raise_for_upstream(response, remote_path) except Exception: await _close_upstream(response) raise async def body(): try: async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): yield chunk finally: await _close_upstream(response) return StreamingResponse( body(), status_code=response.status_code, headers=_response_headers(response, remote_path), media_type=response.headers.get("content-type", "application/octet-stream"), )