Spaces:
Paused
Paused
| import base64 | |
| import json | |
| import logging | |
| import time | |
| from urllib.parse import urlencode | |
| from Crypto.Cipher import AES | |
| from Crypto.Random import get_random_bytes | |
| from Crypto.Util.Padding import pad, unpad | |
| from fastapi import HTTPException, Request | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import JSONResponse | |
| from mediaflow_proxy.configs import settings | |
| class EncryptionHandler: | |
| def __init__(self, secret_key: str): | |
| self.secret_key = secret_key.encode("utf-8").ljust(32)[:32] | |
| def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str: | |
| if expiration: | |
| data["exp"] = int(time.time()) + expiration | |
| if ip: | |
| data["ip"] = ip | |
| json_data = json.dumps(data).encode("utf-8") | |
| iv = get_random_bytes(16) | |
| cipher = AES.new(self.secret_key, AES.MODE_CBC, iv) | |
| encrypted_data = cipher.encrypt(pad(json_data, AES.block_size)) | |
| return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8") | |
| def decrypt_data(self, token: str, client_ip: str) -> dict: | |
| try: | |
| encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8")) | |
| iv = encrypted_data[:16] | |
| cipher = AES.new(self.secret_key, AES.MODE_CBC, iv) | |
| decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size) | |
| data = json.loads(decrypted_data) | |
| if "exp" in data: | |
| if data["exp"] < time.time(): | |
| raise HTTPException(status_code=401, detail="Token has expired") | |
| del data["exp"] # Remove expiration from the data | |
| if "ip" in data: | |
| if data["ip"] != client_ip: | |
| raise HTTPException(status_code=403, detail="IP address mismatch") | |
| del data["ip"] # Remove IP from the data | |
| return data | |
| except Exception as e: | |
| raise HTTPException(status_code=401, detail="Invalid or expired token") | |
| class EncryptionMiddleware(BaseHTTPMiddleware): | |
| def __init__(self, app): | |
| super().__init__(app) | |
| self.encryption_handler = encryption_handler | |
| async def dispatch(self, request: Request, call_next): | |
| encrypted_token = request.query_params.get("token") | |
| if encrypted_token: | |
| try: | |
| client_ip = self.get_client_ip(request) | |
| decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip) | |
| # Modify request query parameters with decrypted data | |
| query_params = dict(request.query_params) | |
| query_params.pop("token") # Remove the encrypted token from query params | |
| query_params.update(decrypted_data) # Add decrypted data to query params | |
| query_params["has_encrypted"] = True | |
| # Create a new request scope with updated query parameters | |
| new_query_string = urlencode(query_params) | |
| request.scope["query_string"] = new_query_string.encode() | |
| request._query_params = query_params | |
| except HTTPException as e: | |
| return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code) | |
| try: | |
| response = await call_next(request) | |
| except Exception as e: | |
| logging.exception("An error occurred while processing the request") | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| return response | |
| def get_client_ip(request: Request) -> str | None: | |
| """ | |
| Extract the client's real IP address from the request headers or fallback to the client host. | |
| """ | |
| x_forwarded_for = request.headers.get("X-Forwarded-For") | |
| if x_forwarded_for: | |
| # In some cases, this header can contain multiple IPs | |
| # separated by commas. | |
| # The first one is the original client's IP. | |
| return x_forwarded_for.split(",")[0].strip() | |
| # Fallback to X-Real-IP if X-Forwarded-For is not available | |
| x_real_ip = request.headers.get("X-Real-IP") | |
| if x_real_ip: | |
| return x_real_ip | |
| return request.client.host if request.client else "127.0.0.1" | |
| encryption_handler = EncryptionHandler(settings.api_password) | |