youbiaokachi commited on
Commit
5f5c9df
·
verified ·
1 Parent(s): 01b03a3

Upload 18 files

Browse files
Files changed (1) hide show
  1. warp2protobuf/core/auth.py +73 -36
warp2protobuf/core/auth.py CHANGED
@@ -15,22 +15,29 @@ import httpx
15
  import asyncio
16
  from dotenv import load_dotenv, set_key
17
 
18
- from ..config.settings import REFRESH_TOKEN_B64, REFRESH_URL, CLIENT_VERSION, OS_CATEGORY, OS_NAME, OS_VERSION
 
 
 
 
 
 
 
19
  from .logging import logger, log
20
 
21
 
22
  def decode_jwt_payload(token: str) -> dict:
23
  """Decode JWT payload to check expiration"""
24
  try:
25
- parts = token.split('.')
26
  if len(parts) != 3:
27
  return {}
28
  payload_b64 = parts[1]
29
  padding = 4 - len(payload_b64) % 4
30
  if padding != 4:
31
- payload_b64 += '=' * padding
32
  payload_bytes = base64.urlsafe_b64decode(payload_b64)
33
- payload = json.loads(payload_bytes.decode('utf-8'))
34
  return payload
35
  except Exception as e:
36
  logger.debug(f"Error decoding JWT: {e}")
@@ -39,9 +46,9 @@ def decode_jwt_payload(token: str) -> dict:
39
 
40
  def is_token_expired(token: str, buffer_minutes: int = 5) -> bool:
41
  payload = decode_jwt_payload(token)
42
- if not payload or 'exp' not in payload:
43
  return True
44
- expiry_time = payload['exp']
45
  current_time = time.time()
46
  buffer_time = buffer_minutes * 60
47
  return (expiry_time - current_time) <= buffer_time
@@ -57,7 +64,9 @@ async def refresh_jwt_token() -> dict:
57
  # Prefer dynamic refresh token from environment if present
58
  env_refresh = os.getenv("WARP_REFRESH_TOKEN")
59
  if env_refresh:
60
- payload = f"grant_type=refresh_token&refresh_token={env_refresh}".encode("utf-8")
 
 
61
  else:
62
  payload = base64.b64decode(REFRESH_TOKEN_B64)
63
  headers = {
@@ -68,15 +77,11 @@ async def refresh_jwt_token() -> dict:
68
  "content-type": "application/x-www-form-urlencoded",
69
  "accept": "*/*",
70
  "accept-encoding": "gzip, br",
71
- "content-length": str(len(payload))
72
  }
73
  try:
74
  async with httpx.AsyncClient(timeout=30.0) as client:
75
- response = await client.post(
76
- REFRESH_URL,
77
- headers=headers,
78
- content=payload
79
- )
80
  if response.status_code == 200:
81
  token_data = response.json()
82
  logger.info("Token refresh successful")
@@ -91,9 +96,11 @@ async def refresh_jwt_token() -> dict:
91
 
92
 
93
  def update_env_file(new_jwt: str) -> bool:
94
- env_path = Path(".env")
 
 
95
  try:
96
- set_key(str(env_path), "WARP_JWT", new_jwt)
97
  logger.info("Updated .env file with new JWT token")
98
  return True
99
  except Exception as e:
@@ -102,9 +109,12 @@ def update_env_file(new_jwt: str) -> bool:
102
 
103
 
104
  def update_env_refresh_token(refresh_token: str) -> bool:
105
- env_path = Path(".env")
 
 
 
106
  try:
107
- set_key(str(env_path), "WARP_REFRESH_TOKEN", refresh_token)
108
  logger.info("Updated .env with WARP_REFRESH_TOKEN")
109
  return True
110
  except Exception as e:
@@ -137,11 +147,13 @@ async def check_and_refresh_token() -> bool:
137
  return False
138
  else:
139
  payload = decode_jwt_payload(current_jwt)
140
- if payload and 'exp' in payload:
141
- expiry_time = payload['exp']
142
  time_left = expiry_time - time.time()
143
  hours_left = time_left / 3600
144
- logger.debug(f"Current token is still valid ({hours_left:.1f} hours remaining)")
 
 
145
  else:
146
  logger.debug("Current token appears valid")
147
  return True
@@ -149,6 +161,7 @@ async def check_and_refresh_token() -> bool:
149
 
150
  async def get_valid_jwt() -> str:
151
  from dotenv import load_dotenv as _load
 
152
  _load(override=True)
153
  jwt = os.getenv("WARP_JWT")
154
  if not jwt:
@@ -164,14 +177,19 @@ async def get_valid_jwt() -> str:
164
  _load(override=True)
165
  jwt = os.getenv("WARP_JWT")
166
  if not jwt or is_token_expired(jwt, buffer_minutes=0):
167
- logger.warning("Warning: New token has short expiry but proceeding anyway")
 
 
168
  else:
169
- logger.warning("Warning: JWT token refresh failed, trying to use existing token")
 
 
170
  return jwt
171
 
172
 
173
  def get_jwt_token() -> str:
174
  from dotenv import load_dotenv as _load
 
175
  _load()
176
  return os.getenv("WARP_JWT", "")
177
 
@@ -187,13 +205,16 @@ async def refresh_jwt_if_needed() -> bool:
187
  # ============ Anonymous token acquisition (quota refresh) ============
188
 
189
  _ANON_GQL_URL = "https://app.warp.dev/graphql/v2?op=CreateAnonymousUser"
190
- _IDENTITY_TOOLKIT_BASE = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithCustomToken"
 
 
191
 
192
 
193
  def _extract_google_api_key_from_refresh_url() -> str:
194
  try:
195
  # REFRESH_URL like: https://app.warp.dev/proxy/token?key=API_KEY
196
  from urllib.parse import urlparse, parse_qs
 
197
  parsed = urlparse(REFRESH_URL)
198
  qs = parse_qs(parsed.query)
199
  key = qs.get("key", [""])[0]
@@ -235,7 +256,7 @@ async def _create_anonymous_user() -> dict:
235
  "input": {
236
  "anonymousUserType": "NATIVE_CLIENT_ANONYMOUS_USER_FEATURE_GATED",
237
  "expirationType": "NO_EXPIRATION",
238
- "referralCode": None
239
  },
240
  "requestContext": {
241
  "clientContext": {"version": CLIENT_VERSION},
@@ -244,21 +265,31 @@ async def _create_anonymous_user() -> dict:
244
  "linuxKernelVersion": None,
245
  "name": OS_NAME,
246
  "version": OS_VERSION,
247
- }
248
- }
 
 
 
 
 
249
  }
250
- body = {"query": query, "variables": variables, "operationName": "CreateAnonymousUser"}
251
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
252
  resp = await client.post(_ANON_GQL_URL, headers=headers, json=body)
253
  if resp.status_code != 200:
254
- raise RuntimeError(f"CreateAnonymousUser failed: HTTP {resp.status_code} {resp.text[:200]}")
 
 
255
  data = resp.json()
256
  return data
257
 
258
 
259
  async def _exchange_id_token_for_refresh_token(id_token: str) -> dict:
260
  key = _extract_google_api_key_from_refresh_url()
261
- url = f"{_IDENTITY_TOOLKIT_BASE}?key={key}" if key else f"{_IDENTITY_TOOLKIT_BASE}?key=AIzaSyBdy3O3S9hrdayLJxJ7mriBR4qgUaUygAs"
 
 
 
 
262
  headers = {
263
  "accept-encoding": "gzip, br",
264
  "content-type": "application/x-www-form-urlencoded",
@@ -274,7 +305,9 @@ async def _exchange_id_token_for_refresh_token(id_token: str) -> dict:
274
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
275
  resp = await client.post(url, headers=headers, data=form)
276
  if resp.status_code != 200:
277
- raise RuntimeError(f"signInWithCustomToken failed: HTTP {resp.status_code} {resp.text[:200]}")
 
 
278
  return resp.json()
279
 
280
 
@@ -296,7 +329,9 @@ async def acquire_anonymous_access_token() -> str:
296
  signin = await _exchange_id_token_for_refresh_token(id_token)
297
  refresh_token = signin.get("refreshToken")
298
  if not refresh_token:
299
- raise RuntimeError(f"signInWithCustomToken did not return refreshToken: {signin}")
 
 
300
 
301
  # Persist refresh token for future time-based refreshes
302
  update_env_refresh_token(refresh_token)
@@ -311,12 +346,14 @@ async def acquire_anonymous_access_token() -> str:
311
  "content-type": "application/x-www-form-urlencoded",
312
  "accept": "*/*",
313
  "accept-encoding": "gzip, br",
314
- "content-length": str(len(payload))
315
  }
316
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
317
  resp = await client.post(REFRESH_URL, headers=headers, content=payload)
318
  if resp.status_code != 200:
319
- raise RuntimeError(f"Acquire access_token failed: HTTP {resp.status_code} {resp.text[:200]}")
 
 
320
  token_data = resp.json()
321
  access = token_data.get("access_token")
322
  if not access:
@@ -335,7 +372,7 @@ def print_token_info():
335
  logger.info("Cannot decode JWT token")
336
  return
337
  logger.info("=== JWT Token Information ===")
338
- if 'email' in payload:
339
  logger.info(f"Email: {payload['email']}")
340
- if 'user_id' in payload:
341
- logger.info(f"User ID: {payload['user_id']}")
 
15
  import asyncio
16
  from dotenv import load_dotenv, set_key
17
 
18
+ from ..config.settings import (
19
+ REFRESH_TOKEN_B64,
20
+ REFRESH_URL,
21
+ CLIENT_VERSION,
22
+ OS_CATEGORY,
23
+ OS_NAME,
24
+ OS_VERSION,
25
+ )
26
  from .logging import logger, log
27
 
28
 
29
  def decode_jwt_payload(token: str) -> dict:
30
  """Decode JWT payload to check expiration"""
31
  try:
32
+ parts = token.split(".")
33
  if len(parts) != 3:
34
  return {}
35
  payload_b64 = parts[1]
36
  padding = 4 - len(payload_b64) % 4
37
  if padding != 4:
38
+ payload_b64 += "=" * padding
39
  payload_bytes = base64.urlsafe_b64decode(payload_b64)
40
+ payload = json.loads(payload_bytes.decode("utf-8"))
41
  return payload
42
  except Exception as e:
43
  logger.debug(f"Error decoding JWT: {e}")
 
46
 
47
  def is_token_expired(token: str, buffer_minutes: int = 5) -> bool:
48
  payload = decode_jwt_payload(token)
49
+ if not payload or "exp" not in payload:
50
  return True
51
+ expiry_time = payload["exp"]
52
  current_time = time.time()
53
  buffer_time = buffer_minutes * 60
54
  return (expiry_time - current_time) <= buffer_time
 
64
  # Prefer dynamic refresh token from environment if present
65
  env_refresh = os.getenv("WARP_REFRESH_TOKEN")
66
  if env_refresh:
67
+ payload = f"grant_type=refresh_token&refresh_token={env_refresh}".encode(
68
+ "utf-8"
69
+ )
70
  else:
71
  payload = base64.b64decode(REFRESH_TOKEN_B64)
72
  headers = {
 
77
  "content-type": "application/x-www-form-urlencoded",
78
  "accept": "*/*",
79
  "accept-encoding": "gzip, br",
80
+ "content-length": str(len(payload)),
81
  }
82
  try:
83
  async with httpx.AsyncClient(timeout=30.0) as client:
84
+ response = await client.post(REFRESH_URL, headers=headers, content=payload)
 
 
 
 
85
  if response.status_code == 200:
86
  token_data = response.json()
87
  logger.info("Token refresh successful")
 
96
 
97
 
98
  def update_env_file(new_jwt: str) -> bool:
99
+ # 如果在Hugging Face环境(/tmp/.env存在),就用它,否则用本地的.env
100
+ ENV_FILE_PATH = "/tmp/.env" if os.path.exists("/tmp/.env") else ".env"
101
+ # env_path = Path(".env")
102
  try:
103
+ set_key(str(ENV_FILE_PATH), "WARP_JWT", new_jwt)
104
  logger.info("Updated .env file with new JWT token")
105
  return True
106
  except Exception as e:
 
109
 
110
 
111
  def update_env_refresh_token(refresh_token: str) -> bool:
112
+ # 如果在Hugging Face环境(/tmp/.env存在),就用它,否则用本地的.env
113
+ ENV_FILE_PATH = "/tmp/.env" if os.path.exists("/tmp/.env") else ".env"
114
+ # env_path = Path(".env")
115
+
116
  try:
117
+ set_key(str(ENV_FILE_PATH), "WARP_REFRESH_TOKEN", refresh_token)
118
  logger.info("Updated .env with WARP_REFRESH_TOKEN")
119
  return True
120
  except Exception as e:
 
147
  return False
148
  else:
149
  payload = decode_jwt_payload(current_jwt)
150
+ if payload and "exp" in payload:
151
+ expiry_time = payload["exp"]
152
  time_left = expiry_time - time.time()
153
  hours_left = time_left / 3600
154
+ logger.debug(
155
+ f"Current token is still valid ({hours_left:.1f} hours remaining)"
156
+ )
157
  else:
158
  logger.debug("Current token appears valid")
159
  return True
 
161
 
162
  async def get_valid_jwt() -> str:
163
  from dotenv import load_dotenv as _load
164
+
165
  _load(override=True)
166
  jwt = os.getenv("WARP_JWT")
167
  if not jwt:
 
177
  _load(override=True)
178
  jwt = os.getenv("WARP_JWT")
179
  if not jwt or is_token_expired(jwt, buffer_minutes=0):
180
+ logger.warning(
181
+ "Warning: New token has short expiry but proceeding anyway"
182
+ )
183
  else:
184
+ logger.warning(
185
+ "Warning: JWT token refresh failed, trying to use existing token"
186
+ )
187
  return jwt
188
 
189
 
190
  def get_jwt_token() -> str:
191
  from dotenv import load_dotenv as _load
192
+
193
  _load()
194
  return os.getenv("WARP_JWT", "")
195
 
 
205
  # ============ Anonymous token acquisition (quota refresh) ============
206
 
207
  _ANON_GQL_URL = "https://app.warp.dev/graphql/v2?op=CreateAnonymousUser"
208
+ _IDENTITY_TOOLKIT_BASE = (
209
+ "https://identitytoolkit.googleapis.com/v1/accounts:signInWithCustomToken"
210
+ )
211
 
212
 
213
  def _extract_google_api_key_from_refresh_url() -> str:
214
  try:
215
  # REFRESH_URL like: https://app.warp.dev/proxy/token?key=API_KEY
216
  from urllib.parse import urlparse, parse_qs
217
+
218
  parsed = urlparse(REFRESH_URL)
219
  qs = parse_qs(parsed.query)
220
  key = qs.get("key", [""])[0]
 
256
  "input": {
257
  "anonymousUserType": "NATIVE_CLIENT_ANONYMOUS_USER_FEATURE_GATED",
258
  "expirationType": "NO_EXPIRATION",
259
+ "referralCode": None,
260
  },
261
  "requestContext": {
262
  "clientContext": {"version": CLIENT_VERSION},
 
265
  "linuxKernelVersion": None,
266
  "name": OS_NAME,
267
  "version": OS_VERSION,
268
+ },
269
+ },
270
+ }
271
+ body = {
272
+ "query": query,
273
+ "variables": variables,
274
+ "operationName": "CreateAnonymousUser",
275
  }
 
276
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
277
  resp = await client.post(_ANON_GQL_URL, headers=headers, json=body)
278
  if resp.status_code != 200:
279
+ raise RuntimeError(
280
+ f"CreateAnonymousUser failed: HTTP {resp.status_code} {resp.text[:200]}"
281
+ )
282
  data = resp.json()
283
  return data
284
 
285
 
286
  async def _exchange_id_token_for_refresh_token(id_token: str) -> dict:
287
  key = _extract_google_api_key_from_refresh_url()
288
+ url = (
289
+ f"{_IDENTITY_TOOLKIT_BASE}?key={key}"
290
+ if key
291
+ else f"{_IDENTITY_TOOLKIT_BASE}?key=AIzaSyBdy3O3S9hrdayLJxJ7mriBR4qgUaUygAs"
292
+ )
293
  headers = {
294
  "accept-encoding": "gzip, br",
295
  "content-type": "application/x-www-form-urlencoded",
 
305
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
306
  resp = await client.post(url, headers=headers, data=form)
307
  if resp.status_code != 200:
308
+ raise RuntimeError(
309
+ f"signInWithCustomToken failed: HTTP {resp.status_code} {resp.text[:200]}"
310
+ )
311
  return resp.json()
312
 
313
 
 
329
  signin = await _exchange_id_token_for_refresh_token(id_token)
330
  refresh_token = signin.get("refreshToken")
331
  if not refresh_token:
332
+ raise RuntimeError(
333
+ f"signInWithCustomToken did not return refreshToken: {signin}"
334
+ )
335
 
336
  # Persist refresh token for future time-based refreshes
337
  update_env_refresh_token(refresh_token)
 
346
  "content-type": "application/x-www-form-urlencoded",
347
  "accept": "*/*",
348
  "accept-encoding": "gzip, br",
349
+ "content-length": str(len(payload)),
350
  }
351
  async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
352
  resp = await client.post(REFRESH_URL, headers=headers, content=payload)
353
  if resp.status_code != 200:
354
+ raise RuntimeError(
355
+ f"Acquire access_token failed: HTTP {resp.status_code} {resp.text[:200]}"
356
+ )
357
  token_data = resp.json()
358
  access = token_data.get("access_token")
359
  if not access:
 
372
  logger.info("Cannot decode JWT token")
373
  return
374
  logger.info("=== JWT Token Information ===")
375
+ if "email" in payload:
376
  logger.info(f"Email: {payload['email']}")
377
+ if "user_id" in payload:
378
+ logger.info(f"User ID: {payload['user_id']}")