Spaces:
Running
Running
File size: 4,133 Bytes
e2ab8a3 1bdfa3a e2ab8a3 1bdfa3a e2ab8a3 1bdfa3a e2ab8a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import time
import uuid
import os
from typing import Dict, Tuple, Optional
import requests
def _get_proxies() -> Optional[Dict[str, str]]:
proxy = os.getenv("HTTP_PROXY", "").strip()
if proxy:
return {"http": proxy, "https": proxy}
return None
# OIDC endpoints and constants (aligned with v1/auth_client.py)
OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
REGISTER_URL = f"{OIDC_BASE}/client/register"
DEVICE_AUTH_URL = f"{OIDC_BASE}/device_authorization"
TOKEN_URL = f"{OIDC_BASE}/token"
START_URL = "https://view.awsapps.com/start"
USER_AGENT = "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0"
X_AMZ_USER_AGENT = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
AMZ_SDK_REQUEST = "attempt=1; max=3"
def make_headers() -> Dict[str, str]:
return {
"content-type": "application/json",
"user-agent": USER_AGENT,
"x-amz-user-agent": X_AMZ_USER_AGENT,
"amz-sdk-request": AMZ_SDK_REQUEST,
"amz-sdk-invocation-id": str(uuid.uuid4()),
}
def post_json(url: str, payload: Dict) -> requests.Response:
# Keep JSON order and mimic body closely to v1
payload_str = json.dumps(payload, ensure_ascii=False)
headers = make_headers()
resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60), proxies=_get_proxies())
return resp
def register_client_min() -> Tuple[str, str]:
"""
Register an OIDC client (minimal) and return (clientId, clientSecret).
"""
payload = {
"clientName": "Amazon Q Developer for command line",
"clientType": "public",
"scopes": [
"codewhisperer:completions",
"codewhisperer:analysis",
"codewhisperer:conversations",
],
}
r = post_json(REGISTER_URL, payload)
r.raise_for_status()
data = r.json()
return data["clientId"], data["clientSecret"]
def device_authorize(client_id: str, client_secret: str) -> Dict:
"""
Start device authorization. Returns dict that includes:
- deviceCode
- interval
- expiresIn
- verificationUriComplete
- userCode
"""
payload = {
"clientId": client_id,
"clientSecret": client_secret,
"startUrl": START_URL,
}
r = post_json(DEVICE_AUTH_URL, payload)
r.raise_for_status()
return r.json()
def poll_token_device_code(
client_id: str,
client_secret: str,
device_code: str,
interval: int,
expires_in: int,
max_timeout_sec: Optional[int] = 300,
) -> Dict:
"""
Poll token with device_code until approved or timeout.
- Respects upstream expires_in, but caps total time by max_timeout_sec (default 5 minutes).
Returns token dict with at least 'accessToken' and optionally 'refreshToken'.
Raises:
- TimeoutError on timeout
- requests.HTTPError for non-recoverable HTTP errors
"""
payload = {
"clientId": client_id,
"clientSecret": client_secret,
"deviceCode": device_code,
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
}
now = time.time()
upstream_deadline = now + max(1, int(expires_in))
cap_deadline = now + max_timeout_sec if (max_timeout_sec and max_timeout_sec > 0) else upstream_deadline
deadline = min(upstream_deadline, cap_deadline)
# Ensure interval sane
poll_interval = max(1, int(interval or 1))
while time.time() < deadline:
r = post_json(TOKEN_URL, payload)
if r.status_code == 200:
return r.json()
if r.status_code == 400:
# Expect AuthorizationPendingException early on
try:
err = r.json()
except Exception:
err = {"error": r.text}
if str(err.get("error")) == "authorization_pending":
time.sleep(poll_interval)
continue
# Other 4xx are errors
r.raise_for_status()
# Non-200, non-400
r.raise_for_status()
raise TimeoutError("Device authorization expired before approval (timeout reached)") |