Spaces:
Paused
Paused
| from playwright.async_api import Page as AsyncPage | |
| from api_utils.server_state import state | |
| from logging_utils import set_request_id | |
| from .context_types import RequestContext | |
| async def analyze_model_requirements( | |
| req_id: str, context: RequestContext, requested_model: str, proxy_model_name: str | |
| ) -> RequestContext: | |
| set_request_id(req_id) | |
| logger = context["logger"] | |
| current_ai_studio_model_id = context["current_ai_studio_model_id"] | |
| parsed_model_list = context["parsed_model_list"] | |
| if requested_model and requested_model != proxy_model_name: | |
| requested_model_id = requested_model.split("/")[-1] | |
| logger.info(f"[{req_id}] Requesting model: {requested_model_id}") | |
| if parsed_model_list: | |
| valid_model_ids = [ | |
| str(m.get("id")) for m in parsed_model_list if m.get("id") | |
| ] | |
| if requested_model_id not in valid_model_ids: | |
| from .error_utils import bad_request | |
| raise bad_request( | |
| req_id, | |
| f"Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}", | |
| ) | |
| context["model_id_to_use"] = requested_model_id | |
| if current_ai_studio_model_id != requested_model_id: | |
| context["needs_model_switching"] = True | |
| logger.info( | |
| f"[{req_id}] Model switch needed: Current={current_ai_studio_model_id} -> Target={requested_model_id}" | |
| ) | |
| return context | |
| async def handle_model_switching( | |
| req_id: str, context: RequestContext | |
| ) -> RequestContext: | |
| set_request_id(req_id) | |
| if not context["needs_model_switching"]: | |
| return context | |
| logger = context["logger"] | |
| page = context["page"] | |
| model_switching_lock = context["model_switching_lock"] | |
| model_id_to_use = context["model_id_to_use"] | |
| # Assert non-None values required for model switching | |
| assert page is not None, "Page must be ready for model switching" | |
| assert model_id_to_use is not None, "Target model ID must be set" | |
| async with model_switching_lock: | |
| if state.current_ai_studio_model_id != model_id_to_use: | |
| logger.info( | |
| f"[{req_id}] Preparing to switch model: {state.current_ai_studio_model_id} -> {model_id_to_use}" | |
| ) | |
| from browser_utils import switch_ai_studio_model | |
| switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id) | |
| if switch_success: | |
| state.current_ai_studio_model_id = model_id_to_use | |
| context["model_actually_switched"] = True | |
| context["current_ai_studio_model_id"] = model_id_to_use | |
| logger.info( | |
| f"[{req_id}] ✅ Model switched successfully: {state.current_ai_studio_model_id}" | |
| ) | |
| else: | |
| # Current model ID should exist when switching fails | |
| current_model = state.current_ai_studio_model_id or "unknown" | |
| await _handle_model_switch_failure( | |
| req_id, | |
| page, | |
| model_id_to_use, | |
| current_model, | |
| logger, | |
| ) | |
| return context | |
| async def _handle_model_switch_failure( | |
| req_id: str, page: AsyncPage, model_id_to_use: str, model_before_switch: str, logger | |
| ) -> None: | |
| logger.warning(f"[{req_id}] ❌ Failed to switch to model {model_id_to_use}.") | |
| state.current_ai_studio_model_id = model_before_switch | |
| from .error_utils import http_error | |
| raise http_error( | |
| 422, | |
| f"[{req_id}] Failed to switch to model '{model_id_to_use}'. Ensure model is available.", | |
| ) | |
| async def handle_parameter_cache(req_id: str, context: RequestContext) -> None: | |
| set_request_id(req_id) | |
| logger = context["logger"] | |
| params_cache_lock = context["params_cache_lock"] | |
| page_params_cache = context["page_params_cache"] | |
| current_ai_studio_model_id = context["current_ai_studio_model_id"] | |
| model_actually_switched = context["model_actually_switched"] | |
| async with params_cache_lock: | |
| cached_model_for_params = page_params_cache.get( | |
| "last_known_model_id_for_params" | |
| ) | |
| if model_actually_switched or ( | |
| current_ai_studio_model_id != cached_model_for_params | |
| ): | |
| logger.info(f"[{req_id}] Model changed, parameter cache invalidated.") | |
| page_params_cache.clear() | |
| page_params_cache["last_known_model_id_for_params"] = ( | |
| current_ai_studio_model_id | |
| ) | |