Spaces:
Paused
Paused
File size: 5,752 Bytes
7861a83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import json
import logging
import re
import zlib
class HttpInterceptor:
"""
Class to intercept and process HTTP requests and responses
"""
def __init__(self, log_dir='logs'):
self.log_dir = log_dir
self.logger = logging.getLogger('http_interceptor')
self.setup_logging()
@staticmethod
def setup_logging():
"""Set up logging configuration"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
@staticmethod
def should_intercept(host, path):
"""
Determine if the request should be intercepted based on host and path
"""
# Check if the endpoint contains GenerateContent
if 'GenerateContent' in path:
return True
# Add more conditions as needed
return False
async def process_request(self, request_data, host, path):
"""
Process the request data before sending to the server
"""
if not self.should_intercept(host, path):
return request_data
# Log the request
self.logger.info(f"Intercepted request to {host}{path}")
try:
return request_data
except (json.JSONDecodeError, UnicodeDecodeError):
# Not JSON or not UTF-8, just pass through
return request_data
async def process_response(self, response_data, host, path, headers):
"""
Process the response data before sending to the client
"""
try:
# Handle chunked encoding
decoded_data, is_done = self._decode_chunked(bytes(response_data))
# Handle gzip encoding
decoded_data = self._decompress_zlib_stream(decoded_data)
result = self.parse_response(decoded_data)
result["done"] = is_done
return result
except Exception as e:
raise e
def parse_response(self, response_data):
pattern = rb'\[\[\[null,.*?]],"model"]'
matches = []
for match_obj in re.finditer(pattern, response_data):
matches.append(match_obj.group(0))
resp = {
"reason": "",
"body": "",
"function": [],
}
# Print each full match
for match in matches:
json_data = json.loads(match)
try:
payload = json_data[0][0]
except Exception as e:
continue
if len(payload)==2: # body
resp["body"] = resp["body"] + payload[1]
elif len(payload) == 11 and payload[1] is None and type(payload[10]) == list: # function
array_tool_calls = payload[10]
func_name = array_tool_calls[0]
params = self.parse_toolcall_params(array_tool_calls[1])
resp["function"].append({"name":func_name, "params":params})
elif len(payload) > 2: # reason
resp["reason"] = resp["reason"] + payload[1]
return resp
def parse_toolcall_params(self, args):
try:
params = args[0]
func_params = {}
for param in params:
param_name = param[0]
param_value = param[1]
if type(param_value)==list:
if len(param_value)==1: # null
func_params[param_name] = None
elif len(param_value) == 2: # number and integer
func_params[param_name] = param_value[1]
elif len(param_value) == 3: # string
func_params[param_name] = param_value[2]
elif len(param_value) == 4: # boolean
func_params[param_name] = param_value[3] == 1
elif len(param_value) == 5: # object
func_params[param_name] = self.parse_toolcall_params(param_value[4])
return func_params
except Exception as e:
raise e
@staticmethod
def _decompress_zlib_stream(compressed_stream):
decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) # zlib header
decompressed = decompressor.decompress(compressed_stream)
return decompressed
@staticmethod
def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
chunked_data = bytearray()
while True:
# print(' '.join(format(x, '02x') for x in response_body))
length_crlf_idx = response_body.find(b"\r\n")
if length_crlf_idx == -1:
break
hex_length = response_body[:length_crlf_idx]
try:
length = int(hex_length, 16)
except ValueError as e:
logging.error(f"Parsing chunked length failed: {e}")
break
if length == 0:
length_crlf_idx = response_body.find(b"0\r\n\r\n")
if length_crlf_idx != -1:
return chunked_data, True
if length + 2 > len(response_body):
break
chunked_data.extend(response_body[length_crlf_idx + 2:length_crlf_idx + 2 + length])
if length_crlf_idx + 2 + length + 2 > len(response_body):
break
response_body = response_body[length_crlf_idx + 2 + length + 2:]
return chunked_data, False
|