| | """
|
| | ComfyUI RunpodDirect - Direct Model Downloads for RunPod
|
| | Download models directly to your RunPod instance with multi-connection support
|
| | """
|
| |
|
| | import os
|
| | import logging
|
| | import asyncio
|
| | import folder_paths
|
| | from aiohttp import web
|
| | from server import PromptServer
|
| |
|
| |
|
| | active_downloads = {}
|
| |
|
| | download_control = {}
|
| |
|
| | download_queue = []
|
| | current_download_task = None
|
| |
|
| |
|
| | CHUNK_SIZE = 32 * 1024 * 1024
|
| | NUM_CONNECTIONS = 8
|
| |
|
| |
|
| | @PromptServer.instance.routes.post("/server_download/start")
|
| | async def start_download(request):
|
| | """Start downloading a model file to the server"""
|
| | try:
|
| | json_data = await request.json()
|
| | url = json_data.get("url")
|
| | save_path = json_data.get("save_path")
|
| | filename = json_data.get("filename")
|
| |
|
| | if not url or not save_path or not filename:
|
| | return web.json_response(
|
| | {"error": "Missing required parameters: url, save_path, filename"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | if save_path not in folder_paths.folder_names_and_paths:
|
| | return web.json_response(
|
| | {"error": f"Invalid save_path: {save_path}. Must be one of: {list(folder_paths.folder_names_and_paths.keys())}"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| |
|
| | if "/" in filename or "\\" in filename or os.path.sep in filename:
|
| | return web.json_response(
|
| | {"error": "Invalid filename: must not contain path separators"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | if ".." in filename or filename.startswith("/") or filename.startswith("~"):
|
| | return web.json_response(
|
| | {"error": "Invalid filename: path traversal patterns detected"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | safe_filename = os.path.basename(filename)
|
| | if safe_filename != filename:
|
| | return web.json_response(
|
| | {"error": "Invalid filename: must be a simple filename without path components"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | output_dir = folder_paths.folder_names_and_paths[save_path][0][0]
|
| | output_path = os.path.join(output_dir, safe_filename)
|
| |
|
| |
|
| | output_path = os.path.abspath(output_path)
|
| | output_dir = os.path.abspath(output_dir)
|
| | if not output_path.startswith(output_dir + os.sep):
|
| | return web.json_response(
|
| | {"error": "Security error: attempted directory escape"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | if os.path.exists(output_path):
|
| | return web.json_response(
|
| | {"error": f"File already exists: {output_path}"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| |
|
| |
|
| | download_id = f"{save_path}/{safe_filename}"
|
| | active_downloads[download_id] = {
|
| | "url": url,
|
| | "filename": safe_filename,
|
| | "save_path": save_path,
|
| | "output_path": output_path,
|
| | "progress": 0,
|
| | "status": "queued",
|
| | "priority": None
|
| | }
|
| |
|
| |
|
| | download_queue.append({
|
| | "download_id": download_id,
|
| | "url": url,
|
| | "output_path": output_path
|
| | })
|
| |
|
| |
|
| | asyncio.create_task(process_download_queue())
|
| |
|
| | return web.json_response({
|
| | "success": True,
|
| | "download_id": download_id,
|
| | "message": "Download queued"
|
| | })
|
| |
|
| | except Exception as e:
|
| | logging.error(f"Error starting download: {e}")
|
| | return web.json_response(
|
| | {"error": str(e)},
|
| | status=500
|
| | )
|
| |
|
| |
|
| | async def process_download_queue():
|
| | """Process the download queue - one download at a time"""
|
| | global download_queue, current_download_task
|
| |
|
| |
|
| | if current_download_task is not None and not current_download_task.done():
|
| | logging.info("[RunpodDirect] Download already in progress, waiting...")
|
| | return
|
| |
|
| | if len(download_queue) == 0:
|
| | logging.info("[RunpodDirect] Queue is empty")
|
| | return
|
| |
|
| |
|
| | download_item = download_queue.pop(0)
|
| | download_id = download_item["download_id"]
|
| | url = download_item["url"]
|
| | output_path = download_item["output_path"]
|
| |
|
| |
|
| | active_downloads[download_id]["status"] = "downloading"
|
| | active_downloads[download_id]["progress"] = 0
|
| | active_downloads[download_id]["downloaded"] = 0
|
| |
|
| | logging.info(f"[RunpodDirect] Starting download {download_id} with {NUM_CONNECTIONS} connections (full speed)")
|
| |
|
| |
|
| | await PromptServer.instance.send("server_download_progress", {
|
| | "download_id": download_id,
|
| | "progress": 0,
|
| | "downloaded": 0,
|
| | "total": 0
|
| | })
|
| |
|
| |
|
| | current_download_task = asyncio.create_task(download_file(url, output_path, download_id))
|
| |
|
| |
|
| | current_download_task.add_done_callback(lambda t: on_download_complete(download_id))
|
| |
|
| |
|
| | def on_download_complete(download_id):
|
| | """Called when a download completes - processes next in queue"""
|
| | global current_download_task
|
| |
|
| | current_download_task = None
|
| | logging.info(f"[RunpodDirect] Download completed: {download_id}, processing next in queue...")
|
| |
|
| |
|
| | asyncio.create_task(process_download_queue())
|
| |
|
| |
|
| | async def download_chunk(session, url, start, end, output_path, chunk_index, download_id):
|
| | """Download a specific chunk of the file"""
|
| | headers = {'Range': f'bytes={start}-{end}'}
|
| |
|
| | try:
|
| | async with session.get(url, headers=headers) as response:
|
| | if response.status not in [200, 206]:
|
| | return None
|
| |
|
| | chunk_data = await response.read()
|
| |
|
| |
|
| | with open(output_path, 'r+b') as f:
|
| | f.seek(start)
|
| | f.write(chunk_data)
|
| |
|
| | return len(chunk_data)
|
| | except Exception as e:
|
| | logging.error(f"Error downloading chunk {chunk_index} for {download_id}: {e}")
|
| | return None
|
| |
|
| |
|
| | async def download_file(url, output_path, download_id):
|
| | """Download file with multi-connection support and progress tracking"""
|
| | import aiohttp
|
| |
|
| | logging.info(f"[RunpodDirect] Download {download_id} using {NUM_CONNECTIONS} connections (full speed)")
|
| |
|
| | try:
|
| |
|
| | download_control[download_id] = {
|
| | "paused": False,
|
| | "cancelled": False,
|
| | "total_downloaded": 0,
|
| | "lock": asyncio.Lock()
|
| | }
|
| |
|
| | timeout = aiohttp.ClientTimeout(total=None)
|
| | async with aiohttp.ClientSession(timeout=timeout) as session:
|
| |
|
| | total_size = 0
|
| | supports_range = False
|
| |
|
| | try:
|
| |
|
| | async with session.head(url, allow_redirects=True) as response:
|
| | if response.status == 200:
|
| | total_size = int(response.headers.get('content-length', 0))
|
| | supports_range = response.headers.get('accept-ranges') == 'bytes'
|
| | except Exception as e:
|
| | logging.warning(f"HEAD request failed for {download_id}: {e}")
|
| |
|
| |
|
| | if total_size == 0:
|
| | logging.info(f"HEAD request didn't return size, trying GET with Range for {download_id}")
|
| | try:
|
| | headers = {'Range': 'bytes=0-0'}
|
| | async with session.get(url, headers=headers, allow_redirects=True) as response:
|
| | if response.status in [200, 206]:
|
| |
|
| | content_range = response.headers.get('content-range', '')
|
| | if content_range:
|
| |
|
| | parts = content_range.split('/')
|
| | if len(parts) == 2:
|
| | total_size = int(parts[1])
|
| | supports_range = True
|
| |
|
| |
|
| | if total_size == 0:
|
| | total_size = int(response.headers.get('content-length', 0))
|
| | except Exception as e:
|
| | logging.warning(f"GET with Range failed for {download_id}: {e}")
|
| |
|
| | if total_size == 0:
|
| | raise Exception("Could not determine file size from server")
|
| |
|
| | logging.info(f"File size for {download_id}: {total_size} bytes, supports range: {supports_range}")
|
| |
|
| |
|
| | with open(output_path, 'wb') as f:
|
| | f.seek(total_size - 1)
|
| | f.write(b'\0')
|
| |
|
| | active_downloads[download_id]["total"] = total_size
|
| | active_downloads[download_id]["downloaded"] = 0
|
| |
|
| |
|
| | if supports_range and total_size > CHUNK_SIZE:
|
| | logging.info(f"Using {NUM_CONNECTIONS} connections for {download_id}")
|
| |
|
| |
|
| | chunk_size = total_size // NUM_CONNECTIONS
|
| | tasks = []
|
| |
|
| | for i in range(NUM_CONNECTIONS):
|
| | start = i * chunk_size
|
| | end = start + chunk_size - 1 if i < NUM_CONNECTIONS - 1 else total_size - 1
|
| |
|
| | tasks.append(download_chunk_with_progress(
|
| | session, url, start, end, output_path, i, download_id, total_size
|
| | ))
|
| |
|
| |
|
| | results = await asyncio.gather(*tasks, return_exceptions=True)
|
| |
|
| |
|
| | for result in results:
|
| | if isinstance(result, Exception):
|
| | raise result
|
| |
|
| | else:
|
| |
|
| | logging.info(f"Using single connection for {download_id}")
|
| | await download_single_connection(session, url, output_path, download_id, total_size)
|
| |
|
| |
|
| | if download_control[download_id]["cancelled"]:
|
| | os.remove(output_path)
|
| | return
|
| |
|
| |
|
| | active_downloads[download_id]["status"] = "completed"
|
| | active_downloads[download_id]["progress"] = 100
|
| |
|
| |
|
| | await PromptServer.instance.send("server_download_complete", {
|
| | "download_id": download_id,
|
| | "path": output_path,
|
| | "size": total_size
|
| | })
|
| |
|
| | logging.info(f"Successfully downloaded {download_id} to {output_path}")
|
| |
|
| |
|
| | del download_control[download_id]
|
| |
|
| | except Exception as e:
|
| | logging.error(f"Error downloading {download_id}: {e}")
|
| | active_downloads[download_id]["status"] = "error"
|
| | active_downloads[download_id]["error"] = str(e)
|
| |
|
| | await PromptServer.instance.send("server_download_error", {
|
| | "download_id": download_id,
|
| | "error": str(e)
|
| | })
|
| |
|
| |
|
| | if download_id in download_control:
|
| | del download_control[download_id]
|
| |
|
| |
|
| | async def download_chunk_with_progress(session, url, start, end, output_path, chunk_index, download_id, total_size):
|
| | """Download chunk with progress tracking"""
|
| | headers = {'Range': f'bytes={start}-{end}'}
|
| | chunk_size = end - start + 1
|
| | downloaded = 0
|
| | last_report_time = 0
|
| |
|
| | try:
|
| | async with session.get(url, headers=headers) as response:
|
| | if response.status not in [200, 206]:
|
| | raise Exception(f"HTTP {response.status} for chunk {chunk_index}")
|
| |
|
| | with open(output_path, 'r+b') as f:
|
| | f.seek(start)
|
| |
|
| | async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
| |
|
| | while download_control.get(download_id, {}).get("paused", False):
|
| | await asyncio.sleep(0.5)
|
| |
|
| |
|
| | if download_control.get(download_id, {}).get("cancelled", False):
|
| | return
|
| |
|
| | f.write(chunk)
|
| | chunk_len = len(chunk)
|
| | downloaded += chunk_len
|
| |
|
| |
|
| | async with download_control[download_id]["lock"]:
|
| | download_control[download_id]["total_downloaded"] += chunk_len
|
| | total_downloaded = download_control[download_id]["total_downloaded"]
|
| |
|
| |
|
| | import time
|
| | current_time = time.time()
|
| | if chunk_index == 0 and (current_time - last_report_time) >= 0.1:
|
| | progress = (total_downloaded / total_size) * 100
|
| | active_downloads[download_id]["progress"] = progress
|
| | active_downloads[download_id]["downloaded"] = total_downloaded
|
| |
|
| | await PromptServer.instance.send("server_download_progress", {
|
| | "download_id": download_id,
|
| | "progress": progress,
|
| | "downloaded": total_downloaded,
|
| | "total": total_size
|
| | })
|
| |
|
| | last_report_time = current_time
|
| |
|
| | except Exception as e:
|
| | logging.error(f"Error in chunk {chunk_index} for {download_id}: {e}")
|
| | raise
|
| |
|
| |
|
| | async def download_single_connection(session, url, output_path, download_id, total_size):
|
| | """Fallback single connection download"""
|
| | downloaded_size = 0
|
| |
|
| | async with session.get(url) as response:
|
| | if response.status != 200:
|
| | raise Exception(f"HTTP {response.status}")
|
| |
|
| | with open(output_path, 'wb') as f:
|
| | async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
| |
|
| | while download_control.get(download_id, {}).get("paused", False):
|
| | await asyncio.sleep(0.5)
|
| |
|
| |
|
| | if download_control.get(download_id, {}).get("cancelled", False):
|
| | return
|
| |
|
| | f.write(chunk)
|
| | downloaded_size += len(chunk)
|
| |
|
| |
|
| | progress = (downloaded_size / total_size) * 100
|
| | active_downloads[download_id]["progress"] = progress
|
| | active_downloads[download_id]["downloaded"] = downloaded_size
|
| |
|
| | await PromptServer.instance.send("server_download_progress", {
|
| | "download_id": download_id,
|
| | "progress": progress,
|
| | "downloaded": downloaded_size,
|
| | "total": total_size
|
| | })
|
| |
|
| |
|
| | @PromptServer.instance.routes.get("/server_download/status")
|
| | async def get_download_status(request):
|
| | """Get status of all downloads"""
|
| | return web.json_response(active_downloads)
|
| |
|
| |
|
| | @PromptServer.instance.routes.get("/server_download/status/{download_id:.*}")
|
| | async def get_single_download_status(request):
|
| | """Get status of a specific download"""
|
| | download_id = request.match_info.get("download_id", "")
|
| |
|
| | if download_id in active_downloads:
|
| | return web.json_response(active_downloads[download_id])
|
| | else:
|
| | return web.json_response(
|
| | {"error": "Download not found"},
|
| | status=404
|
| | )
|
| |
|
| |
|
| | @PromptServer.instance.routes.post("/server_download/pause")
|
| | async def pause_download(request):
|
| | """Pause an active download"""
|
| | try:
|
| | json_data = await request.json()
|
| | download_id = json_data.get("download_id")
|
| |
|
| | if not download_id:
|
| | return web.json_response(
|
| | {"error": "Missing download_id"},
|
| | status=400
|
| | )
|
| |
|
| | if download_id not in download_control:
|
| | return web.json_response(
|
| | {"error": "Download not found or already completed"},
|
| | status=404
|
| | )
|
| |
|
| | download_control[download_id]["paused"] = True
|
| | active_downloads[download_id]["status"] = "paused"
|
| |
|
| | await PromptServer.instance.send("server_download_paused", {
|
| | "download_id": download_id
|
| | })
|
| |
|
| | return web.json_response({"success": True, "message": "Download paused"})
|
| |
|
| | except Exception as e:
|
| | return web.json_response({"error": str(e)}, status=500)
|
| |
|
| |
|
| | @PromptServer.instance.routes.post("/server_download/resume")
|
| | async def resume_download(request):
|
| | """Resume a paused download"""
|
| | try:
|
| | json_data = await request.json()
|
| | download_id = json_data.get("download_id")
|
| |
|
| | if not download_id:
|
| | return web.json_response(
|
| | {"error": "Missing download_id"},
|
| | status=400
|
| | )
|
| |
|
| | if download_id not in download_control:
|
| | return web.json_response(
|
| | {"error": "Download not found or already completed"},
|
| | status=404
|
| | )
|
| |
|
| | download_control[download_id]["paused"] = False
|
| | active_downloads[download_id]["status"] = "downloading"
|
| |
|
| | await PromptServer.instance.send("server_download_resumed", {
|
| | "download_id": download_id
|
| | })
|
| |
|
| | return web.json_response({"success": True, "message": "Download resumed"})
|
| |
|
| | except Exception as e:
|
| | return web.json_response({"error": str(e)}, status=500)
|
| |
|
| |
|
| | @PromptServer.instance.routes.post("/server_download/cancel")
|
| | async def cancel_download(request):
|
| | """Cancel an active download"""
|
| | global download_queue, current_download_task
|
| |
|
| | try:
|
| | json_data = await request.json()
|
| | download_id = json_data.get("download_id")
|
| |
|
| | if not download_id:
|
| | return web.json_response(
|
| | {"error": "Missing download_id"},
|
| | status=400
|
| | )
|
| |
|
| |
|
| | download_queue[:] = [d for d in download_queue if d["download_id"] != download_id]
|
| |
|
| |
|
| | if download_id in download_control:
|
| | download_control[download_id]["cancelled"] = True
|
| |
|
| |
|
| | if download_id in active_downloads:
|
| | active_downloads[download_id]["status"] = "cancelled"
|
| |
|
| | await PromptServer.instance.send("server_download_cancelled", {
|
| | "download_id": download_id
|
| | })
|
| |
|
| | return web.json_response({"success": True, "message": "Download cancelled"})
|
| |
|
| | except Exception as e:
|
| | return web.json_response({"error": str(e)}, status=500)
|
| |
|
| |
|
| | @PromptServer.instance.routes.get("/extensions/ComfyUI-RunpodDirect/serverDownload.js")
|
| | async def serve_js_with_version(request):
|
| | """Serve JS file with cache-busting headers"""
|
| | js_path = os.path.join(os.path.dirname(__file__), "web", "serverDownload.js")
|
| |
|
| | response = web.FileResponse(js_path)
|
| |
|
| | response.headers['Cache-Control'] = 'no-cache, must-revalidate'
|
| | response.headers['Pragma'] = 'no-cache'
|
| | response.headers['Expires'] = '0'
|
| | response.headers['X-Version'] = __version__
|
| |
|
| | return response
|
| |
|
| |
|
| |
|
| | WEB_DIRECTORY = "./web"
|
| |
|
| |
|
| | __version__ = "1.0.6"
|
| |
|
| | NODE_CLASS_MAPPINGS = {}
|
| | NODE_DISPLAY_NAME_MAPPINGS = {}
|
| |
|
| | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
| |
|