Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import queue | |
| import re | |
| import time | |
| from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple | |
| from config.settings import FUNCTION_CALLING_DEBUG | |
| from logging_utils import set_request_id | |
| # [REFAC-01] Structural Boundary Pattern | |
| TOOL_STRUCTURE_PATTERN = re.compile( | |
| r"(?:^|\n)\s*(?:```[a-zA-Z0-9]*\s*)?<[a-zA-Z0-9_\-]+(?:\s|>)" | |
| ) | |
| async def use_stream_response( | |
| req_id: str, | |
| timeout: float = 5.0, | |
| silence_threshold: float = 60.0, | |
| page: Any = None, | |
| check_client_disconnected: Optional[Callable] = None, | |
| stream_start_time: float = 0.0, | |
| enable_silence_detection: bool = True, | |
| ) -> AsyncGenerator[Any, None]: | |
| """Enhanced stream response handler with UI-based generation active checks.""" | |
| from api_utils.server_state import state | |
| STREAM_QUEUE = state.STREAM_QUEUE | |
| logger = state.logger | |
| from browser_utils.page_controller import PageController | |
| from config import ( | |
| CHAT_SESSION_CONTENT_SELECTOR, | |
| LAST_CHAT_TURN_SELECTOR, | |
| SCROLL_CONTAINER_SELECTOR, | |
| UI_GENERATION_WAIT_TIMEOUT_MS, | |
| ) | |
| from config.global_state import GlobalState | |
| from models import ( | |
| ClientDisconnectedError, | |
| QuotaExceededError, | |
| UpstreamError, | |
| ) | |
| set_request_id(req_id) | |
| if STREAM_QUEUE is None: | |
| logger.warning(f"[{req_id}] STREAM_QUEUE is None, cannot use stream response") | |
| return | |
| if stream_start_time == 0.0: | |
| stream_start_time = time.time() - 10.0 | |
| accumulated_body = "" | |
| accumulated_reason_len = 0 | |
| total_reason_processed = 0 | |
| total_body_processed = 0 | |
| boundary_transitions = 0 | |
| boundary_buffer = "" | |
| acc_reason_state = "" | |
| acc_body_state = "" | |
| force_body_mode = False | |
| split_index = -1 | |
| empty_count = 0 | |
| initial_wait_limit = int(timeout * 10) | |
| silence_wait_limit = int(silence_threshold * 10) | |
| max_empty_retries = max(silence_wait_limit, initial_wait_limit) | |
| hard_timeout_limit = int(timeout * 10 * 3) | |
| _data_received = False | |
| has_content = False | |
| has_seen_functions = False | |
| received_items_count = 0 | |
| stale_done_ignored = False | |
| last_ui_check_time = 0 | |
| ui_check_interval = int(UI_GENERATION_WAIT_TIMEOUT_MS / 100) | |
| if ui_check_interval <= 0: | |
| ui_check_interval = 1 | |
| last_packet_time = time.time() | |
| min_items_before_silence_check = 10 | |
| async def check_ui_generation_active(): | |
| if not page: | |
| return False | |
| try: | |
| stop_button = page.locator('button[aria-label="Stop generating"]') | |
| if await stop_button.is_visible(timeout=1000): | |
| return True | |
| from config.selectors import SUBMIT_BUTTON_SELECTOR | |
| submit_button = page.locator(SUBMIT_BUTTON_SELECTOR) | |
| if await submit_button.count() > 0: | |
| try: | |
| if await submit_button.first.is_disabled(timeout=2000): | |
| return True | |
| except Exception: | |
| return False | |
| return False | |
| except Exception: | |
| return False | |
| try: | |
| while True: | |
| if ( | |
| GlobalState.CURRENT_STREAM_REQ_ID | |
| and GlobalState.CURRENT_STREAM_REQ_ID != req_id | |
| ): | |
| logger.warning(f"[{req_id}] Zombie Stream detected. Terminating.") | |
| yield { | |
| "done": True, | |
| "reason": "zombie_stream_aborted", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| if page: | |
| try: | |
| await page.evaluate( | |
| """([scrollSel, contentSel, lastTurnSel]) => { | |
| const scrollContainer = document.querySelector(scrollSel); | |
| if (scrollContainer) scrollContainer.scrollTop = scrollContainer.scrollHeight; | |
| const sessionContent = document.querySelector(contentSel); | |
| if (sessionContent) sessionContent.scrollTop = sessionContent.scrollHeight; | |
| const lastTurn = document.querySelector(lastTurnSel); | |
| if (lastTurn) lastTurn.scrollIntoView({behavior: "instant", block: "end"}); | |
| window.scrollTo(0, document.body.scrollHeight); | |
| }""", | |
| [ | |
| SCROLL_CONTAINER_SELECTOR, | |
| CHAT_SESSION_CONTENT_SELECTOR, | |
| LAST_CHAT_TURN_SELECTOR, | |
| ], | |
| ) | |
| except Exception: | |
| pass | |
| if GlobalState.IS_QUOTA_EXCEEDED and not GlobalState.IS_RECOVERING: | |
| logger.warning(f"[{req_id}] Quota detected. Pausing...") | |
| try: | |
| start_wait = time.time() | |
| while time.time() - start_wait < 2.0: | |
| if GlobalState.IS_RECOVERING: | |
| break | |
| await asyncio.sleep(0.2) | |
| except Exception: | |
| pass | |
| if GlobalState.IS_RECOVERING: | |
| logger.info(f"[{req_id}] 🔄 Recovery mode detected. Holding...") | |
| elif not GlobalState.IS_QUOTA_EXCEEDED: | |
| logger.info(f"[{req_id}] ✅ Recovery completed. Resuming...") | |
| else: | |
| logger.warning(f"[{req_id}] ⛔ Quota exceeded, waiting...") | |
| await asyncio.sleep(1.0) | |
| continue | |
| if GlobalState.IS_SHUTTING_DOWN.is_set(): | |
| logger.warning(f"[{req_id}] 🛑 Global Shutdown. Aborting.") | |
| yield { | |
| "done": True, | |
| "reason": "global_shutdown", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| try: | |
| data = STREAM_QUEUE.get_nowait() | |
| if data is None: | |
| logger.info(f"[{req_id}] 🔴 Received termination signal.") | |
| break | |
| if isinstance(data, dict) and data.get("done") is True: | |
| logger.info(f"[{req_id}] ✅ Explicit DONE received.") | |
| yield data | |
| break | |
| empty_count = 0 | |
| _data_received = True | |
| received_items_count += 1 | |
| last_packet_time = time.time() | |
| actual_data = data | |
| if isinstance(data, str): | |
| try: | |
| parsed_wrapper = json.loads(data) | |
| if ( | |
| isinstance(parsed_wrapper, dict) | |
| and "ts" in parsed_wrapper | |
| and "data" in parsed_wrapper | |
| ): | |
| if parsed_wrapper["ts"] < stream_start_time: | |
| logger.warning(f"[{req_id}] 🗑️ Stale data ignored.") | |
| continue | |
| actual_data = parsed_wrapper["data"] | |
| else: | |
| actual_data = parsed_wrapper | |
| except json.JSONDecodeError: | |
| pass | |
| if isinstance(actual_data, dict): | |
| if actual_data.get("error"): | |
| status = actual_data.get("status", 500) | |
| message = actual_data.get("message", "Unknown error") | |
| if status == 429 or "quota" in message.lower(): | |
| raise QuotaExceededError( | |
| f"AI Studio quota exceeded: {message}", req_id=req_id | |
| ) | |
| else: | |
| raise UpstreamError( | |
| f"AI Studio error: {message}", | |
| status_code=status, | |
| ) | |
| parsed_data = actual_data | |
| p_reason = str(parsed_data.get("reason", "")) | |
| p_body = str(parsed_data.get("body", "")) | |
| if ( | |
| p_reason | |
| and acc_reason_state | |
| and p_reason.startswith(acc_reason_state) | |
| ): | |
| new_reason_delta = p_reason[len(acc_reason_state) :] | |
| acc_reason_state = p_reason | |
| else: | |
| acc_reason_state += p_reason | |
| new_reason_delta = p_reason | |
| if p_body and acc_body_state and p_body.startswith(acc_body_state): | |
| acc_body_state = p_body | |
| else: | |
| acc_body_state += p_body | |
| if force_body_mode: | |
| parsed_data["reason"] = acc_reason_state[:split_index] | |
| parsed_data["body"] = ( | |
| acc_body_state + acc_reason_state[split_index:] | |
| ) | |
| else: | |
| text_to_check = boundary_buffer + new_reason_delta | |
| match = TOOL_STRUCTURE_PATTERN.search(text_to_check) | |
| if match: | |
| offset = len(acc_reason_state) - len(text_to_check) | |
| split_index = offset + match.start() | |
| force_body_mode = True | |
| boundary_transitions += 1 | |
| parsed_data["reason"] = acc_reason_state[:split_index] | |
| parsed_data["body"] = ( | |
| acc_body_state + acc_reason_state[split_index:] | |
| ) | |
| logger.info(f"[{req_id}] ✂️ Boundary Split Applied.") | |
| else: | |
| parsed_data["reason"] = acc_reason_state | |
| parsed_data["body"] = acc_body_state | |
| boundary_buffer = (boundary_buffer + new_reason_delta)[ | |
| -100: | |
| ] | |
| accumulated_body += str(parsed_data.get("body", "")) | |
| accumulated_reason_len += len(str(parsed_data.get("reason", ""))) | |
| total_body_processed += len(str(parsed_data.get("body", ""))) | |
| total_reason_processed += len(str(parsed_data.get("reason", ""))) | |
| if parsed_data.get("body") or parsed_data.get("reason"): | |
| has_content = True | |
| if parsed_data.get("function"): | |
| has_seen_functions = True | |
| # Track if any function call has empty arguments (potential parse failure) | |
| for fc in parsed_data.get("function", []): | |
| fc_params = fc.get("params") or fc.get("arguments") or {} | |
| if not fc_params: | |
| if FUNCTION_CALLING_DEBUG: | |
| logger.warning( | |
| f"[{req_id}] ⚠️ Wire format returned '{fc.get('name')}' with empty args - will try DOM fallback" | |
| ) | |
| has_seen_functions = False # Force DOM fallback | |
| break | |
| if parsed_data.get("done") is True: | |
| if GlobalState.IS_QUOTA_EXCEEDED or GlobalState.IS_RECOVERING: | |
| logger.info( | |
| f"[{req_id}] 🛡️ Quota/Recovery active: Holding stream open." | |
| ) | |
| continue | |
| just_rotated = ( | |
| time.time() - GlobalState.LAST_ROTATION_TIMESTAMP < 15.0 | |
| ) | |
| recently_recovered = ( | |
| time.time() - GlobalState.LAST_ROTATION_TIMESTAMP < 30.0 | |
| ) | |
| if ( | |
| not has_content | |
| and received_items_count == 1 | |
| and not stale_done_ignored | |
| and not GlobalState.IS_QUOTA_EXCEEDED | |
| and (just_rotated or recently_recovered) | |
| ): | |
| logger.info( | |
| f"[{req_id}] 🔄 Post-rotation empty DONE detected. Ignoring." | |
| ) | |
| stale_done_ignored = True | |
| continue | |
| if ( | |
| parsed_data.get("done") is True | |
| and not has_seen_functions | |
| and page | |
| ): | |
| # Retry loop for DOM function call detection - UI elements may not render immediately | |
| # Similar to body text retry loop below, but shorter timeout for function calls | |
| dom_functions = [] | |
| dom_text = "" | |
| max_fc_retries = 10 # 10 retries * 0.3s = 3 seconds max wait | |
| for fc_retry in range(max_fc_retries): | |
| ( | |
| dom_functions, | |
| dom_text, | |
| ) = await detect_function_calls_from_dom( | |
| page, req_id, logger | |
| ) | |
| if dom_functions: | |
| if FUNCTION_CALLING_DEBUG: | |
| logger.info( | |
| f"[{req_id}] ✅ DOM captured function calls after {fc_retry + 1} attempts" | |
| ) | |
| break | |
| # Only retry if we haven't found functions and body is also empty | |
| # (indicates potential race condition with UI rendering) | |
| if accumulated_body: | |
| break # We have body text, no need to wait for functions | |
| await asyncio.sleep(0.3) | |
| if dom_functions: | |
| parsed_data["function"] = dom_functions | |
| has_seen_functions = True | |
| # If we have DOM text and accumulated body is empty, inject it to final chunk | |
| if dom_text and not accumulated_body: | |
| parsed_data["body"] = dom_text | |
| accumulated_body = dom_text | |
| yield parsed_data | |
| if parsed_data.get("done") is True: | |
| if ( | |
| accumulated_reason_len > 0 | |
| and len(accumulated_body) == 0 | |
| and not has_seen_functions | |
| ): | |
| logger.info( | |
| f"[{req_id}] ⚠️ Thinking-Only response detected. Waiting for DOM..." | |
| ) | |
| try: | |
| if page: | |
| pc = PageController(page, logger, req_id) | |
| for _ in range(20): | |
| await asyncio.sleep(0.5) | |
| dom_text = ( | |
| await pc.get_body_text_only_from_dom() | |
| ) | |
| if dom_text and len(dom_text.strip()) > 0: | |
| logger.info( | |
| f"[{req_id}] ✅ DOM captured body: {len(dom_text)} chars" | |
| ) | |
| yield { | |
| "body": dom_text, | |
| "reason": "", | |
| "done": False, | |
| } | |
| break | |
| except Exception as e: | |
| logger.error(f"[{req_id}] DOM Wait Error: {e}") | |
| break | |
| else: | |
| stale_done_ignored = False | |
| continue | |
| except (queue.Empty, asyncio.QueueEmpty): | |
| empty_count += 1 | |
| if ( | |
| enable_silence_detection | |
| and received_items_count >= min_items_before_silence_check | |
| and time.time() - last_packet_time > silence_threshold | |
| ): | |
| logger.info(f"[{req_id}] 🔇 Stream silence detected.") | |
| yield { | |
| "done": True, | |
| "reason": "silence_detected", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| if empty_count % 50 == 0: | |
| logger.info( | |
| f"[{req_id}] Waiting for data... ({empty_count}/{max_empty_retries})" | |
| ) | |
| if empty_count >= max_empty_retries: | |
| if GlobalState.IS_RECOVERING: | |
| empty_count = 0 | |
| continue | |
| if ( | |
| await check_ui_generation_active() | |
| and empty_count < hard_timeout_limit | |
| ): | |
| logger.warning(f"[{req_id}] Timeout but UI active. Snoozing...") | |
| empty_count = max(0, empty_count - int(max_empty_retries * 0.5)) | |
| continue | |
| elif empty_count >= hard_timeout_limit: | |
| logger.error(f"[{req_id}] HARD TIMEOUT REACHED!") | |
| yield { | |
| "done": True, | |
| "reason": "hard_timeout", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| yield { | |
| "done": True, | |
| "reason": "internal_timeout", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| if check_client_disconnected: | |
| try: | |
| check_client_disconnected(f"Stream Queue Wait ({req_id})") | |
| except ClientDisconnectedError: | |
| raise | |
| if received_items_count == 0 and empty_count >= initial_wait_limit: | |
| logger.error(f"[{req_id}] Stream has no data (TTFB Timeout).") | |
| try: | |
| from api_utils.server_state import state | |
| page_instance = state.page_instance | |
| if page_instance: | |
| await page_instance.reload() | |
| except Exception: | |
| pass | |
| yield { | |
| "done": True, | |
| "reason": "ttfb_timeout", | |
| "body": "", | |
| "function": [], | |
| } | |
| return | |
| if empty_count - last_ui_check_time >= ui_check_interval: | |
| if await check_ui_generation_active(): | |
| logger.info(f"[{req_id}] UI detected still generating...") | |
| last_ui_check_time = empty_count | |
| await asyncio.sleep(0.1) | |
| continue | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception as e: | |
| if isinstance(e, ClientDisconnectedError): | |
| raise e | |
| logger.error(f"[{req_id}] Error in stream generator: {e}", exc_info=True) | |
| raise | |
| finally: | |
| logger.info( | |
| f"[{req_id}] Stream response completed. Items: {received_items_count}" | |
| ) | |
| await clear_stream_queue() | |
| async def clear_stream_queue(): | |
| import queue | |
| from api_utils.server_state import state | |
| STREAM_QUEUE = state.STREAM_QUEUE | |
| logger = state.logger | |
| if STREAM_QUEUE is None: | |
| return | |
| cleared_count = 0 | |
| while True: | |
| try: | |
| await asyncio.to_thread(STREAM_QUEUE.get_nowait) | |
| cleared_count += 1 | |
| except queue.Empty: | |
| break | |
| except Exception: | |
| break | |
| if cleared_count > 0: | |
| logger.info(f"Stream queue cleared. Items: {cleared_count}") | |
| async def detect_function_calls_from_dom( | |
| page: Any, | |
| req_id: str, | |
| logger: Any, | |
| ) -> Tuple[List[dict], str]: | |
| """Fallback function call detection using DOM parsing. | |
| This is used when the network interceptor doesn't capture function calls | |
| (e.g., due to timing issues or format changes). | |
| Args: | |
| page: Playwright page instance. | |
| req_id: Request ID for logging. | |
| logger: Logger instance. | |
| Returns: | |
| Tuple of (List of function call dicts, text content). | |
| """ | |
| if not page: | |
| return [], "" | |
| try: | |
| from api_utils.utils_ext.function_call_response_parser import ( | |
| FunctionCallResponseParser, | |
| ) | |
| parser = FunctionCallResponseParser(page, logger, req_id) | |
| result = await parser.parse_function_calls() | |
| function_calls: List[dict] = [] | |
| if result.has_function_calls and result.function_calls: | |
| # Convert ParsedFunctionCall objects to dict format expected by stream | |
| for fc in result.function_calls: | |
| function_calls.append({"name": fc.name, "params": fc.arguments}) | |
| if FUNCTION_CALLING_DEBUG: | |
| logger.info( | |
| f"[{req_id}] DOM fallback detected {len(function_calls)} function call(s)" | |
| ) | |
| return function_calls, result.text_content | |
| except Exception as e: | |
| if FUNCTION_CALLING_DEBUG: | |
| logger.debug(f"[{req_id}] DOM function call detection failed: {e}") | |
| return [], "" | |