peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
11.1 kB
import asyncio
from typing import Any, Callable, Dict, List, Optional, Tuple
from playwright.async_api import expect as expect_async
from browser_utils.operations import (
_get_final_response_content,
_wait_for_response_completion,
save_error_snapshot,
)
from config import (
EDIT_MESSAGE_BUTTON_SELECTOR,
PROMPT_TEXTAREA_SELECTOR,
RESPONSE_CONTAINER_SELECTOR,
RESPONSE_TEXT_SELECTOR,
SUBMIT_BUTTON_SELECTOR,
)
from config.settings import FUNCTION_CALLING_DEBUG
from logging_utils import set_request_id
from models import ClientDisconnectedError
from .base import BaseController
class ResponseController(BaseController):
"""Handles retrieval of AI responses."""
async def get_response(
self,
check_client_disconnected: Callable,
prompt_length: int = 0,
timeout: Optional[float] = None,
) -> str:
"""Retrieve response content."""
set_request_id(self.req_id)
self.logger.debug("[Response] Waiting for and retrieving response...")
try:
# Wait for response container
response_container_locator = self.page.locator(
RESPONSE_CONTAINER_SELECTOR
).last
response_element_locator = response_container_locator.locator(
RESPONSE_TEXT_SELECTOR
)
self.logger.debug(
"[Response] Waiting for response element to be attached to DOM..."
)
await expect_async(response_element_locator).to_be_attached(timeout=90000)
await self._check_disconnect(
check_client_disconnected,
"Retrieve Response - Response element attached",
)
# Wait for response completion
submit_button_locator = self.page.locator(SUBMIT_BUTTON_SELECTOR)
edit_button_locator = self.page.locator(EDIT_MESSAGE_BUTTON_SELECTOR)
input_field_locator = self.page.locator(PROMPT_TEXTAREA_SELECTOR)
self.logger.debug("[Response] Waiting for response completion...")
completion_detected = await _wait_for_response_completion(
self.page,
input_field_locator,
submit_button_locator,
edit_button_locator,
self.req_id,
check_client_disconnected,
None,
prompt_length=prompt_length,
timeout=timeout,
)
if not completion_detected:
self.logger.warning(
"Response completion detection failed, attempting to retrieve current content"
)
else:
self.logger.debug("[Response] Response completion detection successful")
# Get final response content
final_content = await _get_final_response_content(
self.page, self.req_id, check_client_disconnected
)
if not final_content or not final_content.strip():
self.logger.warning("Retrieved response content is empty")
await save_error_snapshot(f"empty_response_{self.req_id}")
# Do not raise exception, return empty content to let caller handle
return ""
self.logger.debug(
f"[Response] Successfully retrieved content ({len(final_content)} chars)"
)
return final_content
except Exception as e:
if isinstance(e, asyncio.CancelledError):
self.logger.info("Retrieve response task cancelled")
raise
self.logger.error(f"Error retrieving response: {e}")
if not isinstance(e, ClientDisconnectedError):
await save_error_snapshot(f"get_response_error_{self.req_id}")
raise
async def ensure_generation_stopped(
self, check_client_disconnected: Callable
) -> None:
"""
Ensure generation has stopped.
If submit button is still enabled, click it to stop generation.
Wait until submit button becomes disabled.
"""
submit_button_locator = self.page.locator(SUBMIT_BUTTON_SELECTOR)
# Check client connection status
check_client_disconnected("Ensure generation stopped - pre-check")
await asyncio.sleep(0.5) # Give UI time to update
# Check if button is still enabled, if so click to stop
try:
is_button_enabled = await submit_button_locator.is_enabled(timeout=2000)
if is_button_enabled:
# Button still enabled after stream completion, click to stop
self.logger.debug(
"[Cleanup] Submit button state: ENABLED -> Clicking stop"
)
await submit_button_locator.click(timeout=5000, force=True)
else:
self.logger.debug(
"[Cleanup] Submit button state: DISABLED (no action needed)"
)
except Exception as button_check_err:
if isinstance(button_check_err, asyncio.CancelledError):
raise
self.logger.warning(f"Failed to check button state: {button_check_err}")
# Wait for button to be disabled
try:
await expect_async(submit_button_locator).to_be_disabled(timeout=30000)
except Exception as e:
if isinstance(e, asyncio.CancelledError):
raise
self.logger.warning(f"Timeout or error ensuring generation stopped: {e}")
# Do not raise even on timeout as this is just a cleanup step
async def detect_function_calls(
self,
check_client_disconnected: Callable,
) -> bool:
"""
Detect if the current response contains function calls.
This is a fast check without full parsing. Uses the
FunctionCallResponseParser for DOM-based detection.
Args:
check_client_disconnected: Callback to check client connection.
Returns:
True if function calls are detected, False otherwise.
"""
try:
from api_utils.utils_ext.function_call_response_parser import (
FunctionCallResponseParser,
)
parser = FunctionCallResponseParser(
page=self.page,
logger=self.logger,
req_id=self.req_id,
)
return await parser.detect_function_calls(check_client_disconnected)
except Exception as e:
if FUNCTION_CALLING_DEBUG:
self.logger.debug(
f"[{self.req_id}] Error detecting function calls: {e}"
)
return False
async def parse_function_calls(
self,
check_client_disconnected: Callable,
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
Parse function calls from the current response.
Args:
check_client_disconnected: Callback to check client connection.
Returns:
Tuple of:
- has_function_calls: Whether any function calls were found
- function_calls: List of function call dicts with 'name' and 'params'
- text_content: Any remaining text content
"""
try:
from api_utils.utils_ext.function_call_response_parser import (
FunctionCallResponseParser,
)
parser = FunctionCallResponseParser(
page=self.page,
logger=self.logger,
req_id=self.req_id,
)
result = await parser.parse_function_calls(check_client_disconnected)
# Convert ParsedFunctionCall to dict format
function_calls: List[Dict[str, Any]] = []
for fc in result.function_calls:
function_calls.append(
{
"name": fc.name,
"params": fc.arguments,
}
)
return result.has_function_calls, function_calls, result.text_content
except Exception as e:
if FUNCTION_CALLING_DEBUG:
self.logger.error(f"[{self.req_id}] Error parsing function calls: {e}")
return False, [], ""
async def get_response_with_function_calls(
self,
check_client_disconnected: Callable,
prompt_length: int = 0,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
"""
Retrieve response content with function call detection.
This method extends get_response by also checking for function calls
in the response and returning them in a structured format.
Args:
check_client_disconnected: Callback to check client connection.
prompt_length: Length of the prompt for timeout calculation.
timeout: Optional explicit timeout.
Returns:
Dict with:
- content: Text content of the response
- has_function_calls: Whether function calls were detected
- function_calls: List of function call dicts
- raw_content: The raw content string
"""
set_request_id(self.req_id)
if FUNCTION_CALLING_DEBUG:
self.logger.debug(
"[Response] Getting response with function call detection..."
)
result: Dict[str, Any] = {
"content": "",
"has_function_calls": False,
"function_calls": [],
"raw_content": "",
}
try:
# Get the raw response content first
raw_content = await self.get_response(
check_client_disconnected,
prompt_length=prompt_length,
timeout=timeout,
)
result["raw_content"] = raw_content
# Check for function calls
has_fc, function_calls, text_content = await self.parse_function_calls(
check_client_disconnected
)
if has_fc:
if FUNCTION_CALLING_DEBUG:
self.logger.info(
f"[{self.req_id}] Detected {len(function_calls)} function call(s) in response"
)
result["has_function_calls"] = True
result["function_calls"] = function_calls
result["content"] = text_content
else:
result["content"] = raw_content
except Exception as e:
if isinstance(e, asyncio.CancelledError):
raise
if FUNCTION_CALLING_DEBUG:
self.logger.error(
f"[{self.req_id}] Error getting response with FC: {e}"
)
if not isinstance(e, ClientDisconnectedError):
await save_error_snapshot(f"get_response_fc_error_{self.req_id}")
raise
return result