Spaces:
Running on Zero
Running on Zero
| """ | |
| Synced with huggingface/pyspaces:spaces/zero/api.py | |
| """ | |
| from datetime import timedelta | |
| from typing import Any | |
| from typing import Generator | |
| from typing import Literal | |
| from typing import NamedTuple | |
| from typing import Optional | |
| from typing import overload | |
| import httpx | |
| from pydantic import BaseModel | |
| from typing_extensions import assert_never | |
| AllowToken = str | |
| NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG) | |
| NvidiaUUID = str | |
| CGroupPath = str | |
| TaskId = int | |
| GPUSize = Literal['large', 'xlarge'] | |
| AuthLevel = Literal['regular', 'pro'] | |
| QueuingReason = Literal['node', 'concurrency'] | |
| AUTHENTICATED_HEADER = 'X-Authenticated' | |
| QUEUING_REASON_HEADER = 'X-Queuing-Reason' | |
| class ScheduleResponse(BaseModel): | |
| idle: bool | |
| nvidiaIndex: int | |
| nvidiaUUID: str | |
| allowToken: str | |
| class ScheduleMetadata(BaseModel): | |
| auth: Optional[AuthLevel] = None | |
| queuing_reason: Optional[QueuingReason] = None | |
| class QuotaInfos(BaseModel): | |
| left: int | |
| wait: timedelta | |
| class QueueEvent(BaseModel): | |
| event: Literal['ping', 'failed', 'succeeded'] | |
| data: Optional[ScheduleResponse] = None | |
| def sse_parse(text: str): | |
| event, *data = text.strip().splitlines() | |
| assert event.startswith('event:') | |
| event = event[6:].strip() | |
| if event in ('ping', 'failed'): | |
| return QueueEvent(event=event) | |
| assert event == 'succeeded' | |
| (data,) = data | |
| assert data.startswith('data:') | |
| data = data[5:].strip() | |
| return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data)) | |
| def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]: | |
| for text in res.iter_text(): | |
| if len(text) == 0: | |
| break # pragma: no cover | |
| try: | |
| yield sse_parse(text) | |
| except GeneratorExit: | |
| res.close() | |
| break | |
| class APIClient: | |
| def __init__(self, client: httpx.Client): | |
| self.client = client | |
| def startup_report(self) -> httpx.codes: | |
| res = self.client.post('/startup-report') | |
| return httpx.codes(res.status_code) | |
| def schedule( | |
| self, | |
| cgroup_path: str, | |
| task_id: int = 0, | |
| token: str | None = None, | |
| token_version: int = 1, | |
| duration_seconds: int | None = None, | |
| enable_queue: bool = True, | |
| gpu_size: GPUSize | None = None, | |
| ): | |
| params: dict[str, str | int | bool] = { | |
| 'cgroupPath': cgroup_path, | |
| 'taskId': task_id, | |
| 'enableQueue': enable_queue, | |
| 'tokenVersion': token_version, | |
| } | |
| if duration_seconds is not None: | |
| params['durationSeconds'] = duration_seconds | |
| if gpu_size is not None: | |
| params['gpuSize'] = gpu_size | |
| if token is not None: | |
| params['token'] = token | |
| res = self.client.send( | |
| request=self.client.build_request( | |
| method='POST', | |
| url='/schedule', | |
| params=params, | |
| ), | |
| stream=True, | |
| ) | |
| status = httpx.codes(res.status_code) | |
| auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER) | |
| queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER) | |
| metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason) | |
| if (status is not httpx.codes.OK and | |
| status is not httpx.codes.TOO_MANY_REQUESTS | |
| ): | |
| res.close() | |
| return status, metadata | |
| if "text/event-stream" in res.headers['content-type']: | |
| return sse_stream(res), metadata | |
| res.read() | |
| if status is httpx.codes.TOO_MANY_REQUESTS: | |
| return QuotaInfos(**res.json()), metadata # pragma: no cover | |
| if status is httpx.codes.OK: | |
| return ScheduleResponse(**res.json()), metadata | |
| assert_never(status) | |
| def allow( | |
| self, | |
| allow_token: str, | |
| pid: int, | |
| ): | |
| res = self.client.post('/allow', params={ | |
| 'allowToken': allow_token, | |
| 'pid': pid, | |
| }) | |
| return httpx.codes(res.status_code) | |
| def release( | |
| self, | |
| allow_token: str, | |
| fail: bool = False, | |
| ) -> httpx.codes: | |
| res = self.client.post('/release', params={ | |
| 'allowToken': allow_token, | |
| 'fail': fail, | |
| }) | |
| return httpx.codes(res.status_code) | |
| def get_queue_size(self) -> float: | |
| res = self.client.get('/queue-size') | |
| assert res.status_code == 200, res.status_code | |
| size = res.json() | |
| return size | |