| | """ |
| | API Client Framework for api.comfy.org. |
| | |
| | This module provides a flexible framework for making API requests from ComfyUI nodes. |
| | It supports both synchronous and asynchronous API operations with proper type validation. |
| | |
| | Key Components: |
| | -------------- |
| | 1. ApiClient - Handles HTTP requests with authentication and error handling |
| | 2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models |
| | 3. ApiOperation - Executes a single synchronous API operation |
| | |
| | Usage Examples: |
| | -------------- |
| | |
| | # Example 1: Synchronous API Operation |
| | # ------------------------------------ |
| | # For a simple API call that returns the result immediately: |
| | |
| | # 1. Create the API client |
| | api_client = ApiClient( |
| | base_url="https://api.example.com", |
| | auth_token="your_auth_token_here", |
| | comfy_api_key="your_comfy_api_key_here", |
| | timeout=30.0, |
| | verify_ssl=True |
| | ) |
| | |
| | # 2. Define the endpoint |
| | user_info_endpoint = ApiEndpoint( |
| | path="/v1/users/me", |
| | method=HttpMethod.GET, |
| | request_model=EmptyRequest, # No request body needed |
| | response_model=UserProfile, # Pydantic model for the response |
| | query_params=None |
| | ) |
| | |
| | # 3. Create the request object |
| | request = EmptyRequest() |
| | |
| | # 4. Create and execute the operation |
| | operation = ApiOperation( |
| | endpoint=user_info_endpoint, |
| | request=request |
| | ) |
| | user_profile = operation.execute(client=api_client) # Returns immediately with the result |
| | |
| | |
| | # Example 2: Asynchronous API Operation with Polling |
| | # ------------------------------------------------- |
| | # For an API that starts a task and requires polling for completion: |
| | |
| | # 1. Define the endpoints (initial request and polling) |
| | generate_image_endpoint = ApiEndpoint( |
| | path="/v1/images/generate", |
| | method=HttpMethod.POST, |
| | request_model=ImageGenerationRequest, |
| | response_model=TaskCreatedResponse, |
| | query_params=None |
| | ) |
| | |
| | check_task_endpoint = ApiEndpoint( |
| | path="/v1/tasks/{task_id}", |
| | method=HttpMethod.GET, |
| | request_model=EmptyRequest, |
| | response_model=ImageGenerationResult, |
| | query_params=None |
| | ) |
| | |
| | # 2. Create the request object |
| | request = ImageGenerationRequest( |
| | prompt="a beautiful sunset over mountains", |
| | width=1024, |
| | height=1024, |
| | num_images=1 |
| | ) |
| | |
| | # 3. Create and execute the polling operation |
| | operation = PollingOperation( |
| | initial_endpoint=generate_image_endpoint, |
| | initial_request=request, |
| | poll_endpoint=check_task_endpoint, |
| | task_id_field="task_id", |
| | status_field="status", |
| | completed_statuses=["completed"], |
| | failed_statuses=["failed", "error"] |
| | ) |
| | |
| | # This will make the initial request and then poll until completion |
| | result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done |
| | """ |
| |
|
| | from __future__ import annotations |
| | import logging |
| | import time |
| | import io |
| | import socket |
| | from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple |
| | from enum import Enum |
| | import json |
| | import requests |
| | from urllib.parse import urljoin, urlparse |
| | from pydantic import BaseModel, Field |
| | import uuid |
| |
|
| | from server import PromptServer |
| | from comfy.cli_args import args |
| | from comfy import utils |
| | from . import request_logger |
| |
|
| | T = TypeVar("T", bound=BaseModel) |
| | R = TypeVar("R", bound=BaseModel) |
| | P = TypeVar("P", bound=BaseModel) |
| |
|
| | PROGRESS_BAR_MAX = 100 |
| |
|
| |
|
| | class NetworkError(Exception): |
| | """Base exception for network-related errors with diagnostic information.""" |
| | pass |
| |
|
| |
|
| | class LocalNetworkError(NetworkError): |
| | """Exception raised when local network connectivity issues are detected.""" |
| | pass |
| |
|
| |
|
| | class ApiServerError(NetworkError): |
| | """Exception raised when the API server is unreachable but internet is working.""" |
| | pass |
| |
|
| |
|
| | class EmptyRequest(BaseModel): |
| | """Base class for empty request bodies. |
| | For GET requests, fields will be sent as query parameters.""" |
| |
|
| | pass |
| |
|
| |
|
| | class UploadRequest(BaseModel): |
| | file_name: str = Field(..., description="Filename to upload") |
| | content_type: Optional[str] = Field( |
| | None, |
| | description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", |
| | ) |
| |
|
| |
|
| | class UploadResponse(BaseModel): |
| | download_url: str = Field(..., description="URL to GET uploaded file") |
| | upload_url: str = Field(..., description="URL to PUT file to upload") |
| |
|
| |
|
| | class HttpMethod(str, Enum): |
| | GET = "GET" |
| | POST = "POST" |
| | PUT = "PUT" |
| | DELETE = "DELETE" |
| | PATCH = "PATCH" |
| |
|
| |
|
| | class ApiClient: |
| | """ |
| | Client for making HTTP requests to an API with authentication, error handling, and retry logic. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | base_url: str, |
| | auth_token: Optional[str] = None, |
| | comfy_api_key: Optional[str] = None, |
| | timeout: float = 3600.0, |
| | verify_ssl: bool = True, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff_factor: float = 2.0, |
| | retry_status_codes: Optional[Tuple[int, ...]] = None, |
| | ): |
| | self.base_url = base_url |
| | self.auth_token = auth_token |
| | self.comfy_api_key = comfy_api_key |
| | self.timeout = timeout |
| | self.verify_ssl = verify_ssl |
| | self.max_retries = max_retries |
| | self.retry_delay = retry_delay |
| | self.retry_backoff_factor = retry_backoff_factor |
| | |
| | |
| | self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) |
| |
|
| | def _generate_operation_id(self, path: str) -> str: |
| | """Generates a unique operation ID for logging.""" |
| | return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" |
| |
|
| | def _create_json_payload_args( |
| | self, |
| | data: Optional[Dict[str, Any]] = None, |
| | headers: Optional[Dict[str, str]] = None, |
| | ) -> Dict[str, Any]: |
| | return { |
| | "json": data, |
| | "headers": headers, |
| | } |
| |
|
| | def _create_form_data_args( |
| | self, |
| | data: Dict[str, Any], |
| | files: Dict[str, Any], |
| | headers: Optional[Dict[str, str]] = None, |
| | multipart_parser = None, |
| | ) -> Dict[str, Any]: |
| | if headers and "Content-Type" in headers: |
| | del headers["Content-Type"] |
| |
|
| | if multipart_parser: |
| | data = multipart_parser(data) |
| |
|
| | return { |
| | "data": data, |
| | "files": files, |
| | "headers": headers, |
| | } |
| |
|
| | def _create_urlencoded_form_data_args( |
| | self, |
| | data: Dict[str, Any], |
| | headers: Optional[Dict[str, str]] = None, |
| | ) -> Dict[str, Any]: |
| | headers = headers or {} |
| | headers["Content-Type"] = "application/x-www-form-urlencoded" |
| |
|
| | return { |
| | "data": data, |
| | "headers": headers, |
| | } |
| |
|
| | def get_headers(self) -> Dict[str, str]: |
| | """Get headers for API requests, including authentication if available""" |
| | headers = {"Content-Type": "application/json", "Accept": "application/json"} |
| |
|
| | if self.auth_token: |
| | headers["Authorization"] = f"Bearer {self.auth_token}" |
| | elif self.comfy_api_key: |
| | headers["X-API-KEY"] = self.comfy_api_key |
| |
|
| | return headers |
| |
|
| | def _check_connectivity(self, target_url: str) -> Dict[str, bool]: |
| | """ |
| | Check connectivity to determine if network issues are local or server-related. |
| | |
| | Args: |
| | target_url: URL to check connectivity to |
| | |
| | Returns: |
| | Dictionary with connectivity status details |
| | """ |
| | results = { |
| | "internet_accessible": False, |
| | "api_accessible": False, |
| | "is_local_issue": False, |
| | "is_api_issue": False |
| | } |
| |
|
| | |
| | try: |
| | |
| | check_response = requests.get("https://www.google.com", |
| | timeout=5.0, |
| | verify=self.verify_ssl) |
| | if check_response.status_code < 500: |
| | results["internet_accessible"] = True |
| | except (requests.RequestException, socket.error): |
| | results["internet_accessible"] = False |
| | results["is_local_issue"] = True |
| | return results |
| |
|
| | |
| | try: |
| | |
| | parsed_url = urlparse(target_url) |
| | api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" |
| |
|
| | |
| | api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) |
| | if api_response.status_code < 500: |
| | results["api_accessible"] = True |
| | else: |
| | results["api_accessible"] = False |
| | results["is_api_issue"] = True |
| | except requests.RequestException: |
| | results["api_accessible"] = False |
| | |
| | results["is_api_issue"] = True |
| |
|
| | return results |
| |
|
| | def request( |
| | self, |
| | method: str, |
| | path: str, |
| | params: Optional[Dict[str, Any]] = None, |
| | data: Optional[Dict[str, Any]] = None, |
| | files: Optional[Dict[str, Any]] = None, |
| | headers: Optional[Dict[str, str]] = None, |
| | content_type: str = "application/json", |
| | multipart_parser: Callable = None, |
| | retry_count: int = 0, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Make an HTTP request to the API with automatic retries for transient errors. |
| | |
| | Args: |
| | method: HTTP method (GET, POST, etc.) |
| | path: API endpoint path (will be joined with base_url) |
| | params: Query parameters |
| | data: body data |
| | files: Files to upload |
| | headers: Additional headers |
| | content_type: Content type of the request. Defaults to application/json. |
| | retry_count: Internal parameter for tracking retries, do not set manually |
| | |
| | Returns: |
| | Parsed JSON response |
| | |
| | Raises: |
| | LocalNetworkError: If local network connectivity issues are detected |
| | ApiServerError: If the API server is unreachable but internet is working |
| | Exception: For other request failures |
| | """ |
| | |
| | relative_path = path.lstrip('/') |
| | url = urljoin(self.base_url, relative_path) |
| | self.check_auth(self.auth_token, self.comfy_api_key) |
| | |
| | request_headers = self.get_headers() |
| | if headers: |
| | request_headers.update(headers) |
| |
|
| | |
| | if files: |
| | del request_headers["Content-Type"] |
| |
|
| | logging.debug(f"[DEBUG] Request Headers: {request_headers}") |
| | logging.debug(f"[DEBUG] Files: {files}") |
| | logging.debug(f"[DEBUG] Params: {params}") |
| | logging.debug(f"[DEBUG] Data: {data}") |
| |
|
| | if content_type == "application/x-www-form-urlencoded": |
| | payload_args = self._create_urlencoded_form_data_args(data, request_headers) |
| | elif content_type == "multipart/form-data": |
| | payload_args = self._create_form_data_args( |
| | data, files, request_headers, multipart_parser |
| | ) |
| | else: |
| | payload_args = self._create_json_payload_args(data, request_headers) |
| |
|
| | operation_id = self._generate_operation_id(path) |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | request_headers=request_headers, |
| | request_params=params, |
| | request_data=data if content_type == "application/json" else "[form-data or other]" |
| | ) |
| |
|
| | try: |
| | response = requests.request( |
| | method=method, |
| | url=url, |
| | params=params, |
| | timeout=self.timeout, |
| | verify=self.verify_ssl, |
| | **payload_args, |
| | ) |
| |
|
| | |
| | if (response.status_code in self.retry_status_codes and |
| | retry_count < self.max_retries): |
| |
|
| | |
| | delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
| |
|
| | logging.warning( |
| | f"Request failed with status {response.status_code}. " |
| | f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" |
| | ) |
| |
|
| | time.sleep(delay) |
| | return self.request( |
| | method=method, |
| | path=path, |
| | params=params, |
| | data=data, |
| | files=files, |
| | headers=headers, |
| | content_type=content_type, |
| | multipart_parser=multipart_parser, |
| | retry_count=retry_count + 1, |
| | ) |
| |
|
| | |
| | response.raise_for_status() |
| |
|
| | |
| | response_content_to_log = response.content |
| | try: |
| | |
| | response_content_to_log = response.json() |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | response_status_code=response.status_code, |
| | response_headers=dict(response.headers), |
| | response_content=response_content_to_log |
| | ) |
| |
|
| | except requests.ConnectionError as e: |
| | error_message = f"ConnectionError: {str(e)}" |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, |
| | request_url=url, |
| | error_message=error_message |
| | ) |
| | |
| | if retry_count >= self.max_retries: |
| | |
| | connectivity = self._check_connectivity(self.base_url) |
| |
|
| | if connectivity["is_local_issue"]: |
| | raise LocalNetworkError( |
| | "Unable to connect to the API server due to local network issues. " |
| | "Please check your internet connection and try again." |
| | ) from e |
| | elif connectivity["is_api_issue"]: |
| | raise ApiServerError( |
| | f"The API server at {self.base_url} is currently unreachable. " |
| | f"The service may be experiencing issues. Please try again later." |
| | ) from e |
| |
|
| | |
| | if retry_count < self.max_retries: |
| | delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
| | logging.warning( |
| | f"Connection error: {str(e)}. " |
| | f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" |
| | ) |
| | time.sleep(delay) |
| | return self.request( |
| | method=method, |
| | path=path, |
| | params=params, |
| | data=data, |
| | files=files, |
| | headers=headers, |
| | content_type=content_type, |
| | multipart_parser=multipart_parser, |
| | retry_count=retry_count + 1, |
| | ) |
| |
|
| | |
| | |
| | final_error_message = ( |
| | f"Unable to connect to the API server after {self.max_retries} attempts. " |
| | f"Please check your internet connection or try again later." |
| | ) |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, request_url=url, |
| | error_message=final_error_message |
| | ) |
| | raise Exception(final_error_message) from e |
| |
|
| | except requests.Timeout as e: |
| | error_message = f"Timeout: {str(e)}" |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, request_url=url, |
| | error_message=error_message |
| | ) |
| | |
| | if retry_count < self.max_retries: |
| | delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
| | logging.warning( |
| | f"Request timed out. " |
| | f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" |
| | ) |
| | time.sleep(delay) |
| | return self.request( |
| | method=method, |
| | path=path, |
| | params=params, |
| | data=data, |
| | files=files, |
| | headers=headers, |
| | content_type=content_type, |
| | multipart_parser=multipart_parser, |
| | retry_count=retry_count + 1, |
| | ) |
| | final_error_message = ( |
| | f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " |
| | f"The server might be experiencing high load or the operation is taking longer than expected." |
| | ) |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, request_url=url, |
| | error_message=final_error_message |
| | ) |
| | raise Exception(final_error_message) from e |
| |
|
| | except requests.HTTPError as e: |
| | status_code = e.response.status_code if hasattr(e, "response") else None |
| | original_error_message = f"HTTP Error: {str(e)}" |
| | error_content_for_log = None |
| | if hasattr(e, "response") and e.response is not None: |
| | error_content_for_log = e.response.content |
| | try: |
| | error_content_for_log = e.response.json() |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| |
|
| | |
| | |
| | user_display_error_message = original_error_message |
| |
|
| | try: |
| | if hasattr(e, "response") and e.response is not None and e.response.content: |
| | error_json = e.response.json() |
| | if "error" in error_json and "message" in error_json["error"]: |
| | user_display_error_message = f"API Error: {error_json['error']['message']}" |
| | if "type" in error_json["error"]: |
| | user_display_error_message += f" (Type: {error_json['error']['type']})" |
| | elif isinstance(error_json, dict): |
| | user_display_error_message = f"API Error: {json.dumps(error_json)}" |
| | else: |
| | user_display_error_message = f"API Error: {str(error_json)}" |
| | except json.JSONDecodeError: |
| | |
| | if hasattr(e, "response") and e.response is not None and e.response.content: |
| | raw_content = e.response.content.decode(errors='ignore') |
| | if len(raw_content) < 200: |
| | user_display_error_message = f"API Error (raw): {raw_content}" |
| | else: |
| | user_display_error_message = f"API Error (raw, status {status_code})" |
| |
|
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method=method, request_url=url, |
| | response_status_code=status_code, |
| | response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, |
| | response_content=error_content_for_log, |
| | error_message=original_error_message |
| | ) |
| |
|
| | logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") |
| | if hasattr(e, "response") and e.response is not None and e.response.content: |
| | logging.debug(f"[DEBUG] Response content: {e.response.content}") |
| |
|
| | |
| | if (status_code in self.retry_status_codes and |
| | retry_count < self.max_retries): |
| |
|
| | delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
| | logging.warning( |
| | f"HTTP error {status_code}. " |
| | f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" |
| | ) |
| | time.sleep(delay) |
| | return self.request( |
| | method=method, |
| | path=path, |
| | params=params, |
| | data=data, |
| | files=files, |
| | headers=headers, |
| | content_type=content_type, |
| | multipart_parser=multipart_parser, |
| | retry_count=retry_count + 1, |
| | ) |
| |
|
| | |
| | if status_code == 401: |
| | user_display_error_message = "Unauthorized: Please login first to use this node." |
| | elif status_code == 402: |
| | user_display_error_message = "Payment Required: Please add credits to your account to use this node." |
| | elif status_code == 409: |
| | user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." |
| | elif status_code == 429: |
| | user_display_error_message = "Rate Limit Exceeded: Please try again later." |
| | |
| |
|
| | raise Exception(user_display_error_message) |
| |
|
| | |
| | if response.content: |
| | return response.json() |
| | return {} |
| |
|
| | def check_auth(self, auth_token, comfy_api_key): |
| | """Verify that an auth token is present or comfy_api_key is present""" |
| | if auth_token is None and comfy_api_key is None: |
| | raise Exception("Unauthorized: Please login first to use this node.") |
| | return auth_token or comfy_api_key |
| |
|
| | @staticmethod |
| | def upload_file( |
| | upload_url: str, |
| | file: io.BytesIO | str, |
| | content_type: str | None = None, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff_factor: float = 2.0, |
| | ): |
| | """Upload a file to the API with retry logic. |
| | |
| | Args: |
| | upload_url: The URL to upload to |
| | file: Either a file path string, BytesIO object, or tuple of (file_path, filename) |
| | content_type: Optional mime type to set for the upload |
| | max_retries: Maximum number of retry attempts |
| | retry_delay: Initial delay between retries in seconds |
| | retry_backoff_factor: Multiplier for the delay after each retry |
| | """ |
| | headers = {} |
| | if content_type: |
| | headers["Content-Type"] = content_type |
| |
|
| | |
| | if isinstance(file, io.BytesIO): |
| | file.seek(0) |
| | data = file.read() |
| | elif isinstance(file, str): |
| | with open(file, "rb") as f: |
| | data = f.read() |
| | else: |
| | raise ValueError("File must be either a BytesIO object or a file path string") |
| |
|
| | |
| | last_exception = None |
| | operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" |
| |
|
| | |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method="PUT", |
| | request_url=upload_url, |
| | request_headers=headers, |
| | request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" |
| | ) |
| |
|
| | for retry_attempt in range(max_retries + 1): |
| | try: |
| | response = requests.put(upload_url, data=data, headers=headers) |
| | response.raise_for_status() |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method="PUT", request_url=upload_url, |
| | response_status_code=response.status_code, |
| | response_headers=dict(response.headers), |
| | response_content="File uploaded successfully." |
| | ) |
| | return response |
| |
|
| | except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: |
| | last_exception = e |
| | error_message_for_log = f"{type(e).__name__}: {str(e)}" |
| | response_content_for_log = None |
| | status_code_for_log = None |
| | headers_for_log = None |
| |
|
| | if hasattr(e, 'response') and e.response is not None: |
| | status_code_for_log = e.response.status_code |
| | headers_for_log = dict(e.response.headers) |
| | try: |
| | response_content_for_log = e.response.json() |
| | except json.JSONDecodeError: |
| | response_content_for_log = e.response.content |
| |
|
| |
|
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method="PUT", request_url=upload_url, |
| | response_status_code=status_code_for_log, |
| | response_headers=headers_for_log, |
| | response_content=response_content_for_log, |
| | error_message=error_message_for_log |
| | ) |
| |
|
| | if retry_attempt < max_retries: |
| | delay = retry_delay * (retry_backoff_factor ** retry_attempt) |
| | logging.warning( |
| | f"File upload failed: {str(e)}. " |
| | f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" |
| | ) |
| | time.sleep(delay) |
| | else: |
| | break |
| |
|
| | |
| | final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" |
| | try: |
| | |
| | check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) |
| | if check_response.status_code >= 500: |
| | final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " |
| | f"(status {check_response.status_code}). Original error: {str(last_exception)}") |
| | |
| | |
| | |
| | |
| |
|
| | except (requests.RequestException, socket.error) as conn_check_exc: |
| | |
| | final_error_message = (f"Failed to upload file due to network connectivity issues " |
| | f"(cannot reach Google: {str(conn_check_exc)}). " |
| | f"Original upload error: {str(last_exception)}") |
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method="PUT", request_url=upload_url, |
| | error_message=final_error_message |
| | ) |
| | raise LocalNetworkError(final_error_message) from last_exception |
| |
|
| | request_logger.log_request_response( |
| | operation_id=operation_id, |
| | request_method="PUT", request_url=upload_url, |
| | error_message=final_error_message |
| | ) |
| | raise Exception(final_error_message) from last_exception |
| |
|
| |
|
| | class ApiEndpoint(Generic[T, R]): |
| | """Defines an API endpoint with its request and response types""" |
| |
|
| | def __init__( |
| | self, |
| | path: str, |
| | method: HttpMethod, |
| | request_model: Type[T], |
| | response_model: Type[R], |
| | query_params: Optional[Dict[str, Any]] = None, |
| | ): |
| | """Initialize an API endpoint definition. |
| | |
| | Args: |
| | path: The URL path for this endpoint, can include placeholders like {id} |
| | method: The HTTP method to use (GET, POST, etc.) |
| | request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint |
| | response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint |
| | query_params: Optional dictionary of query parameters to include in the request |
| | """ |
| | self.path = path |
| | self.method = method |
| | self.request_model = request_model |
| | self.response_model = response_model |
| | self.query_params = query_params or {} |
| |
|
| |
|
| | class SynchronousOperation(Generic[T, R]): |
| | """ |
| | Represents a single synchronous API operation. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | endpoint: ApiEndpoint[T, R], |
| | request: T, |
| | files: Optional[Dict[str, Any]] = None, |
| | api_base: str | None = None, |
| | auth_token: Optional[str] = None, |
| | comfy_api_key: Optional[str] = None, |
| | auth_kwargs: Optional[Dict[str,str]] = None, |
| | timeout: float = 604800.0, |
| | verify_ssl: bool = True, |
| | content_type: str = "application/json", |
| | multipart_parser: Callable = None, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff_factor: float = 2.0, |
| | ): |
| | self.endpoint = endpoint |
| | self.request = request |
| | self.response = None |
| | self.error = None |
| | self.api_base: str = api_base or args.comfy_api_base |
| | self.auth_token = auth_token |
| | self.comfy_api_key = comfy_api_key |
| | if auth_kwargs is not None: |
| | self.auth_token = auth_kwargs.get("auth_token", self.auth_token) |
| | self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) |
| | self.timeout = timeout |
| | self.verify_ssl = verify_ssl |
| | self.files = files |
| | self.content_type = content_type |
| | self.multipart_parser = multipart_parser |
| | self.max_retries = max_retries |
| | self.retry_delay = retry_delay |
| | self.retry_backoff_factor = retry_backoff_factor |
| |
|
| | def execute(self, client: Optional[ApiClient] = None) -> R: |
| | """Execute the API operation using the provided client or create one with retry support""" |
| | try: |
| | |
| | if client is None: |
| | client = ApiClient( |
| | base_url=self.api_base, |
| | auth_token=self.auth_token, |
| | comfy_api_key=self.comfy_api_key, |
| | timeout=self.timeout, |
| | verify_ssl=self.verify_ssl, |
| | max_retries=self.max_retries, |
| | retry_delay=self.retry_delay, |
| | retry_backoff_factor=self.retry_backoff_factor, |
| | ) |
| |
|
| | |
| | request_dict = ( |
| | None |
| | if isinstance(self.request, EmptyRequest) |
| | else self.request.model_dump(exclude_none=True) |
| | ) |
| | if request_dict: |
| | for key, value in request_dict.items(): |
| | if isinstance(value, Enum): |
| | request_dict[key] = value.value |
| |
|
| | |
| | logging.debug( |
| | f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" |
| | ) |
| | logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") |
| | logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") |
| |
|
| | |
| | resp = client.request( |
| | method=self.endpoint.method.value, |
| | path=self.endpoint.path, |
| | data=request_dict, |
| | params=self.endpoint.query_params, |
| | files=self.files, |
| | content_type=self.content_type, |
| | multipart_parser=self.multipart_parser |
| | ) |
| |
|
| | |
| | logging.debug("=" * 50) |
| | logging.debug("[DEBUG] RESPONSE DETAILS:") |
| | logging.debug("[DEBUG] Status Code: 200 (Success)") |
| | logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") |
| | logging.debug("=" * 50) |
| |
|
| | |
| | return self._parse_response(resp) |
| |
|
| | except LocalNetworkError as e: |
| | |
| | logging.error(f"[ERROR] Local network error: {str(e)}") |
| | raise |
| |
|
| | except ApiServerError as e: |
| | |
| | logging.error(f"[ERROR] API server error: {str(e)}") |
| | raise |
| |
|
| | except Exception as e: |
| | logging.error(f"[ERROR] API Exception: {str(e)}") |
| | raise Exception(str(e)) |
| |
|
| | def _parse_response(self, resp): |
| | """Parse response data - can be overridden by subclasses""" |
| | |
| | |
| |
|
| | |
| | self.response = self.endpoint.response_model.model_validate(resp) |
| | logging.debug(f"[DEBUG] Parsed Response: {self.response}") |
| | return self.response |
| |
|
| |
|
| | class TaskStatus(str, Enum): |
| | """Enum for task status values""" |
| |
|
| | COMPLETED = "completed" |
| | FAILED = "failed" |
| | PENDING = "pending" |
| |
|
| |
|
| | class PollingOperation(Generic[T, R]): |
| | """ |
| | Represents an asynchronous API operation that requires polling for completion. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | poll_endpoint: ApiEndpoint[EmptyRequest, R], |
| | completed_statuses: list, |
| | failed_statuses: list, |
| | status_extractor: Callable[[R], str], |
| | progress_extractor: Callable[[R], float] = None, |
| | result_url_extractor: Callable[[R], str] = None, |
| | request: Optional[T] = None, |
| | api_base: str | None = None, |
| | auth_token: Optional[str] = None, |
| | comfy_api_key: Optional[str] = None, |
| | auth_kwargs: Optional[Dict[str,str]] = None, |
| | poll_interval: float = 5.0, |
| | max_poll_attempts: int = 120, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0, |
| | retry_backoff_factor: float = 2.0, |
| | estimated_duration: Optional[float] = None, |
| | node_id: Optional[str] = None, |
| | ): |
| | self.poll_endpoint = poll_endpoint |
| | self.request = request |
| | self.api_base: str = api_base or args.comfy_api_base |
| | self.auth_token = auth_token |
| | self.comfy_api_key = comfy_api_key |
| | if auth_kwargs is not None: |
| | self.auth_token = auth_kwargs.get("auth_token", self.auth_token) |
| | self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) |
| | self.poll_interval = poll_interval |
| | self.max_poll_attempts = max_poll_attempts |
| | self.max_retries = max_retries |
| | self.retry_delay = retry_delay |
| | self.retry_backoff_factor = retry_backoff_factor |
| | self.estimated_duration = estimated_duration |
| |
|
| | |
| | self.status_extractor = status_extractor or ( |
| | lambda x: getattr(x, "status", None) |
| | ) |
| | self.progress_extractor = progress_extractor |
| | self.result_url_extractor = result_url_extractor |
| | self.node_id = node_id |
| | self.completed_statuses = completed_statuses |
| | self.failed_statuses = failed_statuses |
| |
|
| | |
| | self.final_response = None |
| | self.error = None |
| |
|
| | def execute(self, client: Optional[ApiClient] = None) -> R: |
| | """Execute the polling operation using the provided client. If failed, raise an exception.""" |
| | try: |
| | if client is None: |
| | client = ApiClient( |
| | base_url=self.api_base, |
| | auth_token=self.auth_token, |
| | comfy_api_key=self.comfy_api_key, |
| | max_retries=self.max_retries, |
| | retry_delay=self.retry_delay, |
| | retry_backoff_factor=self.retry_backoff_factor, |
| | ) |
| | return self._poll_until_complete(client) |
| | except LocalNetworkError as e: |
| | |
| | raise Exception( |
| | f"Polling failed due to local network issues. Please check your internet connection. " |
| | f"Details: {str(e)}" |
| | ) from e |
| | except ApiServerError as e: |
| | |
| | raise Exception( |
| | f"Polling failed due to API server issues. The service may be experiencing problems. " |
| | f"Please try again later. Details: {str(e)}" |
| | ) from e |
| | except Exception as e: |
| | raise Exception(f"Error during polling: {str(e)}") |
| |
|
| | def _display_text_on_node(self, text: str): |
| | """Sends text to the client which will be displayed on the node in the UI""" |
| | if not self.node_id: |
| | return |
| |
|
| | PromptServer.instance.send_progress_text(text, self.node_id) |
| |
|
| | def _display_time_progress_on_node(self, time_completed: int): |
| | if not self.node_id: |
| | return |
| |
|
| | if self.estimated_duration is not None: |
| | estimated_time_remaining = max( |
| | 0, int(self.estimated_duration) - int(time_completed) |
| | ) |
| | message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" |
| | else: |
| | message = f"Task in progress: {time_completed:.0f}s" |
| | self._display_text_on_node(message) |
| |
|
| | def _check_task_status(self, response: R) -> TaskStatus: |
| | """Check task status using the status extractor function""" |
| | try: |
| | status = self.status_extractor(response) |
| | if status in self.completed_statuses: |
| | return TaskStatus.COMPLETED |
| | elif status in self.failed_statuses: |
| | return TaskStatus.FAILED |
| | return TaskStatus.PENDING |
| | except Exception as e: |
| | logging.error(f"Error extracting status: {e}") |
| | return TaskStatus.PENDING |
| |
|
| | def _poll_until_complete(self, client: ApiClient) -> R: |
| | """Poll until the task is complete""" |
| | poll_count = 0 |
| | consecutive_errors = 0 |
| | max_consecutive_errors = min(5, self.max_retries * 2) |
| |
|
| | if self.progress_extractor: |
| | progress = utils.ProgressBar(PROGRESS_BAR_MAX) |
| |
|
| | while poll_count < self.max_poll_attempts: |
| | try: |
| | poll_count += 1 |
| | logging.debug(f"[DEBUG] Polling attempt #{poll_count}") |
| |
|
| | request_dict = ( |
| | self.request.model_dump(exclude_none=True) |
| | if self.request is not None |
| | else None |
| | ) |
| |
|
| | if poll_count == 1: |
| | logging.debug( |
| | f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" |
| | ) |
| | logging.debug( |
| | f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" |
| | ) |
| |
|
| | |
| | resp = client.request( |
| | method=self.poll_endpoint.method.value, |
| | path=self.poll_endpoint.path, |
| | params=self.poll_endpoint.query_params, |
| | data=request_dict, |
| | ) |
| |
|
| | |
| | consecutive_errors = 0 |
| |
|
| | |
| | response_obj = self.poll_endpoint.response_model.model_validate(resp) |
| |
|
| | |
| | status = self._check_task_status(response_obj) |
| | logging.debug(f"[DEBUG] Task Status: {status}") |
| |
|
| | |
| | if self.progress_extractor: |
| | new_progress = self.progress_extractor(response_obj) |
| | if new_progress is not None: |
| | progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) |
| |
|
| | if status == TaskStatus.COMPLETED: |
| | message = "Task completed successfully" |
| | if self.result_url_extractor: |
| | result_url = self.result_url_extractor(response_obj) |
| | if result_url: |
| | message = f"Result URL: {result_url}" |
| | else: |
| | message = "Task completed successfully!" |
| | logging.debug(f"[DEBUG] {message}") |
| | self._display_text_on_node(message) |
| | self.final_response = response_obj |
| | if self.progress_extractor: |
| | progress.update(100) |
| | return self.final_response |
| | elif status == TaskStatus.FAILED: |
| | message = f"Task failed: {json.dumps(resp)}" |
| | logging.error(f"[DEBUG] {message}") |
| | raise Exception(message) |
| | else: |
| | logging.debug("[DEBUG] Task still pending, continuing to poll...") |
| |
|
| | |
| | logging.debug( |
| | f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" |
| | ) |
| | for i in range(int(self.poll_interval)): |
| | time_completed = (poll_count * self.poll_interval) + i |
| | self._display_time_progress_on_node(time_completed) |
| | time.sleep(1) |
| |
|
| | except (LocalNetworkError, ApiServerError) as e: |
| | |
| | consecutive_errors += 1 |
| | if consecutive_errors >= max_consecutive_errors: |
| | raise Exception( |
| | f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" |
| | ) from e |
| |
|
| | |
| | logging.warning( |
| | f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " |
| | f"Will retry in {self.poll_interval} seconds." |
| | ) |
| | time.sleep(self.poll_interval) |
| |
|
| | except Exception as e: |
| | |
| | consecutive_errors += 1 |
| | if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: |
| | raise Exception( |
| | f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" |
| | ) from e |
| |
|
| | logging.error(f"[DEBUG] Polling error: {str(e)}") |
| | logging.warning( |
| | f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " |
| | f"Will retry in {self.poll_interval} seconds." |
| | ) |
| | time.sleep(self.poll_interval) |
| |
|
| | |
| | raise Exception( |
| | f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " |
| | f"The operation may still be running on the server but is taking longer than expected." |
| | ) |
| |
|