peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
9.76 kB
"""
Model Switching Logic
"""
import asyncio
import json
import logging
import os
from typing import Optional
from playwright.async_api import Page as AsyncPage
from playwright.async_api import expect as expect_async
from config import AI_STUDIO_URL_PATTERN, INPUT_SELECTOR, MODEL_NAME_SELECTOR
from .ui_state import _verify_and_apply_ui_state
logger = logging.getLogger("AIStudioProxyServer")
async def switch_ai_studio_model(page: AsyncPage, model_id: str, req_id: str) -> bool:
"""Switch AI Studio model"""
logger.info(f"[Model] Switching to -> {model_id}")
original_prefs_str: Optional[str] = None
_original_prompt_model: Optional[str] = None
new_chat_url = f"https://{AI_STUDIO_URL_PATTERN}prompts/new_chat"
try:
original_prefs_str = await page.evaluate(
"() => localStorage.getItem('aiStudioUserPreference')"
)
if original_prefs_str:
try:
original_prefs_obj = json.loads(original_prefs_str)
_original_prompt_model = original_prefs_obj.get("promptModel")
except json.JSONDecodeError:
logger.warning(
"Failed to parse original aiStudioUserPreference JSON string."
)
original_prefs_str = None
current_prefs_for_modification = (
json.loads(original_prefs_str) if original_prefs_str else {}
)
full_model_path = f"models/{model_id}"
if current_prefs_for_modification.get("promptModel") == full_model_path:
logger.debug(f"[Model] Already at target model {model_id}")
if page.url != new_chat_url:
logger.debug(
f"[Model] URL is not new_chat, navigating to {new_chat_url}"
)
await page.goto(
new_chat_url, wait_until="domcontentloaded", timeout=30000
)
await expect_async(page.locator(INPUT_SELECTOR)).to_be_visible(
timeout=30000
)
return True
logger.debug(
f"[Model] Updating localStorage.promptModel: {current_prefs_for_modification.get('promptModel', 'unknown')} -> {full_model_path}"
)
current_prefs_for_modification["promptModel"] = full_model_path
await page.evaluate(
"(prefsStr) => localStorage.setItem('aiStudioUserPreference', prefsStr)",
json.dumps(current_prefs_for_modification),
)
# Use new forced setting feature
logger.debug("[State] Applying forced UI state settings...")
ui_state_success = await _verify_and_apply_ui_state(page, req_id)
if not ui_state_success:
logger.warning(
"UI state setting failed, but continuing model switching flow"
)
# To maintain compatibility, also update current prefs object
current_prefs_for_modification["isAdvancedOpen"] = True
current_prefs_for_modification["areToolsOpen"] = True
await page.evaluate(
"(prefsStr) => localStorage.setItem('aiStudioUserPreference', prefsStr)",
json.dumps(current_prefs_for_modification),
)
logger.debug(f"[Model] Navigating to {new_chat_url}...")
await page.goto(new_chat_url, wait_until="domcontentloaded", timeout=30000)
input_field = page.locator(INPUT_SELECTOR)
await expect_async(input_field).to_be_visible(timeout=30000)
logger.debug("[Model] Page navigation complete, input box visible")
# Verify UI state settings again after page load
logger.debug("[State] Verifying UI state...")
final_ui_state_success = await _verify_and_apply_ui_state(page, req_id)
if final_ui_state_success:
logger.debug("[State] UI state verification successful")
else:
logger.warning(
"Final UI state verification failed, but continuing model switching flow"
)
final_prefs_str = await page.evaluate(
"() => localStorage.getItem('aiStudioUserPreference')"
)
final_prompt_model_in_storage: Optional[str] = None
if final_prefs_str:
try:
final_prefs_obj = json.loads(final_prefs_str)
final_prompt_model_in_storage = final_prefs_obj.get("promptModel")
except json.JSONDecodeError:
logger.warning(
"Failed to parse refreshed aiStudioUserPreference JSON string."
)
if final_prompt_model_in_storage == full_model_path:
logger.debug(f"[Model] localStorage set correctly: {full_model_path}")
page_display_match = False
# Get parsed_model_list
from api_utils.server_state import state
parsed_model_list = getattr(state, "parsed_model_list", [])
if parsed_model_list:
for m_obj in parsed_model_list:
if m_obj.get("id") == model_id:
m_obj.get("display_name")
break
try:
model_name_locator = page.locator(MODEL_NAME_SELECTOR)
actual_displayed_model_id_on_page_raw = (
await model_name_locator.first.inner_text(timeout=5000)
)
actual_displayed_model_id_on_page = (
actual_displayed_model_id_on_page_raw.strip()
)
target_model_id = model_id
if actual_displayed_model_id_on_page == target_model_id:
page_display_match = True
logger.info("[Model] Switching successful")
else:
page_display_match = False
logger.error(
f"Page displayed model ID ('{actual_displayed_model_id_on_page}') inconsistent with expected ID ('{target_model_id}')."
)
except asyncio.CancelledError:
raise
except Exception as e_disp:
page_display_match = False # Reading failed, assume mismatch
logger.warning(
f"Error reading displayed model ID: {e_disp}. Cannot verify page display."
)
if page_display_match:
try:
logger.debug("[Model] Re-enabling temporary chat mode...")
from browser_utils.initialization import enable_temporary_chat_mode
enabled = await enable_temporary_chat_mode(page)
if enabled:
logger.debug("[Model] Temporary chat mode enabled or already active")
else:
logger.warning(
"[Model] Temporary chat mode state verification failed"
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(f"Failed to re-enable temporary chat mode after model switching: {e}")
# Invalidate function calling cache on model switch
try:
from api_utils.utils_ext.function_calling_cache import (
FunctionCallingCache,
)
FunctionCallingCache.get_instance().invalidate(
reason=f"model_switch:{model_id}", req_id=req_id
)
except ImportError:
pass # Cache module not available
except Exception as e_cache:
logger.debug(f"[Model] Failed to invalidate FC cache: {e_cache}")
return True
else:
logger.error(
"Model switching failed because page displayed model does not match expectation (even if localStorage may have changed)."
)
else:
logger.error(
f"AI Studio did not accept model change (localStorage). Expected='{full_model_path}', Actual='{final_prompt_model_in_storage or 'not set or invalid'}'."
)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Serious error occurred during model switching")
from browser_utils.operations import save_error_snapshot
await save_error_snapshot(f"model_switch_error_{req_id}")
return False
return False
def load_excluded_models(filename: str):
"""Load excluded model list"""
from api_utils.server_state import state
excluded_model_ids = getattr(state, "excluded_model_ids", set())
excluded_file_path = os.path.join(os.path.dirname(__file__), "..", "..", filename)
try:
if os.path.exists(excluded_file_path):
with open(excluded_file_path, "r", encoding="utf-8") as f:
loaded_ids = {line.strip() for line in f if line.strip()}
if loaded_ids:
excluded_model_ids.update(loaded_ids)
state.excluded_model_ids = excluded_model_ids
logger.debug(
f"Loaded {len(loaded_ids)} models from '{filename}' into exclusion list"
)
else:
logger.debug(
f"'{filename}' file is empty or contains no valid model IDs, exclusion list unchanged."
)
else:
logger.debug(
f"Model exclusion list file '{filename}' not found, list is empty."
)
except Exception as e:
logger.error(
f"Error loading excluded models from '{filename}': {e}", exc_info=True
)