""" Synced with huggingface/pyspaces:spaces/zero/api.py """ from datetime import timedelta from typing import Any from typing import Generator from typing import Literal from typing import NamedTuple from typing import Optional from typing import overload import httpx from pydantic import BaseModel from typing_extensions import assert_never AllowToken = str NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG) NvidiaUUID = str CGroupPath = str TaskId = int GPUSize = Literal['large', 'xlarge'] AuthLevel = Literal['regular', 'pro'] QueuingReason = Literal['node', 'concurrency'] AUTHENTICATED_HEADER = 'X-Authenticated' QUEUING_REASON_HEADER = 'X-Queuing-Reason' class ScheduleResponse(BaseModel): idle: bool nvidiaIndex: int nvidiaUUID: str allowToken: str class ScheduleMetadata(BaseModel): auth: Optional[AuthLevel] = None queuing_reason: Optional[QueuingReason] = None class QuotaInfos(BaseModel): left: int wait: timedelta class QueueEvent(BaseModel): event: Literal['ping', 'failed', 'succeeded'] data: Optional[ScheduleResponse] = None def sse_parse(text: str): event, *data = text.strip().splitlines() assert event.startswith('event:') event = event[6:].strip() if event in ('ping', 'failed'): return QueueEvent(event=event) assert event == 'succeeded' (data,) = data assert data.startswith('data:') data = data[5:].strip() return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data)) def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]: for text in res.iter_text(): if len(text) == 0: break # pragma: no cover try: yield sse_parse(text) except GeneratorExit: res.close() break class APIClient: def __init__(self, client: httpx.Client): self.client = client def startup_report(self) -> httpx.codes: res = self.client.post('/startup-report') return httpx.codes(res.status_code) def schedule( self, cgroup_path: str, task_id: int = 0, token: str | None = None, token_version: int = 1, duration_seconds: int | None = None, enable_queue: bool = True, gpu_size: GPUSize | None = None, ): params: dict[str, str | int | bool] = { 'cgroupPath': cgroup_path, 'taskId': task_id, 'enableQueue': enable_queue, 'tokenVersion': token_version, } if duration_seconds is not None: params['durationSeconds'] = duration_seconds if gpu_size is not None: params['gpuSize'] = gpu_size if token is not None: params['token'] = token res = self.client.send( request=self.client.build_request( method='POST', url='/schedule', params=params, ), stream=True, ) status = httpx.codes(res.status_code) auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER) queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER) metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason) if (status is not httpx.codes.OK and status is not httpx.codes.TOO_MANY_REQUESTS ): res.close() return status, metadata if "text/event-stream" in res.headers['content-type']: return sse_stream(res), metadata res.read() if status is httpx.codes.TOO_MANY_REQUESTS: return QuotaInfos(**res.json()), metadata # pragma: no cover if status is httpx.codes.OK: return ScheduleResponse(**res.json()), metadata assert_never(status) def allow( self, allow_token: str, pid: int, ): res = self.client.post('/allow', params={ 'allowToken': allow_token, 'pid': pid, }) return httpx.codes(res.status_code) def release( self, allow_token: str, fail: bool = False, ) -> httpx.codes: res = self.client.post('/release', params={ 'allowToken': allow_token, 'fail': fail, }) return httpx.codes(res.status_code) def get_queue_size(self) -> float: res = self.client.get('/queue-size') assert res.status_code == 200, res.status_code size = res.json() return size