multimodalart's picture
multimodalart HF Staff
Upload 75 files
a34bca4 verified
"""
Synced with huggingface/pyspaces:spaces/zero/api.py
"""
from datetime import timedelta
from typing import Any
from typing import Generator
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import overload
import httpx
from pydantic import BaseModel
from typing_extensions import assert_never
AllowToken = str
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
NvidiaUUID = str
CGroupPath = str
TaskId = int
GPUSize = Literal['large', 'xlarge']
AuthLevel = Literal['regular', 'pro']
QueuingReason = Literal['node', 'concurrency']
AUTHENTICATED_HEADER = 'X-Authenticated'
QUEUING_REASON_HEADER = 'X-Queuing-Reason'
class ScheduleResponse(BaseModel):
idle: bool
nvidiaIndex: int
nvidiaUUID: str
allowToken: str
class ScheduleMetadata(BaseModel):
auth: Optional[AuthLevel] = None
queuing_reason: Optional[QueuingReason] = None
class QuotaInfos(BaseModel):
left: int
wait: timedelta
class QueueEvent(BaseModel):
event: Literal['ping', 'failed', 'succeeded']
data: Optional[ScheduleResponse] = None
def sse_parse(text: str):
event, *data = text.strip().splitlines()
assert event.startswith('event:')
event = event[6:].strip()
if event in ('ping', 'failed'):
return QueueEvent(event=event)
assert event == 'succeeded'
(data,) = data
assert data.startswith('data:')
data = data[5:].strip()
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
for text in res.iter_text():
if len(text) == 0:
break # pragma: no cover
try:
yield sse_parse(text)
except GeneratorExit:
res.close()
break
class APIClient:
def __init__(self, client: httpx.Client):
self.client = client
def startup_report(self) -> httpx.codes:
res = self.client.post('/startup-report')
return httpx.codes(res.status_code)
def schedule(
self,
cgroup_path: str,
task_id: int = 0,
token: str | None = None,
token_version: int = 1,
duration_seconds: int | None = None,
enable_queue: bool = True,
gpu_size: GPUSize | None = None,
):
params: dict[str, str | int | bool] = {
'cgroupPath': cgroup_path,
'taskId': task_id,
'enableQueue': enable_queue,
'tokenVersion': token_version,
}
if duration_seconds is not None:
params['durationSeconds'] = duration_seconds
if gpu_size is not None:
params['gpuSize'] = gpu_size
if token is not None:
params['token'] = token
res = self.client.send(
request=self.client.build_request(
method='POST',
url='/schedule',
params=params,
),
stream=True,
)
status = httpx.codes(res.status_code)
auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER)
metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason)
if (status is not httpx.codes.OK and
status is not httpx.codes.TOO_MANY_REQUESTS
):
res.close()
return status, metadata
if "text/event-stream" in res.headers['content-type']:
return sse_stream(res), metadata
res.read()
if status is httpx.codes.TOO_MANY_REQUESTS:
return QuotaInfos(**res.json()), metadata # pragma: no cover
if status is httpx.codes.OK:
return ScheduleResponse(**res.json()), metadata
assert_never(status)
def allow(
self,
allow_token: str,
pid: int,
):
res = self.client.post('/allow', params={
'allowToken': allow_token,
'pid': pid,
})
return httpx.codes(res.status_code)
def release(
self,
allow_token: str,
fail: bool = False,
) -> httpx.codes:
res = self.client.post('/release', params={
'allowToken': allow_token,
'fail': fail,
})
return httpx.codes(res.status_code)
def get_queue_size(self) -> float:
res = self.client.get('/queue-size')
assert res.status_code == 200, res.status_code
size = res.json()
return size