File size: 4,956 Bytes
ebf2eae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import threading
import requests
from flask import Flask, request, jsonify, abort, Response

class ProxyServer:
    def __init__(self, host='127.0.0.1', port=7860):
        self.app = Flask(__name__)
        self.host = host
        self.port = port
        self._setup_routes()

    def _setup_routes(self):
        # 定义一个通用的代理路由,捕获所有以 /v1/ 开头的请求路径
        # 例如:/v1/https/open.bigmodel.cn/api/paas/v4/chat/completions
        @self.app.route('/v1/<path:url_path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'])
        def proxy_request(url_path):
            print(f"接收到的 url_path: {url_path}")

            # url_path 示例: https/open.bigmodel.cn/api/paas/v4/chat/completions
            # 找到第一个 '/' 的位置,分隔协议和域名
            first_slash_idx = url_path.find('/')
            if first_slash_idx == -1:
                abort(400, description="无效的URL路径格式。期望协议/域名/路径。")
            
            protocol = url_path[:first_slash_idx] # 'https'
            print(f"解析出的协议: {protocol}")

            # 找到第二个 '/' 的位置,分隔域名和实际的路径
            # 从 first_slash_idx + 1 开始查找,即从 'open.bigmodel.cn/api/paas/v4/chat/completions' 开始
            second_slash_idx = url_path.find('/', first_slash_idx + 1)

            if second_slash_idx == -1:
                # 如果没有第二个斜杠,说明只有协议和域名,没有后续路径
                domain = url_path[first_slash_idx + 1:] # 'open.bigmodel.cn'
                remaining_path = ''
            else:
                domain = url_path[first_slash_idx + 1:second_slash_idx] # 'open.bigmodel.cn'
                remaining_path = url_path[second_slash_idx:] # '/api/paas/v4/chat/completions'

            target_url = f"{protocol}://{domain}{remaining_path}"
            print(f"\n\n\n代理请求到 {target_url}")

            # 转发原始请求的头部,排除 'Host' 头部以避免冲突
            headers = {key: value for key, value in request.headers if key.lower() != 'host'}
            
            try:
                # 使用 requests 库向目标 URL 发送请求,并转发原始请求的方法、头部、数据和查询参数
                # stream=True 用于流式传输响应,这对于代理大文件或保持连接非常有用
                resp = requests.request(
                    method=request.method,
                    url=target_url,
                    headers=headers,
                    data=request.get_data(),
                    params=request.args,
                    allow_redirects=False, # 不允许 requests 自动处理重定向
                    verify=True # 禁用 SSL 证书验证,仅用于调试
                )

                # 打印目标 API 返回的实际状态码和响应体,用于调试
                print(f"目标API响应状态码: {resp.status_code}")
                print(f"目标API响应体: {resp.text[:500]}...") # 打印前500个字符,避免过长

                # 构建响应头部,并确保 Content-Length 被移除,因为我们将使用流式传输
                # 如果原始响应是分块编码,requests 会自动处理,但我们通过流式传输来确保 Flask 也以流式方式发送
                excluded_headers = ['content-encoding']
                response_headers = [
                    (name, value) for name, value in resp.headers.items()
                    if name.lower() not in excluded_headers
                ]
                
                # 返回流式响应内容
                # 使用 generate_response 函数来迭代响应内容,实现流式传输
                def generate_response():
                    for chunk in resp.iter_content(chunk_size=8192):
                        yield chunk

                response = Response(generate_response(), status=resp.status_code, headers=response_headers)
                return response

            except requests.exceptions.RequestException as e:
                print(f"代理请求失败: {e}")
                abort(500, description=f"代理请求到 {target_url} 失败: {e}")

    def run(self):
        print(f"代理服务器正在 {self.host}:{self.port} 上启动")
        # Flask 默认在开发模式下会为每个请求创建一个新的线程,实现并发处理。
        # 在生产环境中,建议使用 Gunicorn 或 uWSGI 等 WSGI 服务器来管理多线程/多进程。
        self.app.run(host=self.host, port=self.port, debug=True) # 将 debug 模式设置为 True

if __name__ == '__main__':
    # 提示:请确保您已激活 conda 环境 'any-api' (conda activate any-api)
    # 提示:请确保已安装 Flask 和 requests 库 (pip install Flask requests)
    proxy = ProxyServer(
        host='0.0.0.0',
        port=7860
    )
    proxy.run()