Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- rgthree-comfy/web/comfyui/services/bookmarks_services.js +10 -0
- rgthree-comfy/web/comfyui/services/config_service.js +28 -0
- rgthree-comfy/web/comfyui/services/context_service.js +51 -0
- rgthree-comfy/web/comfyui/services/fast_groups_service.js +138 -0
- rgthree-comfy/web/common/css/buttons.css +90 -0
- rgthree-comfy/web/common/css/dialog.css +124 -0
- rgthree-comfy/web/common/css/dialog_model_info.css +333 -0
- rgthree-comfy/web/common/css/menu.css +91 -0
- rgthree-comfy/web/common/css/pages_base.css +66 -0
- rgthree-comfy/web/common/media/rgthree.svg +7 -0
- rgthree-comfy/web/common/media/svgs.js +160 -0
- rgthree-comfy/web/common/shared_utils.js +142 -0
- rgthree-comfy/web/common/utils_dom.js +311 -0
- rgthree-comfy/web/common/utils_workflow.js +55 -0
- rgthree-comfy/web/link_fixer/link_page.js +195 -0
- sd-dynamic-thresholding/.github/FUNDING.yml +1 -0
- sd-dynamic-thresholding/.github/workflows/publish.yml +21 -0
- sd-dynamic-thresholding/__pycache__/__init__.cpython-312.pyc +0 -0
- sd-dynamic-thresholding/__pycache__/dynthres_comfyui.cpython-312.pyc +0 -0
- sd-dynamic-thresholding/__pycache__/dynthres_core.cpython-312.pyc +0 -0
- sd-dynamic-thresholding/github/comfy_node.png +0 -0
- sd-dynamic-thresholding/github/ui.png +0 -0
- sd-dynamic-thresholding/javascript/active.js +68 -0
- sd-dynamic-thresholding/scripts/dynamic_thresholding.py +270 -0
- sigmas_tools_and_the_golden_scheduler/.github/workflows/publish.yml +21 -0
- sigmas_tools_and_the_golden_scheduler/__pycache__/__init__.cpython-312.pyc +0 -0
- sigmas_tools_and_the_golden_scheduler/__pycache__/sigmas_merge.cpython-312.pyc +0 -0
- stable-diffusion-temperature-settings/.github/FUNDING.yml +3 -0
- stable-diffusion-temperature-settings/.github/workflows/publish.yml +22 -0
- stable-diffusion-temperature-settings/__pycache__/__init__.cpython-312.pyc +0 -0
- stable-diffusion-temperature-settings/__pycache__/nodes.cpython-312.pyc +0 -0
- stable-diffusion-temperature-settings/workflows/tinybottle.png +0 -0
- ultimate-upscale-for-automatic1111/scripts/ultimate-upscale.py +569 -0
- was-node-suite-comfyui/.github/workflows/publish_action.yml +20 -0
- was-node-suite-comfyui/__pycache__/__init__.cpython-312.pyc +0 -0
- was-node-suite-comfyui/modules/BLIP/__init__.py +0 -0
- was-node-suite-comfyui/modules/BLIP/blip_med.py +955 -0
- was-node-suite-comfyui/modules/BLIP/blip_module.py +423 -0
- was-node-suite-comfyui/modules/BLIP/blip_module_license.txt +12 -0
- was-node-suite-comfyui/modules/BLIP/blip_vit.py +305 -0
- was-node-suite-comfyui/modules/__init__.py +0 -0
- was-node-suite-comfyui/repos/SAM/demo/README.md +126 -0
- was-node-suite-comfyui/repos/SAM/demo/package.json +62 -0
- was-node-suite-comfyui/repos/SAM/demo/postcss.config.js +10 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/Interfaces.tsx +29 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/maskUtils.tsx +47 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/onnxModelAPI.tsx +71 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/scaleHelper.tsx +18 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/context.tsx +31 -0
- was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/createContext.tsx +27 -0
rgthree-comfy/web/comfyui/services/bookmarks_services.js
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { app } from "../../../scripts/app.js";
|
| 2 |
+
import { NodeTypesString } from "../constants.js";
|
| 3 |
+
class BookmarksService {
|
| 4 |
+
getCurrentBookmarks() {
|
| 5 |
+
return app.graph._nodes
|
| 6 |
+
.filter((n) => n.type === NodeTypesString.BOOKMARK)
|
| 7 |
+
.sort((a, b) => a.title.localeCompare(b.title));
|
| 8 |
+
}
|
| 9 |
+
}
|
| 10 |
+
export const SERVICE = new BookmarksService();
|
rgthree-comfy/web/comfyui/services/config_service.js
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { rgthreeConfig } from "../../../rgthree/config.js";
|
| 2 |
+
import { getObjectValue, setObjectValue } from "../../../rgthree/common/shared_utils.js";
|
| 3 |
+
import { rgthreeApi } from "../../../rgthree/common/rgthree_api.js";
|
| 4 |
+
class ConfigService extends EventTarget {
|
| 5 |
+
getConfigValue(key, def) {
|
| 6 |
+
return getObjectValue(rgthreeConfig, key, def);
|
| 7 |
+
}
|
| 8 |
+
getFeatureValue(key, def) {
|
| 9 |
+
key = "features." + key.replace(/^features\./, "");
|
| 10 |
+
return getObjectValue(rgthreeConfig, key, def);
|
| 11 |
+
}
|
| 12 |
+
async setConfigValues(changed) {
|
| 13 |
+
const body = new FormData();
|
| 14 |
+
body.append("json", JSON.stringify(changed));
|
| 15 |
+
const response = await rgthreeApi.fetchJson("/config", { method: "POST", body });
|
| 16 |
+
if (response.status === "ok") {
|
| 17 |
+
for (const [key, value] of Object.entries(changed)) {
|
| 18 |
+
setObjectValue(rgthreeConfig, key, value);
|
| 19 |
+
this.dispatchEvent(new CustomEvent("config-change", { detail: { key, value } }));
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
else {
|
| 23 |
+
return false;
|
| 24 |
+
}
|
| 25 |
+
return true;
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
export const SERVICE = new ConfigService();
|
rgthree-comfy/web/comfyui/services/context_service.js
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { getConnectedOutputNodesAndFilterPassThroughs } from "../utils.js";
|
| 2 |
+
export let SERVICE;
|
| 3 |
+
const OWNED_PREFIX = "+";
|
| 4 |
+
const REGEX_PREFIX = /^[\+⚠️]\s*/;
|
| 5 |
+
const REGEX_EMPTY_INPUT = /^\+\s*$/;
|
| 6 |
+
export function stripContextInputPrefixes(name) {
|
| 7 |
+
return name.replace(REGEX_PREFIX, "");
|
| 8 |
+
}
|
| 9 |
+
export function getContextOutputName(inputName) {
|
| 10 |
+
if (inputName === "base_ctx")
|
| 11 |
+
return "CONTEXT";
|
| 12 |
+
return stripContextInputPrefixes(inputName).toUpperCase();
|
| 13 |
+
}
|
| 14 |
+
export var InputMutationOperation;
|
| 15 |
+
(function (InputMutationOperation) {
|
| 16 |
+
InputMutationOperation[InputMutationOperation["UNKNOWN"] = 0] = "UNKNOWN";
|
| 17 |
+
InputMutationOperation[InputMutationOperation["ADDED"] = 1] = "ADDED";
|
| 18 |
+
InputMutationOperation[InputMutationOperation["REMOVED"] = 2] = "REMOVED";
|
| 19 |
+
InputMutationOperation[InputMutationOperation["RENAMED"] = 3] = "RENAMED";
|
| 20 |
+
})(InputMutationOperation || (InputMutationOperation = {}));
|
| 21 |
+
export class ContextService {
|
| 22 |
+
constructor() {
|
| 23 |
+
if (SERVICE) {
|
| 24 |
+
throw new Error("ContextService was already instantiated.");
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
onInputChanges(node, mutation) {
|
| 28 |
+
const childCtxs = getConnectedOutputNodesAndFilterPassThroughs(node, node, 0);
|
| 29 |
+
for (const childCtx of childCtxs) {
|
| 30 |
+
childCtx.handleUpstreamMutation(mutation);
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
getDynamicContextInputsData(node) {
|
| 34 |
+
return node
|
| 35 |
+
.getContextInputsList()
|
| 36 |
+
.map((input, index) => ({
|
| 37 |
+
name: stripContextInputPrefixes(input.name),
|
| 38 |
+
type: String(input.type),
|
| 39 |
+
index,
|
| 40 |
+
}))
|
| 41 |
+
.filter((i) => i.type !== "*");
|
| 42 |
+
}
|
| 43 |
+
getDynamicContextOutputsData(node) {
|
| 44 |
+
return node.outputs.map((output, index) => ({
|
| 45 |
+
name: stripContextInputPrefixes(output.name),
|
| 46 |
+
type: String(output.type),
|
| 47 |
+
index,
|
| 48 |
+
}));
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
SERVICE = new ContextService();
|
rgthree-comfy/web/comfyui/services/fast_groups_service.js
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { app } from "../../../scripts/app.js";
|
| 2 |
+
class FastGroupsService {
|
| 3 |
+
constructor() {
|
| 4 |
+
this.msThreshold = 400;
|
| 5 |
+
this.msLastUnsorted = 0;
|
| 6 |
+
this.msLastAlpha = 0;
|
| 7 |
+
this.msLastPosition = 0;
|
| 8 |
+
this.groupsUnsorted = [];
|
| 9 |
+
this.groupsSortedAlpha = [];
|
| 10 |
+
this.groupsSortedPosition = [];
|
| 11 |
+
this.fastGroupNodes = [];
|
| 12 |
+
this.runScheduledForMs = null;
|
| 13 |
+
this.runScheduleTimeout = null;
|
| 14 |
+
this.runScheduleAnimation = null;
|
| 15 |
+
this.cachedNodeBoundings = null;
|
| 16 |
+
}
|
| 17 |
+
addFastGroupNode(node) {
|
| 18 |
+
this.fastGroupNodes.push(node);
|
| 19 |
+
this.scheduleRun(8);
|
| 20 |
+
}
|
| 21 |
+
removeFastGroupNode(node) {
|
| 22 |
+
var _a;
|
| 23 |
+
const index = this.fastGroupNodes.indexOf(node);
|
| 24 |
+
if (index > -1) {
|
| 25 |
+
this.fastGroupNodes.splice(index, 1);
|
| 26 |
+
}
|
| 27 |
+
if (!((_a = this.fastGroupNodes) === null || _a === void 0 ? void 0 : _a.length)) {
|
| 28 |
+
this.clearScheduledRun();
|
| 29 |
+
this.groupsUnsorted = [];
|
| 30 |
+
this.groupsSortedAlpha = [];
|
| 31 |
+
this.groupsSortedPosition = [];
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
run() {
|
| 35 |
+
if (!this.runScheduledForMs) {
|
| 36 |
+
return;
|
| 37 |
+
}
|
| 38 |
+
for (const node of this.fastGroupNodes) {
|
| 39 |
+
node.refreshWidgets();
|
| 40 |
+
}
|
| 41 |
+
this.clearScheduledRun();
|
| 42 |
+
this.scheduleRun();
|
| 43 |
+
}
|
| 44 |
+
scheduleRun(ms = 500) {
|
| 45 |
+
if (this.runScheduledForMs && ms < this.runScheduledForMs) {
|
| 46 |
+
this.clearScheduledRun();
|
| 47 |
+
}
|
| 48 |
+
if (!this.runScheduledForMs && this.fastGroupNodes.length) {
|
| 49 |
+
this.runScheduledForMs = ms;
|
| 50 |
+
this.runScheduleTimeout = setTimeout(() => {
|
| 51 |
+
this.runScheduleAnimation = requestAnimationFrame(() => this.run());
|
| 52 |
+
}, ms);
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
clearScheduledRun() {
|
| 56 |
+
this.runScheduleTimeout && clearTimeout(this.runScheduleTimeout);
|
| 57 |
+
this.runScheduleAnimation && cancelAnimationFrame(this.runScheduleAnimation);
|
| 58 |
+
this.runScheduleTimeout = null;
|
| 59 |
+
this.runScheduleAnimation = null;
|
| 60 |
+
this.runScheduledForMs = null;
|
| 61 |
+
}
|
| 62 |
+
getBoundingsForAllNodes() {
|
| 63 |
+
if (!this.cachedNodeBoundings) {
|
| 64 |
+
this.cachedNodeBoundings = {};
|
| 65 |
+
for (const node of app.graph._nodes) {
|
| 66 |
+
this.cachedNodeBoundings[node.id] = node.getBounding();
|
| 67 |
+
}
|
| 68 |
+
setTimeout(() => {
|
| 69 |
+
this.cachedNodeBoundings = null;
|
| 70 |
+
}, 50);
|
| 71 |
+
}
|
| 72 |
+
return this.cachedNodeBoundings;
|
| 73 |
+
}
|
| 74 |
+
recomputeInsideNodesForGroup(group) {
|
| 75 |
+
const cachedBoundings = this.getBoundingsForAllNodes();
|
| 76 |
+
const nodes = group.graph._nodes;
|
| 77 |
+
group._nodes.length = 0;
|
| 78 |
+
for (const node of nodes) {
|
| 79 |
+
const node_bounding = cachedBoundings[node.id];
|
| 80 |
+
if (!node_bounding || !LiteGraph.overlapBounding(group._bounding, node_bounding)) {
|
| 81 |
+
continue;
|
| 82 |
+
}
|
| 83 |
+
group._nodes.push(node);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
getGroupsUnsorted(now) {
|
| 87 |
+
const canvas = app.canvas;
|
| 88 |
+
const graph = app.graph;
|
| 89 |
+
if (!canvas.selected_group_moving &&
|
| 90 |
+
(!this.groupsUnsorted.length || now - this.msLastUnsorted > this.msThreshold)) {
|
| 91 |
+
this.groupsUnsorted = [...graph._groups];
|
| 92 |
+
for (const group of this.groupsUnsorted) {
|
| 93 |
+
this.recomputeInsideNodesForGroup(group);
|
| 94 |
+
group._rgthreeHasAnyActiveNode = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS);
|
| 95 |
+
}
|
| 96 |
+
this.msLastUnsorted = now;
|
| 97 |
+
}
|
| 98 |
+
return this.groupsUnsorted;
|
| 99 |
+
}
|
| 100 |
+
getGroupsAlpha(now) {
|
| 101 |
+
const graph = app.graph;
|
| 102 |
+
if (!this.groupsSortedAlpha.length || now - this.msLastAlpha > this.msThreshold) {
|
| 103 |
+
this.groupsSortedAlpha = [...this.getGroupsUnsorted(now)].sort((a, b) => {
|
| 104 |
+
return a.title.localeCompare(b.title);
|
| 105 |
+
});
|
| 106 |
+
this.msLastAlpha = now;
|
| 107 |
+
}
|
| 108 |
+
return this.groupsSortedAlpha;
|
| 109 |
+
}
|
| 110 |
+
getGroupsPosition(now) {
|
| 111 |
+
const graph = app.graph;
|
| 112 |
+
if (!this.groupsSortedPosition.length || now - this.msLastPosition > this.msThreshold) {
|
| 113 |
+
this.groupsSortedPosition = [...this.getGroupsUnsorted(now)].sort((a, b) => {
|
| 114 |
+
const aY = Math.floor(a._pos[1] / 30);
|
| 115 |
+
const bY = Math.floor(b._pos[1] / 30);
|
| 116 |
+
if (aY == bY) {
|
| 117 |
+
const aX = Math.floor(a._pos[0] / 30);
|
| 118 |
+
const bX = Math.floor(b._pos[0] / 30);
|
| 119 |
+
return aX - bX;
|
| 120 |
+
}
|
| 121 |
+
return aY - bY;
|
| 122 |
+
});
|
| 123 |
+
this.msLastPosition = now;
|
| 124 |
+
}
|
| 125 |
+
return this.groupsSortedPosition;
|
| 126 |
+
}
|
| 127 |
+
getGroups(sort) {
|
| 128 |
+
const now = +new Date();
|
| 129 |
+
if (sort === "alphanumeric") {
|
| 130 |
+
return this.getGroupsAlpha(now);
|
| 131 |
+
}
|
| 132 |
+
if (sort === "position") {
|
| 133 |
+
return this.getGroupsPosition(now);
|
| 134 |
+
}
|
| 135 |
+
return this.getGroupsUnsorted(now);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
export const SERVICE = new FastGroupsService();
|
rgthree-comfy/web/common/css/buttons.css
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:not(#fakeid) .rgthree-button-reset {
|
| 2 |
+
position: relative;
|
| 3 |
+
appearance: none;
|
| 4 |
+
cursor: pointer;
|
| 5 |
+
border: 0;
|
| 6 |
+
background: transparent;
|
| 7 |
+
color: inherit;
|
| 8 |
+
padding: 0;
|
| 9 |
+
margin: 0;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
:not(#fakeid) .rgthree-button {
|
| 13 |
+
--padding-top: 7px;
|
| 14 |
+
--padding-bottom: 9px;
|
| 15 |
+
--padding-x: 16px;
|
| 16 |
+
position: relative;
|
| 17 |
+
cursor: pointer;
|
| 18 |
+
border: 0;
|
| 19 |
+
border-radius: 0.25rem;
|
| 20 |
+
background: rgba(0, 0, 0, 0.5);
|
| 21 |
+
color: white;
|
| 22 |
+
font-family: system-ui, sans-serif;
|
| 23 |
+
font-size: 1rem;
|
| 24 |
+
line-height: 1;
|
| 25 |
+
white-space: nowrap;
|
| 26 |
+
text-decoration: none;
|
| 27 |
+
margin: 0.25rem;
|
| 28 |
+
box-shadow: 0px 0px 2px rgb(0, 0, 0);
|
| 29 |
+
background: #212121;
|
| 30 |
+
transition: all 0.1s ease-in-out;
|
| 31 |
+
padding: var(--padding-top) var(--padding-x) var(--padding-bottom);
|
| 32 |
+
display: inline-flex;
|
| 33 |
+
flex-direction: row;
|
| 34 |
+
align-items: center;
|
| 35 |
+
justify-content: center;
|
| 36 |
+
}
|
| 37 |
+
:not(#fakeid) .rgthree-button::before, :not(#fakeid) .rgthree-button::after {
|
| 38 |
+
content: "";
|
| 39 |
+
display: block;
|
| 40 |
+
position: absolute;
|
| 41 |
+
border-radius: 0.25rem;
|
| 42 |
+
left: 0;
|
| 43 |
+
top: 0;
|
| 44 |
+
width: 100%;
|
| 45 |
+
height: 100%;
|
| 46 |
+
box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.12), inset -1px -1px 0px rgba(0, 0, 0, 0.75);
|
| 47 |
+
background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15));
|
| 48 |
+
mix-blend-mode: screen;
|
| 49 |
+
}
|
| 50 |
+
:not(#fakeid) .rgthree-button::after {
|
| 51 |
+
mix-blend-mode: multiply;
|
| 52 |
+
}
|
| 53 |
+
:not(#fakeid) .rgthree-button:hover {
|
| 54 |
+
background: #303030;
|
| 55 |
+
}
|
| 56 |
+
:not(#fakeid) .rgthree-button:active {
|
| 57 |
+
box-shadow: 0px 0px 0px rgba(0, 0, 0, 0);
|
| 58 |
+
background: #121212;
|
| 59 |
+
padding: calc(var(--padding-top) + 1px) calc(var(--padding-x) - 1px) calc(var(--padding-bottom) - 1px) calc(var(--padding-x) + 1px);
|
| 60 |
+
}
|
| 61 |
+
:not(#fakeid) .rgthree-button:active::before, :not(#fakeid) .rgthree-button:active::after {
|
| 62 |
+
box-shadow: 1px 1px 0px rgba(255, 255, 255, 0.15), inset 1px 1px 0px rgba(0, 0, 0, 0.5), inset 1px 3px 5px rgba(0, 0, 0, 0.33);
|
| 63 |
+
}
|
| 64 |
+
:not(#fakeid) .rgthree-button.-blue {
|
| 65 |
+
background: #346599 !important;
|
| 66 |
+
}
|
| 67 |
+
:not(#fakeid) .rgthree-button.-blue:hover {
|
| 68 |
+
background: #3b77b8 !important;
|
| 69 |
+
}
|
| 70 |
+
:not(#fakeid) .rgthree-button.-blue:active {
|
| 71 |
+
background: #1d5086 !important;
|
| 72 |
+
}
|
| 73 |
+
:not(#fakeid) .rgthree-button.-green {
|
| 74 |
+
background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #14580b;
|
| 75 |
+
}
|
| 76 |
+
:not(#fakeid) .rgthree-button.-green:hover {
|
| 77 |
+
background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #1a6d0f;
|
| 78 |
+
}
|
| 79 |
+
:not(#fakeid) .rgthree-button.-green:active {
|
| 80 |
+
background: linear-gradient(to bottom, rgba(0, 0, 0, 0.15), rgba(255, 255, 255, 0.06)), #0f3f09;
|
| 81 |
+
}
|
| 82 |
+
:not(#fakeid) .rgthree-button[disabled] {
|
| 83 |
+
box-shadow: none;
|
| 84 |
+
background: #666 !important;
|
| 85 |
+
color: #aaa;
|
| 86 |
+
pointer-events: none;
|
| 87 |
+
}
|
| 88 |
+
:not(#fakeid) .rgthree-button[disabled]::before, :not(#fakeid) .rgthree-button[disabled]::after {
|
| 89 |
+
display: none;
|
| 90 |
+
}
|
rgthree-comfy/web/common/css/dialog.css
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@charset "UTF-8";
|
| 2 |
+
.rgthree-dialog {
|
| 3 |
+
outline: 0;
|
| 4 |
+
border: 0;
|
| 5 |
+
border-radius: 6px;
|
| 6 |
+
background: #414141;
|
| 7 |
+
color: #fff;
|
| 8 |
+
box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.05), inset -1px -1px 0px rgba(0, 0, 0, 0.5), 2px 2px 20px rgb(0, 0, 0);
|
| 9 |
+
max-width: 800px;
|
| 10 |
+
box-sizing: border-box;
|
| 11 |
+
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
|
| 12 |
+
font-size: 1rem;
|
| 13 |
+
padding: 0;
|
| 14 |
+
max-height: calc(100% - 32px);
|
| 15 |
+
}
|
| 16 |
+
.rgthree-dialog *, .rgthree-dialog *::before, .rgthree-dialog *::after {
|
| 17 |
+
box-sizing: inherit;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.rgthree-dialog-container > * {
|
| 21 |
+
padding: 8px 16px;
|
| 22 |
+
}
|
| 23 |
+
.rgthree-dialog-container > *:first-child {
|
| 24 |
+
padding-top: 16px;
|
| 25 |
+
}
|
| 26 |
+
.rgthree-dialog-container > *:last-child {
|
| 27 |
+
padding-bottom: 16px;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
.rgthree-dialog.-iconed::after {
|
| 31 |
+
content: "";
|
| 32 |
+
font-size: 276px;
|
| 33 |
+
position: absolute;
|
| 34 |
+
right: 0px;
|
| 35 |
+
bottom: 0px;
|
| 36 |
+
opacity: 0.15;
|
| 37 |
+
display: block;
|
| 38 |
+
width: 237px;
|
| 39 |
+
overflow: hidden;
|
| 40 |
+
height: 186px;
|
| 41 |
+
line-height: 1;
|
| 42 |
+
pointer-events: none;
|
| 43 |
+
z-index: -1;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.rgthree-dialog.-iconed.-help::after {
|
| 47 |
+
content: "🛟";
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.rgthree-dialog.-iconed.-settings::after {
|
| 51 |
+
content: "⚙️";
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
@media (max-width: 832px) {
|
| 55 |
+
.rgthree-dialog {
|
| 56 |
+
max-width: calc(100% - 32px);
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
.rgthree-dialog-container-title {
|
| 60 |
+
display: flex;
|
| 61 |
+
flex-direction: row;
|
| 62 |
+
align-items: center;
|
| 63 |
+
justify-content: start;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.rgthree-dialog-container-title > svg:first-child {
|
| 67 |
+
width: 36px;
|
| 68 |
+
height: 36px;
|
| 69 |
+
margin-right: 16px;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.rgthree-dialog-container-title h2 {
|
| 73 |
+
font-size: 1.375rem;
|
| 74 |
+
margin: 0;
|
| 75 |
+
font-weight: bold;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
.rgthree-dialog-container-title h2 small {
|
| 79 |
+
font-size: 0.8125rem;
|
| 80 |
+
font-weight: normal;
|
| 81 |
+
opacity: 0.75;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.rgthree-dialog-container-content {
|
| 85 |
+
overflow: auto;
|
| 86 |
+
max-height: calc(100vh - 200px); /* Arbitrary height to copensate for margin, title, and footer.*/
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.rgthree-dialog-container-content p {
|
| 90 |
+
font-size: 0.8125rem;
|
| 91 |
+
margin-top: 0;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
.rgthree-dialog-container-content ul li p {
|
| 95 |
+
margin-bottom: 4px;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.rgthree-dialog-container-content ul li p + p {
|
| 99 |
+
margin-top: 0.5em;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.rgthree-dialog-container-content ul li ul {
|
| 103 |
+
margin-top: 0.5em;
|
| 104 |
+
margin-bottom: 1em;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
.rgthree-dialog-container-content p code {
|
| 108 |
+
display: inline-block;
|
| 109 |
+
padding: 2px 4px;
|
| 110 |
+
margin: 0px 2px;
|
| 111 |
+
border: 1px solid rgba(255, 255, 255, 0.25);
|
| 112 |
+
border-radius: 3px;
|
| 113 |
+
background: rgba(255, 255, 255, 0.1);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
.rgthree-dialog-container-footer {
|
| 117 |
+
display: flex;
|
| 118 |
+
align-items: center;
|
| 119 |
+
justify-content: center;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
body.rgthree-dialog-open > *:not(.rgthree-dialog):not(.rgthree-top-messages-container) {
|
| 123 |
+
filter: blur(5px);
|
| 124 |
+
}
|
rgthree-comfy/web/common/css/dialog_model_info.css
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.rgthree-info-dialog {
|
| 2 |
+
width: 90vw;
|
| 3 |
+
max-width: 960px;
|
| 4 |
+
}
|
| 5 |
+
.rgthree-info-dialog .rgthree-info-area {
|
| 6 |
+
list-style: none;
|
| 7 |
+
padding: 0;
|
| 8 |
+
margin: 0;
|
| 9 |
+
display: flex;
|
| 10 |
+
}
|
| 11 |
+
.rgthree-info-dialog .rgthree-info-area > li {
|
| 12 |
+
display: inline-flex;
|
| 13 |
+
margin: 0;
|
| 14 |
+
vertical-align: top;
|
| 15 |
+
}
|
| 16 |
+
.rgthree-info-dialog .rgthree-info-area > li + li {
|
| 17 |
+
margin-left: 6px;
|
| 18 |
+
}
|
| 19 |
+
.rgthree-info-dialog .rgthree-info-area > li:not(.-link) + li.-link {
|
| 20 |
+
margin-left: auto;
|
| 21 |
+
}
|
| 22 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * {
|
| 23 |
+
min-height: 24px;
|
| 24 |
+
border-radius: 4px;
|
| 25 |
+
line-height: 1;
|
| 26 |
+
color: rgba(255, 255, 255, 0.85);
|
| 27 |
+
background: rgb(69, 92, 85);
|
| 28 |
+
font-size: 14px;
|
| 29 |
+
font-weight: bold;
|
| 30 |
+
text-decoration: none;
|
| 31 |
+
display: flex;
|
| 32 |
+
height: 1.6em;
|
| 33 |
+
padding-left: 0.5em;
|
| 34 |
+
padding-right: 0.5em;
|
| 35 |
+
padding-bottom: 0.1em;
|
| 36 |
+
align-content: center;
|
| 37 |
+
justify-content: center;
|
| 38 |
+
align-items: center;
|
| 39 |
+
box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5);
|
| 40 |
+
}
|
| 41 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg {
|
| 42 |
+
width: 16px;
|
| 43 |
+
height: 16px;
|
| 44 |
+
}
|
| 45 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg:last-child {
|
| 46 |
+
margin-left: 0.5em;
|
| 47 |
+
}
|
| 48 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *[href] {
|
| 49 |
+
box-shadow: inset 0px 1px 0px rgba(255, 255, 255, 0.25), inset 0px -1px 0px rgba(0, 0, 0, 0.66);
|
| 50 |
+
}
|
| 51 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *:empty {
|
| 52 |
+
display: none;
|
| 53 |
+
}
|
| 54 |
+
.rgthree-info-dialog .rgthree-info-area > li.-type > * {
|
| 55 |
+
background: rgb(73, 54, 94);
|
| 56 |
+
color: rgb(228, 209, 248);
|
| 57 |
+
}
|
| 58 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu {
|
| 59 |
+
margin-left: auto;
|
| 60 |
+
}
|
| 61 |
+
:not(#fakeid) .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu .rgthree-button {
|
| 62 |
+
margin: 0;
|
| 63 |
+
min-height: 24px;
|
| 64 |
+
padding: 0 12px;
|
| 65 |
+
}
|
| 66 |
+
.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu svg {
|
| 67 |
+
width: 16px;
|
| 68 |
+
height: 16px;
|
| 69 |
+
}
|
| 70 |
+
.rgthree-info-dialog .rgthree-info-table {
|
| 71 |
+
border-collapse: collapse;
|
| 72 |
+
margin: 16px 0px;
|
| 73 |
+
width: 100%;
|
| 74 |
+
font-size: 12px;
|
| 75 |
+
}
|
| 76 |
+
.rgthree-info-dialog .rgthree-info-table tr.editable button {
|
| 77 |
+
display: flex;
|
| 78 |
+
width: 28px;
|
| 79 |
+
height: 28px;
|
| 80 |
+
align-items: center;
|
| 81 |
+
justify-content: center;
|
| 82 |
+
}
|
| 83 |
+
.rgthree-info-dialog .rgthree-info-table tr.editable button svg + svg {
|
| 84 |
+
display: none;
|
| 85 |
+
}
|
| 86 |
+
.rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg {
|
| 87 |
+
display: none;
|
| 88 |
+
}
|
| 89 |
+
.rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg + svg {
|
| 90 |
+
display: inline-block;
|
| 91 |
+
}
|
| 92 |
+
.rgthree-info-dialog .rgthree-info-table td {
|
| 93 |
+
position: relative;
|
| 94 |
+
border: 1px solid rgba(255, 255, 255, 0.25);
|
| 95 |
+
padding: 0;
|
| 96 |
+
vertical-align: top;
|
| 97 |
+
}
|
| 98 |
+
.rgthree-info-dialog .rgthree-info-table td:first-child {
|
| 99 |
+
background: rgba(255, 255, 255, 0.075);
|
| 100 |
+
width: 10px;
|
| 101 |
+
}
|
| 102 |
+
.rgthree-info-dialog .rgthree-info-table td:first-child > *:first-child {
|
| 103 |
+
white-space: nowrap;
|
| 104 |
+
padding-right: 32px;
|
| 105 |
+
}
|
| 106 |
+
.rgthree-info-dialog .rgthree-info-table td:first-child small {
|
| 107 |
+
display: block;
|
| 108 |
+
margin-top: 2px;
|
| 109 |
+
opacity: 0.75;
|
| 110 |
+
}
|
| 111 |
+
.rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action] {
|
| 112 |
+
text-decoration: underline;
|
| 113 |
+
cursor: pointer;
|
| 114 |
+
}
|
| 115 |
+
.rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action]:hover {
|
| 116 |
+
text-decoration: none;
|
| 117 |
+
}
|
| 118 |
+
.rgthree-info-dialog .rgthree-info-table td a, .rgthree-info-dialog .rgthree-info-table td a:hover, .rgthree-info-dialog .rgthree-info-table td a:visited {
|
| 119 |
+
color: inherit;
|
| 120 |
+
}
|
| 121 |
+
.rgthree-info-dialog .rgthree-info-table td svg {
|
| 122 |
+
width: 1.3333em;
|
| 123 |
+
height: 1.3333em;
|
| 124 |
+
vertical-align: -0.285em;
|
| 125 |
+
}
|
| 126 |
+
.rgthree-info-dialog .rgthree-info-table td svg.logo-civitai {
|
| 127 |
+
margin-right: 0.3333em;
|
| 128 |
+
}
|
| 129 |
+
.rgthree-info-dialog .rgthree-info-table td > *:first-child {
|
| 130 |
+
display: block;
|
| 131 |
+
padding: 6px 10px;
|
| 132 |
+
}
|
| 133 |
+
.rgthree-info-dialog .rgthree-info-table td > input, .rgthree-info-dialog .rgthree-info-table td > textarea {
|
| 134 |
+
padding: 5px 10px;
|
| 135 |
+
border: 0;
|
| 136 |
+
box-shadow: inset 1px 1px 5px 0px rgba(0, 0, 0, 0.5);
|
| 137 |
+
font: inherit;
|
| 138 |
+
appearance: none;
|
| 139 |
+
background: #fff;
|
| 140 |
+
color: #121212;
|
| 141 |
+
resize: vertical;
|
| 142 |
+
}
|
| 143 |
+
.rgthree-info-dialog .rgthree-info-table td > input:only-child, .rgthree-info-dialog .rgthree-info-table td > textarea:only-child {
|
| 144 |
+
width: 100%;
|
| 145 |
+
}
|
| 146 |
+
:not(#fakeid) .rgthree-info-dialog .rgthree-info-table td .rgthree-button[data-action=fetch-civitai] {
|
| 147 |
+
font-size: inherit;
|
| 148 |
+
padding: 6px 16px;
|
| 149 |
+
margin: 2px;
|
| 150 |
+
}
|
| 151 |
+
.rgthree-info-dialog .rgthree-info-table tr[data-field-name=userNote] td > span:first-child {
|
| 152 |
+
white-space: pre;
|
| 153 |
+
}
|
| 154 |
+
.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td {
|
| 155 |
+
border: 0;
|
| 156 |
+
background: transparent;
|
| 157 |
+
padding: 12px 4px 4px;
|
| 158 |
+
font-size: 1.2em;
|
| 159 |
+
}
|
| 160 |
+
.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td > small {
|
| 161 |
+
font-style: italic;
|
| 162 |
+
opacity: 0.66;
|
| 163 |
+
}
|
| 164 |
+
.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td:empty {
|
| 165 |
+
padding: 4px;
|
| 166 |
+
}
|
| 167 |
+
.rgthree-info-dialog .rgthree-info-table td .-help {
|
| 168 |
+
border: 1px solid currentColor;
|
| 169 |
+
position: absolute;
|
| 170 |
+
right: 5px;
|
| 171 |
+
top: 6px;
|
| 172 |
+
line-height: 1;
|
| 173 |
+
font-size: 11px;
|
| 174 |
+
width: 12px;
|
| 175 |
+
height: 12px;
|
| 176 |
+
border-radius: 8px;
|
| 177 |
+
display: flex;
|
| 178 |
+
align-content: center;
|
| 179 |
+
justify-content: center;
|
| 180 |
+
cursor: help;
|
| 181 |
+
}
|
| 182 |
+
.rgthree-info-dialog .rgthree-info-table td .-help::before {
|
| 183 |
+
content: "?";
|
| 184 |
+
}
|
| 185 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list {
|
| 186 |
+
list-style: none;
|
| 187 |
+
padding: 2px 8px;
|
| 188 |
+
margin: 0;
|
| 189 |
+
display: flex;
|
| 190 |
+
flex-direction: row;
|
| 191 |
+
flex-wrap: wrap;
|
| 192 |
+
max-height: 15vh;
|
| 193 |
+
overflow: auto;
|
| 194 |
+
}
|
| 195 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li {
|
| 196 |
+
display: inline-flex;
|
| 197 |
+
margin: 2px;
|
| 198 |
+
vertical-align: top;
|
| 199 |
+
border-radius: 4px;
|
| 200 |
+
line-height: 1;
|
| 201 |
+
color: rgba(255, 255, 255, 0.85);
|
| 202 |
+
background: rgb(73, 91, 106);
|
| 203 |
+
font-size: 1.2em;
|
| 204 |
+
font-weight: 600;
|
| 205 |
+
text-decoration: none;
|
| 206 |
+
display: flex;
|
| 207 |
+
height: 1.6em;
|
| 208 |
+
align-content: center;
|
| 209 |
+
justify-content: center;
|
| 210 |
+
align-items: center;
|
| 211 |
+
box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5);
|
| 212 |
+
cursor: pointer;
|
| 213 |
+
white-space: nowrap;
|
| 214 |
+
max-width: 183px;
|
| 215 |
+
}
|
| 216 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li:hover {
|
| 217 |
+
background: rgb(68, 109, 142);
|
| 218 |
+
}
|
| 219 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > svg {
|
| 220 |
+
width: auto;
|
| 221 |
+
height: 1.2em;
|
| 222 |
+
}
|
| 223 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > span {
|
| 224 |
+
padding-left: 0.5em;
|
| 225 |
+
padding-right: 0.5em;
|
| 226 |
+
padding-bottom: 0.1em;
|
| 227 |
+
text-overflow: ellipsis;
|
| 228 |
+
overflow: hidden;
|
| 229 |
+
}
|
| 230 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > small {
|
| 231 |
+
align-self: stretch;
|
| 232 |
+
display: flex;
|
| 233 |
+
align-items: center;
|
| 234 |
+
justify-content: center;
|
| 235 |
+
padding: 0 0.5em;
|
| 236 |
+
background: rgba(0, 0, 0, 0.2);
|
| 237 |
+
}
|
| 238 |
+
.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li.-rgthree-is-selected {
|
| 239 |
+
background: rgb(42, 126, 193);
|
| 240 |
+
}
|
| 241 |
+
.rgthree-info-dialog .rgthree-info-images {
|
| 242 |
+
list-style: none;
|
| 243 |
+
padding: 0;
|
| 244 |
+
margin: 0;
|
| 245 |
+
scroll-snap-type: x mandatory;
|
| 246 |
+
display: flex;
|
| 247 |
+
flex-direction: row;
|
| 248 |
+
overflow: auto;
|
| 249 |
+
}
|
| 250 |
+
.rgthree-info-dialog .rgthree-info-images > li {
|
| 251 |
+
scroll-snap-align: start;
|
| 252 |
+
max-width: 90%;
|
| 253 |
+
flex: 0 0 auto;
|
| 254 |
+
display: flex;
|
| 255 |
+
align-items: center;
|
| 256 |
+
justify-content: center;
|
| 257 |
+
flex-direction: column;
|
| 258 |
+
overflow: hidden;
|
| 259 |
+
padding: 0;
|
| 260 |
+
margin: 6px;
|
| 261 |
+
font-size: 0;
|
| 262 |
+
position: relative;
|
| 263 |
+
}
|
| 264 |
+
.rgthree-info-dialog .rgthree-info-images > li figure {
|
| 265 |
+
margin: 0;
|
| 266 |
+
position: static;
|
| 267 |
+
}
|
| 268 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption {
|
| 269 |
+
position: absolute;
|
| 270 |
+
left: 0;
|
| 271 |
+
width: 100%;
|
| 272 |
+
bottom: 0;
|
| 273 |
+
padding: 12px;
|
| 274 |
+
font-size: 12px;
|
| 275 |
+
background: rgba(0, 0, 0, 0.85);
|
| 276 |
+
opacity: 0;
|
| 277 |
+
transform: translateY(50px);
|
| 278 |
+
transition: all 0.25s ease-in-out;
|
| 279 |
+
}
|
| 280 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span {
|
| 281 |
+
display: inline-block;
|
| 282 |
+
padding: 2px 4px;
|
| 283 |
+
margin: 2px;
|
| 284 |
+
border-radius: 2px;
|
| 285 |
+
border: 1px solid rgba(255, 255, 255, 0.2);
|
| 286 |
+
word-break: break-word;
|
| 287 |
+
}
|
| 288 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span label {
|
| 289 |
+
display: inline;
|
| 290 |
+
padding: 0;
|
| 291 |
+
margin: 0;
|
| 292 |
+
opacity: 0.5;
|
| 293 |
+
pointer-events: none;
|
| 294 |
+
user-select: none;
|
| 295 |
+
}
|
| 296 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a {
|
| 297 |
+
color: inherit;
|
| 298 |
+
text-decoration: underline;
|
| 299 |
+
}
|
| 300 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a:hover {
|
| 301 |
+
text-decoration: none;
|
| 302 |
+
}
|
| 303 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a svg {
|
| 304 |
+
height: 10px;
|
| 305 |
+
margin-left: 4px;
|
| 306 |
+
fill: currentColor;
|
| 307 |
+
}
|
| 308 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty {
|
| 309 |
+
text-align: center;
|
| 310 |
+
}
|
| 311 |
+
.rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty::before {
|
| 312 |
+
content: "No data.";
|
| 313 |
+
}
|
| 314 |
+
.rgthree-info-dialog .rgthree-info-images > li:hover figure figcaption {
|
| 315 |
+
opacity: 1;
|
| 316 |
+
transform: translateY(0px);
|
| 317 |
+
}
|
| 318 |
+
.rgthree-info-dialog .rgthree-info-images > li .rgthree-info-table {
|
| 319 |
+
width: calc(100% - 16px);
|
| 320 |
+
}
|
| 321 |
+
.rgthree-info-dialog .rgthree-info-civitai-link {
|
| 322 |
+
margin: 8px;
|
| 323 |
+
color: #eee;
|
| 324 |
+
}
|
| 325 |
+
.rgthree-info-dialog .rgthree-info-civitai-link a, .rgthree-info-dialog .rgthree-info-civitai-link a:hover, .rgthree-info-dialog .rgthree-info-civitai-link a:visited {
|
| 326 |
+
color: inherit;
|
| 327 |
+
text-decoration: none;
|
| 328 |
+
}
|
| 329 |
+
.rgthree-info-dialog .rgthree-info-civitai-link > svg {
|
| 330 |
+
width: 16px;
|
| 331 |
+
height: 16px;
|
| 332 |
+
margin-right: 8px;
|
| 333 |
+
}
|
rgthree-comfy/web/common/css/menu.css
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.rgthree-menu {
|
| 2 |
+
list-style: none;
|
| 3 |
+
padding: 0;
|
| 4 |
+
margin: 0;
|
| 5 |
+
position: fixed;
|
| 6 |
+
z-index: 999999;
|
| 7 |
+
pointer-events: none;
|
| 8 |
+
opacity: 0;
|
| 9 |
+
transition: opacity 0.08s ease-in-out;
|
| 10 |
+
color: #dde;
|
| 11 |
+
background-color: #111;
|
| 12 |
+
font-size: 12px;
|
| 13 |
+
box-shadow: 0 0 10px black !important;
|
| 14 |
+
}
|
| 15 |
+
.rgthree-menu > li {
|
| 16 |
+
position: relative;
|
| 17 |
+
padding: 4px 6px;
|
| 18 |
+
z-index: 9999;
|
| 19 |
+
white-space: nowrap;
|
| 20 |
+
}
|
| 21 |
+
.rgthree-menu > li[role=button] {
|
| 22 |
+
background-color: var(--comfy-menu-bg) !important;
|
| 23 |
+
color: var(--input-text);
|
| 24 |
+
cursor: pointer;
|
| 25 |
+
}
|
| 26 |
+
.rgthree-menu > li[role=button]:hover {
|
| 27 |
+
filter: brightness(155%);
|
| 28 |
+
}
|
| 29 |
+
.rgthree-menu[state^=measuring] {
|
| 30 |
+
display: block;
|
| 31 |
+
opacity: 0;
|
| 32 |
+
}
|
| 33 |
+
.rgthree-menu[state=open] {
|
| 34 |
+
display: block;
|
| 35 |
+
opacity: 1;
|
| 36 |
+
pointer-events: all;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
.rgthree-top-menu {
|
| 40 |
+
box-sizing: border-box;
|
| 41 |
+
white-space: nowrap;
|
| 42 |
+
background: var(--content-bg);
|
| 43 |
+
color: var(--content-fg);
|
| 44 |
+
display: flex;
|
| 45 |
+
flex-direction: column;
|
| 46 |
+
}
|
| 47 |
+
.rgthree-top-menu * {
|
| 48 |
+
box-sizing: inherit;
|
| 49 |
+
}
|
| 50 |
+
.rgthree-top-menu menu {
|
| 51 |
+
list-style: none;
|
| 52 |
+
padding: 0;
|
| 53 |
+
margin: 0;
|
| 54 |
+
}
|
| 55 |
+
.rgthree-top-menu menu > li:not(#fakeid) {
|
| 56 |
+
list-style: none;
|
| 57 |
+
padding: 0;
|
| 58 |
+
margin: 0;
|
| 59 |
+
}
|
| 60 |
+
.rgthree-top-menu menu > li:not(#fakeid) > button {
|
| 61 |
+
cursor: pointer;
|
| 62 |
+
padding: 8px 12px 8px 8px;
|
| 63 |
+
width: 100%;
|
| 64 |
+
text-align: start;
|
| 65 |
+
display: flex;
|
| 66 |
+
flex-direction: row;
|
| 67 |
+
align-items: center;
|
| 68 |
+
justify-content: start;
|
| 69 |
+
}
|
| 70 |
+
.rgthree-top-menu menu > li:not(#fakeid) > button:hover {
|
| 71 |
+
background-color: var(--comfy-input-bg);
|
| 72 |
+
}
|
| 73 |
+
.rgthree-top-menu menu > li:not(#fakeid) > button svg {
|
| 74 |
+
height: 16px;
|
| 75 |
+
width: auto;
|
| 76 |
+
margin-inline-end: 0.6em;
|
| 77 |
+
}
|
| 78 |
+
.rgthree-top-menu menu > li:not(#fakeid) > button svg.github-star {
|
| 79 |
+
fill: rgb(227, 179, 65);
|
| 80 |
+
}
|
| 81 |
+
.rgthree-top-menu menu > li:not(#fakeid).rgthree-message {
|
| 82 |
+
min-height: 32px;
|
| 83 |
+
}
|
| 84 |
+
.rgthree-top-menu menu > li:not(#fakeid).rgthree-message > span {
|
| 85 |
+
padding: 8px 12px;
|
| 86 |
+
display: block;
|
| 87 |
+
width: 100%;
|
| 88 |
+
text-align: center;
|
| 89 |
+
font-style: italic;
|
| 90 |
+
font-size: 12px;
|
| 91 |
+
}
|
rgthree-comfy/web/common/css/pages_base.css
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
html {
|
| 2 |
+
font-size: 100%;
|
| 3 |
+
overflow-y: scroll;
|
| 4 |
+
-webkit-text-size-adjust: 100%;
|
| 5 |
+
-ms-text-size-adjust: 100%;
|
| 6 |
+
box-sizing: border-box;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
*, *:before, *:after {
|
| 10 |
+
box-sizing: inherit;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
:root {
|
| 14 |
+
--header-height: 56px;
|
| 15 |
+
--progress-height: 12px;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
button {
|
| 19 |
+
all: unset;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.-bevel {
|
| 23 |
+
position: relative;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
.-bevel::before {
|
| 27 |
+
content: "";
|
| 28 |
+
position: absolute;
|
| 29 |
+
left: 0;
|
| 30 |
+
top: 0;
|
| 31 |
+
width: 100%;
|
| 32 |
+
height: 100%;
|
| 33 |
+
border: 1px solid red;
|
| 34 |
+
border-color: rgba(255, 255, 255, 0.15) rgba(255, 255, 255, 0.15) rgba(0, 0, 0, 0.5) rgba(0, 0, 0, 0.5);
|
| 35 |
+
z-index: 5;
|
| 36 |
+
pointer-events: none;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
body {
|
| 40 |
+
background: #202020;
|
| 41 |
+
font-family: Arial, sans-serif;
|
| 42 |
+
font-size: 1rem;
|
| 43 |
+
font-weight: 400;
|
| 44 |
+
margin: 0;
|
| 45 |
+
padding-top: calc(var(--header-height) + var(--progress-height));
|
| 46 |
+
color: #ffffff;
|
| 47 |
+
display: flex;
|
| 48 |
+
flex-direction: column;
|
| 49 |
+
align-items: center;
|
| 50 |
+
justify-content: start;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.app-header {
|
| 54 |
+
height: var(--header-height);
|
| 55 |
+
padding: 0;
|
| 56 |
+
position: fixed;
|
| 57 |
+
z-index: 99;
|
| 58 |
+
top: 0;
|
| 59 |
+
left: 0;
|
| 60 |
+
width: 100%;
|
| 61 |
+
background: #353535;
|
| 62 |
+
display: flex;
|
| 63 |
+
flex-direction: row;
|
| 64 |
+
align-items: center;
|
| 65 |
+
justify-content: start;
|
| 66 |
+
}
|
rgthree-comfy/web/common/media/rgthree.svg
ADDED
|
|
rgthree-comfy/web/common/media/svgs.js
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { createElement as $el } from "../utils_dom.js";
|
| 2 |
+
export const logoRgthree = `<svg viewBox="0 0 256 256" fill="currentColor" class="rgthree-logo">
|
| 3 |
+
<path d="M88.503,158.997 L152.731,196.103 L152.738,196.092 L152.762,196.103 L152.769,196.106 L152.771,196.103 L183.922,142.084 L174.153,136.437 L148.611,180.676 L101.512,153.484 L132.193,30.415 L156.124,71.869 L165.896,66.225 L128.002,0.59 "></path>
|
| 4 |
+
<path d="M55.586,148.581l13.44,47.521l0.014,0.051l0.168-0.051l10.689-3.022l-6.589-23.313l45.609,26.335l0.087,0.051l0.027-0.051 l5.617-9.718l-42.648-24.622l35.771-143.45L33.232,164.729l9.77,5.645L55.586,148.581z M87.394,93.484l-16.708,67.018l-5.018-17.747 l-8.028,2.27L87.394,93.484z"></path>
|
| 5 |
+
<path d="M189.85,107.717 L137.892,137.718 L143.532,147.49 L185.723,123.133 L231.109,201.746 L24.895,201.746 L37.363,180.146 L27.592,174.505 L5.347,213.03 L250.653,213.03 "></path>
|
| 6 |
+
<path d="M5.347,247.299v8.111h245.307v-8.111l-41.94-0.003c-1.336,0-2.404-1.065-2.441-2.396v-12.14 c0.037-1.315,1.089-2.368,2.41-2.385h41.972v-8.11H5.347v8.11h41.951c1.338,0.017,2.427,1.104,2.427,2.449v12.01 c0,1.365-1.105,2.462-2.457,2.462L5.347,247.299z M139.438,247.296c-1.334,0-2.406-1.065-2.439-2.396v-12.14 c0.033-1.315,1.085-2.368,2.41-2.385h46.415c1.335,0.017,2.425,1.104,2.425,2.449v12.01c0,1.365-1.103,2.462-2.459,2.462H139.438z M70.193,247.296c-1.339,0-2.408-1.065-2.441-2.396v-12.14c0.033-1.315,1.086-2.368,2.407-2.385h46.418 c1.336,0.017,2.425,1.104,2.425,2.449v12.01c0,1.365-1.103,2.462-2.458,2.462H70.193z"></path>
|
| 7 |
+
</svg>`;
|
| 8 |
+
export const github = `<svg viewBox="0 0 16 16" fill="currentColor" class="github-logo">
|
| 9 |
+
<path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2 0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0 0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16 1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51 1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68 1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
|
| 10 |
+
</svg>`;
|
| 11 |
+
export const iconStarFilled = `<svg viewBox="0 0 16 16" fill="currentColor" class="github-star">
|
| 12 |
+
<path d="M8 .25a.75.75 0 0 1 .673.418l1.882 3.815 4.21.612a.75.75 0 0 1 .416 1.279l-3.046 2.97.719 4.192a.751.751 0 0 1-1.088.791L8 12.347l-3.766 1.98a.75.75 0 0 1-1.088-.79l.72-4.194L.818 6.374a.75.75 0 0 1 .416-1.28l4.21-.611L7.327.668A.75.75 0 0 1 8 .25Z"></path>
|
| 13 |
+
</svg>`;
|
| 14 |
+
export const iconReplace = `<svg viewBox="0 0 52 52" fill="currentColor">
|
| 15 |
+
<path d="M20,37.5c0-0.8-0.7-1.5-1.5-1.5h-15C2.7,36,2,36.7,2,37.5v11C2,49.3,2.7,50,3.5,50h15c0.8,0,1.5-0.7,1.5-1.5 V37.5z"/>
|
| 16 |
+
<path d="M8.1,22H3.2c-1,0-1.5,0.9-0.9,1.4l8,8.3c0.4,0.3,1,0.3,1.4,0l8-8.3c0.6-0.6,0.1-1.4-0.9-1.4h-4.7 c0-5,4.9-10,9.9-10V6C15,6,8.1,13,8.1,22z"/>
|
| 17 |
+
<path d="M41.8,20.3c-0.4-0.3-1-0.3-1.4,0l-8,8.3c-0.6,0.6-0.1,1.4,0.9,1.4h4.8c0,6-4.1,10-10.1,10v6 c9,0,16.1-7,16.1-16H49c1,0,1.5-0.9,0.9-1.4L41.8,20.3z"/>
|
| 18 |
+
<path d="M50,3.5C50,2.7,49.3,2,48.5,2h-15C32.7,2,32,2.7,32,3.5v11c0,0.8,0.7,1.5,1.5,1.5h15c0.8,0,1.5-0.7,1.5-1.5 V3.5z"/>
|
| 19 |
+
</svg>`;
|
| 20 |
+
export const iconNode = `<svg viewBox="0 -0.5 25 25" fill="none">
|
| 21 |
+
<path fill-rule="evenodd" clip-rule="evenodd" d="M15.5 19H9.5C7.29086 19 5.5 17.2091 5.5 15V9C5.5 6.79086 7.29086 5 9.5 5H15.5C17.7091 5 19.5 6.79086 19.5 9V15C19.5 17.2091 17.7091 19 15.5 19Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
| 22 |
+
<path d="M19.5 9.75C19.9142 9.75 20.25 9.41421 20.25 9C20.25 8.58579 19.9142 8.25 19.5 8.25V9.75ZM5.5 8.25C5.08579 8.25 4.75 8.58579 4.75 9C4.75 9.41421 5.08579 9.75 5.5 9.75V8.25ZM11.5 14.25C11.0858 14.25 10.75 14.5858 10.75 15C10.75 15.4142 11.0858 15.75 11.5 15.75V14.25ZM13.5 15.75C13.9142 15.75 14.25 15.4142 14.25 15C14.25 14.5858 13.9142 14.25 13.5 14.25V15.75ZM19.5 8.25H5.5V9.75H19.5V8.25ZM11.5 15.75H13.5V14.25H11.5V15.75Z" fill="currentColor" />
|
| 23 |
+
</svg>`;
|
| 24 |
+
export const iconGear = `<svg viewBox="0 0 24 24" fill="currentColor">
|
| 25 |
+
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.7848 0.449982C13.8239 0.449982 14.7167 1.16546 14.9122 2.15495L14.9991 2.59495C15.3408 4.32442 17.1859 5.35722 18.9016 4.7794L19.3383 4.63233C20.3199 4.30175 21.4054 4.69358 21.9249 5.56605L22.7097 6.88386C23.2293 7.75636 23.0365 8.86366 22.2504 9.52253L21.9008 9.81555C20.5267 10.9672 20.5267 13.0328 21.9008 14.1844L22.2504 14.4774C23.0365 15.1363 23.2293 16.2436 22.7097 17.1161L21.925 18.4339C21.4054 19.3064 20.3199 19.6982 19.3382 19.3676L18.9017 19.2205C17.1859 18.6426 15.3408 19.6754 14.9991 21.405L14.9122 21.845C14.7167 22.8345 13.8239 23.55 12.7848 23.55H11.2152C10.1761 23.55 9.28331 22.8345 9.08781 21.8451L9.00082 21.4048C8.65909 19.6754 6.81395 18.6426 5.09822 19.2205L4.66179 19.3675C3.68016 19.6982 2.59465 19.3063 2.07505 18.4338L1.2903 17.1161C0.770719 16.2436 0.963446 15.1363 1.74956 14.4774L2.09922 14.1844C3.47324 13.0327 3.47324 10.9672 2.09922 9.8156L1.74956 9.52254C0.963446 8.86366 0.77072 7.75638 1.2903 6.8839L2.07508 5.56608C2.59466 4.69359 3.68014 4.30176 4.66176 4.63236L5.09831 4.77939C6.81401 5.35722 8.65909 4.32449 9.00082 2.59506L9.0878 2.15487C9.28331 1.16542 10.176 0.449982 11.2152 0.449982H12.7848ZM12 15.3C13.8225 15.3 15.3 13.8225 15.3 12C15.3 10.1774 13.8225 8.69998 12 8.69998C10.1774 8.69998 8.69997 10.1774 8.69997 12C8.69997 13.8225 10.1774 15.3 12 15.3Z" />
|
| 26 |
+
</svg>`;
|
| 27 |
+
export const checkmark = `<svg viewBox="0 0 32 32" fill="currentColor" class="icon-checkmark">
|
| 28 |
+
<g transform="translate(-518.000000, -1039.000000)">
|
| 29 |
+
<path d="M548.783,1040.2 C547.188,1038.57 544.603,1038.57 543.008,1040.2 L528.569,1054.92 L524.96,1051.24 C523.365,1049.62 520.779,1049.62 519.185,1051.24 C517.59,1052.87 517.59,1055.51 519.185,1057.13 L525.682,1063.76 C527.277,1065.39 529.862,1065.39 531.457,1063.76 L548.783,1046.09 C550.378,1044.46 550.378,1041.82 548.783,1040.2"></path>
|
| 30 |
+
</g>
|
| 31 |
+
</svg>`;
|
| 32 |
+
export const logoCivitai = `<svg viewBox="0 0 178 178" class="logo-civitai">
|
| 33 |
+
<defs>
|
| 34 |
+
<linearGradient id="bgblue" gradientUnits="userSpaceOnUse" x1="89.3" y1="-665.5" x2="89.3" y2="-841.1" gradientTransform="matrix(1 0 0 -1 0 -664)">
|
| 35 |
+
<stop offset="0" style="stop-color:#1284F7"/>
|
| 36 |
+
<stop offset="1" style="stop-color:#0A20C9"/>
|
| 37 |
+
</linearGradient>
|
| 38 |
+
</defs>
|
| 39 |
+
<path fill="#000" d="M13.3,45.4v87.7l76,43.9l76-43.9V45.4l-76-43.9L13.3,45.4z"/>
|
| 40 |
+
<path style="fill:url(#bgblue);" d="M89.3,29.2l52,30v60l-52,30l-52-30v-60 L89.3,29.2 M89.3,1.5l-76,43.9v87.8l76,43.9l76-43.9V45.4L89.3,1.5z" />
|
| 41 |
+
<path fill="#FFF" d="M104.1,97.2l-14.9,8.5l-14.9-8.5v-17l14.9-8.5l14.9,8.5h18.2V69.7l-33-19l-33,19v38.1l33,19l33-19V97.2H104.1z" />
|
| 42 |
+
</svg>`;
|
| 43 |
+
export const iconOutLink = `<svg viewBox="0 0 32 32">
|
| 44 |
+
<path d="M 18 5 L 18 7 L 23.5625 7 L 11.28125 19.28125 L 12.71875 20.71875 L 25 8.4375 L 25 14 L 27 14 L 27 5 Z M 5 9 L 5 27 L 23 27 L 23 14 L 21 16 L 21 25 L 7 25 L 7 11 L 16 11 L 18 9 Z"></path>
|
| 45 |
+
</svg>`;
|
| 46 |
+
export const link = `<svg viewBox="0 0 640 512">
|
| 47 |
+
<path d="M598.6 41.41C570.1 13.8 534.8 0 498.6 0s-72.36 13.8-99.96 41.41l-43.36 43.36c15.11 8.012 29.47 17.58 41.91 30.02c3.146 3.146 5.898 6.518 8.742 9.838l37.96-37.96C458.5 72.05 477.1 64 498.6 64c20.67 0 40.1 8.047 54.71 22.66c14.61 14.61 22.66 34.04 22.66 54.71s-8.049 40.1-22.66 54.71l-133.3 133.3C405.5 343.1 386 352 365.4 352s-40.1-8.048-54.71-22.66C296 314.7 287.1 295.3 287.1 274.6s8.047-40.1 22.66-54.71L314.2 216.4C312.1 212.5 309.9 208.5 306.7 205.3C298.1 196.7 286.8 192 274.6 192c-11.93 0-23.1 4.664-31.61 12.97c-30.71 53.96-23.63 123.6 22.39 169.6C293 402.2 329.2 416 365.4 416c36.18 0 72.36-13.8 99.96-41.41L598.6 241.3c28.45-28.45 42.24-66.01 41.37-103.3C639.1 102.1 625.4 68.16 598.6 41.41zM234 387.4L196.1 425.3C181.5 439.1 162 448 141.4 448c-20.67 0-40.1-8.047-54.71-22.66c-14.61-14.61-22.66-34.04-22.66-54.71s8.049-40.1 22.66-54.71l133.3-133.3C234.5 168 253.1 160 274.6 160s40.1 8.048 54.71 22.66c14.62 14.61 22.66 34.04 22.66 54.71s-8.047 40.1-22.66 54.71L325.8 295.6c2.094 3.939 4.219 7.895 7.465 11.15C341.9 315.3 353.3 320 365.4 320c11.93 0 23.1-4.664 31.61-12.97c30.71-53.96 23.63-123.6-22.39-169.6C346.1 109.8 310.8 96 274.6 96C238.4 96 202.3 109.8 174.7 137.4L41.41 270.7c-27.6 27.6-41.41 63.78-41.41 99.96c-.0001 36.18 13.8 72.36 41.41 99.97C69.01 498.2 105.2 512 141.4 512c36.18 0 72.36-13.8 99.96-41.41l43.36-43.36c-15.11-8.012-29.47-17.58-41.91-30.02C239.6 394.1 236.9 390.7 234 387.4z"/>
|
| 48 |
+
</svg>`;
|
| 49 |
+
export const pencil = `<svg viewBox="0 0 24 24">
|
| 50 |
+
<path d="M 16.9375 1.0625 L 3.875 14.125 L 1.0742188 22.925781 L 9.875 20.125 L 22.9375 7.0625 C 22.9375 7.0625 22.8375 4.9615 20.9375 3.0625 C 19.0375 1.1625 16.9375 1.0625 16.9375 1.0625 z M 17.3125 2.6875 C 18.3845 2.8915 19.237984 3.3456094 19.896484 4.0214844 C 20.554984 4.6973594 21.0185 5.595 21.3125 6.6875 L 19.5 8.5 L 15.5 4.5 L 16.9375 3.0625 L 17.3125 2.6875 z M 4.9785156 15.126953 C 4.990338 15.129931 6.1809555 15.430955 7.375 16.625 C 8.675 17.825 8.875 18.925781 8.875 18.925781 L 8.9179688 18.976562 L 5.3691406 20.119141 L 3.8730469 18.623047 L 4.9785156 15.126953 z"/>
|
| 51 |
+
</svg>`;
|
| 52 |
+
export const dotdotdot = `<svg viewBox="0 0 24 24" fill="currentColor">
|
| 53 |
+
<circle cy="12" r="3" cx="3"></circle>
|
| 54 |
+
<circle cy="12" r="3" cx="12"></circle>
|
| 55 |
+
<circle cx="21" cy="12" r="3"></circle>
|
| 56 |
+
</svg>`;
|
| 57 |
+
export const models = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 58 |
+
<path d="M4 4h6v6h-6z"></path>
|
| 59 |
+
<path d="M14 4h6v6h-6z"></path>
|
| 60 |
+
<path d="M4 14h6v6h-6z"></path>
|
| 61 |
+
<path d="M17 17m-3 0a3 3 0 1 0 6 0a3 3 0 1 0 -6 0"></path>
|
| 62 |
+
</svg>`;
|
| 63 |
+
export const pencilColored = `<svg viewBox="0 0 64 64">
|
| 64 |
+
<path fill="#ffce31" d="M7.934 41.132L39.828 9.246l14.918 14.922l-31.895 31.886z"></path>
|
| 65 |
+
<path d="M61.3 4.6l-1.9-1.9C55.8-.9 50-.9 46.3 2.7l-6.5 6.5l15 15l6.5-6.5c3.6-3.6 3.6-9.5 0-13.1" fill="#ed4c5c"></path>
|
| 66 |
+
<path fill="#93a2aa" d="M35.782 13.31l4.1-4.102l14.92 14.92l-4.1 4.101z"></path>
|
| 67 |
+
<path fill="#c7d3d8" d="M37.338 14.865l4.1-4.101l11.739 11.738l-4.102 4.1z"></path>
|
| 68 |
+
<path fill="#fed0ac" d="M7.9 41.1l-6.5 17l4.5 4.5l17-6.5z"/>
|
| 69 |
+
<path d="M.3 61.1c-.9 2.4.3 3.5 2.7 2.6l8.2-3.1l-7.7-7.7l-3.2 8.2" fill="#333"></path>
|
| 70 |
+
<path fill="#ffdf85" d="M7.89 41.175l27.86-27.86l4.95 4.95l-27.86 27.86z"/>
|
| 71 |
+
<path fill="#ff8736" d="M17.904 51.142l27.86-27.86l4.95 4.95l-27.86 27.86z"></path>
|
| 72 |
+
</svg>`;
|
| 73 |
+
export const diskColored = `<svg viewBox="-0.01 -0.008 100.016 100.016">
|
| 74 |
+
<path fill="#26f" fill_="#23475F" d="M88.555-.008H83v.016a2 2 0 0 1-2 2H19a2 2 0 0 1-2-2v-.016H4a4 4 0 0 0-4 4v92.016a4 4 0 0 0 4 4h92a4 4 0 0 0 4-4V11.517c.049-.089-11.436-11.454-11.445-11.525z"/>
|
| 75 |
+
<path fill="#04d" fill_="#1C3C50" d="M81.04 53.008H18.96a2 2 0 0 0-2 2v45h66.08v-45c0-1.106-.895-2-2-2zm-61.957-10h61.834a2 2 0 0 0 2-2V.547A1.993 1.993 0 0 1 81 2.007H19c-.916 0-1.681-.62-1.917-1.46v40.46a2 2 0 0 0 2 2.001z"/>
|
| 76 |
+
<path fill="#EBF0F1" d="M22 55.977h56a2 2 0 0 1 2 2v37.031a2 2 0 0 1-2 2H22c-1.104 0-2-.396-2-1.5V57.977a2 2 0 0 1 2-2z"/>
|
| 77 |
+
<path fill="#BCC4C8" d="M25 77.008h50v1H25v-1zm0 10h50v1H25v-1z"/>
|
| 78 |
+
<path fill="#1C3C50" d="M7 84.008h3a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2v-3a2 2 0 0 1 2-2zm83 0h3a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2h-3a2 2 0 0 1-2-2v-3a2 2 0 0 1 2-2z"/>
|
| 79 |
+
<path fill="#BCC4C8" d="M37 1.981v36.026a2 2 0 0 0 2 2h39a2 2 0 0 0 2-2V1.981c0 .007-42.982.007-43 0zm37 29.027a2 2 0 0 1-2 2h-6a2 2 0 0 1-2-2V10.981a2 2 0 0 1 2-2h6a2 2 0 0 1 2 2v20.027z"/>
|
| 80 |
+
<path fill="#FF9D00" d="M78 55.977H22a2 2 0 0 0-2 2v10.031h60V57.977a2 2 0 0 0-2-2z"/>
|
| 81 |
+
</svg>`;
|
| 82 |
+
export const folderColored = `<svg viewBox="0 0 501.379 501.379">
|
| 83 |
+
<path style="fill:#EF9F2C;" d="M406.423,93.889H205.889c-17.067,0-30.933-13.867-30.933-30.933s-13.867-30.933-30.933-30.933H30.956
|
| 84 |
+
c-17.067,0-30.933,13.867-30.933,30.933v375.467c0,17.067,13.867,30.933,30.933,30.933h375.467
|
| 85 |
+
c17.067,0,30.933-13.867,30.933-30.933v-313.6C436.289,107.756,422.423,93.889,406.423,93.889z"/>
|
| 86 |
+
<path style="fill:#FEC656;" d="M470.423,157.889H97.089c-13.867,0-26.667,9.6-29.867,22.4l-66.133,249.6
|
| 87 |
+
c-5.333,19.2,9.6,38.4,29.867,38.4h373.333c13.867,0,26.667-9.6,29.867-22.4l66.133-248.533
|
| 88 |
+
C505.623,177.089,490.689,157.889,470.423,157.889z"/>
|
| 89 |
+
</svg>`;
|
| 90 |
+
export const modelsColored = `<svg viewBox="0 0 24 24">
|
| 91 |
+
<path fill="#aa3366" d="M0 0h10v10h-10z"></path>
|
| 92 |
+
<path d="M14 0h10v10h-10z" fill="#3366aa"></path>
|
| 93 |
+
<path d="M0 14h10v10h-10z" fill="#66aa33"></path>
|
| 94 |
+
<path fill="#dd9922" d="M19 19m-5 0 a5 5 0 1 0 10 0 a5 5 0 1 0 -10 0"></path>
|
| 95 |
+
</svg>`;
|
| 96 |
+
export const legoBlocksColored = `<svg viewBox="0 0 512 512">
|
| 97 |
+
<g>
|
| 98 |
+
<rect x="57.67" style="fill:#00BAB9;" width="101.275" height="78.769"/>
|
| 99 |
+
<rect x="205.363" style="fill:#00BAB9;" width="101.275" height="78.769"/>
|
| 100 |
+
<rect x="353.055" style="fill:#00BAB9;" width="101.275" height="78.769"/>
|
| 101 |
+
</g>
|
| 102 |
+
<polygon style="fill:#B8DE6F;" points="478.242,289.758 478.242,512 33.758,512 33.758,289.758 256,267.253 "/>
|
| 103 |
+
<polygon style="fill:#41D4D3;" points="478.242,67.516 478.242,289.758 33.758,289.758 33.758,67.516 57.67,67.516 158.945,67.516
|
| 104 |
+
205.363,67.516 306.637,67.516 353.055,67.516 454.33,67.516 "/>
|
| 105 |
+
<g>
|
| 106 |
+
<circle style="fill:#00BAB9;" cx="402.286" cy="143.473" r="8.44"/>
|
| 107 |
+
<circle style="fill:#00BAB9;" cx="368.527" cy="177.231" r="8.44"/>
|
| 108 |
+
</g>
|
| 109 |
+
<circle style="fill:#7BD288;" cx="109.714" cy="436.044" r="8.44"/>
|
| 110 |
+
</svg>`;
|
| 111 |
+
export const legoBlockColored = `<svg viewBox="0 0 256 256">
|
| 112 |
+
<style>
|
| 113 |
+
.s0 { fill: #ff0000 }
|
| 114 |
+
.s1 { fill: #c30000 }
|
| 115 |
+
.s2 { fill: #800000 }
|
| 116 |
+
.s3 { fill: #cc0000 }
|
| 117 |
+
.s4 { fill: #e00000 }
|
| 118 |
+
</style>
|
| 119 |
+
<g id="Folder 2">
|
| 120 |
+
<path id="Shape 1 copy 2" class="s0" d="m128 61l116 45-116 139-116-139z"/>
|
| 121 |
+
<path id="Shape 1" class="s1" d="m12 106l116 45v95l-116-45z"/>
|
| 122 |
+
<path id="Shape 1 copy" class="s2" d="m244 106l-116 45v95l116-45z"/>
|
| 123 |
+
<g id="Folder 1">
|
| 124 |
+
<path id="Shape 2" class="s3" d="m102 111.2c0-6.1 11.4-9.9 25.5-9.9 14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9-14.1 0-25.5-4.8-25.5-10.9 0-3.3 0-13.3 0-16.6z"/>
|
| 125 |
+
<path id="Shape 2 copy 4" class="s1" d="m102 111.2c0-6.1 11.4-9.9 25.5-9.9 14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9-14.1 0-25.5-4.8-25.5-10.9 0-3.3 0-13.3 0-16.6z"/>
|
| 126 |
+
<path id="Shape 2 copy 2" class="s2" d="m127.5 101.3c14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9 0-13.1 0-25.7 0-37.4z"/>
|
| 127 |
+
<path id="Shape 2 copy" class="s0" d="m127.5 118.8c-12.2 0-22-3.4-22-7.6 0-4.2 9.8-7.7 22-7.7 12.2 0 22 3.5 22 7.7 0 4.2-9.8 7.6-22 7.6zm0 0c-12.2 0-22-3.4-22-7.6 0-4.2 9.8-7.7 22-7.7 12.2 0 22 3.5 22 7.7 0 4.2-9.8 7.6-22 7.6z"/>
|
| 128 |
+
</g>
|
| 129 |
+
<g id="Folder 1 copy">
|
| 130 |
+
<path id="Shape 2" class="s4" d="m103 67.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 131 |
+
<path id="Shape 2 copy 4" class="s1" d="m103 67.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 132 |
+
<path id="Shape 2 copy 2" class="s2" d="m127.5 58c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
|
| 133 |
+
<path id="Shape 2 copy" class="s0" d="m127.5 74.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
|
| 134 |
+
</g>
|
| 135 |
+
<g id="Folder 1 copy 2">
|
| 136 |
+
<path id="Shape 2" class="s4" d="m161 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 137 |
+
<path id="Shape 2 copy 4" class="s1" d="m161 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 138 |
+
<path id="Shape 2 copy 2" class="s2" d="m185.5 80c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
|
| 139 |
+
<path id="Shape 2 copy" class="s0" d="m185.5 96.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
|
| 140 |
+
</g>
|
| 141 |
+
<g id="Folder 1 copy 3">
|
| 142 |
+
<path id="Shape 2" class="s4" d="m45 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 143 |
+
<path id="Shape 2 copy 4" class="s1" d="m45 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
|
| 144 |
+
<path id="Shape 2 copy 2" class="s2" d="m69.5 80c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
|
| 145 |
+
<path id="Shape 2 copy" class="s0" d="m69.5 96.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
|
| 146 |
+
</g>
|
| 147 |
+
</g>
|
| 148 |
+
</svg>`;
|
| 149 |
+
export const gearColored = `<svg viewBox="0 0 128 128" preserveAspectRatio="xMidYMid meet">
|
| 150 |
+
<path d="M124 71.85v-15.7c0-.59-.45-1.09-1.03-1.15l-17.83-1.89c-.47-.05-.85-.38-.98-.83c-.86-2.95-2.03-5.76-3.48-8.39c-.23-.41-.19-.92.11-1.28l11.28-13.94c.37-.46.34-1.13-.08-1.54l-11.1-11.1a1.15 1.15 0 0 0-1.54-.08L85.39 27.22c-.37.3-.87.33-1.28.11a41.796 41.796 0 0 0-8.39-3.48c-.45-.13-.78-.51-.83-.98L73 5.03C72.94 4.45 72.44 4 71.85 4h-15.7c-.59 0-1.09.45-1.15 1.03l-1.89 17.83c-.05.47-.38.85-.83.98c-2.95.86-5.76 2.03-8.39 3.48c-.41.23-.92.19-1.28-.11L28.67 15.94a1.15 1.15 0 0 0-1.54.08l-11.1 11.1a1.15 1.15 0 0 0-.08 1.54L27.23 42.6c.3.37.33.87.11 1.28a41.796 41.796 0 0 0-3.48 8.39c-.13.45-.51.78-.98.83L5.03 55c-.58.06-1.03.56-1.03 1.15v15.7c0 .59.45 1.09 1.03 1.15l17.83 1.89c.47.05.85.38.98.83c.86 2.95 2.03 5.76 3.48 8.39c.23.41.19.92-.11 1.28L15.94 99.33c-.37.46-.34 1.13.08 1.54l11.1 11.1c.42.42 1.08.45 1.54.08l13.94-11.28c.37-.3.87-.33 1.28-.11c2.64 1.45 5.45 2.62 8.39 3.48c.45.13.78.51.83.98l1.9 17.85c.06.59.56 1.03 1.15 1.03h15.7c.59 0 1.09-.45 1.15-1.03l1.89-17.83c.05-.47.38-.85.83-.98c2.95-.86 5.76-2.03 8.39-3.48c.41-.23.92-.19 1.28.11l13.94 11.28c.46.37 1.13.34 1.54-.08l11.1-11.1c.42-.42.45-1.08.08-1.54l-11.28-13.94c-.3-.37-.33-.87-.11-1.28c1.45-2.64 2.62-5.45 3.48-8.39c.13-.45.51-.78.98-.83L122.97 73c.58-.06 1.03-.56 1.03-1.15zm-60 3.43c-6.23 0-11.28-5.05-11.28-11.28S57.77 52.72 64 52.72S75.28 57.77 75.28 64S70.23 75.28 64 75.28z" fill="#82aec0"></path>
|
| 151 |
+
<path d="M80.56 49.48c3.67 4.18 5.78 9.77 5.43 15.85c-.65 11.16-9.83 20.19-21 20.68c-4.75.21-9.18-1.09-12.86-3.45c-.28-.18-.58.2-.34.44a22.412 22.412 0 0 0 17.85 6.67c10.78-.85 19.56-9.5 20.55-20.27c.77-8.36-3.06-15.87-9.23-20.33c-.29-.2-.62.15-.4.41z" fill="#2f7889"></path>
|
| 152 |
+
<path d="M43.87 65.32c-.67-13.15 7.83-22.79 20.01-22.79c.65 0 1.68 0 2.48.92c1.01 1.18 1.1 2.6 0 3.77c-.81.86-1.95.92-2.53 1c-12.3 1.59-15.18 9.35-15.83 16.77c-.03.33.06 2.35-1.71 2.56c-2.15.25-2.41-1.91-2.42-2.23z" fill="#b9e4ea"></path>
|
| 153 |
+
<path d="M25.24 65.87c-.01-22.03 15.9-40.19 38.13-41.05c.68-.03 2.45 0 3.55.99c1.01.91 1.38 2.51.79 3.82c-.95 2.11-2.85 2.07-3.36 2.09c-18.51.66-34.18 15.73-34.19 33.95c0 .29-.05.58-.15.84l-.1.25c-.76 1.98-3.52 2.09-4.43.18c-.15-.34-.24-.7-.24-1.07z" fill="#94d1e0"></path>
|
| 154 |
+
</svg>`;
|
| 155 |
+
export function $svg(markup, attrs) {
|
| 156 |
+
if (!markup.match(/^\s*<svg/)) {
|
| 157 |
+
throw new Error("Cannot call $svg with non-svg markup.");
|
| 158 |
+
}
|
| 159 |
+
return $el(markup, attrs || {});
|
| 160 |
+
}
|
rgthree-comfy/web/common/shared_utils.js
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export function getResolver(timeout = 5000) {
|
| 2 |
+
const resolver = {};
|
| 3 |
+
resolver.id = generateId(8);
|
| 4 |
+
resolver.completed = false;
|
| 5 |
+
resolver.resolved = false;
|
| 6 |
+
resolver.rejected = false;
|
| 7 |
+
resolver.promise = new Promise((resolve, reject) => {
|
| 8 |
+
resolver.reject = (e) => {
|
| 9 |
+
resolver.completed = true;
|
| 10 |
+
resolver.rejected = true;
|
| 11 |
+
reject(e);
|
| 12 |
+
};
|
| 13 |
+
resolver.resolve = (data) => {
|
| 14 |
+
resolver.completed = true;
|
| 15 |
+
resolver.resolved = true;
|
| 16 |
+
resolve(data);
|
| 17 |
+
};
|
| 18 |
+
});
|
| 19 |
+
resolver.timeout = setTimeout(() => {
|
| 20 |
+
if (!resolver.completed) {
|
| 21 |
+
resolver.reject();
|
| 22 |
+
}
|
| 23 |
+
}, timeout);
|
| 24 |
+
return resolver;
|
| 25 |
+
}
|
| 26 |
+
const DEBOUNCE_FN_TO_PROMISE = new WeakMap();
|
| 27 |
+
export function debounce(fn, ms = 64) {
|
| 28 |
+
if (!DEBOUNCE_FN_TO_PROMISE.get(fn)) {
|
| 29 |
+
DEBOUNCE_FN_TO_PROMISE.set(fn, wait(ms).then(() => {
|
| 30 |
+
DEBOUNCE_FN_TO_PROMISE.delete(fn);
|
| 31 |
+
fn();
|
| 32 |
+
}));
|
| 33 |
+
}
|
| 34 |
+
return DEBOUNCE_FN_TO_PROMISE.get(fn);
|
| 35 |
+
}
|
| 36 |
+
export function wait(ms = 16) {
|
| 37 |
+
if (ms === 16) {
|
| 38 |
+
return new Promise((resolve) => {
|
| 39 |
+
requestAnimationFrame(() => {
|
| 40 |
+
resolve();
|
| 41 |
+
});
|
| 42 |
+
});
|
| 43 |
+
}
|
| 44 |
+
return new Promise((resolve) => {
|
| 45 |
+
setTimeout(() => {
|
| 46 |
+
resolve();
|
| 47 |
+
}, ms);
|
| 48 |
+
});
|
| 49 |
+
}
|
| 50 |
+
function dec2hex(dec) {
|
| 51 |
+
return dec.toString(16).padStart(2, "0");
|
| 52 |
+
}
|
| 53 |
+
export function generateId(length) {
|
| 54 |
+
const arr = new Uint8Array(length / 2);
|
| 55 |
+
crypto.getRandomValues(arr);
|
| 56 |
+
return Array.from(arr, dec2hex).join("");
|
| 57 |
+
}
|
| 58 |
+
export function getObjectValue(obj, objKey, def) {
|
| 59 |
+
if (!obj || !objKey)
|
| 60 |
+
return def;
|
| 61 |
+
const keys = objKey.split(".");
|
| 62 |
+
const key = keys.shift();
|
| 63 |
+
const found = obj[key];
|
| 64 |
+
if (keys.length) {
|
| 65 |
+
return getObjectValue(found, keys.join("."), def);
|
| 66 |
+
}
|
| 67 |
+
return found;
|
| 68 |
+
}
|
| 69 |
+
export function setObjectValue(obj, objKey, value, createMissingObjects = true) {
|
| 70 |
+
if (!obj || !objKey)
|
| 71 |
+
return obj;
|
| 72 |
+
const keys = objKey.split(".");
|
| 73 |
+
const key = keys.shift();
|
| 74 |
+
if (obj[key] === undefined) {
|
| 75 |
+
if (!createMissingObjects) {
|
| 76 |
+
return;
|
| 77 |
+
}
|
| 78 |
+
obj[key] = {};
|
| 79 |
+
}
|
| 80 |
+
if (!keys.length) {
|
| 81 |
+
obj[key] = value;
|
| 82 |
+
}
|
| 83 |
+
else {
|
| 84 |
+
if (typeof obj[key] != "object") {
|
| 85 |
+
obj[key] = {};
|
| 86 |
+
}
|
| 87 |
+
setObjectValue(obj[key], keys.join("."), value, createMissingObjects);
|
| 88 |
+
}
|
| 89 |
+
return obj;
|
| 90 |
+
}
|
| 91 |
+
export function moveArrayItem(arr, itemOrFrom, to) {
|
| 92 |
+
const from = typeof itemOrFrom === "number" ? itemOrFrom : arr.indexOf(itemOrFrom);
|
| 93 |
+
arr.splice(to, 0, arr.splice(from, 1)[0]);
|
| 94 |
+
}
|
| 95 |
+
export function removeArrayItem(arr, itemOrIndex) {
|
| 96 |
+
const index = typeof itemOrIndex === "number" ? itemOrIndex : arr.indexOf(itemOrIndex);
|
| 97 |
+
arr.splice(index, 1);
|
| 98 |
+
}
|
| 99 |
+
export function injectCss(href) {
|
| 100 |
+
if (document.querySelector(`link[href^="${href}"]`)) {
|
| 101 |
+
return Promise.resolve();
|
| 102 |
+
}
|
| 103 |
+
return new Promise((resolve) => {
|
| 104 |
+
const link = document.createElement("link");
|
| 105 |
+
link.setAttribute("rel", "stylesheet");
|
| 106 |
+
link.setAttribute("type", "text/css");
|
| 107 |
+
const timeout = setTimeout(resolve, 1000);
|
| 108 |
+
link.addEventListener("load", (e) => {
|
| 109 |
+
clearInterval(timeout);
|
| 110 |
+
resolve();
|
| 111 |
+
});
|
| 112 |
+
link.href = href;
|
| 113 |
+
document.head.appendChild(link);
|
| 114 |
+
});
|
| 115 |
+
}
|
| 116 |
+
export function defineProperty(instance, property, desc) {
|
| 117 |
+
var _a, _b, _c, _d, _e, _f;
|
| 118 |
+
const existingDesc = Object.getOwnPropertyDescriptor(instance, property);
|
| 119 |
+
if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.configurable) === false) {
|
| 120 |
+
throw new Error(`Error: rgthree-comfy cannot define un-configurable property "${property}"`);
|
| 121 |
+
}
|
| 122 |
+
if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.get) && desc.get) {
|
| 123 |
+
const descGet = desc.get;
|
| 124 |
+
desc.get = () => {
|
| 125 |
+
existingDesc.get.apply(instance, []);
|
| 126 |
+
return descGet.apply(instance, []);
|
| 127 |
+
};
|
| 128 |
+
}
|
| 129 |
+
if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.set) && desc.set) {
|
| 130 |
+
const descSet = desc.set;
|
| 131 |
+
desc.set = (v) => {
|
| 132 |
+
existingDesc.set.apply(instance, [v]);
|
| 133 |
+
return descSet.apply(instance, [v]);
|
| 134 |
+
};
|
| 135 |
+
}
|
| 136 |
+
desc.enumerable = (_b = (_a = desc.enumerable) !== null && _a !== void 0 ? _a : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.enumerable) !== null && _b !== void 0 ? _b : true;
|
| 137 |
+
desc.configurable = (_d = (_c = desc.configurable) !== null && _c !== void 0 ? _c : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.configurable) !== null && _d !== void 0 ? _d : true;
|
| 138 |
+
if (!desc.get && !desc.set) {
|
| 139 |
+
desc.writable = (_f = (_e = desc.writable) !== null && _e !== void 0 ? _e : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.writable) !== null && _f !== void 0 ? _f : true;
|
| 140 |
+
}
|
| 141 |
+
return Object.defineProperty(instance, property, desc);
|
| 142 |
+
}
|
rgthree-comfy/web/common/utils_dom.js
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const DIRECT_ATTRIBUTE_MAP = {
|
| 2 |
+
cellpadding: 'cellPadding',
|
| 3 |
+
cellspacing: 'cellSpacing',
|
| 4 |
+
colspan: 'colSpan',
|
| 5 |
+
frameborder: 'frameBorder',
|
| 6 |
+
height: 'height',
|
| 7 |
+
maxlength: 'maxLength',
|
| 8 |
+
nonce: 'nonce',
|
| 9 |
+
role: 'role',
|
| 10 |
+
rowspan: 'rowSpan',
|
| 11 |
+
type: 'type',
|
| 12 |
+
usemap: 'useMap',
|
| 13 |
+
valign: 'vAlign',
|
| 14 |
+
width: 'width',
|
| 15 |
+
};
|
| 16 |
+
const RGX_NUMERIC_STYLE_UNIT = 'px';
|
| 17 |
+
const RGX_NUMERIC_STYLE = /^((max|min)?(width|height)|margin|padding|(margin|padding)?(left|top|bottom|right)|fontsize|borderwidth)$/i;
|
| 18 |
+
const RGX_DEFAULT_VALUE_PROP = /input|textarea|select/i;
|
| 19 |
+
function localAssertNotFalsy(input, errorMsg = `Input is not of type.`) {
|
| 20 |
+
if (input == null) {
|
| 21 |
+
throw new Error(errorMsg);
|
| 22 |
+
}
|
| 23 |
+
return input;
|
| 24 |
+
}
|
| 25 |
+
const RGX_STRING_VALID = '[a-z0-9_-]';
|
| 26 |
+
const RGX_TAG = new RegExp(`^([a-z]${RGX_STRING_VALID}*)(\\.|\\[|\\#|$)`, 'i');
|
| 27 |
+
const RGX_ATTR_ID = new RegExp(`#(${RGX_STRING_VALID}+)`, 'gi');
|
| 28 |
+
const RGX_ATTR_CLASS = new RegExp(`(^|\\S)\\.([a-z0-9_\\-\\.]+)`, 'gi');
|
| 29 |
+
const RGX_STRING_CONTENT_TO_SQUARES = '(.*?)(\\[|\\])';
|
| 30 |
+
const RGX_ATTRS_MAYBE_OPEN = new RegExp(`\\[${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi');
|
| 31 |
+
const RGX_ATTRS_FOLLOW_OPEN = new RegExp(`^${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi');
|
| 32 |
+
export function query(selectors, parent = document) {
|
| 33 |
+
return Array.from(parent.querySelectorAll(selectors)).filter(n => !!n);
|
| 34 |
+
}
|
| 35 |
+
export function queryOne(selectors, parent = document) {
|
| 36 |
+
var _a;
|
| 37 |
+
return (_a = parent.querySelector(selectors)) !== null && _a !== void 0 ? _a : null;
|
| 38 |
+
}
|
| 39 |
+
export function createText(text) {
|
| 40 |
+
return document.createTextNode(text);
|
| 41 |
+
}
|
| 42 |
+
export function getClosestOrSelf(element, query) {
|
| 43 |
+
const el = element;
|
| 44 |
+
return ((el === null || el === void 0 ? void 0 : el.closest) && (el.matches(query) && el || el.closest(query))) || null;
|
| 45 |
+
}
|
| 46 |
+
export function containsOrSelf(parent, contained) {
|
| 47 |
+
var _a;
|
| 48 |
+
return parent === contained || ((_a = parent === null || parent === void 0 ? void 0 : parent.contains) === null || _a === void 0 ? void 0 : _a.call(parent, contained)) || false;
|
| 49 |
+
}
|
| 50 |
+
export function createElement(selectorOrMarkup, attrs) {
|
| 51 |
+
const frag = getHtmlFragment(selectorOrMarkup);
|
| 52 |
+
let element = frag === null || frag === void 0 ? void 0 : frag.firstElementChild;
|
| 53 |
+
let selector = "";
|
| 54 |
+
if (!element) {
|
| 55 |
+
selector = selectorOrMarkup.replace(/[\r\n]\s*/g, "");
|
| 56 |
+
const tag = getSelectorTag(selector) || "div";
|
| 57 |
+
element = document.createElement(tag);
|
| 58 |
+
selector = selector.replace(RGX_TAG, "$2");
|
| 59 |
+
selector = selector.replace(RGX_ATTR_ID, '[id="$1"]');
|
| 60 |
+
selector = selector.replace(RGX_ATTR_CLASS, (match, p1, p2) => `${p1}[class="${p2.replace(/\./g, " ")}"]`);
|
| 61 |
+
}
|
| 62 |
+
const selectorAttrs = getSelectorAttributes(selector);
|
| 63 |
+
if (selectorAttrs) {
|
| 64 |
+
for (const attr of selectorAttrs) {
|
| 65 |
+
let matches = attr.substring(1, attr.length - 1).split("=");
|
| 66 |
+
let key = localAssertNotFalsy(matches.shift());
|
| 67 |
+
let value = matches.join("=");
|
| 68 |
+
if (value === undefined) {
|
| 69 |
+
setAttribute(element, key, true);
|
| 70 |
+
}
|
| 71 |
+
else {
|
| 72 |
+
value = value.replace(/^['"](.*)['"]$/, "$1");
|
| 73 |
+
setAttribute(element, key, value);
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
if (attrs) {
|
| 78 |
+
setAttributes(element, attrs);
|
| 79 |
+
}
|
| 80 |
+
return element;
|
| 81 |
+
}
|
| 82 |
+
export const $el = createElement;
|
| 83 |
+
function getSelectorTag(str) {
|
| 84 |
+
return tryMatch(str, RGX_TAG);
|
| 85 |
+
}
|
| 86 |
+
function getSelectorAttributes(selector) {
|
| 87 |
+
RGX_ATTRS_MAYBE_OPEN.lastIndex = 0;
|
| 88 |
+
let attrs = [];
|
| 89 |
+
let result;
|
| 90 |
+
while (result = RGX_ATTRS_MAYBE_OPEN.exec(selector)) {
|
| 91 |
+
let attr = result[0];
|
| 92 |
+
if (attr.endsWith(']')) {
|
| 93 |
+
attrs.push(attr);
|
| 94 |
+
}
|
| 95 |
+
else {
|
| 96 |
+
attr = result[0]
|
| 97 |
+
+ getOpenAttributesRecursive(selector.substr(RGX_ATTRS_MAYBE_OPEN.lastIndex), 2);
|
| 98 |
+
RGX_ATTRS_MAYBE_OPEN.lastIndex += (attr.length - result[0].length);
|
| 99 |
+
attrs.push(attr);
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
return attrs;
|
| 103 |
+
}
|
| 104 |
+
function getOpenAttributesRecursive(selectorSubstring, openCount) {
|
| 105 |
+
let matches = selectorSubstring.match(RGX_ATTRS_FOLLOW_OPEN);
|
| 106 |
+
let result = '';
|
| 107 |
+
if (matches && matches.length) {
|
| 108 |
+
result = matches[0];
|
| 109 |
+
openCount += result.endsWith(']') ? -1 : 1;
|
| 110 |
+
if (openCount > 0) {
|
| 111 |
+
result += getOpenAttributesRecursive(selectorSubstring.substr(result.length), openCount);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
return result;
|
| 115 |
+
}
|
| 116 |
+
function tryMatch(str, rgx, index = 1) {
|
| 117 |
+
var _a;
|
| 118 |
+
let found = '';
|
| 119 |
+
try {
|
| 120 |
+
found = ((_a = str.match(rgx)) === null || _a === void 0 ? void 0 : _a[index]) || '';
|
| 121 |
+
}
|
| 122 |
+
catch (e) {
|
| 123 |
+
found = '';
|
| 124 |
+
}
|
| 125 |
+
return found;
|
| 126 |
+
}
|
| 127 |
+
export function setAttributes(element, data) {
|
| 128 |
+
let attr;
|
| 129 |
+
for (attr in data) {
|
| 130 |
+
if (data.hasOwnProperty(attr)) {
|
| 131 |
+
setAttribute(element, attr, data[attr]);
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
function getHtmlFragment(value) {
|
| 136 |
+
if (value.match(/^\s*<.*?>[\s\S]*<\/[a-z0-9]+>\s*$/)) {
|
| 137 |
+
return document.createRange().createContextualFragment(value.trim());
|
| 138 |
+
}
|
| 139 |
+
return null;
|
| 140 |
+
}
|
| 141 |
+
function getChild(value) {
|
| 142 |
+
if (value instanceof Node) {
|
| 143 |
+
return value;
|
| 144 |
+
}
|
| 145 |
+
if (typeof value === 'string') {
|
| 146 |
+
let child = getHtmlFragment(value);
|
| 147 |
+
if (child) {
|
| 148 |
+
return child;
|
| 149 |
+
}
|
| 150 |
+
if (getSelectorTag(value)) {
|
| 151 |
+
return createElement(value);
|
| 152 |
+
}
|
| 153 |
+
return createText(value);
|
| 154 |
+
}
|
| 155 |
+
if (value && typeof value.toElement === 'function') {
|
| 156 |
+
return value.toElement();
|
| 157 |
+
}
|
| 158 |
+
return null;
|
| 159 |
+
}
|
| 160 |
+
export function setAttribute(element, attribute, value) {
|
| 161 |
+
let isRemoving = value == null;
|
| 162 |
+
if (attribute === 'default') {
|
| 163 |
+
attribute = RGX_DEFAULT_VALUE_PROP.test(element.nodeName) ? 'value' : 'text';
|
| 164 |
+
}
|
| 165 |
+
if (attribute === 'text') {
|
| 166 |
+
empty(element).appendChild(createText(value != null ? String(value) : ''));
|
| 167 |
+
}
|
| 168 |
+
else if (attribute === 'html') {
|
| 169 |
+
empty(element).innerHTML += value != null ? String(value) : '';
|
| 170 |
+
}
|
| 171 |
+
else if (attribute == 'style') {
|
| 172 |
+
if (typeof value === 'string') {
|
| 173 |
+
element.style.cssText = isRemoving ? '' : (value != null ? String(value) : '');
|
| 174 |
+
}
|
| 175 |
+
else {
|
| 176 |
+
for (const [styleKey, styleValue] of Object.entries(value)) {
|
| 177 |
+
element.style[styleKey] = styleValue;
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
else if (attribute == 'events') {
|
| 182 |
+
for (const [key, fn] of Object.entries(value)) {
|
| 183 |
+
addEvent(element, key, fn);
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
else if (attribute === 'parent') {
|
| 187 |
+
value.appendChild(element);
|
| 188 |
+
}
|
| 189 |
+
else if (attribute === 'child' || attribute === 'children') {
|
| 190 |
+
if (typeof value === 'string' && /^\[[^\[\]]+\]$/.test(value)) {
|
| 191 |
+
const parseable = value.replace(/^\[([^\[\]]+)\]$/, '["$1"]').replace(/,/g, '","');
|
| 192 |
+
try {
|
| 193 |
+
const parsed = JSON.parse(parseable);
|
| 194 |
+
value = parsed;
|
| 195 |
+
}
|
| 196 |
+
catch (e) {
|
| 197 |
+
console.error(e);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
if (attribute === 'children') {
|
| 201 |
+
empty(element);
|
| 202 |
+
}
|
| 203 |
+
let children = value instanceof Array ? value : [value];
|
| 204 |
+
for (let child of children) {
|
| 205 |
+
child = getChild(child);
|
| 206 |
+
if (child instanceof Node) {
|
| 207 |
+
if (element instanceof HTMLTemplateElement) {
|
| 208 |
+
element.content.appendChild(child);
|
| 209 |
+
}
|
| 210 |
+
else {
|
| 211 |
+
element.appendChild(child);
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
else if (attribute == 'for') {
|
| 217 |
+
element.htmlFor = value != null ? String(value) : '';
|
| 218 |
+
if (isRemoving) {
|
| 219 |
+
element.removeAttribute('for');
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
else if (attribute === 'class' || attribute === 'className' || attribute === 'classes') {
|
| 223 |
+
element.className = isRemoving ? '' : Array.isArray(value) ? value.join(' ') : String(value);
|
| 224 |
+
}
|
| 225 |
+
else if (attribute === 'dataset') {
|
| 226 |
+
if (typeof value !== 'object') {
|
| 227 |
+
console.error('Expecting an object for dataset');
|
| 228 |
+
return;
|
| 229 |
+
}
|
| 230 |
+
for (const [key, val] of Object.entries(value)) {
|
| 231 |
+
element.dataset[key] = String(val);
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
else if (attribute.startsWith('on') && typeof value === 'function') {
|
| 235 |
+
element.addEventListener(attribute.substring(2), value);
|
| 236 |
+
}
|
| 237 |
+
else if (['checked', 'disabled', 'readonly', 'required', 'selected'].includes(attribute)) {
|
| 238 |
+
element[attribute] = !!value;
|
| 239 |
+
if (!value) {
|
| 240 |
+
element.removeAttribute(attribute);
|
| 241 |
+
}
|
| 242 |
+
else {
|
| 243 |
+
element.setAttribute(attribute, attribute);
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
else if (DIRECT_ATTRIBUTE_MAP.hasOwnProperty(attribute)) {
|
| 247 |
+
if (isRemoving) {
|
| 248 |
+
element.removeAttribute(DIRECT_ATTRIBUTE_MAP[attribute]);
|
| 249 |
+
}
|
| 250 |
+
else {
|
| 251 |
+
element.setAttribute(DIRECT_ATTRIBUTE_MAP[attribute], String(value));
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
else if (isRemoving) {
|
| 255 |
+
element.removeAttribute(attribute);
|
| 256 |
+
}
|
| 257 |
+
else {
|
| 258 |
+
let oldVal = element.getAttribute(attribute);
|
| 259 |
+
if (oldVal !== value) {
|
| 260 |
+
element.setAttribute(attribute, String(value));
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
function addEvent(element, key, fn) {
|
| 265 |
+
element.addEventListener(key, fn);
|
| 266 |
+
}
|
| 267 |
+
function setStyles(element, styles = null) {
|
| 268 |
+
if (styles) {
|
| 269 |
+
for (let name in styles) {
|
| 270 |
+
setStyle(element, name, styles[name]);
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
return element;
|
| 274 |
+
}
|
| 275 |
+
function setStyle(element, name, value) {
|
| 276 |
+
name = (name.indexOf('float') > -1 ? 'cssFloat' : name);
|
| 277 |
+
if (name.indexOf('-') != -1) {
|
| 278 |
+
name = name.replace(/-\D/g, (match) => {
|
| 279 |
+
return match.charAt(1).toUpperCase();
|
| 280 |
+
});
|
| 281 |
+
}
|
| 282 |
+
if (value == String(Number(value)) && RGX_NUMERIC_STYLE.test(name)) {
|
| 283 |
+
value = value + RGX_NUMERIC_STYLE_UNIT;
|
| 284 |
+
}
|
| 285 |
+
if (name === 'display' && typeof value !== 'string') {
|
| 286 |
+
value = !!value ? null : 'none';
|
| 287 |
+
}
|
| 288 |
+
element.style[name] = value === null ? null : String(value);
|
| 289 |
+
return element;
|
| 290 |
+
}
|
| 291 |
+
;
|
| 292 |
+
export function empty(element) {
|
| 293 |
+
while (element.firstChild) {
|
| 294 |
+
element.removeChild(element.firstChild);
|
| 295 |
+
}
|
| 296 |
+
return element;
|
| 297 |
+
}
|
| 298 |
+
export function appendChildren(el, children) {
|
| 299 |
+
children = !Array.isArray(children) ? [children] : children;
|
| 300 |
+
for (let child of children) {
|
| 301 |
+
child = getChild(child);
|
| 302 |
+
if (child instanceof Node) {
|
| 303 |
+
if (el instanceof HTMLTemplateElement) {
|
| 304 |
+
el.content.appendChild(child);
|
| 305 |
+
}
|
| 306 |
+
else {
|
| 307 |
+
el.appendChild(child);
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
}
|
rgthree-comfy/web/common/utils_workflow.js
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { getResolver } from "./shared_utils.js";
|
| 2 |
+
import { getPngMetadata, getWebpMetadata } from "../../scripts/pnginfo.js";
|
| 3 |
+
function parseWorkflowJson(stringJson) {
|
| 4 |
+
stringJson = stringJson || "null";
|
| 5 |
+
stringJson = stringJson.replace(/:\s*NaN/g, ": null");
|
| 6 |
+
return JSON.parse(stringJson);
|
| 7 |
+
}
|
| 8 |
+
export async function tryToGetWorkflowDataFromEvent(e) {
|
| 9 |
+
var _a, _b, _c, _d;
|
| 10 |
+
let work;
|
| 11 |
+
for (const file of ((_a = e.dataTransfer) === null || _a === void 0 ? void 0 : _a.files) || []) {
|
| 12 |
+
const data = await tryToGetWorkflowDataFromFile(file);
|
| 13 |
+
if (data.workflow || data.prompt) {
|
| 14 |
+
return data;
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
const validTypes = ["text/uri-list", "text/x-moz-url"];
|
| 18 |
+
const match = (((_b = e.dataTransfer) === null || _b === void 0 ? void 0 : _b.types) || []).find((t) => validTypes.find((v) => t === v));
|
| 19 |
+
if (match) {
|
| 20 |
+
const uri = (_d = (_c = e.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0];
|
| 21 |
+
if (uri) {
|
| 22 |
+
return tryToGetWorkflowDataFromFile(await (await fetch(uri)).blob());
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
return { workflow: null, prompt: null };
|
| 26 |
+
}
|
| 27 |
+
export async function tryToGetWorkflowDataFromFile(file) {
|
| 28 |
+
var _a;
|
| 29 |
+
if (file.type === "image/png") {
|
| 30 |
+
const pngInfo = await getPngMetadata(file);
|
| 31 |
+
return {
|
| 32 |
+
workflow: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow),
|
| 33 |
+
prompt: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt),
|
| 34 |
+
};
|
| 35 |
+
}
|
| 36 |
+
if (file.type === "image/webp") {
|
| 37 |
+
const pngInfo = await getWebpMetadata(file);
|
| 38 |
+
const workflow = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Workflow) || "null");
|
| 39 |
+
const prompt = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Prompt) || "null");
|
| 40 |
+
return { workflow, prompt };
|
| 41 |
+
}
|
| 42 |
+
if (file.type === "application/json" || ((_a = file.name) === null || _a === void 0 ? void 0 : _a.endsWith(".json"))) {
|
| 43 |
+
const resolver = getResolver();
|
| 44 |
+
const reader = new FileReader();
|
| 45 |
+
reader.onload = async () => {
|
| 46 |
+
const json = parseWorkflowJson(reader.result);
|
| 47 |
+
const isApiJson = Object.values(json).every((v) => v.class_type);
|
| 48 |
+
const prompt = isApiJson ? json : null;
|
| 49 |
+
const workflow = !isApiJson && !(json === null || json === void 0 ? void 0 : json.templates) ? json : null;
|
| 50 |
+
return { workflow, prompt };
|
| 51 |
+
};
|
| 52 |
+
return resolver.promise;
|
| 53 |
+
}
|
| 54 |
+
return { workflow: null, prompt: null };
|
| 55 |
+
}
|
rgthree-comfy/web/link_fixer/link_page.js
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { fixBadLinks } from "../common/link_fixer.js";
|
| 2 |
+
import { getPngMetadata } from "../../scripts/pnginfo.js";
|
| 3 |
+
function wait(ms = 16, value) {
|
| 4 |
+
return new Promise((resolve) => {
|
| 5 |
+
setTimeout(() => {
|
| 6 |
+
resolve(value);
|
| 7 |
+
}, ms);
|
| 8 |
+
});
|
| 9 |
+
}
|
| 10 |
+
const logger = {
|
| 11 |
+
logTo: console,
|
| 12 |
+
log: (...args) => {
|
| 13 |
+
logger.logTo === console
|
| 14 |
+
? console.log(...args)
|
| 15 |
+
: (logger.logTo.innerText += args.join(",") + "\n");
|
| 16 |
+
},
|
| 17 |
+
};
|
| 18 |
+
const findBadLinksLogger = {
|
| 19 |
+
log: async (...args) => {
|
| 20 |
+
logger.log(...args);
|
| 21 |
+
},
|
| 22 |
+
};
|
| 23 |
+
export class LinkPage {
|
| 24 |
+
constructor() {
|
| 25 |
+
this.containerEl = document.querySelector(".box");
|
| 26 |
+
this.figcaptionEl = document.querySelector("figcaption");
|
| 27 |
+
this.outputeMessageEl = document.querySelector(".output");
|
| 28 |
+
this.outputImageEl = document.querySelector(".output-image");
|
| 29 |
+
this.btnFix = document.querySelector(".btn-fix");
|
| 30 |
+
document.addEventListener("dragover", (e) => {
|
| 31 |
+
e.preventDefault();
|
| 32 |
+
}, false);
|
| 33 |
+
document.addEventListener("drop", (e) => {
|
| 34 |
+
this.onDrop(e);
|
| 35 |
+
});
|
| 36 |
+
this.btnFix.addEventListener("click", (e) => {
|
| 37 |
+
this.onFixClick(e);
|
| 38 |
+
});
|
| 39 |
+
}
|
| 40 |
+
async onFixClick(e) {
|
| 41 |
+
if (!this.graphResults || !this.graph) {
|
| 42 |
+
this.updateUi("⛔ Fix button click without results.");
|
| 43 |
+
return;
|
| 44 |
+
}
|
| 45 |
+
let graphFinalResults = fixBadLinks(this.graph, true);
|
| 46 |
+
graphFinalResults = fixBadLinks(graphFinalResults.graph, true);
|
| 47 |
+
if (graphFinalResults.patched || graphFinalResults.deleted) {
|
| 48 |
+
graphFinalResults = fixBadLinks(graphFinalResults.graph, true);
|
| 49 |
+
}
|
| 50 |
+
this.graphFinalResults = graphFinalResults;
|
| 51 |
+
await this.saveFixedWorkflow();
|
| 52 |
+
if (graphFinalResults.hasBadLinks) {
|
| 53 |
+
this.updateUi("⛔ Hmm... Still detecting bad links. Can you file an issue at https://github.com/rgthree/rgthree-comfy/issues with your image/workflow.");
|
| 54 |
+
}
|
| 55 |
+
else {
|
| 56 |
+
this.updateUi("✅ Workflow fixed.<br><br><small>Please load new saved workflow json and double check linking and execution.</small>");
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
async onDrop(event) {
|
| 60 |
+
var _a, _b, _c, _d;
|
| 61 |
+
if (!event.dataTransfer) {
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
this.reset();
|
| 65 |
+
event.preventDefault();
|
| 66 |
+
event.stopPropagation();
|
| 67 |
+
if (event.dataTransfer.files.length && ((_b = (_a = event.dataTransfer.files) === null || _a === void 0 ? void 0 : _a[0]) === null || _b === void 0 ? void 0 : _b.type) !== "image/bmp") {
|
| 68 |
+
await this.handleFile(event.dataTransfer.files[0]);
|
| 69 |
+
return;
|
| 70 |
+
}
|
| 71 |
+
const validTypes = ["text/uri-list", "text/x-moz-url"];
|
| 72 |
+
const match = [...event.dataTransfer.types].find((t) => validTypes.find((v) => t === v));
|
| 73 |
+
if (match) {
|
| 74 |
+
const uri = (_d = (_c = event.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0];
|
| 75 |
+
if (uri) {
|
| 76 |
+
await this.handleFile(await (await fetch(uri)).blob());
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
reset() {
|
| 81 |
+
this.file = undefined;
|
| 82 |
+
this.graph = undefined;
|
| 83 |
+
this.graphResults = undefined;
|
| 84 |
+
this.graphFinalResults = undefined;
|
| 85 |
+
this.updateUi();
|
| 86 |
+
}
|
| 87 |
+
updateUi(msg) {
|
| 88 |
+
this.outputeMessageEl.innerHTML = "";
|
| 89 |
+
if (this.file && !this.containerEl.classList.contains("-has-file")) {
|
| 90 |
+
this.containerEl.classList.add("-has-file");
|
| 91 |
+
this.figcaptionEl.innerHTML = this.file.name || this.file.type;
|
| 92 |
+
if (this.file.type === "application/json") {
|
| 93 |
+
this.outputImageEl.src = "icon_file_json.png";
|
| 94 |
+
}
|
| 95 |
+
else {
|
| 96 |
+
const reader = new FileReader();
|
| 97 |
+
reader.onload = () => (this.outputImageEl.src = reader.result);
|
| 98 |
+
reader.readAsDataURL(this.file);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
else if (!this.file && this.containerEl.classList.contains("-has-file")) {
|
| 102 |
+
this.containerEl.classList.remove("-has-file");
|
| 103 |
+
this.outputImageEl.src = "";
|
| 104 |
+
this.outputImageEl.removeAttribute("src");
|
| 105 |
+
}
|
| 106 |
+
if (this.graphResults) {
|
| 107 |
+
this.containerEl.classList.add("-has-results");
|
| 108 |
+
if (!this.graphResults.patched && !this.graphResults.deleted) {
|
| 109 |
+
this.outputeMessageEl.innerHTML = "✅ No bad links detected in the workflow.";
|
| 110 |
+
}
|
| 111 |
+
else {
|
| 112 |
+
this.containerEl.classList.add("-has-fixable-results");
|
| 113 |
+
this.outputeMessageEl.innerHTML = `⚠️ Found ${this.graphResults.patched} links to fix, and ${this.graphResults.deleted} to be removed.`;
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
else {
|
| 117 |
+
this.containerEl.classList.remove("-has-results");
|
| 118 |
+
this.containerEl.classList.remove("-has-fixable-results");
|
| 119 |
+
}
|
| 120 |
+
if (msg) {
|
| 121 |
+
this.outputeMessageEl.innerHTML = msg;
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
async handleFile(file) {
|
| 125 |
+
this.file = file;
|
| 126 |
+
this.updateUi();
|
| 127 |
+
let workflow = null;
|
| 128 |
+
if (file.type.startsWith("image/")) {
|
| 129 |
+
const pngInfo = await getPngMetadata(file);
|
| 130 |
+
workflow = pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow;
|
| 131 |
+
}
|
| 132 |
+
else if (file.type === "application/json" ||
|
| 133 |
+
(file instanceof File && file.name.endsWith(".json"))) {
|
| 134 |
+
workflow = await new Promise((resolve) => {
|
| 135 |
+
const reader = new FileReader();
|
| 136 |
+
reader.onload = () => {
|
| 137 |
+
resolve(reader.result);
|
| 138 |
+
};
|
| 139 |
+
reader.readAsText(file);
|
| 140 |
+
});
|
| 141 |
+
}
|
| 142 |
+
if (!workflow) {
|
| 143 |
+
this.updateUi("⛔ No workflow found in dropped item.");
|
| 144 |
+
}
|
| 145 |
+
else {
|
| 146 |
+
try {
|
| 147 |
+
this.graph = JSON.parse(workflow);
|
| 148 |
+
}
|
| 149 |
+
catch (e) {
|
| 150 |
+
this.graph = undefined;
|
| 151 |
+
}
|
| 152 |
+
if (!this.graph) {
|
| 153 |
+
this.updateUi("⛔ Invalid workflow found in dropped item.");
|
| 154 |
+
}
|
| 155 |
+
else {
|
| 156 |
+
this.loadGraphData(this.graph);
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
async loadGraphData(graphData) {
|
| 161 |
+
this.graphResults = await fixBadLinks(graphData);
|
| 162 |
+
this.updateUi();
|
| 163 |
+
}
|
| 164 |
+
async saveFixedWorkflow() {
|
| 165 |
+
if (!this.graphFinalResults) {
|
| 166 |
+
this.updateUi("⛔ Save w/o final graph patched.");
|
| 167 |
+
return false;
|
| 168 |
+
}
|
| 169 |
+
let filename = this.file.name || "workflow.json";
|
| 170 |
+
let filenames = filename.split(".");
|
| 171 |
+
filenames.pop();
|
| 172 |
+
filename = filenames.join(".");
|
| 173 |
+
filename += "_fixed.json";
|
| 174 |
+
filename = prompt("Save workflow as:", filename);
|
| 175 |
+
if (!filename)
|
| 176 |
+
return false;
|
| 177 |
+
if (!filename.toLowerCase().endsWith(".json")) {
|
| 178 |
+
filename += ".json";
|
| 179 |
+
}
|
| 180 |
+
const json = JSON.stringify(this.graphFinalResults.graph, null, 2);
|
| 181 |
+
const blob = new Blob([json], { type: "application/json" });
|
| 182 |
+
const url = URL.createObjectURL(blob);
|
| 183 |
+
const anchor = document.createElement("a");
|
| 184 |
+
anchor.download = filename;
|
| 185 |
+
anchor.href = url;
|
| 186 |
+
anchor.style.display = "none";
|
| 187 |
+
document.body.appendChild(anchor);
|
| 188 |
+
await wait();
|
| 189 |
+
anchor.click();
|
| 190 |
+
await wait();
|
| 191 |
+
anchor.remove();
|
| 192 |
+
window.URL.revokeObjectURL(url);
|
| 193 |
+
return true;
|
| 194 |
+
}
|
| 195 |
+
}
|
sd-dynamic-thresholding/.github/FUNDING.yml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
github: mcmonkey4eva
|
sd-dynamic-thresholding/.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- master
|
| 7 |
+
paths:
|
| 8 |
+
- "pyproject.toml"
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
publish-node:
|
| 12 |
+
name: Publish Custom Node to registry
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- name: Check out code
|
| 16 |
+
uses: actions/checkout@v4
|
| 17 |
+
- name: Publish Custom Node
|
| 18 |
+
uses: Comfy-Org/publish-node-action@main
|
| 19 |
+
with:
|
| 20 |
+
## Add your own personal access token to your Github Repository secrets and reference it here.
|
| 21 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
sd-dynamic-thresholding/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (435 Bytes). View file
|
|
|
sd-dynamic-thresholding/__pycache__/dynthres_comfyui.cpython-312.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
sd-dynamic-thresholding/__pycache__/dynthres_core.cpython-312.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
sd-dynamic-thresholding/github/comfy_node.png
ADDED
|
sd-dynamic-thresholding/github/ui.png
ADDED
|
sd-dynamic-thresholding/javascript/active.js
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
let dynthres_update_enabled = function() {
|
| 2 |
+
return Array.from(arguments);
|
| 3 |
+
};
|
| 4 |
+
|
| 5 |
+
(function(){
|
| 6 |
+
let accordions = {};
|
| 7 |
+
let enabled = {};
|
| 8 |
+
onUiUpdate(() => {
|
| 9 |
+
let accordion_id_prefix = "#dynthres_";
|
| 10 |
+
let extension_checkbox_class = ".dynthres-enabled";
|
| 11 |
+
|
| 12 |
+
dynthres_update_enabled = function() {
|
| 13 |
+
let res = Array.from(arguments);
|
| 14 |
+
let tabname = res[1] ? "img2img" : "txt2img";
|
| 15 |
+
|
| 16 |
+
let checkbox = accordions[tabname]?.querySelector(extension_checkbox_class + ' input');
|
| 17 |
+
checkbox?.dispatchEvent(new Event('change'));
|
| 18 |
+
|
| 19 |
+
return res;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
function attachEnabledButtonListener(checkbox, accordion) {
|
| 23 |
+
let span = accordion.querySelector('.label-wrap span');
|
| 24 |
+
let badge = document.createElement('input');
|
| 25 |
+
badge.type = "checkbox";
|
| 26 |
+
badge.checked = checkbox.checked;
|
| 27 |
+
badge.addEventListener('click', (e) => {
|
| 28 |
+
checkbox.checked = !checkbox.checked;
|
| 29 |
+
badge.checked = checkbox.checked;
|
| 30 |
+
checkbox.dispatchEvent(new Event('change'));
|
| 31 |
+
e.stopPropagation();
|
| 32 |
+
});
|
| 33 |
+
|
| 34 |
+
badge.className = checkbox.className;
|
| 35 |
+
badge.classList.add('primary');
|
| 36 |
+
span.insertBefore(badge, span.firstChild);
|
| 37 |
+
let space = document.createElement('span');
|
| 38 |
+
space.innerHTML = " ";
|
| 39 |
+
span.insertBefore(space, badge.nextSibling);
|
| 40 |
+
|
| 41 |
+
checkbox.addEventListener('change', () => {
|
| 42 |
+
let badge = accordion.querySelector('.label-wrap span input');
|
| 43 |
+
badge.checked = checkbox.checked;
|
| 44 |
+
});
|
| 45 |
+
checkbox.parentNode.style.display = "none";
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
if (Object.keys(accordions).length < 2) {
|
| 49 |
+
let accordion = gradioApp().querySelector(accordion_id_prefix + 'txt2img');
|
| 50 |
+
if (accordion) {
|
| 51 |
+
accordions.txt2img = accordion;
|
| 52 |
+
}
|
| 53 |
+
accordion = gradioApp().querySelector(accordion_id_prefix + 'img2img');
|
| 54 |
+
if (accordion) {
|
| 55 |
+
accordions.img2img = accordion;
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if (Object.keys(accordions).length > 0 && accordions.txt2img && !enabled.txt2img) {
|
| 60 |
+
enabled.txt2img = accordions.txt2img.querySelector(extension_checkbox_class + ' input');
|
| 61 |
+
attachEnabledButtonListener(enabled.txt2img, accordions.txt2img);
|
| 62 |
+
}
|
| 63 |
+
if (Object.keys(accordions).length > 0 && accordions.img2img && !enabled.img2img) {
|
| 64 |
+
enabled.img2img = accordions.img2img.querySelector(extension_checkbox_class + ' input');
|
| 65 |
+
attachEnabledButtonListener(enabled.img2img, accordions.img2img);
|
| 66 |
+
}
|
| 67 |
+
});
|
| 68 |
+
})();
|
sd-dynamic-thresholding/scripts/dynamic_thresholding.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##################
|
| 2 |
+
# Stable Diffusion Dynamic Thresholding (CFG Scale Fix)
|
| 3 |
+
#
|
| 4 |
+
# Author: Alex 'mcmonkey' Goodwin
|
| 5 |
+
# GitHub URL: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding
|
| 6 |
+
# Created: 2022/01/26
|
| 7 |
+
# Last updated: 2023/01/30
|
| 8 |
+
#
|
| 9 |
+
# For usage help, view the README.md file in the extension root, or via the GitHub page.
|
| 10 |
+
#
|
| 11 |
+
##################
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import torch, traceback
|
| 15 |
+
import dynthres_core
|
| 16 |
+
from modules import scripts, script_callbacks, sd_samplers, sd_samplers_compvis, sd_samplers_common
|
| 17 |
+
try:
|
| 18 |
+
import dynthres_unipc
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"\n\n======\nError! UniPC sampler support failed to load! Is your WebUI up to date?\n(Error: {e})\n======")
|
| 21 |
+
try:
|
| 22 |
+
from modules.sd_samplers_kdiffusion import CFGDenoiserKDiffusion as cfgdenoisekdiff
|
| 23 |
+
IS_AUTO_16 = True
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"\n\n======\nWarning! Using legacy KDiff version! Is your WebUI up to date?\n======")
|
| 26 |
+
from modules.sd_samplers_kdiffusion import CFGDenoiser as cfgdenoisekdiff
|
| 27 |
+
IS_AUTO_16 = False
|
| 28 |
+
|
| 29 |
+
DISABLE_VISIBILITY = True
|
| 30 |
+
|
| 31 |
+
######################### Data values #########################
|
| 32 |
+
MODES_WITH_VALUE = ["Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
|
| 33 |
+
|
| 34 |
+
######################### Script class entrypoint #########################
|
| 35 |
+
class Script(scripts.Script):
|
| 36 |
+
|
| 37 |
+
def title(self):
|
| 38 |
+
return "Dynamic Thresholding (CFG Scale Fix)"
|
| 39 |
+
|
| 40 |
+
def show(self, is_img2img):
|
| 41 |
+
return scripts.AlwaysVisible
|
| 42 |
+
|
| 43 |
+
def ui(self, is_img2img):
|
| 44 |
+
def vis_change(is_vis):
|
| 45 |
+
return {"visible": is_vis, "__type__": "update"}
|
| 46 |
+
# "Dynamic Thresholding (CFG Scale Fix)"
|
| 47 |
+
dtrue = gr.Checkbox(value=True, visible=False)
|
| 48 |
+
dfalse = gr.Checkbox(value=False, visible=False)
|
| 49 |
+
with gr.Accordion("Dynamic Thresholding (CFG Scale Fix)", open=False, elem_id="dynthres_" + ("img2img" if is_img2img else "txt2img")):
|
| 50 |
+
with gr.Row():
|
| 51 |
+
enabled = gr.Checkbox(value=False, label="Enable Dynamic Thresholding (CFG Scale Fix)", elem_classes=["dynthres-enabled"], elem_id='dynthres_enabled')
|
| 52 |
+
with gr.Group():
|
| 53 |
+
gr.HTML(value=f"View <a style=\"border-bottom: 1px #00ffff dotted;\" href=\"https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/wiki/Usage-Tips\">the wiki for usage tips.</a><br><br>", elem_id='dynthres_wiki_link')
|
| 54 |
+
mimic_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Mimic CFG Scale', value=7.0, elem_id='dynthres_mimic_scale')
|
| 55 |
+
with gr.Accordion("Advanced Options", open=False, elem_id='dynthres_advanced_opts'):
|
| 56 |
+
with gr.Row():
|
| 57 |
+
threshold_percentile = gr.Slider(minimum=90.0, value=100.0, maximum=100.0, step=0.05, label='Top percentile of latents to clamp', elem_id='dynthres_threshold_percentile')
|
| 58 |
+
interpolate_phi = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Interpolate Phi", value=1.0, elem_id='dynthres_interpolate_phi')
|
| 59 |
+
with gr.Row():
|
| 60 |
+
mimic_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="Mimic Scale Scheduler", elem_id='dynthres_mimic_mode')
|
| 61 |
+
cfg_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="CFG Scale Scheduler", elem_id='dynthres_cfg_mode')
|
| 62 |
+
mimic_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=DISABLE_VISIBILITY, label="Minimum value of the Mimic Scale Scheduler", elem_id='dynthres_mimic_scale_min')
|
| 63 |
+
cfg_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=DISABLE_VISIBILITY, label="Minimum value of the CFG Scale Scheduler", elem_id='dynthres_cfg_scale_min')
|
| 64 |
+
sched_val = gr.Slider(minimum=0.0, maximum=40.0, step=0.5, value=4.0, visible=DISABLE_VISIBILITY, label="Scheduler Value", info="Value unique to the scheduler mode - for Power Up/Down, this is the power. For Linear/Cosine Repeating, this is the number of repeats per image.", elem_id='dynthres_sched_val')
|
| 65 |
+
with gr.Row():
|
| 66 |
+
separate_feature_channels = gr.Checkbox(value=True, label="Separate Feature Channels", elem_id='dynthres_separate_feature_channels')
|
| 67 |
+
scaling_startpoint = gr.Radio(["ZERO", "MEAN"], value="MEAN", label="Scaling Startpoint")
|
| 68 |
+
variability_measure = gr.Radio(["STD", "AD"], value="AD", label="Variability Measure")
|
| 69 |
+
def should_show_scheduler_value(cfg_mode, mimic_mode):
|
| 70 |
+
sched_vis = cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE or DISABLE_VISIBILITY
|
| 71 |
+
return vis_change(sched_vis), vis_change(mimic_mode != "Constant" or DISABLE_VISIBILITY), vis_change(cfg_mode != "Constant" or DISABLE_VISIBILITY)
|
| 72 |
+
cfg_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
| 73 |
+
mimic_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
|
| 74 |
+
enabled.change(
|
| 75 |
+
_js="dynthres_update_enabled",
|
| 76 |
+
fn=None,
|
| 77 |
+
inputs=[enabled, dtrue if is_img2img else dfalse],
|
| 78 |
+
show_progress = False)
|
| 79 |
+
self.infotext_fields = (
|
| 80 |
+
(enabled, lambda d: gr.Checkbox.update(value="Dynamic thresholding enabled" in d)),
|
| 81 |
+
(mimic_scale, "Mimic scale"),
|
| 82 |
+
(separate_feature_channels, "Separate Feature Channels"),
|
| 83 |
+
(scaling_startpoint, lambda d: gr.Radio.update(value=d.get("Scaling Startpoint", "MEAN"))),
|
| 84 |
+
(variability_measure, lambda d: gr.Radio.update(value=d.get("Variability Measure", "AD"))),
|
| 85 |
+
(interpolate_phi, "Interpolate Phi"),
|
| 86 |
+
(threshold_percentile, "Threshold percentile"),
|
| 87 |
+
(mimic_scale_min, "Mimic scale minimum"),
|
| 88 |
+
(mimic_mode, lambda d: gr.Dropdown.update(value=d.get("Mimic mode", "Constant"))),
|
| 89 |
+
(cfg_mode, lambda d: gr.Dropdown.update(value=d.get("CFG mode", "Constant"))),
|
| 90 |
+
(cfg_scale_min, "CFG scale minimum"),
|
| 91 |
+
(sched_val, "Scheduler value"))
|
| 92 |
+
return [enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi]
|
| 93 |
+
|
| 94 |
+
last_id = 0
|
| 95 |
+
|
| 96 |
+
def process_batch(self, p, enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi, batch_number, prompts, seeds, subseeds):
|
| 97 |
+
enabled = getattr(p, 'dynthres_enabled', enabled)
|
| 98 |
+
if not enabled:
|
| 99 |
+
return
|
| 100 |
+
orig_sampler_name = p.sampler_name
|
| 101 |
+
orig_latent_sampler_name = getattr(p, 'latent_sampler', None)
|
| 102 |
+
if orig_sampler_name in ["DDIM", "PLMS"]:
|
| 103 |
+
raise RuntimeError(f"Cannot use sampler {orig_sampler_name} with Dynamic Thresholding")
|
| 104 |
+
if orig_latent_sampler_name in ["DDIM", "PLMS"]:
|
| 105 |
+
raise RuntimeError(f"Cannot use secondary sampler {orig_latent_sampler_name} with Dynamic Thresholding")
|
| 106 |
+
if 'UniPC' in (orig_sampler_name, orig_latent_sampler_name) and p.enable_hr:
|
| 107 |
+
raise RuntimeError(f"UniPC does not support Hires Fix. Auto WebUI silently swaps to DDIM for this, which DynThresh does not support. Please swap to a sampler capable of img2img processing for HR Fix to work.")
|
| 108 |
+
mimic_scale = getattr(p, 'dynthres_mimic_scale', mimic_scale)
|
| 109 |
+
separate_feature_channels = getattr(p, 'dynthres_separate_feature_channels', separate_feature_channels)
|
| 110 |
+
scaling_startpoint = getattr(p, 'dynthres_scaling_startpoint', scaling_startpoint)
|
| 111 |
+
variability_measure = getattr(p, 'dynthres_variability_measure', variability_measure)
|
| 112 |
+
interpolate_phi = getattr(p, 'dynthres_interpolate_phi', interpolate_phi)
|
| 113 |
+
threshold_percentile = getattr(p, 'dynthres_threshold_percentile', threshold_percentile)
|
| 114 |
+
mimic_mode = getattr(p, 'dynthres_mimic_mode', mimic_mode)
|
| 115 |
+
mimic_scale_min = getattr(p, 'dynthres_mimic_scale_min', mimic_scale_min)
|
| 116 |
+
cfg_mode = getattr(p, 'dynthres_cfg_mode', cfg_mode)
|
| 117 |
+
cfg_scale_min = getattr(p, 'dynthres_cfg_scale_min', cfg_scale_min)
|
| 118 |
+
experiment_mode = getattr(p, 'dynthres_experiment_mode', 0)
|
| 119 |
+
sched_val = getattr(p, 'dynthres_scheduler_val', sched_val)
|
| 120 |
+
p.extra_generation_params["Dynamic thresholding enabled"] = True
|
| 121 |
+
p.extra_generation_params["Mimic scale"] = mimic_scale
|
| 122 |
+
p.extra_generation_params["Separate Feature Channels"] = separate_feature_channels
|
| 123 |
+
p.extra_generation_params["Scaling Startpoint"] = scaling_startpoint
|
| 124 |
+
p.extra_generation_params["Variability Measure"] = variability_measure
|
| 125 |
+
p.extra_generation_params["Interpolate Phi"] = interpolate_phi
|
| 126 |
+
p.extra_generation_params["Threshold percentile"] = threshold_percentile
|
| 127 |
+
p.extra_generation_params["Sampler"] = orig_sampler_name
|
| 128 |
+
if mimic_mode != "Constant":
|
| 129 |
+
p.extra_generation_params["Mimic mode"] = mimic_mode
|
| 130 |
+
p.extra_generation_params["Mimic scale minimum"] = mimic_scale_min
|
| 131 |
+
if cfg_mode != "Constant":
|
| 132 |
+
p.extra_generation_params["CFG mode"] = cfg_mode
|
| 133 |
+
p.extra_generation_params["CFG scale minimum"] = cfg_scale_min
|
| 134 |
+
if cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE:
|
| 135 |
+
p.extra_generation_params["Scheduler value"] = sched_val
|
| 136 |
+
# Note: the ID number is to protect the edge case of multiple simultaneous runs with different settings
|
| 137 |
+
Script.last_id += 1
|
| 138 |
+
# Percentage to portion
|
| 139 |
+
threshold_percentile *= 0.01
|
| 140 |
+
|
| 141 |
+
def make_sampler(orig_sampler_name):
|
| 142 |
+
fixed_sampler_name = f"{orig_sampler_name}_dynthres{Script.last_id}"
|
| 143 |
+
|
| 144 |
+
# Make a placeholder sampler
|
| 145 |
+
sampler = sd_samplers.all_samplers_map[orig_sampler_name]
|
| 146 |
+
dt_data = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
|
| 147 |
+
if orig_sampler_name == "UniPC":
|
| 148 |
+
def unipc_constructor(model):
|
| 149 |
+
return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dt_data)
|
| 150 |
+
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, unipc_constructor, sampler.aliases, sampler.options)
|
| 151 |
+
else:
|
| 152 |
+
def new_constructor(model):
|
| 153 |
+
result = sampler.constructor(model)
|
| 154 |
+
cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dt_data)
|
| 155 |
+
result.model_wrap_cfg = cfg
|
| 156 |
+
return result
|
| 157 |
+
new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, new_constructor, sampler.aliases, sampler.options)
|
| 158 |
+
return fixed_sampler_name, new_sampler
|
| 159 |
+
|
| 160 |
+
# Apply for usage
|
| 161 |
+
p.orig_sampler_name = orig_sampler_name
|
| 162 |
+
p.orig_latent_sampler_name = orig_latent_sampler_name
|
| 163 |
+
p.fixed_samplers = []
|
| 164 |
+
|
| 165 |
+
if orig_latent_sampler_name:
|
| 166 |
+
latent_sampler_name, latent_sampler = make_sampler(orig_latent_sampler_name)
|
| 167 |
+
sd_samplers.all_samplers_map[latent_sampler_name] = latent_sampler
|
| 168 |
+
p.fixed_samplers.append(latent_sampler_name)
|
| 169 |
+
p.latent_sampler = latent_sampler_name
|
| 170 |
+
|
| 171 |
+
if orig_sampler_name != orig_latent_sampler_name:
|
| 172 |
+
p.sampler_name, new_sampler = make_sampler(orig_sampler_name)
|
| 173 |
+
sd_samplers.all_samplers_map[p.sampler_name] = new_sampler
|
| 174 |
+
p.fixed_samplers.append(p.sampler_name)
|
| 175 |
+
else:
|
| 176 |
+
p.sampler_name = p.latent_sampler
|
| 177 |
+
|
| 178 |
+
if p.sampler is not None:
|
| 179 |
+
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
| 180 |
+
|
| 181 |
+
def postprocess_batch(self, p, enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi, batch_number, images):
|
| 182 |
+
if not enabled or not hasattr(p, 'orig_sampler_name'):
|
| 183 |
+
return
|
| 184 |
+
p.sampler_name = p.orig_sampler_name
|
| 185 |
+
if p.orig_latent_sampler_name:
|
| 186 |
+
p.latent_sampler = p.orig_latent_sampler_name
|
| 187 |
+
for added_sampler in p.fixed_samplers:
|
| 188 |
+
del sd_samplers.all_samplers_map[added_sampler]
|
| 189 |
+
del p.fixed_samplers
|
| 190 |
+
del p.orig_sampler_name
|
| 191 |
+
del p.orig_latent_sampler_name
|
| 192 |
+
|
| 193 |
+
######################### CompVis Implementation logic #########################
|
| 194 |
+
|
| 195 |
+
class CustomVanillaSDSampler(sd_samplers_compvis.VanillaStableDiffusionSampler):
|
| 196 |
+
def __init__(self, constructor, sd_model, dt_data):
|
| 197 |
+
super().__init__(constructor, sd_model)
|
| 198 |
+
self.sampler.main_class = dt_data
|
| 199 |
+
|
| 200 |
+
######################### K-Diffusion Implementation logic #########################
|
| 201 |
+
|
| 202 |
+
class CustomCFGDenoiser(cfgdenoisekdiff):
|
| 203 |
+
def __init__(self, model, dt_data):
|
| 204 |
+
super().__init__(model)
|
| 205 |
+
self.main_class = dt_data
|
| 206 |
+
|
| 207 |
+
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
| 208 |
+
if isinstance(uncond, dict) and 'crossattn' in uncond:
|
| 209 |
+
uncond = uncond['crossattn']
|
| 210 |
+
denoised_uncond = x_out[-uncond.shape[0]:]
|
| 211 |
+
# conds_list shape is (batch, cond, 2)
|
| 212 |
+
weights = torch.tensor(conds_list, device=uncond.device).select(2, 1)
|
| 213 |
+
weights = weights.reshape(*weights.shape, 1, 1, 1)
|
| 214 |
+
self.main_class.step = self.step
|
| 215 |
+
if hasattr(self, 'total_steps'):
|
| 216 |
+
self.main_class.max_steps = self.total_steps
|
| 217 |
+
|
| 218 |
+
if self.main_class.experiment_mode >= 4 and self.main_class.experiment_mode <= 5:
|
| 219 |
+
# https://arxiv.org/pdf/2305.08891.pdf "Rescale CFG". It's not good, but if you want to test it, just set experiment_mode = 4 + phi.
|
| 220 |
+
denoised = torch.clone(denoised_uncond)
|
| 221 |
+
fi = self.main_class.experiment_mode - 4.0
|
| 222 |
+
for i, conds in enumerate(conds_list):
|
| 223 |
+
for cond_index, weight in conds:
|
| 224 |
+
xcfg = (denoised_uncond[i] + (x_out[cond_index] - denoised_uncond[i]) * (cond_scale * weight))
|
| 225 |
+
xrescaled = xcfg * (torch.std(x_out[cond_index]) / torch.std(xcfg))
|
| 226 |
+
xfinal = fi * xrescaled + (1.0 - fi) * xcfg
|
| 227 |
+
denoised[i] = xfinal
|
| 228 |
+
return denoised
|
| 229 |
+
|
| 230 |
+
return self.main_class.dynthresh(x_out[:-uncond.shape[0]], denoised_uncond, cond_scale, weights)
|
| 231 |
+
|
| 232 |
+
######################### XYZ Plot Script Support logic #########################
|
| 233 |
+
|
| 234 |
+
def make_axis_options():
|
| 235 |
+
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
|
| 236 |
+
def apply_mimic_scale(p, x, xs):
|
| 237 |
+
if x != 0:
|
| 238 |
+
setattr(p, "dynthres_enabled", True)
|
| 239 |
+
setattr(p, "dynthres_mimic_scale", x)
|
| 240 |
+
else:
|
| 241 |
+
setattr(p, "dynthres_enabled", False)
|
| 242 |
+
def confirm_scheduler(p, xs):
|
| 243 |
+
for x in xs:
|
| 244 |
+
if x not in dynthres_core.DynThresh.Modes:
|
| 245 |
+
raise RuntimeError(f"Unknown Scheduler: {x}")
|
| 246 |
+
extra_axis_options = [
|
| 247 |
+
xyz_grid.AxisOption("[DynThres] Mimic Scale", float, apply_mimic_scale),
|
| 248 |
+
xyz_grid.AxisOption("[DynThres] Separate Feature Channels", int,
|
| 249 |
+
xyz_grid.apply_field("dynthres_separate_feature_channels")),
|
| 250 |
+
xyz_grid.AxisOption("[DynThres] Scaling Startpoint", str, xyz_grid.apply_field("dynthres_scaling_startpoint"), choices=lambda:['ZERO', 'MEAN']),
|
| 251 |
+
xyz_grid.AxisOption("[DynThres] Variability Measure", str, xyz_grid.apply_field("dynthres_variability_measure"), choices=lambda:['STD', 'AD']),
|
| 252 |
+
xyz_grid.AxisOption("[DynThres] Interpolate Phi", float, xyz_grid.apply_field("dynthres_interpolate_phi")),
|
| 253 |
+
xyz_grid.AxisOption("[DynThres] Threshold Percentile", float, xyz_grid.apply_field("dynthres_threshold_percentile")),
|
| 254 |
+
xyz_grid.AxisOption("[DynThres] Mimic Scheduler", str, xyz_grid.apply_field("dynthres_mimic_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
|
| 255 |
+
xyz_grid.AxisOption("[DynThres] Mimic minimum", float, xyz_grid.apply_field("dynthres_mimic_scale_min")),
|
| 256 |
+
xyz_grid.AxisOption("[DynThres] CFG Scheduler", str, xyz_grid.apply_field("dynthres_cfg_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
|
| 257 |
+
xyz_grid.AxisOption("[DynThres] CFG minimum", float, xyz_grid.apply_field("dynthres_cfg_scale_min")),
|
| 258 |
+
xyz_grid.AxisOption("[DynThres] Scheduler value", float, xyz_grid.apply_field("dynthres_scheduler_val"))
|
| 259 |
+
]
|
| 260 |
+
if not any("[DynThres]" in x.label for x in xyz_grid.axis_options):
|
| 261 |
+
xyz_grid.axis_options.extend(extra_axis_options)
|
| 262 |
+
|
| 263 |
+
def callback_before_ui():
|
| 264 |
+
try:
|
| 265 |
+
make_axis_options()
|
| 266 |
+
except Exception as e:
|
| 267 |
+
traceback.print_exc()
|
| 268 |
+
print(f"Failed to add support for X/Y/Z Plot Script because: {e}")
|
| 269 |
+
|
| 270 |
+
script_callbacks.on_before_ui(callback_before_ui)
|
sigmas_tools_and_the_golden_scheduler/.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
paths:
|
| 8 |
+
- "pyproject.toml"
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
publish-node:
|
| 12 |
+
name: Publish Custom Node to registry
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- name: Check out code
|
| 16 |
+
uses: actions/checkout@v4
|
| 17 |
+
- name: Publish Custom Node
|
| 18 |
+
uses: Comfy-Org/publish-node-action@main
|
| 19 |
+
with:
|
| 20 |
+
## Add your own personal access token to your Github Repository secrets and reference it here.
|
| 21 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
sigmas_tools_and_the_golden_scheduler/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (273 Bytes). View file
|
|
|
sigmas_tools_and_the_golden_scheduler/__pycache__/sigmas_merge.cpython-312.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
stable-diffusion-temperature-settings/.github/FUNDING.yml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These are supported funding model platforms
|
| 2 |
+
|
| 3 |
+
patreon: extraltodeus
|
stable-diffusion-temperature-settings/.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
- master
|
| 8 |
+
paths:
|
| 9 |
+
- "pyproject.toml"
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
publish-node:
|
| 13 |
+
name: Publish Custom Node to registry
|
| 14 |
+
runs-on: ubuntu-latest
|
| 15 |
+
steps:
|
| 16 |
+
- name: Check out code
|
| 17 |
+
uses: actions/checkout@v4
|
| 18 |
+
- name: Publish Custom Node
|
| 19 |
+
uses: Comfy-Org/publish-node-action@main
|
| 20 |
+
with:
|
| 21 |
+
## Add your own personal access token to your Github Repository secrets and reference it here.
|
| 22 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
stable-diffusion-temperature-settings/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (342 Bytes). View file
|
|
|
stable-diffusion-temperature-settings/__pycache__/nodes.cpython-312.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
stable-diffusion-temperature-settings/workflows/tinybottle.png
ADDED
|
ultimate-upscale-for-automatic1111/scripts/ultimate-upscale.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from PIL import Image, ImageDraw, ImageOps
|
| 4 |
+
from modules import processing, shared, images, devices, scripts
|
| 5 |
+
from modules.processing import StableDiffusionProcessing
|
| 6 |
+
from modules.processing import Processed
|
| 7 |
+
from modules.shared import opts, state
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
elem_id_prefix = "ultimateupscale"
|
| 11 |
+
|
| 12 |
+
class USDUMode(Enum):
|
| 13 |
+
LINEAR = 0
|
| 14 |
+
CHESS = 1
|
| 15 |
+
NONE = 2
|
| 16 |
+
|
| 17 |
+
class USDUSFMode(Enum):
|
| 18 |
+
NONE = 0
|
| 19 |
+
BAND_PASS = 1
|
| 20 |
+
HALF_TILE = 2
|
| 21 |
+
HALF_TILE_PLUS_INTERSECTIONS = 3
|
| 22 |
+
|
| 23 |
+
class USDUpscaler():
|
| 24 |
+
|
| 25 |
+
def __init__(self, p, image, upscaler_index:int, save_redraw, save_seams_fix, tile_width, tile_height) -> None:
|
| 26 |
+
self.p:StableDiffusionProcessing = p
|
| 27 |
+
self.image:Image = image
|
| 28 |
+
self.scale_factor = math.ceil(max(p.width, p.height) / max(image.width, image.height))
|
| 29 |
+
self.upscaler = shared.sd_upscalers[upscaler_index]
|
| 30 |
+
self.redraw = USDURedraw()
|
| 31 |
+
self.redraw.save = save_redraw
|
| 32 |
+
self.redraw.tile_width = tile_width if tile_width > 0 else tile_height
|
| 33 |
+
self.redraw.tile_height = tile_height if tile_height > 0 else tile_width
|
| 34 |
+
self.seams_fix = USDUSeamsFix()
|
| 35 |
+
self.seams_fix.save = save_seams_fix
|
| 36 |
+
self.seams_fix.tile_width = tile_width if tile_width > 0 else tile_height
|
| 37 |
+
self.seams_fix.tile_height = tile_height if tile_height > 0 else tile_width
|
| 38 |
+
self.initial_info = None
|
| 39 |
+
self.rows = math.ceil(self.p.height / self.redraw.tile_height)
|
| 40 |
+
self.cols = math.ceil(self.p.width / self.redraw.tile_width)
|
| 41 |
+
|
| 42 |
+
def get_factor(self, num):
|
| 43 |
+
# Its just return, don't need elif
|
| 44 |
+
if num == 1:
|
| 45 |
+
return 2
|
| 46 |
+
if num % 4 == 0:
|
| 47 |
+
return 4
|
| 48 |
+
if num % 3 == 0:
|
| 49 |
+
return 3
|
| 50 |
+
if num % 2 == 0:
|
| 51 |
+
return 2
|
| 52 |
+
return 0
|
| 53 |
+
|
| 54 |
+
def get_factors(self):
|
| 55 |
+
scales = []
|
| 56 |
+
current_scale = 1
|
| 57 |
+
current_scale_factor = self.get_factor(self.scale_factor)
|
| 58 |
+
while current_scale_factor == 0:
|
| 59 |
+
self.scale_factor += 1
|
| 60 |
+
current_scale_factor = self.get_factor(self.scale_factor)
|
| 61 |
+
while current_scale < self.scale_factor:
|
| 62 |
+
current_scale_factor = self.get_factor(self.scale_factor // current_scale)
|
| 63 |
+
scales.append(current_scale_factor)
|
| 64 |
+
current_scale = current_scale * current_scale_factor
|
| 65 |
+
if current_scale_factor == 0:
|
| 66 |
+
break
|
| 67 |
+
self.scales = enumerate(scales)
|
| 68 |
+
|
| 69 |
+
def upscale(self):
|
| 70 |
+
# Log info
|
| 71 |
+
print(f"Canva size: {self.p.width}x{self.p.height}")
|
| 72 |
+
print(f"Image size: {self.image.width}x{self.image.height}")
|
| 73 |
+
print(f"Scale factor: {self.scale_factor}")
|
| 74 |
+
# Check upscaler is not empty
|
| 75 |
+
if self.upscaler.name == "None":
|
| 76 |
+
self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
|
| 77 |
+
return
|
| 78 |
+
# Get list with scale factors
|
| 79 |
+
self.get_factors()
|
| 80 |
+
# Upscaling image over all factors
|
| 81 |
+
for index, value in self.scales:
|
| 82 |
+
print(f"Upscaling iteration {index+1} with scale factor {value}")
|
| 83 |
+
self.image = self.upscaler.scaler.upscale(self.image, value, self.upscaler.data_path)
|
| 84 |
+
# Resize image to set values
|
| 85 |
+
self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
|
| 86 |
+
|
| 87 |
+
def setup_redraw(self, redraw_mode, padding, mask_blur):
|
| 88 |
+
self.redraw.mode = USDUMode(redraw_mode)
|
| 89 |
+
self.redraw.enabled = self.redraw.mode != USDUMode.NONE
|
| 90 |
+
self.redraw.padding = padding
|
| 91 |
+
self.p.mask_blur = mask_blur
|
| 92 |
+
|
| 93 |
+
def setup_seams_fix(self, padding, denoise, mask_blur, width, mode):
|
| 94 |
+
self.seams_fix.padding = padding
|
| 95 |
+
self.seams_fix.denoise = denoise
|
| 96 |
+
self.seams_fix.mask_blur = mask_blur
|
| 97 |
+
self.seams_fix.width = width
|
| 98 |
+
self.seams_fix.mode = USDUSFMode(mode)
|
| 99 |
+
self.seams_fix.enabled = self.seams_fix.mode != USDUSFMode.NONE
|
| 100 |
+
|
| 101 |
+
def save_image(self):
|
| 102 |
+
if type(self.p.prompt) != list:
|
| 103 |
+
images.save_image(self.image, self.p.outpath_samples, "", self.p.seed, self.p.prompt, opts.samples_format, info=self.initial_info, p=self.p)
|
| 104 |
+
else:
|
| 105 |
+
images.save_image(self.image, self.p.outpath_samples, "", self.p.seed, self.p.prompt[0], opts.samples_format, info=self.initial_info, p=self.p)
|
| 106 |
+
|
| 107 |
+
def calc_jobs_count(self):
|
| 108 |
+
redraw_job_count = (self.rows * self.cols) if self.redraw.enabled else 0
|
| 109 |
+
seams_job_count = 0
|
| 110 |
+
if self.seams_fix.mode == USDUSFMode.BAND_PASS:
|
| 111 |
+
seams_job_count = self.rows + self.cols - 2
|
| 112 |
+
elif self.seams_fix.mode == USDUSFMode.HALF_TILE:
|
| 113 |
+
seams_job_count = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols
|
| 114 |
+
elif self.seams_fix.mode == USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS:
|
| 115 |
+
seams_job_count = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols + (self.rows - 1) * (self.cols - 1)
|
| 116 |
+
|
| 117 |
+
state.job_count = redraw_job_count + seams_job_count
|
| 118 |
+
|
| 119 |
+
def print_info(self):
|
| 120 |
+
print(f"Tile size: {self.redraw.tile_width}x{self.redraw.tile_height}")
|
| 121 |
+
print(f"Tiles amount: {self.rows * self.cols}")
|
| 122 |
+
print(f"Grid: {self.rows}x{self.cols}")
|
| 123 |
+
print(f"Redraw enabled: {self.redraw.enabled}")
|
| 124 |
+
print(f"Seams fix mode: {self.seams_fix.mode.name}")
|
| 125 |
+
|
| 126 |
+
def add_extra_info(self):
|
| 127 |
+
self.p.extra_generation_params["Ultimate SD upscale upscaler"] = self.upscaler.name
|
| 128 |
+
self.p.extra_generation_params["Ultimate SD upscale tile_width"] = self.redraw.tile_width
|
| 129 |
+
self.p.extra_generation_params["Ultimate SD upscale tile_height"] = self.redraw.tile_height
|
| 130 |
+
self.p.extra_generation_params["Ultimate SD upscale mask_blur"] = self.p.mask_blur
|
| 131 |
+
self.p.extra_generation_params["Ultimate SD upscale padding"] = self.redraw.padding
|
| 132 |
+
|
| 133 |
+
def process(self):
|
| 134 |
+
state.begin()
|
| 135 |
+
self.calc_jobs_count()
|
| 136 |
+
self.result_images = []
|
| 137 |
+
if self.redraw.enabled:
|
| 138 |
+
self.image = self.redraw.start(self.p, self.image, self.rows, self.cols)
|
| 139 |
+
self.initial_info = self.redraw.initial_info
|
| 140 |
+
self.result_images.append(self.image)
|
| 141 |
+
if self.redraw.save:
|
| 142 |
+
self.save_image()
|
| 143 |
+
|
| 144 |
+
if self.seams_fix.enabled:
|
| 145 |
+
self.image = self.seams_fix.start(self.p, self.image, self.rows, self.cols)
|
| 146 |
+
self.initial_info = self.seams_fix.initial_info
|
| 147 |
+
self.result_images.append(self.image)
|
| 148 |
+
if self.seams_fix.save:
|
| 149 |
+
self.save_image()
|
| 150 |
+
state.end()
|
| 151 |
+
|
| 152 |
+
class USDURedraw():
|
| 153 |
+
|
| 154 |
+
def init_draw(self, p, width, height):
|
| 155 |
+
p.inpaint_full_res = True
|
| 156 |
+
p.inpaint_full_res_padding = self.padding
|
| 157 |
+
p.width = math.ceil((self.tile_width+self.padding) / 64) * 64
|
| 158 |
+
p.height = math.ceil((self.tile_height+self.padding) / 64) * 64
|
| 159 |
+
mask = Image.new("L", (width, height), "black")
|
| 160 |
+
draw = ImageDraw.Draw(mask)
|
| 161 |
+
return mask, draw
|
| 162 |
+
|
| 163 |
+
def calc_rectangle(self, xi, yi):
|
| 164 |
+
x1 = xi * self.tile_width
|
| 165 |
+
y1 = yi * self.tile_height
|
| 166 |
+
x2 = xi * self.tile_width + self.tile_width
|
| 167 |
+
y2 = yi * self.tile_height + self.tile_height
|
| 168 |
+
|
| 169 |
+
return x1, y1, x2, y2
|
| 170 |
+
|
| 171 |
+
def linear_process(self, p, image, rows, cols):
|
| 172 |
+
mask, draw = self.init_draw(p, image.width, image.height)
|
| 173 |
+
for yi in range(rows):
|
| 174 |
+
for xi in range(cols):
|
| 175 |
+
if state.interrupted:
|
| 176 |
+
break
|
| 177 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
|
| 178 |
+
p.init_images = [image]
|
| 179 |
+
p.image_mask = mask
|
| 180 |
+
processed = processing.process_images(p)
|
| 181 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
|
| 182 |
+
if (len(processed.images) > 0):
|
| 183 |
+
image = processed.images[0]
|
| 184 |
+
|
| 185 |
+
p.width = image.width
|
| 186 |
+
p.height = image.height
|
| 187 |
+
self.initial_info = processed.infotext(p, 0)
|
| 188 |
+
|
| 189 |
+
return image
|
| 190 |
+
|
| 191 |
+
def chess_process(self, p, image, rows, cols):
|
| 192 |
+
mask, draw = self.init_draw(p, image.width, image.height)
|
| 193 |
+
tiles = []
|
| 194 |
+
# calc tiles colors
|
| 195 |
+
for yi in range(rows):
|
| 196 |
+
for xi in range(cols):
|
| 197 |
+
if state.interrupted:
|
| 198 |
+
break
|
| 199 |
+
if xi == 0:
|
| 200 |
+
tiles.append([])
|
| 201 |
+
color = xi % 2 == 0
|
| 202 |
+
if yi > 0 and yi % 2 != 0:
|
| 203 |
+
color = not color
|
| 204 |
+
tiles[yi].append(color)
|
| 205 |
+
|
| 206 |
+
for yi in range(len(tiles)):
|
| 207 |
+
for xi in range(len(tiles[yi])):
|
| 208 |
+
if state.interrupted:
|
| 209 |
+
break
|
| 210 |
+
if not tiles[yi][xi]:
|
| 211 |
+
tiles[yi][xi] = not tiles[yi][xi]
|
| 212 |
+
continue
|
| 213 |
+
tiles[yi][xi] = not tiles[yi][xi]
|
| 214 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
|
| 215 |
+
p.init_images = [image]
|
| 216 |
+
p.image_mask = mask
|
| 217 |
+
processed = processing.process_images(p)
|
| 218 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
|
| 219 |
+
if (len(processed.images) > 0):
|
| 220 |
+
image = processed.images[0]
|
| 221 |
+
|
| 222 |
+
for yi in range(len(tiles)):
|
| 223 |
+
for xi in range(len(tiles[yi])):
|
| 224 |
+
if state.interrupted:
|
| 225 |
+
break
|
| 226 |
+
if not tiles[yi][xi]:
|
| 227 |
+
continue
|
| 228 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
|
| 229 |
+
p.init_images = [image]
|
| 230 |
+
p.image_mask = mask
|
| 231 |
+
processed = processing.process_images(p)
|
| 232 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
|
| 233 |
+
if (len(processed.images) > 0):
|
| 234 |
+
image = processed.images[0]
|
| 235 |
+
|
| 236 |
+
p.width = image.width
|
| 237 |
+
p.height = image.height
|
| 238 |
+
self.initial_info = processed.infotext(p, 0)
|
| 239 |
+
|
| 240 |
+
return image
|
| 241 |
+
|
| 242 |
+
def start(self, p, image, rows, cols):
|
| 243 |
+
self.initial_info = None
|
| 244 |
+
if self.mode == USDUMode.LINEAR:
|
| 245 |
+
return self.linear_process(p, image, rows, cols)
|
| 246 |
+
if self.mode == USDUMode.CHESS:
|
| 247 |
+
return self.chess_process(p, image, rows, cols)
|
| 248 |
+
|
| 249 |
+
class USDUSeamsFix():
|
| 250 |
+
|
| 251 |
+
def init_draw(self, p):
|
| 252 |
+
self.initial_info = None
|
| 253 |
+
p.width = math.ceil((self.tile_width+self.padding) / 64) * 64
|
| 254 |
+
p.height = math.ceil((self.tile_height+self.padding) / 64) * 64
|
| 255 |
+
|
| 256 |
+
def half_tile_process(self, p, image, rows, cols):
|
| 257 |
+
|
| 258 |
+
self.init_draw(p)
|
| 259 |
+
processed = None
|
| 260 |
+
|
| 261 |
+
gradient = Image.linear_gradient("L")
|
| 262 |
+
row_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
|
| 263 |
+
row_gradient.paste(gradient.resize(
|
| 264 |
+
(self.tile_width, self.tile_height//2), resample=Image.BICUBIC), (0, 0))
|
| 265 |
+
row_gradient.paste(gradient.rotate(180).resize(
|
| 266 |
+
(self.tile_width, self.tile_height//2), resample=Image.BICUBIC),
|
| 267 |
+
(0, self.tile_height//2))
|
| 268 |
+
col_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
|
| 269 |
+
col_gradient.paste(gradient.rotate(90).resize(
|
| 270 |
+
(self.tile_width//2, self.tile_height), resample=Image.BICUBIC), (0, 0))
|
| 271 |
+
col_gradient.paste(gradient.rotate(270).resize(
|
| 272 |
+
(self.tile_width//2, self.tile_height), resample=Image.BICUBIC), (self.tile_width//2, 0))
|
| 273 |
+
|
| 274 |
+
p.denoising_strength = self.denoise
|
| 275 |
+
p.mask_blur = self.mask_blur
|
| 276 |
+
|
| 277 |
+
for yi in range(rows-1):
|
| 278 |
+
for xi in range(cols):
|
| 279 |
+
if state.interrupted:
|
| 280 |
+
break
|
| 281 |
+
p.width = self.tile_width
|
| 282 |
+
p.height = self.tile_height
|
| 283 |
+
p.inpaint_full_res = True
|
| 284 |
+
p.inpaint_full_res_padding = self.padding
|
| 285 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
| 286 |
+
mask.paste(row_gradient, (xi*self.tile_width, yi*self.tile_height + self.tile_height//2))
|
| 287 |
+
|
| 288 |
+
p.init_images = [image]
|
| 289 |
+
p.image_mask = mask
|
| 290 |
+
processed = processing.process_images(p)
|
| 291 |
+
if (len(processed.images) > 0):
|
| 292 |
+
image = processed.images[0]
|
| 293 |
+
|
| 294 |
+
for yi in range(rows):
|
| 295 |
+
for xi in range(cols-1):
|
| 296 |
+
if state.interrupted:
|
| 297 |
+
break
|
| 298 |
+
p.width = self.tile_width
|
| 299 |
+
p.height = self.tile_height
|
| 300 |
+
p.inpaint_full_res = True
|
| 301 |
+
p.inpaint_full_res_padding = self.padding
|
| 302 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
| 303 |
+
mask.paste(col_gradient, (xi*self.tile_width+self.tile_width//2, yi*self.tile_height))
|
| 304 |
+
|
| 305 |
+
p.init_images = [image]
|
| 306 |
+
p.image_mask = mask
|
| 307 |
+
processed = processing.process_images(p)
|
| 308 |
+
if (len(processed.images) > 0):
|
| 309 |
+
image = processed.images[0]
|
| 310 |
+
|
| 311 |
+
p.width = image.width
|
| 312 |
+
p.height = image.height
|
| 313 |
+
if processed is not None:
|
| 314 |
+
self.initial_info = processed.infotext(p, 0)
|
| 315 |
+
|
| 316 |
+
return image
|
| 317 |
+
|
| 318 |
+
def half_tile_process_corners(self, p, image, rows, cols):
|
| 319 |
+
fixed_image = self.half_tile_process(p, image, rows, cols)
|
| 320 |
+
processed = None
|
| 321 |
+
self.init_draw(p)
|
| 322 |
+
gradient = Image.radial_gradient("L").resize(
|
| 323 |
+
(self.tile_width, self.tile_height), resample=Image.BICUBIC)
|
| 324 |
+
gradient = ImageOps.invert(gradient)
|
| 325 |
+
p.denoising_strength = self.denoise
|
| 326 |
+
#p.mask_blur = 0
|
| 327 |
+
p.mask_blur = self.mask_blur
|
| 328 |
+
|
| 329 |
+
for yi in range(rows-1):
|
| 330 |
+
for xi in range(cols-1):
|
| 331 |
+
if state.interrupted:
|
| 332 |
+
break
|
| 333 |
+
p.width = self.tile_width
|
| 334 |
+
p.height = self.tile_height
|
| 335 |
+
p.inpaint_full_res = True
|
| 336 |
+
p.inpaint_full_res_padding = 0
|
| 337 |
+
mask = Image.new("L", (fixed_image.width, fixed_image.height), "black")
|
| 338 |
+
mask.paste(gradient, (xi*self.tile_width + self.tile_width//2,
|
| 339 |
+
yi*self.tile_height + self.tile_height//2))
|
| 340 |
+
|
| 341 |
+
p.init_images = [fixed_image]
|
| 342 |
+
p.image_mask = mask
|
| 343 |
+
processed = processing.process_images(p)
|
| 344 |
+
if (len(processed.images) > 0):
|
| 345 |
+
fixed_image = processed.images[0]
|
| 346 |
+
|
| 347 |
+
p.width = fixed_image.width
|
| 348 |
+
p.height = fixed_image.height
|
| 349 |
+
if processed is not None:
|
| 350 |
+
self.initial_info = processed.infotext(p, 0)
|
| 351 |
+
|
| 352 |
+
return fixed_image
|
| 353 |
+
|
| 354 |
+
def band_pass_process(self, p, image, cols, rows):
|
| 355 |
+
|
| 356 |
+
self.init_draw(p)
|
| 357 |
+
processed = None
|
| 358 |
+
|
| 359 |
+
p.denoising_strength = self.denoise
|
| 360 |
+
p.mask_blur = 0
|
| 361 |
+
|
| 362 |
+
gradient = Image.linear_gradient("L")
|
| 363 |
+
mirror_gradient = Image.new("L", (256, 256), "black")
|
| 364 |
+
mirror_gradient.paste(gradient.resize((256, 128), resample=Image.BICUBIC), (0, 0))
|
| 365 |
+
mirror_gradient.paste(gradient.rotate(180).resize((256, 128), resample=Image.BICUBIC), (0, 128))
|
| 366 |
+
|
| 367 |
+
row_gradient = mirror_gradient.resize((image.width, self.width), resample=Image.BICUBIC)
|
| 368 |
+
col_gradient = mirror_gradient.rotate(90).resize((self.width, image.height), resample=Image.BICUBIC)
|
| 369 |
+
|
| 370 |
+
for xi in range(1, rows):
|
| 371 |
+
if state.interrupted:
|
| 372 |
+
break
|
| 373 |
+
p.width = self.width + self.padding * 2
|
| 374 |
+
p.height = image.height
|
| 375 |
+
p.inpaint_full_res = True
|
| 376 |
+
p.inpaint_full_res_padding = self.padding
|
| 377 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
| 378 |
+
mask.paste(col_gradient, (xi * self.tile_width - self.width // 2, 0))
|
| 379 |
+
|
| 380 |
+
p.init_images = [image]
|
| 381 |
+
p.image_mask = mask
|
| 382 |
+
processed = processing.process_images(p)
|
| 383 |
+
if (len(processed.images) > 0):
|
| 384 |
+
image = processed.images[0]
|
| 385 |
+
for yi in range(1, cols):
|
| 386 |
+
if state.interrupted:
|
| 387 |
+
break
|
| 388 |
+
p.width = image.width
|
| 389 |
+
p.height = self.width + self.padding * 2
|
| 390 |
+
p.inpaint_full_res = True
|
| 391 |
+
p.inpaint_full_res_padding = self.padding
|
| 392 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
| 393 |
+
mask.paste(row_gradient, (0, yi * self.tile_height - self.width // 2))
|
| 394 |
+
|
| 395 |
+
p.init_images = [image]
|
| 396 |
+
p.image_mask = mask
|
| 397 |
+
processed = processing.process_images(p)
|
| 398 |
+
if (len(processed.images) > 0):
|
| 399 |
+
image = processed.images[0]
|
| 400 |
+
|
| 401 |
+
p.width = image.width
|
| 402 |
+
p.height = image.height
|
| 403 |
+
if processed is not None:
|
| 404 |
+
self.initial_info = processed.infotext(p, 0)
|
| 405 |
+
|
| 406 |
+
return image
|
| 407 |
+
|
| 408 |
+
def start(self, p, image, rows, cols):
|
| 409 |
+
if USDUSFMode(self.mode) == USDUSFMode.BAND_PASS:
|
| 410 |
+
return self.band_pass_process(p, image, rows, cols)
|
| 411 |
+
elif USDUSFMode(self.mode) == USDUSFMode.HALF_TILE:
|
| 412 |
+
return self.half_tile_process(p, image, rows, cols)
|
| 413 |
+
elif USDUSFMode(self.mode) == USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS:
|
| 414 |
+
return self.half_tile_process_corners(p, image, rows, cols)
|
| 415 |
+
else:
|
| 416 |
+
return image
|
| 417 |
+
|
| 418 |
+
class Script(scripts.Script):
|
| 419 |
+
def title(self):
|
| 420 |
+
return "Ultimate SD upscale"
|
| 421 |
+
|
| 422 |
+
def show(self, is_img2img):
|
| 423 |
+
return is_img2img
|
| 424 |
+
|
| 425 |
+
def ui(self, is_img2img):
|
| 426 |
+
|
| 427 |
+
target_size_types = [
|
| 428 |
+
"From img2img2 settings",
|
| 429 |
+
"Custom size",
|
| 430 |
+
"Scale from image size"
|
| 431 |
+
]
|
| 432 |
+
|
| 433 |
+
seams_fix_types = [
|
| 434 |
+
"None",
|
| 435 |
+
"Band pass",
|
| 436 |
+
"Half tile offset pass",
|
| 437 |
+
"Half tile offset pass + intersections"
|
| 438 |
+
]
|
| 439 |
+
|
| 440 |
+
redrow_modes = [
|
| 441 |
+
"Linear",
|
| 442 |
+
"Chess",
|
| 443 |
+
"None"
|
| 444 |
+
]
|
| 445 |
+
|
| 446 |
+
info = gr.HTML(
|
| 447 |
+
"<p style=\"margin-bottom:0.75em\">Will upscale the image depending on the selected target size type</p>")
|
| 448 |
+
|
| 449 |
+
with gr.Row():
|
| 450 |
+
target_size_type = gr.Dropdown(label="Target size type", elem_id=f"{elem_id_prefix}_target_size_type", choices=[k for k in target_size_types], type="index",
|
| 451 |
+
value=next(iter(target_size_types)))
|
| 452 |
+
|
| 453 |
+
custom_width = gr.Slider(label='Custom width', elem_id=f"{elem_id_prefix}_custom_width", minimum=64, maximum=8192, step=64, value=2048, visible=False, interactive=True)
|
| 454 |
+
custom_height = gr.Slider(label='Custom height', elem_id=f"{elem_id_prefix}_custom_height", minimum=64, maximum=8192, step=64, value=2048, visible=False, interactive=True)
|
| 455 |
+
custom_scale = gr.Slider(label='Scale', elem_id=f"{elem_id_prefix}_custom_scale", minimum=1, maximum=16, step=0.01, value=2, visible=False, interactive=True)
|
| 456 |
+
|
| 457 |
+
gr.HTML("<p style=\"margin-bottom:0.75em\">Redraw options:</p>")
|
| 458 |
+
with gr.Row():
|
| 459 |
+
upscaler_index = gr.Radio(label='Upscaler', elem_id=f"{elem_id_prefix}_upscaler_index", choices=[x.name for x in shared.sd_upscalers],
|
| 460 |
+
value=shared.sd_upscalers[0].name, type="index")
|
| 461 |
+
with gr.Row():
|
| 462 |
+
redraw_mode = gr.Dropdown(label="Type", elem_id=f"{elem_id_prefix}_redraw_mode", choices=[k for k in redrow_modes], type="index", value=next(iter(redrow_modes)))
|
| 463 |
+
tile_width = gr.Slider(elem_id=f"{elem_id_prefix}_tile_width", minimum=0, maximum=2048, step=64, label='Tile width', value=512)
|
| 464 |
+
tile_height = gr.Slider(elem_id=f"{elem_id_prefix}_tile_height", minimum=0, maximum=2048, step=64, label='Tile height', value=0)
|
| 465 |
+
mask_blur = gr.Slider(elem_id=f"{elem_id_prefix}_mask_blur", label='Mask blur', minimum=0, maximum=64, step=1, value=8)
|
| 466 |
+
padding = gr.Slider(elem_id=f"{elem_id_prefix}_padding", label='Padding', minimum=0, maximum=512, step=1, value=32)
|
| 467 |
+
gr.HTML("<p style=\"margin-bottom:0.75em\">Seams fix:</p>")
|
| 468 |
+
with gr.Row():
|
| 469 |
+
seams_fix_type = gr.Dropdown(label="Type", elem_id=f"{elem_id_prefix}_seams_fix_type", choices=[k for k in seams_fix_types], type="index", value=next(iter(seams_fix_types)))
|
| 470 |
+
seams_fix_denoise = gr.Slider(label='Denoise', elem_id=f"{elem_id_prefix}_seams_fix_denoise", minimum=0, maximum=1, step=0.01, value=0.35, visible=False, interactive=True)
|
| 471 |
+
seams_fix_width = gr.Slider(label='Width', elem_id=f"{elem_id_prefix}_seams_fix_width", minimum=0, maximum=128, step=1, value=64, visible=False, interactive=True)
|
| 472 |
+
seams_fix_mask_blur = gr.Slider(label='Mask blur', elem_id=f"{elem_id_prefix}_seams_fix_mask_blur", minimum=0, maximum=64, step=1, value=4, visible=False, interactive=True)
|
| 473 |
+
seams_fix_padding = gr.Slider(label='Padding', elem_id=f"{elem_id_prefix}_seams_fix_padding", minimum=0, maximum=128, step=1, value=16, visible=False, interactive=True)
|
| 474 |
+
gr.HTML("<p style=\"margin-bottom:0.75em\">Save options:</p>")
|
| 475 |
+
with gr.Row():
|
| 476 |
+
save_upscaled_image = gr.Checkbox(label="Upscaled", elem_id=f"{elem_id_prefix}_save_upscaled_image", value=True)
|
| 477 |
+
save_seams_fix_image = gr.Checkbox(label="Seams fix", elem_id=f"{elem_id_prefix}_save_seams_fix_image", value=False)
|
| 478 |
+
|
| 479 |
+
def select_fix_type(fix_index):
|
| 480 |
+
all_visible = fix_index != 0
|
| 481 |
+
mask_blur_visible = fix_index == 2 or fix_index == 3
|
| 482 |
+
width_visible = fix_index == 1
|
| 483 |
+
|
| 484 |
+
return [gr.update(visible=all_visible),
|
| 485 |
+
gr.update(visible=width_visible),
|
| 486 |
+
gr.update(visible=mask_blur_visible),
|
| 487 |
+
gr.update(visible=all_visible)]
|
| 488 |
+
|
| 489 |
+
seams_fix_type.change(
|
| 490 |
+
fn=select_fix_type,
|
| 491 |
+
inputs=seams_fix_type,
|
| 492 |
+
outputs=[seams_fix_denoise, seams_fix_width, seams_fix_mask_blur, seams_fix_padding]
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def select_scale_type(scale_index):
|
| 496 |
+
is_custom_size = scale_index == 1
|
| 497 |
+
is_custom_scale = scale_index == 2
|
| 498 |
+
|
| 499 |
+
return [gr.update(visible=is_custom_size),
|
| 500 |
+
gr.update(visible=is_custom_size),
|
| 501 |
+
gr.update(visible=is_custom_scale),
|
| 502 |
+
]
|
| 503 |
+
|
| 504 |
+
target_size_type.change(
|
| 505 |
+
fn=select_scale_type,
|
| 506 |
+
inputs=target_size_type,
|
| 507 |
+
outputs=[custom_width, custom_height, custom_scale]
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def init_field(scale_name):
|
| 511 |
+
try:
|
| 512 |
+
scale_index = target_size_types.index(scale_name)
|
| 513 |
+
custom_width.visible = custom_height.visible = scale_index == 1
|
| 514 |
+
custom_scale.visible = scale_index == 2
|
| 515 |
+
except:
|
| 516 |
+
pass
|
| 517 |
+
|
| 518 |
+
target_size_type.init_field = init_field
|
| 519 |
+
|
| 520 |
+
return [info, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding,
|
| 521 |
+
upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur,
|
| 522 |
+
seams_fix_type, target_size_type, custom_width, custom_height, custom_scale]
|
| 523 |
+
|
| 524 |
+
def run(self, p, _, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding,
|
| 525 |
+
upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur,
|
| 526 |
+
seams_fix_type, target_size_type, custom_width, custom_height, custom_scale):
|
| 527 |
+
|
| 528 |
+
# Init
|
| 529 |
+
processing.fix_seed(p)
|
| 530 |
+
devices.torch_gc()
|
| 531 |
+
|
| 532 |
+
p.do_not_save_grid = True
|
| 533 |
+
p.do_not_save_samples = True
|
| 534 |
+
p.inpaint_full_res = False
|
| 535 |
+
|
| 536 |
+
p.inpainting_fill = 1
|
| 537 |
+
p.n_iter = 1
|
| 538 |
+
p.batch_size = 1
|
| 539 |
+
|
| 540 |
+
seed = p.seed
|
| 541 |
+
|
| 542 |
+
# Init image
|
| 543 |
+
init_img = p.init_images[0]
|
| 544 |
+
if init_img == None:
|
| 545 |
+
return Processed(p, [], seed, "Empty image")
|
| 546 |
+
init_img = images.flatten(init_img, opts.img2img_background_color)
|
| 547 |
+
|
| 548 |
+
#override size
|
| 549 |
+
if target_size_type == 1:
|
| 550 |
+
p.width = custom_width
|
| 551 |
+
p.height = custom_height
|
| 552 |
+
if target_size_type == 2:
|
| 553 |
+
p.width = math.ceil((init_img.width * custom_scale) / 64) * 64
|
| 554 |
+
p.height = math.ceil((init_img.height * custom_scale) / 64) * 64
|
| 555 |
+
|
| 556 |
+
# Upscaling
|
| 557 |
+
upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height)
|
| 558 |
+
upscaler.upscale()
|
| 559 |
+
|
| 560 |
+
# Drawing
|
| 561 |
+
upscaler.setup_redraw(redraw_mode, padding, mask_blur)
|
| 562 |
+
upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type)
|
| 563 |
+
upscaler.print_info()
|
| 564 |
+
upscaler.add_extra_info()
|
| 565 |
+
upscaler.process()
|
| 566 |
+
result_images = upscaler.result_images
|
| 567 |
+
|
| 568 |
+
return Processed(p, result_images, seed, upscaler.initial_info if upscaler.initial_info is not None else "")
|
| 569 |
+
|
was-node-suite-comfyui/.github/workflows/publish_action.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
paths:
|
| 8 |
+
- "pyproject.toml"
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
publish-node:
|
| 12 |
+
name: Publish Custom Node to registry
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- name: Check out code
|
| 16 |
+
uses: actions/checkout@v4
|
| 17 |
+
- name: Publish Custom Node
|
| 18 |
+
uses: Comfy-Org/publish-node-action@main
|
| 19 |
+
with:
|
| 20 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
|
was-node-suite-comfyui/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
was-node-suite-comfyui/modules/BLIP/__init__.py
ADDED
|
File without changes
|
was-node-suite-comfyui/modules/BLIP/blip_med.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on huggingface code base
|
| 8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import warnings
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import Tensor, device, dtype, nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.nn import CrossEntropyLoss
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from transformers.activations import ACT2FN
|
| 25 |
+
from transformers.file_utils import (
|
| 26 |
+
ModelOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
+
CausalLMOutputWithCrossAttentions,
|
| 32 |
+
MaskedLMOutput,
|
| 33 |
+
MultipleChoiceModelOutput,
|
| 34 |
+
NextSentencePredictorOutput,
|
| 35 |
+
QuestionAnsweringModelOutput,
|
| 36 |
+
SequenceClassifierOutput,
|
| 37 |
+
TokenClassifierOutput,
|
| 38 |
+
)
|
| 39 |
+
from transformers.modeling_utils import (
|
| 40 |
+
PreTrainedModel,
|
| 41 |
+
apply_chunking_to_forward,
|
| 42 |
+
find_pruneable_heads_and_indices,
|
| 43 |
+
prune_linear_layer,
|
| 44 |
+
)
|
| 45 |
+
from transformers.utils import logging
|
| 46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BertEmbeddings(nn.Module):
|
| 53 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 59 |
+
|
| 60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 61 |
+
# any TensorFlow checkpoint file
|
| 62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 64 |
+
|
| 65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 73 |
+
):
|
| 74 |
+
if input_ids is not None:
|
| 75 |
+
input_shape = input_ids.size()
|
| 76 |
+
else:
|
| 77 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 78 |
+
|
| 79 |
+
seq_length = input_shape[1]
|
| 80 |
+
|
| 81 |
+
if position_ids is None:
|
| 82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 83 |
+
|
| 84 |
+
if inputs_embeds is None:
|
| 85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 86 |
+
|
| 87 |
+
embeddings = inputs_embeds
|
| 88 |
+
|
| 89 |
+
if self.position_embedding_type == "absolute":
|
| 90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 91 |
+
embeddings += position_embeddings
|
| 92 |
+
embeddings = self.LayerNorm(embeddings)
|
| 93 |
+
embeddings = self.dropout(embeddings)
|
| 94 |
+
return embeddings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class BertSelfAttention(nn.Module):
|
| 98 |
+
def __init__(self, config, is_cross_attention):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.config = config
|
| 101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.num_attention_heads = config.num_attention_heads
|
| 108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 110 |
+
|
| 111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 112 |
+
if is_cross_attention:
|
| 113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 115 |
+
else:
|
| 116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 118 |
+
|
| 119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 124 |
+
self.save_attention = False
|
| 125 |
+
|
| 126 |
+
def save_attn_gradients(self, attn_gradients):
|
| 127 |
+
self.attn_gradients = attn_gradients
|
| 128 |
+
|
| 129 |
+
def get_attn_gradients(self):
|
| 130 |
+
return self.attn_gradients
|
| 131 |
+
|
| 132 |
+
def save_attention_map(self, attention_map):
|
| 133 |
+
self.attention_map = attention_map
|
| 134 |
+
|
| 135 |
+
def get_attention_map(self):
|
| 136 |
+
return self.attention_map
|
| 137 |
+
|
| 138 |
+
def transpose_for_scores(self, x):
|
| 139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 140 |
+
x = x.view(*new_x_shape)
|
| 141 |
+
return x.permute(0, 2, 1, 3)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
hidden_states,
|
| 146 |
+
attention_mask=None,
|
| 147 |
+
head_mask=None,
|
| 148 |
+
encoder_hidden_states=None,
|
| 149 |
+
encoder_attention_mask=None,
|
| 150 |
+
past_key_value=None,
|
| 151 |
+
output_attentions=False,
|
| 152 |
+
):
|
| 153 |
+
mixed_query_layer = self.query(hidden_states)
|
| 154 |
+
|
| 155 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 156 |
+
# and values come from an encoder; the attention mask needs to be
|
| 157 |
+
# such that the encoder's padding tokens are not attended to.
|
| 158 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 159 |
+
|
| 160 |
+
if is_cross_attention:
|
| 161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 163 |
+
attention_mask = encoder_attention_mask
|
| 164 |
+
elif past_key_value is not None:
|
| 165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 169 |
+
else:
|
| 170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 172 |
+
|
| 173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 174 |
+
|
| 175 |
+
past_key_value = (key_layer, value_layer)
|
| 176 |
+
|
| 177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 179 |
+
|
| 180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 181 |
+
seq_length = hidden_states.size()[1]
|
| 182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 184 |
+
distance = position_ids_l - position_ids_r
|
| 185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 187 |
+
|
| 188 |
+
if self.position_embedding_type == "relative_key":
|
| 189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 190 |
+
attention_scores = attention_scores + relative_position_scores
|
| 191 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 195 |
+
|
| 196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 199 |
+
attention_scores = attention_scores + attention_mask
|
| 200 |
+
|
| 201 |
+
# Normalize the attention scores to probabilities.
|
| 202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 203 |
+
|
| 204 |
+
if is_cross_attention and self.save_attention:
|
| 205 |
+
self.save_attention_map(attention_probs)
|
| 206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 207 |
+
|
| 208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 211 |
+
|
| 212 |
+
# Mask heads if we want to
|
| 213 |
+
if head_mask is not None:
|
| 214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 215 |
+
|
| 216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 217 |
+
|
| 218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 221 |
+
|
| 222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 223 |
+
|
| 224 |
+
outputs = outputs + (past_key_value,)
|
| 225 |
+
return outputs
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class BertSelfOutput(nn.Module):
|
| 229 |
+
def __init__(self, config):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states, input_tensor):
|
| 236 |
+
hidden_states = self.dense(hidden_states)
|
| 237 |
+
hidden_states = self.dropout(hidden_states)
|
| 238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 239 |
+
return hidden_states
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class BertAttention(nn.Module):
|
| 243 |
+
def __init__(self, config, is_cross_attention=False):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 246 |
+
self.output = BertSelfOutput(config)
|
| 247 |
+
self.pruned_heads = set()
|
| 248 |
+
|
| 249 |
+
def prune_heads(self, heads):
|
| 250 |
+
if len(heads) == 0:
|
| 251 |
+
return
|
| 252 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Prune linear layers
|
| 257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 261 |
+
|
| 262 |
+
# Update hyper params and store pruned heads
|
| 263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
hidden_states,
|
| 270 |
+
attention_mask=None,
|
| 271 |
+
head_mask=None,
|
| 272 |
+
encoder_hidden_states=None,
|
| 273 |
+
encoder_attention_mask=None,
|
| 274 |
+
past_key_value=None,
|
| 275 |
+
output_attentions=False,
|
| 276 |
+
):
|
| 277 |
+
self_outputs = self.self(
|
| 278 |
+
hidden_states,
|
| 279 |
+
attention_mask,
|
| 280 |
+
head_mask,
|
| 281 |
+
encoder_hidden_states,
|
| 282 |
+
encoder_attention_mask,
|
| 283 |
+
past_key_value,
|
| 284 |
+
output_attentions,
|
| 285 |
+
)
|
| 286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 288 |
+
return outputs
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class BertIntermediate(nn.Module):
|
| 292 |
+
def __init__(self, config):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 295 |
+
if isinstance(config.hidden_act, str):
|
| 296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 297 |
+
else:
|
| 298 |
+
self.intermediate_act_fn = config.hidden_act
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states):
|
| 301 |
+
hidden_states = self.dense(hidden_states)
|
| 302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class BertOutput(nn.Module):
|
| 307 |
+
def __init__(self, config):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 312 |
+
|
| 313 |
+
def forward(self, hidden_states, input_tensor):
|
| 314 |
+
hidden_states = self.dense(hidden_states)
|
| 315 |
+
hidden_states = self.dropout(hidden_states)
|
| 316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BertLayer(nn.Module):
|
| 321 |
+
def __init__(self, config, layer_num):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.config = config
|
| 324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 325 |
+
self.seq_len_dim = 1
|
| 326 |
+
self.attention = BertAttention(config)
|
| 327 |
+
self.layer_num = layer_num
|
| 328 |
+
if self.config.add_cross_attention:
|
| 329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
| 330 |
+
self.intermediate = BertIntermediate(config)
|
| 331 |
+
self.output = BertOutput(config)
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
hidden_states,
|
| 336 |
+
attention_mask=None,
|
| 337 |
+
head_mask=None,
|
| 338 |
+
encoder_hidden_states=None,
|
| 339 |
+
encoder_attention_mask=None,
|
| 340 |
+
past_key_value=None,
|
| 341 |
+
output_attentions=False,
|
| 342 |
+
mode=None,
|
| 343 |
+
):
|
| 344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 346 |
+
self_attention_outputs = self.attention(
|
| 347 |
+
hidden_states,
|
| 348 |
+
attention_mask,
|
| 349 |
+
head_mask,
|
| 350 |
+
output_attentions=output_attentions,
|
| 351 |
+
past_key_value=self_attn_past_key_value,
|
| 352 |
+
)
|
| 353 |
+
attention_output = self_attention_outputs[0]
|
| 354 |
+
|
| 355 |
+
outputs = self_attention_outputs[1:-1]
|
| 356 |
+
present_key_value = self_attention_outputs[-1]
|
| 357 |
+
|
| 358 |
+
if mode=='multimodal':
|
| 359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 360 |
+
|
| 361 |
+
cross_attention_outputs = self.crossattention(
|
| 362 |
+
attention_output,
|
| 363 |
+
attention_mask,
|
| 364 |
+
head_mask,
|
| 365 |
+
encoder_hidden_states,
|
| 366 |
+
encoder_attention_mask,
|
| 367 |
+
output_attentions=output_attentions,
|
| 368 |
+
)
|
| 369 |
+
attention_output = cross_attention_outputs[0]
|
| 370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 371 |
+
layer_output = apply_chunking_to_forward(
|
| 372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 373 |
+
)
|
| 374 |
+
outputs = (layer_output,) + outputs
|
| 375 |
+
|
| 376 |
+
outputs = outputs + (present_key_value,)
|
| 377 |
+
|
| 378 |
+
return outputs
|
| 379 |
+
|
| 380 |
+
def feed_forward_chunk(self, attention_output):
|
| 381 |
+
intermediate_output = self.intermediate(attention_output)
|
| 382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 383 |
+
return layer_output
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class BertEncoder(nn.Module):
|
| 387 |
+
def __init__(self, config):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.config = config
|
| 390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 391 |
+
self.gradient_checkpointing = False
|
| 392 |
+
|
| 393 |
+
def forward(
|
| 394 |
+
self,
|
| 395 |
+
hidden_states,
|
| 396 |
+
attention_mask=None,
|
| 397 |
+
head_mask=None,
|
| 398 |
+
encoder_hidden_states=None,
|
| 399 |
+
encoder_attention_mask=None,
|
| 400 |
+
past_key_values=None,
|
| 401 |
+
use_cache=None,
|
| 402 |
+
output_attentions=False,
|
| 403 |
+
output_hidden_states=False,
|
| 404 |
+
return_dict=True,
|
| 405 |
+
mode='multimodal',
|
| 406 |
+
):
|
| 407 |
+
all_hidden_states = () if output_hidden_states else None
|
| 408 |
+
all_self_attentions = () if output_attentions else None
|
| 409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 410 |
+
|
| 411 |
+
next_decoder_cache = () if use_cache else None
|
| 412 |
+
|
| 413 |
+
for i in range(self.config.num_hidden_layers):
|
| 414 |
+
layer_module = self.layer[i]
|
| 415 |
+
if output_hidden_states:
|
| 416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 417 |
+
|
| 418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 420 |
+
|
| 421 |
+
if self.gradient_checkpointing and self.training:
|
| 422 |
+
|
| 423 |
+
if use_cache:
|
| 424 |
+
logger.warn(
|
| 425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 426 |
+
)
|
| 427 |
+
use_cache = False
|
| 428 |
+
|
| 429 |
+
def create_custom_forward(module):
|
| 430 |
+
def custom_forward(*inputs):
|
| 431 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 432 |
+
|
| 433 |
+
return custom_forward
|
| 434 |
+
|
| 435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 436 |
+
create_custom_forward(layer_module),
|
| 437 |
+
hidden_states,
|
| 438 |
+
attention_mask,
|
| 439 |
+
layer_head_mask,
|
| 440 |
+
encoder_hidden_states,
|
| 441 |
+
encoder_attention_mask,
|
| 442 |
+
mode=mode,
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
layer_outputs = layer_module(
|
| 446 |
+
hidden_states,
|
| 447 |
+
attention_mask,
|
| 448 |
+
layer_head_mask,
|
| 449 |
+
encoder_hidden_states,
|
| 450 |
+
encoder_attention_mask,
|
| 451 |
+
past_key_value,
|
| 452 |
+
output_attentions,
|
| 453 |
+
mode=mode,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
hidden_states = layer_outputs[0]
|
| 457 |
+
if use_cache:
|
| 458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 459 |
+
if output_attentions:
|
| 460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 461 |
+
|
| 462 |
+
if output_hidden_states:
|
| 463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 464 |
+
|
| 465 |
+
if not return_dict:
|
| 466 |
+
return tuple(
|
| 467 |
+
v
|
| 468 |
+
for v in [
|
| 469 |
+
hidden_states,
|
| 470 |
+
next_decoder_cache,
|
| 471 |
+
all_hidden_states,
|
| 472 |
+
all_self_attentions,
|
| 473 |
+
all_cross_attentions,
|
| 474 |
+
]
|
| 475 |
+
if v is not None
|
| 476 |
+
)
|
| 477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 478 |
+
last_hidden_state=hidden_states,
|
| 479 |
+
past_key_values=next_decoder_cache,
|
| 480 |
+
hidden_states=all_hidden_states,
|
| 481 |
+
attentions=all_self_attentions,
|
| 482 |
+
cross_attentions=all_cross_attentions,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class BertPooler(nn.Module):
|
| 487 |
+
def __init__(self, config):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 490 |
+
self.activation = nn.Tanh()
|
| 491 |
+
|
| 492 |
+
def forward(self, hidden_states):
|
| 493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 494 |
+
# to the first token.
|
| 495 |
+
first_token_tensor = hidden_states[:, 0]
|
| 496 |
+
pooled_output = self.dense(first_token_tensor)
|
| 497 |
+
pooled_output = self.activation(pooled_output)
|
| 498 |
+
return pooled_output
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 502 |
+
def __init__(self, config):
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 505 |
+
if isinstance(config.hidden_act, str):
|
| 506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 507 |
+
else:
|
| 508 |
+
self.transform_act_fn = config.hidden_act
|
| 509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 510 |
+
|
| 511 |
+
def forward(self, hidden_states):
|
| 512 |
+
hidden_states = self.dense(hidden_states)
|
| 513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 515 |
+
return hidden_states
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class BertLMPredictionHead(nn.Module):
|
| 519 |
+
def __init__(self, config):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 522 |
+
|
| 523 |
+
# The output weights are the same as the input embeddings, but there is
|
| 524 |
+
# an output-only bias for each token.
|
| 525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 526 |
+
|
| 527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 528 |
+
|
| 529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 530 |
+
self.decoder.bias = self.bias
|
| 531 |
+
|
| 532 |
+
def forward(self, hidden_states):
|
| 533 |
+
hidden_states = self.transform(hidden_states)
|
| 534 |
+
hidden_states = self.decoder(hidden_states)
|
| 535 |
+
return hidden_states
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class BertOnlyMLMHead(nn.Module):
|
| 539 |
+
def __init__(self, config):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.predictions = BertLMPredictionHead(config)
|
| 542 |
+
|
| 543 |
+
def forward(self, sequence_output):
|
| 544 |
+
prediction_scores = self.predictions(sequence_output)
|
| 545 |
+
return prediction_scores
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 549 |
+
"""
|
| 550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 551 |
+
models.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
config_class = BertConfig
|
| 555 |
+
base_model_prefix = "bert"
|
| 556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 557 |
+
|
| 558 |
+
def _init_weights(self, module):
|
| 559 |
+
""" Initialize the weights """
|
| 560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 564 |
+
elif isinstance(module, nn.LayerNorm):
|
| 565 |
+
module.bias.data.zero_()
|
| 566 |
+
module.weight.data.fill_(1.0)
|
| 567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 568 |
+
module.bias.data.zero_()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class BertModel(BertPreTrainedModel):
|
| 572 |
+
"""
|
| 573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 578 |
+
input to the forward pass.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 582 |
+
super().__init__(config)
|
| 583 |
+
self.config = config
|
| 584 |
+
|
| 585 |
+
self.embeddings = BertEmbeddings(config)
|
| 586 |
+
|
| 587 |
+
self.encoder = BertEncoder(config)
|
| 588 |
+
|
| 589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 590 |
+
|
| 591 |
+
self.init_weights()
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_input_embeddings(self):
|
| 595 |
+
return self.embeddings.word_embeddings
|
| 596 |
+
|
| 597 |
+
def set_input_embeddings(self, value):
|
| 598 |
+
self.embeddings.word_embeddings = value
|
| 599 |
+
|
| 600 |
+
def _prune_heads(self, heads_to_prune):
|
| 601 |
+
"""
|
| 602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 603 |
+
class PreTrainedModel
|
| 604 |
+
"""
|
| 605 |
+
for layer, heads in heads_to_prune.items():
|
| 606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 610 |
+
"""
|
| 611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 612 |
+
|
| 613 |
+
Arguments:
|
| 614 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 616 |
+
input_shape (:obj:`Tuple[int]`):
|
| 617 |
+
The shape of the input to the model.
|
| 618 |
+
device: (:obj:`torch.device`):
|
| 619 |
+
The device of the input to the model.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 623 |
+
"""
|
| 624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 626 |
+
if attention_mask.dim() == 3:
|
| 627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 628 |
+
elif attention_mask.dim() == 2:
|
| 629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 632 |
+
if is_decoder:
|
| 633 |
+
batch_size, seq_length = input_shape
|
| 634 |
+
|
| 635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 640 |
+
|
| 641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 643 |
+
causal_mask = torch.cat(
|
| 644 |
+
[
|
| 645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 646 |
+
causal_mask,
|
| 647 |
+
],
|
| 648 |
+
axis=-1,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 652 |
+
else:
|
| 653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 657 |
+
input_shape, attention_mask.shape
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 665 |
+
# effectively the same as removing these entirely.
|
| 666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 668 |
+
return extended_attention_mask
|
| 669 |
+
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
input_ids=None,
|
| 673 |
+
attention_mask=None,
|
| 674 |
+
position_ids=None,
|
| 675 |
+
head_mask=None,
|
| 676 |
+
inputs_embeds=None,
|
| 677 |
+
encoder_embeds=None,
|
| 678 |
+
encoder_hidden_states=None,
|
| 679 |
+
encoder_attention_mask=None,
|
| 680 |
+
past_key_values=None,
|
| 681 |
+
use_cache=None,
|
| 682 |
+
output_attentions=None,
|
| 683 |
+
output_hidden_states=None,
|
| 684 |
+
return_dict=None,
|
| 685 |
+
is_decoder=False,
|
| 686 |
+
mode='multimodal',
|
| 687 |
+
):
|
| 688 |
+
r"""
|
| 689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 691 |
+
the model is configured as a decoder.
|
| 692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 695 |
+
- 1 for tokens that are **not masked**,
|
| 696 |
+
- 0 for tokens that are **masked**.
|
| 697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 702 |
+
use_cache (:obj:`bool`, `optional`):
|
| 703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 704 |
+
decoding (see :obj:`past_key_values`).
|
| 705 |
+
"""
|
| 706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 707 |
+
output_hidden_states = (
|
| 708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 709 |
+
)
|
| 710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 711 |
+
|
| 712 |
+
if is_decoder:
|
| 713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 714 |
+
else:
|
| 715 |
+
use_cache = False
|
| 716 |
+
|
| 717 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 719 |
+
elif input_ids is not None:
|
| 720 |
+
input_shape = input_ids.size()
|
| 721 |
+
batch_size, seq_length = input_shape
|
| 722 |
+
device = input_ids.device
|
| 723 |
+
elif inputs_embeds is not None:
|
| 724 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 725 |
+
batch_size, seq_length = input_shape
|
| 726 |
+
device = inputs_embeds.device
|
| 727 |
+
elif encoder_embeds is not None:
|
| 728 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 729 |
+
batch_size, seq_length = input_shape
|
| 730 |
+
device = encoder_embeds.device
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 733 |
+
|
| 734 |
+
# past_key_values_length
|
| 735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 736 |
+
|
| 737 |
+
if attention_mask is None:
|
| 738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 739 |
+
|
| 740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 743 |
+
device, is_decoder)
|
| 744 |
+
|
| 745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 747 |
+
if encoder_hidden_states is not None:
|
| 748 |
+
if type(encoder_hidden_states) == list:
|
| 749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 750 |
+
else:
|
| 751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 753 |
+
|
| 754 |
+
if type(encoder_attention_mask) == list:
|
| 755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 756 |
+
elif encoder_attention_mask is None:
|
| 757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 759 |
+
else:
|
| 760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 761 |
+
else:
|
| 762 |
+
encoder_extended_attention_mask = None
|
| 763 |
+
|
| 764 |
+
# Prepare head mask if needed
|
| 765 |
+
# 1.0 in head_mask indicate we keep the head
|
| 766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 770 |
+
|
| 771 |
+
if encoder_embeds is None:
|
| 772 |
+
embedding_output = self.embeddings(
|
| 773 |
+
input_ids=input_ids,
|
| 774 |
+
position_ids=position_ids,
|
| 775 |
+
inputs_embeds=inputs_embeds,
|
| 776 |
+
past_key_values_length=past_key_values_length,
|
| 777 |
+
)
|
| 778 |
+
else:
|
| 779 |
+
embedding_output = encoder_embeds
|
| 780 |
+
|
| 781 |
+
encoder_outputs = self.encoder(
|
| 782 |
+
embedding_output,
|
| 783 |
+
attention_mask=extended_attention_mask,
|
| 784 |
+
head_mask=head_mask,
|
| 785 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 787 |
+
past_key_values=past_key_values,
|
| 788 |
+
use_cache=use_cache,
|
| 789 |
+
output_attentions=output_attentions,
|
| 790 |
+
output_hidden_states=output_hidden_states,
|
| 791 |
+
return_dict=return_dict,
|
| 792 |
+
mode=mode,
|
| 793 |
+
)
|
| 794 |
+
sequence_output = encoder_outputs[0]
|
| 795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 796 |
+
|
| 797 |
+
if not return_dict:
|
| 798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 799 |
+
|
| 800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 801 |
+
last_hidden_state=sequence_output,
|
| 802 |
+
pooler_output=pooled_output,
|
| 803 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 804 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 805 |
+
attentions=encoder_outputs.attentions,
|
| 806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 812 |
+
|
| 813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 815 |
+
|
| 816 |
+
def __init__(self, config):
|
| 817 |
+
super().__init__(config)
|
| 818 |
+
|
| 819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 820 |
+
self.cls = BertOnlyMLMHead(config)
|
| 821 |
+
|
| 822 |
+
self.init_weights()
|
| 823 |
+
|
| 824 |
+
def get_output_embeddings(self):
|
| 825 |
+
return self.cls.predictions.decoder
|
| 826 |
+
|
| 827 |
+
def set_output_embeddings(self, new_embeddings):
|
| 828 |
+
self.cls.predictions.decoder = new_embeddings
|
| 829 |
+
|
| 830 |
+
def forward(
|
| 831 |
+
self,
|
| 832 |
+
input_ids=None,
|
| 833 |
+
attention_mask=None,
|
| 834 |
+
position_ids=None,
|
| 835 |
+
head_mask=None,
|
| 836 |
+
inputs_embeds=None,
|
| 837 |
+
encoder_hidden_states=None,
|
| 838 |
+
encoder_attention_mask=None,
|
| 839 |
+
labels=None,
|
| 840 |
+
past_key_values=None,
|
| 841 |
+
use_cache=None,
|
| 842 |
+
output_attentions=None,
|
| 843 |
+
output_hidden_states=None,
|
| 844 |
+
return_dict=None,
|
| 845 |
+
return_logits=False,
|
| 846 |
+
is_decoder=True,
|
| 847 |
+
reduction='mean',
|
| 848 |
+
mode='multimodal',
|
| 849 |
+
):
|
| 850 |
+
r"""
|
| 851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 853 |
+
the model is configured as a decoder.
|
| 854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 857 |
+
- 1 for tokens that are **not masked**,
|
| 858 |
+
- 0 for tokens that are **masked**.
|
| 859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 868 |
+
use_cache (:obj:`bool`, `optional`):
|
| 869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 870 |
+
decoding (see :obj:`past_key_values`).
|
| 871 |
+
Returns:
|
| 872 |
+
Example::
|
| 873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 874 |
+
>>> import torch
|
| 875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 879 |
+
>>> outputs = model(**inputs)
|
| 880 |
+
>>> prediction_logits = outputs.logits
|
| 881 |
+
"""
|
| 882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 883 |
+
if labels is not None:
|
| 884 |
+
use_cache = False
|
| 885 |
+
|
| 886 |
+
outputs = self.bert(
|
| 887 |
+
input_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
position_ids=position_ids,
|
| 890 |
+
head_mask=head_mask,
|
| 891 |
+
inputs_embeds=inputs_embeds,
|
| 892 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 893 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 894 |
+
past_key_values=past_key_values,
|
| 895 |
+
use_cache=use_cache,
|
| 896 |
+
output_attentions=output_attentions,
|
| 897 |
+
output_hidden_states=output_hidden_states,
|
| 898 |
+
return_dict=return_dict,
|
| 899 |
+
is_decoder=is_decoder,
|
| 900 |
+
mode=mode,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
sequence_output = outputs[0]
|
| 904 |
+
prediction_scores = self.cls(sequence_output)
|
| 905 |
+
|
| 906 |
+
if return_logits:
|
| 907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 908 |
+
|
| 909 |
+
lm_loss = None
|
| 910 |
+
if labels is not None:
|
| 911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 913 |
+
labels = labels[:, 1:].contiguous()
|
| 914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 916 |
+
if reduction=='none':
|
| 917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
| 918 |
+
|
| 919 |
+
if not return_dict:
|
| 920 |
+
output = (prediction_scores,) + outputs[2:]
|
| 921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 922 |
+
|
| 923 |
+
return CausalLMOutputWithCrossAttentions(
|
| 924 |
+
loss=lm_loss,
|
| 925 |
+
logits=prediction_scores,
|
| 926 |
+
past_key_values=outputs.past_key_values,
|
| 927 |
+
hidden_states=outputs.hidden_states,
|
| 928 |
+
attentions=outputs.attentions,
|
| 929 |
+
cross_attentions=outputs.cross_attentions,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
| 933 |
+
input_shape = input_ids.shape
|
| 934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 935 |
+
if attention_mask is None:
|
| 936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 937 |
+
|
| 938 |
+
# cut decoder_input_ids if past is used
|
| 939 |
+
if past is not None:
|
| 940 |
+
input_ids = input_ids[:, -1:]
|
| 941 |
+
|
| 942 |
+
return {
|
| 943 |
+
"input_ids": input_ids,
|
| 944 |
+
"attention_mask": attention_mask,
|
| 945 |
+
"past_key_values": past,
|
| 946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 948 |
+
"is_decoder": True,
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
def _reorder_cache(self, past, beam_idx):
|
| 952 |
+
reordered_past = ()
|
| 953 |
+
for layer_past in past:
|
| 954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 955 |
+
return reordered_past
|
was-node-suite-comfyui/modules/BLIP/blip_module.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
|
| 11 |
+
from .blip_vit import VisionTransformer, interpolate_pos_embed
|
| 12 |
+
from .blip_med import BertConfig, BertModel, BertLMHeadModel
|
| 13 |
+
from transformers import BertTokenizer
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
from timm.models.hub import download_cached_file
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
|
| 26 |
+
|
| 27 |
+
# BLIP
|
| 28 |
+
|
| 29 |
+
class BLIP_Base(nn.Module):
|
| 30 |
+
def __init__(self,
|
| 31 |
+
med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
|
| 32 |
+
image_size = 224,
|
| 33 |
+
vit = 'base',
|
| 34 |
+
vit_grad_ckpt = False,
|
| 35 |
+
vit_ckpt_layer = 0,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 40 |
+
image_size (int): input image size
|
| 41 |
+
vit (str): model size of vision transformer
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 46 |
+
self.tokenizer = init_tokenizer()
|
| 47 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 48 |
+
med_config.encoder_width = vision_width
|
| 49 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, image, caption, mode):
|
| 53 |
+
|
| 54 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
| 55 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
| 56 |
+
|
| 57 |
+
if mode=='image':
|
| 58 |
+
# return image features
|
| 59 |
+
image_embeds = self.visual_encoder(image)
|
| 60 |
+
return image_embeds
|
| 61 |
+
|
| 62 |
+
elif mode=='text':
|
| 63 |
+
# return text features
|
| 64 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 65 |
+
return_dict = True, mode = 'text')
|
| 66 |
+
return text_output.last_hidden_state
|
| 67 |
+
|
| 68 |
+
elif mode=='multimodal':
|
| 69 |
+
# return multimodel features
|
| 70 |
+
image_embeds = self.visual_encoder(image)
|
| 71 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 72 |
+
|
| 73 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 74 |
+
output = self.text_encoder(text.input_ids,
|
| 75 |
+
attention_mask = text.attention_mask,
|
| 76 |
+
encoder_hidden_states = image_embeds,
|
| 77 |
+
encoder_attention_mask = image_atts,
|
| 78 |
+
return_dict = True,
|
| 79 |
+
)
|
| 80 |
+
return output.last_hidden_state
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BLIP_Decoder(nn.Module):
|
| 85 |
+
def __init__(self,
|
| 86 |
+
med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
|
| 87 |
+
image_size = 384,
|
| 88 |
+
vit = 'base',
|
| 89 |
+
vit_grad_ckpt = False,
|
| 90 |
+
vit_ckpt_layer = 0,
|
| 91 |
+
prompt = 'a picture of ',
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Args:
|
| 95 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 96 |
+
image_size (int): input image size
|
| 97 |
+
vit (str): model size of vision transformer
|
| 98 |
+
"""
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 102 |
+
self.tokenizer = init_tokenizer()
|
| 103 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 104 |
+
med_config.encoder_width = vision_width
|
| 105 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
| 106 |
+
|
| 107 |
+
self.prompt = prompt
|
| 108 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def forward(self, image, caption):
|
| 112 |
+
|
| 113 |
+
image_embeds = self.visual_encoder(image)
|
| 114 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 115 |
+
|
| 116 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
| 117 |
+
|
| 118 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 119 |
+
|
| 120 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
| 121 |
+
decoder_targets[:,:self.prompt_length] = -100
|
| 122 |
+
|
| 123 |
+
decoder_output = self.text_decoder(text.input_ids,
|
| 124 |
+
attention_mask = text.attention_mask,
|
| 125 |
+
encoder_hidden_states = image_embeds,
|
| 126 |
+
encoder_attention_mask = image_atts,
|
| 127 |
+
labels = decoder_targets,
|
| 128 |
+
return_dict = True,
|
| 129 |
+
)
|
| 130 |
+
loss_lm = decoder_output.loss
|
| 131 |
+
|
| 132 |
+
return loss_lm
|
| 133 |
+
|
| 134 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
| 135 |
+
image_embeds = self.visual_encoder(image)
|
| 136 |
+
|
| 137 |
+
if not sample:
|
| 138 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
| 139 |
+
|
| 140 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 141 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
| 142 |
+
|
| 143 |
+
prompt = [self.prompt] * image.size(0)
|
| 144 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
| 145 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
| 146 |
+
input_ids = input_ids[:, :-1]
|
| 147 |
+
|
| 148 |
+
if sample:
|
| 149 |
+
#nucleus sampling
|
| 150 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 151 |
+
max_length=max_length,
|
| 152 |
+
min_length=min_length,
|
| 153 |
+
do_sample=True,
|
| 154 |
+
top_p=top_p,
|
| 155 |
+
num_return_sequences=1,
|
| 156 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 157 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 158 |
+
repetition_penalty=1.1,
|
| 159 |
+
**model_kwargs)
|
| 160 |
+
else:
|
| 161 |
+
#beam search
|
| 162 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 163 |
+
max_length=max_length,
|
| 164 |
+
min_length=min_length,
|
| 165 |
+
num_beams=num_beams,
|
| 166 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 167 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 168 |
+
repetition_penalty=repetition_penalty,
|
| 169 |
+
**model_kwargs)
|
| 170 |
+
|
| 171 |
+
captions = []
|
| 172 |
+
for output in outputs:
|
| 173 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 174 |
+
captions.append(caption[len(self.prompt):])
|
| 175 |
+
return captions
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def blip_decoder(pretrained='',**kwargs):
|
| 179 |
+
model = BLIP_Decoder(**kwargs)
|
| 180 |
+
if pretrained:
|
| 181 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 182 |
+
assert(len(msg.missing_keys)==0)
|
| 183 |
+
return model
|
| 184 |
+
|
| 185 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
| 186 |
+
model = BLIP_Base(**kwargs)
|
| 187 |
+
if pretrained:
|
| 188 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 189 |
+
assert(len(msg.missing_keys)==0)
|
| 190 |
+
return model
|
| 191 |
+
|
| 192 |
+
def init_tokenizer():
|
| 193 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 194 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 195 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 196 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 197 |
+
return tokenizer
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 201 |
+
|
| 202 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 203 |
+
if vit=='base':
|
| 204 |
+
vision_width = 768
|
| 205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 206 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 207 |
+
drop_path_rate=0 or drop_path_rate
|
| 208 |
+
)
|
| 209 |
+
elif vit=='large':
|
| 210 |
+
vision_width = 1024
|
| 211 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 212 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 213 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 214 |
+
)
|
| 215 |
+
return visual_encoder, vision_width
|
| 216 |
+
|
| 217 |
+
def is_url(url_or_filename):
|
| 218 |
+
parsed = urlparse(url_or_filename)
|
| 219 |
+
return parsed.scheme in ("http", "https")
|
| 220 |
+
|
| 221 |
+
def load_checkpoint(model,url_or_filename):
|
| 222 |
+
if is_url(url_or_filename):
|
| 223 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 224 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 225 |
+
elif os.path.isfile(url_or_filename):
|
| 226 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 227 |
+
else:
|
| 228 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 229 |
+
|
| 230 |
+
state_dict = checkpoint['model']
|
| 231 |
+
|
| 232 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 233 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 234 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 235 |
+
model.visual_encoder_m)
|
| 236 |
+
for key in model.state_dict().keys():
|
| 237 |
+
if key in state_dict.keys():
|
| 238 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 239 |
+
del state_dict[key]
|
| 240 |
+
|
| 241 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 242 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 243 |
+
return model,msg
|
| 244 |
+
|
| 245 |
+
# BLIP VQA
|
| 246 |
+
|
| 247 |
+
class BLIP_VQA(nn.Module):
|
| 248 |
+
def __init__(self,
|
| 249 |
+
med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
|
| 250 |
+
image_size = 480,
|
| 251 |
+
vit = 'base',
|
| 252 |
+
vit_grad_ckpt = False,
|
| 253 |
+
vit_ckpt_layer = 0,
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Args:
|
| 257 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 258 |
+
image_size (int): input image size
|
| 259 |
+
vit (str): model size of vision transformer
|
| 260 |
+
"""
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 264 |
+
self.tokenizer = init_tokenizer()
|
| 265 |
+
|
| 266 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 267 |
+
encoder_config.encoder_width = vision_width
|
| 268 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 269 |
+
|
| 270 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
| 271 |
+
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
| 275 |
+
|
| 276 |
+
image_embeds = self.visual_encoder(image)
|
| 277 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 278 |
+
|
| 279 |
+
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
|
| 280 |
+
return_tensors="pt").to(image.device)
|
| 281 |
+
question.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 282 |
+
|
| 283 |
+
if train:
|
| 284 |
+
'''
|
| 285 |
+
n: number of answers for each question
|
| 286 |
+
weights: weight for each answer
|
| 287 |
+
'''
|
| 288 |
+
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
|
| 289 |
+
answer.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 290 |
+
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
| 291 |
+
|
| 292 |
+
question_output = self.text_encoder(question.input_ids,
|
| 293 |
+
attention_mask = question.attention_mask,
|
| 294 |
+
encoder_hidden_states = image_embeds,
|
| 295 |
+
encoder_attention_mask = image_atts,
|
| 296 |
+
return_dict = True)
|
| 297 |
+
|
| 298 |
+
question_states = []
|
| 299 |
+
question_atts = []
|
| 300 |
+
for b, n in enumerate(n):
|
| 301 |
+
question_states += [question_output.last_hidden_state[b]]*n
|
| 302 |
+
question_atts += [question.attention_mask[b]]*n
|
| 303 |
+
question_states = torch.stack(question_states,0)
|
| 304 |
+
question_atts = torch.stack(question_atts,0)
|
| 305 |
+
|
| 306 |
+
answer_output = self.text_decoder(answer.input_ids,
|
| 307 |
+
attention_mask = answer.attention_mask,
|
| 308 |
+
encoder_hidden_states = question_states,
|
| 309 |
+
encoder_attention_mask = question_atts,
|
| 310 |
+
labels = answer_targets,
|
| 311 |
+
return_dict = True,
|
| 312 |
+
reduction = 'none',
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
loss = weights * answer_output.loss
|
| 316 |
+
loss = loss.sum()/image.size(0)
|
| 317 |
+
|
| 318 |
+
return loss
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
else:
|
| 322 |
+
question_output = self.text_encoder(question.input_ids,
|
| 323 |
+
attention_mask = question.attention_mask,
|
| 324 |
+
encoder_hidden_states = image_embeds,
|
| 325 |
+
encoder_attention_mask = image_atts,
|
| 326 |
+
return_dict = True)
|
| 327 |
+
|
| 328 |
+
if inference=='generate':
|
| 329 |
+
num_beams = 3
|
| 330 |
+
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
|
| 331 |
+
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
|
| 332 |
+
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
|
| 333 |
+
|
| 334 |
+
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
| 335 |
+
|
| 336 |
+
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
| 337 |
+
max_length=10,
|
| 338 |
+
min_length=1,
|
| 339 |
+
num_beams=num_beams,
|
| 340 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 341 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 342 |
+
**model_kwargs)
|
| 343 |
+
|
| 344 |
+
answers = []
|
| 345 |
+
for output in outputs:
|
| 346 |
+
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 347 |
+
answers.append(answer)
|
| 348 |
+
return answers
|
| 349 |
+
|
| 350 |
+
elif inference=='rank':
|
| 351 |
+
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
|
| 352 |
+
answer.input_ids, answer.attention_mask, k_test)
|
| 353 |
+
return max_ids
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
| 358 |
+
|
| 359 |
+
num_ques = question_states.size(0)
|
| 360 |
+
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
| 361 |
+
|
| 362 |
+
start_output = self.text_decoder(start_ids,
|
| 363 |
+
encoder_hidden_states = question_states,
|
| 364 |
+
encoder_attention_mask = question_atts,
|
| 365 |
+
return_dict = True,
|
| 366 |
+
reduction = 'none')
|
| 367 |
+
logits = start_output.logits[:,0,:] # first token's logit
|
| 368 |
+
|
| 369 |
+
# topk_probs: top-k probability
|
| 370 |
+
# topk_ids: [num_question, k]
|
| 371 |
+
answer_first_token = answer_ids[:,1]
|
| 372 |
+
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
| 373 |
+
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
| 374 |
+
|
| 375 |
+
# answer input: [num_question*k, answer_len]
|
| 376 |
+
input_ids = []
|
| 377 |
+
input_atts = []
|
| 378 |
+
for b, topk_id in enumerate(topk_ids):
|
| 379 |
+
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
| 380 |
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
| 381 |
+
input_ids = torch.cat(input_ids,dim=0)
|
| 382 |
+
input_atts = torch.cat(input_atts,dim=0)
|
| 383 |
+
|
| 384 |
+
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
| 385 |
+
|
| 386 |
+
# repeat encoder's output for top-k answers
|
| 387 |
+
question_states = tile(question_states, 0, k)
|
| 388 |
+
question_atts = tile(question_atts, 0, k)
|
| 389 |
+
|
| 390 |
+
output = self.text_decoder(input_ids,
|
| 391 |
+
attention_mask = input_atts,
|
| 392 |
+
encoder_hidden_states = question_states,
|
| 393 |
+
encoder_attention_mask = question_atts,
|
| 394 |
+
labels = targets_ids,
|
| 395 |
+
return_dict = True,
|
| 396 |
+
reduction = 'none')
|
| 397 |
+
|
| 398 |
+
log_probs_sum = -output.loss
|
| 399 |
+
log_probs_sum = log_probs_sum.view(num_ques,k)
|
| 400 |
+
|
| 401 |
+
max_topk_ids = log_probs_sum.argmax(dim=1)
|
| 402 |
+
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
|
| 403 |
+
|
| 404 |
+
return max_ids
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def blip_vqa(pretrained='',**kwargs):
|
| 408 |
+
model = BLIP_VQA(**kwargs)
|
| 409 |
+
if pretrained:
|
| 410 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 411 |
+
# assert(len(msg.missing_keys)==0)
|
| 412 |
+
return model
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def tile(x, dim, n_tile):
|
| 416 |
+
init_dim = x.size(dim)
|
| 417 |
+
repeat_idx = [1] * x.dim()
|
| 418 |
+
repeat_idx[dim] = n_tile
|
| 419 |
+
x = x.repeat(*(repeat_idx))
|
| 420 |
+
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
| 421 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
| 422 |
+
|
| 423 |
+
|
was-node-suite-comfyui/modules/BLIP/blip_module_license.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
| 2 |
+
All rights reserved.
|
| 3 |
+
|
| 4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 5 |
+
|
| 6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 7 |
+
|
| 8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 9 |
+
|
| 10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 11 |
+
|
| 12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
was-node-suite-comfyui/modules/BLIP/blip_vit.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on timm code base
|
| 8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
| 17 |
+
from timm.models.registry import register_model
|
| 18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
| 20 |
+
|
| 21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 22 |
+
|
| 23 |
+
class Mlp(nn.Module):
|
| 24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Attention(nn.Module):
|
| 45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
self.attn_gradients = None
|
| 56 |
+
self.attention_map = None
|
| 57 |
+
|
| 58 |
+
def save_attn_gradients(self, attn_gradients):
|
| 59 |
+
self.attn_gradients = attn_gradients
|
| 60 |
+
|
| 61 |
+
def get_attn_gradients(self):
|
| 62 |
+
return self.attn_gradients
|
| 63 |
+
|
| 64 |
+
def save_attention_map(self, attention_map):
|
| 65 |
+
self.attention_map = attention_map
|
| 66 |
+
|
| 67 |
+
def get_attention_map(self):
|
| 68 |
+
return self.attention_map
|
| 69 |
+
|
| 70 |
+
def forward(self, x, register_hook=False):
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 74 |
+
|
| 75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 76 |
+
attn = attn.softmax(dim=-1)
|
| 77 |
+
attn = self.attn_drop(attn)
|
| 78 |
+
|
| 79 |
+
if register_hook:
|
| 80 |
+
self.save_attention_map(attn)
|
| 81 |
+
attn.register_hook(self.save_attn_gradients)
|
| 82 |
+
|
| 83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 84 |
+
x = self.proj(x)
|
| 85 |
+
x = self.proj_drop(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Block(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.norm1 = norm_layer(dim)
|
| 95 |
+
self.attn = Attention(
|
| 96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 99 |
+
self.norm2 = norm_layer(dim)
|
| 100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 102 |
+
|
| 103 |
+
if use_grad_checkpointing:
|
| 104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
| 105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, register_hook=False):
|
| 108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
| 109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class VisionTransformer(nn.Module):
|
| 114 |
+
""" Vision Transformer
|
| 115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 116 |
+
https://arxiv.org/abs/2010.11929
|
| 117 |
+
"""
|
| 118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
| 120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
| 121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
| 122 |
+
"""
|
| 123 |
+
Args:
|
| 124 |
+
img_size (int, tuple): input image size
|
| 125 |
+
patch_size (int, tuple): patch size
|
| 126 |
+
in_chans (int): number of input channels
|
| 127 |
+
num_classes (int): number of classes for classification head
|
| 128 |
+
embed_dim (int): embedding dimension
|
| 129 |
+
depth (int): depth of transformer
|
| 130 |
+
num_heads (int): number of attention heads
|
| 131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 132 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 135 |
+
drop_rate (float): dropout rate
|
| 136 |
+
attn_drop_rate (float): attention dropout rate
|
| 137 |
+
drop_path_rate (float): stochastic depth rate
|
| 138 |
+
norm_layer: (nn.Module): normalization layer
|
| 139 |
+
"""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 143 |
+
|
| 144 |
+
self.patch_embed = PatchEmbed(
|
| 145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 146 |
+
|
| 147 |
+
num_patches = self.patch_embed.num_patches
|
| 148 |
+
|
| 149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 152 |
+
|
| 153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 154 |
+
self.blocks = nn.ModuleList([
|
| 155 |
+
Block(
|
| 156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
| 159 |
+
)
|
| 160 |
+
for i in range(depth)])
|
| 161 |
+
self.norm = norm_layer(embed_dim)
|
| 162 |
+
|
| 163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 164 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 165 |
+
self.apply(self._init_weights)
|
| 166 |
+
|
| 167 |
+
def _init_weights(self, m):
|
| 168 |
+
if isinstance(m, nn.Linear):
|
| 169 |
+
trunc_normal_(m.weight, std=.02)
|
| 170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 171 |
+
nn.init.constant_(m.bias, 0)
|
| 172 |
+
elif isinstance(m, nn.LayerNorm):
|
| 173 |
+
nn.init.constant_(m.bias, 0)
|
| 174 |
+
nn.init.constant_(m.weight, 1.0)
|
| 175 |
+
|
| 176 |
+
@torch.jit.ignore
|
| 177 |
+
def no_weight_decay(self):
|
| 178 |
+
return {'pos_embed', 'cls_token'}
|
| 179 |
+
|
| 180 |
+
def forward(self, x, register_blk=-1):
|
| 181 |
+
B = x.shape[0]
|
| 182 |
+
x = self.patch_embed(x)
|
| 183 |
+
|
| 184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 186 |
+
|
| 187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
| 188 |
+
x = self.pos_drop(x)
|
| 189 |
+
|
| 190 |
+
for i,blk in enumerate(self.blocks):
|
| 191 |
+
x = blk(x, register_blk==i)
|
| 192 |
+
x = self.norm(x)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
@torch.jit.ignore()
|
| 197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
| 198 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
| 203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
| 204 |
+
"""
|
| 205 |
+
import numpy as np
|
| 206 |
+
|
| 207 |
+
def _n2p(w, t=True):
|
| 208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
| 209 |
+
w = w.flatten()
|
| 210 |
+
if t:
|
| 211 |
+
if w.ndim == 4:
|
| 212 |
+
w = w.transpose([3, 2, 0, 1])
|
| 213 |
+
elif w.ndim == 3:
|
| 214 |
+
w = w.transpose([2, 0, 1])
|
| 215 |
+
elif w.ndim == 2:
|
| 216 |
+
w = w.transpose([1, 0])
|
| 217 |
+
return torch.from_numpy(w)
|
| 218 |
+
|
| 219 |
+
w = np.load(checkpoint_path)
|
| 220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
| 221 |
+
prefix = 'opt/target/'
|
| 222 |
+
|
| 223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
| 224 |
+
# hybrid
|
| 225 |
+
backbone = model.patch_embed.backbone
|
| 226 |
+
stem_only = not hasattr(backbone, 'stem')
|
| 227 |
+
stem = backbone if stem_only else backbone.stem
|
| 228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
| 229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
| 230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
| 231 |
+
if not stem_only:
|
| 232 |
+
for i, stage in enumerate(backbone.stages):
|
| 233 |
+
for j, block in enumerate(stage.blocks):
|
| 234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
| 235 |
+
for r in range(3):
|
| 236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
| 237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
| 238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
| 239 |
+
if block.downsample is not None:
|
| 240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
| 241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
| 242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
| 243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
| 244 |
+
else:
|
| 245 |
+
embed_conv_w = adapt_input_conv(
|
| 246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
| 247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
| 248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
| 249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
| 250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
| 251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
| 252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
| 253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
| 254 |
+
model.pos_embed.copy_(pos_embed_w)
|
| 255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
| 256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
| 257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
| 258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
| 259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
| 260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
| 261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
| 262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
| 263 |
+
for i, block in enumerate(model.blocks.children()):
|
| 264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
| 265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
| 266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
| 267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
| 268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
| 269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
| 271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
| 273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
| 274 |
+
for r in range(2):
|
| 275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
| 276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
| 277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
| 278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
| 282 |
+
# interpolate position embedding
|
| 283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
| 285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
| 286 |
+
# height (== width) for the checkpoint position embedding
|
| 287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 288 |
+
# height (== width) for the new position embedding
|
| 289 |
+
new_size = int(num_patches ** 0.5)
|
| 290 |
+
|
| 291 |
+
if orig_size!=new_size:
|
| 292 |
+
# class_token and dist_token are kept unchanged
|
| 293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 294 |
+
# only the position tokens are interpolated
|
| 295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
| 302 |
+
|
| 303 |
+
return new_pos_embed
|
| 304 |
+
else:
|
| 305 |
+
return pos_embed_checkpoint
|
was-node-suite-comfyui/modules/__init__.py
ADDED
|
File without changes
|
was-node-suite-comfyui/repos/SAM/demo/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Segment Anything Simple Web demo
|
| 2 |
+
|
| 3 |
+
This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
|
| 4 |
+
|
| 5 |
+
<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>
|
| 6 |
+
|
| 7 |
+
## Run the app
|
| 8 |
+
|
| 9 |
+
Install Yarn
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
npm install --g yarn
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Build and run:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
yarn && yarn start
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Navigate to [`http://localhost:8081/`](http://localhost:8081/)
|
| 22 |
+
|
| 23 |
+
Move your cursor around to see the mask prediction update in real time.
|
| 24 |
+
|
| 25 |
+
## Export the image embedding
|
| 26 |
+
|
| 27 |
+
In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.
|
| 28 |
+
|
| 29 |
+
Initialize the predictor:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
checkpoint = "sam_vit_h_4b8939.pth"
|
| 33 |
+
model_type = "vit_h"
|
| 34 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
| 35 |
+
sam.to(device='cuda')
|
| 36 |
+
predictor = SamPredictor(sam)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Set the new image and export the embedding:
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
image = cv2.imread('src/assets/dogs.jpg')
|
| 43 |
+
predictor.set_image(image)
|
| 44 |
+
image_embedding = predictor.get_image_embedding().cpu().numpy()
|
| 45 |
+
np.save("dogs_embedding.npy", image_embedding)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Save the new image and embedding in `src/assets/data`.
|
| 49 |
+
|
| 50 |
+
## Export the ONNX model
|
| 51 |
+
|
| 52 |
+
You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb).
|
| 53 |
+
|
| 54 |
+
Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`.
|
| 55 |
+
|
| 56 |
+
Here is a snippet of the export/quantization code:
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
onnx_model_path = "sam_onnx_example.onnx"
|
| 60 |
+
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
|
| 61 |
+
quantize_dynamic(
|
| 62 |
+
model_input=onnx_model_path,
|
| 63 |
+
model_output=onnx_model_quantized_path,
|
| 64 |
+
optimize_model=True,
|
| 65 |
+
per_channel=False,
|
| 66 |
+
reduce_range=False,
|
| 67 |
+
weight_type=QuantType.QUInt8,
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**
|
| 72 |
+
|
| 73 |
+
## Update the image, embedding, model in the app
|
| 74 |
+
|
| 75 |
+
Update the following file paths at the top of`App.tsx`:
|
| 76 |
+
|
| 77 |
+
```py
|
| 78 |
+
const IMAGE_PATH = "/assets/data/dogs.jpg";
|
| 79 |
+
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
|
| 80 |
+
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## ONNX multithreading with SharedArrayBuffer
|
| 84 |
+
|
| 85 |
+
To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)
|
| 86 |
+
|
| 87 |
+
The headers below are set in `configs/webpack/dev.js`:
|
| 88 |
+
|
| 89 |
+
```js
|
| 90 |
+
headers: {
|
| 91 |
+
"Cross-Origin-Opener-Policy": "same-origin",
|
| 92 |
+
"Cross-Origin-Embedder-Policy": "credentialless",
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Structure of the app
|
| 97 |
+
|
| 98 |
+
**`App.tsx`**
|
| 99 |
+
|
| 100 |
+
- Initializes ONNX model
|
| 101 |
+
- Loads image embedding and image
|
| 102 |
+
- Runs the ONNX model based on input prompts
|
| 103 |
+
|
| 104 |
+
**`Stage.tsx`**
|
| 105 |
+
|
| 106 |
+
- Handles mouse move interaction to update the ONNX model prompt
|
| 107 |
+
|
| 108 |
+
**`Tool.tsx`**
|
| 109 |
+
|
| 110 |
+
- Renders the image and the mask prediction
|
| 111 |
+
|
| 112 |
+
**`helpers/maskUtils.tsx`**
|
| 113 |
+
|
| 114 |
+
- Conversion of ONNX model output from array to an HTMLImageElement
|
| 115 |
+
|
| 116 |
+
**`helpers/onnxModelAPI.tsx`**
|
| 117 |
+
|
| 118 |
+
- Formats the inputs for the ONNX model
|
| 119 |
+
|
| 120 |
+
**`helpers/scaleHelper.tsx`**
|
| 121 |
+
|
| 122 |
+
- Handles image scaling logic for SAM (longest size 1024)
|
| 123 |
+
|
| 124 |
+
**`hooks/`**
|
| 125 |
+
|
| 126 |
+
- Handle shared state for the app
|
was-node-suite-comfyui/repos/SAM/demo/package.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "segment-anything-mini-demo",
|
| 3 |
+
"version": "0.1.0",
|
| 4 |
+
"license": "MIT",
|
| 5 |
+
"scripts": {
|
| 6 |
+
"build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js",
|
| 7 |
+
"clean-dist": "rimraf dist/*",
|
| 8 |
+
"lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
|
| 9 |
+
"start": "yarn run start-dev",
|
| 10 |
+
"test": "yarn run start-model-test",
|
| 11 |
+
"start-dev": "webpack serve --config=configs/webpack/dev.js"
|
| 12 |
+
},
|
| 13 |
+
"devDependencies": {
|
| 14 |
+
"@babel/core": "^7.18.13",
|
| 15 |
+
"@babel/preset-env": "^7.18.10",
|
| 16 |
+
"@babel/preset-react": "^7.18.6",
|
| 17 |
+
"@babel/preset-typescript": "^7.18.6",
|
| 18 |
+
"@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
|
| 19 |
+
"@testing-library/react": "^13.3.0",
|
| 20 |
+
"@types/node": "^18.7.13",
|
| 21 |
+
"@types/react": "^18.0.17",
|
| 22 |
+
"@types/react-dom": "^18.0.6",
|
| 23 |
+
"@types/underscore": "^1.11.4",
|
| 24 |
+
"@typescript-eslint/eslint-plugin": "^5.35.1",
|
| 25 |
+
"@typescript-eslint/parser": "^5.35.1",
|
| 26 |
+
"babel-loader": "^8.2.5",
|
| 27 |
+
"copy-webpack-plugin": "^11.0.0",
|
| 28 |
+
"css-loader": "^6.7.1",
|
| 29 |
+
"dotenv": "^16.0.2",
|
| 30 |
+
"dotenv-webpack": "^8.0.1",
|
| 31 |
+
"eslint": "^8.22.0",
|
| 32 |
+
"eslint-plugin-react": "^7.31.0",
|
| 33 |
+
"file-loader": "^6.2.0",
|
| 34 |
+
"fork-ts-checker-webpack-plugin": "^7.2.13",
|
| 35 |
+
"friendly-errors-webpack-plugin": "^1.7.0",
|
| 36 |
+
"html-webpack-plugin": "^5.5.0",
|
| 37 |
+
"image-webpack-loader": "^8.1.0",
|
| 38 |
+
"postcss-loader": "^7.0.1",
|
| 39 |
+
"postcss-preset-env": "^7.8.0",
|
| 40 |
+
"process": "^0.11.10",
|
| 41 |
+
"rimraf": "^3.0.2",
|
| 42 |
+
"sass": "^1.54.5",
|
| 43 |
+
"sass-loader": "^13.0.2",
|
| 44 |
+
"style-loader": "^3.3.1",
|
| 45 |
+
"tailwindcss": "^3.1.8",
|
| 46 |
+
"ts-loader": "^9.3.1",
|
| 47 |
+
"typescript": "^4.8.2",
|
| 48 |
+
"webpack": "^5.74.0",
|
| 49 |
+
"webpack-cli": "^4.10.0",
|
| 50 |
+
"webpack-dev-server": "^4.10.0",
|
| 51 |
+
"webpack-dotenv-plugin": "^2.1.0",
|
| 52 |
+
"webpack-merge": "^5.8.0"
|
| 53 |
+
},
|
| 54 |
+
"dependencies": {
|
| 55 |
+
"npyjs": "^0.4.0",
|
| 56 |
+
"onnxruntime-web": "^1.14.0",
|
| 57 |
+
"react": "^18.2.0",
|
| 58 |
+
"react-dom": "^18.2.0",
|
| 59 |
+
"underscore": "^1.13.6",
|
| 60 |
+
"react-refresh": "^0.14.0"
|
| 61 |
+
}
|
| 62 |
+
}
|
was-node-suite-comfyui/repos/SAM/demo/postcss.config.js
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
const tailwindcss = require("tailwindcss");
|
| 8 |
+
module.exports = {
|
| 9 |
+
plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
|
| 10 |
+
};
|
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/Interfaces.tsx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
|
| 9 |
+
export interface modelScaleProps {
|
| 10 |
+
samScale: number;
|
| 11 |
+
height: number;
|
| 12 |
+
width: number;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export interface modelInputProps {
|
| 16 |
+
x: number;
|
| 17 |
+
y: number;
|
| 18 |
+
clickType: number;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
export interface modeDataProps {
|
| 22 |
+
clicks?: Array<modelInputProps>;
|
| 23 |
+
tensor: Tensor;
|
| 24 |
+
modelScale: modelScaleProps;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
export interface ToolProps {
|
| 28 |
+
handleMouseMove: (e: any) => void;
|
| 29 |
+
}
|
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/maskUtils.tsx
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// Convert the onnx model mask prediction to ImageData
|
| 8 |
+
function arrayToImageData(input: any, width: number, height: number) {
|
| 9 |
+
const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
|
| 10 |
+
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
|
| 11 |
+
for (let i = 0; i < input.length; i++) {
|
| 12 |
+
|
| 13 |
+
// Threshold the onnx model mask prediction at 0.0
|
| 14 |
+
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
|
| 15 |
+
// in python
|
| 16 |
+
if (input[i] > 0.0) {
|
| 17 |
+
arr[4 * i + 0] = r;
|
| 18 |
+
arr[4 * i + 1] = g;
|
| 19 |
+
arr[4 * i + 2] = b;
|
| 20 |
+
arr[4 * i + 3] = a;
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
return new ImageData(arr, height, width);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Use a Canvas element to produce an image from ImageData
|
| 27 |
+
function imageDataToImage(imageData: ImageData) {
|
| 28 |
+
const canvas = imageDataToCanvas(imageData);
|
| 29 |
+
const image = new Image();
|
| 30 |
+
image.src = canvas.toDataURL();
|
| 31 |
+
return image;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Canvas elements can be created from ImageData
|
| 35 |
+
function imageDataToCanvas(imageData: ImageData) {
|
| 36 |
+
const canvas = document.createElement("canvas");
|
| 37 |
+
const ctx = canvas.getContext("2d");
|
| 38 |
+
canvas.width = imageData.width;
|
| 39 |
+
canvas.height = imageData.height;
|
| 40 |
+
ctx?.putImageData(imageData, 0, 0);
|
| 41 |
+
return canvas;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// Convert the onnx model mask output to an HTMLImageElement
|
| 45 |
+
export function onnxMaskToImage(input: any, width: number, height: number) {
|
| 46 |
+
return imageDataToImage(arrayToImageData(input, width, height));
|
| 47 |
+
}
|
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/onnxModelAPI.tsx
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
import { modeDataProps } from "./Interfaces";
|
| 9 |
+
|
| 10 |
+
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
|
| 11 |
+
const imageEmbedding = tensor;
|
| 12 |
+
let pointCoords;
|
| 13 |
+
let pointLabels;
|
| 14 |
+
let pointCoordsTensor;
|
| 15 |
+
let pointLabelsTensor;
|
| 16 |
+
|
| 17 |
+
// Check there are input click prompts
|
| 18 |
+
if (clicks) {
|
| 19 |
+
let n = clicks.length;
|
| 20 |
+
|
| 21 |
+
// If there is no box input, a single padding point with
|
| 22 |
+
// label -1 and coordinates (0.0, 0.0) should be concatenated
|
| 23 |
+
// so initialize the array to support (n + 1) points.
|
| 24 |
+
pointCoords = new Float32Array(2 * (n + 1));
|
| 25 |
+
pointLabels = new Float32Array(n + 1);
|
| 26 |
+
|
| 27 |
+
// Add clicks and scale to what SAM expects
|
| 28 |
+
for (let i = 0; i < n; i++) {
|
| 29 |
+
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
|
| 30 |
+
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
|
| 31 |
+
pointLabels[i] = clicks[i].clickType;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Add in the extra point/label when only clicks and no box
|
| 35 |
+
// The extra point is at (0, 0) with label -1
|
| 36 |
+
pointCoords[2 * n] = 0.0;
|
| 37 |
+
pointCoords[2 * n + 1] = 0.0;
|
| 38 |
+
pointLabels[n] = -1.0;
|
| 39 |
+
|
| 40 |
+
// Create the tensor
|
| 41 |
+
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
|
| 42 |
+
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
|
| 43 |
+
}
|
| 44 |
+
const imageSizeTensor = new Tensor("float32", [
|
| 45 |
+
modelScale.height,
|
| 46 |
+
modelScale.width,
|
| 47 |
+
]);
|
| 48 |
+
|
| 49 |
+
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
|
| 50 |
+
return;
|
| 51 |
+
|
| 52 |
+
// There is no previous mask, so default to an empty tensor
|
| 53 |
+
const maskInput = new Tensor(
|
| 54 |
+
"float32",
|
| 55 |
+
new Float32Array(256 * 256),
|
| 56 |
+
[1, 1, 256, 256]
|
| 57 |
+
);
|
| 58 |
+
// There is no previous mask, so default to 0
|
| 59 |
+
const hasMaskInput = new Tensor("float32", [0]);
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
image_embeddings: imageEmbedding,
|
| 63 |
+
point_coords: pointCoordsTensor,
|
| 64 |
+
point_labels: pointLabelsTensor,
|
| 65 |
+
orig_im_size: imageSizeTensor,
|
| 66 |
+
mask_input: maskInput,
|
| 67 |
+
has_mask_input: hasMaskInput,
|
| 68 |
+
};
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
export { modelData };
|
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/scaleHelper.tsx
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
// Helper function for handling image scaling needed for SAM
|
| 9 |
+
const handleImageScale = (image: HTMLImageElement) => {
|
| 10 |
+
// Input images to SAM must be resized so the longest side is 1024
|
| 11 |
+
const LONG_SIDE_LENGTH = 1024;
|
| 12 |
+
let w = image.naturalWidth;
|
| 13 |
+
let h = image.naturalHeight;
|
| 14 |
+
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
|
| 15 |
+
return { height: h, width: w, samScale };
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
export { handleImageScale };
|
was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/context.tsx
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import React, { useState } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
import AppContext from "./createContext";
|
| 10 |
+
|
| 11 |
+
const AppContextProvider = (props: {
|
| 12 |
+
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
|
| 13 |
+
}) => {
|
| 14 |
+
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
|
| 15 |
+
const [image, setImage] = useState<HTMLImageElement | null>(null);
|
| 16 |
+
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
|
| 17 |
+
|
| 18 |
+
return (
|
| 19 |
+
<AppContext.Provider
|
| 20 |
+
value={{
|
| 21 |
+
clicks: [clicks, setClicks],
|
| 22 |
+
image: [image, setImage],
|
| 23 |
+
maskImg: [maskImg, setMaskImg],
|
| 24 |
+
}}
|
| 25 |
+
>
|
| 26 |
+
{props.children}
|
| 27 |
+
</AppContext.Provider>
|
| 28 |
+
);
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
export default AppContextProvider;
|
was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/createContext.tsx
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { createContext } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
|
| 10 |
+
interface contextProps {
|
| 11 |
+
clicks: [
|
| 12 |
+
clicks: modelInputProps[] | null,
|
| 13 |
+
setClicks: (e: modelInputProps[] | null) => void
|
| 14 |
+
];
|
| 15 |
+
image: [
|
| 16 |
+
image: HTMLImageElement | null,
|
| 17 |
+
setImage: (e: HTMLImageElement | null) => void
|
| 18 |
+
];
|
| 19 |
+
maskImg: [
|
| 20 |
+
maskImg: HTMLImageElement | null,
|
| 21 |
+
setMaskImg: (e: HTMLImageElement | null) => void
|
| 22 |
+
];
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
const AppContext = createContext<contextProps | null>(null);
|
| 26 |
+
|
| 27 |
+
export default AppContext;
|