File size: 8,339 Bytes
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e4ece
 
 
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e4ece
 
 
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e4ece
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e4ece
 
 
 
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""XAI rate-limits API protocol β€” fetch live quota data per mode."""

import asyncio

import orjson

from app.platform.errors import UpstreamError
from app.platform.logging.logger import logger
from app.platform.runtime.clock import now_ms


# ---------------------------------------------------------------------------
# Mode β†’ request model name
# ---------------------------------------------------------------------------

# Each mode requires a separate POST request.
# The API returns a flat single-mode quota object per call.
_MODE_NAMES: dict[int, str] = {
    0: "auto",
    1: "fast",
    2: "expert",
    3: "heavy",
    4: "grok-420-computer-use-sa",
}

# Default window durations used as fallback when API call fails.
_DEFAULT_WINDOW_SECS: dict[int, int] = {
    0: 7_200,  # auto   β€” 2 h  (super/heavy only)
    1: 86_400,  # fast   β€” 24 h (basic; real value overrides for super/heavy)
    2: 7_200,  # expert β€” 2 h  (super/heavy only)
    3: 7_200,  # heavy  β€” 2 h  (heavy-pool only)
    4: 7_200,  # grok_4_3 β€” 2 h  (super/heavy only)
}


def _build_payload(mode_name: str) -> bytes:
    """Build rate-limits request payload: {"modelName": "fast"}"""
    return orjson.dumps({"modelName": mode_name})


# ---------------------------------------------------------------------------
# Response parser
# ---------------------------------------------------------------------------


def parse_rate_limits(
    body: dict, *, default_window_seconds: int = 72_000
) -> dict | None:
    """Parse flat rate-limits response.

    Expected format::

        {
            "windowSizeSeconds": 72000,
            "remainingQueries":  20,
            "totalQueries":      20,
            "lowEffortRateLimits":  null,
            "highEffortRateLimits": null
        }

    Returns a dict with keys ``remaining``, ``total``, ``window_seconds``
    or ``None`` if the required fields are absent.
    """
    remaining = body.get("remainingQueries")
    if remaining is None:
        return None
    total = body.get("totalQueries")
    window_secs = body.get("windowSizeSeconds")
    return {
        "remaining": int(remaining),
        "total": int(total) if total is not None else int(remaining),
        "window_seconds": int(window_secs) if window_secs else default_window_seconds,
    }


# ---------------------------------------------------------------------------
# QuotaWindow builder
# ---------------------------------------------------------------------------


def _to_quota_window(data: dict, synced_at: int) -> object:
    from app.control.account.models import QuotaWindow
    from app.control.account.enums import QuotaSource

    ws = data["window_seconds"]
    return QuotaWindow(
        remaining=data["remaining"],
        total=data["total"],
        window_seconds=ws,
        reset_at=synced_at + ws * 1000,  # estimated end of current window
        synced_at=synced_at,
        source=QuotaSource.REAL,
    )


# ---------------------------------------------------------------------------
# HTTP fetch
# ---------------------------------------------------------------------------


async def _do_fetch(token: str, mode_name: str) -> dict:
    """POST the rate-limits endpoint for one mode and return parsed JSON body."""
    from app.dataplane.reverse.transport.http import post_json
    from app.dataplane.proxy import get_proxy_runtime
    from app.control.proxy.models import ProxyFeedback, ProxyFeedbackKind

    proxy = await get_proxy_runtime()
    lease = await proxy.acquire()
    try:
        body = await post_json(
            "https://grok.com/rest/rate-limits",
            token,
            _build_payload(mode_name),
            lease=lease,
            timeout_s=20.0,
        )
        await proxy.feedback(
            lease, ProxyFeedback(kind=ProxyFeedbackKind.SUCCESS, status_code=200)
        )
        return body
    except Exception as exc:
        status = getattr(exc, "status", None) or getattr(exc, "status_code", None)
        kind = _proxy_feedback_kind_for_error(exc, status=status)
        await proxy.feedback(lease, ProxyFeedback(kind=kind, status_code=status))
        raise


async def _fetch_one(token: str, mode_id: int) -> object | None:
    """Fetch quota window for a single mode. Returns QuotaWindow or None."""
    mode_name = _MODE_NAMES.get(mode_id, "auto")
    try:
        body = await asyncio.wait_for(_do_fetch(token, mode_name), timeout=25.0)
    except asyncio.TimeoutError:
        logger.debug(
            "rate-limits fetch timed out: token={}... mode={}", token[:10], mode_name
        )
        return None
    except Exception as exc:
        if is_invalid_credentials_error(exc):
            raise
        logger.debug(
            "rate-limits fetch failed: token={}... mode={} error={}",
            token[:10],
            mode_name,
            exc,
        )
        return None

    data = parse_rate_limits(
        body,
        default_window_seconds=_DEFAULT_WINDOW_SECS.get(mode_id, 72_000),
    )
    if data is None:
        logger.debug(
            "rate-limits response missing quota fields: token={}... mode={} body={}",
            token[:10],
            mode_name,
            body,
        )
        return None

    return _to_quota_window(data, now_ms())


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


async def fetch_all_quotas(
    token: str, mode_ids: tuple[int, ...] | None = None
) -> dict[int, object] | None:
    """Fetch quota windows for the requested modes concurrently.

    ``mode_ids`` defaults to ``(0, 1, 2, 3, 4)``. Returns ``{mode_id: QuotaWindow}``
    for every mode that responded successfully, or ``None`` if every requested
    mode failed.
    """
    import asyncio

    requested = mode_ids or (0, 1, 2, 3, 4)
    results = await asyncio.gather(
        *(_fetch_one(token, mode_id) for mode_id in requested), return_exceptions=True
    )
    for result in results:
        if isinstance(result, Exception) and is_invalid_credentials_error(result):
            raise result
    windows = {
        mode_id: win
        for mode_id, win in zip(requested, results, strict=False)
        if win is not None and not isinstance(win, Exception)
    }
    return windows if windows else None


async def fetch_mode_quota(token: str, mode_id: int) -> object | None:
    """Fetch the quota window for a single mode. Returns QuotaWindow or None."""
    return await _fetch_one(token, mode_id)


def is_invalid_credentials_body(body: str) -> bool:
    """Return whether *body* contains a Grok invalid/blocked account marker."""
    text = str(body or "").lower()
    return (
        "invalid-credentials" in text
        or "bad-credentials" in text
        or "failed to look up session id" in text
        or "blocked-user" in text
        or "email-domain-rejected" in text
        or "session not found" in text
        or "account suspended" in text
        or "token revoked" in text
        or "token expired" in text
    )


def is_invalid_credentials_error(exc: BaseException) -> bool:
    """Return whether *exc* indicates the account is invalid or blocked."""
    if not isinstance(exc, UpstreamError):
        return False
    if exc.status not in (400, 401, 403):
        return False
    return is_invalid_credentials_body(str(exc.details.get("body", "") or ""))


def _proxy_feedback_kind_for_error(
    exc: BaseException,
    *,
    status: int | None,
):
    """Map quota-fetch failures to proxy feedback without burning healthy clearance."""
    from app.control.proxy.models import ProxyFeedbackKind

    # Invalid or blocked accounts are account problems, not proxy problems.
    if is_invalid_credentials_error(exc):
        return ProxyFeedbackKind.FORBIDDEN

    if status == 429:
        return ProxyFeedbackKind.RATE_LIMITED
    if status == 403:
        return ProxyFeedbackKind.CHALLENGE
    if status == 401:
        return ProxyFeedbackKind.UNAUTHORIZED
    if status and status >= 500:
        return ProxyFeedbackKind.UPSTREAM_5XX
    return ProxyFeedbackKind.TRANSPORT_ERROR


__all__ = [
    "parse_rate_limits",
    "fetch_all_quotas",
    "fetch_mode_quota",
    "is_invalid_credentials_body",
    "is_invalid_credentials_error",
]