import asyncio import json import logging import socket import ssl import time from pathlib import Path from typing import Any, List, Optional from stream.cert_manager import CertificateManager from stream.interceptors import HttpInterceptor from stream.proxy_connector import ProxyConnector class ProxyServer: """ Asynchronous HTTPS proxy server with SSL inspection capabilities """ def __init__( self, host: str = "0.0.0.0", port: int = 3120, intercept_domains: Optional[List[str]] = None, upstream_proxy: Optional[str] = None, queue: Optional[Any] = None, ): self.host = host self.port = port self.intercept_domains = intercept_domains or [] self.passthrough_domains = [ "feedback-pa.clients6.google.com", "play.google.com", "apis.google.com", "accounts.google.com", ] self.upstream_proxy = upstream_proxy self.queue = queue # Initialize components self.cert_manager = CertificateManager() self.proxy_connector = ProxyConnector(upstream_proxy) # Create logs directory log_dir = Path("logs") log_dir.mkdir(exist_ok=True) self.interceptor = HttpInterceptor(str(log_dir)) # Set up logging self.logger = logging.getLogger("proxy_server") # Keep track of background tasks self.background_tasks = set() def _safe_close(self, writer): """ Safely close a writer with robust error handling for SSL shutdown timeouts """ if not writer: return try: sock = writer.get_extra_info("socket") if sock: try: sock.shutdown(socket.SHUT_RDWR) except (OSError, ssl.SSLError): pass except Exception: pass finally: try: writer.close() except Exception: pass def should_intercept(self, host): """ Determine if the connection to the host should be intercepted """ if host in self.passthrough_domains: return False if host in self.intercept_domains: return True for d in self.intercept_domains: if d.startswith("*."): suffix = d[1:] if host.endswith(suffix): return True return False async def handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): """ Handle a client connection """ current_task = asyncio.current_task() if current_task: self.background_tasks.add(current_task) current_task.add_done_callback(self.background_tasks.discard) try: request_line = await reader.readline() request_line = request_line.decode("utf-8").strip() if not request_line: self._safe_close(writer) return parts = request_line.split(" ") if len(parts) < 2: self._safe_close(writer) return method, target = parts[0], parts[1] if method == "CONNECT": await self._handle_connect(reader, writer, target) except asyncio.CancelledError: raise except Exception as e: self.logger.error(f"Error handling client: {e}", exc_info=True) finally: self._safe_close(writer) async def _handle_connect( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, target: str ): """ Handle CONNECT method (for HTTPS connections) """ host, port_str = target.split(":") port = int(port_str) intercept = self.should_intercept(host) if intercept: self.cert_manager.get_domain_cert(host) writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n") await writer.drain() await reader.read(8192) loop = asyncio.get_running_loop() transport = writer.transport if transport is None: self.logger.warning( f"Client writer transport is None for {host}:{port} before TLS upgrade. Closing." ) return ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain( certfile=self.cert_manager.cert_dir / f"{host}.crt", keyfile=self.cert_manager.cert_dir / f"{host}.key", ) client_protocol = transport.get_protocol() new_transport = await loop.start_tls( transport=transport, protocol=client_protocol, sslcontext=ssl_context, server_side=True, ) if new_transport is None: self.logger.error( f"loop.start_tls returned None for {host}:{port}, which is unexpected. Closing connection.", exc_info=True, ) writer.close() return client_writer = asyncio.StreamWriter( transport=new_transport, protocol=client_protocol, reader=reader, loop=loop, ) try: ( server_reader, server_writer, ) = await self.proxy_connector.create_connection( host, port, ssl=ssl.create_default_context() ) await self._forward_data_with_interception( reader, client_writer, server_reader, server_writer, host ) except asyncio.CancelledError: raise except Exception as e: self.logger.error( f"Error connecting to server {host}:{port}: {e}", exc_info=True ) client_writer.close() try: await client_writer.wait_closed() except Exception: pass else: writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n") await writer.drain() await reader.read(8192) try: ( server_reader, server_writer, ) = await self.proxy_connector.create_connection(host, port, ssl=None) await self._forward_data(reader, writer, server_reader, server_writer) except asyncio.CancelledError: raise except Exception as e: self.logger.error( f"Error connecting to server {host}:{port}: {e}", exc_info=True ) writer.close() try: await writer.wait_closed() except Exception: pass async def _forward_data( self, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, server_reader: asyncio.StreamReader, server_writer: asyncio.StreamWriter, ) -> None: async def _forward( reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: while True: data = await reader.read(8192) if not data: break writer.write(data) await writer.drain() except ConnectionResetError: self.logger.debug("Connection reset by peer.") except Exception as e: self.logger.error(f"Error forwarding data: {e}", exc_info=True) finally: self._safe_close(writer) client_to_server = asyncio.create_task(_forward(client_reader, server_writer)) server_to_client = asyncio.create_task(_forward(server_reader, client_writer)) tasks = [client_to_server, server_to_client] try: _done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED ) except asyncio.CancelledError: for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) raise for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass async def _forward_data_with_interception( self, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, server_reader: asyncio.StreamReader, server_writer: asyncio.StreamWriter, host: str, ) -> None: client_buffer = bytearray() server_buffer = bytearray() should_sniff = False request_context = {"request_ts": 0.0} async def _process_client_data(): nonlocal client_buffer, should_sniff try: while True: data = await client_reader.read(8192) if not data: break client_buffer.extend(data) if b"\r\n\r\n" in client_buffer: headers_end = client_buffer.find(b"\r\n\r\n") + 4 headers_data = client_buffer[:headers_end] body_data = client_buffer[headers_end:] lines = headers_data.split(b"\r\n") request_line = lines[0].decode("utf-8") try: _method, path, _ = request_line.split(" ") except ValueError: server_writer.write(client_buffer) await server_writer.drain() client_buffer.clear() continue if "GenerateContent" in path or "generateContent" in path: should_sniff = True request_context["request_ts"] = time.time() # Reset interceptor state for new request to prevent # state leakage from previous requests self.interceptor.reset_for_new_request() self.logger.debug( f"[Proxy] Detected GenerateContent request: {path[:60]}..." ) processed_body = await self.interceptor.process_request( bytes(body_data), host, path ) server_writer.write(headers_data) if isinstance(processed_body, bytes): server_writer.write(processed_body) else: should_sniff = False server_writer.write(client_buffer) await server_writer.drain() client_buffer.clear() else: server_writer.write(data) await server_writer.drain() client_buffer.clear() except ConnectionResetError: self.logger.debug("Connection reset by peer processing client data.") except Exception as e: if "Broken pipe" in str(e) or "Connection reset" in str(e): self.logger.debug(f"[Proxy] Client disconnected: {e}") else: self.logger.error( f"Error processing client data: {e}", exc_info=True ) finally: self._safe_close(server_writer) async def _process_server_data(): nonlocal server_buffer, should_sniff try: while True: data = await server_reader.read(8192) if not data: break server_buffer.extend(data) if b"\r\n\r\n" in server_buffer: headers_end = server_buffer.find(b"\r\n\r\n") + 4 headers_data = server_buffer[:headers_end] body_data = server_buffer[headers_end:] lines = headers_data.split(b"\r\n") status_code = 200 status_message = "OK" if lines and lines[0]: try: status_line = lines[0].decode("utf-8") parts = status_line.split(" ", 2) if len(parts) >= 2: status_code = int(parts[1]) status_message = parts[2] if len(parts) > 2 else "" except (ValueError, UnicodeDecodeError): pass headers: dict[str, str] = {} for i in range(1, len(lines)): if not lines[i]: continue try: key, value = lines[i].decode("utf-8").split(":", 1) headers[key.strip()] = value.strip() except ValueError: continue if should_sniff: try: if status_code >= 400: self.logger.error( f"[UPSTREAM ERROR] {status_code} {status_message}" ) if self.queue is not None: error_payload = { "error": True, "status": status_code, "message": f"{status_code} {status_message}", "done": True, } self.queue.put(json.dumps(error_payload)) else: resp = await self.interceptor.process_response( bytes(body_data), host, "", headers ) if self.queue is not None: payload = { "ts": request_context.get("request_ts", 0), "data": resp, } self.queue.put(json.dumps(payload)) if resp.get("done", False): self.logger.debug( f"[Proxy] Stream complete: body={len(resp.get('body', ''))}" ) except asyncio.CancelledError: raise except Exception as e: self.logger.error( f"Error during response interception: {e}", exc_info=True, ) client_writer.write(data) if b"0\r\n\r\n" in server_buffer: server_buffer.clear() except ConnectionResetError: self.logger.debug("Connection reset by peer processing server data.") except Exception as e: self.logger.error(f"Error processing server data: {e}", exc_info=True) finally: self._safe_close(client_writer) client_to_server = asyncio.create_task(_process_client_data()) server_to_client = asyncio.create_task(_process_server_data()) tasks = [client_to_server, server_to_client] try: _done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED ) except asyncio.CancelledError: for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) raise for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass async def start(self) -> None: """ Start the proxy server """ server = await asyncio.start_server(self.handle_client, self.host, self.port) addr = server.sockets[0].getsockname() self.logger.debug(f"[Proxy] Serving on: {addr}") if self.queue: try: self.queue.put("READY") self.logger.debug("[Proxy] Sent READY signal") except Exception as e: self.logger.error(f"Failed to send 'READY' signal: {e}", exc_info=True) async with server: await server.serve_forever()