Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import os | |
| import traceback | |
| from typing import Optional | |
| from uuid import uuid4 | |
| import httpx | |
| from fastapi import File, UploadFile | |
| from fastapi import APIRouter, Response, Request, Depends, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from core.auth import verify_app_secret | |
| from core.config import get_settings | |
| from core.logger import setup_logger | |
| from core.models import ChatRequest | |
| from core.utils import process_streaming_response | |
| from playsound import playsound # 用于播放音频 | |
| # from rich import print | |
| logger = setup_logger(__name__) | |
| router = APIRouter() | |
| ALLOWED_MODELS = get_settings().ALLOWED_MODELS | |
| current_index = 0 | |
| async def list_models(): | |
| return {"object": "list", "data": ALLOWED_MODELS, "success": True} | |
| async def chat_completions_options(): | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "Content-Type, Authorization", | |
| }, | |
| ) | |
| # 识图 | |
| # 识图 | |
| # 文本转语音 | |
| async def speech(request: Request): | |
| global current_index | |
| url = 'https://api.thinkbuddy.ai/v1/content/speech/tts' | |
| token_str = os.getenv('TOKEN', '') | |
| token_array = token_str.split(',') | |
| if len(token_array) > 0: | |
| current_index = current_index % len(token_array) | |
| print('speech current index is ', current_index) | |
| request_headers = {**get_settings().HEADERS, | |
| 'authorization': f"Bearer {token_array[current_index]}", | |
| 'Accept': 'application/json, text/plain, */*', | |
| } | |
| # data = { | |
| # "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。", | |
| # "voice": "nova" # alloy echo fable onyx nova shimmer | |
| # } | |
| body = await request.json() | |
| try: | |
| async with httpx.AsyncClient(http2=True) as client: | |
| response = await client.post(url, headers=request_headers, json=body) | |
| response.raise_for_status() | |
| # 假设响应是音频数据,保存为文件 | |
| if response.status_code == 200: | |
| # 保存音频文件 | |
| with open('output.mp3', 'wb') as f: | |
| f.write(response.content) | |
| print("音频文件已保存为 output.mp3") | |
| # 异步播放音频 | |
| # 使用 asyncio.to_thread 来避免阻塞事件循环 | |
| # await asyncio.to_thread(playsound, 'output.mp3') | |
| return True | |
| else: | |
| print(f"请求失败,状态码: {response.status_code}") | |
| print(f"响应内容: {response.text}") | |
| return False | |
| except httpx.RequestError as e: | |
| print(f"请求错误: {e}") | |
| print("错误堆栈:") | |
| traceback.print_exc() | |
| return False | |
| except httpx.HTTPStatusError as e: | |
| print(f"HTTP 错误: {e}") | |
| print("错误堆栈:") | |
| traceback.print_exc() | |
| return False | |
| except Exception as e: | |
| print(f"发生错误: {e}") | |
| print("错误堆栈:") | |
| traceback.print_exc() | |
| return False | |
| finally: | |
| current_index += 1 | |
| # 语音转文本 | |
| async def transcriptions(request: Request, file: UploadFile = File(...)): | |
| global current_index | |
| url = 'https://api.thinkbuddy.ai/v1/content/transcribe' | |
| params = {'enhance': 'true'} | |
| try: | |
| # 读取文件内容 | |
| content = await safe_read_file(file) | |
| # 获取原始 content-type | |
| content_type = request.headers.get('content-type') | |
| # files = { | |
| # 'file': (str(uuid4()), | |
| # content, | |
| # file.content_type or 'application/octet-stream') | |
| # } | |
| files = { | |
| 'file': ('file.mp4', content, 'audio/mp4'), | |
| 'model': (None, 'whisper-1') | |
| } | |
| # 记录请求信息 | |
| logger.info(f"Received upload request for file: {file.filename}") | |
| logger.info(f"Content-Type: {request.headers.get('content-type')}") | |
| token_str = os.getenv('TOKEN', '') | |
| token_array = token_str.split(',') | |
| if len(token_array) > 0: | |
| current_index = current_index % len(token_array) | |
| print('transcriptions current index is ', current_index) | |
| request_headers = {**get_settings().HEADERS, | |
| 'authorization': f"Bearer {token_array[current_index]}", | |
| 'Accept': 'application/json, text/plain, */*', | |
| 'Content-Type': content_type, | |
| } | |
| # 设置较长的超时时间 | |
| timeout = httpx.Timeout( | |
| connect=30.0, # 连接超时 | |
| read=300.0, # 读取超时 | |
| write=30.0, # 写入超时 | |
| pool=30.0 # 连接池超时 | |
| ) | |
| # 使用httpx发送异步请求 | |
| async with httpx.AsyncClient(http2=True, timeout=timeout) as client: | |
| response = await client.post(url, | |
| params=params, | |
| headers=request_headers, | |
| files=files) | |
| current_index += 1 | |
| response.raise_for_status() | |
| return response.json() | |
| except httpx.TimeoutException: | |
| raise HTTPException(status_code=504, detail="请求目标服务器超时") | |
| except httpx.HTTPStatusError as e: | |
| raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
| except Exception as e: | |
| traceback.print_tb(e.__traceback__) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # 清理资源 | |
| await file.close() | |
| async def safe_read_file(file: UploadFile) -> Optional[bytes]: | |
| """安全地读取文件内容""" | |
| try: | |
| return await file.read() | |
| except Exception as e: | |
| logger.error(f"Error reading file: {str(e)}") | |
| return None | |
| # 文件上传 | |
| async def upload_file(request: Request, file: UploadFile = File(...)): | |
| global current_index | |
| try: | |
| # 读取文件内容 | |
| content = await safe_read_file(file) | |
| # 获取原始 content-type | |
| content_type = request.headers.get('content-type') | |
| files = { | |
| 'file': ( | |
| # str(uuid4()), | |
| file.filename, # 使用原始文件名而不是 UUID | |
| content, | |
| file.content_type ) | |
| } | |
| # 记录请求信息 | |
| logger.info(f"Received upload request for file: {file.filename}") | |
| logger.info(f"Content-Type: {request.headers.get('content-type')}") | |
| token_str = os.getenv('TOKEN', '') | |
| token_array = token_str.split(',') | |
| if len(token_array) > 0: | |
| current_index = current_index % len(token_array) | |
| print('upload_file current index is ', current_index) | |
| request_headers = {**get_settings().HEADERS, | |
| 'authorization': f"Bearer {token_array[current_index]}", | |
| 'Accept': 'application/json, text/plain, */*', | |
| 'Content-Type': content_type, | |
| } | |
| # 使用httpx发送异步请求 | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100) | |
| current_index += 1 | |
| response.raise_for_status() | |
| return response.json() | |
| except httpx.TimeoutException: | |
| raise HTTPException(status_code=504, detail="请求目标服务器超时") | |
| except httpx.HTTPStatusError as e: | |
| # raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
| print(f"HTTPStatusError发生错误: {e}") | |
| print("错误堆栈:") | |
| traceback.print_exc() | |
| except Exception as e: | |
| # traceback.print_tb(e.__traceback__) | |
| # raise HTTPException(status_code=500, detail=str(e)) | |
| print(f"发生错误: {e}") | |
| print("错误堆栈:") | |
| traceback.print_exc() | |
| finally: | |
| # 清理资源 | |
| await file.close() | |
| async def chat_completions( | |
| request: ChatRequest, app_secret: str = Depends(verify_app_secret) | |
| ): | |
| global current_index | |
| logger.info("Entering chat_completions route") | |
| # logger.info(f"Received request: {request}") | |
| # logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}") | |
| # logger.info(f"Received request json format: {json.dumps(request.model_dump())}") | |
| # logger.info(f"Received request json format: {json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))}") | |
| logger.info(f"App secret: {app_secret}") | |
| logger.info(f"Received chat completion request for model: {request.model}") | |
| if request.model not in [model["id"] for model in ALLOWED_MODELS]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}", | |
| ) | |
| if request.stream: | |
| logger.info("Streaming response") | |
| # 创建一个标志来追踪是否有响应 | |
| has_response = False | |
| async def content_generator(): | |
| nonlocal has_response | |
| try: | |
| async for item in process_streaming_response(request, app_secret, current_index): | |
| has_response = True | |
| yield item | |
| except Exception as e: | |
| logger.error(f"Error in streaming response: {e}") | |
| raise | |
| response = StreamingResponse( | |
| content_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Transfer-Encoding": "chunked" | |
| } | |
| ) | |
| # 在返回响应之前增加 current_index | |
| # if has_response: | |
| # current_index += 1 | |
| current_index += 1 | |
| return response | |
| else: | |
| logger.info("Non-streaming response") | |
| # return await process_non_streaming_response(request) | |
| async def health_check(request: Request): | |
| return Response(content=json.dumps({"status": "ok"}), media_type="application/json") | |
| async def environment(app_secret: str = Depends(verify_app_secret)): | |
| length = 0 | |
| if os.getenv('TOKEN', '').split(',') is not None: | |
| length = len(os.getenv('TOKEN', '').split(',')) | |
| return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "length": length, "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json") |