velai-workshop / velai /network_utils.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Mapping, Sequence
import requests
@dataclass(frozen=True, slots=True)
class DownloadResult:
url: str
ok: bool
status_code: int | None
data: bytes | None
error: str | None
def _download_sync(
url: str,
*,
headers: Mapping[str, str] | None,
timeout_s: float,
max_bytes: int | None,
) -> bytes:
resp = requests.get(url, headers=headers, timeout=timeout_s, stream=True)
resp.raise_for_status()
if max_bytes is None:
return resp.content
chunks: list[bytes] = []
total = 0
for chunk in resp.iter_content(chunk_size=8192):
if not chunk:
continue
total += len(chunk)
if total > max_bytes:
raise ValueError(f"Response too large (>{max_bytes} bytes): {url}")
chunks.append(chunk)
return b"".join(chunks)
async def download_bytes(
url: str,
*,
headers: Mapping[str, str] | None = None,
timeout_s: float = 30.0,
max_bytes: int | None = None,
retries: int = 2,
retry_backoff_s: float = 0.5,
) -> bytes:
attempt = 0
last_exc: BaseException | None = None
while attempt <= retries:
try:
return await asyncio.to_thread(
_download_sync,
url,
headers=headers,
timeout_s=timeout_s,
max_bytes=max_bytes,
)
except Exception as exc:
last_exc = exc
if attempt >= retries:
raise
await asyncio.sleep(retry_backoff_s * (2**attempt))
attempt += 1
raise RuntimeError(f"Download failed for {url}") from last_exc
async def download_many_bytes(
urls: Sequence[str],
*,
headers: Mapping[str, str] | None = None,
timeout_s: float = 30.0,
max_bytes: int | None = None,
retries: int = 2,
retry_backoff_s: float = 0.5,
concurrency: int = 8,
return_results: bool = False,
) -> list[bytes] | list[DownloadResult]:
sem = asyncio.Semaphore(max(1, concurrency))
async def one(url: str) -> bytes | DownloadResult:
async with sem:
try:
data = await download_bytes(
url,
headers=headers,
timeout_s=timeout_s,
max_bytes=max_bytes,
retries=retries,
retry_backoff_s=retry_backoff_s,
)
if return_results:
return DownloadResult(
url=url,
ok=True,
status_code=None,
data=data,
error=None,
)
return data
except Exception as exc:
if not return_results:
raise
status = None
if isinstance(exc, requests.HTTPError):
status = exc.response.status_code if exc.response is not None else None
return DownloadResult(
url=url,
ok=False,
status_code=status,
data=None,
error=str(exc),
)
results = await asyncio.gather(*(one(url) for url in urls))
return list(results)