hins111 commited on
Commit
7861a83
·
verified ·
1 Parent(s): 161fc74

Upload 7 files

Browse files
stream/__init__.py CHANGED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import multiprocessing
3
+
4
+ from stream import main
5
+
6
+ def start(*args, **kwargs):
7
+ """
8
+ 启动流式代理服务器,兼容位置参数和关键字参数
9
+
10
+ 位置参数模式(与参考文件兼容):
11
+ start(queue, port, proxy)
12
+
13
+ 关键字参数模式:
14
+ start(queue=queue, port=port, proxy=proxy)
15
+ """
16
+ if args:
17
+ # 位置参数模式(与参考文件兼容)
18
+ queue = args[0] if len(args) > 0 else None
19
+ port = args[1] if len(args) > 1 else None
20
+ proxy = args[2] if len(args) > 2 else None
21
+ else:
22
+ # 关键字参数模式
23
+ queue = kwargs.get('queue', None)
24
+ port = kwargs.get('port', None)
25
+ proxy = kwargs.get('proxy', None)
26
+
27
+ asyncio.run(main.builtin(queue=queue, port=port, proxy=proxy))
stream/cert_manager.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ from pathlib import Path
4
+ from cryptography import x509
5
+ from cryptography.x509.oid import NameOID
6
+ from cryptography.hazmat.primitives import hashes, serialization
7
+ from cryptography.hazmat.primitives.asymmetric import rsa
8
+ from cryptography.hazmat.backends import default_backend
9
+
10
+ class CertificateManager:
11
+ def __init__(self, cert_dir='certs'):
12
+ self.cert_dir = Path(cert_dir)
13
+ self.cert_dir.mkdir(exist_ok=True)
14
+
15
+ self.ca_key_path = self.cert_dir / 'ca.key'
16
+ self.ca_cert_path = self.cert_dir / 'ca.crt'
17
+
18
+ # Generate or load CA certificate
19
+ if not self.ca_cert_path.exists() or not self.ca_key_path.exists():
20
+ self._generate_ca_cert()
21
+
22
+ self._load_ca_cert()
23
+
24
+ def _generate_ca_cert(self):
25
+ """Generate a self-signed CA certificate"""
26
+ # Generate private key
27
+ private_key = rsa.generate_private_key(
28
+ public_exponent=65537,
29
+ key_size=2048,
30
+ backend=default_backend()
31
+ )
32
+
33
+ # Write private key to file
34
+ with open(self.ca_key_path, 'wb') as f:
35
+ f.write(private_key.private_bytes(
36
+ encoding=serialization.Encoding.PEM,
37
+ format=serialization.PrivateFormat.PKCS8,
38
+ encryption_algorithm=serialization.NoEncryption()
39
+ ))
40
+
41
+ # Create self-signed certificate
42
+ subject = issuer = x509.Name([
43
+ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
44
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
45
+ x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
46
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy CA"),
47
+ x509.NameAttribute(NameOID.COMMON_NAME, "Proxy CA Root"),
48
+ ])
49
+
50
+ cert = x509.CertificateBuilder().subject_name(
51
+ subject
52
+ ).issuer_name(
53
+ issuer
54
+ ).public_key(
55
+ private_key.public_key()
56
+ ).serial_number(
57
+ x509.random_serial_number()
58
+ ).not_valid_before(
59
+ datetime.datetime.utcnow()
60
+ ).not_valid_after(
61
+ datetime.datetime.utcnow() + datetime.timedelta(days=3650)
62
+ ).add_extension(
63
+ x509.BasicConstraints(ca=True, path_length=None), critical=True
64
+ ).add_extension(
65
+ x509.KeyUsage(
66
+ digital_signature=True,
67
+ content_commitment=False,
68
+ key_encipherment=True,
69
+ data_encipherment=False,
70
+ key_agreement=False,
71
+ key_cert_sign=True,
72
+ crl_sign=True,
73
+ encipher_only=False,
74
+ decipher_only=False
75
+ ), critical=True
76
+ ).sign(private_key, hashes.SHA256(), default_backend())
77
+
78
+ # Write certificate to file
79
+ with open(self.ca_cert_path, 'wb') as f:
80
+ f.write(cert.public_bytes(serialization.Encoding.PEM))
81
+
82
+ def _load_ca_cert(self):
83
+ """Load the CA certificate and private key"""
84
+ with open(self.ca_key_path, 'rb') as f:
85
+ self.ca_key = serialization.load_pem_private_key(
86
+ f.read(),
87
+ password=None,
88
+ backend=default_backend()
89
+ )
90
+
91
+ with open(self.ca_cert_path, 'rb') as f:
92
+ self.ca_cert = x509.load_pem_x509_certificate(
93
+ f.read(),
94
+ default_backend()
95
+ )
96
+
97
+ def get_domain_cert(self, domain):
98
+ """Get or generate a certificate for the specified domain"""
99
+ cert_path = self.cert_dir / f"{domain}.crt"
100
+ key_path = self.cert_dir / f"{domain}.key"
101
+
102
+ if cert_path.exists() and key_path.exists():
103
+ # Load existing certificate and key
104
+ with open(key_path, 'rb') as f:
105
+ private_key = serialization.load_pem_private_key(
106
+ f.read(),
107
+ password=None,
108
+ backend=default_backend()
109
+ )
110
+
111
+ with open(cert_path, 'rb') as f:
112
+ cert = x509.load_pem_x509_certificate(
113
+ f.read(),
114
+ default_backend()
115
+ )
116
+
117
+ return private_key, cert
118
+
119
+ # Generate new certificate
120
+ return self._generate_domain_cert(domain)
121
+
122
+ def _generate_domain_cert(self, domain):
123
+ """Generate a certificate for the specified domain signed by the CA"""
124
+ # Generate private key
125
+ private_key = rsa.generate_private_key(
126
+ public_exponent=65537,
127
+ key_size=2048,
128
+ backend=default_backend()
129
+ )
130
+
131
+ # Write private key to file
132
+ key_path = self.cert_dir / f"{domain}.key"
133
+ with open(key_path, 'wb') as f:
134
+ f.write(private_key.private_bytes(
135
+ encoding=serialization.Encoding.PEM,
136
+ format=serialization.PrivateFormat.PKCS8,
137
+ encryption_algorithm=serialization.NoEncryption()
138
+ ))
139
+
140
+ # Create certificate
141
+ subject = x509.Name([
142
+ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
143
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
144
+ x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
145
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Server"),
146
+ x509.NameAttribute(NameOID.COMMON_NAME, domain),
147
+ ])
148
+
149
+ cert = x509.CertificateBuilder().subject_name(
150
+ subject
151
+ ).issuer_name(
152
+ self.ca_cert.subject
153
+ ).public_key(
154
+ private_key.public_key()
155
+ ).serial_number(
156
+ x509.random_serial_number()
157
+ ).not_valid_before(
158
+ datetime.datetime.utcnow()
159
+ ).not_valid_after(
160
+ datetime.datetime.utcnow() + datetime.timedelta(days=365)
161
+ ).add_extension(
162
+ x509.SubjectAlternativeName([x509.DNSName(domain)]),
163
+ critical=False
164
+ ).sign(self.ca_key, hashes.SHA256(), default_backend())
165
+
166
+ # Write certificate to file
167
+ cert_path = self.cert_dir / f"{domain}.crt"
168
+ with open(cert_path, 'wb') as f:
169
+ f.write(cert.public_bytes(serialization.Encoding.PEM))
170
+
171
+ return private_key, cert
stream/interceptors.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ import zlib
5
+
6
+ class HttpInterceptor:
7
+ """
8
+ Class to intercept and process HTTP requests and responses
9
+ """
10
+ def __init__(self, log_dir='logs'):
11
+ self.log_dir = log_dir
12
+ self.logger = logging.getLogger('http_interceptor')
13
+ self.setup_logging()
14
+
15
+ @staticmethod
16
+ def setup_logging():
17
+ """Set up logging configuration"""
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.StreamHandler()
23
+ ]
24
+ )
25
+
26
+ @staticmethod
27
+ def should_intercept(host, path):
28
+ """
29
+ Determine if the request should be intercepted based on host and path
30
+ """
31
+ # Check if the endpoint contains GenerateContent
32
+ if 'GenerateContent' in path:
33
+ return True
34
+
35
+ # Add more conditions as needed
36
+ return False
37
+
38
+ async def process_request(self, request_data, host, path):
39
+ """
40
+ Process the request data before sending to the server
41
+ """
42
+ if not self.should_intercept(host, path):
43
+ return request_data
44
+
45
+ # Log the request
46
+ self.logger.info(f"Intercepted request to {host}{path}")
47
+
48
+ try:
49
+ return request_data
50
+ except (json.JSONDecodeError, UnicodeDecodeError):
51
+ # Not JSON or not UTF-8, just pass through
52
+ return request_data
53
+
54
+ async def process_response(self, response_data, host, path, headers):
55
+ """
56
+ Process the response data before sending to the client
57
+ """
58
+ try:
59
+ # Handle chunked encoding
60
+ decoded_data, is_done = self._decode_chunked(bytes(response_data))
61
+ # Handle gzip encoding
62
+ decoded_data = self._decompress_zlib_stream(decoded_data)
63
+ result = self.parse_response(decoded_data)
64
+ result["done"] = is_done
65
+ return result
66
+ except Exception as e:
67
+ raise e
68
+
69
+ def parse_response(self, response_data):
70
+ pattern = rb'\[\[\[null,.*?]],"model"]'
71
+ matches = []
72
+ for match_obj in re.finditer(pattern, response_data):
73
+ matches.append(match_obj.group(0))
74
+
75
+
76
+ resp = {
77
+ "reason": "",
78
+ "body": "",
79
+ "function": [],
80
+ }
81
+
82
+ # Print each full match
83
+ for match in matches:
84
+ json_data = json.loads(match)
85
+
86
+ try:
87
+ payload = json_data[0][0]
88
+ except Exception as e:
89
+ continue
90
+
91
+ if len(payload)==2: # body
92
+ resp["body"] = resp["body"] + payload[1]
93
+ elif len(payload) == 11 and payload[1] is None and type(payload[10]) == list: # function
94
+ array_tool_calls = payload[10]
95
+ func_name = array_tool_calls[0]
96
+ params = self.parse_toolcall_params(array_tool_calls[1])
97
+ resp["function"].append({"name":func_name, "params":params})
98
+ elif len(payload) > 2: # reason
99
+ resp["reason"] = resp["reason"] + payload[1]
100
+
101
+ return resp
102
+
103
+ def parse_toolcall_params(self, args):
104
+ try:
105
+ params = args[0]
106
+ func_params = {}
107
+ for param in params:
108
+ param_name = param[0]
109
+ param_value = param[1]
110
+
111
+ if type(param_value)==list:
112
+ if len(param_value)==1: # null
113
+ func_params[param_name] = None
114
+ elif len(param_value) == 2: # number and integer
115
+ func_params[param_name] = param_value[1]
116
+ elif len(param_value) == 3: # string
117
+ func_params[param_name] = param_value[2]
118
+ elif len(param_value) == 4: # boolean
119
+ func_params[param_name] = param_value[3] == 1
120
+ elif len(param_value) == 5: # object
121
+ func_params[param_name] = self.parse_toolcall_params(param_value[4])
122
+ return func_params
123
+ except Exception as e:
124
+ raise e
125
+
126
+ @staticmethod
127
+ def _decompress_zlib_stream(compressed_stream):
128
+ decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) # zlib header
129
+ decompressed = decompressor.decompress(compressed_stream)
130
+ return decompressed
131
+
132
+ @staticmethod
133
+ def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
134
+ chunked_data = bytearray()
135
+ while True:
136
+ # print(' '.join(format(x, '02x') for x in response_body))
137
+
138
+ length_crlf_idx = response_body.find(b"\r\n")
139
+ if length_crlf_idx == -1:
140
+ break
141
+
142
+ hex_length = response_body[:length_crlf_idx]
143
+ try:
144
+ length = int(hex_length, 16)
145
+ except ValueError as e:
146
+ logging.error(f"Parsing chunked length failed: {e}")
147
+ break
148
+
149
+ if length == 0:
150
+ length_crlf_idx = response_body.find(b"0\r\n\r\n")
151
+ if length_crlf_idx != -1:
152
+ return chunked_data, True
153
+
154
+ if length + 2 > len(response_body):
155
+ break
156
+
157
+ chunked_data.extend(response_body[length_crlf_idx + 2:length_crlf_idx + 2 + length])
158
+ if length_crlf_idx + 2 + length + 2 > len(response_body):
159
+ break
160
+
161
+ response_body = response_body[length_crlf_idx + 2 + length + 2:]
162
+ return chunked_data, False
stream/main.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import logging
4
+ import multiprocessing
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from stream.proxy_server import ProxyServer
9
+
10
+ def parse_args():
11
+ """Parse command line arguments"""
12
+ parser = argparse.ArgumentParser(description='HTTPS Proxy Server with SSL Inspection')
13
+
14
+ parser.add_argument('--host', default='127.0.0.1', help='Host to bind the proxy server')
15
+ parser.add_argument('--port', type=int, default=3120, help='Port to bind the proxy server')
16
+ parser.add_argument('--domains', nargs='+', default=['*.google.com'],
17
+ help='List of domain patterns to intercept (regex)')
18
+ parser.add_argument('--proxy', help='Upstream proxy URL (e.g., http://user:pass@host:port)')
19
+
20
+ return parser.parse_args()
21
+
22
+
23
+ async def main():
24
+ """Main entry point"""
25
+ args = parse_args()
26
+
27
+ # Set up logging
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
31
+ handlers=[
32
+ logging.StreamHandler()
33
+ ]
34
+ )
35
+
36
+ logger = logging.getLogger('main')
37
+
38
+ # Create certs directory
39
+ cert_dir = Path('certs')
40
+ cert_dir.mkdir(exist_ok=True)
41
+
42
+ # Print startup information
43
+ logger.info(f"Starting proxy server on {args.host}:{args.port}")
44
+ logger.info(f"Intercepting domains: {args.domains}")
45
+ if args.proxy:
46
+ logger.info(f"Using upstream proxy: {args.proxy}")
47
+
48
+ # Create and start the proxy server
49
+ proxy_server = ProxyServer(
50
+ host=args.host,
51
+ port=args.port,
52
+ intercept_domains=args.domains,
53
+ upstream_proxy=args.proxy,
54
+ queue=None,
55
+ )
56
+
57
+ try:
58
+ await proxy_server.start()
59
+ except KeyboardInterrupt:
60
+ logger.info("Shutting down proxy server")
61
+ except Exception as e:
62
+ logger.error(f"Error starting proxy server: {e}")
63
+ sys.exit(1)
64
+
65
+
66
+ async def builtin(queue: multiprocessing.Queue = None, port=None, proxy=None):
67
+ # Set up logging
68
+ logging.basicConfig(
69
+ level=logging.INFO,
70
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
71
+ handlers=[
72
+ logging.StreamHandler()
73
+ ]
74
+ )
75
+
76
+ logger = logging.getLogger('main')
77
+
78
+ # Create certs directory
79
+ cert_dir = Path('certs')
80
+ cert_dir.mkdir(exist_ok=True)
81
+
82
+ if port is None:
83
+ port = 3120
84
+
85
+ # Create and start the proxy server
86
+ proxy_server = ProxyServer(
87
+ host="127.0.0.1",
88
+ port=port,
89
+ intercept_domains=['*.google.com'],
90
+ upstream_proxy=proxy,
91
+ queue=queue,
92
+ )
93
+
94
+ try:
95
+ await proxy_server.start()
96
+ except KeyboardInterrupt:
97
+ logger.info("Shutting down proxy server")
98
+ except Exception as e:
99
+ logger.error(f"Error starting proxy server: {e}")
100
+ sys.exit(1)
101
+
102
+ if __name__ == '__main__':
103
+ asyncio.run(main())
stream/proxy_connector.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import ssl as ssl_module
3
+ import urllib.parse
4
+ from aiohttp import TCPConnector
5
+ from python_socks.async_.asyncio import Proxy
6
+
7
+
8
+ class ProxyConnector:
9
+ """
10
+ Class to handle connections through different types of proxies
11
+ """
12
+
13
+ def __init__(self, proxy_url=None):
14
+ self.proxy_url = proxy_url
15
+ self.connector = None
16
+
17
+ if proxy_url:
18
+ self._setup_connector()
19
+
20
+ def _setup_connector(self):
21
+ """Set up the appropriate connector based on the proxy URL"""
22
+ if not self.proxy_url:
23
+ self.connector = TCPConnector()
24
+ return
25
+
26
+ # Parse the proxy URL
27
+ parsed = urllib.parse.urlparse(self.proxy_url)
28
+ proxy_type = parsed.scheme.lower()
29
+
30
+ if proxy_type in ('http', 'https', 'socks4', 'socks5'):
31
+ self.connector = "SocksConnector"
32
+ else:
33
+ raise ValueError(f"Unsupported proxy type: {proxy_type}")
34
+
35
+ async def create_connection(self, host, port, ssl=None):
36
+ """Create a connection to the target host through the proxy"""
37
+ if not self.connector:
38
+ # Direct connection without proxy
39
+ reader, writer = await asyncio.open_connection(host, port, ssl=ssl)
40
+ return reader, writer
41
+
42
+ # SOCKS proxy connection
43
+ proxy = Proxy.from_url(self.proxy_url)
44
+ sock = await proxy.connect(dest_host=host, dest_port=port)
45
+ if ssl is None:
46
+ reader, writer = await asyncio.open_connection(
47
+ host=None,
48
+ port=None,
49
+ sock=sock,
50
+ ssl=None,
51
+ )
52
+ return reader, writer
53
+ else:
54
+ ssl_context = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
55
+ ssl_context.check_hostname = False
56
+ ssl_context.verify_mode = ssl_module.CERT_NONE
57
+ ssl_context.minimum_version = ssl_module.TLSVersion.TLSv1_2 # Force TLS 1.2 or higher
58
+ ssl_context.maximum_version = ssl_module.TLSVersion.TLSv1_3 # Allow TLS 1.3 if supported
59
+ ssl_context.set_ciphers('DEFAULT@SECLEVEL=2') # Use secure ciphers
60
+
61
+ reader, writer = await asyncio.open_connection(
62
+ host=None,
63
+ port=None,
64
+ sock=sock,
65
+ ssl=ssl_context,
66
+ server_hostname=host,
67
+ )
68
+ return reader, writer
stream/proxy_server.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Optional
3
+ import json
4
+ import logging
5
+ import ssl
6
+ import multiprocessing
7
+ from pathlib import Path
8
+ from urllib.parse import urlparse
9
+
10
+ from stream.cert_manager import CertificateManager
11
+ from stream.proxy_connector import ProxyConnector
12
+ from stream.interceptors import HttpInterceptor
13
+
14
+ class ProxyServer:
15
+ """
16
+ Asynchronous HTTPS proxy server with SSL inspection capabilities
17
+ """
18
+ def __init__(self, host='0.0.0.0', port=3120, intercept_domains=None, upstream_proxy=None, queue: Optional[multiprocessing.Queue]=None):
19
+ self.host = host
20
+ self.port = port
21
+ self.intercept_domains = intercept_domains or []
22
+ self.upstream_proxy = upstream_proxy
23
+ self.queue = queue
24
+
25
+ # Initialize components
26
+ self.cert_manager = CertificateManager()
27
+ self.proxy_connector = ProxyConnector(upstream_proxy)
28
+
29
+ # Create logs directory
30
+ log_dir = Path('logs')
31
+ log_dir.mkdir(exist_ok=True)
32
+ self.interceptor = HttpInterceptor(str(log_dir))
33
+
34
+ # Set up logging
35
+ self.logger = logging.getLogger('proxy_server')
36
+
37
+ def should_intercept(self, host):
38
+ """
39
+ Determine if the connection to the host should be intercepted
40
+ """
41
+ if host in self.intercept_domains:
42
+ return True
43
+
44
+ # Wildcard match (e.g. *.example.com)
45
+ for d in self.intercept_domains:
46
+ if d.startswith("*."):
47
+ suffix = d[1:] # Remove *
48
+ if host.endswith(suffix):
49
+ return True
50
+
51
+ return False
52
+
53
+ async def handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
54
+ """
55
+ Handle a client connection
56
+ """
57
+ try:
58
+ # Read the initial request line
59
+ request_line = await reader.readline()
60
+ request_line = request_line.decode('utf-8').strip()
61
+
62
+ if not request_line:
63
+ writer.close()
64
+ return
65
+
66
+ # Parse the request line
67
+ method, target, version = request_line.split(' ')
68
+
69
+ if method == 'CONNECT':
70
+ # Handle HTTPS connection
71
+ await self._handle_connect(reader, writer, target)
72
+
73
+ except Exception as e:
74
+ self.logger.error(f"Error handling client: {e}")
75
+ finally:
76
+ writer.close()
77
+
78
+ async def _handle_connect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, target: str):
79
+ """
80
+ Handle CONNECT method (for HTTPS connections)
81
+ """
82
+
83
+ host, port = target.split(':')
84
+ port = int(port)
85
+ # Determine if we should intercept this connection
86
+ intercept = self.should_intercept(host)
87
+
88
+ if intercept:
89
+ self.logger.info(f"Sniff HTTPS requests to : {target}")
90
+
91
+ self.cert_manager.get_domain_cert(host)
92
+
93
+ # Send 200 Connection Established to the client
94
+ writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n')
95
+ await writer.drain()
96
+
97
+ # Drop the proxy connect header
98
+ await reader.read(8192)
99
+
100
+ loop = asyncio.get_running_loop()
101
+ transport = writer.transport # This is the original client transport
102
+
103
+ if transport is None: # 新增检查块开始
104
+ self.logger.warning(f"Client writer transport is None for {host}:{port} before TLS upgrade. Closing.")
105
+ # writer is likely already closed or in a bad state.
106
+ # We can't proceed with start_tls if transport is None.
107
+ return # Exit _handle_connect for this client # 新增检查块结束
108
+
109
+ ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
110
+ ssl_context.load_cert_chain(
111
+ certfile=self.cert_manager.cert_dir / f"{host}.crt",
112
+ keyfile=self.cert_manager.cert_dir / f"{host}.key"
113
+ )
114
+
115
+ # 1. 正确获取与原始 transport 关联的协议实例
116
+ # 'transport' here is 'writer.transport' from line 101, now checked not to be None
117
+ client_protocol = transport.get_protocol()
118
+
119
+ # 2. 将获取到的 client_protocol 实例传递给 start_tls
120
+ # loop.start_tls 会修改这个 client_protocol 实例,使其与 new_transport 关联
121
+ new_transport = await loop.start_tls(
122
+ transport=transport,
123
+ protocol=client_protocol, # 关键:传递获取到的协议实例
124
+ sslcontext=ssl_context,
125
+ server_side=True
126
+ )
127
+
128
+ # 3. 增加对 new_transport 的 None 检查 (主要为了类型安全和 Pylance)
129
+ if new_transport is None:
130
+ self.logger.error(f"loop.start_tls returned None for {host}:{port}, which is unexpected. Closing connection.")
131
+ # Ensure client writer is closed if it was opened or transport was valid before
132
+ writer.close()
133
+ # await writer.wait_closed() # Consider if waiting is necessary here
134
+ return
135
+
136
+ client_reader = reader
137
+
138
+ # 4. 创建 StreamWriter 时,使用被 start_tls 正确更新过的 client_protocol
139
+ client_writer = asyncio.StreamWriter(
140
+ transport=new_transport, # 使用新的 TLS transport
141
+ protocol=client_protocol, # 关键:使用被 start_tls 更新过的协议实例
142
+ reader=client_reader,
143
+ loop=loop
144
+ )
145
+
146
+ # Connect to the target server
147
+ try:
148
+ server_reader, server_writer = await self.proxy_connector.create_connection(
149
+ host, port, ssl=ssl.create_default_context()
150
+ )
151
+
152
+ # Start bidirectional forwarding with interception
153
+ await self._forward_data_with_interception(
154
+ client_reader, client_writer,
155
+ server_reader, server_writer,
156
+ host
157
+ )
158
+ except Exception as e:
159
+ # self.logger.error(f"Error connecting to server {host}:{port}: {e}")
160
+ client_writer.close()
161
+ # await client_writer.wait_closed()
162
+ else:
163
+ # No interception, just forward the connection
164
+ writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n')
165
+ await writer.drain()
166
+
167
+ # Drop the proxy connect header
168
+ await reader.read(8192)
169
+
170
+ try:
171
+ # Connect to the target server
172
+ server_reader, server_writer = await self.proxy_connector.create_connection(
173
+ host, port, ssl=None
174
+ )
175
+
176
+ # Start bidirectional forwarding without interception
177
+ await self._forward_data(
178
+ reader, writer,
179
+ server_reader, server_writer
180
+ )
181
+ except Exception as e:
182
+ # self.logger.error(f"Error connecting to server {host}:{port}: {e}")
183
+ writer.close()
184
+ # await writer.wait_closed()
185
+ async def _forward_data(self, client_reader, client_writer, server_reader, server_writer):
186
+ """
187
+ Forward data between client and server without interception
188
+ """
189
+ async def _forward(reader, writer):
190
+ try:
191
+ while True:
192
+ data = await reader.read(8192)
193
+ if not data:
194
+ break
195
+ writer.write(data)
196
+ await writer.drain()
197
+ except Exception as e:
198
+ self.logger.error(f"Error forwarding data: {e}")
199
+ finally:
200
+ writer.close()
201
+
202
+ # Create tasks for both directions
203
+ client_to_server = asyncio.create_task(_forward(client_reader, server_writer))
204
+ server_to_client = asyncio.create_task(_forward(server_reader, client_writer))
205
+
206
+ # Wait for both tasks to complete
207
+ tasks = [client_to_server, server_to_client]
208
+ await asyncio.gather(*tasks)
209
+ # await asyncio.gather(client_to_server, server_to_client)
210
+
211
+ async def _forward_data_with_interception(self, client_reader, client_writer,
212
+ server_reader, server_writer, host):
213
+ """
214
+ Forward data between client and server with interception
215
+ """
216
+ # Buffer to store HTTP request/response data
217
+ client_buffer = bytearray()
218
+ server_buffer = bytearray()
219
+ should_sniff = False
220
+
221
+ # Parse HTTP headers from client
222
+ async def _process_client_data():
223
+ nonlocal client_buffer, should_sniff
224
+
225
+ try:
226
+ while True:
227
+ data = await client_reader.read(8192)
228
+ if not data:
229
+ break
230
+ client_buffer.extend(data)
231
+
232
+ # Try to parse HTTP request
233
+ if b'\r\n\r\n' in client_buffer:
234
+ # Split headers and body
235
+ headers_end = client_buffer.find(b'\r\n\r\n') + 4
236
+ headers_data = client_buffer[:headers_end]
237
+ body_data = client_buffer[headers_end:]
238
+
239
+ # Parse request line and headers
240
+ lines = headers_data.split(b'\r\n')
241
+ request_line = lines[0].decode('utf-8')
242
+
243
+ try:
244
+ method, path, _ = request_line.split(' ')
245
+ except ValueError:
246
+ # Not a valid HTTP request, just forward
247
+ server_writer.write(client_buffer)
248
+ await server_writer.drain()
249
+ client_buffer.clear()
250
+ continue
251
+
252
+ # Check if we should intercept this request
253
+ if 'GenerateContent' in path:
254
+ should_sniff = True
255
+ # Process the request body
256
+ processed_body = await self.interceptor.process_request(
257
+ body_data, host, path
258
+ )
259
+
260
+ # Send the processed request
261
+ server_writer.write(headers_data)
262
+ server_writer.write(processed_body)
263
+ else:
264
+ should_sniff = False
265
+ # Forward the request as is
266
+ server_writer.write(client_buffer)
267
+
268
+ await server_writer.drain()
269
+ client_buffer.clear()
270
+ else:
271
+ # Not enough data to parse headers, forward as is
272
+ server_writer.write(data)
273
+ await server_writer.drain()
274
+ client_buffer.clear()
275
+ except Exception as e:
276
+ self.logger.error(f"Error processing client data: {e}")
277
+ finally:
278
+ server_writer.close()
279
+ # await server_writer.wait_closed()
280
+
281
+ # Parse HTTP headers from server
282
+ async def _process_server_data():
283
+ nonlocal server_buffer, should_sniff
284
+
285
+ try:
286
+ while True:
287
+ data = await server_reader.read(8192)
288
+ if not data:
289
+ break
290
+
291
+ server_buffer.extend(data)
292
+ if b'\r\n\r\n' in server_buffer:
293
+ # Split headers and body
294
+ headers_end = server_buffer.find(b'\r\n\r\n') + 4
295
+ headers_data = server_buffer[:headers_end]
296
+ body_data = server_buffer[headers_end:]
297
+
298
+ # Parse status line and headers
299
+ lines = headers_data.split(b'\r\n')
300
+
301
+ # Parse headers
302
+ headers = {}
303
+ for i in range(1, len(lines)):
304
+ if not lines[i]:
305
+ continue
306
+ try:
307
+ key, value = lines[i].decode('utf-8').split(':', 1)
308
+ headers[key.strip()] = value.strip()
309
+ except ValueError:
310
+ continue
311
+
312
+ # Check if this is a response to a GenerateContent request
313
+ if should_sniff:
314
+ try:
315
+ resp = await self.interceptor.process_response(
316
+ body_data, host, "", headers
317
+ )
318
+
319
+ if self.queue is not None:
320
+ self.queue.put(json.dumps(resp))
321
+ except Exception as e:
322
+ pass
323
+
324
+ # Not enough data to parse headers, forward as is
325
+ client_writer.write(data)
326
+ # await client_writer.drain()
327
+ if b"0\r\n\r\n" in server_buffer:
328
+ server_buffer.clear()
329
+ except Exception as e:
330
+ self.logger.error(f"Error processing server data: {e}")
331
+ finally:
332
+ client_writer.close()
333
+ # await client_writer.wait_closed()
334
+
335
+ # Create tasks for both directions
336
+ client_to_server = asyncio.create_task(_process_client_data())
337
+ server_to_client = asyncio.create_task(_process_server_data())
338
+
339
+
340
+ # Wait for both tasks to complete
341
+ tasks = [client_to_server, server_to_client]
342
+ await asyncio.gather(*tasks)
343
+ # await asyncio.gather(client_to_server, server_to_client)
344
+
345
+ async def start(self):
346
+ """
347
+ Start the proxy server
348
+ """
349
+ server = await asyncio.start_server(
350
+ self.handle_client, self.host, self.port
351
+ )
352
+
353
+ addr = server.sockets[0].getsockname()
354
+ self.logger.info(f'Serving on {addr}')
355
+
356
+ async with server:
357
+ await server.serve_forever()
stream/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from urllib.parse import urlparse
3
+
4
+ def is_generate_content_endpoint(url):
5
+ """
6
+ Check if the URL is a GenerateContent endpoint
7
+ """
8
+ return 'GenerateContent' in url
9
+
10
+ def parse_proxy_url(proxy_url):
11
+ """
12
+ Parse a proxy URL into its components
13
+
14
+ Returns:
15
+ tuple: (scheme, host, port, username, password)
16
+ """
17
+ if not proxy_url:
18
+ return None, None, None, None, None
19
+
20
+ parsed = urlparse(proxy_url)
21
+
22
+ scheme = parsed.scheme
23
+ host = parsed.hostname
24
+ port = parsed.port
25
+ username = parsed.username
26
+ password = parsed.password
27
+
28
+ return scheme, host, port, username, password
29
+
30
+ def setup_logger(name, log_file=None, level=logging.INFO):
31
+ """
32
+ Set up a logger with the specified name and configuration
33
+
34
+ Args:
35
+ name (str): Logger name
36
+ log_file (str, optional): Path to log file
37
+ level (int, optional): Logging level
38
+
39
+ Returns:
40
+ logging.Logger: Configured logger
41
+ """
42
+ logger = logging.getLogger(name)
43
+ logger.setLevel(level)
44
+
45
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
46
+
47
+ # Add console handler
48
+ console_handler = logging.StreamHandler()
49
+ console_handler.setFormatter(formatter)
50
+ logger.addHandler(console_handler)
51
+
52
+ # Add file handler if specified
53
+ if log_file:
54
+ file_handler = logging.FileHandler(log_file)
55
+ file_handler.setFormatter(formatter)
56
+ logger.addHandler(file_handler)
57
+
58
+ return logger