|
|
import asyncio |
|
|
import contextlib |
|
|
import logging |
|
|
import time |
|
|
import uuid |
|
|
from io import BytesIO |
|
|
from typing import Optional, Union |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
import aiohttp |
|
|
import torch |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from comfy_api.latest import IO, Input |
|
|
from comfy_api.util import VideoCodec, VideoContainer |
|
|
from comfy_api_nodes.apis import request_logger |
|
|
|
|
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt |
|
|
from .client import ( |
|
|
ApiEndpoint, |
|
|
_diagnose_connectivity, |
|
|
_display_time_progress, |
|
|
sync_op, |
|
|
) |
|
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted |
|
|
from .conversions import ( |
|
|
audio_ndarray_to_bytesio, |
|
|
audio_tensor_to_contiguous_ndarray, |
|
|
tensor_to_bytesio, |
|
|
) |
|
|
|
|
|
|
|
|
class UploadRequest(BaseModel): |
|
|
file_name: str = Field(..., description="Filename to upload") |
|
|
content_type: Optional[str] = Field( |
|
|
None, |
|
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", |
|
|
) |
|
|
|
|
|
|
|
|
class UploadResponse(BaseModel): |
|
|
download_url: str = Field(..., description="URL to GET uploaded file") |
|
|
upload_url: str = Field(..., description="URL to PUT file to upload") |
|
|
|
|
|
|
|
|
async def upload_images_to_comfyapi( |
|
|
cls: type[IO.ComfyNode], |
|
|
image: torch.Tensor, |
|
|
*, |
|
|
max_images: int = 8, |
|
|
mime_type: Optional[str] = None, |
|
|
wait_label: Optional[str] = "Uploading", |
|
|
) -> list[str]: |
|
|
""" |
|
|
Uploads images to ComfyUI API and returns download URLs. |
|
|
To upload multiple images, stack them in the batch dimension first. |
|
|
""" |
|
|
|
|
|
download_urls: list[str] = [] |
|
|
is_batch = len(image.shape) > 3 |
|
|
batch_len = image.shape[0] if is_batch else 1 |
|
|
|
|
|
for idx in range(min(batch_len, max_images)): |
|
|
tensor = image[idx] if is_batch else image |
|
|
img_io = tensor_to_bytesio(tensor, mime_type=mime_type) |
|
|
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) |
|
|
download_urls.append(url) |
|
|
return download_urls |
|
|
|
|
|
|
|
|
async def upload_audio_to_comfyapi( |
|
|
cls: type[IO.ComfyNode], |
|
|
audio: Input.Audio, |
|
|
*, |
|
|
container_format: str = "mp4", |
|
|
codec_name: str = "aac", |
|
|
mime_type: str = "audio/mp4", |
|
|
filename: str = "uploaded_audio.mp4", |
|
|
) -> str: |
|
|
""" |
|
|
Uploads a single audio input to ComfyUI API and returns its download URL. |
|
|
Encodes the raw waveform into the specified format before uploading. |
|
|
""" |
|
|
sample_rate: int = audio["sample_rate"] |
|
|
waveform: torch.Tensor = audio["waveform"] |
|
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) |
|
|
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) |
|
|
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) |
|
|
|
|
|
|
|
|
async def upload_video_to_comfyapi( |
|
|
cls: type[IO.ComfyNode], |
|
|
video: Input.Video, |
|
|
*, |
|
|
container: VideoContainer = VideoContainer.MP4, |
|
|
codec: VideoCodec = VideoCodec.H264, |
|
|
max_duration: Optional[int] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Uploads a single video to ComfyUI API and returns its download URL. |
|
|
Uses the specified container and codec for saving the video before upload. |
|
|
""" |
|
|
if max_duration is not None: |
|
|
try: |
|
|
actual_duration = video.get_duration() |
|
|
if actual_duration > max_duration: |
|
|
raise ValueError( |
|
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." |
|
|
) |
|
|
except Exception as e: |
|
|
logging.error("Error getting video duration: %s", str(e)) |
|
|
raise ValueError(f"Could not verify video duration from source: {e}") from e |
|
|
|
|
|
upload_mime_type = f"video/{container.value.lower()}" |
|
|
filename = f"uploaded_video.{container.value.lower()}" |
|
|
|
|
|
|
|
|
video_bytes_io = BytesIO() |
|
|
video.save_to(video_bytes_io, format=container, codec=codec) |
|
|
video_bytes_io.seek(0) |
|
|
|
|
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) |
|
|
|
|
|
|
|
|
async def upload_file_to_comfyapi( |
|
|
cls: type[IO.ComfyNode], |
|
|
file_bytes_io: BytesIO, |
|
|
filename: str, |
|
|
upload_mime_type: Optional[str], |
|
|
wait_label: Optional[str] = "Uploading", |
|
|
) -> str: |
|
|
"""Uploads a single file to ComfyUI API and returns its download URL.""" |
|
|
if upload_mime_type is None: |
|
|
request_object = UploadRequest(file_name=filename) |
|
|
else: |
|
|
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) |
|
|
create_resp = await sync_op( |
|
|
cls, |
|
|
endpoint=ApiEndpoint(path="/customers/storage", method="POST"), |
|
|
data=request_object, |
|
|
response_model=UploadResponse, |
|
|
final_label_on_success=None, |
|
|
monitor_progress=False, |
|
|
) |
|
|
await upload_file( |
|
|
cls, |
|
|
create_resp.upload_url, |
|
|
file_bytes_io, |
|
|
content_type=upload_mime_type, |
|
|
wait_label=wait_label, |
|
|
) |
|
|
return create_resp.download_url |
|
|
|
|
|
|
|
|
async def upload_file( |
|
|
cls: type[IO.ComfyNode], |
|
|
upload_url: str, |
|
|
file: Union[BytesIO, str], |
|
|
*, |
|
|
content_type: Optional[str] = None, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0, |
|
|
retry_backoff: float = 2.0, |
|
|
wait_label: Optional[str] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. |
|
|
|
|
|
Args: |
|
|
cls: Node class (provides auth context + UI progress hooks). |
|
|
upload_url: Pre-signed PUT URL. |
|
|
file: BytesIO or path string. |
|
|
content_type: Explicit MIME type. If None, we *suppress* Content-Type. |
|
|
max_retries: Maximum retry attempts. |
|
|
retry_delay: Initial delay in seconds. |
|
|
retry_backoff: Exponential backoff factor. |
|
|
wait_label: Progress label shown in Comfy UI. |
|
|
|
|
|
Raises: |
|
|
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception |
|
|
""" |
|
|
if isinstance(file, BytesIO): |
|
|
with contextlib.suppress(Exception): |
|
|
file.seek(0) |
|
|
data = file.read() |
|
|
elif isinstance(file, str): |
|
|
with open(file, "rb") as f: |
|
|
data = f.read() |
|
|
else: |
|
|
raise ValueError("file must be a BytesIO or a filesystem path string") |
|
|
|
|
|
headers: dict[str, str] = {} |
|
|
skip_auto_headers: set[str] = set() |
|
|
if content_type: |
|
|
headers["Content-Type"] = content_type |
|
|
else: |
|
|
skip_auto_headers.add("Content-Type") |
|
|
|
|
|
attempt = 0 |
|
|
delay = retry_delay |
|
|
start_ts = time.monotonic() |
|
|
op_uuid = uuid.uuid4().hex[:8] |
|
|
while True: |
|
|
attempt += 1 |
|
|
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) |
|
|
timeout = aiohttp.ClientTimeout(total=None) |
|
|
stop_evt = asyncio.Event() |
|
|
|
|
|
async def _monitor(): |
|
|
try: |
|
|
while not stop_evt.is_set(): |
|
|
if is_processing_interrupted(): |
|
|
return |
|
|
if wait_label: |
|
|
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) |
|
|
await asyncio.sleep(1.0) |
|
|
except asyncio.CancelledError: |
|
|
return |
|
|
|
|
|
monitor_task = asyncio.create_task(_monitor()) |
|
|
sess: Optional[aiohttp.ClientSession] = None |
|
|
try: |
|
|
try: |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
request_headers=headers or None, |
|
|
request_params=None, |
|
|
request_data=f"[File data {len(data)} bytes]", |
|
|
) |
|
|
except Exception as e: |
|
|
logging.debug("[DEBUG] upload request logging failed: %s", e) |
|
|
|
|
|
sess = aiohttp.ClientSession(timeout=timeout) |
|
|
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) |
|
|
req_task = asyncio.create_task(req) |
|
|
|
|
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) |
|
|
|
|
|
if monitor_task in done and req_task in pending: |
|
|
req_task.cancel() |
|
|
raise ProcessingInterrupted("Upload cancelled") |
|
|
|
|
|
try: |
|
|
resp = await req_task |
|
|
except asyncio.CancelledError: |
|
|
raise ProcessingInterrupted("Upload cancelled") from None |
|
|
|
|
|
async with resp: |
|
|
if resp.status >= 400: |
|
|
with contextlib.suppress(Exception): |
|
|
try: |
|
|
body = await resp.json() |
|
|
except Exception: |
|
|
body = await resp.text() |
|
|
msg = f"Upload failed with status {resp.status}" |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
response_status_code=resp.status, |
|
|
response_headers=dict(resp.headers), |
|
|
response_content=body, |
|
|
error_message=msg, |
|
|
) |
|
|
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: |
|
|
await sleep_with_interrupt( |
|
|
delay, |
|
|
cls, |
|
|
wait_label, |
|
|
start_ts, |
|
|
None, |
|
|
display_callback=_display_time_progress if wait_label else None, |
|
|
) |
|
|
delay *= retry_backoff |
|
|
continue |
|
|
raise Exception(f"Failed to upload (HTTP {resp.status}).") |
|
|
try: |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
response_status_code=resp.status, |
|
|
response_headers=dict(resp.headers), |
|
|
response_content="File uploaded successfully.", |
|
|
) |
|
|
except Exception as e: |
|
|
logging.debug("[DEBUG] upload response logging failed: %s", e) |
|
|
return |
|
|
except asyncio.CancelledError: |
|
|
raise ProcessingInterrupted("Task cancelled") from None |
|
|
except (aiohttp.ClientError, OSError) as e: |
|
|
if attempt <= max_retries: |
|
|
with contextlib.suppress(Exception): |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
request_headers=headers or None, |
|
|
request_data=f"[File data {len(data)} bytes]", |
|
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)", |
|
|
) |
|
|
await sleep_with_interrupt( |
|
|
delay, |
|
|
cls, |
|
|
wait_label, |
|
|
start_ts, |
|
|
None, |
|
|
display_callback=_display_time_progress if wait_label else None, |
|
|
) |
|
|
delay *= retry_backoff |
|
|
continue |
|
|
|
|
|
diag = await _diagnose_connectivity() |
|
|
if not diag["internet_accessible"]: |
|
|
raise LocalNetworkError( |
|
|
"Unable to connect to the network. Please check your internet connection and try again." |
|
|
) from e |
|
|
raise ApiServerError("The API service appears unreachable at this time.") from e |
|
|
finally: |
|
|
stop_evt.set() |
|
|
if monitor_task: |
|
|
monitor_task.cancel() |
|
|
with contextlib.suppress(Exception): |
|
|
await monitor_task |
|
|
if sess: |
|
|
with contextlib.suppress(Exception): |
|
|
await sess.close() |
|
|
|
|
|
|
|
|
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: |
|
|
try: |
|
|
parsed = urlparse(url) |
|
|
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") |
|
|
except Exception: |
|
|
slug = "upload" |
|
|
return f"{method}_{slug}_{op_uuid}_try{attempt}" |
|
|
|