Spaces:
Paused
Paused
File size: 9,755 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
Model Switching Logic
"""
import asyncio
import json
import logging
import os
from typing import Optional
from playwright.async_api import Page as AsyncPage
from playwright.async_api import expect as expect_async
from config import AI_STUDIO_URL_PATTERN, INPUT_SELECTOR, MODEL_NAME_SELECTOR
from .ui_state import _verify_and_apply_ui_state
logger = logging.getLogger("AIStudioProxyServer")
async def switch_ai_studio_model(page: AsyncPage, model_id: str, req_id: str) -> bool:
"""Switch AI Studio model"""
logger.info(f"[Model] Switching to -> {model_id}")
original_prefs_str: Optional[str] = None
_original_prompt_model: Optional[str] = None
new_chat_url = f"https://{AI_STUDIO_URL_PATTERN}prompts/new_chat"
try:
original_prefs_str = await page.evaluate(
"() => localStorage.getItem('aiStudioUserPreference')"
)
if original_prefs_str:
try:
original_prefs_obj = json.loads(original_prefs_str)
_original_prompt_model = original_prefs_obj.get("promptModel")
except json.JSONDecodeError:
logger.warning(
"Failed to parse original aiStudioUserPreference JSON string."
)
original_prefs_str = None
current_prefs_for_modification = (
json.loads(original_prefs_str) if original_prefs_str else {}
)
full_model_path = f"models/{model_id}"
if current_prefs_for_modification.get("promptModel") == full_model_path:
logger.debug(f"[Model] Already at target model {model_id}")
if page.url != new_chat_url:
logger.debug(
f"[Model] URL is not new_chat, navigating to {new_chat_url}"
)
await page.goto(
new_chat_url, wait_until="domcontentloaded", timeout=30000
)
await expect_async(page.locator(INPUT_SELECTOR)).to_be_visible(
timeout=30000
)
return True
logger.debug(
f"[Model] Updating localStorage.promptModel: {current_prefs_for_modification.get('promptModel', 'unknown')} -> {full_model_path}"
)
current_prefs_for_modification["promptModel"] = full_model_path
await page.evaluate(
"(prefsStr) => localStorage.setItem('aiStudioUserPreference', prefsStr)",
json.dumps(current_prefs_for_modification),
)
# Use new forced setting feature
logger.debug("[State] Applying forced UI state settings...")
ui_state_success = await _verify_and_apply_ui_state(page, req_id)
if not ui_state_success:
logger.warning(
"UI state setting failed, but continuing model switching flow"
)
# To maintain compatibility, also update current prefs object
current_prefs_for_modification["isAdvancedOpen"] = True
current_prefs_for_modification["areToolsOpen"] = True
await page.evaluate(
"(prefsStr) => localStorage.setItem('aiStudioUserPreference', prefsStr)",
json.dumps(current_prefs_for_modification),
)
logger.debug(f"[Model] Navigating to {new_chat_url}...")
await page.goto(new_chat_url, wait_until="domcontentloaded", timeout=30000)
input_field = page.locator(INPUT_SELECTOR)
await expect_async(input_field).to_be_visible(timeout=30000)
logger.debug("[Model] Page navigation complete, input box visible")
# Verify UI state settings again after page load
logger.debug("[State] Verifying UI state...")
final_ui_state_success = await _verify_and_apply_ui_state(page, req_id)
if final_ui_state_success:
logger.debug("[State] UI state verification successful")
else:
logger.warning(
"Final UI state verification failed, but continuing model switching flow"
)
final_prefs_str = await page.evaluate(
"() => localStorage.getItem('aiStudioUserPreference')"
)
final_prompt_model_in_storage: Optional[str] = None
if final_prefs_str:
try:
final_prefs_obj = json.loads(final_prefs_str)
final_prompt_model_in_storage = final_prefs_obj.get("promptModel")
except json.JSONDecodeError:
logger.warning(
"Failed to parse refreshed aiStudioUserPreference JSON string."
)
if final_prompt_model_in_storage == full_model_path:
logger.debug(f"[Model] localStorage set correctly: {full_model_path}")
page_display_match = False
# Get parsed_model_list
from api_utils.server_state import state
parsed_model_list = getattr(state, "parsed_model_list", [])
if parsed_model_list:
for m_obj in parsed_model_list:
if m_obj.get("id") == model_id:
m_obj.get("display_name")
break
try:
model_name_locator = page.locator(MODEL_NAME_SELECTOR)
actual_displayed_model_id_on_page_raw = (
await model_name_locator.first.inner_text(timeout=5000)
)
actual_displayed_model_id_on_page = (
actual_displayed_model_id_on_page_raw.strip()
)
target_model_id = model_id
if actual_displayed_model_id_on_page == target_model_id:
page_display_match = True
logger.info("[Model] Switching successful")
else:
page_display_match = False
logger.error(
f"Page displayed model ID ('{actual_displayed_model_id_on_page}') inconsistent with expected ID ('{target_model_id}')."
)
except asyncio.CancelledError:
raise
except Exception as e_disp:
page_display_match = False # Reading failed, assume mismatch
logger.warning(
f"Error reading displayed model ID: {e_disp}. Cannot verify page display."
)
if page_display_match:
try:
logger.debug("[Model] Re-enabling temporary chat mode...")
from browser_utils.initialization import enable_temporary_chat_mode
enabled = await enable_temporary_chat_mode(page)
if enabled:
logger.debug("[Model] Temporary chat mode enabled or already active")
else:
logger.warning(
"[Model] Temporary chat mode state verification failed"
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(f"Failed to re-enable temporary chat mode after model switching: {e}")
# Invalidate function calling cache on model switch
try:
from api_utils.utils_ext.function_calling_cache import (
FunctionCallingCache,
)
FunctionCallingCache.get_instance().invalidate(
reason=f"model_switch:{model_id}", req_id=req_id
)
except ImportError:
pass # Cache module not available
except Exception as e_cache:
logger.debug(f"[Model] Failed to invalidate FC cache: {e_cache}")
return True
else:
logger.error(
"Model switching failed because page displayed model does not match expectation (even if localStorage may have changed)."
)
else:
logger.error(
f"AI Studio did not accept model change (localStorage). Expected='{full_model_path}', Actual='{final_prompt_model_in_storage or 'not set or invalid'}'."
)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Serious error occurred during model switching")
from browser_utils.operations import save_error_snapshot
await save_error_snapshot(f"model_switch_error_{req_id}")
return False
return False
def load_excluded_models(filename: str):
"""Load excluded model list"""
from api_utils.server_state import state
excluded_model_ids = getattr(state, "excluded_model_ids", set())
excluded_file_path = os.path.join(os.path.dirname(__file__), "..", "..", filename)
try:
if os.path.exists(excluded_file_path):
with open(excluded_file_path, "r", encoding="utf-8") as f:
loaded_ids = {line.strip() for line in f if line.strip()}
if loaded_ids:
excluded_model_ids.update(loaded_ids)
state.excluded_model_ids = excluded_model_ids
logger.debug(
f"Loaded {len(loaded_ids)} models from '{filename}' into exclusion list"
)
else:
logger.debug(
f"'{filename}' file is empty or contains no valid model IDs, exclusion list unchanged."
)
else:
logger.debug(
f"Model exclusion list file '{filename}' not found, list is empty."
)
except Exception as e:
logger.error(
f"Error loading excluded models from '{filename}': {e}", exc_info=True
)
|