File size: 3,781 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# --- browser_utils/initialization/network.py ---
import asyncio
import json
import logging

from playwright.async_api import BrowserContext as AsyncBrowserContext

from config import settings

from .scripts import add_init_scripts_to_context

logger = logging.getLogger("AIStudioProxyServer")


async def setup_network_interception_and_scripts(context: AsyncBrowserContext):
    """Setup network interception and script injection"""
    try:
        # Check for network interception toggle
        if settings.NETWORK_INTERCEPTION_ENABLED:
            # Setup network interception
            await _setup_model_list_interception(context)
        else:
            logger.debug("[Network] Network interception disabled")

        # Check for script injection toggle
        if settings.ENABLE_SCRIPT_INJECTION:
            # Optional: still inject scripts as fallback
            await add_init_scripts_to_context(context)
        else:
            logger.debug("[Network] Script injection disabled")

    except asyncio.CancelledError:
        raise
    except Exception as e:
        logger.error(f"Error setting up network interception and scripts: {e}")


async def _setup_model_list_interception(context: AsyncBrowserContext):
    """Setup model list network interception"""
    try:

        async def handle_model_list_route(route):
            """Handle model list request route"""
            request = route.request

            # Check if it's a model list request
            if "alkalimakersuite" in request.url and "ListModels" in request.url:
                logger.info(f"Intercepted model list request: {request.url}")

                # Continue original request
                response = await route.fetch()

                # Get original response body
                original_body = await response.body()

                # Process response
                modified_body = await _modify_model_list_response(
                    original_body, request.url
                )

                # Return modified response
                await route.fulfill(response=response, body=modified_body)
            else:
                # For other requests, continue normally
                await route.continue_()

        # Register route interceptor
        await context.route("**/*", handle_model_list_route)
        logger.info("Model list network interception setup")

    except asyncio.CancelledError:
        raise
    except Exception as e:
        logger.error(f"Error setting up model list network interception: {e}")


async def _modify_model_list_response(original_body: bytes, url: str) -> bytes:
    """Modify model list response (Cleanup/Pass-through)"""
    try:
        # Decode response body
        original_text = original_body.decode("utf-8")

        # Handle anti-hijack prefix
        ANTI_HIJACK_PREFIX = ")]}'\n"
        has_prefix = False
        if original_text.startswith(ANTI_HIJACK_PREFIX):
            original_text = original_text[len(ANTI_HIJACK_PREFIX) :]
            has_prefix = True

        # Parse JSON to ensure it's valid, but we don't inject models anymore
        try:
            json_data = json.loads(original_text)
        except json.JSONDecodeError as json_err:
            logger.error(f"Failed to parse model list response JSON: {json_err}")
            return original_body

        # Serialize back to JSON
        modified_text = json.dumps(json_data, separators=(",", ":"))

        # Add prefix back
        if has_prefix:
            modified_text = ANTI_HIJACK_PREFIX + modified_text

        return modified_text.encode("utf-8")

    except asyncio.CancelledError:
        raise
    except Exception as e:
        logger.error(f"Error processing model list response: {e}")
        return original_body