|
|
from fastapi import Request |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
from app.config.config import settings |
|
|
from app.log.logger import get_main_logger |
|
|
import re |
|
|
|
|
|
logger = get_main_logger() |
|
|
|
|
|
class SmartRoutingMiddleware(BaseHTTPMiddleware): |
|
|
def __init__(self, app): |
|
|
super().__init__(app) |
|
|
|
|
|
pass |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
if not settings.URL_NORMALIZATION_ENABLED: |
|
|
return await call_next(request) |
|
|
logger.debug(f"request: {request}") |
|
|
original_path = str(request.url.path) |
|
|
method = request.method |
|
|
|
|
|
|
|
|
fixed_path, fix_info = self.fix_request_url(original_path, method, request) |
|
|
|
|
|
if fixed_path != original_path: |
|
|
logger.info(f"URL fixed: {method} {original_path} → {fixed_path}") |
|
|
if fix_info: |
|
|
logger.debug(f"Fix details: {fix_info}") |
|
|
|
|
|
|
|
|
request.scope["path"] = fixed_path |
|
|
request.scope["raw_path"] = fixed_path.encode() |
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
def fix_request_url(self, path: str, method: str, request: Request) -> tuple: |
|
|
"""简化的URL修复逻辑""" |
|
|
|
|
|
|
|
|
if self.is_already_correct_format(path): |
|
|
return path, None |
|
|
|
|
|
|
|
|
if "generatecontent" in path.lower() or "v1beta/models" in path.lower(): |
|
|
return self.fix_gemini_by_operation(path, method, request) |
|
|
|
|
|
|
|
|
if "/openai/" in path.lower(): |
|
|
return self.fix_openai_by_operation(path, method) |
|
|
|
|
|
|
|
|
if "/v1/" in path.lower(): |
|
|
return self.fix_v1_by_operation(path, method) |
|
|
|
|
|
|
|
|
if "/chat/completions" in path.lower(): |
|
|
return "/v1/chat/completions", {"type": "v1_chat"} |
|
|
|
|
|
|
|
|
return path, None |
|
|
|
|
|
def is_already_correct_format(self, path: str) -> bool: |
|
|
"""检查是否已经是正确的API格式""" |
|
|
|
|
|
correct_patterns = [ |
|
|
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
|
|
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
|
|
r"^/v1beta/models$", |
|
|
r"^/gemini/v1beta/models$", |
|
|
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
|
|
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
|
|
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
|
|
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
|
|
r"^/vertex-express/v1beta/models$", |
|
|
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", |
|
|
] |
|
|
|
|
|
for pattern in correct_patterns: |
|
|
if re.match(pattern, path): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def fix_gemini_by_operation( |
|
|
self, path: str, method: str, request: Request |
|
|
) -> tuple: |
|
|
"""根据Gemini操作修复,考虑端点偏好""" |
|
|
if method == "GET": |
|
|
return "/v1beta/models", { |
|
|
"role": "gemini_models", |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
model_name = self.extract_model_name(path, request) |
|
|
except ValueError: |
|
|
|
|
|
return path, None |
|
|
|
|
|
|
|
|
is_stream = self.detect_stream_request(path, request) |
|
|
|
|
|
|
|
|
if "/vertex-express/" in path.lower(): |
|
|
if is_stream: |
|
|
target_url = ( |
|
|
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent" |
|
|
) |
|
|
else: |
|
|
target_url = ( |
|
|
f"/vertex-express/v1beta/models/{model_name}:generateContent" |
|
|
) |
|
|
|
|
|
fix_info = { |
|
|
"rule": ( |
|
|
"vertex_express_generate" |
|
|
if not is_stream |
|
|
else "vertex_express_stream" |
|
|
), |
|
|
"preference": "vertex_express_format", |
|
|
"is_stream": is_stream, |
|
|
"model": model_name, |
|
|
} |
|
|
else: |
|
|
|
|
|
if is_stream: |
|
|
target_url = f"/v1beta/models/{model_name}:streamGenerateContent" |
|
|
else: |
|
|
target_url = f"/v1beta/models/{model_name}:generateContent" |
|
|
|
|
|
fix_info = { |
|
|
"rule": "gemini_generate" if not is_stream else "gemini_stream", |
|
|
"preference": "gemini_format", |
|
|
"is_stream": is_stream, |
|
|
"model": model_name, |
|
|
} |
|
|
|
|
|
return target_url, fix_info |
|
|
|
|
|
def fix_openai_by_operation(self, path: str, method: str) -> tuple: |
|
|
"""根据操作类型修复OpenAI格式""" |
|
|
if method == "POST": |
|
|
if "chat" in path.lower() or "completion" in path.lower(): |
|
|
return "/openai/v1/chat/completions", {"type": "openai_chat"} |
|
|
elif "embedding" in path.lower(): |
|
|
return "/openai/v1/embeddings", {"type": "openai_embeddings"} |
|
|
elif "image" in path.lower(): |
|
|
return "/openai/v1/images/generations", {"type": "openai_images"} |
|
|
elif "audio" in path.lower(): |
|
|
return "/openai/v1/audio/speech", {"type": "openai_audio"} |
|
|
elif method == "GET": |
|
|
if "model" in path.lower(): |
|
|
return "/openai/v1/models", {"type": "openai_models"} |
|
|
|
|
|
return path, None |
|
|
|
|
|
def fix_v1_by_operation(self, path: str, method: str) -> tuple: |
|
|
"""根据操作类型修复v1格式""" |
|
|
if method == "POST": |
|
|
if "chat" in path.lower() or "completion" in path.lower(): |
|
|
return "/v1/chat/completions", {"type": "v1_chat"} |
|
|
elif "embedding" in path.lower(): |
|
|
return "/v1/embeddings", {"type": "v1_embeddings"} |
|
|
elif "image" in path.lower(): |
|
|
return "/v1/images/generations", {"type": "v1_images"} |
|
|
elif "audio" in path.lower(): |
|
|
return "/v1/audio/speech", {"type": "v1_audio"} |
|
|
elif method == "GET": |
|
|
if "model" in path.lower(): |
|
|
return "/v1/models", {"type": "v1_models"} |
|
|
|
|
|
return path, None |
|
|
|
|
|
def detect_stream_request(self, path: str, request: Request) -> bool: |
|
|
"""检测是否为流式请求""" |
|
|
|
|
|
if "stream" in path.lower(): |
|
|
return True |
|
|
|
|
|
|
|
|
if request.query_params.get("stream") == "true": |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def extract_model_name(self, path: str, request: Request) -> str: |
|
|
"""从请求中提取模型名称,用于构建Gemini API URL""" |
|
|
|
|
|
try: |
|
|
if hasattr(request, "_body") and request._body: |
|
|
import json |
|
|
|
|
|
body = json.loads(request._body.decode()) |
|
|
if "model" in body and body["model"]: |
|
|
return body["model"] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
model_param = request.query_params.get("model") |
|
|
if model_param: |
|
|
return model_param |
|
|
|
|
|
|
|
|
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE) |
|
|
if match: |
|
|
return match.group(1) |
|
|
|
|
|
|
|
|
raise ValueError("Unable to extract model name from request") |
|
|
|