| from fastapi import Request |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from app.config.config import settings |
| from app.log.logger import get_main_logger |
| import re |
|
|
| logger = get_main_logger() |
|
|
| class SmartRoutingMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app): |
| super().__init__(app) |
| |
| pass |
|
|
| async def dispatch(self, request: Request, call_next): |
| if not settings.URL_NORMALIZATION_ENABLED: |
| return await call_next(request) |
| logger.debug(f"request: {request}") |
| original_path = str(request.url.path) |
| method = request.method |
| |
| |
| fixed_path, fix_info = self.fix_request_url(original_path, method, request) |
|
|
| if fixed_path != original_path: |
| logger.info(f"URL fixed: {method} {original_path} → {fixed_path}") |
| if fix_info: |
| logger.debug(f"Fix details: {fix_info}") |
|
|
| |
| request.scope["path"] = fixed_path |
| request.scope["raw_path"] = fixed_path.encode() |
| |
| return await call_next(request) |
|
|
| def fix_request_url(self, path: str, method: str, request: Request) -> tuple: |
| """简化的URL修复逻辑""" |
|
|
| |
| if self.is_already_correct_format(path): |
| return path, None |
|
|
| |
| if "generatecontent" in path.lower() or "v1beta/models" in path.lower(): |
| return self.fix_gemini_by_operation(path, method, request) |
|
|
| |
| if "/openai/" in path.lower(): |
| return self.fix_openai_by_operation(path, method) |
|
|
| |
| if "/v1/" in path.lower(): |
| return self.fix_v1_by_operation(path, method) |
|
|
| |
| if "/chat/completions" in path.lower(): |
| return "/v1/chat/completions", {"type": "v1_chat"} |
|
|
| |
| return path, None |
|
|
| def is_already_correct_format(self, path: str) -> bool: |
| """检查是否已经是正确的API格式""" |
| |
| correct_patterns = [ |
| r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
| r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
| r"^/v1beta/models$", |
| r"^/gemini/v1beta/models$", |
| r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
| r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
| r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", |
| r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", |
| r"^/vertex-express/v1beta/models$", |
| r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", |
| ] |
|
|
| for pattern in correct_patterns: |
| if re.match(pattern, path): |
| return True |
|
|
| return False |
|
|
| def fix_gemini_by_operation( |
| self, path: str, method: str, request: Request |
| ) -> tuple: |
| """根据Gemini操作修复,考虑端点偏好""" |
| if method == "GET": |
| return "/v1beta/models", { |
| "role": "gemini_models", |
| } |
|
|
| |
| try: |
| model_name = self.extract_model_name(path, request) |
| except ValueError: |
| |
| return path, None |
|
|
| |
| is_stream = self.detect_stream_request(path, request) |
|
|
| |
| if "/vertex-express/" in path.lower(): |
| if is_stream: |
| target_url = ( |
| f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent" |
| ) |
| else: |
| target_url = ( |
| f"/vertex-express/v1beta/models/{model_name}:generateContent" |
| ) |
|
|
| fix_info = { |
| "rule": ( |
| "vertex_express_generate" |
| if not is_stream |
| else "vertex_express_stream" |
| ), |
| "preference": "vertex_express_format", |
| "is_stream": is_stream, |
| "model": model_name, |
| } |
| else: |
| |
| if is_stream: |
| target_url = f"/v1beta/models/{model_name}:streamGenerateContent" |
| else: |
| target_url = f"/v1beta/models/{model_name}:generateContent" |
|
|
| fix_info = { |
| "rule": "gemini_generate" if not is_stream else "gemini_stream", |
| "preference": "gemini_format", |
| "is_stream": is_stream, |
| "model": model_name, |
| } |
|
|
| return target_url, fix_info |
|
|
| def fix_openai_by_operation(self, path: str, method: str) -> tuple: |
| """根据操作类型修复OpenAI格式""" |
| if method == "POST": |
| if "chat" in path.lower() or "completion" in path.lower(): |
| return "/openai/v1/chat/completions", {"type": "openai_chat"} |
| elif "embedding" in path.lower(): |
| return "/openai/v1/embeddings", {"type": "openai_embeddings"} |
| elif "image" in path.lower(): |
| return "/openai/v1/images/generations", {"type": "openai_images"} |
| elif "audio" in path.lower(): |
| return "/openai/v1/audio/speech", {"type": "openai_audio"} |
| elif method == "GET": |
| if "model" in path.lower(): |
| return "/openai/v1/models", {"type": "openai_models"} |
|
|
| return path, None |
|
|
| def fix_v1_by_operation(self, path: str, method: str) -> tuple: |
| """根据操作类型修复v1格式""" |
| if method == "POST": |
| if "chat" in path.lower() or "completion" in path.lower(): |
| return "/v1/chat/completions", {"type": "v1_chat"} |
| elif "embedding" in path.lower(): |
| return "/v1/embeddings", {"type": "v1_embeddings"} |
| elif "image" in path.lower(): |
| return "/v1/images/generations", {"type": "v1_images"} |
| elif "audio" in path.lower(): |
| return "/v1/audio/speech", {"type": "v1_audio"} |
| elif method == "GET": |
| if "model" in path.lower(): |
| return "/v1/models", {"type": "v1_models"} |
|
|
| return path, None |
|
|
| def detect_stream_request(self, path: str, request: Request) -> bool: |
| """检测是否为流式请求""" |
| |
| if "stream" in path.lower(): |
| return True |
|
|
| |
| if request.query_params.get("stream") == "true": |
| return True |
|
|
| return False |
|
|
| def extract_model_name(self, path: str, request: Request) -> str: |
| """从请求中提取模型名称,用于构建Gemini API URL""" |
| |
| try: |
| if hasattr(request, "_body") and request._body: |
| import json |
|
|
| body = json.loads(request._body.decode()) |
| if "model" in body and body["model"]: |
| return body["model"] |
| except Exception: |
| pass |
|
|
| |
| model_param = request.query_params.get("model") |
| if model_param: |
| return model_param |
|
|
| |
| match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE) |
| if match: |
| return match.group(1) |
|
|
| |
| raise ValueError("Unable to extract model name from request") |
|
|