File size: 5,645 Bytes
1a4aa87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
API Versioning Middleware
Ensures consistent API versioning across all endpoints.
"""
import re
from typing import Optional, Callable
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from backend.logging.logger import get_logger
# Current API version
CURRENT_VERSION = "1.0.0"
SUPPORTED_VERSIONS = ["1.0.0", "1.0.0-beta", "0.9.0"]
# Version compatibility matrix
VERSION_COMPATIBILITY = {
"1.0.0": {
"min_client": "1.0.0",
"features": ["evaluation", "certification", "monitoring", "leaderboard", "risk_passport"],
"breaking": [],
},
"1.0.0-beta": {
"min_client": "0.9.0",
"features": ["evaluation", "certification", "monitoring", "leaderboard"],
"breaking": ["risk_passport"],
},
"0.9.0": {
"min_client": "0.9.0",
"features": ["evaluation", "certification"],
"breaking": ["monitoring", "leaderboard", "risk_passport"],
},
}
class APIVersioningMiddleware(BaseHTTPMiddleware):
"""
Middleware to handle API versioning.
- Extracts version from Accept-Version header or URL path
- Validates version compatibility
- Adds version headers to responses
- Handles deprecation warnings
"""
def __init__(self, app, default_version: str = CURRENT_VERSION):
super().__init__(app)
self.default_version = default_version
self.logger = get_logger("api.versioning")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract version from header
client_version = request.headers.get("Accept-Version")
# Extract version from URL path if header not present
if not client_version:
client_version = self._extract_version_from_path(request.url.path)
# Use default if no version specified
if not client_version:
client_version = self.default_version
# Validate version
is_supported, version_to_use = self._validate_version(client_version)
# Add version info to request state
request.state.api_version = version_to_use
request.state.client_version = client_version
# Process request
response = await call_next(request)
# Add version headers to response
response.headers["X-API-Version"] = version_to_use
response.headers["X-Supported-Versions"] = ",".join(SUPPORTED_VERSIONS)
# Add deprecation warning if needed
if self._is_deprecated(version_to_use):
response.headers["X-API-Deprecated"] = "true"
response.headers["X-API-Deprecation-Date"] = self._get_deprecation_date(version_to_use)
# Return appropriate response
if not is_supported:
return JSONResponse(
status_code=400,
content={
"error": "Unsupported API Version",
"message": f"Version {client_version} is not supported",
"supported_versions": SUPPORTED_VERSIONS,
"current_version": CURRENT_VERSION,
},
headers={
"X-API-Version": version_to_use,
"X-Supported-Versions": ",".join(SUPPORTED_VERSIONS),
}
)
return response
def _extract_version_from_path(self, path: str) -> Optional[str]:
"""Extract version from URL path like /api/v1/..."""
match = re.match(r'/api/(v\d+(?:\.\d+)?(?:-[a-z]+)?)', path)
if match:
version = match.group(1).replace("v", "")
return version
return None
def _validate_version(self, version: str) -> tuple[bool, str]:
"""Validate if version is supported, return (is_supported, effective_version)"""
if version in SUPPORTED_VERSIONS:
return True, version
# Try to find compatible version
# For now, just return current version if not found
return False, CURRENT_VERSION
def _is_deprecated(self, version: str) -> bool:
"""Check if version is deprecated"""
return version in ["0.9.0"]
def _get_deprecation_date(self, version: str) -> str:
"""Get deprecation date for version"""
deprecation_dates = {
"0.9.0": "2024-12-31",
}
return deprecation_dates.get(version, "Unknown")
def get_api_version(request: Request) -> str:
"""
Get the API version for a request.
Args:
request: FastAPI request object
Returns:
API version string
"""
return getattr(request.state, "api_version", CURRENT_VERSION)
def require_feature(request: Request, feature: str) -> bool:
"""
Check if a feature is available for the client's API version.
Args:
request: FastAPI request object
feature: Feature name to check
Returns:
True if feature is available
"""
version = get_api_version(request)
config = VERSION_COMPATIBILITY.get(version, VERSION_COMPATIBILITY[CURRENT_VERSION])
return feature in config.get("features", [])
__all__ = [
"APIVersioningMiddleware",
"CURRENT_VERSION",
"SUPPORTED_VERSIONS",
"VERSION_COMPATIBILITY",
"get_api_version",
"require_feature",
]
|