|
|
|
|
| const { start } = require("../utils");
|
| const lg = require("../utils/litegraph");
|
|
|
| describe("extensions", () => {
|
| beforeEach(() => {
|
| lg.setup(global);
|
| });
|
|
|
| afterEach(() => {
|
| lg.teardown(global);
|
| });
|
|
|
| it("calls each extension hook", async () => {
|
| const mockExtension = {
|
| name: "TestExtension",
|
| init: jest.fn(),
|
| setup: jest.fn(),
|
| addCustomNodeDefs: jest.fn(),
|
| getCustomWidgets: jest.fn(),
|
| beforeRegisterNodeDef: jest.fn(),
|
| registerCustomNodes: jest.fn(),
|
| loadedGraphNode: jest.fn(),
|
| nodeCreated: jest.fn(),
|
| beforeConfigureGraph: jest.fn(),
|
| afterConfigureGraph: jest.fn(),
|
| };
|
|
|
| const { app, ez, graph } = await start({
|
| async preSetup(app) {
|
| app.registerExtension(mockExtension);
|
| },
|
| });
|
|
|
|
|
| expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.init).toHaveBeenCalledWith(app);
|
|
|
|
|
| expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
|
| const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
|
| expect(defs).toHaveProperty("KSampler");
|
| expect(defs).toHaveProperty("LoadImage");
|
|
|
|
|
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
|
|
|
|
|
| const nodeNames = Object.keys(defs);
|
| const nodeCount = nodeNames.length;
|
| expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
| for (let i = 0; i < 10; i++) {
|
|
|
| const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
| const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
|
|
| expect(nodeClass.name).toBe("ComfyNode");
|
| expect(nodeClass.comfyClass).toBe(nodeNames[i]);
|
| expect(nodeDef.name).toBe(nodeNames[i]);
|
| expect(nodeDef).toHaveProperty("input");
|
| expect(nodeDef).toHaveProperty("output");
|
| }
|
|
|
|
|
| expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
|
|
|
|
| expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
|
|
|
| const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
|
|
|
|
|
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
|
| for (let i = 0; i < graphData.nodes.length; i++) {
|
| expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
| }
|
|
|
|
|
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
| for (let i = 0; i < graphData.nodes.length; i++) {
|
| expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
| }
|
|
|
|
|
| expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
|
|
|
| expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.setup).toHaveBeenCalledWith(app);
|
|
|
|
|
| const callOrder = [
|
| "init",
|
| "addCustomNodeDefs",
|
| "getCustomWidgets",
|
| "beforeRegisterNodeDef",
|
| "registerCustomNodes",
|
| "beforeConfigureGraph",
|
| "nodeCreated",
|
| "loadedGraphNode",
|
| "afterConfigureGraph",
|
| "setup",
|
| ];
|
| for (let i = 1; i < callOrder.length; i++) {
|
| const fn1 = mockExtension[callOrder[i - 1]];
|
| const fn2 = mockExtension[callOrder[i]];
|
| expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
|
| }
|
|
|
| graph.clear();
|
|
|
|
|
| ez.LoadImage();
|
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
| expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
|
|
|
|
|
| await graph.reload();
|
|
|
|
|
| expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
| expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
| expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
|
|
|
|
| expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
|
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
| expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
| }, 15000);
|
|
|
| it("allows custom nodeDefs and widgets to be registered", async () => {
|
| const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
| expect(node.constructor.comfyClass).toBe("TestNode");
|
| expect(inputName).toBe("test_input");
|
| expect(inputData[0]).toBe("CUSTOMWIDGET");
|
| expect(inputData[1]?.hello).toBe("world");
|
| expect(app).toStrictEqual(app);
|
|
|
| return {
|
| widget: node.addWidget("button", inputName, "hello", () => {}),
|
| };
|
| });
|
|
|
|
|
| const mockExtension = {
|
| name: "TestExtension",
|
| addCustomNodeDefs: (nodeDefs) => {
|
| nodeDefs["TestNode"] = {
|
| output: [],
|
| output_name: [],
|
| output_is_list: [],
|
| name: "TestNode",
|
| display_name: "TestNode",
|
| category: "Test",
|
| input: {
|
| required: {
|
| test_input: ["CUSTOMWIDGET", { hello: "world" }],
|
| },
|
| },
|
| };
|
| },
|
| getCustomWidgets: jest.fn(() => {
|
| return {
|
| CUSTOMWIDGET: widgetMock,
|
| };
|
| }),
|
| };
|
|
|
| const { graph, ez } = await start({
|
| async preSetup(app) {
|
| app.registerExtension(mockExtension);
|
| },
|
| });
|
|
|
| expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
|
|
|
| graph.clear();
|
| expect(widgetMock).toBeCalledTimes(0);
|
| const node = ez.TestNode();
|
| expect(widgetMock).toBeCalledTimes(1);
|
|
|
|
|
| expect(node.inputs.length).toBe(0);
|
| expect(node.widgets.length).toBe(1);
|
| const w = node.widgets[0].widget;
|
| expect(w.name).toBe("test_input");
|
| expect(w.type).toBe("button");
|
| });
|
| });
|
|
|