CassiopeiaCode commited on
Commit
1bdfa3a
·
1 Parent(s): 59d6efd

feat: 添加HTTP代理支持

Browse files

- 添加HTTP_PROXY环境变量配置
- 在所有requests请求中使用代理(留空则不使用)
- 支持replicate.py、app.py和auth_flow.py中的所有HTTP请求

Files changed (4) hide show
  1. .env.example +5 -1
  2. app.py +7 -1
  3. auth_flow.py +8 -1
  4. replicate.py +9 -1
.env.example CHANGED
@@ -4,4 +4,8 @@
4
  OPENAI_KEYS=""
5
 
6
  # 出错次数阈值,超过此值自动禁用账号
7
- MAX_ERROR_COUNT=100
 
 
 
 
 
4
  OPENAI_KEYS=""
5
 
6
  # 出错次数阈值,超过此值自动禁用账号
7
+ MAX_ERROR_COUNT=100
8
+
9
+ # HTTP代理设置(留空不使用代理)
10
+ # 例如:HTTP_PROXY="http://127.0.0.1:7890"
11
+ HTTP_PROXY=""
app.py CHANGED
@@ -193,6 +193,12 @@ class ChatCompletionRequest(BaseModel):
193
  OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
194
  TOKEN_URL = f"{OIDC_BASE}/token"
195
 
 
 
 
 
 
 
196
  def _oidc_headers() -> Dict[str, str]:
197
  return {
198
  "content-type": "application/json",
@@ -220,7 +226,7 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
220
  }
221
 
222
  try:
223
- r = requests.post(TOKEN_URL, headers=_oidc_headers(), json=payload, timeout=(15, 60))
224
  r.raise_for_status()
225
  data = r.json()
226
  new_access = data.get("accessToken")
 
193
  OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
194
  TOKEN_URL = f"{OIDC_BASE}/token"
195
 
196
+ def _get_proxies() -> Optional[Dict[str, str]]:
197
+ proxy = os.getenv("HTTP_PROXY", "").strip()
198
+ if proxy:
199
+ return {"http": proxy, "https": proxy}
200
+ return None
201
+
202
  def _oidc_headers() -> Dict[str, str]:
203
  return {
204
  "content-type": "application/json",
 
226
  }
227
 
228
  try:
229
+ r = requests.post(TOKEN_URL, headers=_oidc_headers(), json=payload, timeout=(15, 60), proxies=_get_proxies())
230
  r.raise_for_status()
231
  data = r.json()
232
  new_access = data.get("accessToken")
auth_flow.py CHANGED
@@ -1,10 +1,17 @@
1
  import json
2
  import time
3
  import uuid
 
4
  from typing import Dict, Tuple, Optional
5
 
6
  import requests
7
 
 
 
 
 
 
 
8
  # OIDC endpoints and constants (aligned with v1/auth_client.py)
9
  OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
10
  REGISTER_URL = f"{OIDC_BASE}/client/register"
@@ -31,7 +38,7 @@ def post_json(url: str, payload: Dict) -> requests.Response:
31
  # Keep JSON order and mimic body closely to v1
32
  payload_str = json.dumps(payload, ensure_ascii=False)
33
  headers = make_headers()
34
- resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60))
35
  return resp
36
 
37
 
 
1
  import json
2
  import time
3
  import uuid
4
+ import os
5
  from typing import Dict, Tuple, Optional
6
 
7
  import requests
8
 
9
+ def _get_proxies() -> Optional[Dict[str, str]]:
10
+ proxy = os.getenv("HTTP_PROXY", "").strip()
11
+ if proxy:
12
+ return {"http": proxy, "https": proxy}
13
+ return None
14
+
15
  # OIDC endpoints and constants (aligned with v1/auth_client.py)
16
  OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
17
  REGISTER_URL = f"{OIDC_BASE}/client/register"
 
38
  # Keep JSON order and mimic body closely to v1
39
  payload_str = json.dumps(payload, ensure_ascii=False)
40
  headers = make_headers()
41
+ resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60), proxies=_get_proxies())
42
  return resp
43
 
44
 
replicate.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import uuid
 
3
  from pathlib import Path
4
  from typing import Dict, Optional, Tuple, Iterator, List, Generator, Any
5
  import struct
@@ -15,6 +16,12 @@ class StreamTracker:
15
  self.has_content = True
16
  yield item
17
 
 
 
 
 
 
 
18
  BASE_DIR = Path(__file__).resolve().parent
19
  TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
20
 
@@ -198,7 +205,8 @@ def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model:
198
  payload_str = json.dumps(body_json, ensure_ascii=False)
199
  headers = _merge_headers(headers_from_log, access_token)
200
  session = requests.Session()
201
- resp = session.post(url, headers=headers, data=payload_str, stream=True, timeout=timeout)
 
202
  if resp.status_code >= 400:
203
  try:
204
  err = resp.text
 
1
  import json
2
  import uuid
3
+ import os
4
  from pathlib import Path
5
  from typing import Dict, Optional, Tuple, Iterator, List, Generator, Any
6
  import struct
 
16
  self.has_content = True
17
  yield item
18
 
19
+ def _get_proxies() -> Optional[Dict[str, str]]:
20
+ proxy = os.getenv("HTTP_PROXY", "").strip()
21
+ if proxy:
22
+ return {"http": proxy, "https": proxy}
23
+ return None
24
+
25
  BASE_DIR = Path(__file__).resolve().parent
26
  TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
27
 
 
205
  payload_str = json.dumps(body_json, ensure_ascii=False)
206
  headers = _merge_headers(headers_from_log, access_token)
207
  session = requests.Session()
208
+ proxies = _get_proxies()
209
+ resp = session.post(url, headers=headers, data=payload_str, stream=True, timeout=timeout, proxies=proxies)
210
  if resp.status_code >= 400:
211
  try:
212
  err = resp.text