File size: 9,640 Bytes
77169b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
本地中转代理(forward proxy)。

用途:
- 浏览器只配置无鉴权的本地代理:127.0.0.1:<port>
- 本地代理再转发到“带用户名密码鉴权”的上游代理(HTTP proxy)

实现重点:
- 支持 CONNECT(HTTPS 隧道)——浏览器最常见的代理用法
- 兼容少量 HTTP 明文请求(GET http://... 这种 absolute-form)
"""

from __future__ import annotations

import base64
import contextlib
import select
import socket
import socketserver
import threading
from dataclasses import dataclass
from typing import Callable, Optional
from urllib.parse import urlparse


def _basic_proxy_auth(username: str, password: str) -> str:
    raw = f"{username}:{password}".encode("utf-8")
    return "Basic " + base64.b64encode(raw).decode("ascii")


def _recv_until(
    sock: socket.socket, marker: bytes, max_bytes: int = 256 * 1024
) -> bytes:
    data = bytearray()
    while marker not in data:
        chunk = sock.recv(4096)
        if not chunk:
            break
        data += chunk
        if len(data) > max_bytes:
            break
    return bytes(data)


def _split_headers(data: bytes) -> tuple[bytes, bytes]:
    idx = data.find(b"\r\n\r\n")
    if idx < 0:
        return data, b""
    return data[: idx + 4], data[idx + 4 :]


def _parse_first_line(header_bytes: bytes) -> tuple[str, str, str]:
    # e.g. "CONNECT example.com:443 HTTP/1.1"
    first = header_bytes.split(b"\r\n", 1)[0].decode("latin-1", errors="replace")
    parts = first.strip().split()
    if len(parts) >= 3:
        return parts[0].upper(), parts[1], parts[2]
    if len(parts) == 2:
        return parts[0].upper(), parts[1], "HTTP/1.1"
    return "GET", "/", "HTTP/1.1"


def _remove_hop_by_hop_headers(header_bytes: bytes) -> bytes:
    # 仅做最小处理:去掉 Proxy-Authorization / Proxy-Connection,避免重复/冲突
    lines = header_bytes.split(b"\r\n")
    if not lines:
        return header_bytes
    out = [lines[0]]
    for line in lines[1:]:
        lower = line.lower()
        if lower.startswith(b"proxy-authorization:"):
            continue
        if lower.startswith(b"proxy-connection:"):
            continue
        out.append(line)
    return b"\r\n".join(out)


def _relay_bidi(a: socket.socket, b: socket.socket, stop_evt: threading.Event) -> None:
    a.setblocking(False)
    b.setblocking(False)
    socks = [a, b]
    try:
        while not stop_evt.is_set():
            r, _, _ = select.select(socks, [], [], 0.5)
            if not r:
                continue
            for s in r:
                try:
                    data = s.recv(65536)
                except BlockingIOError:
                    continue
                if not data:
                    stop_evt.set()
                    break
                other = b if s is a else a
                try:
                    other.sendall(data)
                except OSError:
                    stop_evt.set()
                    break
    finally:
        with contextlib.suppress(Exception):
            a.close()
        with contextlib.suppress(Exception):
            b.close()


class _ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
    allow_reuse_address = True
    daemon_threads = True


@dataclass(frozen=True)
class UpstreamProxy:
    host: str
    port: int
    username: str
    password: str

    @property
    def auth_header_value(self) -> str:
        return _basic_proxy_auth(self.username, self.password)


def parse_proxy_server(proxy_server: str) -> tuple[str, int]:
    """
    支持:
    - http://host:port
    - host:port
    """
    s = (proxy_server or "").strip()
    if not s:
        raise ValueError("proxy_server 为空")
    if "://" not in s:
        s = "http://" + s
    u = urlparse(s)
    if not u.hostname or not u.port:
        raise ValueError(f"无法解析 proxy_server: {proxy_server!r}")
    return u.hostname, int(u.port)


class LocalProxyForwarder:
    """
    启动一个本地 HTTP 代理,并把请求/隧道转发到上游代理(带 Basic 鉴权)。
    """

    def __init__(
        self,
        upstream: UpstreamProxy,
        *,
        listen_host: str = "127.0.0.1",
        listen_port: int = 0,
        on_log: Optional[Callable[[str], None]] = None,
    ) -> None:
        self._upstream = upstream
        self._listen_host = listen_host
        self._listen_port = listen_port
        self._on_log = on_log

        self._server: _ThreadingTCPServer | None = None
        self._thread: threading.Thread | None = None

    @property
    def port(self) -> int:
        if not self._server:
            raise RuntimeError("forwarder 尚未启动")
        return int(self._server.server_address[1])

    @property
    def proxy_url(self) -> str:
        return f"http://{self._listen_host}:{self.port}"

    def _log(self, msg: str) -> None:
        if self._on_log:
            try:
                self._on_log(msg)
            except Exception:
                pass

    def start(self) -> "LocalProxyForwarder":
        if self._server is not None:
            return self

        upstream = self._upstream
        parent = self

        class Handler(socketserver.BaseRequestHandler):
            def handle(self) -> None:
                client = self.request
                try:
                    data = _recv_until(client, b"\r\n\r\n")
                    if not data:
                        return
                    header, rest = _split_headers(data)
                    method, target, _ver = _parse_first_line(header)

                    upstream_sock = socket.create_connection(
                        (upstream.host, upstream.port), timeout=15
                    )
                    upstream_sock.settimeout(20)

                    if method == "CONNECT":
                        # 通过上游代理建立到 target 的隧道
                        connect_req = (
                            f"CONNECT {target} HTTP/1.1\r\n"
                            f"Host: {target}\r\n"
                            f"Proxy-Authorization: {upstream.auth_header_value}\r\n"
                            f"Proxy-Connection: keep-alive\r\n"
                            f"Connection: keep-alive\r\n"
                            f"\r\n"
                        ).encode("latin-1", errors="ignore")
                        upstream_sock.sendall(connect_req)
                        upstream_resp = _recv_until(upstream_sock, b"\r\n\r\n")
                        if not upstream_resp:
                            client.sendall(b"HTTP/1.1 502 Bad Gateway\r\n\r\n")
                            return
                        # 将上游响应直接回给浏览器(一般是 200 Connection Established)
                        client.sendall(upstream_resp)

                        # CONNECT 时,header 后可能不会有 body;但如果有残留,丢给上游
                        if rest:
                            upstream_sock.sendall(rest)

                        stop_evt = threading.Event()
                        _relay_bidi(client, upstream_sock, stop_evt)
                        return

                    # 非 CONNECT:把请求转发给上游代理(absolute-form 请求)
                    # 注:这里只做最小实现,主要为兼容偶发 http:// 明文请求
                    filtered = _remove_hop_by_hop_headers(header)
                    # 插入 Proxy-Authorization
                    parts = filtered.split(b"\r\n")
                    out_lines = [parts[0]]
                    inserted = False
                    for line in parts[1:]:
                        if not inserted and line == b"":
                            out_lines.append(
                                f"Proxy-Authorization: {upstream.auth_header_value}".encode(
                                    "latin-1", errors="ignore"
                                )
                            )
                            inserted = True
                        out_lines.append(line)
                    new_header = b"\r\n".join(out_lines)
                    upstream_sock.sendall(new_header)
                    if rest:
                        upstream_sock.sendall(rest)

                    # 单向把响应回写给客户端直到连接关闭
                    while True:
                        chunk = upstream_sock.recv(65536)
                        if not chunk:
                            break
                        client.sendall(chunk)
                except Exception as e:
                    parent._log(f"[proxy] handler error: {e}")
                    with contextlib.suppress(Exception):
                        client.sendall(b"HTTP/1.1 502 Bad Gateway\r\n\r\n")
                finally:
                    with contextlib.suppress(Exception):
                        client.close()

        self._server = _ThreadingTCPServer(
            (self._listen_host, self._listen_port), Handler
        )
        self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
        self._thread.start()
        return self

    def stop(self) -> None:
        if self._server is None:
            return
        with contextlib.suppress(Exception):
            self._server.shutdown()
        with contextlib.suppress(Exception):
            self._server.server_close()
        self._server = None
        self._thread = None

    def __enter__(self) -> "LocalProxyForwarder":
        return self.start()

    def __exit__(self, exc_type, exc, tb) -> None:
        self.stop()