hzruo commited on
Commit
71001db
·
verified ·
1 Parent(s): c338bcb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +448 -0
app.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SambaNova OpenAI 接口代理 (支持模型列表透传和自动登录)
3
+ """
4
+
5
+ import os
6
+ import uuid
7
+ import json
8
+ import time
9
+ import asyncio
10
+ import httpx
11
+ import secrets
12
+ import urllib.parse
13
+ from typing import Optional, Dict, Any
14
+ from fastapi import FastAPI, Request, HTTPException, Depends, Header
15
+ from fastapi.responses import StreamingResponse, JSONResponse
16
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17
+ from fake_useragent import UserAgent
18
+ # 修复 Pydantic 导入
19
+ try:
20
+ # 尝试从 pydantic-settings 导入 (Pydantic v2)
21
+ from pydantic_settings import BaseSettings
22
+ except ImportError:
23
+ # 回退到旧版本 (Pydantic v1)
24
+ from pydantic import BaseSettings
25
+
26
+ # ================ 配置 ================
27
+ class Settings(BaseSettings):
28
+ # SambaNova 配置
29
+ SAMBA_EMAIL: str = os.getenv("SAMBA_EMAIL", "")
30
+ SAMBA_PASSWORD: str = os.getenv("SAMBA_PASSWORD", "")
31
+ SAMBA_COMPLETION_URL: str = os.getenv("SAMBA_COMPLETION_URL", "https://cloud.sambanova.ai/api/completion")
32
+ SAMBA_MODELS_URL: str = os.getenv("SAMBA_MODELS_URL", "https://api.sambanova.ai/v1/models")
33
+
34
+ # 本地API密钥配置
35
+ LOCAL_API_KEY: str = os.getenv("LOCAL_API_KEY", secrets.token_urlsafe(32))
36
+
37
+ # 其他配置
38
+ TOKEN_CACHE_TIME: int = int(os.getenv("TOKEN_CACHE_TIME", 3600)) # 默认缓存1小时
39
+ FINGERPRINT_PREFIX: str = os.getenv("FINGERPRINT_PREFIX", "anon_")
40
+
41
+ class Config:
42
+ env_file = ".env"
43
+
44
+ settings = Settings()
45
+ # =====================================
46
+
47
+ app = FastAPI(title="SambaNova OpenAI Proxy with Auto-Login")
48
+ security = HTTPBearer()
49
+
50
+ # 全局变量存储访问令牌和过期时间
51
+ access_token = None
52
+ token_expiry = 0
53
+ token_lock = asyncio.Lock()
54
+
55
+ def generate_fingerprint() -> str:
56
+ """生成符合格式要求的随机指纹"""
57
+ return f"{settings.FINGERPRINT_PREFIX}{uuid.uuid4().hex[:20]}"
58
+
59
+ async def validate_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
60
+ """验证本地API密钥并返回SambaNova访问令牌"""
61
+ api_key = credentials.credentials
62
+
63
+ # 如果未配置本地API密钥或为空,则跳过验证
64
+ if settings.LOCAL_API_KEY and settings.LOCAL_API_KEY.strip():
65
+ # 验证本地API密钥
66
+ if api_key != settings.LOCAL_API_KEY:
67
+ raise HTTPException(
68
+ status_code=401,
69
+ detail="Invalid API key",
70
+ headers={"WWW-Authenticate": "Bearer"},
71
+ )
72
+ else:
73
+ print("[警告] LOCAL_API_KEY未配置或为空,跳过API密钥验证")
74
+
75
+ # 获取或刷新SambaNova访问令牌
76
+ token = await get_samba_token()
77
+ if not token:
78
+ raise HTTPException(
79
+ status_code=500,
80
+ detail="Failed to obtain SambaNova access token. Check server logs for details."
81
+ )
82
+
83
+ return token
84
+
85
+ async def get_samba_token() -> Optional[str]:
86
+ """获取或刷新SambaNova访问令牌"""
87
+ global access_token, token_expiry
88
+
89
+ # 使用锁防止并发请求同时刷新令牌
90
+ async with token_lock:
91
+ current_time = time.time()
92
+
93
+ # 如果令牌有效,直接返回
94
+ if access_token and current_time < token_expiry:
95
+ print(f"[令牌] 使用缓存令牌: {access_token}")
96
+ return access_token
97
+
98
+ # 否则获取新令牌
99
+ try:
100
+ # 检查凭据是否已配置
101
+ if not settings.SAMBA_EMAIL or not settings.SAMBA_PASSWORD:
102
+ print("[错误] 未配置SambaNova凭据,请设置SAMBA_EMAIL和SAMBA_PASSWORD环境变量")
103
+ return None
104
+
105
+ print(f"[令牌] 开始获取新令牌... 邮箱: {settings.SAMBA_EMAIL}")
106
+ auth = SambaAuthAsync(settings.SAMBA_EMAIL, settings.SAMBA_PASSWORD)
107
+ new_token = await auth.login()
108
+
109
+ if new_token:
110
+ access_token = new_token
111
+ token_expiry = current_time + settings.TOKEN_CACHE_TIME
112
+ print(f"[令牌更新成功] 完整令牌: {new_token}")
113
+ print(f"[令牌更新成功] 令牌将在 {settings.TOKEN_CACHE_TIME} 秒后过期")
114
+ return access_token
115
+ else:
116
+ print("[令牌获取失败] 请检查SambaNova凭据是否正确")
117
+ return None
118
+ except Exception as e:
119
+ print(f"[令牌获取异常] {str(e)}")
120
+ return None
121
+
122
+ def reset_token_expiry():
123
+ """重置令牌过期时间,强制下次请求重新获取令牌"""
124
+ global token_expiry
125
+ token_expiry = 0
126
+ print("[令牌] 令牌已过期,将在下次请求时重新获取")
127
+
128
+ async def forward_get_request(url: str, token: str) -> httpx.Response:
129
+ """转发 GET 请求到目标接口"""
130
+ headers = {
131
+ "accept": "application/json",
132
+ "user-agent": "SambaNova-Proxy/1.0",
133
+ "origin": "https://cloud.sambanova.ai",
134
+ "referer": "https://cloud.sambanova.ai/"
135
+ }
136
+
137
+ cookies = {
138
+ "access_token": token
139
+ }
140
+
141
+ async with httpx.AsyncClient() as client:
142
+ try:
143
+ resp = await client.get(
144
+ url,
145
+ headers=headers,
146
+ cookies=cookies,
147
+ timeout=10.0
148
+ )
149
+
150
+ # 检查是否需要刷新令牌
151
+ if resp.status_code == 401:
152
+ # 令牌已过期,需要刷新
153
+ reset_token_expiry()
154
+ raise HTTPException(401, "Token expired, please retry")
155
+
156
+ resp.raise_for_status()
157
+ return resp
158
+ except httpx.HTTPStatusError as e:
159
+ if e.response.status_code == 401:
160
+ # 令牌已过期,需要刷新
161
+ reset_token_expiry()
162
+ raise HTTPException(401, "Token expired, please retry")
163
+ raise HTTPException(e.response.status_code, f"Upstream error: {e.response.text}")
164
+
165
+ async def forward_post_request(url: str, payload: dict, token: str) -> httpx.Response:
166
+ """转发 POST 请求到目标接口"""
167
+ headers = {
168
+ "content-type": "application/json",
169
+ "user-agent": "SambaNova-Proxy/1.0",
170
+ "origin": "https://cloud.sambanova.ai",
171
+ "referer": "https://cloud.sambanova.ai/"
172
+ }
173
+
174
+ cookies = {
175
+ "access_token": token
176
+ }
177
+
178
+ async with httpx.AsyncClient() as client:
179
+ try:
180
+ resp = await client.post(
181
+ url,
182
+ json=payload,
183
+ headers=headers,
184
+ cookies=cookies,
185
+ timeout=30.0
186
+ )
187
+
188
+ # 检查是否需要刷新令牌
189
+ if resp.status_code == 401:
190
+ # 令牌已过期,需要刷新
191
+ reset_token_expiry()
192
+ raise HTTPException(401, "Token expired, please retry")
193
+
194
+ resp.raise_for_status()
195
+ return resp
196
+ except httpx.HTTPStatusError as e:
197
+ if e.response.status_code == 401:
198
+ # 令牌已过期,需要刷新
199
+ reset_token_expiry()
200
+ raise HTTPException(401, "Token expired, please retry")
201
+ raise HTTPException(e.response.status_code, f"Upstream error: {e.response.text}")
202
+
203
+ @app.get("/v1/models")
204
+ async def list_models(token: str = Depends(validate_api_key)):
205
+ """透传模型列表接口"""
206
+ try:
207
+ resp = await forward_get_request(settings.SAMBA_MODELS_URL, token)
208
+ content = resp.json()
209
+ json_str = json.dumps(content, separators=(',', ':'), ensure_ascii=False)
210
+ json_bytes = json_str.encode('utf-8')
211
+ return JSONResponse(
212
+ content=content,
213
+ headers={
214
+ "Content-Type": "application/json",
215
+ "Content-Length": str(len(json_bytes)),
216
+ "Cache-Control": "public, max-age=300"
217
+ }
218
+ )
219
+ except httpx.RequestError as e:
220
+ raise HTTPException(504, f"Gateway timeout: {str(e)}")
221
+ except Exception as e:
222
+ raise HTTPException(500, f"Internal server error: {str(e)}")
223
+
224
+ @app.post("/v1/chat/completions")
225
+ async def chat_completions(
226
+ request: Request,
227
+ token: str = Depends(validate_api_key)
228
+ ):
229
+ """处理对话请求"""
230
+ try:
231
+ openai_payload = await request.json()
232
+ print(f"[请求] 收到聊天请求,模型: {openai_payload.get('model', 'DeepSeek-R1')}")
233
+
234
+ samba_payload = {
235
+ "body": {
236
+ "model": openai_payload.get("model", "DeepSeek-R1"),
237
+ "messages": openai_payload["messages"],
238
+ "stream": True,
239
+ "stop": openai_payload.get("stop", ["<|eot_id|>"]),
240
+ "temperature": openai_payload.get("temperature", 0),
241
+ "max_tokens": openai_payload.get("max_tokens", 2048),
242
+ "do_sample": openai_payload.get("temperature", 0) > 0
243
+ },
244
+ "env_type": "text",
245
+ "fingerprint": generate_fingerprint()
246
+ }
247
+
248
+ print(f"[转发] 使用令牌 {token[:10]}... 转发请求到 SambaNova")
249
+ resp = await forward_post_request(settings.SAMBA_COMPLETION_URL, samba_payload, token)
250
+ print(f"[响应] 成功获取响应,开始流式传输")
251
+
252
+ return StreamingResponse(
253
+ resp.aiter_bytes(),
254
+ media_type="text/event-stream",
255
+ headers={
256
+ "X-Proxy-Version": "1.0",
257
+ "X-Request-ID": str(uuid.uuid4())
258
+ }
259
+ )
260
+ except HTTPException as e:
261
+ print(f"[错误] HTTP异常: {e.detail}")
262
+ raise
263
+ except httpx.RequestError as e:
264
+ print(f"[错误] 请求错误: {str(e)}")
265
+ raise HTTPException(504, f"Gateway timeout: {str(e)}")
266
+ except Exception as e:
267
+ print(f"[错误] 未处理异常: {str(e)}")
268
+ raise HTTPException(500, f"Internal server error: {str(e)}")
269
+
270
+ @app.get("/info")
271
+ async def get_info():
272
+ """获取服务信息"""
273
+ return {
274
+ "status": "running",
275
+ "api_key_configured": bool(settings.LOCAL_API_KEY),
276
+ "samba_credentials_configured": bool(settings.SAMBA_EMAIL and settings.SAMBA_PASSWORD),
277
+ "token_status": "active" if access_token and time.time() < token_expiry else "not_available",
278
+ "token_expires_in": max(0, int(token_expiry - time.time())) if access_token else 0
279
+ }
280
+
281
+ @app.get("/debug/token", include_in_schema=False)
282
+ async def debug_token():
283
+ """调试端点:检查当前令牌状态"""
284
+ global access_token, token_expiry
285
+ current_time = time.time()
286
+
287
+ return {
288
+ "token_exists": access_token is not None,
289
+ "token_prefix": access_token[:10] + "..." if access_token else None,
290
+ "token_valid": access_token is not None and current_time < token_expiry,
291
+ "expires_in_seconds": max(0, int(token_expiry - current_time)) if access_token else 0,
292
+ "current_time": current_time,
293
+ "expiry_time": token_expiry,
294
+ }
295
+
296
+ class SambaAuthAsync:
297
+ def __init__(self, email, password):
298
+ self.email = email
299
+ self.password = password
300
+ self.client = httpx.AsyncClient()
301
+ self.ua = UserAgent()
302
+ self.base_headers = {
303
+ "accept": "*/*",
304
+ "accept-language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
305
+ "origin": "https://cloud.sambanova.ai",
306
+ "referer": "https://cloud.sambanova.ai/",
307
+ "user-agent": self.ua.random
308
+ }
309
+ self.config = None
310
+ self.nonce = None # 确保nonce属性存在
311
+
312
+ async def _get_config(self):
313
+ """获取动态配置信息"""
314
+ config_url = "https://cloud.sambanova.ai/api/config"
315
+ response = await self.client.get(config_url, headers=self.base_headers)
316
+ response.raise_for_status()
317
+ self.config = response.json()
318
+ print(f"[配置获取成功] ClientID: {self.config['clientId']}")
319
+
320
+ async def _get_login_ticket(self):
321
+ """获取登录票据"""
322
+ auth_url = f"https://{self.config['issuerBaseUrl']}/co/authenticate"
323
+ payload = {
324
+ "client_id": self.config["clientId"],
325
+ "username": self.email,
326
+ "password": self.password,
327
+ "realm": "Username-Password-Authentication",
328
+ "credential_type": "http://auth0.com/oauth/grant-type/password-realm"
329
+ }
330
+
331
+ headers = {**self.base_headers, "content-type": "application/json"}
332
+
333
+ response = await self.client.post(auth_url, headers=headers, json=payload)
334
+ response.raise_for_status()
335
+ return response.json()["login_ticket"]
336
+
337
+ async def _get_auth_code(self, login_ticket: str):
338
+ """获取授权码"""
339
+ state = secrets.token_urlsafe(32)
340
+ self.nonce = secrets.token_urlsafe(32) # 保存nonce到实例变量
341
+
342
+ params = {
343
+ "client_id": self.config["clientId"],
344
+ "response_type": "code",
345
+ "redirect_uri": self.config["redirectURL"],
346
+ "scope": "openid profile email",
347
+ "nonce": self.nonce,
348
+ "state": state,
349
+ "login_ticket": login_ticket,
350
+ "realm": "Username-Password-Authentication",
351
+ "auth0Client": "eyJuYW1lIjoibG9jay5qcyIsInZlcnNpb24iOiIxMi4zLjAiLCJlbnYiOnsiYXV0aDAuanMiOiI5LjIyLjEiLCJhdXRoMC5qcy11bHAiOiI5LjIyLjEifX0="
352
+ }
353
+
354
+ auth_url = f"https://{self.config['issuerBaseUrl']}/authorize"
355
+ response = await self.client.get(
356
+ auth_url,
357
+ params=params,
358
+ follow_redirects=False
359
+ )
360
+
361
+ if response.status_code == 302:
362
+ location = response.headers["location"]
363
+ parsed = urllib.parse.urlparse(location)
364
+ query = urllib.parse.parse_qs(parsed.query)
365
+ return query.get("code", [None])[0], state
366
+ raise Exception(f"未收到302重定向,实际状态码:{response.status_code}")
367
+
368
+ async def _exchange_token(self, code: str, state: str):
369
+ """交换访问令牌"""
370
+ # 设置必要的cookies
371
+ self.client.cookies.set("nonce", self.nonce, domain="cloud.sambanova.ai")
372
+
373
+ callback_url = f"{self.config['redirectURL']}?code={code}&state={state}"
374
+ response = await self.client.get(
375
+ callback_url,
376
+ headers={
377
+ **self.base_headers,
378
+ "sec-fetch-site": "same-site",
379
+ "sec-fetch-mode": "navigate",
380
+ "sec-fetch-user": "?1",
381
+ "sec-fetch-dest": "document"
382
+ },
383
+ follow_redirects=True
384
+ )
385
+
386
+ # 从cookies中提取access_token
387
+ for cookie in self.client.cookies.jar:
388
+ if cookie.name == "access_token" and "sambanova.ai" in cookie.domain:
389
+ return cookie.value
390
+ raise Exception("未找到access_token")
391
+
392
+ async def login(self):
393
+ """完整登录流程"""
394
+ try:
395
+ await self._get_config()
396
+ login_ticket = await self._get_login_ticket()
397
+ print(f"[登录票据获取成功] 完整票据: {login_ticket}")
398
+
399
+ auth_code, state = await self._get_auth_code(login_ticket)
400
+ if not auth_code:
401
+ raise Exception("授权码获取失败")
402
+ print(f"[授权码获取成功] 完整授权码: {auth_code}")
403
+ print(f"[授权状态] state: {state}")
404
+
405
+ token = await self._exchange_token(auth_code, state)
406
+ print(f"[令牌获取成功] 完整令牌: {token}")
407
+ return token
408
+
409
+ except Exception as e:
410
+ print(f"[登录失败] 详细错误: {str(e)}")
411
+ return None
412
+ finally:
413
+ await self.client.aclose()
414
+
415
+ @app.on_event("startup")
416
+ async def startup_event():
417
+ """应用启动时预获取令牌"""
418
+ print("\n" + "="*50)
419
+ print("[启动] SambaNova OpenAI 代理服务启动")
420
+ print("="*50)
421
+
422
+ # 检查环境变量
423
+ print(f"[环境] SAMBA_EMAIL: {'已设置' if settings.SAMBA_EMAIL else '未设置'}")
424
+ print(f"[环境] SAMBA_PASSWORD: {'已设置' if settings.SAMBA_PASSWORD else '未设置'}")
425
+ print(f"[环境] LOCAL_API_KEY: {'已设置' if settings.LOCAL_API_KEY else '未设置'}")
426
+
427
+ # 尝试直接登录
428
+ print("[登录] 开始尝试登录...")
429
+ try:
430
+ auth = SambaAuthAsync(settings.SAMBA_EMAIL, settings.SAMBA_PASSWORD)
431
+ token = await auth.login()
432
+
433
+ if token:
434
+ global access_token, token_expiry
435
+ access_token = token
436
+ token_expiry = time.time() + settings.TOKEN_CACHE_TIME
437
+ print(f"[登录] 登录成功! 令牌: {token}")
438
+ print(f"[登录] 令牌将在 {settings.TOKEN_CACHE_TIME} 秒后过期")
439
+ else:
440
+ print("[登录] 登录失败,未获取到令牌")
441
+ except Exception as e:
442
+ print(f"[登录] 登录过程发生异常: {str(e)}")
443
+
444
+ print("="*50 + "\n")
445
+
446
+ if __name__ == "__main__":
447
+ import uvicorn
448
+ uvicorn.run(app, host="0.0.0.0", port=7860)