test_489 / trackio /asgi_app.py
abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
7a943a8 verified
from __future__ import annotations
import inspect
import json
import math
import tempfile
import threading
from pathlib import Path
from typing import Any
from urllib.parse import unquote
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import FileResponse, JSONResponse, Response
from starlette.routing import Route
from trackio.exceptions import TrackioAPIError
from trackio.remote_client import HTTP_API_VERSION
_PACKAGE_JSON_PATH = Path(__file__).parent / "package.json"
_TRACKIO_PACKAGE_VERSION = json.loads(_PACKAGE_JSON_PATH.read_text())["version"]
def _normalize_allowed_file_roots(
allowed_file_roots: list[str | Path] | None,
) -> tuple[Path, ...]:
roots = []
for root in allowed_file_roots or []:
roots.append(Path(root).resolve())
return tuple(roots)
def _is_allowed_file_path(path: Path, allowed_roots: tuple[Path, ...]) -> bool:
resolved_path = path.resolve(strict=False)
for root in allowed_roots:
try:
resolved_path.relative_to(root)
return True
except ValueError:
continue
return False
def _json_safe(data: Any) -> Any:
if data is None or isinstance(data, (str, bool, int)):
return data
if isinstance(data, float):
return data if math.isfinite(data) else None
if isinstance(data, dict):
return {k: _json_safe(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return [_json_safe(v) for v in data]
if hasattr(data, "item"):
try:
return _json_safe(data.item())
except Exception:
pass
return str(data)
def register_uploaded_temp_file(request: Request, file_path: str | Path) -> None:
resolved_path = Path(file_path).resolve(strict=False)
with request.app.state.uploaded_temp_files_lock:
request.app.state.uploaded_temp_files.add(resolved_path)
def consume_uploaded_temp_file(request: Request, file_data: Any) -> Path:
file_path = file_data.get("path") if isinstance(file_data, dict) else None
if not isinstance(file_path, str) or not file_path:
raise TrackioAPIError("Expected uploaded file metadata with a valid path.")
resolved_path = Path(file_path).resolve(strict=False)
with request.app.state.uploaded_temp_files_lock:
if resolved_path not in request.app.state.uploaded_temp_files:
raise TrackioAPIError(
"Uploaded file was not created by this Trackio server."
)
request.app.state.uploaded_temp_files.remove(resolved_path)
if not resolved_path.is_file():
raise TrackioAPIError("Uploaded file is missing.")
return resolved_path
def _invoke_handler(
fn: Any,
request: Request,
args: list[Any] | None = None,
kwargs: dict[str, Any] | None = None,
) -> Any:
sig = inspect.signature(fn)
params = list(sig.parameters.values())
positional_args: list[Any] = []
keyword_args: dict[str, Any] = {}
args = list(args or [])
kwargs = dict(kwargs or {})
data_index = 0
for param in params:
if param.name == "request":
keyword_args["request"] = request
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
positional_args.extend(args[data_index:])
data_index = len(args)
elif param.kind == inspect.Parameter.VAR_KEYWORD:
keyword_args.update(kwargs)
kwargs.clear()
elif param.name in kwargs:
keyword_args[param.name] = kwargs.pop(param.name)
elif param.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
) and data_index < len(args):
positional_args.append(args[data_index])
data_index += 1
elif param.default is inspect.Signature.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
raise TrackioAPIError(f"Missing required parameter: {param.name}")
return fn(*positional_args, **keyword_args)
async def version_handler(request: Request) -> Response:
mcp_enabled = bool(getattr(request.app.state, "mcp_enabled", False))
return JSONResponse(
{
"version": _TRACKIO_PACKAGE_VERSION,
"api_version": HTTP_API_VERSION,
"api_transport": "http",
"mcp_enabled": mcp_enabled,
"mcp_path": "/mcp" if mcp_enabled else None,
}
)
async def api_handler(request: Request) -> Response:
api_registry = request.app.state.api_registry
api_name = request.path_params["api_name"]
fn = api_registry.get(api_name)
if fn is None:
return JSONResponse({"error": f"Unknown API: {api_name}"}, status_code=404)
try:
body = await request.json()
except Exception:
body = {}
args: list[Any] = []
kwargs: dict[str, Any] = {}
if isinstance(body, dict):
if "args" in body or "kwargs" in body:
args = body.get("args") or []
kwargs = body.get("kwargs") or {}
elif "data" in body and isinstance(body["data"], list):
args = body["data"]
else:
kwargs = body
elif isinstance(body, list):
args = body
elif body is not None:
args = [body]
if not isinstance(args, list):
args = [args]
if not isinstance(kwargs, dict):
kwargs = {}
try:
result = _invoke_handler(fn, request, args=args, kwargs=kwargs)
return JSONResponse({"data": _json_safe(result)})
except TrackioAPIError as e:
return JSONResponse({"error": str(e)}, status_code=400)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
async def upload_handler(request: Request) -> Response:
form = await request.form()
uploads = form.getlist("files")
saved_paths = []
for upload in uploads:
suffix = Path(getattr(upload, "filename", "") or "").suffix
with tempfile.NamedTemporaryFile(
delete=False,
prefix="trackio-upload-",
suffix=suffix,
) as tmp:
tmp.write(await upload.read())
register_uploaded_temp_file(request, tmp.name)
saved_paths.append(tmp.name)
return JSONResponse({"paths": saved_paths})
async def file_handler(request: Request) -> Response:
fs_path = request.query_params.get("path")
if fs_path is None:
return Response("Missing path", status_code=400)
fp = Path(unquote(fs_path)).resolve(strict=False)
allowed_roots = getattr(request.app.state, "allowed_file_roots", ())
if fp.is_file() and _is_allowed_file_path(fp, allowed_roots):
return FileResponse(str(fp))
return Response("Not found", status_code=404)
def create_trackio_starlette_app(
oauth_routes: list[Route],
api_registry: dict[str, Any],
extra_routes: list[Any] | None = None,
mcp_lifespan: Any = None,
mcp_enabled: bool = False,
allowed_file_roots: list[str | Path] | None = None,
) -> Starlette:
routes: list[Any] = list(oauth_routes)
routes.extend(
[
Route("/version", endpoint=version_handler, methods=["GET"]),
Route("/api/upload", endpoint=upload_handler, methods=["POST"]),
Route("/api/{api_name:str}", endpoint=api_handler, methods=["POST"]),
Route("/file", endpoint=file_handler, methods=["GET"]),
]
)
routes.extend(extra_routes or [])
app = Starlette(routes=routes, lifespan=mcp_lifespan)
app.state.api_registry = api_registry
app.state.mcp_enabled = mcp_enabled
app.state.allowed_file_roots = _normalize_allowed_file_roots(allowed_file_roots)
app.state.uploaded_temp_files = set()
app.state.uploaded_temp_files_lock = threading.Lock()
return app