Spaces:
Paused
Paused
| import asyncio | |
| from typing import Optional | |
| import json | |
| import logging | |
| import ssl | |
| import multiprocessing | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| from stream.cert_manager import CertificateManager | |
| from stream.proxy_connector import ProxyConnector | |
| from stream.interceptors import HttpInterceptor | |
| class ProxyServer: | |
| """ | |
| Asynchronous HTTPS proxy server with SSL inspection capabilities | |
| """ | |
| def __init__(self, host='0.0.0.0', port=3120, intercept_domains=None, upstream_proxy=None, queue: Optional[multiprocessing.Queue]=None): | |
| self.host = host | |
| self.port = port | |
| self.intercept_domains = intercept_domains or [] | |
| self.upstream_proxy = upstream_proxy | |
| self.queue = queue | |
| # Initialize components | |
| self.cert_manager = CertificateManager() | |
| self.proxy_connector = ProxyConnector(upstream_proxy) | |
| # Create logs directory | |
| log_dir = Path('logs') | |
| log_dir.mkdir(exist_ok=True) | |
| self.interceptor = HttpInterceptor(str(log_dir)) | |
| # Set up logging | |
| self.logger = logging.getLogger('proxy_server') | |
| def should_intercept(self, host): | |
| """ | |
| Determine if the connection to the host should be intercepted | |
| """ | |
| if host in self.intercept_domains: | |
| return True | |
| # Wildcard match (e.g. *.example.com) | |
| for d in self.intercept_domains: | |
| if d.startswith("*."): | |
| suffix = d[1:] # Remove * | |
| if host.endswith(suffix): | |
| return True | |
| return False | |
| async def handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
| """ | |
| Handle a client connection | |
| """ | |
| try: | |
| # Read the initial request line | |
| request_line = await reader.readline() | |
| request_line = request_line.decode('utf-8').strip() | |
| if not request_line: | |
| writer.close() | |
| return | |
| # Parse the request line | |
| method, target, version = request_line.split(' ') | |
| if method == 'CONNECT': | |
| # Handle HTTPS connection | |
| await self._handle_connect(reader, writer, target) | |
| except Exception as e: | |
| self.logger.error(f"Error handling client: {e}") | |
| finally: | |
| writer.close() | |
| async def _handle_connect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, target: str): | |
| """ | |
| Handle CONNECT method (for HTTPS connections) | |
| """ | |
| host, port = target.split(':') | |
| port = int(port) | |
| # Determine if we should intercept this connection | |
| intercept = self.should_intercept(host) | |
| if intercept: | |
| self.logger.info(f"Sniff HTTPS requests to : {target}") | |
| self.cert_manager.get_domain_cert(host) | |
| # Send 200 Connection Established to the client | |
| writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n') | |
| await writer.drain() | |
| # Drop the proxy connect header | |
| await reader.read(8192) | |
| loop = asyncio.get_running_loop() | |
| transport = writer.transport # This is the original client transport | |
| if transport is None: # 新增检查块开始 | |
| self.logger.warning(f"Client writer transport is None for {host}:{port} before TLS upgrade. Closing.") | |
| # writer is likely already closed or in a bad state. | |
| # We can't proceed with start_tls if transport is None. | |
| return # Exit _handle_connect for this client # 新增检查块结束 | |
| ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | |
| ssl_context.load_cert_chain( | |
| certfile=self.cert_manager.cert_dir / f"{host}.crt", | |
| keyfile=self.cert_manager.cert_dir / f"{host}.key" | |
| ) | |
| # 1. 正确获取与原始 transport 关联的协议实例 | |
| # 'transport' here is 'writer.transport' from line 101, now checked not to be None | |
| client_protocol = transport.get_protocol() | |
| # 2. 将获取到的 client_protocol 实例传递给 start_tls | |
| # loop.start_tls 会修改这个 client_protocol 实例,使其与 new_transport 关联 | |
| new_transport = await loop.start_tls( | |
| transport=transport, | |
| protocol=client_protocol, # 关键:传递获取到的协议实例 | |
| sslcontext=ssl_context, | |
| server_side=True | |
| ) | |
| # 3. 增加对 new_transport 的 None 检查 (主要为了类型安全和 Pylance) | |
| if new_transport is None: | |
| self.logger.error(f"loop.start_tls returned None for {host}:{port}, which is unexpected. Closing connection.") | |
| # Ensure client writer is closed if it was opened or transport was valid before | |
| writer.close() | |
| # await writer.wait_closed() # Consider if waiting is necessary here | |
| return | |
| client_reader = reader | |
| # 4. 创建 StreamWriter 时,使用被 start_tls 正确更新过的 client_protocol | |
| client_writer = asyncio.StreamWriter( | |
| transport=new_transport, # 使用新的 TLS transport | |
| protocol=client_protocol, # 关键:使用被 start_tls 更新过的协议实例 | |
| reader=client_reader, | |
| loop=loop | |
| ) | |
| # Connect to the target server | |
| try: | |
| server_reader, server_writer = await self.proxy_connector.create_connection( | |
| host, port, ssl=ssl.create_default_context() | |
| ) | |
| # Start bidirectional forwarding with interception | |
| await self._forward_data_with_interception( | |
| client_reader, client_writer, | |
| server_reader, server_writer, | |
| host | |
| ) | |
| except Exception as e: | |
| # self.logger.error(f"Error connecting to server {host}:{port}: {e}") | |
| client_writer.close() | |
| # await client_writer.wait_closed() | |
| else: | |
| # No interception, just forward the connection | |
| writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n') | |
| await writer.drain() | |
| # Drop the proxy connect header | |
| await reader.read(8192) | |
| try: | |
| # Connect to the target server | |
| server_reader, server_writer = await self.proxy_connector.create_connection( | |
| host, port, ssl=None | |
| ) | |
| # Start bidirectional forwarding without interception | |
| await self._forward_data( | |
| reader, writer, | |
| server_reader, server_writer | |
| ) | |
| except Exception as e: | |
| # self.logger.error(f"Error connecting to server {host}:{port}: {e}") | |
| writer.close() | |
| # await writer.wait_closed() | |
| async def _forward_data(self, client_reader, client_writer, server_reader, server_writer): | |
| """ | |
| Forward data between client and server without interception | |
| """ | |
| async def _forward(reader, writer): | |
| try: | |
| while True: | |
| data = await reader.read(8192) | |
| if not data: | |
| break | |
| writer.write(data) | |
| await writer.drain() | |
| except Exception as e: | |
| self.logger.error(f"Error forwarding data: {e}") | |
| finally: | |
| writer.close() | |
| # Create tasks for both directions | |
| client_to_server = asyncio.create_task(_forward(client_reader, server_writer)) | |
| server_to_client = asyncio.create_task(_forward(server_reader, client_writer)) | |
| # Wait for both tasks to complete | |
| tasks = [client_to_server, server_to_client] | |
| await asyncio.gather(*tasks) | |
| # await asyncio.gather(client_to_server, server_to_client) | |
| async def _forward_data_with_interception(self, client_reader, client_writer, | |
| server_reader, server_writer, host): | |
| """ | |
| Forward data between client and server with interception | |
| """ | |
| # Buffer to store HTTP request/response data | |
| client_buffer = bytearray() | |
| server_buffer = bytearray() | |
| should_sniff = False | |
| # Parse HTTP headers from client | |
| async def _process_client_data(): | |
| nonlocal client_buffer, should_sniff | |
| try: | |
| while True: | |
| data = await client_reader.read(8192) | |
| if not data: | |
| break | |
| client_buffer.extend(data) | |
| # Try to parse HTTP request | |
| if b'\r\n\r\n' in client_buffer: | |
| # Split headers and body | |
| headers_end = client_buffer.find(b'\r\n\r\n') + 4 | |
| headers_data = client_buffer[:headers_end] | |
| body_data = client_buffer[headers_end:] | |
| # Parse request line and headers | |
| lines = headers_data.split(b'\r\n') | |
| request_line = lines[0].decode('utf-8') | |
| try: | |
| method, path, _ = request_line.split(' ') | |
| except ValueError: | |
| # Not a valid HTTP request, just forward | |
| server_writer.write(client_buffer) | |
| await server_writer.drain() | |
| client_buffer.clear() | |
| continue | |
| # Check if we should intercept this request | |
| if 'GenerateContent' in path: | |
| should_sniff = True | |
| # Process the request body | |
| processed_body = await self.interceptor.process_request( | |
| body_data, host, path | |
| ) | |
| # Send the processed request | |
| server_writer.write(headers_data) | |
| server_writer.write(processed_body) | |
| else: | |
| should_sniff = False | |
| # Forward the request as is | |
| server_writer.write(client_buffer) | |
| await server_writer.drain() | |
| client_buffer.clear() | |
| else: | |
| # Not enough data to parse headers, forward as is | |
| server_writer.write(data) | |
| await server_writer.drain() | |
| client_buffer.clear() | |
| except Exception as e: | |
| self.logger.error(f"Error processing client data: {e}") | |
| finally: | |
| server_writer.close() | |
| # await server_writer.wait_closed() | |
| # Parse HTTP headers from server | |
| async def _process_server_data(): | |
| nonlocal server_buffer, should_sniff | |
| try: | |
| while True: | |
| data = await server_reader.read(8192) | |
| if not data: | |
| break | |
| server_buffer.extend(data) | |
| if b'\r\n\r\n' in server_buffer: | |
| # Split headers and body | |
| headers_end = server_buffer.find(b'\r\n\r\n') + 4 | |
| headers_data = server_buffer[:headers_end] | |
| body_data = server_buffer[headers_end:] | |
| # Parse status line and headers | |
| lines = headers_data.split(b'\r\n') | |
| # Parse headers | |
| headers = {} | |
| for i in range(1, len(lines)): | |
| if not lines[i]: | |
| continue | |
| try: | |
| key, value = lines[i].decode('utf-8').split(':', 1) | |
| headers[key.strip()] = value.strip() | |
| except ValueError: | |
| continue | |
| # Check if this is a response to a GenerateContent request | |
| if should_sniff: | |
| try: | |
| resp = await self.interceptor.process_response( | |
| body_data, host, "", headers | |
| ) | |
| if self.queue is not None: | |
| self.queue.put(json.dumps(resp)) | |
| except Exception as e: | |
| pass | |
| # Not enough data to parse headers, forward as is | |
| client_writer.write(data) | |
| # await client_writer.drain() | |
| if b"0\r\n\r\n" in server_buffer: | |
| server_buffer.clear() | |
| except Exception as e: | |
| self.logger.error(f"Error processing server data: {e}") | |
| finally: | |
| client_writer.close() | |
| # await client_writer.wait_closed() | |
| # Create tasks for both directions | |
| client_to_server = asyncio.create_task(_process_client_data()) | |
| server_to_client = asyncio.create_task(_process_server_data()) | |
| # Wait for both tasks to complete | |
| tasks = [client_to_server, server_to_client] | |
| await asyncio.gather(*tasks) | |
| # await asyncio.gather(client_to_server, server_to_client) | |
| async def start(self): | |
| """ | |
| Start the proxy server | |
| """ | |
| server = await asyncio.start_server( | |
| self.handle_client, self.host, self.port | |
| ) | |
| addr = server.sockets[0].getsockname() | |
| self.logger.info(f'Serving on {addr}') | |
| async with server: | |
| await server.serve_forever() | |