Spaces:
Running
Running
| from __future__ import annotations | |
| from fastapi import FastAPI, Request, Response, Depends | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
| from contextlib import asynccontextmanager | |
| from urllib.parse import quote | |
| import httpx | |
| import asyncio | |
| import os | |
| import secrets | |
| # ============== 保活配置 ============== | |
| SPACE_HOST = os.environ.get("KEEPALIVE_HOST", "doasyousay-googleapis.hf.space") | |
| KEEP_ALIVE_INTERVAL = 300 | |
| async def keep_alive_task(): | |
| """后台保活任务""" | |
| await asyncio.sleep(30) | |
| while True: | |
| try: | |
| if SPACE_HOST: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get("https://" + SPACE_HOST, timeout=30) | |
| print(f"[Keep-Alive] Ping -> {resp.status_code}") | |
| else: | |
| print("[Keep-Alive] SPACE_HOST not set, skipping...") | |
| except Exception as e: | |
| print(f"[Keep-Alive] Error: {e}") | |
| await asyncio.sleep(KEEP_ALIVE_INTERVAL) | |
| async def lifespan(app: FastAPI): | |
| task = asyncio.create_task(keep_alive_task()) | |
| print("[Startup] Keep-alive task started") | |
| yield | |
| task.cancel() | |
| print("[Shutdown] Keep-alive task cancelled") | |
| app = FastAPI(lifespan=lifespan) | |
| security = HTTPBasic(auto_error=False) | |
| USERNAME = os.environ.get("PROXY_USER", "IulHnU") | |
| PASSWORD = os.environ.get("PROXY_PASS", "TtLOY2") | |
| def parse_proxy_header(proxy_value: str) -> str: | |
| """ | |
| 解析 proxy header 并返回 httpx 可用的代理 URL。 | |
| 支持两种格式: | |
| 1. 新格式 (SOCKS5 认证): host:port:username:password | |
| -> socks5://user:pass@host:port | |
| 2. 旧格式 (HTTP 代理): host:port | |
| -> http://host:port | |
| """ | |
| if not proxy_value: | |
| return None | |
| parts = proxy_value.split(":") | |
| # 至少需要 host:port | |
| if len(parts) < 2: | |
| print(f"[Proxy] Invalid format (need host:port): {proxy_value}") | |
| return None | |
| # 新格式: host:port:username:password (4+ 部分) | |
| if len(parts) >= 4: | |
| host = parts[0] | |
| port = parts[1] | |
| username = parts[2] | |
| password = ":".join(parts[3:]) # 密码可能包含冒号 | |
| if not port.isdigit(): | |
| print(f"[Proxy] Invalid port: {port}") | |
| return None | |
| encoded_user = quote(username, safe="") | |
| encoded_pass = quote(password, safe="") | |
| print(f"[Proxy] Using HTTP proxy: {host}:{port} (user={username})") | |
| return f"http://{encoded_user}:{encoded_pass}@{host}:{port}" | |
| # 旧格式: host:port | |
| host, port = parts[0], parts[1] | |
| if not port.isdigit(): | |
| print(f"[Proxy] Invalid port: {port}") | |
| return None | |
| print(f"[Proxy] Using HTTP proxy: {host}:{port}") | |
| return f"http://{host}:{port}" | |
| def verify_auth(credentials: HTTPBasicCredentials = Depends(security)): | |
| """可选认证:如果传递了凭据则验证,但不强制要求也不阻止请求""" | |
| if credentials: | |
| is_user = secrets.compare_digest(credentials.username, USERNAME) | |
| is_pass = secrets.compare_digest(credentials.password, PASSWORD) | |
| if not (is_user and is_pass): | |
| print(f"[Auth] Invalid credentials provided (user={credentials.username})") | |
| return None # 始终放行 | |
| async def proxy(target_url: str, request: Request, _: None = Depends(verify_auth)): | |
| try: | |
| body = await request.body() | |
| # 提取 proxy 头并解析代理 URL | |
| proxy_header = request.headers.get("proxy") | |
| proxy_url = parse_proxy_header(proxy_header) | |
| # 过滤会泄露客户端真实 IP 的头,以及 host 和 proxy 头 | |
| STRIP_HEADERS = { | |
| b"host", b"proxy", | |
| b"x-forwarded-for", b"x-forwarded-host", b"x-forwarded-proto", | |
| b"x-real-ip", b"forwarded", b"via", | |
| b"cf-connecting-ip", b"true-client-ip", | |
| } | |
| forward_headers = [ | |
| (k, v) for k, v in request.headers.raw | |
| if k.lower() not in STRIP_HEADERS | |
| ] | |
| client = httpx.AsyncClient( | |
| proxy=proxy_url, | |
| ) | |
| req = client.build_request( | |
| method=request.method, | |
| url=target_url, | |
| headers=forward_headers, | |
| content=body if body else None, | |
| timeout=30, | |
| ) | |
| resp = await client.send(req, stream=True, follow_redirects=False) | |
| # 流式转发:边收边发 | |
| async def stream_body(): | |
| try: | |
| async for chunk in resp.aiter_raw(): | |
| yield chunk | |
| finally: | |
| await resp.aclose() | |
| await client.aclose() | |
| SKIP_RESP_HEADERS = {b"transfer-encoding", b"connection"} | |
| response = StreamingResponse( | |
| content=stream_body(), | |
| status_code=resp.status_code, | |
| ) | |
| for k, v in resp.headers.raw: | |
| if k.lower() not in SKIP_RESP_HEADERS: | |
| response.headers.append(k.decode(), v.decode()) | |
| return response | |
| except Exception as e: | |
| import traceback | |
| print(f"Proxy Error: {e}") | |
| traceback.print_exc() | |
| return Response(content=f"Proxy Error: {str(e)}", status_code=500) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |