| | import asyncio |
| | import contextlib |
| | import json |
| | import logging |
| | import time |
| | import uuid |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | from io import BytesIO |
| | from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union |
| | from urllib.parse import urljoin, urlparse |
| |
|
| | import aiohttp |
| | from aiohttp.client_exceptions import ClientError, ContentTypeError |
| | from pydantic import BaseModel |
| |
|
| | from comfy import utils |
| | from comfy_api.latest import IO |
| | from comfy_api_nodes.apis import request_logger |
| | from server import PromptServer |
| |
|
| | from ._helpers import ( |
| | default_base_url, |
| | get_auth_header, |
| | get_node_id, |
| | is_processing_interrupted, |
| | sleep_with_interrupt, |
| | ) |
| | from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted |
| |
|
| | M = TypeVar("M", bound=BaseModel) |
| |
|
| |
|
| | class ApiEndpoint: |
| | def __init__( |
| | self, |
| | path: str, |
| | method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", |
| | *, |
| | query_params: Optional[dict[str, Any]] = None, |
| | headers: Optional[dict[str, str]] = None, |
| | ): |
| | self.path = path |
| | self.method = method |
| | self.query_params = query_params or {} |
| | self.headers = headers or {} |
| |
|
| |
|
| | @dataclass |
| | class _RequestConfig: |
| | node_cls: type[IO.ComfyNode] |
| | endpoint: ApiEndpoint |
| | timeout: float |
| | content_type: str |
| | data: Optional[dict[str, Any]] |
| | files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] |
| | multipart_parser: Optional[Callable] |
| | max_retries: int |
| | retry_delay: float |
| | retry_backoff: float |
| | wait_label: str = "Waiting" |
| | monitor_progress: bool = True |
| | estimated_total: Optional[int] = None |
| | final_label_on_success: Optional[str] = "Completed" |
| | progress_origin_ts: Optional[float] = None |
| |
|
| |
|
| | @dataclass |
| | class _PollUIState: |
| | started: float |
| | status_label: str = "Queued" |
| | is_queued: bool = True |
| | price: Optional[float] = None |
| | estimated_duration: Optional[int] = None |
| | base_processing_elapsed: float = 0.0 |
| | active_since: Optional[float] = None |
| |
|
| |
|
| | _RETRY_STATUS = {408, 429, 500, 502, 503, 504} |
| | COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] |
| | FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] |
| | QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] |
| |
|
| |
|
| | async def sync_op( |
| | cls: type[IO.ComfyNode], |
| | endpoint: ApiEndpoint, |
| | *, |
| | response_model: Type[M], |
| | data: Optional[BaseModel] = None, |
| | files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, |
| | content_type: str = "application/json", |
| | timeout: float = 3600.0, |
| | multipart_parser: Optional[Callable] = None, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff: float = 2.0, |
| | wait_label: str = "Waiting for server", |
| | estimated_duration: Optional[int] = None, |
| | final_label_on_success: Optional[str] = "Completed", |
| | progress_origin_ts: Optional[float] = None, |
| | monitor_progress: bool = True, |
| | ) -> M: |
| | raw = await sync_op_raw( |
| | cls, |
| | endpoint, |
| | data=data, |
| | files=files, |
| | content_type=content_type, |
| | timeout=timeout, |
| | multipart_parser=multipart_parser, |
| | max_retries=max_retries, |
| | retry_delay=retry_delay, |
| | retry_backoff=retry_backoff, |
| | wait_label=wait_label, |
| | estimated_duration=estimated_duration, |
| | as_binary=False, |
| | final_label_on_success=final_label_on_success, |
| | progress_origin_ts=progress_origin_ts, |
| | monitor_progress=monitor_progress, |
| | ) |
| | if not isinstance(raw, dict): |
| | raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") |
| | return _validate_or_raise(response_model, raw) |
| |
|
| |
|
| | async def poll_op( |
| | cls: type[IO.ComfyNode], |
| | poll_endpoint: ApiEndpoint, |
| | *, |
| | response_model: Type[M], |
| | status_extractor: Callable[[M], Optional[Union[str, int]]], |
| | progress_extractor: Optional[Callable[[M], Optional[int]]] = None, |
| | price_extractor: Optional[Callable[[M], Optional[float]]] = None, |
| | completed_statuses: Optional[list[Union[str, int]]] = None, |
| | failed_statuses: Optional[list[Union[str, int]]] = None, |
| | queued_statuses: Optional[list[Union[str, int]]] = None, |
| | data: Optional[BaseModel] = None, |
| | poll_interval: float = 5.0, |
| | max_poll_attempts: int = 120, |
| | timeout_per_poll: float = 120.0, |
| | max_retries_per_poll: int = 3, |
| | retry_delay_per_poll: float = 1.0, |
| | retry_backoff_per_poll: float = 2.0, |
| | estimated_duration: Optional[int] = None, |
| | cancel_endpoint: Optional[ApiEndpoint] = None, |
| | cancel_timeout: float = 10.0, |
| | ) -> M: |
| | raw = await poll_op_raw( |
| | cls, |
| | poll_endpoint=poll_endpoint, |
| | status_extractor=_wrap_model_extractor(response_model, status_extractor), |
| | progress_extractor=_wrap_model_extractor(response_model, progress_extractor), |
| | price_extractor=_wrap_model_extractor(response_model, price_extractor), |
| | completed_statuses=completed_statuses, |
| | failed_statuses=failed_statuses, |
| | queued_statuses=queued_statuses, |
| | data=data, |
| | poll_interval=poll_interval, |
| | max_poll_attempts=max_poll_attempts, |
| | timeout_per_poll=timeout_per_poll, |
| | max_retries_per_poll=max_retries_per_poll, |
| | retry_delay_per_poll=retry_delay_per_poll, |
| | retry_backoff_per_poll=retry_backoff_per_poll, |
| | estimated_duration=estimated_duration, |
| | cancel_endpoint=cancel_endpoint, |
| | cancel_timeout=cancel_timeout, |
| | ) |
| | if not isinstance(raw, dict): |
| | raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") |
| | return _validate_or_raise(response_model, raw) |
| |
|
| |
|
| | async def sync_op_raw( |
| | cls: type[IO.ComfyNode], |
| | endpoint: ApiEndpoint, |
| | *, |
| | data: Optional[Union[dict[str, Any], BaseModel]] = None, |
| | files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, |
| | content_type: str = "application/json", |
| | timeout: float = 3600.0, |
| | multipart_parser: Optional[Callable] = None, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff: float = 2.0, |
| | wait_label: str = "Waiting for server", |
| | estimated_duration: Optional[int] = None, |
| | as_binary: bool = False, |
| | final_label_on_success: Optional[str] = "Completed", |
| | progress_origin_ts: Optional[float] = None, |
| | monitor_progress: bool = True, |
| | ) -> Union[dict[str, Any], bytes]: |
| | """ |
| | Make a single network request. |
| | - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). |
| | - If as_binary=True: returns bytes. |
| | """ |
| | if isinstance(data, BaseModel): |
| | data = data.model_dump(exclude_none=True) |
| | for k, v in list(data.items()): |
| | if isinstance(v, Enum): |
| | data[k] = v.value |
| | cfg = _RequestConfig( |
| | node_cls=cls, |
| | endpoint=endpoint, |
| | timeout=timeout, |
| | content_type=content_type, |
| | data=data, |
| | files=files, |
| | multipart_parser=multipart_parser, |
| | max_retries=max_retries, |
| | retry_delay=retry_delay, |
| | retry_backoff=retry_backoff, |
| | wait_label=wait_label, |
| | monitor_progress=monitor_progress, |
| | estimated_total=estimated_duration, |
| | final_label_on_success=final_label_on_success, |
| | progress_origin_ts=progress_origin_ts, |
| | ) |
| | return await _request_base(cfg, expect_binary=as_binary) |
| |
|
| |
|
| | async def poll_op_raw( |
| | cls: type[IO.ComfyNode], |
| | poll_endpoint: ApiEndpoint, |
| | *, |
| | status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], |
| | progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, |
| | price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, |
| | completed_statuses: Optional[list[Union[str, int]]] = None, |
| | failed_statuses: Optional[list[Union[str, int]]] = None, |
| | queued_statuses: Optional[list[Union[str, int]]] = None, |
| | data: Optional[Union[dict[str, Any], BaseModel]] = None, |
| | poll_interval: float = 5.0, |
| | max_poll_attempts: int = 120, |
| | timeout_per_poll: float = 120.0, |
| | max_retries_per_poll: int = 3, |
| | retry_delay_per_poll: float = 1.0, |
| | retry_backoff_per_poll: float = 2.0, |
| | estimated_duration: Optional[int] = None, |
| | cancel_endpoint: Optional[ApiEndpoint] = None, |
| | cancel_timeout: float = 10.0, |
| | ) -> dict[str, Any]: |
| | """ |
| | Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, |
| | checks interruption every second, and calls Cancel endpoint (if provided) on interruption. |
| | |
| | Uses default complete, failed and queued states assumption. |
| | |
| | Returns the final JSON response from the poll endpoint. |
| | """ |
| | completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) |
| | failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) |
| | queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) |
| | started = time.monotonic() |
| | consumed_attempts = 0 |
| |
|
| | progress_bar = utils.ProgressBar(100) if progress_extractor else None |
| | last_progress: Optional[int] = None |
| |
|
| | state = _PollUIState(started=started, estimated_duration=estimated_duration) |
| | stop_ticker = asyncio.Event() |
| |
|
| | async def _ticker(): |
| | """Emit a UI update every second while polling is in progress.""" |
| | try: |
| | while not stop_ticker.is_set(): |
| | if is_processing_interrupted(): |
| | break |
| | now = time.monotonic() |
| | proc_elapsed = state.base_processing_elapsed + ( |
| | (now - state.active_since) if state.active_since is not None else 0.0 |
| | ) |
| | _display_time_progress( |
| | cls, |
| | status=state.status_label, |
| | elapsed_seconds=int(now - state.started), |
| | estimated_total=state.estimated_duration, |
| | price=state.price, |
| | is_queued=state.is_queued, |
| | processing_elapsed_seconds=int(proc_elapsed), |
| | ) |
| | await asyncio.sleep(1.0) |
| | except Exception as exc: |
| | logging.debug("Polling ticker exited: %s", exc) |
| |
|
| | ticker_task = asyncio.create_task(_ticker()) |
| | try: |
| | while consumed_attempts < max_poll_attempts: |
| | try: |
| | resp_json = await sync_op_raw( |
| | cls, |
| | poll_endpoint, |
| | data=data, |
| | timeout=timeout_per_poll, |
| | max_retries=max_retries_per_poll, |
| | retry_delay=retry_delay_per_poll, |
| | retry_backoff=retry_backoff_per_poll, |
| | wait_label="Checking", |
| | estimated_duration=None, |
| | as_binary=False, |
| | final_label_on_success=None, |
| | monitor_progress=False, |
| | ) |
| | if not isinstance(resp_json, dict): |
| | raise Exception("Polling endpoint returned non-JSON response.") |
| | except ProcessingInterrupted: |
| | if cancel_endpoint: |
| | with contextlib.suppress(Exception): |
| | await sync_op_raw( |
| | cls, |
| | cancel_endpoint, |
| | timeout=cancel_timeout, |
| | max_retries=0, |
| | wait_label="Cancelling task", |
| | estimated_duration=None, |
| | as_binary=False, |
| | final_label_on_success=None, |
| | monitor_progress=False, |
| | ) |
| | raise |
| |
|
| | try: |
| | status = _normalize_status_value(status_extractor(resp_json)) |
| | except Exception as e: |
| | logging.error("Status extraction failed: %s", e) |
| | status = None |
| |
|
| | if price_extractor: |
| | new_price = price_extractor(resp_json) |
| | if new_price is not None: |
| | state.price = new_price |
| |
|
| | if progress_extractor: |
| | new_progress = progress_extractor(resp_json) |
| | if new_progress is not None and last_progress != new_progress: |
| | progress_bar.update_absolute(new_progress, total=100) |
| | last_progress = new_progress |
| |
|
| | now_ts = time.monotonic() |
| | is_queued = status in queued_states |
| |
|
| | if is_queued: |
| | if state.active_since is not None: |
| | state.base_processing_elapsed += now_ts - state.active_since |
| | state.active_since = None |
| | else: |
| | if state.active_since is None: |
| | state.active_since = now_ts |
| |
|
| | state.is_queued = is_queued |
| | state.status_label = status or ("Queued" if is_queued else "Processing") |
| | if status in completed_states: |
| | if state.active_since is not None: |
| | state.base_processing_elapsed += now_ts - state.active_since |
| | state.active_since = None |
| | stop_ticker.set() |
| | with contextlib.suppress(Exception): |
| | await ticker_task |
| |
|
| | if progress_bar and last_progress != 100: |
| | progress_bar.update_absolute(100, total=100) |
| |
|
| | _display_time_progress( |
| | cls, |
| | status=status if status else "Completed", |
| | elapsed_seconds=int(now_ts - started), |
| | estimated_total=estimated_duration, |
| | price=state.price, |
| | is_queued=False, |
| | processing_elapsed_seconds=int(state.base_processing_elapsed), |
| | ) |
| | return resp_json |
| |
|
| | if status in failed_states: |
| | msg = f"Task failed: {json.dumps(resp_json)}" |
| | logging.error(msg) |
| | raise Exception(msg) |
| |
|
| | try: |
| | await sleep_with_interrupt(poll_interval, cls, None, None, None) |
| | except ProcessingInterrupted: |
| | if cancel_endpoint: |
| | with contextlib.suppress(Exception): |
| | await sync_op_raw( |
| | cls, |
| | cancel_endpoint, |
| | timeout=cancel_timeout, |
| | max_retries=0, |
| | wait_label="Cancelling task", |
| | estimated_duration=None, |
| | as_binary=False, |
| | final_label_on_success=None, |
| | monitor_progress=False, |
| | ) |
| | raise |
| | if not is_queued: |
| | consumed_attempts += 1 |
| |
|
| | raise Exception( |
| | f"Polling timed out after {max_poll_attempts} non-queued attempts " |
| | f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." |
| | ) |
| | except ProcessingInterrupted: |
| | raise |
| | except (LocalNetworkError, ApiServerError): |
| | raise |
| | except Exception as e: |
| | raise Exception(f"Polling aborted due to error: {e}") from e |
| | finally: |
| | stop_ticker.set() |
| | with contextlib.suppress(Exception): |
| | await ticker_task |
| |
|
| |
|
| | def _display_text( |
| | node_cls: type[IO.ComfyNode], |
| | text: Optional[str], |
| | *, |
| | status: Optional[Union[str, int]] = None, |
| | price: Optional[float] = None, |
| | ) -> None: |
| | display_lines: list[str] = [] |
| | if status: |
| | display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") |
| | if price is not None: |
| | display_lines.append(f"Price: ${float(price):,.4f}") |
| | if text is not None: |
| | display_lines.append(text) |
| | if display_lines: |
| | PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) |
| |
|
| |
|
| | def _display_time_progress( |
| | node_cls: type[IO.ComfyNode], |
| | status: Optional[Union[str, int]], |
| | elapsed_seconds: int, |
| | estimated_total: Optional[int] = None, |
| | *, |
| | price: Optional[float] = None, |
| | is_queued: Optional[bool] = None, |
| | processing_elapsed_seconds: Optional[int] = None, |
| | ) -> None: |
| | if estimated_total is not None and estimated_total > 0 and is_queued is False: |
| | pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds |
| | remaining = max(0, int(estimated_total) - int(pe)) |
| | time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" |
| | else: |
| | time_line = f"Time elapsed: {int(elapsed_seconds)}s" |
| | _display_text(node_cls, time_line, status=status, price=price) |
| |
|
| |
|
| | async def _diagnose_connectivity() -> dict[str, bool]: |
| | """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" |
| | results = { |
| | "internet_accessible": False, |
| | "api_accessible": False, |
| | } |
| | timeout = aiohttp.ClientTimeout(total=5.0) |
| | async with aiohttp.ClientSession(timeout=timeout) as session: |
| | with contextlib.suppress(ClientError, OSError): |
| | async with session.get("https://www.google.com") as resp: |
| | results["internet_accessible"] = resp.status < 500 |
| | if not results["internet_accessible"]: |
| | return results |
| |
|
| | parsed = urlparse(default_base_url()) |
| | health_url = f"{parsed.scheme}://{parsed.netloc}/health" |
| | with contextlib.suppress(ClientError, OSError): |
| | async with session.get(health_url) as resp: |
| | results["api_accessible"] = resp.status < 500 |
| | return results |
| |
|
| |
|
| | def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: |
| | """Normalize (filename, value, content_type).""" |
| | if len(t) == 2: |
| | return t[0], t[1], "application/octet-stream" |
| | if len(t) == 3: |
| | return t[0], t[1], t[2] |
| | raise ValueError("files tuple must be (filename, file[, content_type])") |
| |
|
| |
|
| | def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: |
| | params = dict(endpoint_params or {}) |
| | if method.upper() == "GET" and data: |
| | for k, v in data.items(): |
| | if v is not None: |
| | params[k] = v |
| | return params |
| |
|
| |
|
| | def _friendly_http_message(status: int, body: Any) -> str: |
| | if status == 401: |
| | return "Unauthorized: Please login first to use this node." |
| | if status == 402: |
| | return "Payment Required: Please add credits to your account to use this node." |
| | if status == 409: |
| | return "There is a problem with your account. Please contact support@comfy.org." |
| | if status == 429: |
| | return "Rate Limit Exceeded: Please try again later." |
| | try: |
| | if isinstance(body, dict): |
| | err = body.get("error") |
| | if isinstance(err, dict): |
| | msg = err.get("message") |
| | typ = err.get("type") |
| | if msg and typ: |
| | return f"API Error: {msg} (Type: {typ})" |
| | if msg: |
| | return f"API Error: {msg}" |
| | return f"API Error: {json.dumps(body)}" |
| | else: |
| | txt = str(body) |
| | if len(txt) <= 200: |
| | return f"API Error (raw): {txt}" |
| | return f"API Error (status {status})" |
| | except Exception: |
| | return f"HTTP {status}: Unknown error" |
| |
|
| |
|
| | def _generate_operation_id(method: str, path: str, attempt: int) -> str: |
| | slug = path.strip("/").replace("/", "_") or "op" |
| | return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" |
| |
|
| |
|
| | def _snapshot_request_body_for_logging( |
| | content_type: str, |
| | method: str, |
| | data: Optional[dict[str, Any]], |
| | files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], |
| | ) -> Optional[Union[dict[str, Any], str]]: |
| | if method.upper() == "GET": |
| | return None |
| | if content_type == "multipart/form-data": |
| | form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) |
| | file_fields: list[dict[str, str]] = [] |
| | if files: |
| | file_iter = files if isinstance(files, list) else list(files.items()) |
| | for field_name, file_obj in file_iter: |
| | if file_obj is None: |
| | continue |
| | if isinstance(file_obj, tuple): |
| | filename = file_obj[0] |
| | else: |
| | filename = getattr(file_obj, "name", field_name) |
| | file_fields.append({"field": field_name, "filename": str(filename or "")}) |
| | return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} |
| | if content_type == "application/x-www-form-urlencoded": |
| | return data or {} |
| | return data or {} |
| |
|
| |
|
| | async def _request_base(cfg: _RequestConfig, expect_binary: bool): |
| | """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" |
| | url = cfg.endpoint.path |
| | parsed_url = urlparse(url) |
| | if not parsed_url.scheme and not parsed_url.netloc: |
| | url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) |
| |
|
| | method = cfg.endpoint.method |
| | params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) |
| |
|
| | async def _monitor(stop_evt: asyncio.Event, start_ts: float): |
| | """Every second: update elapsed time and signal interruption.""" |
| | try: |
| | while not stop_evt.is_set(): |
| | if is_processing_interrupted(): |
| | return |
| | if cfg.monitor_progress: |
| | _display_time_progress( |
| | cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total |
| | ) |
| | await asyncio.sleep(1.0) |
| | except asyncio.CancelledError: |
| | return |
| |
|
| | start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() |
| | attempt = 0 |
| | delay = cfg.retry_delay |
| | operation_succeeded: bool = False |
| | final_elapsed_seconds: Optional[int] = None |
| | while True: |
| | attempt += 1 |
| | stop_event = asyncio.Event() |
| | monitor_task: Optional[asyncio.Task] = None |
| | sess: Optional[aiohttp.ClientSession] = None |
| |
|
| | operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) |
| | logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) |
| |
|
| | payload_headers = {"Accept": "*/*"} |
| | if not parsed_url.scheme and not parsed_url.netloc: |
| | payload_headers.update(get_auth_header(cfg.node_cls)) |
| | if cfg.endpoint.headers: |
| | payload_headers.update(cfg.endpoint.headers) |
| |
|
| | payload_kw: dict[str, Any] = {"headers": payload_headers} |
| | if method == "GET": |
| | payload_headers.pop("Content-Type", None) |
| | request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) |
| | try: |
| | if cfg.monitor_progress: |
| | monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) |
| |
|
| | timeout = aiohttp.ClientTimeout(total=cfg.timeout) |
| | sess = aiohttp.ClientSession(timeout=timeout) |
| |
|
| | if cfg.content_type == "multipart/form-data" and method != "GET": |
| | |
| | payload_headers.pop("Content-Type", None) |
| | if cfg.multipart_parser and cfg.data: |
| | form = cfg.multipart_parser(cfg.data) |
| | if not isinstance(form, aiohttp.FormData): |
| | raise ValueError("multipart_parser must return aiohttp.FormData") |
| | else: |
| | form = aiohttp.FormData(default_to_multipart=True) |
| | if cfg.data: |
| | for k, v in cfg.data.items(): |
| | if v is None: |
| | continue |
| | form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) |
| | if cfg.files: |
| | file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() |
| | for field_name, file_obj in file_iter: |
| | if file_obj is None: |
| | continue |
| | if isinstance(file_obj, tuple): |
| | filename, file_value, content_type = _unpack_tuple(file_obj) |
| | else: |
| | filename = getattr(file_obj, "name", field_name) |
| | file_value = file_obj |
| | content_type = "application/octet-stream" |
| | |
| | if isinstance(file_value, BytesIO): |
| | with contextlib.suppress(Exception): |
| | file_value.seek(0) |
| | form.add_field(field_name, file_value, filename=filename, content_type=content_type) |
| | payload_kw["data"] = form |
| | elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": |
| | payload_headers["Content-Type"] = "application/x-www-form-urlencoded" |
| | payload_kw["data"] = cfg.data or {} |
| | elif method != "GET": |
| | payload_headers["Content-Type"] = "application/json" |
| | payload_kw["json"] = cfg.data or {} |
| |
|
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | request_headers=dict(payload_headers) if payload_headers else None, |
| | request_params=dict(params) if params else None, |
| | request_data=request_body_log, |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] request logging failed: %s", _log_e) |
| |
|
| | req_coro = sess.request(method, url, params=params, **payload_kw) |
| | req_task = asyncio.create_task(req_coro) |
| |
|
| | |
| | tasks = {req_task} |
| | if monitor_task: |
| | tasks.add(monitor_task) |
| | done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) |
| |
|
| | if monitor_task and monitor_task in done: |
| | |
| | if req_task in pending: |
| | req_task.cancel() |
| | raise ProcessingInterrupted("Task cancelled") |
| |
|
| | |
| | resp = await req_task |
| | async with resp: |
| | if resp.status >= 400: |
| | try: |
| | body = await resp.json() |
| | except (ContentTypeError, json.JSONDecodeError): |
| | body = await resp.text() |
| | if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: |
| | logging.warning( |
| | "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", |
| | method, |
| | url, |
| | resp.status, |
| | delay, |
| | attempt, |
| | cfg.max_retries, |
| | ) |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | response_status_code=resp.status, |
| | response_headers=dict(resp.headers), |
| | response_content=body, |
| | error_message=_friendly_http_message(resp.status, body), |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] response logging failed: %s", _log_e) |
| |
|
| | await sleep_with_interrupt( |
| | delay, |
| | cfg.node_cls, |
| | cfg.wait_label if cfg.monitor_progress else None, |
| | start_time if cfg.monitor_progress else None, |
| | cfg.estimated_total, |
| | display_callback=_display_time_progress if cfg.monitor_progress else None, |
| | ) |
| | delay *= cfg.retry_backoff |
| | continue |
| | msg = _friendly_http_message(resp.status, body) |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | response_status_code=resp.status, |
| | response_headers=dict(resp.headers), |
| | response_content=body, |
| | error_message=msg, |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] response logging failed: %s", _log_e) |
| | raise Exception(msg) |
| |
|
| | if expect_binary: |
| | buff = bytearray() |
| | last_tick = time.monotonic() |
| | async for chunk in resp.content.iter_chunked(64 * 1024): |
| | buff.extend(chunk) |
| | now = time.monotonic() |
| | if now - last_tick >= 1.0: |
| | last_tick = now |
| | if is_processing_interrupted(): |
| | raise ProcessingInterrupted("Task cancelled") |
| | if cfg.monitor_progress: |
| | _display_time_progress( |
| | cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total |
| | ) |
| | bytes_payload = bytes(buff) |
| | operation_succeeded = True |
| | final_elapsed_seconds = int(time.monotonic() - start_time) |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | response_status_code=resp.status, |
| | response_headers=dict(resp.headers), |
| | response_content=bytes_payload, |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] response logging failed: %s", _log_e) |
| | return bytes_payload |
| | else: |
| | try: |
| | payload = await resp.json() |
| | response_content_to_log: Any = payload |
| | except (ContentTypeError, json.JSONDecodeError): |
| | text = await resp.text() |
| | try: |
| | payload = json.loads(text) if text else {} |
| | except json.JSONDecodeError: |
| | payload = {"_raw": text} |
| | response_content_to_log = payload if isinstance(payload, dict) else text |
| | operation_succeeded = True |
| | final_elapsed_seconds = int(time.monotonic() - start_time) |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | response_status_code=resp.status, |
| | response_headers=dict(resp.headers), |
| | response_content=response_content_to_log, |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] response logging failed: %s", _log_e) |
| | return payload |
| |
|
| | except ProcessingInterrupted: |
| | logging.debug("Polling was interrupted by user") |
| | raise |
| | except (ClientError, OSError) as e: |
| | if attempt <= cfg.max_retries: |
| | logging.warning( |
| | "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", |
| | method, |
| | url, |
| | delay, |
| | attempt, |
| | cfg.max_retries, |
| | str(e), |
| | ) |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | request_headers=dict(payload_headers) if payload_headers else None, |
| | request_params=dict(params) if params else None, |
| | request_data=request_body_log, |
| | error_message=f"{type(e).__name__}: {str(e)} (will retry)", |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] request error logging failed: %s", _log_e) |
| | await sleep_with_interrupt( |
| | delay, |
| | cfg.node_cls, |
| | cfg.wait_label if cfg.monitor_progress else None, |
| | start_time if cfg.monitor_progress else None, |
| | cfg.estimated_total, |
| | display_callback=_display_time_progress if cfg.monitor_progress else None, |
| | ) |
| | delay *= cfg.retry_backoff |
| | continue |
| | diag = await _diagnose_connectivity() |
| | if not diag["internet_accessible"]: |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | request_headers=dict(payload_headers) if payload_headers else None, |
| | request_params=dict(params) if params else None, |
| | request_data=request_body_log, |
| | error_message=f"LocalNetworkError: {str(e)}", |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] final error logging failed: %s", _log_e) |
| | raise LocalNetworkError( |
| | "Unable to connect to the API server due to local network issues. " |
| | "Please check your internet connection and try again." |
| | ) from e |
| | try: |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | request_headers=dict(payload_headers) if payload_headers else None, |
| | request_params=dict(params) if params else None, |
| | request_data=request_body_log, |
| | error_message=f"ApiServerError: {str(e)}", |
| | ) |
| | except Exception as _log_e: |
| | logging.debug("[DEBUG] final error logging failed: %s", _log_e) |
| | raise ApiServerError( |
| | f"The API server at {default_base_url()} is currently unreachable. " |
| | f"The service may be experiencing issues." |
| | ) from e |
| | finally: |
| | stop_event.set() |
| | if monitor_task: |
| | monitor_task.cancel() |
| | with contextlib.suppress(Exception): |
| | await monitor_task |
| | if sess: |
| | with contextlib.suppress(Exception): |
| | await sess.close() |
| | if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: |
| | _display_time_progress( |
| | cfg.node_cls, |
| | status=cfg.final_label_on_success, |
| | elapsed_seconds=( |
| | final_elapsed_seconds |
| | if final_elapsed_seconds is not None |
| | else int(time.monotonic() - start_time) |
| | ), |
| | estimated_total=cfg.estimated_total, |
| | price=None, |
| | is_queued=False, |
| | processing_elapsed_seconds=final_elapsed_seconds, |
| | ) |
| |
|
| |
|
| | def _validate_or_raise(response_model: Type[M], payload: Any) -> M: |
| | try: |
| | return response_model.model_validate(payload) |
| | except Exception as e: |
| | logging.error( |
| | "Response validation failed for %s: %s", |
| | getattr(response_model, "__name__", response_model), |
| | e, |
| | ) |
| | raise Exception( |
| | f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" |
| | ) from e |
| |
|
| |
|
| | def _wrap_model_extractor( |
| | response_model: Type[M], |
| | extractor: Optional[Callable[[M], Any]], |
| | ) -> Optional[Callable[[dict[str, Any]], Any]]: |
| | """Wrap a typed extractor so it can be used by the dict-based poller. |
| | Validates the dict into `response_model` before invoking `extractor`. |
| | Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating |
| | the same response for multiple extractors in a single poll attempt. |
| | """ |
| | if extractor is None: |
| | return None |
| | _cache: dict[int, M] = {} |
| |
|
| | def _wrapped(d: dict[str, Any]) -> Any: |
| | try: |
| | key = id(d) |
| | model = _cache.get(key) |
| | if model is None: |
| | model = response_model.model_validate(d) |
| | _cache[key] = model |
| | return extractor(model) |
| | except Exception as e: |
| | logging.error("Extractor failed (typed -> dict wrapper): %s", e) |
| | raise |
| |
|
| | return _wrapped |
| |
|
| |
|
| | def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: |
| | if not values: |
| | return set() |
| | out: set[Union[str, int]] = set() |
| | for v in values: |
| | nv = _normalize_status_value(v) |
| | if nv is not None: |
| | out.add(nv) |
| | return out |
| |
|
| |
|
| | def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: |
| | if isinstance(val, str): |
| | return val.strip().lower() |
| | return val |
| |
|