bobinamoe's picture
Upload 1664 files
baac5bb verified
from __future__ import annotations
import asyncio
import json
import os
import shutil
import socket
import subprocess
import sys
import tempfile
import time
import uuid
import zipfile
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, Union
import folder_paths
from aiohttp import web
from server import PromptServer
from comfy_pack.hash import async_batch_get_sha256
from comfy_pack.model_helper import alookup_model_source
from comfy_pack.package import build_bento
ZPath = Union[Path, zipfile.Path]
TEMP_FOLDER = Path(__file__).parent.parent / "temp"
COMFY_PACK_DIR = Path(__file__).parent.parent / "src" / "comfy_pack"
EXCLUDE_PACKAGES = ["bentoml", "onnxruntime", "conda", "nvidia-*"]
def normalize_name(name: str) -> str:
import re
return re.sub(r"[-_.]+", "-", name).lower()
def get_snapshot_path() -> Path | None:
manager_file_path = Path(
folder_paths.get_user_directory(), "default", "ComfyUI-Manager"
)
return manager_file_path / "snapshots"
async def _save_snapshot() -> dict[str, Any]:
save_snapshot_route = next(
(
route
for route in PromptServer.instance.routes
if route.path == "/snapshot/save"
),
None,
)
if not save_snapshot_route:
raise RuntimeError("ComfyUI-Manager must be installed to save snapshot")
await save_snapshot_route.handler(None)
snapshot_path = get_snapshot_path()
if not snapshot_path.exists():
raise RuntimeError("Snapshot save failed")
most_recent = max(
snapshot_path.glob("*.json"), key=lambda x: x.stat().st_mtime, default=None
)
if not most_recent:
raise RuntimeError("Snapshot save failed")
with most_recent.open("r") as f:
return json.load(f)
async def _write_snapshot(path: ZPath, data: dict, models: list) -> None:
snapshot = await _save_snapshot()
for package in list(snapshot["pips"]):
if any(
fnmatch(normalize_name(package.split("==")[0]), pat)
for pat in EXCLUDE_PACKAGES
):
del snapshot["pips"][package]
with path.joinpath("snapshot.json").open("w") as f:
snapshot.update(
{
"python": f"{sys.version_info.major}.{sys.version_info.minor}",
"models": models,
}
)
f.write(json.dumps(snapshot, indent=2))
def _is_port_in_use(port: int | str, host="localhost"):
if isinstance(port, str):
port = int(port)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.connect((host, port))
return True
except ConnectionRefusedError:
return False
except Exception:
return True
def _is_file_refered(file_path: Path, workflow_api: dict) -> bool:
""" """
used_inputs = set()
for node in workflow_api.values():
for _, v in node["inputs"].items():
if isinstance(v, str):
used_inputs.add(v)
all_inputs = "\n".join(used_inputs)
file_path = file_path.absolute().relative_to(folder_paths.base_path)
if file_path.parts[0] == "input":
relpath = Path(*file_path.parts[1:])
else: # models
relpath = Path(*file_path.parts[2:])
return str(relpath) in all_inputs
async def _get_models(
store_models: bool = False,
workflow_api: dict | None = None,
model_filter: set[str] | None = None,
ensure_sha=True,
ensure_source=True,
) -> list:
proc = await asyncio.subprocess.create_subprocess_exec(
"git",
"ls-files",
"--others",
folder_paths.models_dir,
stdout=subprocess.PIPE,
)
stdout, _ = await proc.communicate()
models = []
model_filenames = [
os.path.abspath(line)
for line in stdout.decode().splitlines()
if not os.path.basename(line).startswith(".")
]
model_hashes = await async_batch_get_sha256(
model_filenames,
cache_only=not (ensure_sha or store_models),
)
for filename in model_filenames:
relpath = os.path.relpath(filename, folder_paths.base_path)
model_data = {
"filename": relpath,
"size": os.path.getsize(filename),
"atime": os.path.getatime(filename),
"ctime": os.path.getctime(filename),
"disabled": relpath not in model_filter
if model_filter is not None
else False,
"sha256": model_hashes.get(filename),
}
model_data["source"] = await alookup_model_source(
model_data["sha256"],
cache_only=not ensure_source,
)
# should_store = store_models and (
# model_data["source"].get("source") != "huggingface"
# or model_data["source"].get("repo", "").startswith("datasets/")
# ) # TODO: sort this out
should_store = store_models
if should_store:
import bentoml
model_tag = f"cpack-model:{model_data['sha256'][:16]}"
try:
model = bentoml.models.get(model_tag)
except bentoml.exceptions.NotFound:
with bentoml.models.create(
model_tag, labels={"filename": relpath}
) as model:
shutil.copy(filename, model.path_of("model.bin"))
model_data["model_tag"] = model_tag
models.append(model_data)
if workflow_api:
for model in models:
model["refered"] = _is_file_refered(Path(model["filename"]), workflow_api)
return models
async def _write_workflow(path: ZPath, data: dict) -> None:
print("Package => Writing workflow")
with path.joinpath("workflow_api.json").open("w") as f:
f.write(json.dumps(data["workflow_api"], indent=2))
with path.joinpath("workflow.json").open("w") as f:
f.write(json.dumps(data["workflow"], indent=2))
async def _write_inputs(path: ZPath, data: dict) -> None:
print("Package => Writing inputs")
if isinstance(path, Path):
path.joinpath("input").mkdir(exist_ok=True)
input_dir = folder_paths.get_input_directory()
if "files" in data:
selected = "\n".join(set(data.get("files", [])))
else:
selected = None
src_root = Path(input_dir).absolute()
for src in src_root.glob("**/*"):
rel = src.relative_to(src_root)
if selected is not None and str(rel) not in selected:
continue
if src.is_dir():
if isinstance(path, Path):
path.joinpath("input").joinpath(rel).mkdir(parents=True, exist_ok=True)
if src.is_file():
with path.joinpath("input").joinpath(rel).open("wb") as f:
with open(src, "rb") as input_file:
shutil.copyfileobj(input_file, f)
@PromptServer.instance.routes.post("/bentoml/pack")
async def pack_workspace(request):
data = await request.json()
TEMP_FOLDER.mkdir(exist_ok=True)
older_than_1h = time.time() - 60 * 60
for file in TEMP_FOLDER.iterdir():
if file.is_file() and file.stat().st_ctime < older_than_1h:
file.unlink()
zip_filename = f"{uuid.uuid4()}.zip"
with zipfile.ZipFile(TEMP_FOLDER / zip_filename, "w") as zf:
path = zipfile.Path(zf)
await _prepare_pack(path, data)
return web.json_response({"download_url": f"/bentoml/download/{zip_filename}"})
class DevServer:
TIMEOUT = 3600 * 24
proc: Union[None, subprocess.Popen] = None
watch_dog_task: asyncio.Task | None = None
last_feed = 0
run_dir: Path | None = None
port = 0
@classmethod
def start(cls, workflow_api: dict, port: int = 3000):
from comfy_pack import __file__ as comfy_pack_file
cls.stop()
cls.port = port
# prepare a temporary directory
cls.run_dir = Path(tempfile.mkdtemp(suffix="-bento", prefix="comfy-pack-"))
with cls.run_dir.joinpath("workflow_api.json").open("w") as f:
f.write(json.dumps(workflow_api, indent=2))
shutil.copy(
Path(comfy_pack_file).with_name("service.py"),
cls.run_dir / "service.py",
)
shutil.copytree(COMFY_PACK_DIR, cls.run_dir / COMFY_PACK_DIR.name)
# find a free port
self_port = 8188
for i, arg in enumerate(sys.argv):
if arg == "--port" or arg == "-p":
self_port = int(sys.argv[i + 1])
break
print(f"Starting dev server at port {port}, comfyui at port {self_port}")
cls.proc = subprocess.Popen(
[
sys.executable,
"-m",
"bentoml",
"serve",
"service:ComfyService",
"--port",
str(port),
],
cwd=str(cls.run_dir.absolute()),
env={
**os.environ,
"COMFYUI_SERVER": f"localhost:{self_port}",
},
)
cls.watch_dog_task = asyncio.create_task(cls.watch_dog())
cls.last_feed = time.time()
@classmethod
async def watch_dog(cls):
while True:
await asyncio.sleep(0.1)
if cls.last_feed + cls.TIMEOUT < time.time():
cls.stop()
break
@classmethod
def feed_watch_dog(cls):
if cls.proc:
if cls.proc.poll() is None:
cls.last_feed = time.time()
return True
else:
cls.stop()
return False
return False
@classmethod
def stop(cls):
if cls.proc:
cls.proc.terminate()
cls.proc.wait()
cls.proc = None
time.sleep(1)
print("Dev server stopped")
if cls.watch_dog_task:
cls.watch_dog_task.cancel()
cls.watch_dog_task = None
cls.last_feed = 0
if cls.run_dir:
shutil.rmtree(cls.run_dir)
cls.run_dir = None
def _parse_workflow(workflow: dict) -> tuple[dict[str, Any], dict[str, Any]]:
inputs = {}
outputs = {}
dep_map = {}
for id, node in workflow.items():
for input_name, v in node["inputs"].items():
if isinstance(v, list) and len(v) == 2: # is a link
dep_map[tuple(v)] = node, input_name
for id, node in workflow.items():
node["id"] = id
if node["class_type"].startswith("CPackInput"):
if not node.get("inputs"):
continue
inputs[id] = node
elif node["class_type"].startswith("CPackOutput"):
if not node.get("inputs"):
continue
outputs[id] = node
return inputs, outputs
def _validate_workflow(data: dict):
workflow = data.get("workflow_api", {})
if not workflow:
return web.json_response(
{
"result": "error",
"error": "empty workflow",
},
)
input_spec, output_spec = _parse_workflow(workflow)
if not input_spec:
return web.json_response(
{
"result": "error",
"error": "At least one ComfyPack input node is required",
},
)
if not output_spec:
return web.json_response(
{
"result": "error",
"error": "At least one ComfyPack output node is required",
},
)
@PromptServer.instance.routes.post("/bentoml/serve")
async def serve(request):
data = await request.json()
if (error := _validate_workflow(data)) is not None:
return error
DevServer.stop()
if _is_port_in_use(data.get("port", 3000), host=data.get("host", "localhost")):
return web.json_response(
{
"result": "error",
"error": "Port is already in use",
},
)
try:
DevServer.start(workflow_api=data["workflow_api"], port=data.get("port", 3000))
return web.json_response(
{
"result": "success",
"url": f"http://{data.get('host', 'localhost')}:{data.get('port', 3000)}",
},
)
except Exception as e:
return web.json_response(
{
"result": "error",
"error": f"Build failed: {e.__class__.__name__}: {e}",
},
)
@PromptServer.instance.routes.get("/bentoml/serve/heartbeat")
async def heartbeat(_):
running = DevServer.feed_watch_dog()
if running:
return web.json_response({"ready": True})
else:
return web.json_response({"error": "Server is not running"})
@PromptServer.instance.routes.post("/bentoml/serve/terminate")
async def terminate(_):
DevServer.stop()
return web.json_response({"result": "success"})
@PromptServer.instance.routes.get("/bentoml/download/{zip_filename}")
async def download_workspace(request):
zip_filename = request.match_info["zip_filename"]
return web.FileResponse(TEMP_FOLDER / zip_filename)
async def _prepare_pack(
working_dir: ZPath,
data: dict,
store_models: bool = False,
ensure_source: bool = True,
) -> None:
model_filter = set(data.get("models", []))
models = await _get_models(
store_models=store_models,
model_filter=model_filter,
ensure_source=ensure_source,
)
await _write_snapshot(working_dir, data, models)
await _write_workflow(working_dir, data)
await _write_inputs(working_dir, data)
@PromptServer.instance.routes.post("/bentoml/model/query")
async def get_models(request):
data = await request.json()
models = await _get_models(
workflow_api=data.get("workflow_api"),
ensure_sha=False,
ensure_source=False,
)
return web.json_response({"models": models})
async def _get_inputs(workflow_api):
input_dir = folder_paths.get_input_directory()
inputs = []
for src in Path(input_dir).rglob("*"):
if src.is_file():
rel = src.relative_to(input_dir)
badges = []
checked = False
if _is_file_refered(src, workflow_api):
badges.append({"text": "Referenced"})
checked = True
data = {
"path": str(rel),
"badges": badges,
"checked": checked,
}
inputs.append(data)
return inputs
@PromptServer.instance.routes.post("/bentoml/file/query")
async def get_inputs(request):
data = await request.json()
inputs = await _get_inputs(
workflow_api=data.get("workflow_api"),
)
return web.json_response({"files": inputs})
@PromptServer.instance.routes.post("/bentoml/build")
async def build_bento_api(request):
"""Request body: {
workflow_api: dict,
workflow: dict,
bento_name: str,
push?: bool,
api_key?: str,
endpoint?: str,
system_packages?: list[str]
}"""
import bentoml
data = await request.json()
if (error := _validate_workflow(data)) is not None:
return error
with tempfile.TemporaryDirectory(suffix="-bento", prefix="comfy-pack-") as temp_dir:
temp_dir_path = Path(temp_dir)
await _prepare_pack(temp_dir_path, data, store_models=True, ensure_source=False)
# create a bento
try:
bento = build_bento(
data["bento_name"],
temp_dir_path,
system_packages=data.get("system_packages"),
)
except bentoml.exceptions.BentoMLException as e:
return web.json_response(
{
"result": "error",
"error": f"Build failed: {e.__class__.__name__}: {e}",
},
)
if data.get("push", False):
credentials = {}
if api_key := data.get("api_key"):
credentials["api_key"] = api_key
if endpoint := data.get("endpoint"):
credentials["endpoint"] = endpoint
client = bentoml.cloud.BentoCloudClient(**credentials)
try:
client.bento.push(bento)
except bentoml.exceptions.BentoMLException as e:
return web.json_response(
{
"result": "error",
"error": f"Push failed: {e.__class__.__name__}: {e}",
}
)
return web.json_response({"result": "success", "bento": str(bento.tag)})