File size: 4,548 Bytes
a34bca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Synced with huggingface/pyspaces:spaces/zero/api.py
"""
from datetime import timedelta
from typing import Any
from typing import Generator
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import overload

import httpx
from pydantic import BaseModel
from typing_extensions import assert_never


AllowToken = str
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
NvidiaUUID = str
CGroupPath = str
TaskId = int

GPUSize = Literal['large', 'xlarge']
AuthLevel = Literal['regular', 'pro']
QueuingReason = Literal['node', 'concurrency']


AUTHENTICATED_HEADER = 'X-Authenticated'
QUEUING_REASON_HEADER = 'X-Queuing-Reason'


class ScheduleResponse(BaseModel):
    idle: bool
    nvidiaIndex: int
    nvidiaUUID: str
    allowToken: str


class ScheduleMetadata(BaseModel):
    auth: Optional[AuthLevel] = None
    queuing_reason: Optional[QueuingReason] = None


class QuotaInfos(BaseModel):
    left: int
    wait: timedelta


class QueueEvent(BaseModel):
    event: Literal['ping', 'failed', 'succeeded']
    data: Optional[ScheduleResponse] = None


def sse_parse(text: str):
    event, *data = text.strip().splitlines()
    assert event.startswith('event:')
    event = event[6:].strip()
    if event in ('ping', 'failed'):
        return QueueEvent(event=event)
    assert event == 'succeeded'
    (data,) = data
    assert data.startswith('data:')
    data = data[5:].strip()
    return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))


def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
    for text in res.iter_text():
        if len(text) == 0:
            break # pragma: no cover
        try:
            yield sse_parse(text)
        except GeneratorExit:
            res.close()
            break


class APIClient:

    def __init__(self, client: httpx.Client):
        self.client = client

    def startup_report(self) -> httpx.codes:
        res = self.client.post('/startup-report')
        return httpx.codes(res.status_code)

    def schedule(
        self,
        cgroup_path: str,
        task_id: int = 0,
        token: str | None = None,
        token_version: int = 1,
        duration_seconds: int | None = None,
        enable_queue: bool = True,
        gpu_size: GPUSize | None = None,
    ):
        params: dict[str, str | int | bool] = {
            'cgroupPath': cgroup_path,
            'taskId': task_id,
            'enableQueue': enable_queue,
            'tokenVersion': token_version,
        }
        if duration_seconds is not None:
            params['durationSeconds'] = duration_seconds
        if gpu_size is not None:
            params['gpuSize'] = gpu_size
        if token is not None:
            params['token'] = token
        res = self.client.send(
            request=self.client.build_request(
                method='POST',
                url='/schedule',
                params=params,
            ),
            stream=True,
        )
        status = httpx.codes(res.status_code)
        auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
        queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER)
        metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason)
        if (status is not httpx.codes.OK and
            status is not httpx.codes.TOO_MANY_REQUESTS
        ):
            res.close()
            return status, metadata
        if "text/event-stream" in res.headers['content-type']:
            return sse_stream(res), metadata
        res.read()
        if status is httpx.codes.TOO_MANY_REQUESTS:
            return QuotaInfos(**res.json()), metadata # pragma: no cover
        if status is httpx.codes.OK:
            return ScheduleResponse(**res.json()), metadata
        assert_never(status)

    def allow(
        self,
        allow_token: str,
        pid: int,
    ):
        res = self.client.post('/allow', params={
            'allowToken': allow_token,
            'pid': pid,
        })
        return httpx.codes(res.status_code)

    def release(
        self,
        allow_token: str,
        fail: bool = False,
    ) -> httpx.codes:
        res = self.client.post('/release', params={
            'allowToken': allow_token,
            'fail': fail,
        })
        return httpx.codes(res.status_code)

    def get_queue_size(self) -> float:
        res = self.client.get('/queue-size')
        assert res.status_code == 200, res.status_code
        size = res.json()
        return size