Spaces:
Paused
Paused
File size: 4,613 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from playwright.async_api import Page as AsyncPage
from api_utils.server_state import state
from logging_utils import set_request_id
from .context_types import RequestContext
async def analyze_model_requirements(
req_id: str, context: RequestContext, requested_model: str, proxy_model_name: str
) -> RequestContext:
set_request_id(req_id)
logger = context["logger"]
current_ai_studio_model_id = context["current_ai_studio_model_id"]
parsed_model_list = context["parsed_model_list"]
if requested_model and requested_model != proxy_model_name:
requested_model_id = requested_model.split("/")[-1]
logger.info(f"[{req_id}] Requesting model: {requested_model_id}")
if parsed_model_list:
valid_model_ids = [
str(m.get("id")) for m in parsed_model_list if m.get("id")
]
if requested_model_id not in valid_model_ids:
from .error_utils import bad_request
raise bad_request(
req_id,
f"Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}",
)
context["model_id_to_use"] = requested_model_id
if current_ai_studio_model_id != requested_model_id:
context["needs_model_switching"] = True
logger.info(
f"[{req_id}] Model switch needed: Current={current_ai_studio_model_id} -> Target={requested_model_id}"
)
return context
async def handle_model_switching(
req_id: str, context: RequestContext
) -> RequestContext:
set_request_id(req_id)
if not context["needs_model_switching"]:
return context
logger = context["logger"]
page = context["page"]
model_switching_lock = context["model_switching_lock"]
model_id_to_use = context["model_id_to_use"]
# Assert non-None values required for model switching
assert page is not None, "Page must be ready for model switching"
assert model_id_to_use is not None, "Target model ID must be set"
async with model_switching_lock:
if state.current_ai_studio_model_id != model_id_to_use:
logger.info(
f"[{req_id}] Preparing to switch model: {state.current_ai_studio_model_id} -> {model_id_to_use}"
)
from browser_utils import switch_ai_studio_model
switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id)
if switch_success:
state.current_ai_studio_model_id = model_id_to_use
context["model_actually_switched"] = True
context["current_ai_studio_model_id"] = model_id_to_use
logger.info(
f"[{req_id}] ✅ Model switched successfully: {state.current_ai_studio_model_id}"
)
else:
# Current model ID should exist when switching fails
current_model = state.current_ai_studio_model_id or "unknown"
await _handle_model_switch_failure(
req_id,
page,
model_id_to_use,
current_model,
logger,
)
return context
async def _handle_model_switch_failure(
req_id: str, page: AsyncPage, model_id_to_use: str, model_before_switch: str, logger
) -> None:
logger.warning(f"[{req_id}] ❌ Failed to switch to model {model_id_to_use}.")
state.current_ai_studio_model_id = model_before_switch
from .error_utils import http_error
raise http_error(
422,
f"[{req_id}] Failed to switch to model '{model_id_to_use}'. Ensure model is available.",
)
async def handle_parameter_cache(req_id: str, context: RequestContext) -> None:
set_request_id(req_id)
logger = context["logger"]
params_cache_lock = context["params_cache_lock"]
page_params_cache = context["page_params_cache"]
current_ai_studio_model_id = context["current_ai_studio_model_id"]
model_actually_switched = context["model_actually_switched"]
async with params_cache_lock:
cached_model_for_params = page_params_cache.get(
"last_known_model_id_for_params"
)
if model_actually_switched or (
current_ai_studio_model_id != cached_model_for_params
):
logger.info(f"[{req_id}] Model changed, parameter cache invalidated.")
page_params_cache.clear()
page_params_cache["last_known_model_id_for_params"] = (
current_ai_studio_model_id
)
|