File size: 4,586 Bytes
2ee3bb2
 
 
 
26c3b1f
2ee3bb2
2280bd6
cb2432b
b1049e3
2ee3bb2
 
 
 
 
 
 
 
 
 
 
 
47121b2
2ee3bb2
 
 
 
26c3b1f
 
 
 
 
 
 
940b595
3b27764
2ee3bb2
 
 
 
 
 
 
 
cb2432b
2ee3bb2
26c3b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dbc742
 
 
 
 
26c3b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee3bb2
 
00e65da
2ee3bb2
00e65da
3b27764
00e65da
 
940b595
 
 
 
 
00e65da
940b595
00e65da
940b595
 
 
 
 
 
2ee3bb2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json
from typing import Optional

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

TARGET_URL = "http://beibeioo.top"

async def get_http_client():
    return httpx.AsyncClient(
        timeout=30.0,
        follow_redirects=True,
        http2=True
    )

async def is_stream_request(request: Request) -> bool:
    # 检查请求头中的 accept 字段
    accept = request.headers.get('accept', '')
    # 检查查询参数
    stream = request.query_params.get('stream', '').lower()
    return 'text/event-stream' in accept or stream == 'true'

async def proxy_handler(url: str, request: Request):
    try:
        headers = dict(request.headers)
        headers.pop('host', None)
        headers.pop('connection', None)
        headers.pop('content-length', None)
        headers.pop('transfer-encoding', None)
        
        params = dict(request.query_params)
        body = await request.body()
        
        async with await get_http_client() as client:
            if await is_stream_request(request):
                # 流式请求
                response = await client.stream(
                    method=request.method,
                    url=url,
                    params=params,
                    headers=headers,
                    content=body if body else None,
                )
                
                response_headers = dict(response.headers)
                response_headers.pop('transfer-encoding', None)
                response_headers.pop('content-encoding', None)
                response_headers.pop('content-length', None)

                async def stream_generator():
                    try:
                        buffer = b""
                        async for chunk in response.aiter_raw():
                            # 立即发送每个收到的块
                            if chunk:
                                yield chunk
                    except Exception as e:
                        print(f"Streaming error: {e}")
                    finally:
                        await response.aclose()

                return StreamingResponse(
                    stream_generator(),
                    status_code=response.status_code,
                    headers=response_headers,
                    media_type=response.headers.get('content-type')
                )
            else:
                # 非流式请求
                response = await client.request(
                    method=request.method,
                    url=url,
                    params=params,
                    headers=headers,
                    content=body if body else None,
                )
                
                response_headers = dict(response.headers)
                response_headers.pop('transfer-encoding', None)
                response_headers.pop('content-encoding', None)
                response_headers.pop('content-length', None)
                
                return Response(
                    content=response.content,
                    status_code=response.status_code,
                    headers=response_headers,
                    media_type=response.headers.get('content-type')
                )
            
    except httpx.TimeoutException:
        return Response(content="请求超时", status_code=504)
    except httpx.RequestError:
        return Response(content="无法连接到目标服务器", status_code=502)
    except Exception as e:
        return Response(content=f"服务器错误: {str(e)}", status_code=500)

# 处理 /api/v1 路径
@app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def api_v1_proxy(path: str, request: Request):
    url = f"{TARGET_URL}/v1/{path}"
    return await proxy_handler(url, request)

# 处理根路径和其他路径
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def root_proxy(path: str, request: Request):
    if not path or path == "/":
        url = TARGET_URL
    else:
        url = f"{TARGET_URL}/{path}"
    return await proxy_handler(url, request)

@app.get("/healthcheck")
async def healthcheck():
    return {"status": "healthy"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)