import json import logging import re import zlib class HttpInterceptor: """ Class to intercept and process HTTP requests and responses """ def __init__(self, log_dir='logs'): self.log_dir = log_dir self.logger = logging.getLogger('http_interceptor') self.setup_logging() @staticmethod def setup_logging(): """Set up logging configuration""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() ] ) @staticmethod def should_intercept(host, path): """ Determine if the request should be intercepted based on host and path """ # Check if the endpoint contains GenerateContent if 'GenerateContent' in path: return True # Add more conditions as needed return False async def process_request(self, request_data, host, path): """ Process the request data before sending to the server """ if not self.should_intercept(host, path): return request_data # Log the request self.logger.info(f"Intercepted request to {host}{path}") try: return request_data except (json.JSONDecodeError, UnicodeDecodeError): # Not JSON or not UTF-8, just pass through return request_data async def process_response(self, response_data, host, path, headers): """ Process the response data before sending to the client """ try: # Handle chunked encoding decoded_data, is_done = self._decode_chunked(bytes(response_data)) # Handle gzip encoding decoded_data = self._decompress_zlib_stream(decoded_data) result = self.parse_response(decoded_data) result["done"] = is_done return result except Exception as e: raise e def parse_response(self, response_data): pattern = rb'\[\[\[null,.*?]],"model"]' matches = [] for match_obj in re.finditer(pattern, response_data): matches.append(match_obj.group(0)) resp = { "reason": "", "body": "", "function": [], } # Print each full match for match in matches: json_data = json.loads(match) try: payload = json_data[0][0] except Exception as e: continue if len(payload)==2: # body resp["body"] = resp["body"] + payload[1] elif len(payload) == 11 and payload[1] is None and type(payload[10]) == list: # function array_tool_calls = payload[10] func_name = array_tool_calls[0] params = self.parse_toolcall_params(array_tool_calls[1]) resp["function"].append({"name":func_name, "params":params}) elif len(payload) > 2: # reason resp["reason"] = resp["reason"] + payload[1] return resp def parse_toolcall_params(self, args): try: params = args[0] func_params = {} for param in params: param_name = param[0] param_value = param[1] if type(param_value)==list: if len(param_value)==1: # null func_params[param_name] = None elif len(param_value) == 2: # number and integer func_params[param_name] = param_value[1] elif len(param_value) == 3: # string func_params[param_name] = param_value[2] elif len(param_value) == 4: # boolean func_params[param_name] = param_value[3] == 1 elif len(param_value) == 5: # object func_params[param_name] = self.parse_toolcall_params(param_value[4]) return func_params except Exception as e: raise e @staticmethod def _decompress_zlib_stream(compressed_stream): decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) # zlib header decompressed = decompressor.decompress(compressed_stream) return decompressed @staticmethod def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]: chunked_data = bytearray() while True: # print(' '.join(format(x, '02x') for x in response_body)) length_crlf_idx = response_body.find(b"\r\n") if length_crlf_idx == -1: break hex_length = response_body[:length_crlf_idx] try: length = int(hex_length, 16) except ValueError as e: logging.error(f"Parsing chunked length failed: {e}") break if length == 0: length_crlf_idx = response_body.find(b"0\r\n\r\n") if length_crlf_idx != -1: return chunked_data, True if length + 2 > len(response_body): break chunked_data.extend(response_body[length_crlf_idx + 2:length_crlf_idx + 2 + length]) if length_crlf_idx + 2 + length + 2 > len(response_body): break response_body = response_body[length_crlf_idx + 2 + length + 2:] return chunked_data, False