3v324v23's picture
lfs
1e3b872
import { app } from "../../../scripts/app.js";
import { addMenuHandler } from "./common/utils.js";
import { findWidgetByName } from "./common/utils.js";
function replaceNode(oldNode, newNodeName) {
// Create new node
const newNode = LiteGraph.createNode(newNodeName);
if (!newNode) {
return;
}
app.graph.add(newNode);
// Position new node at the same position as the old node
newNode.pos = oldNode.pos.slice();
// Define widget mappings
const mappings = {
"KSampler (Efficient) <-> KSampler Adv. (Efficient)": {
seed: "noise_seed",
cfg: "cfg",
sampler_name: "sampler_name",
scheduler: "scheduler",
preview_method: "preview_method",
vae_decode: "vae_decode"
},
"KSampler (Efficient) <-> KSampler SDXL (Eff.)": {
seed: "noise_seed",
cfg: "cfg",
sampler_name: "sampler_name",
scheduler: "scheduler",
preview_method: "preview_method",
vae_decode: "vae_decode"
},
"KSampler Adv. (Efficient) <-> KSampler SDXL (Eff.)": {
noise_seed: "noise_seed",
steps: "steps",
cfg: "cfg",
sampler_name: "sampler_name",
scheduler: "scheduler",
start_at_step: "start_at_step",
preview_method: "preview_method",
vae_decode: "vae_decode"}
};
const swapKey = `${oldNode.type} <-> ${newNodeName}`;
let widgetMapping = {};
// Check if a reverse mapping is needed
if (!mappings[swapKey]) {
const reverseKey = `${newNodeName} <-> ${oldNode.type}`;
const reverseMapping = mappings[reverseKey];
if (reverseMapping) {
widgetMapping = Object.entries(reverseMapping).reduce((acc, [key, value]) => {
acc[value] = key;
return acc;
}, {});
}
} else {
widgetMapping = mappings[swapKey];
}
if (oldNode.type === "KSampler (Efficient)" && (newNodeName === "KSampler Adv. (Efficient)" || newNodeName === "KSampler SDXL (Eff.)")) {
const denoise = Math.min(Math.max(findWidgetByName(oldNode, "denoise").value, 0), 1); // Ensure denoise is between 0 and 1
const steps = Math.min(Math.max(findWidgetByName(oldNode, "steps").value, 0), 10000); // Ensure steps is between 0 and 10000
const total_steps = Math.floor(steps / denoise);
const start_at_step = total_steps - steps;
findWidgetByName(newNode, "steps").value = Math.min(Math.max(total_steps, 0), 10000); // Ensure total_steps is between 0 and 10000
findWidgetByName(newNode, "start_at_step").value = Math.min(Math.max(start_at_step, 0), 10000); // Ensure start_at_step is between 0 and 10000
}
else if ((oldNode.type === "KSampler Adv. (Efficient)" || oldNode.type === "KSampler SDXL (Eff.)") && newNodeName === "KSampler (Efficient)") {
const stepsAdv = Math.min(Math.max(findWidgetByName(oldNode, "steps").value, 0), 10000); // Ensure stepsAdv is between 0 and 10000
const start_at_step = Math.min(Math.max(findWidgetByName(oldNode, "start_at_step").value, 0), 10000); // Ensure start_at_step is between 0 and 10000
const denoise = Math.min(Math.max((stepsAdv - start_at_step) / stepsAdv, 0), 1); // Ensure denoise is between 0 and 1
const stepsTotal = stepsAdv - start_at_step;
findWidgetByName(newNode, "denoise").value = denoise;
findWidgetByName(newNode, "steps").value = Math.min(Math.max(stepsTotal, 0), 10000); // Ensure stepsTotal is between 0 and 10000
}
// Transfer widget values from old node to new node
oldNode.widgets.forEach(widget => {
const newName = widgetMapping[widget.name];
if (newName) {
const newWidget = findWidgetByName(newNode, newName);
if (newWidget) {
newWidget.value = widget.value;
}
}
});
// Determine the starting indices based on the node types
let oldNodeInputStartIndex = 0;
let newNodeInputStartIndex = 0;
let oldNodeOutputStartIndex = 0;
let newNodeOutputStartIndex = 0;
if (oldNode.type === "KSampler SDXL (Eff.)" || newNodeName === "KSampler SDXL (Eff.)") {
oldNodeInputStartIndex = (oldNode.type === "KSampler SDXL (Eff.)") ? 1 : 3;
newNodeInputStartIndex = (newNodeName === "KSampler SDXL (Eff.)") ? 1 : 3;
oldNodeOutputStartIndex = (oldNode.type === "KSampler SDXL (Eff.)") ? 1 : 3;
newNodeOutputStartIndex = (newNodeName === "KSampler SDXL (Eff.)") ? 1 : 3;
}
// Transfer connections from old node to new node
oldNode.inputs.slice(oldNodeInputStartIndex).forEach((input, index) => {
if (input && input.link !== null) {
const originLinkInfo = oldNode.graph.links[input.link];
if (originLinkInfo) {
const originNode = oldNode.graph.getNodeById(originLinkInfo.origin_id);
if (originNode) {
originNode.connect(originLinkInfo.origin_slot, newNode, index + newNodeInputStartIndex);
}
}
}
});
oldNode.outputs.slice(oldNodeOutputStartIndex).forEach((output, index) => {
if (output && output.links) {
output.links.forEach(link => {
const targetLinkInfo = oldNode.graph.links[link];
if (targetLinkInfo) {
const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id);
if (targetNode) {
newNode.connect(index + newNodeOutputStartIndex, targetNode, targetLinkInfo.target_slot);
}
}
});
}
});
// Remove old node
app.graph.remove(oldNode);
}
function replaceNodeMenuCallback(currentNode, targetNodeName) {
return function() {
replaceNode(currentNode, targetNodeName);
};
}
function showSwapMenu(value, options, e, menu, node) {
const swapOptions = [];
if (node.type !== "KSampler (Efficient)") {
swapOptions.push({
content: "KSampler (Efficient)",
callback: replaceNodeMenuCallback(node, "KSampler (Efficient)")
});
}
if (node.type !== "KSampler Adv. (Efficient)") {
swapOptions.push({
content: "KSampler Adv. (Efficient)",
callback: replaceNodeMenuCallback(node, "KSampler Adv. (Efficient)")
});
}
if (node.type !== "KSampler SDXL (Eff.)") {
swapOptions.push({
content: "KSampler SDXL (Eff.)",
callback: replaceNodeMenuCallback(node, "KSampler SDXL (Eff.)")
});
}
new LiteGraph.ContextMenu(swapOptions, {
event: e,
callback: null,
parentMenu: menu,
node: node
});
return false; // This ensures the original context menu doesn't proceed
}
// Extension Definition
app.registerExtension({
name: "efficiency.SwapSamplers",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (["KSampler (Efficient)", "KSampler Adv. (Efficient)", "KSampler SDXL (Eff.)"].includes(nodeData.name)) {
addMenuHandler(nodeType, function (insertOption) {
insertOption({
content: "🔄 Swap with...",
has_submenu: true,
callback: showSwapMenu
});
});
}
},
});