File size: 12,720 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import asyncio
import contextlib
import logging
import time
import uuid
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse

import aiohttp
import torch
from pydantic import BaseModel, Field

from comfy_api.latest import IO, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import request_logger

from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
    ApiEndpoint,
    _diagnose_connectivity,
    _display_time_progress,
    sync_op,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import (
    audio_ndarray_to_bytesio,
    audio_tensor_to_contiguous_ndarray,
    tensor_to_bytesio,
)


class UploadRequest(BaseModel):
    file_name: str = Field(..., description="Filename to upload")
    content_type: Optional[str] = Field(
        None,
        description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
    )


class UploadResponse(BaseModel):
    download_url: str = Field(..., description="URL to GET uploaded file")
    upload_url: str = Field(..., description="URL to PUT file to upload")


async def upload_images_to_comfyapi(
    cls: type[IO.ComfyNode],
    image: torch.Tensor,
    *,
    max_images: int = 8,
    mime_type: Optional[str] = None,
    wait_label: Optional[str] = "Uploading",
) -> list[str]:
    """
    Uploads images to ComfyUI API and returns download URLs.
    To upload multiple images, stack them in the batch dimension first.
    """
    # if batch, try to upload each file if max_images is greater than 0
    download_urls: list[str] = []
    is_batch = len(image.shape) > 3
    batch_len = image.shape[0] if is_batch else 1

    for idx in range(min(batch_len, max_images)):
        tensor = image[idx] if is_batch else image
        img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
        url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
        download_urls.append(url)
    return download_urls


async def upload_audio_to_comfyapi(
    cls: type[IO.ComfyNode],
    audio: Input.Audio,
    *,
    container_format: str = "mp4",
    codec_name: str = "aac",
    mime_type: str = "audio/mp4",
    filename: str = "uploaded_audio.mp4",
) -> str:
    """
    Uploads a single audio input to ComfyUI API and returns its download URL.
    Encodes the raw waveform into the specified format before uploading.
    """
    sample_rate: int = audio["sample_rate"]
    waveform: torch.Tensor = audio["waveform"]
    audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
    audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
    return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)


async def upload_video_to_comfyapi(
    cls: type[IO.ComfyNode],
    video: Input.Video,
    *,
    container: VideoContainer = VideoContainer.MP4,
    codec: VideoCodec = VideoCodec.H264,
    max_duration: Optional[int] = None,
) -> str:
    """
    Uploads a single video to ComfyUI API and returns its download URL.
    Uses the specified container and codec for saving the video before upload.
    """
    if max_duration is not None:
        try:
            actual_duration = video.get_duration()
            if actual_duration > max_duration:
                raise ValueError(
                    f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
                )
        except Exception as e:
            logging.error("Error getting video duration: %s", str(e))
            raise ValueError(f"Could not verify video duration from source: {e}") from e

    upload_mime_type = f"video/{container.value.lower()}"
    filename = f"uploaded_video.{container.value.lower()}"

    # Convert VideoInput to BytesIO using specified container/codec
    video_bytes_io = BytesIO()
    video.save_to(video_bytes_io, format=container, codec=codec)
    video_bytes_io.seek(0)

    return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)


async def upload_file_to_comfyapi(
    cls: type[IO.ComfyNode],
    file_bytes_io: BytesIO,
    filename: str,
    upload_mime_type: Optional[str],
    wait_label: Optional[str] = "Uploading",
) -> str:
    """Uploads a single file to ComfyUI API and returns its download URL."""
    if upload_mime_type is None:
        request_object = UploadRequest(file_name=filename)
    else:
        request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
    create_resp = await sync_op(
        cls,
        endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
        data=request_object,
        response_model=UploadResponse,
        final_label_on_success=None,
        monitor_progress=False,
    )
    await upload_file(
        cls,
        create_resp.upload_url,
        file_bytes_io,
        content_type=upload_mime_type,
        wait_label=wait_label,
    )
    return create_resp.download_url


async def upload_file(
    cls: type[IO.ComfyNode],
    upload_url: str,
    file: Union[BytesIO, str],
    *,
    content_type: Optional[str] = None,
    max_retries: int = 3,
    retry_delay: float = 1.0,
    retry_backoff: float = 2.0,
    wait_label: Optional[str] = None,
) -> None:
    """
    Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.

    Args:
        cls: Node class (provides auth context + UI progress hooks).
        upload_url: Pre-signed PUT URL.
        file: BytesIO or path string.
        content_type: Explicit MIME type. If None, we *suppress* Content-Type.
        max_retries: Maximum retry attempts.
        retry_delay: Initial delay in seconds.
        retry_backoff: Exponential backoff factor.
        wait_label: Progress label shown in Comfy UI.

    Raises:
        ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
    """
    if isinstance(file, BytesIO):
        with contextlib.suppress(Exception):
            file.seek(0)
        data = file.read()
    elif isinstance(file, str):
        with open(file, "rb") as f:
            data = f.read()
    else:
        raise ValueError("file must be a BytesIO or a filesystem path string")

    headers: dict[str, str] = {}
    skip_auto_headers: set[str] = set()
    if content_type:
        headers["Content-Type"] = content_type
    else:
        skip_auto_headers.add("Content-Type")  # Don't let aiohttp add Content-Type, it can break the signed request

    attempt = 0
    delay = retry_delay
    start_ts = time.monotonic()
    op_uuid = uuid.uuid4().hex[:8]
    while True:
        attempt += 1
        operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
        timeout = aiohttp.ClientTimeout(total=None)
        stop_evt = asyncio.Event()

        async def _monitor():
            try:
                while not stop_evt.is_set():
                    if is_processing_interrupted():
                        return
                    if wait_label:
                        _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
                    await asyncio.sleep(1.0)
            except asyncio.CancelledError:
                return

        monitor_task = asyncio.create_task(_monitor())
        sess: Optional[aiohttp.ClientSession] = None
        try:
            try:
                request_logger.log_request_response(
                    operation_id=operation_id,
                    request_method="PUT",
                    request_url=upload_url,
                    request_headers=headers or None,
                    request_params=None,
                    request_data=f"[File data {len(data)} bytes]",
                )
            except Exception as e:
                logging.debug("[DEBUG] upload request logging failed: %s", e)

            sess = aiohttp.ClientSession(timeout=timeout)
            req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
            req_task = asyncio.create_task(req)

            done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)

            if monitor_task in done and req_task in pending:
                req_task.cancel()
                raise ProcessingInterrupted("Upload cancelled")

            try:
                resp = await req_task
            except asyncio.CancelledError:
                raise ProcessingInterrupted("Upload cancelled") from None

            async with resp:
                if resp.status >= 400:
                    with contextlib.suppress(Exception):
                        try:
                            body = await resp.json()
                        except Exception:
                            body = await resp.text()
                        msg = f"Upload failed with status {resp.status}"
                        request_logger.log_request_response(
                            operation_id=operation_id,
                            request_method="PUT",
                            request_url=upload_url,
                            response_status_code=resp.status,
                            response_headers=dict(resp.headers),
                            response_content=body,
                            error_message=msg,
                        )
                    if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
                        await sleep_with_interrupt(
                            delay,
                            cls,
                            wait_label,
                            start_ts,
                            None,
                            display_callback=_display_time_progress if wait_label else None,
                        )
                        delay *= retry_backoff
                        continue
                    raise Exception(f"Failed to upload (HTTP {resp.status}).")
                try:
                    request_logger.log_request_response(
                        operation_id=operation_id,
                        request_method="PUT",
                        request_url=upload_url,
                        response_status_code=resp.status,
                        response_headers=dict(resp.headers),
                        response_content="File uploaded successfully.",
                    )
                except Exception as e:
                    logging.debug("[DEBUG] upload response logging failed: %s", e)
                return
        except asyncio.CancelledError:
            raise ProcessingInterrupted("Task cancelled") from None
        except (aiohttp.ClientError, OSError) as e:
            if attempt <= max_retries:
                with contextlib.suppress(Exception):
                    request_logger.log_request_response(
                        operation_id=operation_id,
                        request_method="PUT",
                        request_url=upload_url,
                        request_headers=headers or None,
                        request_data=f"[File data {len(data)} bytes]",
                        error_message=f"{type(e).__name__}: {str(e)} (will retry)",
                    )
                await sleep_with_interrupt(
                    delay,
                    cls,
                    wait_label,
                    start_ts,
                    None,
                    display_callback=_display_time_progress if wait_label else None,
                )
                delay *= retry_backoff
                continue

            diag = await _diagnose_connectivity()
            if not diag["internet_accessible"]:
                raise LocalNetworkError(
                    "Unable to connect to the network. Please check your internet connection and try again."
                ) from e
            raise ApiServerError("The API service appears unreachable at this time.") from e
        finally:
            stop_evt.set()
            if monitor_task:
                monitor_task.cancel()
                with contextlib.suppress(Exception):
                    await monitor_task
            if sess:
                with contextlib.suppress(Exception):
                    await sess.close()


def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
    try:
        parsed = urlparse(url)
        slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
    except Exception:
        slug = "upload"
    return f"{method}_{slug}_{op_uuid}_try{attempt}"