File size: 5,752 Bytes
7861a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
import logging
import re
import zlib

class HttpInterceptor:
    """

    Class to intercept and process HTTP requests and responses

    """
    def __init__(self, log_dir='logs'):
        self.log_dir = log_dir
        self.logger = logging.getLogger('http_interceptor')
        self.setup_logging()
    
    @staticmethod
    def setup_logging():
        """Set up logging configuration"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.StreamHandler()
            ]
        )
    
    @staticmethod
    def should_intercept(host, path):
        """

        Determine if the request should be intercepted based on host and path

        """
        # Check if the endpoint contains GenerateContent
        if 'GenerateContent' in path:
            return True
        
        # Add more conditions as needed
        return False
    
    async def process_request(self, request_data, host, path):
        """

        Process the request data before sending to the server

        """
        if not self.should_intercept(host, path):
            return request_data
        
        # Log the request
        self.logger.info(f"Intercepted request to {host}{path}")
        
        try:
            return request_data
        except (json.JSONDecodeError, UnicodeDecodeError):
            # Not JSON or not UTF-8, just pass through
            return request_data
    
    async def process_response(self, response_data, host, path, headers):
        """

        Process the response data before sending to the client

        """
        try:
            # Handle chunked encoding
            decoded_data, is_done = self._decode_chunked(bytes(response_data))
            # Handle gzip encoding
            decoded_data = self._decompress_zlib_stream(decoded_data)
            result  = self.parse_response(decoded_data)
            result["done"] = is_done
            return result
        except Exception as e:
            raise e

    def parse_response(self, response_data):
        pattern = rb'\[\[\[null,.*?]],"model"]'
        matches = []
        for match_obj in re.finditer(pattern, response_data):
            matches.append(match_obj.group(0))


        resp = {
            "reason": "",
            "body": "",
            "function": [],
        }

        # Print each full match
        for match in matches:
            json_data = json.loads(match)

            try:
                payload = json_data[0][0]
            except Exception as e:
                continue

            if len(payload)==2: # body
                resp["body"] = resp["body"] + payload[1]
            elif len(payload) == 11 and payload[1] is None and type(payload[10]) == list:  # function
                array_tool_calls = payload[10]
                func_name = array_tool_calls[0]
                params = self.parse_toolcall_params(array_tool_calls[1])
                resp["function"].append({"name":func_name, "params":params})
            elif len(payload) > 2: # reason
                resp["reason"] = resp["reason"] + payload[1]

        return resp

    def parse_toolcall_params(self, args):
        try:
            params = args[0]
            func_params = {}
            for param in params:
                param_name = param[0]
                param_value = param[1]

                if type(param_value)==list:
                    if len(param_value)==1: # null
                        func_params[param_name] = None
                    elif len(param_value) == 2: # number and integer
                        func_params[param_name] = param_value[1]
                    elif len(param_value) == 3: # string
                        func_params[param_name] = param_value[2]
                    elif len(param_value) == 4: # boolean
                        func_params[param_name] = param_value[3] == 1
                    elif len(param_value) == 5: # object
                        func_params[param_name] = self.parse_toolcall_params(param_value[4])
            return func_params
        except Exception as e:
            raise e

    @staticmethod
    def _decompress_zlib_stream(compressed_stream):
        decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32)  # zlib header
        decompressed = decompressor.decompress(compressed_stream)
        return decompressed

    @staticmethod
    def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
        chunked_data = bytearray()
        while True:
            # print(' '.join(format(x, '02x') for x in response_body))

            length_crlf_idx = response_body.find(b"\r\n")
            if length_crlf_idx == -1:
                break

            hex_length = response_body[:length_crlf_idx]
            try:
                length = int(hex_length, 16)
            except ValueError as e:
                logging.error(f"Parsing chunked length failed: {e}")
                break

            if length == 0:
                length_crlf_idx = response_body.find(b"0\r\n\r\n")
                if length_crlf_idx != -1:
                    return chunked_data, True

            if length + 2 > len(response_body):
                break

            chunked_data.extend(response_body[length_crlf_idx + 2:length_crlf_idx + 2 + length])
            if length_crlf_idx + 2 + length + 2 > len(response_body):
                break

            response_body = response_body[length_crlf_idx + 2 + length + 2:]
        return chunked_data, False