Spaces:
Sleeping
Sleeping
| import codecs | |
| import hashlib | |
| import json | |
| import os | |
| import ssl | |
| import uuid | |
| from datetime import datetime | |
| # from http.client import HTTPException | |
| from typing import Dict, Any, Optional | |
| import httpx | |
| from fastapi import HTTPException | |
| from httpx import ConnectError, TransportError | |
| from starlette import status | |
| from core.config import get_settings | |
| from core.logger import setup_logger | |
| from core.models import ChatRequest | |
| # from rich import print | |
| settings = get_settings() | |
| logger = setup_logger(__name__) | |
| def decode_unicode_escape(s): | |
| # 检查输入是否为字典类型 | |
| if isinstance(s, dict): | |
| return s | |
| # 如果需要,将输入转换为字符串 | |
| if not isinstance(s, (str, bytes)): | |
| s = str(s) | |
| # 如果是字符串,转换为字节 | |
| if isinstance(s, str): | |
| s = s.encode('utf-8') | |
| return codecs.decode(s, 'unicode_escape') | |
| FIREBASE_API_KEY = settings.FIREBASE_API_KEY | |
| async def refresh_token_via_rest(refresh_token): | |
| refresh_token_array = [x.strip() for x in refresh_token.split(',')] | |
| token_array = [] | |
| if len(refresh_token_array) > 0: | |
| print('refresh token length is ', len(refresh_token_array)) | |
| for e in refresh_token_array: | |
| # Firebase Auth REST API endpoint | |
| url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}" | |
| payload = { | |
| 'grant_type': 'refresh_token', | |
| 'refresh_token': e | |
| } | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post(url, json=payload) | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(json.dumps(data, indent=2)) | |
| # return { | |
| # 'id_token': data['id_token'], | |
| # 'refresh_token': data.get('refresh_token'), | |
| # 'expires_in': data['expires_in'] | |
| # } | |
| # return data['id_token'] | |
| token_array.append(data['id_token']) | |
| else: | |
| print(f"刷新失败: {response.text}") | |
| return None | |
| except Exception as e: | |
| print(f"请求异常: {e}") | |
| return None | |
| return ','.join(token_array) | |
| async def sign_in_with_idp(): | |
| url = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp" | |
| # 查询参数 | |
| params = { | |
| "key": FIREBASE_API_KEY | |
| } | |
| # 请求头 | |
| headers = { | |
| "X-Client-Version": "Node/JsCore/10.5.2/FirebaseCore-web", | |
| "X-Firebase-gmpid": "1:123807869619:web:43b278a622ed6322789ec6", | |
| "Content-Type": "application/json", | |
| "User-Agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)" | |
| } | |
| # 请求体 | |
| data = { | |
| "requestUri": "http://localhost", | |
| "returnSecureToken": True, | |
| "postBody": f"&id_token={settings.AUTHORIZATION_TOKEN}&providerId=google.com" | |
| } | |
| print("Request Headers:", json.dumps(headers, indent=2)) # 格式化打印 | |
| print("Request Body:", json.dumps(data, indent=2)) # 格式化打印 | |
| print("Request params:", json.dumps(params, indent=2)) # 格式化打印 | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| url, | |
| params=params, | |
| headers=headers, | |
| json=data | |
| ) | |
| # 检查状态码 | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| raise Exception(f"Request failed with status code: {response.status_code}") | |
| async def handle_firebase_response(response) -> str: | |
| try: | |
| # 如果响应是字典(已经解析的 JSON) | |
| if isinstance(response, dict): | |
| print(json.dumps(response, indent=2)) | |
| if response.get('error', {}).get('code') == 400: | |
| print("Invalid id_token in IdP response") | |
| # 保存refresh_token到配置中 | |
| if 'refreshToken' in response: | |
| os.environ["REFRESH_TOKEN"] = response['refreshToken'] | |
| if 'idToken' in response: | |
| return response['idToken'] | |
| else: | |
| raise ValueError("dict case Response does not contain idToken") | |
| # 如果响应是 Response 对象 | |
| elif hasattr(response, 'status_code'): | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(data) | |
| # 保存refresh_token到配置中 | |
| if 'refreshToken' in data: | |
| os.environ["REFRESH_TOKEN"] = data['refreshToken'] | |
| if 'idToken' in data: | |
| return data['idToken'] | |
| else: | |
| raise ValueError("response case Response does not contain idToken") | |
| # 处理其他状态码 | |
| elif response.status_code == 400: | |
| error_data = response.json() | |
| raise ValueError(f"Bad Request: {error_data.get('error', {}).get('message', 'Unknown error')}") | |
| elif response.status_code == 401: | |
| raise ValueError("Unauthorized: Invalid credentials") | |
| elif response.status_code == 403: | |
| raise ValueError("Forbidden: Insufficient permissions") | |
| elif response.status_code == 404: | |
| raise ValueError("Not Found: Resource doesn't exist") | |
| else: | |
| raise ValueError(f"Unexpected status code: {response.status_code}") | |
| else: | |
| raise ValueError(f"Unexpected response type: {type(response)}") | |
| except json.JSONDecodeError: | |
| raise ValueError("Invalid JSON response") | |
| except Exception as e: | |
| raise ValueError(f"Error processing response: {str(e)}") | |
| # SHA-256 | |
| def _sha256_hash(text): | |
| sha256 = hashlib.sha256() | |
| sha256.update(text.encode('utf-8')) | |
| return sha256.hexdigest() | |
| # 处理字典列表 | |
| def sha256_hash_messages(messages): | |
| # 只提取 role 为 "user" 的消息的 content 字段 | |
| message_data = [str(msg['content']) for msg in messages if msg['role'] == "user"] | |
| print("Filtered contents:", message_data) # 调试用 | |
| json_str = json.dumps(message_data, sort_keys=True) | |
| print("JSON string:", json_str) # 调试用 | |
| return hashlib.sha256(json_str.encode('utf-8')).hexdigest() | |
| def create_chat_completion_data( | |
| content: str, model: str, timestamp: int, finish_reason: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "created": timestamp, | |
| "model": model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"content": content, "role": "assistant"}, | |
| "finish_reason": finish_reason, | |
| } | |
| ], | |
| "usage": None, | |
| } | |
| async def process_streaming_response(request: ChatRequest, app_secret: str, current_index: int): | |
| # 创建自定义 SSL 上下文 | |
| ssl_context = ssl.create_default_context() | |
| ssl_context.check_hostname = True | |
| ssl_context.verify_mode = ssl.CERT_REQUIRED | |
| async with httpx.AsyncClient( | |
| verify=ssl_context, | |
| # timeout=30.0, # 增加超时时间 | |
| # http2=True # 启用 HTTP/2 | |
| ) as client: | |
| try: | |
| token_str = os.getenv('TOKEN', '') | |
| token_array = token_str.split(',') | |
| if len(token_array) > 0: | |
| current_index = current_index % len(token_array) | |
| print('completions current index is ', current_index) | |
| request_headers = {**settings.HEADERS, 'authorization': f"Bearer {token_array[current_index]}"} # 从环境变量中获取新的TOKEN | |
| # 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据 | |
| request_data = request.model_dump() # 如果使用较新版本的 Pydantic | |
| # # 获取请求数据 | |
| # request_data = { | |
| # "model": request.model, | |
| # "messages": [msg.dict() for msg in request.messages], | |
| # "temperature": request.temperature, | |
| # "top_p": request.top_p, | |
| # "max_tokens": request.max_tokens, | |
| # "stream": request.stream | |
| # } | |
| # print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印 | |
| # print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印 | |
| print("Request Headers:", json.dumps(request_headers, indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印 | |
| print("Request Body:", json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印 | |
| async with client.stream( | |
| "POST", | |
| f"https://api.thinkbuddy.ai/v1/chat/completions", | |
| headers=request_headers, | |
| json=request_data, | |
| timeout=100, | |
| ) as response: | |
| response.raise_for_status() | |
| print(f"Response status code: {response.status_code}") | |
| timestamp = int(datetime.now().timestamp()) | |
| async for line in response.aiter_lines(): | |
| # print(f"{type(line)}: {line}") | |
| if line and line.startswith("data: "): | |
| try: | |
| if line.strip() == 'data: [DONE]': | |
| await response.aclose() | |
| break | |
| data_str = line[6:] # 去掉 'data: ' 前缀 | |
| # 解析JSON | |
| json_data = json.loads(data_str) | |
| if 'choices' in json_data and len(json_data['choices']) > 0: | |
| delta = json_data['choices'][0].get('delta', {}) | |
| if 'content' in delta: | |
| print(delta['content'], end='', flush=True) | |
| yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n" | |
| except json.JSONDecodeError as e: | |
| print(f"JSON解析错误: {e}") | |
| print(f"原始数据: {line}") | |
| continue | |
| yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except ConnectError as e: | |
| logger.error(f"Connection error details: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Service temporarily unavailable. Please try again later." | |
| ) | |
| except TransportError as e: | |
| logger.error(f"Transport error details: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail="Network transport error occurred." | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| # 这里需要处理401错误 | |
| # 处理429错误 | |
| if e.response.status_code == 429: | |
| token_str = os.getenv('TOKEN', '') | |
| token_array = token_str.split(',') | |
| token_array.pop(current_index) | |
| os.environ["TOKEN"] = ','.join(token_array) | |
| refresh_token_str = os.getenv('REFRESH_TOKEN', '') | |
| refresh_token_array = refresh_token_str.split(',') | |
| refresh_token_array.pop(current_index) | |
| os.environ["REFRESH_TOKEN"] = ','.join(refresh_token_array) | |
| logger.error(f"HTTP error occurred: {e}") | |
| raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
| except httpx.RequestError as e: | |
| logger.error(f"Error occurred during request: {e}") | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) | |
| finally: | |
| await response.aclose() | |