Spaces:
Sleeping
Sleeping
File size: 4,548 Bytes
a34bca4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
Synced with huggingface/pyspaces:spaces/zero/api.py
"""
from datetime import timedelta
from typing import Any
from typing import Generator
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import overload
import httpx
from pydantic import BaseModel
from typing_extensions import assert_never
AllowToken = str
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
NvidiaUUID = str
CGroupPath = str
TaskId = int
GPUSize = Literal['large', 'xlarge']
AuthLevel = Literal['regular', 'pro']
QueuingReason = Literal['node', 'concurrency']
AUTHENTICATED_HEADER = 'X-Authenticated'
QUEUING_REASON_HEADER = 'X-Queuing-Reason'
class ScheduleResponse(BaseModel):
idle: bool
nvidiaIndex: int
nvidiaUUID: str
allowToken: str
class ScheduleMetadata(BaseModel):
auth: Optional[AuthLevel] = None
queuing_reason: Optional[QueuingReason] = None
class QuotaInfos(BaseModel):
left: int
wait: timedelta
class QueueEvent(BaseModel):
event: Literal['ping', 'failed', 'succeeded']
data: Optional[ScheduleResponse] = None
def sse_parse(text: str):
event, *data = text.strip().splitlines()
assert event.startswith('event:')
event = event[6:].strip()
if event in ('ping', 'failed'):
return QueueEvent(event=event)
assert event == 'succeeded'
(data,) = data
assert data.startswith('data:')
data = data[5:].strip()
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
for text in res.iter_text():
if len(text) == 0:
break # pragma: no cover
try:
yield sse_parse(text)
except GeneratorExit:
res.close()
break
class APIClient:
def __init__(self, client: httpx.Client):
self.client = client
def startup_report(self) -> httpx.codes:
res = self.client.post('/startup-report')
return httpx.codes(res.status_code)
def schedule(
self,
cgroup_path: str,
task_id: int = 0,
token: str | None = None,
token_version: int = 1,
duration_seconds: int | None = None,
enable_queue: bool = True,
gpu_size: GPUSize | None = None,
):
params: dict[str, str | int | bool] = {
'cgroupPath': cgroup_path,
'taskId': task_id,
'enableQueue': enable_queue,
'tokenVersion': token_version,
}
if duration_seconds is not None:
params['durationSeconds'] = duration_seconds
if gpu_size is not None:
params['gpuSize'] = gpu_size
if token is not None:
params['token'] = token
res = self.client.send(
request=self.client.build_request(
method='POST',
url='/schedule',
params=params,
),
stream=True,
)
status = httpx.codes(res.status_code)
auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER)
metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason)
if (status is not httpx.codes.OK and
status is not httpx.codes.TOO_MANY_REQUESTS
):
res.close()
return status, metadata
if "text/event-stream" in res.headers['content-type']:
return sse_stream(res), metadata
res.read()
if status is httpx.codes.TOO_MANY_REQUESTS:
return QuotaInfos(**res.json()), metadata # pragma: no cover
if status is httpx.codes.OK:
return ScheduleResponse(**res.json()), metadata
assert_never(status)
def allow(
self,
allow_token: str,
pid: int,
):
res = self.client.post('/allow', params={
'allowToken': allow_token,
'pid': pid,
})
return httpx.codes(res.status_code)
def release(
self,
allow_token: str,
fail: bool = False,
) -> httpx.codes:
res = self.client.post('/release', params={
'allowToken': allow_token,
'fail': fail,
})
return httpx.codes(res.status_code)
def get_queue_size(self) -> float:
res = self.client.get('/queue-size')
assert res.status_code == 200, res.status_code
size = res.json()
return size
|