highscoregames12018's picture
Add/update custom_nodes
ae99871 verified
"""
ComfyUI RunpodDirect - Direct Model Downloads for RunPod
Download models directly to your RunPod instance with multi-connection support
"""
import os
import logging
import asyncio
import folder_paths
from aiohttp import web
from server import PromptServer
# Track active downloads
active_downloads = {}
# Download control (for pause/resume)
download_control = {}
# Download queue management
download_queue = []
current_download_task = None # Only one download at a time
# Configuration optimized for datacenter connections (RunPod)
CHUNK_SIZE = 32 * 1024 * 1024 # 32MB chunks - balanced for 500MB to 30GB+ files
NUM_CONNECTIONS = 8 # 8 parallel connections - optimal for DC bandwidth
@PromptServer.instance.routes.post("/server_download/start")
async def start_download(request):
"""Start downloading a model file to the server"""
try:
json_data = await request.json()
url = json_data.get("url")
save_path = json_data.get("save_path") # e.g., "checkpoints"
filename = json_data.get("filename") # e.g., "model.safetensors"
if not url or not save_path or not filename:
return web.json_response(
{"error": "Missing required parameters: url, save_path, filename"},
status=400
)
# Validate save_path
if save_path not in folder_paths.folder_names_and_paths:
return web.json_response(
{"error": f"Invalid save_path: {save_path}. Must be one of: {list(folder_paths.folder_names_and_paths.keys())}"},
status=400
)
# Security: Validate filename to prevent path traversal attacks
# Check for any directory separators (both Unix and Windows style)
if "/" in filename or "\\" in filename or os.path.sep in filename:
return web.json_response(
{"error": "Invalid filename: must not contain path separators"},
status=400
)
# Additional check for various path traversal patterns
if ".." in filename or filename.startswith("/") or filename.startswith("~"):
return web.json_response(
{"error": "Invalid filename: path traversal patterns detected"},
status=400
)
# Normalize the filename to remove any potential tricks
safe_filename = os.path.basename(filename)
if safe_filename != filename:
return web.json_response(
{"error": "Invalid filename: must be a simple filename without path components"},
status=400
)
# Get the first folder path for this model type
output_dir = folder_paths.folder_names_and_paths[save_path][0][0]
output_path = os.path.join(output_dir, safe_filename)
# Final security check: ensure the resolved path is within the intended directory
output_path = os.path.abspath(output_path)
output_dir = os.path.abspath(output_dir)
if not output_path.startswith(output_dir + os.sep):
return web.json_response(
{"error": "Security error: attempted directory escape"},
status=400
)
# Check if file already exists
if os.path.exists(output_path):
return web.json_response(
{"error": f"File already exists: {output_path}"},
status=400
)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Mark as queued
download_id = f"{save_path}/{safe_filename}"
active_downloads[download_id] = {
"url": url,
"filename": safe_filename,
"save_path": save_path,
"output_path": output_path,
"progress": 0,
"status": "queued",
"priority": None
}
# Add to queue
download_queue.append({
"download_id": download_id,
"url": url,
"output_path": output_path
})
# Process queue (will start download if slot available)
asyncio.create_task(process_download_queue())
return web.json_response({
"success": True,
"download_id": download_id,
"message": "Download queued"
})
except Exception as e:
logging.error(f"Error starting download: {e}")
return web.json_response(
{"error": str(e)},
status=500
)
async def process_download_queue():
"""Process the download queue - one download at a time"""
global download_queue, current_download_task
# Check if already downloading
if current_download_task is not None and not current_download_task.done():
logging.info("[RunpodDirect] Download already in progress, waiting...")
return # Already downloading
if len(download_queue) == 0:
logging.info("[RunpodDirect] Queue is empty")
return # Nothing to process
# Get next download from queue
download_item = download_queue.pop(0)
download_id = download_item["download_id"]
url = download_item["url"]
output_path = download_item["output_path"]
# Set status to downloading
active_downloads[download_id]["status"] = "downloading"
active_downloads[download_id]["progress"] = 0
active_downloads[download_id]["downloaded"] = 0
logging.info(f"[RunpodDirect] Starting download {download_id} with {NUM_CONNECTIONS} connections (full speed)")
# Notify frontend that download is starting
await PromptServer.instance.send("server_download_progress", {
"download_id": download_id,
"progress": 0,
"downloaded": 0,
"total": 0
})
# Start download task
current_download_task = asyncio.create_task(download_file(url, output_path, download_id))
# Add completion callback to process next in queue
current_download_task.add_done_callback(lambda t: on_download_complete(download_id))
def on_download_complete(download_id):
"""Called when a download completes - processes next in queue"""
global current_download_task
current_download_task = None
logging.info(f"[RunpodDirect] Download completed: {download_id}, processing next in queue...")
# Process next in queue
asyncio.create_task(process_download_queue())
async def download_chunk(session, url, start, end, output_path, chunk_index, download_id):
"""Download a specific chunk of the file"""
headers = {'Range': f'bytes={start}-{end}'}
try:
async with session.get(url, headers=headers) as response:
if response.status not in [200, 206]:
return None
chunk_data = await response.read()
# Write chunk to file at specific position
with open(output_path, 'r+b') as f:
f.seek(start)
f.write(chunk_data)
return len(chunk_data)
except Exception as e:
logging.error(f"Error downloading chunk {chunk_index} for {download_id}: {e}")
return None
async def download_file(url, output_path, download_id):
"""Download file with multi-connection support and progress tracking"""
import aiohttp
logging.info(f"[RunpodDirect] Download {download_id} using {NUM_CONNECTIONS} connections (full speed)")
try:
# Initialize control for this download
download_control[download_id] = {
"paused": False,
"cancelled": False,
"total_downloaded": 0, # Shared counter for all chunks
"lock": asyncio.Lock() # Lock for thread-safe updates
}
timeout = aiohttp.ClientTimeout(total=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
# Get file size - try HEAD first, then fall back to GET with Range
total_size = 0
supports_range = False
try:
# Try HEAD request first
async with session.head(url, allow_redirects=True) as response:
if response.status == 200:
total_size = int(response.headers.get('content-length', 0))
supports_range = response.headers.get('accept-ranges') == 'bytes'
except Exception as e:
logging.warning(f"HEAD request failed for {download_id}: {e}")
# If HEAD didn't give us the size, try GET with Range header
if total_size == 0:
logging.info(f"HEAD request didn't return size, trying GET with Range for {download_id}")
try:
headers = {'Range': 'bytes=0-0'}
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status in [200, 206]:
# Try to get size from Content-Range header first
content_range = response.headers.get('content-range', '')
if content_range:
# Format: "bytes 0-0/12345" where 12345 is total size
parts = content_range.split('/')
if len(parts) == 2:
total_size = int(parts[1])
supports_range = True
# Fallback to Content-Length
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
except Exception as e:
logging.warning(f"GET with Range failed for {download_id}: {e}")
if total_size == 0:
raise Exception("Could not determine file size from server")
logging.info(f"File size for {download_id}: {total_size} bytes, supports range: {supports_range}")
# Create file with full size
with open(output_path, 'wb') as f:
f.seek(total_size - 1)
f.write(b'\0')
active_downloads[download_id]["total"] = total_size
active_downloads[download_id]["downloaded"] = 0
# Use multi-connection download if server supports range requests
if supports_range and total_size > CHUNK_SIZE:
logging.info(f"Using {NUM_CONNECTIONS} connections for {download_id}")
# Calculate chunk ranges
chunk_size = total_size // NUM_CONNECTIONS
tasks = []
for i in range(NUM_CONNECTIONS):
start = i * chunk_size
end = start + chunk_size - 1 if i < NUM_CONNECTIONS - 1 else total_size - 1
tasks.append(download_chunk_with_progress(
session, url, start, end, output_path, i, download_id, total_size
))
# Download all chunks in parallel
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check for errors
for result in results:
if isinstance(result, Exception):
raise result
else:
# Fallback to single connection download
logging.info(f"Using single connection for {download_id}")
await download_single_connection(session, url, output_path, download_id, total_size)
# Check if cancelled
if download_control[download_id]["cancelled"]:
os.remove(output_path)
return
# Mark as complete
active_downloads[download_id]["status"] = "completed"
active_downloads[download_id]["progress"] = 100
# Send completion message
await PromptServer.instance.send("server_download_complete", {
"download_id": download_id,
"path": output_path,
"size": total_size
})
logging.info(f"Successfully downloaded {download_id} to {output_path}")
# Cleanup
del download_control[download_id]
except Exception as e:
logging.error(f"Error downloading {download_id}: {e}")
active_downloads[download_id]["status"] = "error"
active_downloads[download_id]["error"] = str(e)
await PromptServer.instance.send("server_download_error", {
"download_id": download_id,
"error": str(e)
})
# Cleanup
if download_id in download_control:
del download_control[download_id]
async def download_chunk_with_progress(session, url, start, end, output_path, chunk_index, download_id, total_size):
"""Download chunk with progress tracking"""
headers = {'Range': f'bytes={start}-{end}'}
chunk_size = end - start + 1
downloaded = 0
last_report_time = 0
try:
async with session.get(url, headers=headers) as response:
if response.status not in [200, 206]:
raise Exception(f"HTTP {response.status} for chunk {chunk_index}")
with open(output_path, 'r+b') as f:
f.seek(start)
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
# Check if paused
while download_control.get(download_id, {}).get("paused", False):
await asyncio.sleep(0.5)
# Check if cancelled
if download_control.get(download_id, {}).get("cancelled", False):
return
f.write(chunk)
chunk_len = len(chunk)
downloaded += chunk_len
# Update shared progress counter with lock
async with download_control[download_id]["lock"]:
download_control[download_id]["total_downloaded"] += chunk_len
total_downloaded = download_control[download_id]["total_downloaded"]
# Send progress updates every 100ms to avoid spam (only from chunk 0)
import time
current_time = time.time()
if chunk_index == 0 and (current_time - last_report_time) >= 0.1:
progress = (total_downloaded / total_size) * 100
active_downloads[download_id]["progress"] = progress
active_downloads[download_id]["downloaded"] = total_downloaded
await PromptServer.instance.send("server_download_progress", {
"download_id": download_id,
"progress": progress,
"downloaded": total_downloaded,
"total": total_size
})
last_report_time = current_time
except Exception as e:
logging.error(f"Error in chunk {chunk_index} for {download_id}: {e}")
raise
async def download_single_connection(session, url, output_path, download_id, total_size):
"""Fallback single connection download"""
downloaded_size = 0
async with session.get(url) as response:
if response.status != 200:
raise Exception(f"HTTP {response.status}")
with open(output_path, 'wb') as f:
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
# Check if paused
while download_control.get(download_id, {}).get("paused", False):
await asyncio.sleep(0.5)
# Check if cancelled
if download_control.get(download_id, {}).get("cancelled", False):
return
f.write(chunk)
downloaded_size += len(chunk)
# Update progress
progress = (downloaded_size / total_size) * 100
active_downloads[download_id]["progress"] = progress
active_downloads[download_id]["downloaded"] = downloaded_size
await PromptServer.instance.send("server_download_progress", {
"download_id": download_id,
"progress": progress,
"downloaded": downloaded_size,
"total": total_size
})
@PromptServer.instance.routes.get("/server_download/status")
async def get_download_status(request):
"""Get status of all downloads"""
return web.json_response(active_downloads)
@PromptServer.instance.routes.get("/server_download/status/{download_id:.*}")
async def get_single_download_status(request):
"""Get status of a specific download"""
download_id = request.match_info.get("download_id", "")
if download_id in active_downloads:
return web.json_response(active_downloads[download_id])
else:
return web.json_response(
{"error": "Download not found"},
status=404
)
@PromptServer.instance.routes.post("/server_download/pause")
async def pause_download(request):
"""Pause an active download"""
try:
json_data = await request.json()
download_id = json_data.get("download_id")
if not download_id:
return web.json_response(
{"error": "Missing download_id"},
status=400
)
if download_id not in download_control:
return web.json_response(
{"error": "Download not found or already completed"},
status=404
)
download_control[download_id]["paused"] = True
active_downloads[download_id]["status"] = "paused"
await PromptServer.instance.send("server_download_paused", {
"download_id": download_id
})
return web.json_response({"success": True, "message": "Download paused"})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@PromptServer.instance.routes.post("/server_download/resume")
async def resume_download(request):
"""Resume a paused download"""
try:
json_data = await request.json()
download_id = json_data.get("download_id")
if not download_id:
return web.json_response(
{"error": "Missing download_id"},
status=400
)
if download_id not in download_control:
return web.json_response(
{"error": "Download not found or already completed"},
status=404
)
download_control[download_id]["paused"] = False
active_downloads[download_id]["status"] = "downloading"
await PromptServer.instance.send("server_download_resumed", {
"download_id": download_id
})
return web.json_response({"success": True, "message": "Download resumed"})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@PromptServer.instance.routes.post("/server_download/cancel")
async def cancel_download(request):
"""Cancel an active download"""
global download_queue, current_download_task
try:
json_data = await request.json()
download_id = json_data.get("download_id")
if not download_id:
return web.json_response(
{"error": "Missing download_id"},
status=400
)
# Check if download is queued (not started yet)
download_queue[:] = [d for d in download_queue if d["download_id"] != download_id]
# Check if download is active
if download_id in download_control:
download_control[download_id]["cancelled"] = True
# Update status
if download_id in active_downloads:
active_downloads[download_id]["status"] = "cancelled"
await PromptServer.instance.send("server_download_cancelled", {
"download_id": download_id
})
return web.json_response({"success": True, "message": "Download cancelled"})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@PromptServer.instance.routes.get("/extensions/ComfyUI-RunpodDirect/serverDownload.js")
async def serve_js_with_version(request):
"""Serve JS file with cache-busting headers"""
js_path = os.path.join(os.path.dirname(__file__), "web", "serverDownload.js")
response = web.FileResponse(js_path)
# Add cache control headers to force revalidation
response.headers['Cache-Control'] = 'no-cache, must-revalidate'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
response.headers['X-Version'] = __version__
return response
# Set the web directory for frontend files
WEB_DIRECTORY = "./web"
# Version for cache busting - increment this when you update the JS
__version__ = "1.0.6"
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]