AIstudioProxyAPI / stream /proxy_server.py
hins111's picture
Upload 7 files
7861a83 verified
import asyncio
from typing import Optional
import json
import logging
import ssl
import multiprocessing
from pathlib import Path
from urllib.parse import urlparse
from stream.cert_manager import CertificateManager
from stream.proxy_connector import ProxyConnector
from stream.interceptors import HttpInterceptor
class ProxyServer:
"""
Asynchronous HTTPS proxy server with SSL inspection capabilities
"""
def __init__(self, host='0.0.0.0', port=3120, intercept_domains=None, upstream_proxy=None, queue: Optional[multiprocessing.Queue]=None):
self.host = host
self.port = port
self.intercept_domains = intercept_domains or []
self.upstream_proxy = upstream_proxy
self.queue = queue
# Initialize components
self.cert_manager = CertificateManager()
self.proxy_connector = ProxyConnector(upstream_proxy)
# Create logs directory
log_dir = Path('logs')
log_dir.mkdir(exist_ok=True)
self.interceptor = HttpInterceptor(str(log_dir))
# Set up logging
self.logger = logging.getLogger('proxy_server')
def should_intercept(self, host):
"""
Determine if the connection to the host should be intercepted
"""
if host in self.intercept_domains:
return True
# Wildcard match (e.g. *.example.com)
for d in self.intercept_domains:
if d.startswith("*."):
suffix = d[1:] # Remove *
if host.endswith(suffix):
return True
return False
async def handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
"""
Handle a client connection
"""
try:
# Read the initial request line
request_line = await reader.readline()
request_line = request_line.decode('utf-8').strip()
if not request_line:
writer.close()
return
# Parse the request line
method, target, version = request_line.split(' ')
if method == 'CONNECT':
# Handle HTTPS connection
await self._handle_connect(reader, writer, target)
except Exception as e:
self.logger.error(f"Error handling client: {e}")
finally:
writer.close()
async def _handle_connect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, target: str):
"""
Handle CONNECT method (for HTTPS connections)
"""
host, port = target.split(':')
port = int(port)
# Determine if we should intercept this connection
intercept = self.should_intercept(host)
if intercept:
self.logger.info(f"Sniff HTTPS requests to : {target}")
self.cert_manager.get_domain_cert(host)
# Send 200 Connection Established to the client
writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n')
await writer.drain()
# Drop the proxy connect header
await reader.read(8192)
loop = asyncio.get_running_loop()
transport = writer.transport # This is the original client transport
if transport is None: # 新增检查块开始
self.logger.warning(f"Client writer transport is None for {host}:{port} before TLS upgrade. Closing.")
# writer is likely already closed or in a bad state.
# We can't proceed with start_tls if transport is None.
return # Exit _handle_connect for this client # 新增检查块结束
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile=self.cert_manager.cert_dir / f"{host}.crt",
keyfile=self.cert_manager.cert_dir / f"{host}.key"
)
# 1. 正确获取与原始 transport 关联的协议实例
# 'transport' here is 'writer.transport' from line 101, now checked not to be None
client_protocol = transport.get_protocol()
# 2. 将获取到的 client_protocol 实例传递给 start_tls
# loop.start_tls 会修改这个 client_protocol 实例,使其与 new_transport 关联
new_transport = await loop.start_tls(
transport=transport,
protocol=client_protocol, # 关键:传递获取到的协议实例
sslcontext=ssl_context,
server_side=True
)
# 3. 增加对 new_transport 的 None 检查 (主要为了类型安全和 Pylance)
if new_transport is None:
self.logger.error(f"loop.start_tls returned None for {host}:{port}, which is unexpected. Closing connection.")
# Ensure client writer is closed if it was opened or transport was valid before
writer.close()
# await writer.wait_closed() # Consider if waiting is necessary here
return
client_reader = reader
# 4. 创建 StreamWriter 时,使用被 start_tls 正确更新过的 client_protocol
client_writer = asyncio.StreamWriter(
transport=new_transport, # 使用新的 TLS transport
protocol=client_protocol, # 关键:使用被 start_tls 更新过的协议实例
reader=client_reader,
loop=loop
)
# Connect to the target server
try:
server_reader, server_writer = await self.proxy_connector.create_connection(
host, port, ssl=ssl.create_default_context()
)
# Start bidirectional forwarding with interception
await self._forward_data_with_interception(
client_reader, client_writer,
server_reader, server_writer,
host
)
except Exception as e:
# self.logger.error(f"Error connecting to server {host}:{port}: {e}")
client_writer.close()
# await client_writer.wait_closed()
else:
# No interception, just forward the connection
writer.write(b'HTTP/1.1 200 Connection Established\r\n\r\n')
await writer.drain()
# Drop the proxy connect header
await reader.read(8192)
try:
# Connect to the target server
server_reader, server_writer = await self.proxy_connector.create_connection(
host, port, ssl=None
)
# Start bidirectional forwarding without interception
await self._forward_data(
reader, writer,
server_reader, server_writer
)
except Exception as e:
# self.logger.error(f"Error connecting to server {host}:{port}: {e}")
writer.close()
# await writer.wait_closed()
async def _forward_data(self, client_reader, client_writer, server_reader, server_writer):
"""
Forward data between client and server without interception
"""
async def _forward(reader, writer):
try:
while True:
data = await reader.read(8192)
if not data:
break
writer.write(data)
await writer.drain()
except Exception as e:
self.logger.error(f"Error forwarding data: {e}")
finally:
writer.close()
# Create tasks for both directions
client_to_server = asyncio.create_task(_forward(client_reader, server_writer))
server_to_client = asyncio.create_task(_forward(server_reader, client_writer))
# Wait for both tasks to complete
tasks = [client_to_server, server_to_client]
await asyncio.gather(*tasks)
# await asyncio.gather(client_to_server, server_to_client)
async def _forward_data_with_interception(self, client_reader, client_writer,
server_reader, server_writer, host):
"""
Forward data between client and server with interception
"""
# Buffer to store HTTP request/response data
client_buffer = bytearray()
server_buffer = bytearray()
should_sniff = False
# Parse HTTP headers from client
async def _process_client_data():
nonlocal client_buffer, should_sniff
try:
while True:
data = await client_reader.read(8192)
if not data:
break
client_buffer.extend(data)
# Try to parse HTTP request
if b'\r\n\r\n' in client_buffer:
# Split headers and body
headers_end = client_buffer.find(b'\r\n\r\n') + 4
headers_data = client_buffer[:headers_end]
body_data = client_buffer[headers_end:]
# Parse request line and headers
lines = headers_data.split(b'\r\n')
request_line = lines[0].decode('utf-8')
try:
method, path, _ = request_line.split(' ')
except ValueError:
# Not a valid HTTP request, just forward
server_writer.write(client_buffer)
await server_writer.drain()
client_buffer.clear()
continue
# Check if we should intercept this request
if 'GenerateContent' in path:
should_sniff = True
# Process the request body
processed_body = await self.interceptor.process_request(
body_data, host, path
)
# Send the processed request
server_writer.write(headers_data)
server_writer.write(processed_body)
else:
should_sniff = False
# Forward the request as is
server_writer.write(client_buffer)
await server_writer.drain()
client_buffer.clear()
else:
# Not enough data to parse headers, forward as is
server_writer.write(data)
await server_writer.drain()
client_buffer.clear()
except Exception as e:
self.logger.error(f"Error processing client data: {e}")
finally:
server_writer.close()
# await server_writer.wait_closed()
# Parse HTTP headers from server
async def _process_server_data():
nonlocal server_buffer, should_sniff
try:
while True:
data = await server_reader.read(8192)
if not data:
break
server_buffer.extend(data)
if b'\r\n\r\n' in server_buffer:
# Split headers and body
headers_end = server_buffer.find(b'\r\n\r\n') + 4
headers_data = server_buffer[:headers_end]
body_data = server_buffer[headers_end:]
# Parse status line and headers
lines = headers_data.split(b'\r\n')
# Parse headers
headers = {}
for i in range(1, len(lines)):
if not lines[i]:
continue
try:
key, value = lines[i].decode('utf-8').split(':', 1)
headers[key.strip()] = value.strip()
except ValueError:
continue
# Check if this is a response to a GenerateContent request
if should_sniff:
try:
resp = await self.interceptor.process_response(
body_data, host, "", headers
)
if self.queue is not None:
self.queue.put(json.dumps(resp))
except Exception as e:
pass
# Not enough data to parse headers, forward as is
client_writer.write(data)
# await client_writer.drain()
if b"0\r\n\r\n" in server_buffer:
server_buffer.clear()
except Exception as e:
self.logger.error(f"Error processing server data: {e}")
finally:
client_writer.close()
# await client_writer.wait_closed()
# Create tasks for both directions
client_to_server = asyncio.create_task(_process_client_data())
server_to_client = asyncio.create_task(_process_server_data())
# Wait for both tasks to complete
tasks = [client_to_server, server_to_client]
await asyncio.gather(*tasks)
# await asyncio.gather(client_to_server, server_to_client)
async def start(self):
"""
Start the proxy server
"""
server = await asyncio.start_server(
self.handle_client, self.host, self.port
)
addr = server.sockets[0].getsockname()
self.logger.info(f'Serving on {addr}')
async with server:
await server.serve_forever()