Benedict King commited on
Commit ·
2ec384d
1
Parent(s): d4d650a
feat: add TextToSpeechRequest model and implement audio speech endpoint with processing logic
Browse files- main.py +17 -3
- models.py +9 -1
- request.py +28 -1
- response.py +4 -1
- utils.py +1 -0
main.py
CHANGED
|
@@ -15,7 +15,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
|
| 15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 16 |
from fastapi.exceptions import RequestValidationError
|
| 17 |
|
| 18 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
@@ -360,6 +360,9 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 360 |
moderated_content = request_model.get_last_text_message()
|
| 361 |
elif request_model.request_type == "image":
|
| 362 |
moderated_content = request_model.prompt
|
|
|
|
|
|
|
|
|
|
| 363 |
if moderated_content:
|
| 364 |
current_info["text"] = moderated_content
|
| 365 |
|
|
@@ -521,6 +524,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 521 |
engine = "moderation"
|
| 522 |
request.stream = False
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
if provider.get("engine"):
|
| 525 |
engine = provider["engine"]
|
| 526 |
|
|
@@ -662,7 +669,7 @@ class ModelRequestHandler:
|
|
| 662 |
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
| 663 |
return provider_list
|
| 664 |
|
| 665 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
| 666 |
config = app.state.config
|
| 667 |
# api_keys_db = app.state.api_keys_db
|
| 668 |
api_list = app.state.api_list
|
|
@@ -705,7 +712,7 @@ class ModelRequestHandler:
|
|
| 705 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
| 706 |
|
| 707 |
# 在 try_all_providers 函数中处理失败的情况
|
| 708 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
| 709 |
status_code = 500
|
| 710 |
error_message = None
|
| 711 |
num_providers = len(providers)
|
|
@@ -866,6 +873,13 @@ async def images_generations(
|
|
| 866 |
):
|
| 867 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 870 |
async def moderations(
|
| 871 |
request: ModerationRequest,
|
|
|
|
| 15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 16 |
from fastapi.exceptions import RequestValidationError
|
| 17 |
|
| 18 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
|
| 360 |
moderated_content = request_model.get_last_text_message()
|
| 361 |
elif request_model.request_type == "image":
|
| 362 |
moderated_content = request_model.prompt
|
| 363 |
+
elif model.startswith("tts"):
|
| 364 |
+
moderated_content = request_model.input
|
| 365 |
+
|
| 366 |
if moderated_content:
|
| 367 |
current_info["text"] = moderated_content
|
| 368 |
|
|
|
|
| 524 |
engine = "moderation"
|
| 525 |
request.stream = False
|
| 526 |
|
| 527 |
+
if endpoint == "/v1/audio/speech":
|
| 528 |
+
engine = "tts"
|
| 529 |
+
request.stream = False
|
| 530 |
+
|
| 531 |
if provider.get("engine"):
|
| 532 |
engine = provider["engine"]
|
| 533 |
|
|
|
|
| 669 |
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
| 670 |
return provider_list
|
| 671 |
|
| 672 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], token: str, endpoint=None):
|
| 673 |
config = app.state.config
|
| 674 |
# api_keys_db = app.state.api_keys_db
|
| 675 |
api_list = app.state.api_list
|
|
|
|
| 712 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
| 713 |
|
| 714 |
# 在 try_all_providers 函数中处理失败的情况
|
| 715 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
| 716 |
status_code = 500
|
| 717 |
error_message = None
|
| 718 |
num_providers = len(providers)
|
|
|
|
| 873 |
):
|
| 874 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 875 |
|
| 876 |
+
@app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
|
| 877 |
+
async def audio_speech(
|
| 878 |
+
request: TextToSpeechRequest,
|
| 879 |
+
token: str = Depends(verify_api_key)
|
| 880 |
+
):
|
| 881 |
+
return await model_handler.request_model(request, token, endpoint="/v1/audio/speech")
|
| 882 |
+
|
| 883 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 884 |
async def moderations(
|
| 885 |
request: ModerationRequest,
|
models.py
CHANGED
|
@@ -134,4 +134,12 @@ class UnifiedRequest(BaseModel):
|
|
| 134 |
values["data"].request_type = "moderation"
|
| 135 |
else:
|
| 136 |
raise ValueError("无法确定请求类型")
|
| 137 |
-
return values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
values["data"].request_type = "moderation"
|
| 135 |
else:
|
| 136 |
raise ValueError("无法确定请求类型")
|
| 137 |
+
return values
|
| 138 |
+
|
| 139 |
+
class TextToSpeechRequest(BaseRequest):
|
| 140 |
+
model: str
|
| 141 |
+
input: str
|
| 142 |
+
voice: str
|
| 143 |
+
response_format: Optional[str] = "mp3"
|
| 144 |
+
speed: Optional[float] = 1.0
|
| 145 |
+
stream: Optional[bool] = False # Add this line
|
request.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import json
|
|
|
|
| 4 |
import httpx
|
| 5 |
import base64
|
| 6 |
import urllib.parse
|
|
@@ -1134,7 +1135,33 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
| 1134 |
return await get_dalle_payload(request, engine, provider)
|
| 1135 |
elif engine == "whisper":
|
| 1136 |
return await get_whisper_payload(request, engine, provider)
|
|
|
|
|
|
|
| 1137 |
elif engine == "moderation":
|
| 1138 |
return await get_moderation_payload(request, engine, provider)
|
| 1139 |
else:
|
| 1140 |
-
raise ValueError("Unknown payload")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import json
|
| 4 |
+
from venv import logger
|
| 5 |
import httpx
|
| 6 |
import base64
|
| 7 |
import urllib.parse
|
|
|
|
| 1135 |
return await get_dalle_payload(request, engine, provider)
|
| 1136 |
elif engine == "whisper":
|
| 1137 |
return await get_whisper_payload(request, engine, provider)
|
| 1138 |
+
elif engine == "tts":
|
| 1139 |
+
return await get_tts_payload(request, engine, provider)
|
| 1140 |
elif engine == "moderation":
|
| 1141 |
return await get_moderation_payload(request, engine, provider)
|
| 1142 |
else:
|
| 1143 |
+
raise ValueError("Unknown payload")
|
| 1144 |
+
|
| 1145 |
+
async def get_tts_payload(request, engine, provider):
|
| 1146 |
+
headers = {
|
| 1147 |
+
"Content-Type": "application/json",
|
| 1148 |
+
}
|
| 1149 |
+
if provider.get("api"):
|
| 1150 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
| 1151 |
+
url = provider['base_url']
|
| 1152 |
+
url = BaseAPI(url).audio_speech
|
| 1153 |
+
|
| 1154 |
+
payload = {
|
| 1155 |
+
"model": provider['model'][request.model],
|
| 1156 |
+
"input": request.input,
|
| 1157 |
+
"voice": request.voice,
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
if request.response_format:
|
| 1161 |
+
payload["response_format"] = request.response_format
|
| 1162 |
+
if request.speed:
|
| 1163 |
+
payload["speed"] = request.speed
|
| 1164 |
+
if request.stream is not None:
|
| 1165 |
+
payload["stream"] = request.stream
|
| 1166 |
+
|
| 1167 |
+
return url, headers, payload
|
response.py
CHANGED
|
@@ -285,7 +285,10 @@ async def fetch_response(client, url, headers, payload):
|
|
| 285 |
if error_message:
|
| 286 |
yield error_message
|
| 287 |
return
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 291 |
try:
|
|
|
|
| 285 |
if error_message:
|
| 286 |
yield error_message
|
| 287 |
return
|
| 288 |
+
if url.endswith("/v1/audio/speech"):
|
| 289 |
+
yield response.read()
|
| 290 |
+
else:
|
| 291 |
+
yield response.json()
|
| 292 |
|
| 293 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 294 |
try:
|
utils.py
CHANGED
|
@@ -313,6 +313,7 @@ class BaseAPI:
|
|
| 313 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 314 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 315 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
|
|
|
| 316 |
|
| 317 |
def safe_get(data, *keys, default=None):
|
| 318 |
for key in keys:
|
|
|
|
| 313 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 314 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 315 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
| 316 |
+
self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/speech",) + ("",) * 3)
|
| 317 |
|
| 318 |
def safe_get(data, *keys, default=None):
|
| 319 |
for key in keys:
|