keyon857 commited on
Commit
22c9f7e
·
1 Parent(s): 3106f09

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. rgthree-comfy/web/comfyui/services/bookmarks_services.js +10 -0
  2. rgthree-comfy/web/comfyui/services/config_service.js +28 -0
  3. rgthree-comfy/web/comfyui/services/context_service.js +51 -0
  4. rgthree-comfy/web/comfyui/services/fast_groups_service.js +138 -0
  5. rgthree-comfy/web/common/css/buttons.css +90 -0
  6. rgthree-comfy/web/common/css/dialog.css +124 -0
  7. rgthree-comfy/web/common/css/dialog_model_info.css +333 -0
  8. rgthree-comfy/web/common/css/menu.css +91 -0
  9. rgthree-comfy/web/common/css/pages_base.css +66 -0
  10. rgthree-comfy/web/common/media/rgthree.svg +7 -0
  11. rgthree-comfy/web/common/media/svgs.js +160 -0
  12. rgthree-comfy/web/common/shared_utils.js +142 -0
  13. rgthree-comfy/web/common/utils_dom.js +311 -0
  14. rgthree-comfy/web/common/utils_workflow.js +55 -0
  15. rgthree-comfy/web/link_fixer/link_page.js +195 -0
  16. sd-dynamic-thresholding/.github/FUNDING.yml +1 -0
  17. sd-dynamic-thresholding/.github/workflows/publish.yml +21 -0
  18. sd-dynamic-thresholding/__pycache__/__init__.cpython-312.pyc +0 -0
  19. sd-dynamic-thresholding/__pycache__/dynthres_comfyui.cpython-312.pyc +0 -0
  20. sd-dynamic-thresholding/__pycache__/dynthres_core.cpython-312.pyc +0 -0
  21. sd-dynamic-thresholding/github/comfy_node.png +0 -0
  22. sd-dynamic-thresholding/github/ui.png +0 -0
  23. sd-dynamic-thresholding/javascript/active.js +68 -0
  24. sd-dynamic-thresholding/scripts/dynamic_thresholding.py +270 -0
  25. sigmas_tools_and_the_golden_scheduler/.github/workflows/publish.yml +21 -0
  26. sigmas_tools_and_the_golden_scheduler/__pycache__/__init__.cpython-312.pyc +0 -0
  27. sigmas_tools_and_the_golden_scheduler/__pycache__/sigmas_merge.cpython-312.pyc +0 -0
  28. stable-diffusion-temperature-settings/.github/FUNDING.yml +3 -0
  29. stable-diffusion-temperature-settings/.github/workflows/publish.yml +22 -0
  30. stable-diffusion-temperature-settings/__pycache__/__init__.cpython-312.pyc +0 -0
  31. stable-diffusion-temperature-settings/__pycache__/nodes.cpython-312.pyc +0 -0
  32. stable-diffusion-temperature-settings/workflows/tinybottle.png +0 -0
  33. ultimate-upscale-for-automatic1111/scripts/ultimate-upscale.py +569 -0
  34. was-node-suite-comfyui/.github/workflows/publish_action.yml +20 -0
  35. was-node-suite-comfyui/__pycache__/__init__.cpython-312.pyc +0 -0
  36. was-node-suite-comfyui/modules/BLIP/__init__.py +0 -0
  37. was-node-suite-comfyui/modules/BLIP/blip_med.py +955 -0
  38. was-node-suite-comfyui/modules/BLIP/blip_module.py +423 -0
  39. was-node-suite-comfyui/modules/BLIP/blip_module_license.txt +12 -0
  40. was-node-suite-comfyui/modules/BLIP/blip_vit.py +305 -0
  41. was-node-suite-comfyui/modules/__init__.py +0 -0
  42. was-node-suite-comfyui/repos/SAM/demo/README.md +126 -0
  43. was-node-suite-comfyui/repos/SAM/demo/package.json +62 -0
  44. was-node-suite-comfyui/repos/SAM/demo/postcss.config.js +10 -0
  45. was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/Interfaces.tsx +29 -0
  46. was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/maskUtils.tsx +47 -0
  47. was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/onnxModelAPI.tsx +71 -0
  48. was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/scaleHelper.tsx +18 -0
  49. was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/context.tsx +31 -0
  50. was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/createContext.tsx +27 -0
rgthree-comfy/web/comfyui/services/bookmarks_services.js ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { app } from "../../../scripts/app.js";
2
+ import { NodeTypesString } from "../constants.js";
3
+ class BookmarksService {
4
+ getCurrentBookmarks() {
5
+ return app.graph._nodes
6
+ .filter((n) => n.type === NodeTypesString.BOOKMARK)
7
+ .sort((a, b) => a.title.localeCompare(b.title));
8
+ }
9
+ }
10
+ export const SERVICE = new BookmarksService();
rgthree-comfy/web/comfyui/services/config_service.js ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { rgthreeConfig } from "../../../rgthree/config.js";
2
+ import { getObjectValue, setObjectValue } from "../../../rgthree/common/shared_utils.js";
3
+ import { rgthreeApi } from "../../../rgthree/common/rgthree_api.js";
4
+ class ConfigService extends EventTarget {
5
+ getConfigValue(key, def) {
6
+ return getObjectValue(rgthreeConfig, key, def);
7
+ }
8
+ getFeatureValue(key, def) {
9
+ key = "features." + key.replace(/^features\./, "");
10
+ return getObjectValue(rgthreeConfig, key, def);
11
+ }
12
+ async setConfigValues(changed) {
13
+ const body = new FormData();
14
+ body.append("json", JSON.stringify(changed));
15
+ const response = await rgthreeApi.fetchJson("/config", { method: "POST", body });
16
+ if (response.status === "ok") {
17
+ for (const [key, value] of Object.entries(changed)) {
18
+ setObjectValue(rgthreeConfig, key, value);
19
+ this.dispatchEvent(new CustomEvent("config-change", { detail: { key, value } }));
20
+ }
21
+ }
22
+ else {
23
+ return false;
24
+ }
25
+ return true;
26
+ }
27
+ }
28
+ export const SERVICE = new ConfigService();
rgthree-comfy/web/comfyui/services/context_service.js ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { getConnectedOutputNodesAndFilterPassThroughs } from "../utils.js";
2
+ export let SERVICE;
3
+ const OWNED_PREFIX = "+";
4
+ const REGEX_PREFIX = /^[\+⚠️]\s*/;
5
+ const REGEX_EMPTY_INPUT = /^\+\s*$/;
6
+ export function stripContextInputPrefixes(name) {
7
+ return name.replace(REGEX_PREFIX, "");
8
+ }
9
+ export function getContextOutputName(inputName) {
10
+ if (inputName === "base_ctx")
11
+ return "CONTEXT";
12
+ return stripContextInputPrefixes(inputName).toUpperCase();
13
+ }
14
+ export var InputMutationOperation;
15
+ (function (InputMutationOperation) {
16
+ InputMutationOperation[InputMutationOperation["UNKNOWN"] = 0] = "UNKNOWN";
17
+ InputMutationOperation[InputMutationOperation["ADDED"] = 1] = "ADDED";
18
+ InputMutationOperation[InputMutationOperation["REMOVED"] = 2] = "REMOVED";
19
+ InputMutationOperation[InputMutationOperation["RENAMED"] = 3] = "RENAMED";
20
+ })(InputMutationOperation || (InputMutationOperation = {}));
21
+ export class ContextService {
22
+ constructor() {
23
+ if (SERVICE) {
24
+ throw new Error("ContextService was already instantiated.");
25
+ }
26
+ }
27
+ onInputChanges(node, mutation) {
28
+ const childCtxs = getConnectedOutputNodesAndFilterPassThroughs(node, node, 0);
29
+ for (const childCtx of childCtxs) {
30
+ childCtx.handleUpstreamMutation(mutation);
31
+ }
32
+ }
33
+ getDynamicContextInputsData(node) {
34
+ return node
35
+ .getContextInputsList()
36
+ .map((input, index) => ({
37
+ name: stripContextInputPrefixes(input.name),
38
+ type: String(input.type),
39
+ index,
40
+ }))
41
+ .filter((i) => i.type !== "*");
42
+ }
43
+ getDynamicContextOutputsData(node) {
44
+ return node.outputs.map((output, index) => ({
45
+ name: stripContextInputPrefixes(output.name),
46
+ type: String(output.type),
47
+ index,
48
+ }));
49
+ }
50
+ }
51
+ SERVICE = new ContextService();
rgthree-comfy/web/comfyui/services/fast_groups_service.js ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { app } from "../../../scripts/app.js";
2
+ class FastGroupsService {
3
+ constructor() {
4
+ this.msThreshold = 400;
5
+ this.msLastUnsorted = 0;
6
+ this.msLastAlpha = 0;
7
+ this.msLastPosition = 0;
8
+ this.groupsUnsorted = [];
9
+ this.groupsSortedAlpha = [];
10
+ this.groupsSortedPosition = [];
11
+ this.fastGroupNodes = [];
12
+ this.runScheduledForMs = null;
13
+ this.runScheduleTimeout = null;
14
+ this.runScheduleAnimation = null;
15
+ this.cachedNodeBoundings = null;
16
+ }
17
+ addFastGroupNode(node) {
18
+ this.fastGroupNodes.push(node);
19
+ this.scheduleRun(8);
20
+ }
21
+ removeFastGroupNode(node) {
22
+ var _a;
23
+ const index = this.fastGroupNodes.indexOf(node);
24
+ if (index > -1) {
25
+ this.fastGroupNodes.splice(index, 1);
26
+ }
27
+ if (!((_a = this.fastGroupNodes) === null || _a === void 0 ? void 0 : _a.length)) {
28
+ this.clearScheduledRun();
29
+ this.groupsUnsorted = [];
30
+ this.groupsSortedAlpha = [];
31
+ this.groupsSortedPosition = [];
32
+ }
33
+ }
34
+ run() {
35
+ if (!this.runScheduledForMs) {
36
+ return;
37
+ }
38
+ for (const node of this.fastGroupNodes) {
39
+ node.refreshWidgets();
40
+ }
41
+ this.clearScheduledRun();
42
+ this.scheduleRun();
43
+ }
44
+ scheduleRun(ms = 500) {
45
+ if (this.runScheduledForMs && ms < this.runScheduledForMs) {
46
+ this.clearScheduledRun();
47
+ }
48
+ if (!this.runScheduledForMs && this.fastGroupNodes.length) {
49
+ this.runScheduledForMs = ms;
50
+ this.runScheduleTimeout = setTimeout(() => {
51
+ this.runScheduleAnimation = requestAnimationFrame(() => this.run());
52
+ }, ms);
53
+ }
54
+ }
55
+ clearScheduledRun() {
56
+ this.runScheduleTimeout && clearTimeout(this.runScheduleTimeout);
57
+ this.runScheduleAnimation && cancelAnimationFrame(this.runScheduleAnimation);
58
+ this.runScheduleTimeout = null;
59
+ this.runScheduleAnimation = null;
60
+ this.runScheduledForMs = null;
61
+ }
62
+ getBoundingsForAllNodes() {
63
+ if (!this.cachedNodeBoundings) {
64
+ this.cachedNodeBoundings = {};
65
+ for (const node of app.graph._nodes) {
66
+ this.cachedNodeBoundings[node.id] = node.getBounding();
67
+ }
68
+ setTimeout(() => {
69
+ this.cachedNodeBoundings = null;
70
+ }, 50);
71
+ }
72
+ return this.cachedNodeBoundings;
73
+ }
74
+ recomputeInsideNodesForGroup(group) {
75
+ const cachedBoundings = this.getBoundingsForAllNodes();
76
+ const nodes = group.graph._nodes;
77
+ group._nodes.length = 0;
78
+ for (const node of nodes) {
79
+ const node_bounding = cachedBoundings[node.id];
80
+ if (!node_bounding || !LiteGraph.overlapBounding(group._bounding, node_bounding)) {
81
+ continue;
82
+ }
83
+ group._nodes.push(node);
84
+ }
85
+ }
86
+ getGroupsUnsorted(now) {
87
+ const canvas = app.canvas;
88
+ const graph = app.graph;
89
+ if (!canvas.selected_group_moving &&
90
+ (!this.groupsUnsorted.length || now - this.msLastUnsorted > this.msThreshold)) {
91
+ this.groupsUnsorted = [...graph._groups];
92
+ for (const group of this.groupsUnsorted) {
93
+ this.recomputeInsideNodesForGroup(group);
94
+ group._rgthreeHasAnyActiveNode = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS);
95
+ }
96
+ this.msLastUnsorted = now;
97
+ }
98
+ return this.groupsUnsorted;
99
+ }
100
+ getGroupsAlpha(now) {
101
+ const graph = app.graph;
102
+ if (!this.groupsSortedAlpha.length || now - this.msLastAlpha > this.msThreshold) {
103
+ this.groupsSortedAlpha = [...this.getGroupsUnsorted(now)].sort((a, b) => {
104
+ return a.title.localeCompare(b.title);
105
+ });
106
+ this.msLastAlpha = now;
107
+ }
108
+ return this.groupsSortedAlpha;
109
+ }
110
+ getGroupsPosition(now) {
111
+ const graph = app.graph;
112
+ if (!this.groupsSortedPosition.length || now - this.msLastPosition > this.msThreshold) {
113
+ this.groupsSortedPosition = [...this.getGroupsUnsorted(now)].sort((a, b) => {
114
+ const aY = Math.floor(a._pos[1] / 30);
115
+ const bY = Math.floor(b._pos[1] / 30);
116
+ if (aY == bY) {
117
+ const aX = Math.floor(a._pos[0] / 30);
118
+ const bX = Math.floor(b._pos[0] / 30);
119
+ return aX - bX;
120
+ }
121
+ return aY - bY;
122
+ });
123
+ this.msLastPosition = now;
124
+ }
125
+ return this.groupsSortedPosition;
126
+ }
127
+ getGroups(sort) {
128
+ const now = +new Date();
129
+ if (sort === "alphanumeric") {
130
+ return this.getGroupsAlpha(now);
131
+ }
132
+ if (sort === "position") {
133
+ return this.getGroupsPosition(now);
134
+ }
135
+ return this.getGroupsUnsorted(now);
136
+ }
137
+ }
138
+ export const SERVICE = new FastGroupsService();
rgthree-comfy/web/common/css/buttons.css ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :not(#fakeid) .rgthree-button-reset {
2
+ position: relative;
3
+ appearance: none;
4
+ cursor: pointer;
5
+ border: 0;
6
+ background: transparent;
7
+ color: inherit;
8
+ padding: 0;
9
+ margin: 0;
10
+ }
11
+
12
+ :not(#fakeid) .rgthree-button {
13
+ --padding-top: 7px;
14
+ --padding-bottom: 9px;
15
+ --padding-x: 16px;
16
+ position: relative;
17
+ cursor: pointer;
18
+ border: 0;
19
+ border-radius: 0.25rem;
20
+ background: rgba(0, 0, 0, 0.5);
21
+ color: white;
22
+ font-family: system-ui, sans-serif;
23
+ font-size: 1rem;
24
+ line-height: 1;
25
+ white-space: nowrap;
26
+ text-decoration: none;
27
+ margin: 0.25rem;
28
+ box-shadow: 0px 0px 2px rgb(0, 0, 0);
29
+ background: #212121;
30
+ transition: all 0.1s ease-in-out;
31
+ padding: var(--padding-top) var(--padding-x) var(--padding-bottom);
32
+ display: inline-flex;
33
+ flex-direction: row;
34
+ align-items: center;
35
+ justify-content: center;
36
+ }
37
+ :not(#fakeid) .rgthree-button::before, :not(#fakeid) .rgthree-button::after {
38
+ content: "";
39
+ display: block;
40
+ position: absolute;
41
+ border-radius: 0.25rem;
42
+ left: 0;
43
+ top: 0;
44
+ width: 100%;
45
+ height: 100%;
46
+ box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.12), inset -1px -1px 0px rgba(0, 0, 0, 0.75);
47
+ background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15));
48
+ mix-blend-mode: screen;
49
+ }
50
+ :not(#fakeid) .rgthree-button::after {
51
+ mix-blend-mode: multiply;
52
+ }
53
+ :not(#fakeid) .rgthree-button:hover {
54
+ background: #303030;
55
+ }
56
+ :not(#fakeid) .rgthree-button:active {
57
+ box-shadow: 0px 0px 0px rgba(0, 0, 0, 0);
58
+ background: #121212;
59
+ padding: calc(var(--padding-top) + 1px) calc(var(--padding-x) - 1px) calc(var(--padding-bottom) - 1px) calc(var(--padding-x) + 1px);
60
+ }
61
+ :not(#fakeid) .rgthree-button:active::before, :not(#fakeid) .rgthree-button:active::after {
62
+ box-shadow: 1px 1px 0px rgba(255, 255, 255, 0.15), inset 1px 1px 0px rgba(0, 0, 0, 0.5), inset 1px 3px 5px rgba(0, 0, 0, 0.33);
63
+ }
64
+ :not(#fakeid) .rgthree-button.-blue {
65
+ background: #346599 !important;
66
+ }
67
+ :not(#fakeid) .rgthree-button.-blue:hover {
68
+ background: #3b77b8 !important;
69
+ }
70
+ :not(#fakeid) .rgthree-button.-blue:active {
71
+ background: #1d5086 !important;
72
+ }
73
+ :not(#fakeid) .rgthree-button.-green {
74
+ background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #14580b;
75
+ }
76
+ :not(#fakeid) .rgthree-button.-green:hover {
77
+ background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #1a6d0f;
78
+ }
79
+ :not(#fakeid) .rgthree-button.-green:active {
80
+ background: linear-gradient(to bottom, rgba(0, 0, 0, 0.15), rgba(255, 255, 255, 0.06)), #0f3f09;
81
+ }
82
+ :not(#fakeid) .rgthree-button[disabled] {
83
+ box-shadow: none;
84
+ background: #666 !important;
85
+ color: #aaa;
86
+ pointer-events: none;
87
+ }
88
+ :not(#fakeid) .rgthree-button[disabled]::before, :not(#fakeid) .rgthree-button[disabled]::after {
89
+ display: none;
90
+ }
rgthree-comfy/web/common/css/dialog.css ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @charset "UTF-8";
2
+ .rgthree-dialog {
3
+ outline: 0;
4
+ border: 0;
5
+ border-radius: 6px;
6
+ background: #414141;
7
+ color: #fff;
8
+ box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.05), inset -1px -1px 0px rgba(0, 0, 0, 0.5), 2px 2px 20px rgb(0, 0, 0);
9
+ max-width: 800px;
10
+ box-sizing: border-box;
11
+ font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
12
+ font-size: 1rem;
13
+ padding: 0;
14
+ max-height: calc(100% - 32px);
15
+ }
16
+ .rgthree-dialog *, .rgthree-dialog *::before, .rgthree-dialog *::after {
17
+ box-sizing: inherit;
18
+ }
19
+
20
+ .rgthree-dialog-container > * {
21
+ padding: 8px 16px;
22
+ }
23
+ .rgthree-dialog-container > *:first-child {
24
+ padding-top: 16px;
25
+ }
26
+ .rgthree-dialog-container > *:last-child {
27
+ padding-bottom: 16px;
28
+ }
29
+
30
+ .rgthree-dialog.-iconed::after {
31
+ content: "";
32
+ font-size: 276px;
33
+ position: absolute;
34
+ right: 0px;
35
+ bottom: 0px;
36
+ opacity: 0.15;
37
+ display: block;
38
+ width: 237px;
39
+ overflow: hidden;
40
+ height: 186px;
41
+ line-height: 1;
42
+ pointer-events: none;
43
+ z-index: -1;
44
+ }
45
+
46
+ .rgthree-dialog.-iconed.-help::after {
47
+ content: "🛟";
48
+ }
49
+
50
+ .rgthree-dialog.-iconed.-settings::after {
51
+ content: "⚙️";
52
+ }
53
+
54
+ @media (max-width: 832px) {
55
+ .rgthree-dialog {
56
+ max-width: calc(100% - 32px);
57
+ }
58
+ }
59
+ .rgthree-dialog-container-title {
60
+ display: flex;
61
+ flex-direction: row;
62
+ align-items: center;
63
+ justify-content: start;
64
+ }
65
+
66
+ .rgthree-dialog-container-title > svg:first-child {
67
+ width: 36px;
68
+ height: 36px;
69
+ margin-right: 16px;
70
+ }
71
+
72
+ .rgthree-dialog-container-title h2 {
73
+ font-size: 1.375rem;
74
+ margin: 0;
75
+ font-weight: bold;
76
+ }
77
+
78
+ .rgthree-dialog-container-title h2 small {
79
+ font-size: 0.8125rem;
80
+ font-weight: normal;
81
+ opacity: 0.75;
82
+ }
83
+
84
+ .rgthree-dialog-container-content {
85
+ overflow: auto;
86
+ max-height: calc(100vh - 200px); /* Arbitrary height to copensate for margin, title, and footer.*/
87
+ }
88
+
89
+ .rgthree-dialog-container-content p {
90
+ font-size: 0.8125rem;
91
+ margin-top: 0;
92
+ }
93
+
94
+ .rgthree-dialog-container-content ul li p {
95
+ margin-bottom: 4px;
96
+ }
97
+
98
+ .rgthree-dialog-container-content ul li p + p {
99
+ margin-top: 0.5em;
100
+ }
101
+
102
+ .rgthree-dialog-container-content ul li ul {
103
+ margin-top: 0.5em;
104
+ margin-bottom: 1em;
105
+ }
106
+
107
+ .rgthree-dialog-container-content p code {
108
+ display: inline-block;
109
+ padding: 2px 4px;
110
+ margin: 0px 2px;
111
+ border: 1px solid rgba(255, 255, 255, 0.25);
112
+ border-radius: 3px;
113
+ background: rgba(255, 255, 255, 0.1);
114
+ }
115
+
116
+ .rgthree-dialog-container-footer {
117
+ display: flex;
118
+ align-items: center;
119
+ justify-content: center;
120
+ }
121
+
122
+ body.rgthree-dialog-open > *:not(.rgthree-dialog):not(.rgthree-top-messages-container) {
123
+ filter: blur(5px);
124
+ }
rgthree-comfy/web/common/css/dialog_model_info.css ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .rgthree-info-dialog {
2
+ width: 90vw;
3
+ max-width: 960px;
4
+ }
5
+ .rgthree-info-dialog .rgthree-info-area {
6
+ list-style: none;
7
+ padding: 0;
8
+ margin: 0;
9
+ display: flex;
10
+ }
11
+ .rgthree-info-dialog .rgthree-info-area > li {
12
+ display: inline-flex;
13
+ margin: 0;
14
+ vertical-align: top;
15
+ }
16
+ .rgthree-info-dialog .rgthree-info-area > li + li {
17
+ margin-left: 6px;
18
+ }
19
+ .rgthree-info-dialog .rgthree-info-area > li:not(.-link) + li.-link {
20
+ margin-left: auto;
21
+ }
22
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * {
23
+ min-height: 24px;
24
+ border-radius: 4px;
25
+ line-height: 1;
26
+ color: rgba(255, 255, 255, 0.85);
27
+ background: rgb(69, 92, 85);
28
+ font-size: 14px;
29
+ font-weight: bold;
30
+ text-decoration: none;
31
+ display: flex;
32
+ height: 1.6em;
33
+ padding-left: 0.5em;
34
+ padding-right: 0.5em;
35
+ padding-bottom: 0.1em;
36
+ align-content: center;
37
+ justify-content: center;
38
+ align-items: center;
39
+ box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5);
40
+ }
41
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg {
42
+ width: 16px;
43
+ height: 16px;
44
+ }
45
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg:last-child {
46
+ margin-left: 0.5em;
47
+ }
48
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *[href] {
49
+ box-shadow: inset 0px 1px 0px rgba(255, 255, 255, 0.25), inset 0px -1px 0px rgba(0, 0, 0, 0.66);
50
+ }
51
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *:empty {
52
+ display: none;
53
+ }
54
+ .rgthree-info-dialog .rgthree-info-area > li.-type > * {
55
+ background: rgb(73, 54, 94);
56
+ color: rgb(228, 209, 248);
57
+ }
58
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu {
59
+ margin-left: auto;
60
+ }
61
+ :not(#fakeid) .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu .rgthree-button {
62
+ margin: 0;
63
+ min-height: 24px;
64
+ padding: 0 12px;
65
+ }
66
+ .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu svg {
67
+ width: 16px;
68
+ height: 16px;
69
+ }
70
+ .rgthree-info-dialog .rgthree-info-table {
71
+ border-collapse: collapse;
72
+ margin: 16px 0px;
73
+ width: 100%;
74
+ font-size: 12px;
75
+ }
76
+ .rgthree-info-dialog .rgthree-info-table tr.editable button {
77
+ display: flex;
78
+ width: 28px;
79
+ height: 28px;
80
+ align-items: center;
81
+ justify-content: center;
82
+ }
83
+ .rgthree-info-dialog .rgthree-info-table tr.editable button svg + svg {
84
+ display: none;
85
+ }
86
+ .rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg {
87
+ display: none;
88
+ }
89
+ .rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg + svg {
90
+ display: inline-block;
91
+ }
92
+ .rgthree-info-dialog .rgthree-info-table td {
93
+ position: relative;
94
+ border: 1px solid rgba(255, 255, 255, 0.25);
95
+ padding: 0;
96
+ vertical-align: top;
97
+ }
98
+ .rgthree-info-dialog .rgthree-info-table td:first-child {
99
+ background: rgba(255, 255, 255, 0.075);
100
+ width: 10px;
101
+ }
102
+ .rgthree-info-dialog .rgthree-info-table td:first-child > *:first-child {
103
+ white-space: nowrap;
104
+ padding-right: 32px;
105
+ }
106
+ .rgthree-info-dialog .rgthree-info-table td:first-child small {
107
+ display: block;
108
+ margin-top: 2px;
109
+ opacity: 0.75;
110
+ }
111
+ .rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action] {
112
+ text-decoration: underline;
113
+ cursor: pointer;
114
+ }
115
+ .rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action]:hover {
116
+ text-decoration: none;
117
+ }
118
+ .rgthree-info-dialog .rgthree-info-table td a, .rgthree-info-dialog .rgthree-info-table td a:hover, .rgthree-info-dialog .rgthree-info-table td a:visited {
119
+ color: inherit;
120
+ }
121
+ .rgthree-info-dialog .rgthree-info-table td svg {
122
+ width: 1.3333em;
123
+ height: 1.3333em;
124
+ vertical-align: -0.285em;
125
+ }
126
+ .rgthree-info-dialog .rgthree-info-table td svg.logo-civitai {
127
+ margin-right: 0.3333em;
128
+ }
129
+ .rgthree-info-dialog .rgthree-info-table td > *:first-child {
130
+ display: block;
131
+ padding: 6px 10px;
132
+ }
133
+ .rgthree-info-dialog .rgthree-info-table td > input, .rgthree-info-dialog .rgthree-info-table td > textarea {
134
+ padding: 5px 10px;
135
+ border: 0;
136
+ box-shadow: inset 1px 1px 5px 0px rgba(0, 0, 0, 0.5);
137
+ font: inherit;
138
+ appearance: none;
139
+ background: #fff;
140
+ color: #121212;
141
+ resize: vertical;
142
+ }
143
+ .rgthree-info-dialog .rgthree-info-table td > input:only-child, .rgthree-info-dialog .rgthree-info-table td > textarea:only-child {
144
+ width: 100%;
145
+ }
146
+ :not(#fakeid) .rgthree-info-dialog .rgthree-info-table td .rgthree-button[data-action=fetch-civitai] {
147
+ font-size: inherit;
148
+ padding: 6px 16px;
149
+ margin: 2px;
150
+ }
151
+ .rgthree-info-dialog .rgthree-info-table tr[data-field-name=userNote] td > span:first-child {
152
+ white-space: pre;
153
+ }
154
+ .rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td {
155
+ border: 0;
156
+ background: transparent;
157
+ padding: 12px 4px 4px;
158
+ font-size: 1.2em;
159
+ }
160
+ .rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td > small {
161
+ font-style: italic;
162
+ opacity: 0.66;
163
+ }
164
+ .rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td:empty {
165
+ padding: 4px;
166
+ }
167
+ .rgthree-info-dialog .rgthree-info-table td .-help {
168
+ border: 1px solid currentColor;
169
+ position: absolute;
170
+ right: 5px;
171
+ top: 6px;
172
+ line-height: 1;
173
+ font-size: 11px;
174
+ width: 12px;
175
+ height: 12px;
176
+ border-radius: 8px;
177
+ display: flex;
178
+ align-content: center;
179
+ justify-content: center;
180
+ cursor: help;
181
+ }
182
+ .rgthree-info-dialog .rgthree-info-table td .-help::before {
183
+ content: "?";
184
+ }
185
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list {
186
+ list-style: none;
187
+ padding: 2px 8px;
188
+ margin: 0;
189
+ display: flex;
190
+ flex-direction: row;
191
+ flex-wrap: wrap;
192
+ max-height: 15vh;
193
+ overflow: auto;
194
+ }
195
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li {
196
+ display: inline-flex;
197
+ margin: 2px;
198
+ vertical-align: top;
199
+ border-radius: 4px;
200
+ line-height: 1;
201
+ color: rgba(255, 255, 255, 0.85);
202
+ background: rgb(73, 91, 106);
203
+ font-size: 1.2em;
204
+ font-weight: 600;
205
+ text-decoration: none;
206
+ display: flex;
207
+ height: 1.6em;
208
+ align-content: center;
209
+ justify-content: center;
210
+ align-items: center;
211
+ box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5);
212
+ cursor: pointer;
213
+ white-space: nowrap;
214
+ max-width: 183px;
215
+ }
216
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li:hover {
217
+ background: rgb(68, 109, 142);
218
+ }
219
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > svg {
220
+ width: auto;
221
+ height: 1.2em;
222
+ }
223
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > span {
224
+ padding-left: 0.5em;
225
+ padding-right: 0.5em;
226
+ padding-bottom: 0.1em;
227
+ text-overflow: ellipsis;
228
+ overflow: hidden;
229
+ }
230
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > small {
231
+ align-self: stretch;
232
+ display: flex;
233
+ align-items: center;
234
+ justify-content: center;
235
+ padding: 0 0.5em;
236
+ background: rgba(0, 0, 0, 0.2);
237
+ }
238
+ .rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li.-rgthree-is-selected {
239
+ background: rgb(42, 126, 193);
240
+ }
241
+ .rgthree-info-dialog .rgthree-info-images {
242
+ list-style: none;
243
+ padding: 0;
244
+ margin: 0;
245
+ scroll-snap-type: x mandatory;
246
+ display: flex;
247
+ flex-direction: row;
248
+ overflow: auto;
249
+ }
250
+ .rgthree-info-dialog .rgthree-info-images > li {
251
+ scroll-snap-align: start;
252
+ max-width: 90%;
253
+ flex: 0 0 auto;
254
+ display: flex;
255
+ align-items: center;
256
+ justify-content: center;
257
+ flex-direction: column;
258
+ overflow: hidden;
259
+ padding: 0;
260
+ margin: 6px;
261
+ font-size: 0;
262
+ position: relative;
263
+ }
264
+ .rgthree-info-dialog .rgthree-info-images > li figure {
265
+ margin: 0;
266
+ position: static;
267
+ }
268
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption {
269
+ position: absolute;
270
+ left: 0;
271
+ width: 100%;
272
+ bottom: 0;
273
+ padding: 12px;
274
+ font-size: 12px;
275
+ background: rgba(0, 0, 0, 0.85);
276
+ opacity: 0;
277
+ transform: translateY(50px);
278
+ transition: all 0.25s ease-in-out;
279
+ }
280
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption > span {
281
+ display: inline-block;
282
+ padding: 2px 4px;
283
+ margin: 2px;
284
+ border-radius: 2px;
285
+ border: 1px solid rgba(255, 255, 255, 0.2);
286
+ word-break: break-word;
287
+ }
288
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption > span label {
289
+ display: inline;
290
+ padding: 0;
291
+ margin: 0;
292
+ opacity: 0.5;
293
+ pointer-events: none;
294
+ user-select: none;
295
+ }
296
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a {
297
+ color: inherit;
298
+ text-decoration: underline;
299
+ }
300
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a:hover {
301
+ text-decoration: none;
302
+ }
303
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a svg {
304
+ height: 10px;
305
+ margin-left: 4px;
306
+ fill: currentColor;
307
+ }
308
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty {
309
+ text-align: center;
310
+ }
311
+ .rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty::before {
312
+ content: "No data.";
313
+ }
314
+ .rgthree-info-dialog .rgthree-info-images > li:hover figure figcaption {
315
+ opacity: 1;
316
+ transform: translateY(0px);
317
+ }
318
+ .rgthree-info-dialog .rgthree-info-images > li .rgthree-info-table {
319
+ width: calc(100% - 16px);
320
+ }
321
+ .rgthree-info-dialog .rgthree-info-civitai-link {
322
+ margin: 8px;
323
+ color: #eee;
324
+ }
325
+ .rgthree-info-dialog .rgthree-info-civitai-link a, .rgthree-info-dialog .rgthree-info-civitai-link a:hover, .rgthree-info-dialog .rgthree-info-civitai-link a:visited {
326
+ color: inherit;
327
+ text-decoration: none;
328
+ }
329
+ .rgthree-info-dialog .rgthree-info-civitai-link > svg {
330
+ width: 16px;
331
+ height: 16px;
332
+ margin-right: 8px;
333
+ }
rgthree-comfy/web/common/css/menu.css ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .rgthree-menu {
2
+ list-style: none;
3
+ padding: 0;
4
+ margin: 0;
5
+ position: fixed;
6
+ z-index: 999999;
7
+ pointer-events: none;
8
+ opacity: 0;
9
+ transition: opacity 0.08s ease-in-out;
10
+ color: #dde;
11
+ background-color: #111;
12
+ font-size: 12px;
13
+ box-shadow: 0 0 10px black !important;
14
+ }
15
+ .rgthree-menu > li {
16
+ position: relative;
17
+ padding: 4px 6px;
18
+ z-index: 9999;
19
+ white-space: nowrap;
20
+ }
21
+ .rgthree-menu > li[role=button] {
22
+ background-color: var(--comfy-menu-bg) !important;
23
+ color: var(--input-text);
24
+ cursor: pointer;
25
+ }
26
+ .rgthree-menu > li[role=button]:hover {
27
+ filter: brightness(155%);
28
+ }
29
+ .rgthree-menu[state^=measuring] {
30
+ display: block;
31
+ opacity: 0;
32
+ }
33
+ .rgthree-menu[state=open] {
34
+ display: block;
35
+ opacity: 1;
36
+ pointer-events: all;
37
+ }
38
+
39
+ .rgthree-top-menu {
40
+ box-sizing: border-box;
41
+ white-space: nowrap;
42
+ background: var(--content-bg);
43
+ color: var(--content-fg);
44
+ display: flex;
45
+ flex-direction: column;
46
+ }
47
+ .rgthree-top-menu * {
48
+ box-sizing: inherit;
49
+ }
50
+ .rgthree-top-menu menu {
51
+ list-style: none;
52
+ padding: 0;
53
+ margin: 0;
54
+ }
55
+ .rgthree-top-menu menu > li:not(#fakeid) {
56
+ list-style: none;
57
+ padding: 0;
58
+ margin: 0;
59
+ }
60
+ .rgthree-top-menu menu > li:not(#fakeid) > button {
61
+ cursor: pointer;
62
+ padding: 8px 12px 8px 8px;
63
+ width: 100%;
64
+ text-align: start;
65
+ display: flex;
66
+ flex-direction: row;
67
+ align-items: center;
68
+ justify-content: start;
69
+ }
70
+ .rgthree-top-menu menu > li:not(#fakeid) > button:hover {
71
+ background-color: var(--comfy-input-bg);
72
+ }
73
+ .rgthree-top-menu menu > li:not(#fakeid) > button svg {
74
+ height: 16px;
75
+ width: auto;
76
+ margin-inline-end: 0.6em;
77
+ }
78
+ .rgthree-top-menu menu > li:not(#fakeid) > button svg.github-star {
79
+ fill: rgb(227, 179, 65);
80
+ }
81
+ .rgthree-top-menu menu > li:not(#fakeid).rgthree-message {
82
+ min-height: 32px;
83
+ }
84
+ .rgthree-top-menu menu > li:not(#fakeid).rgthree-message > span {
85
+ padding: 8px 12px;
86
+ display: block;
87
+ width: 100%;
88
+ text-align: center;
89
+ font-style: italic;
90
+ font-size: 12px;
91
+ }
rgthree-comfy/web/common/css/pages_base.css ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ html {
2
+ font-size: 100%;
3
+ overflow-y: scroll;
4
+ -webkit-text-size-adjust: 100%;
5
+ -ms-text-size-adjust: 100%;
6
+ box-sizing: border-box;
7
+ }
8
+
9
+ *, *:before, *:after {
10
+ box-sizing: inherit;
11
+ }
12
+
13
+ :root {
14
+ --header-height: 56px;
15
+ --progress-height: 12px;
16
+ }
17
+
18
+ button {
19
+ all: unset;
20
+ }
21
+
22
+ .-bevel {
23
+ position: relative;
24
+ }
25
+
26
+ .-bevel::before {
27
+ content: "";
28
+ position: absolute;
29
+ left: 0;
30
+ top: 0;
31
+ width: 100%;
32
+ height: 100%;
33
+ border: 1px solid red;
34
+ border-color: rgba(255, 255, 255, 0.15) rgba(255, 255, 255, 0.15) rgba(0, 0, 0, 0.5) rgba(0, 0, 0, 0.5);
35
+ z-index: 5;
36
+ pointer-events: none;
37
+ }
38
+
39
+ body {
40
+ background: #202020;
41
+ font-family: Arial, sans-serif;
42
+ font-size: 1rem;
43
+ font-weight: 400;
44
+ margin: 0;
45
+ padding-top: calc(var(--header-height) + var(--progress-height));
46
+ color: #ffffff;
47
+ display: flex;
48
+ flex-direction: column;
49
+ align-items: center;
50
+ justify-content: start;
51
+ }
52
+
53
+ .app-header {
54
+ height: var(--header-height);
55
+ padding: 0;
56
+ position: fixed;
57
+ z-index: 99;
58
+ top: 0;
59
+ left: 0;
60
+ width: 100%;
61
+ background: #353535;
62
+ display: flex;
63
+ flex-direction: row;
64
+ align-items: center;
65
+ justify-content: start;
66
+ }
rgthree-comfy/web/common/media/rgthree.svg ADDED
rgthree-comfy/web/common/media/svgs.js ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { createElement as $el } from "../utils_dom.js";
2
+ export const logoRgthree = `<svg viewBox="0 0 256 256" fill="currentColor" class="rgthree-logo">
3
+ <path d="M88.503,158.997 L152.731,196.103 L152.738,196.092 L152.762,196.103 L152.769,196.106 L152.771,196.103 L183.922,142.084 L174.153,136.437 L148.611,180.676 L101.512,153.484 L132.193,30.415 L156.124,71.869 L165.896,66.225 L128.002,0.59 "></path>
4
+ <path d="M55.586,148.581l13.44,47.521l0.014,0.051l0.168-0.051l10.689-3.022l-6.589-23.313l45.609,26.335l0.087,0.051l0.027-0.051 l5.617-9.718l-42.648-24.622l35.771-143.45L33.232,164.729l9.77,5.645L55.586,148.581z M87.394,93.484l-16.708,67.018l-5.018-17.747 l-8.028,2.27L87.394,93.484z"></path>
5
+ <path d="M189.85,107.717 L137.892,137.718 L143.532,147.49 L185.723,123.133 L231.109,201.746 L24.895,201.746 L37.363,180.146 L27.592,174.505 L5.347,213.03 L250.653,213.03 "></path>
6
+ <path d="M5.347,247.299v8.111h245.307v-8.111l-41.94-0.003c-1.336,0-2.404-1.065-2.441-2.396v-12.14 c0.037-1.315,1.089-2.368,2.41-2.385h41.972v-8.11H5.347v8.11h41.951c1.338,0.017,2.427,1.104,2.427,2.449v12.01 c0,1.365-1.105,2.462-2.457,2.462L5.347,247.299z M139.438,247.296c-1.334,0-2.406-1.065-2.439-2.396v-12.14 c0.033-1.315,1.085-2.368,2.41-2.385h46.415c1.335,0.017,2.425,1.104,2.425,2.449v12.01c0,1.365-1.103,2.462-2.459,2.462H139.438z M70.193,247.296c-1.339,0-2.408-1.065-2.441-2.396v-12.14c0.033-1.315,1.086-2.368,2.407-2.385h46.418 c1.336,0.017,2.425,1.104,2.425,2.449v12.01c0,1.365-1.103,2.462-2.458,2.462H70.193z"></path>
7
+ </svg>`;
8
+ export const github = `<svg viewBox="0 0 16 16" fill="currentColor" class="github-logo">
9
+ <path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2 0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0 0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16 1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51 1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68 1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
10
+ </svg>`;
11
+ export const iconStarFilled = `<svg viewBox="0 0 16 16" fill="currentColor" class="github-star">
12
+ <path d="M8 .25a.75.75 0 0 1 .673.418l1.882 3.815 4.21.612a.75.75 0 0 1 .416 1.279l-3.046 2.97.719 4.192a.751.751 0 0 1-1.088.791L8 12.347l-3.766 1.98a.75.75 0 0 1-1.088-.79l.72-4.194L.818 6.374a.75.75 0 0 1 .416-1.28l4.21-.611L7.327.668A.75.75 0 0 1 8 .25Z"></path>
13
+ </svg>`;
14
+ export const iconReplace = `<svg viewBox="0 0 52 52" fill="currentColor">
15
+ <path d="M20,37.5c0-0.8-0.7-1.5-1.5-1.5h-15C2.7,36,2,36.7,2,37.5v11C2,49.3,2.7,50,3.5,50h15c0.8,0,1.5-0.7,1.5-1.5 V37.5z"/>
16
+ <path d="M8.1,22H3.2c-1,0-1.5,0.9-0.9,1.4l8,8.3c0.4,0.3,1,0.3,1.4,0l8-8.3c0.6-0.6,0.1-1.4-0.9-1.4h-4.7 c0-5,4.9-10,9.9-10V6C15,6,8.1,13,8.1,22z"/>
17
+ <path d="M41.8,20.3c-0.4-0.3-1-0.3-1.4,0l-8,8.3c-0.6,0.6-0.1,1.4,0.9,1.4h4.8c0,6-4.1,10-10.1,10v6 c9,0,16.1-7,16.1-16H49c1,0,1.5-0.9,0.9-1.4L41.8,20.3z"/>
18
+ <path d="M50,3.5C50,2.7,49.3,2,48.5,2h-15C32.7,2,32,2.7,32,3.5v11c0,0.8,0.7,1.5,1.5,1.5h15c0.8,0,1.5-0.7,1.5-1.5 V3.5z"/>
19
+ </svg>`;
20
+ export const iconNode = `<svg viewBox="0 -0.5 25 25" fill="none">
21
+ <path fill-rule="evenodd" clip-rule="evenodd" d="M15.5 19H9.5C7.29086 19 5.5 17.2091 5.5 15V9C5.5 6.79086 7.29086 5 9.5 5H15.5C17.7091 5 19.5 6.79086 19.5 9V15C19.5 17.2091 17.7091 19 15.5 19Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
22
+ <path d="M19.5 9.75C19.9142 9.75 20.25 9.41421 20.25 9C20.25 8.58579 19.9142 8.25 19.5 8.25V9.75ZM5.5 8.25C5.08579 8.25 4.75 8.58579 4.75 9C4.75 9.41421 5.08579 9.75 5.5 9.75V8.25ZM11.5 14.25C11.0858 14.25 10.75 14.5858 10.75 15C10.75 15.4142 11.0858 15.75 11.5 15.75V14.25ZM13.5 15.75C13.9142 15.75 14.25 15.4142 14.25 15C14.25 14.5858 13.9142 14.25 13.5 14.25V15.75ZM19.5 8.25H5.5V9.75H19.5V8.25ZM11.5 15.75H13.5V14.25H11.5V15.75Z" fill="currentColor" />
23
+ </svg>`;
24
+ export const iconGear = `<svg viewBox="0 0 24 24" fill="currentColor">
25
+ <path fill-rule="evenodd" clip-rule="evenodd" d="M12.7848 0.449982C13.8239 0.449982 14.7167 1.16546 14.9122 2.15495L14.9991 2.59495C15.3408 4.32442 17.1859 5.35722 18.9016 4.7794L19.3383 4.63233C20.3199 4.30175 21.4054 4.69358 21.9249 5.56605L22.7097 6.88386C23.2293 7.75636 23.0365 8.86366 22.2504 9.52253L21.9008 9.81555C20.5267 10.9672 20.5267 13.0328 21.9008 14.1844L22.2504 14.4774C23.0365 15.1363 23.2293 16.2436 22.7097 17.1161L21.925 18.4339C21.4054 19.3064 20.3199 19.6982 19.3382 19.3676L18.9017 19.2205C17.1859 18.6426 15.3408 19.6754 14.9991 21.405L14.9122 21.845C14.7167 22.8345 13.8239 23.55 12.7848 23.55H11.2152C10.1761 23.55 9.28331 22.8345 9.08781 21.8451L9.00082 21.4048C8.65909 19.6754 6.81395 18.6426 5.09822 19.2205L4.66179 19.3675C3.68016 19.6982 2.59465 19.3063 2.07505 18.4338L1.2903 17.1161C0.770719 16.2436 0.963446 15.1363 1.74956 14.4774L2.09922 14.1844C3.47324 13.0327 3.47324 10.9672 2.09922 9.8156L1.74956 9.52254C0.963446 8.86366 0.77072 7.75638 1.2903 6.8839L2.07508 5.56608C2.59466 4.69359 3.68014 4.30176 4.66176 4.63236L5.09831 4.77939C6.81401 5.35722 8.65909 4.32449 9.00082 2.59506L9.0878 2.15487C9.28331 1.16542 10.176 0.449982 11.2152 0.449982H12.7848ZM12 15.3C13.8225 15.3 15.3 13.8225 15.3 12C15.3 10.1774 13.8225 8.69998 12 8.69998C10.1774 8.69998 8.69997 10.1774 8.69997 12C8.69997 13.8225 10.1774 15.3 12 15.3Z" />
26
+ </svg>`;
27
+ export const checkmark = `<svg viewBox="0 0 32 32" fill="currentColor" class="icon-checkmark">
28
+ <g transform="translate(-518.000000, -1039.000000)">
29
+ <path d="M548.783,1040.2 C547.188,1038.57 544.603,1038.57 543.008,1040.2 L528.569,1054.92 L524.96,1051.24 C523.365,1049.62 520.779,1049.62 519.185,1051.24 C517.59,1052.87 517.59,1055.51 519.185,1057.13 L525.682,1063.76 C527.277,1065.39 529.862,1065.39 531.457,1063.76 L548.783,1046.09 C550.378,1044.46 550.378,1041.82 548.783,1040.2"></path>
30
+ </g>
31
+ </svg>`;
32
+ export const logoCivitai = `<svg viewBox="0 0 178 178" class="logo-civitai">
33
+ <defs>
34
+ <linearGradient id="bgblue" gradientUnits="userSpaceOnUse" x1="89.3" y1="-665.5" x2="89.3" y2="-841.1" gradientTransform="matrix(1 0 0 -1 0 -664)">
35
+ <stop offset="0" style="stop-color:#1284F7"/>
36
+ <stop offset="1" style="stop-color:#0A20C9"/>
37
+ </linearGradient>
38
+ </defs>
39
+ <path fill="#000" d="M13.3,45.4v87.7l76,43.9l76-43.9V45.4l-76-43.9L13.3,45.4z"/>
40
+ <path style="fill:url(#bgblue);" d="M89.3,29.2l52,30v60l-52,30l-52-30v-60 L89.3,29.2 M89.3,1.5l-76,43.9v87.8l76,43.9l76-43.9V45.4L89.3,1.5z" />
41
+ <path fill="#FFF" d="M104.1,97.2l-14.9,8.5l-14.9-8.5v-17l14.9-8.5l14.9,8.5h18.2V69.7l-33-19l-33,19v38.1l33,19l33-19V97.2H104.1z" />
42
+ </svg>`;
43
+ export const iconOutLink = `<svg viewBox="0 0 32 32">
44
+ <path d="M 18 5 L 18 7 L 23.5625 7 L 11.28125 19.28125 L 12.71875 20.71875 L 25 8.4375 L 25 14 L 27 14 L 27 5 Z M 5 9 L 5 27 L 23 27 L 23 14 L 21 16 L 21 25 L 7 25 L 7 11 L 16 11 L 18 9 Z"></path>
45
+ </svg>`;
46
+ export const link = `<svg viewBox="0 0 640 512">
47
+ <path d="M598.6 41.41C570.1 13.8 534.8 0 498.6 0s-72.36 13.8-99.96 41.41l-43.36 43.36c15.11 8.012 29.47 17.58 41.91 30.02c3.146 3.146 5.898 6.518 8.742 9.838l37.96-37.96C458.5 72.05 477.1 64 498.6 64c20.67 0 40.1 8.047 54.71 22.66c14.61 14.61 22.66 34.04 22.66 54.71s-8.049 40.1-22.66 54.71l-133.3 133.3C405.5 343.1 386 352 365.4 352s-40.1-8.048-54.71-22.66C296 314.7 287.1 295.3 287.1 274.6s8.047-40.1 22.66-54.71L314.2 216.4C312.1 212.5 309.9 208.5 306.7 205.3C298.1 196.7 286.8 192 274.6 192c-11.93 0-23.1 4.664-31.61 12.97c-30.71 53.96-23.63 123.6 22.39 169.6C293 402.2 329.2 416 365.4 416c36.18 0 72.36-13.8 99.96-41.41L598.6 241.3c28.45-28.45 42.24-66.01 41.37-103.3C639.1 102.1 625.4 68.16 598.6 41.41zM234 387.4L196.1 425.3C181.5 439.1 162 448 141.4 448c-20.67 0-40.1-8.047-54.71-22.66c-14.61-14.61-22.66-34.04-22.66-54.71s8.049-40.1 22.66-54.71l133.3-133.3C234.5 168 253.1 160 274.6 160s40.1 8.048 54.71 22.66c14.62 14.61 22.66 34.04 22.66 54.71s-8.047 40.1-22.66 54.71L325.8 295.6c2.094 3.939 4.219 7.895 7.465 11.15C341.9 315.3 353.3 320 365.4 320c11.93 0 23.1-4.664 31.61-12.97c30.71-53.96 23.63-123.6-22.39-169.6C346.1 109.8 310.8 96 274.6 96C238.4 96 202.3 109.8 174.7 137.4L41.41 270.7c-27.6 27.6-41.41 63.78-41.41 99.96c-.0001 36.18 13.8 72.36 41.41 99.97C69.01 498.2 105.2 512 141.4 512c36.18 0 72.36-13.8 99.96-41.41l43.36-43.36c-15.11-8.012-29.47-17.58-41.91-30.02C239.6 394.1 236.9 390.7 234 387.4z"/>
48
+ </svg>`;
49
+ export const pencil = `<svg viewBox="0 0 24 24">
50
+ <path d="M 16.9375 1.0625 L 3.875 14.125 L 1.0742188 22.925781 L 9.875 20.125 L 22.9375 7.0625 C 22.9375 7.0625 22.8375 4.9615 20.9375 3.0625 C 19.0375 1.1625 16.9375 1.0625 16.9375 1.0625 z M 17.3125 2.6875 C 18.3845 2.8915 19.237984 3.3456094 19.896484 4.0214844 C 20.554984 4.6973594 21.0185 5.595 21.3125 6.6875 L 19.5 8.5 L 15.5 4.5 L 16.9375 3.0625 L 17.3125 2.6875 z M 4.9785156 15.126953 C 4.990338 15.129931 6.1809555 15.430955 7.375 16.625 C 8.675 17.825 8.875 18.925781 8.875 18.925781 L 8.9179688 18.976562 L 5.3691406 20.119141 L 3.8730469 18.623047 L 4.9785156 15.126953 z"/>
51
+ </svg>`;
52
+ export const dotdotdot = `<svg viewBox="0 0 24 24" fill="currentColor">
53
+ <circle cy="12" r="3" cx="3"></circle>
54
+ <circle cy="12" r="3" cx="12"></circle>
55
+ <circle cx="21" cy="12" r="3"></circle>
56
+ </svg>`;
57
+ export const models = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
58
+ <path d="M4 4h6v6h-6z"></path>
59
+ <path d="M14 4h6v6h-6z"></path>
60
+ <path d="M4 14h6v6h-6z"></path>
61
+ <path d="M17 17m-3 0a3 3 0 1 0 6 0a3 3 0 1 0 -6 0"></path>
62
+ </svg>`;
63
+ export const pencilColored = `<svg viewBox="0 0 64 64">
64
+ <path fill="#ffce31" d="M7.934 41.132L39.828 9.246l14.918 14.922l-31.895 31.886z"></path>
65
+ <path d="M61.3 4.6l-1.9-1.9C55.8-.9 50-.9 46.3 2.7l-6.5 6.5l15 15l6.5-6.5c3.6-3.6 3.6-9.5 0-13.1" fill="#ed4c5c"></path>
66
+ <path fill="#93a2aa" d="M35.782 13.31l4.1-4.102l14.92 14.92l-4.1 4.101z"></path>
67
+ <path fill="#c7d3d8" d="M37.338 14.865l4.1-4.101l11.739 11.738l-4.102 4.1z"></path>
68
+ <path fill="#fed0ac" d="M7.9 41.1l-6.5 17l4.5 4.5l17-6.5z"/>
69
+ <path d="M.3 61.1c-.9 2.4.3 3.5 2.7 2.6l8.2-3.1l-7.7-7.7l-3.2 8.2" fill="#333"></path>
70
+ <path fill="#ffdf85" d="M7.89 41.175l27.86-27.86l4.95 4.95l-27.86 27.86z"/>
71
+ <path fill="#ff8736" d="M17.904 51.142l27.86-27.86l4.95 4.95l-27.86 27.86z"></path>
72
+ </svg>`;
73
+ export const diskColored = `<svg viewBox="-0.01 -0.008 100.016 100.016">
74
+ <path fill="#26f" fill_="#23475F" d="M88.555-.008H83v.016a2 2 0 0 1-2 2H19a2 2 0 0 1-2-2v-.016H4a4 4 0 0 0-4 4v92.016a4 4 0 0 0 4 4h92a4 4 0 0 0 4-4V11.517c.049-.089-11.436-11.454-11.445-11.525z"/>
75
+ <path fill="#04d" fill_="#1C3C50" d="M81.04 53.008H18.96a2 2 0 0 0-2 2v45h66.08v-45c0-1.106-.895-2-2-2zm-61.957-10h61.834a2 2 0 0 0 2-2V.547A1.993 1.993 0 0 1 81 2.007H19c-.916 0-1.681-.62-1.917-1.46v40.46a2 2 0 0 0 2 2.001z"/>
76
+ <path fill="#EBF0F1" d="M22 55.977h56a2 2 0 0 1 2 2v37.031a2 2 0 0 1-2 2H22c-1.104 0-2-.396-2-1.5V57.977a2 2 0 0 1 2-2z"/>
77
+ <path fill="#BCC4C8" d="M25 77.008h50v1H25v-1zm0 10h50v1H25v-1z"/>
78
+ <path fill="#1C3C50" d="M7 84.008h3a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2v-3a2 2 0 0 1 2-2zm83 0h3a2 2 0 0 1 2 2v3a2 2 0 0 1-2 2h-3a2 2 0 0 1-2-2v-3a2 2 0 0 1 2-2z"/>
79
+ <path fill="#BCC4C8" d="M37 1.981v36.026a2 2 0 0 0 2 2h39a2 2 0 0 0 2-2V1.981c0 .007-42.982.007-43 0zm37 29.027a2 2 0 0 1-2 2h-6a2 2 0 0 1-2-2V10.981a2 2 0 0 1 2-2h6a2 2 0 0 1 2 2v20.027z"/>
80
+ <path fill="#FF9D00" d="M78 55.977H22a2 2 0 0 0-2 2v10.031h60V57.977a2 2 0 0 0-2-2z"/>
81
+ </svg>`;
82
+ export const folderColored = `<svg viewBox="0 0 501.379 501.379">
83
+ <path style="fill:#EF9F2C;" d="M406.423,93.889H205.889c-17.067,0-30.933-13.867-30.933-30.933s-13.867-30.933-30.933-30.933H30.956
84
+ c-17.067,0-30.933,13.867-30.933,30.933v375.467c0,17.067,13.867,30.933,30.933,30.933h375.467
85
+ c17.067,0,30.933-13.867,30.933-30.933v-313.6C436.289,107.756,422.423,93.889,406.423,93.889z"/>
86
+ <path style="fill:#FEC656;" d="M470.423,157.889H97.089c-13.867,0-26.667,9.6-29.867,22.4l-66.133,249.6
87
+ c-5.333,19.2,9.6,38.4,29.867,38.4h373.333c13.867,0,26.667-9.6,29.867-22.4l66.133-248.533
88
+ C505.623,177.089,490.689,157.889,470.423,157.889z"/>
89
+ </svg>`;
90
+ export const modelsColored = `<svg viewBox="0 0 24 24">
91
+ <path fill="#aa3366" d="M0 0h10v10h-10z"></path>
92
+ <path d="M14 0h10v10h-10z" fill="#3366aa"></path>
93
+ <path d="M0 14h10v10h-10z" fill="#66aa33"></path>
94
+ <path fill="#dd9922" d="M19 19m-5 0 a5 5 0 1 0 10 0 a5 5 0 1 0 -10 0"></path>
95
+ </svg>`;
96
+ export const legoBlocksColored = `<svg viewBox="0 0 512 512">
97
+ <g>
98
+ <rect x="57.67" style="fill:#00BAB9;" width="101.275" height="78.769"/>
99
+ <rect x="205.363" style="fill:#00BAB9;" width="101.275" height="78.769"/>
100
+ <rect x="353.055" style="fill:#00BAB9;" width="101.275" height="78.769"/>
101
+ </g>
102
+ <polygon style="fill:#B8DE6F;" points="478.242,289.758 478.242,512 33.758,512 33.758,289.758 256,267.253 "/>
103
+ <polygon style="fill:#41D4D3;" points="478.242,67.516 478.242,289.758 33.758,289.758 33.758,67.516 57.67,67.516 158.945,67.516
104
+ 205.363,67.516 306.637,67.516 353.055,67.516 454.33,67.516 "/>
105
+ <g>
106
+ <circle style="fill:#00BAB9;" cx="402.286" cy="143.473" r="8.44"/>
107
+ <circle style="fill:#00BAB9;" cx="368.527" cy="177.231" r="8.44"/>
108
+ </g>
109
+ <circle style="fill:#7BD288;" cx="109.714" cy="436.044" r="8.44"/>
110
+ </svg>`;
111
+ export const legoBlockColored = `<svg viewBox="0 0 256 256">
112
+ <style>
113
+ .s0 { fill: #ff0000 }
114
+ .s1 { fill: #c30000 }
115
+ .s2 { fill: #800000 }
116
+ .s3 { fill: #cc0000 }
117
+ .s4 { fill: #e00000 }
118
+ </style>
119
+ <g id="Folder 2">
120
+ <path id="Shape 1 copy 2" class="s0" d="m128 61l116 45-116 139-116-139z"/>
121
+ <path id="Shape 1" class="s1" d="m12 106l116 45v95l-116-45z"/>
122
+ <path id="Shape 1 copy" class="s2" d="m244 106l-116 45v95l116-45z"/>
123
+ <g id="Folder 1">
124
+ <path id="Shape 2" class="s3" d="m102 111.2c0-6.1 11.4-9.9 25.5-9.9 14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9-14.1 0-25.5-4.8-25.5-10.9 0-3.3 0-13.3 0-16.6z"/>
125
+ <path id="Shape 2 copy 4" class="s1" d="m102 111.2c0-6.1 11.4-9.9 25.5-9.9 14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9-14.1 0-25.5-4.8-25.5-10.9 0-3.3 0-13.3 0-16.6z"/>
126
+ <path id="Shape 2 copy 2" class="s2" d="m127.5 101.3c14.1 0 25.5 3.8 25.5 9.9 0 3.3 0 13.3 0 16.6 0 6.1-11.4 10.9-25.5 10.9 0-13.1 0-25.7 0-37.4z"/>
127
+ <path id="Shape 2 copy" class="s0" d="m127.5 118.8c-12.2 0-22-3.4-22-7.6 0-4.2 9.8-7.7 22-7.7 12.2 0 22 3.5 22 7.7 0 4.2-9.8 7.6-22 7.6zm0 0c-12.2 0-22-3.4-22-7.6 0-4.2 9.8-7.7 22-7.7 12.2 0 22 3.5 22 7.7 0 4.2-9.8 7.6-22 7.6z"/>
128
+ </g>
129
+ <g id="Folder 1 copy">
130
+ <path id="Shape 2" class="s4" d="m103 67.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
131
+ <path id="Shape 2 copy 4" class="s1" d="m103 67.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
132
+ <path id="Shape 2 copy 2" class="s2" d="m127.5 58c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
133
+ <path id="Shape 2 copy" class="s0" d="m127.5 74.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
134
+ </g>
135
+ <g id="Folder 1 copy 2">
136
+ <path id="Shape 2" class="s4" d="m161 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
137
+ <path id="Shape 2 copy 4" class="s1" d="m161 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
138
+ <path id="Shape 2 copy 2" class="s2" d="m185.5 80c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
139
+ <path id="Shape 2 copy" class="s0" d="m185.5 96.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
140
+ </g>
141
+ <g id="Folder 1 copy 3">
142
+ <path id="Shape 2" class="s4" d="m45 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
143
+ <path id="Shape 2 copy 4" class="s1" d="m45 89.5c0-5.8 11-9.5 24.5-9.5 13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5-13.5 0-24.5-4.7-24.5-10.5 0-3.2 0-12.8 0-16z"/>
144
+ <path id="Shape 2 copy 2" class="s2" d="m69.5 80c13.5 0 24.5 3.7 24.5 9.5 0 3.2 0 12.8 0 16 0 5.8-11 10.5-24.5 10.5 0-12.6 0-24.8 0-36z"/>
145
+ <path id="Shape 2 copy" class="s0" d="m69.5 96.9c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4zm0 0c-11.7 0-21.2-3.3-21.2-7.4 0-4.1 9.5-7.4 21.2-7.4 11.7 0 21.2 3.3 21.2 7.4 0 4.1-9.5 7.4-21.2 7.4z"/>
146
+ </g>
147
+ </g>
148
+ </svg>`;
149
+ export const gearColored = `<svg viewBox="0 0 128 128" preserveAspectRatio="xMidYMid meet">
150
+ <path d="M124 71.85v-15.7c0-.59-.45-1.09-1.03-1.15l-17.83-1.89c-.47-.05-.85-.38-.98-.83c-.86-2.95-2.03-5.76-3.48-8.39c-.23-.41-.19-.92.11-1.28l11.28-13.94c.37-.46.34-1.13-.08-1.54l-11.1-11.1a1.15 1.15 0 0 0-1.54-.08L85.39 27.22c-.37.3-.87.33-1.28.11a41.796 41.796 0 0 0-8.39-3.48c-.45-.13-.78-.51-.83-.98L73 5.03C72.94 4.45 72.44 4 71.85 4h-15.7c-.59 0-1.09.45-1.15 1.03l-1.89 17.83c-.05.47-.38.85-.83.98c-2.95.86-5.76 2.03-8.39 3.48c-.41.23-.92.19-1.28-.11L28.67 15.94a1.15 1.15 0 0 0-1.54.08l-11.1 11.1a1.15 1.15 0 0 0-.08 1.54L27.23 42.6c.3.37.33.87.11 1.28a41.796 41.796 0 0 0-3.48 8.39c-.13.45-.51.78-.98.83L5.03 55c-.58.06-1.03.56-1.03 1.15v15.7c0 .59.45 1.09 1.03 1.15l17.83 1.89c.47.05.85.38.98.83c.86 2.95 2.03 5.76 3.48 8.39c.23.41.19.92-.11 1.28L15.94 99.33c-.37.46-.34 1.13.08 1.54l11.1 11.1c.42.42 1.08.45 1.54.08l13.94-11.28c.37-.3.87-.33 1.28-.11c2.64 1.45 5.45 2.62 8.39 3.48c.45.13.78.51.83.98l1.9 17.85c.06.59.56 1.03 1.15 1.03h15.7c.59 0 1.09-.45 1.15-1.03l1.89-17.83c.05-.47.38-.85.83-.98c2.95-.86 5.76-2.03 8.39-3.48c.41-.23.92-.19 1.28.11l13.94 11.28c.46.37 1.13.34 1.54-.08l11.1-11.1c.42-.42.45-1.08.08-1.54l-11.28-13.94c-.3-.37-.33-.87-.11-1.28c1.45-2.64 2.62-5.45 3.48-8.39c.13-.45.51-.78.98-.83L122.97 73c.58-.06 1.03-.56 1.03-1.15zm-60 3.43c-6.23 0-11.28-5.05-11.28-11.28S57.77 52.72 64 52.72S75.28 57.77 75.28 64S70.23 75.28 64 75.28z" fill="#82aec0"></path>
151
+ <path d="M80.56 49.48c3.67 4.18 5.78 9.77 5.43 15.85c-.65 11.16-9.83 20.19-21 20.68c-4.75.21-9.18-1.09-12.86-3.45c-.28-.18-.58.2-.34.44a22.412 22.412 0 0 0 17.85 6.67c10.78-.85 19.56-9.5 20.55-20.27c.77-8.36-3.06-15.87-9.23-20.33c-.29-.2-.62.15-.4.41z" fill="#2f7889"></path>
152
+ <path d="M43.87 65.32c-.67-13.15 7.83-22.79 20.01-22.79c.65 0 1.68 0 2.48.92c1.01 1.18 1.1 2.6 0 3.77c-.81.86-1.95.92-2.53 1c-12.3 1.59-15.18 9.35-15.83 16.77c-.03.33.06 2.35-1.71 2.56c-2.15.25-2.41-1.91-2.42-2.23z" fill="#b9e4ea"></path>
153
+ <path d="M25.24 65.87c-.01-22.03 15.9-40.19 38.13-41.05c.68-.03 2.45 0 3.55.99c1.01.91 1.38 2.51.79 3.82c-.95 2.11-2.85 2.07-3.36 2.09c-18.51.66-34.18 15.73-34.19 33.95c0 .29-.05.58-.15.84l-.1.25c-.76 1.98-3.52 2.09-4.43.18c-.15-.34-.24-.7-.24-1.07z" fill="#94d1e0"></path>
154
+ </svg>`;
155
+ export function $svg(markup, attrs) {
156
+ if (!markup.match(/^\s*<svg/)) {
157
+ throw new Error("Cannot call $svg with non-svg markup.");
158
+ }
159
+ return $el(markup, attrs || {});
160
+ }
rgthree-comfy/web/common/shared_utils.js ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export function getResolver(timeout = 5000) {
2
+ const resolver = {};
3
+ resolver.id = generateId(8);
4
+ resolver.completed = false;
5
+ resolver.resolved = false;
6
+ resolver.rejected = false;
7
+ resolver.promise = new Promise((resolve, reject) => {
8
+ resolver.reject = (e) => {
9
+ resolver.completed = true;
10
+ resolver.rejected = true;
11
+ reject(e);
12
+ };
13
+ resolver.resolve = (data) => {
14
+ resolver.completed = true;
15
+ resolver.resolved = true;
16
+ resolve(data);
17
+ };
18
+ });
19
+ resolver.timeout = setTimeout(() => {
20
+ if (!resolver.completed) {
21
+ resolver.reject();
22
+ }
23
+ }, timeout);
24
+ return resolver;
25
+ }
26
+ const DEBOUNCE_FN_TO_PROMISE = new WeakMap();
27
+ export function debounce(fn, ms = 64) {
28
+ if (!DEBOUNCE_FN_TO_PROMISE.get(fn)) {
29
+ DEBOUNCE_FN_TO_PROMISE.set(fn, wait(ms).then(() => {
30
+ DEBOUNCE_FN_TO_PROMISE.delete(fn);
31
+ fn();
32
+ }));
33
+ }
34
+ return DEBOUNCE_FN_TO_PROMISE.get(fn);
35
+ }
36
+ export function wait(ms = 16) {
37
+ if (ms === 16) {
38
+ return new Promise((resolve) => {
39
+ requestAnimationFrame(() => {
40
+ resolve();
41
+ });
42
+ });
43
+ }
44
+ return new Promise((resolve) => {
45
+ setTimeout(() => {
46
+ resolve();
47
+ }, ms);
48
+ });
49
+ }
50
+ function dec2hex(dec) {
51
+ return dec.toString(16).padStart(2, "0");
52
+ }
53
+ export function generateId(length) {
54
+ const arr = new Uint8Array(length / 2);
55
+ crypto.getRandomValues(arr);
56
+ return Array.from(arr, dec2hex).join("");
57
+ }
58
+ export function getObjectValue(obj, objKey, def) {
59
+ if (!obj || !objKey)
60
+ return def;
61
+ const keys = objKey.split(".");
62
+ const key = keys.shift();
63
+ const found = obj[key];
64
+ if (keys.length) {
65
+ return getObjectValue(found, keys.join("."), def);
66
+ }
67
+ return found;
68
+ }
69
+ export function setObjectValue(obj, objKey, value, createMissingObjects = true) {
70
+ if (!obj || !objKey)
71
+ return obj;
72
+ const keys = objKey.split(".");
73
+ const key = keys.shift();
74
+ if (obj[key] === undefined) {
75
+ if (!createMissingObjects) {
76
+ return;
77
+ }
78
+ obj[key] = {};
79
+ }
80
+ if (!keys.length) {
81
+ obj[key] = value;
82
+ }
83
+ else {
84
+ if (typeof obj[key] != "object") {
85
+ obj[key] = {};
86
+ }
87
+ setObjectValue(obj[key], keys.join("."), value, createMissingObjects);
88
+ }
89
+ return obj;
90
+ }
91
+ export function moveArrayItem(arr, itemOrFrom, to) {
92
+ const from = typeof itemOrFrom === "number" ? itemOrFrom : arr.indexOf(itemOrFrom);
93
+ arr.splice(to, 0, arr.splice(from, 1)[0]);
94
+ }
95
+ export function removeArrayItem(arr, itemOrIndex) {
96
+ const index = typeof itemOrIndex === "number" ? itemOrIndex : arr.indexOf(itemOrIndex);
97
+ arr.splice(index, 1);
98
+ }
99
+ export function injectCss(href) {
100
+ if (document.querySelector(`link[href^="${href}"]`)) {
101
+ return Promise.resolve();
102
+ }
103
+ return new Promise((resolve) => {
104
+ const link = document.createElement("link");
105
+ link.setAttribute("rel", "stylesheet");
106
+ link.setAttribute("type", "text/css");
107
+ const timeout = setTimeout(resolve, 1000);
108
+ link.addEventListener("load", (e) => {
109
+ clearInterval(timeout);
110
+ resolve();
111
+ });
112
+ link.href = href;
113
+ document.head.appendChild(link);
114
+ });
115
+ }
116
+ export function defineProperty(instance, property, desc) {
117
+ var _a, _b, _c, _d, _e, _f;
118
+ const existingDesc = Object.getOwnPropertyDescriptor(instance, property);
119
+ if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.configurable) === false) {
120
+ throw new Error(`Error: rgthree-comfy cannot define un-configurable property "${property}"`);
121
+ }
122
+ if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.get) && desc.get) {
123
+ const descGet = desc.get;
124
+ desc.get = () => {
125
+ existingDesc.get.apply(instance, []);
126
+ return descGet.apply(instance, []);
127
+ };
128
+ }
129
+ if ((existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.set) && desc.set) {
130
+ const descSet = desc.set;
131
+ desc.set = (v) => {
132
+ existingDesc.set.apply(instance, [v]);
133
+ return descSet.apply(instance, [v]);
134
+ };
135
+ }
136
+ desc.enumerable = (_b = (_a = desc.enumerable) !== null && _a !== void 0 ? _a : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.enumerable) !== null && _b !== void 0 ? _b : true;
137
+ desc.configurable = (_d = (_c = desc.configurable) !== null && _c !== void 0 ? _c : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.configurable) !== null && _d !== void 0 ? _d : true;
138
+ if (!desc.get && !desc.set) {
139
+ desc.writable = (_f = (_e = desc.writable) !== null && _e !== void 0 ? _e : existingDesc === null || existingDesc === void 0 ? void 0 : existingDesc.writable) !== null && _f !== void 0 ? _f : true;
140
+ }
141
+ return Object.defineProperty(instance, property, desc);
142
+ }
rgthree-comfy/web/common/utils_dom.js ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const DIRECT_ATTRIBUTE_MAP = {
2
+ cellpadding: 'cellPadding',
3
+ cellspacing: 'cellSpacing',
4
+ colspan: 'colSpan',
5
+ frameborder: 'frameBorder',
6
+ height: 'height',
7
+ maxlength: 'maxLength',
8
+ nonce: 'nonce',
9
+ role: 'role',
10
+ rowspan: 'rowSpan',
11
+ type: 'type',
12
+ usemap: 'useMap',
13
+ valign: 'vAlign',
14
+ width: 'width',
15
+ };
16
+ const RGX_NUMERIC_STYLE_UNIT = 'px';
17
+ const RGX_NUMERIC_STYLE = /^((max|min)?(width|height)|margin|padding|(margin|padding)?(left|top|bottom|right)|fontsize|borderwidth)$/i;
18
+ const RGX_DEFAULT_VALUE_PROP = /input|textarea|select/i;
19
+ function localAssertNotFalsy(input, errorMsg = `Input is not of type.`) {
20
+ if (input == null) {
21
+ throw new Error(errorMsg);
22
+ }
23
+ return input;
24
+ }
25
+ const RGX_STRING_VALID = '[a-z0-9_-]';
26
+ const RGX_TAG = new RegExp(`^([a-z]${RGX_STRING_VALID}*)(\\.|\\[|\\#|$)`, 'i');
27
+ const RGX_ATTR_ID = new RegExp(`#(${RGX_STRING_VALID}+)`, 'gi');
28
+ const RGX_ATTR_CLASS = new RegExp(`(^|\\S)\\.([a-z0-9_\\-\\.]+)`, 'gi');
29
+ const RGX_STRING_CONTENT_TO_SQUARES = '(.*?)(\\[|\\])';
30
+ const RGX_ATTRS_MAYBE_OPEN = new RegExp(`\\[${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi');
31
+ const RGX_ATTRS_FOLLOW_OPEN = new RegExp(`^${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi');
32
+ export function query(selectors, parent = document) {
33
+ return Array.from(parent.querySelectorAll(selectors)).filter(n => !!n);
34
+ }
35
+ export function queryOne(selectors, parent = document) {
36
+ var _a;
37
+ return (_a = parent.querySelector(selectors)) !== null && _a !== void 0 ? _a : null;
38
+ }
39
+ export function createText(text) {
40
+ return document.createTextNode(text);
41
+ }
42
+ export function getClosestOrSelf(element, query) {
43
+ const el = element;
44
+ return ((el === null || el === void 0 ? void 0 : el.closest) && (el.matches(query) && el || el.closest(query))) || null;
45
+ }
46
+ export function containsOrSelf(parent, contained) {
47
+ var _a;
48
+ return parent === contained || ((_a = parent === null || parent === void 0 ? void 0 : parent.contains) === null || _a === void 0 ? void 0 : _a.call(parent, contained)) || false;
49
+ }
50
+ export function createElement(selectorOrMarkup, attrs) {
51
+ const frag = getHtmlFragment(selectorOrMarkup);
52
+ let element = frag === null || frag === void 0 ? void 0 : frag.firstElementChild;
53
+ let selector = "";
54
+ if (!element) {
55
+ selector = selectorOrMarkup.replace(/[\r\n]\s*/g, "");
56
+ const tag = getSelectorTag(selector) || "div";
57
+ element = document.createElement(tag);
58
+ selector = selector.replace(RGX_TAG, "$2");
59
+ selector = selector.replace(RGX_ATTR_ID, '[id="$1"]');
60
+ selector = selector.replace(RGX_ATTR_CLASS, (match, p1, p2) => `${p1}[class="${p2.replace(/\./g, " ")}"]`);
61
+ }
62
+ const selectorAttrs = getSelectorAttributes(selector);
63
+ if (selectorAttrs) {
64
+ for (const attr of selectorAttrs) {
65
+ let matches = attr.substring(1, attr.length - 1).split("=");
66
+ let key = localAssertNotFalsy(matches.shift());
67
+ let value = matches.join("=");
68
+ if (value === undefined) {
69
+ setAttribute(element, key, true);
70
+ }
71
+ else {
72
+ value = value.replace(/^['"](.*)['"]$/, "$1");
73
+ setAttribute(element, key, value);
74
+ }
75
+ }
76
+ }
77
+ if (attrs) {
78
+ setAttributes(element, attrs);
79
+ }
80
+ return element;
81
+ }
82
+ export const $el = createElement;
83
+ function getSelectorTag(str) {
84
+ return tryMatch(str, RGX_TAG);
85
+ }
86
+ function getSelectorAttributes(selector) {
87
+ RGX_ATTRS_MAYBE_OPEN.lastIndex = 0;
88
+ let attrs = [];
89
+ let result;
90
+ while (result = RGX_ATTRS_MAYBE_OPEN.exec(selector)) {
91
+ let attr = result[0];
92
+ if (attr.endsWith(']')) {
93
+ attrs.push(attr);
94
+ }
95
+ else {
96
+ attr = result[0]
97
+ + getOpenAttributesRecursive(selector.substr(RGX_ATTRS_MAYBE_OPEN.lastIndex), 2);
98
+ RGX_ATTRS_MAYBE_OPEN.lastIndex += (attr.length - result[0].length);
99
+ attrs.push(attr);
100
+ }
101
+ }
102
+ return attrs;
103
+ }
104
+ function getOpenAttributesRecursive(selectorSubstring, openCount) {
105
+ let matches = selectorSubstring.match(RGX_ATTRS_FOLLOW_OPEN);
106
+ let result = '';
107
+ if (matches && matches.length) {
108
+ result = matches[0];
109
+ openCount += result.endsWith(']') ? -1 : 1;
110
+ if (openCount > 0) {
111
+ result += getOpenAttributesRecursive(selectorSubstring.substr(result.length), openCount);
112
+ }
113
+ }
114
+ return result;
115
+ }
116
+ function tryMatch(str, rgx, index = 1) {
117
+ var _a;
118
+ let found = '';
119
+ try {
120
+ found = ((_a = str.match(rgx)) === null || _a === void 0 ? void 0 : _a[index]) || '';
121
+ }
122
+ catch (e) {
123
+ found = '';
124
+ }
125
+ return found;
126
+ }
127
+ export function setAttributes(element, data) {
128
+ let attr;
129
+ for (attr in data) {
130
+ if (data.hasOwnProperty(attr)) {
131
+ setAttribute(element, attr, data[attr]);
132
+ }
133
+ }
134
+ }
135
+ function getHtmlFragment(value) {
136
+ if (value.match(/^\s*<.*?>[\s\S]*<\/[a-z0-9]+>\s*$/)) {
137
+ return document.createRange().createContextualFragment(value.trim());
138
+ }
139
+ return null;
140
+ }
141
+ function getChild(value) {
142
+ if (value instanceof Node) {
143
+ return value;
144
+ }
145
+ if (typeof value === 'string') {
146
+ let child = getHtmlFragment(value);
147
+ if (child) {
148
+ return child;
149
+ }
150
+ if (getSelectorTag(value)) {
151
+ return createElement(value);
152
+ }
153
+ return createText(value);
154
+ }
155
+ if (value && typeof value.toElement === 'function') {
156
+ return value.toElement();
157
+ }
158
+ return null;
159
+ }
160
+ export function setAttribute(element, attribute, value) {
161
+ let isRemoving = value == null;
162
+ if (attribute === 'default') {
163
+ attribute = RGX_DEFAULT_VALUE_PROP.test(element.nodeName) ? 'value' : 'text';
164
+ }
165
+ if (attribute === 'text') {
166
+ empty(element).appendChild(createText(value != null ? String(value) : ''));
167
+ }
168
+ else if (attribute === 'html') {
169
+ empty(element).innerHTML += value != null ? String(value) : '';
170
+ }
171
+ else if (attribute == 'style') {
172
+ if (typeof value === 'string') {
173
+ element.style.cssText = isRemoving ? '' : (value != null ? String(value) : '');
174
+ }
175
+ else {
176
+ for (const [styleKey, styleValue] of Object.entries(value)) {
177
+ element.style[styleKey] = styleValue;
178
+ }
179
+ }
180
+ }
181
+ else if (attribute == 'events') {
182
+ for (const [key, fn] of Object.entries(value)) {
183
+ addEvent(element, key, fn);
184
+ }
185
+ }
186
+ else if (attribute === 'parent') {
187
+ value.appendChild(element);
188
+ }
189
+ else if (attribute === 'child' || attribute === 'children') {
190
+ if (typeof value === 'string' && /^\[[^\[\]]+\]$/.test(value)) {
191
+ const parseable = value.replace(/^\[([^\[\]]+)\]$/, '["$1"]').replace(/,/g, '","');
192
+ try {
193
+ const parsed = JSON.parse(parseable);
194
+ value = parsed;
195
+ }
196
+ catch (e) {
197
+ console.error(e);
198
+ }
199
+ }
200
+ if (attribute === 'children') {
201
+ empty(element);
202
+ }
203
+ let children = value instanceof Array ? value : [value];
204
+ for (let child of children) {
205
+ child = getChild(child);
206
+ if (child instanceof Node) {
207
+ if (element instanceof HTMLTemplateElement) {
208
+ element.content.appendChild(child);
209
+ }
210
+ else {
211
+ element.appendChild(child);
212
+ }
213
+ }
214
+ }
215
+ }
216
+ else if (attribute == 'for') {
217
+ element.htmlFor = value != null ? String(value) : '';
218
+ if (isRemoving) {
219
+ element.removeAttribute('for');
220
+ }
221
+ }
222
+ else if (attribute === 'class' || attribute === 'className' || attribute === 'classes') {
223
+ element.className = isRemoving ? '' : Array.isArray(value) ? value.join(' ') : String(value);
224
+ }
225
+ else if (attribute === 'dataset') {
226
+ if (typeof value !== 'object') {
227
+ console.error('Expecting an object for dataset');
228
+ return;
229
+ }
230
+ for (const [key, val] of Object.entries(value)) {
231
+ element.dataset[key] = String(val);
232
+ }
233
+ }
234
+ else if (attribute.startsWith('on') && typeof value === 'function') {
235
+ element.addEventListener(attribute.substring(2), value);
236
+ }
237
+ else if (['checked', 'disabled', 'readonly', 'required', 'selected'].includes(attribute)) {
238
+ element[attribute] = !!value;
239
+ if (!value) {
240
+ element.removeAttribute(attribute);
241
+ }
242
+ else {
243
+ element.setAttribute(attribute, attribute);
244
+ }
245
+ }
246
+ else if (DIRECT_ATTRIBUTE_MAP.hasOwnProperty(attribute)) {
247
+ if (isRemoving) {
248
+ element.removeAttribute(DIRECT_ATTRIBUTE_MAP[attribute]);
249
+ }
250
+ else {
251
+ element.setAttribute(DIRECT_ATTRIBUTE_MAP[attribute], String(value));
252
+ }
253
+ }
254
+ else if (isRemoving) {
255
+ element.removeAttribute(attribute);
256
+ }
257
+ else {
258
+ let oldVal = element.getAttribute(attribute);
259
+ if (oldVal !== value) {
260
+ element.setAttribute(attribute, String(value));
261
+ }
262
+ }
263
+ }
264
+ function addEvent(element, key, fn) {
265
+ element.addEventListener(key, fn);
266
+ }
267
+ function setStyles(element, styles = null) {
268
+ if (styles) {
269
+ for (let name in styles) {
270
+ setStyle(element, name, styles[name]);
271
+ }
272
+ }
273
+ return element;
274
+ }
275
+ function setStyle(element, name, value) {
276
+ name = (name.indexOf('float') > -1 ? 'cssFloat' : name);
277
+ if (name.indexOf('-') != -1) {
278
+ name = name.replace(/-\D/g, (match) => {
279
+ return match.charAt(1).toUpperCase();
280
+ });
281
+ }
282
+ if (value == String(Number(value)) && RGX_NUMERIC_STYLE.test(name)) {
283
+ value = value + RGX_NUMERIC_STYLE_UNIT;
284
+ }
285
+ if (name === 'display' && typeof value !== 'string') {
286
+ value = !!value ? null : 'none';
287
+ }
288
+ element.style[name] = value === null ? null : String(value);
289
+ return element;
290
+ }
291
+ ;
292
+ export function empty(element) {
293
+ while (element.firstChild) {
294
+ element.removeChild(element.firstChild);
295
+ }
296
+ return element;
297
+ }
298
+ export function appendChildren(el, children) {
299
+ children = !Array.isArray(children) ? [children] : children;
300
+ for (let child of children) {
301
+ child = getChild(child);
302
+ if (child instanceof Node) {
303
+ if (el instanceof HTMLTemplateElement) {
304
+ el.content.appendChild(child);
305
+ }
306
+ else {
307
+ el.appendChild(child);
308
+ }
309
+ }
310
+ }
311
+ }
rgthree-comfy/web/common/utils_workflow.js ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { getResolver } from "./shared_utils.js";
2
+ import { getPngMetadata, getWebpMetadata } from "../../scripts/pnginfo.js";
3
+ function parseWorkflowJson(stringJson) {
4
+ stringJson = stringJson || "null";
5
+ stringJson = stringJson.replace(/:\s*NaN/g, ": null");
6
+ return JSON.parse(stringJson);
7
+ }
8
+ export async function tryToGetWorkflowDataFromEvent(e) {
9
+ var _a, _b, _c, _d;
10
+ let work;
11
+ for (const file of ((_a = e.dataTransfer) === null || _a === void 0 ? void 0 : _a.files) || []) {
12
+ const data = await tryToGetWorkflowDataFromFile(file);
13
+ if (data.workflow || data.prompt) {
14
+ return data;
15
+ }
16
+ }
17
+ const validTypes = ["text/uri-list", "text/x-moz-url"];
18
+ const match = (((_b = e.dataTransfer) === null || _b === void 0 ? void 0 : _b.types) || []).find((t) => validTypes.find((v) => t === v));
19
+ if (match) {
20
+ const uri = (_d = (_c = e.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0];
21
+ if (uri) {
22
+ return tryToGetWorkflowDataFromFile(await (await fetch(uri)).blob());
23
+ }
24
+ }
25
+ return { workflow: null, prompt: null };
26
+ }
27
+ export async function tryToGetWorkflowDataFromFile(file) {
28
+ var _a;
29
+ if (file.type === "image/png") {
30
+ const pngInfo = await getPngMetadata(file);
31
+ return {
32
+ workflow: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow),
33
+ prompt: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt),
34
+ };
35
+ }
36
+ if (file.type === "image/webp") {
37
+ const pngInfo = await getWebpMetadata(file);
38
+ const workflow = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Workflow) || "null");
39
+ const prompt = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Prompt) || "null");
40
+ return { workflow, prompt };
41
+ }
42
+ if (file.type === "application/json" || ((_a = file.name) === null || _a === void 0 ? void 0 : _a.endsWith(".json"))) {
43
+ const resolver = getResolver();
44
+ const reader = new FileReader();
45
+ reader.onload = async () => {
46
+ const json = parseWorkflowJson(reader.result);
47
+ const isApiJson = Object.values(json).every((v) => v.class_type);
48
+ const prompt = isApiJson ? json : null;
49
+ const workflow = !isApiJson && !(json === null || json === void 0 ? void 0 : json.templates) ? json : null;
50
+ return { workflow, prompt };
51
+ };
52
+ return resolver.promise;
53
+ }
54
+ return { workflow: null, prompt: null };
55
+ }
rgthree-comfy/web/link_fixer/link_page.js ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { fixBadLinks } from "../common/link_fixer.js";
2
+ import { getPngMetadata } from "../../scripts/pnginfo.js";
3
+ function wait(ms = 16, value) {
4
+ return new Promise((resolve) => {
5
+ setTimeout(() => {
6
+ resolve(value);
7
+ }, ms);
8
+ });
9
+ }
10
+ const logger = {
11
+ logTo: console,
12
+ log: (...args) => {
13
+ logger.logTo === console
14
+ ? console.log(...args)
15
+ : (logger.logTo.innerText += args.join(",") + "\n");
16
+ },
17
+ };
18
+ const findBadLinksLogger = {
19
+ log: async (...args) => {
20
+ logger.log(...args);
21
+ },
22
+ };
23
+ export class LinkPage {
24
+ constructor() {
25
+ this.containerEl = document.querySelector(".box");
26
+ this.figcaptionEl = document.querySelector("figcaption");
27
+ this.outputeMessageEl = document.querySelector(".output");
28
+ this.outputImageEl = document.querySelector(".output-image");
29
+ this.btnFix = document.querySelector(".btn-fix");
30
+ document.addEventListener("dragover", (e) => {
31
+ e.preventDefault();
32
+ }, false);
33
+ document.addEventListener("drop", (e) => {
34
+ this.onDrop(e);
35
+ });
36
+ this.btnFix.addEventListener("click", (e) => {
37
+ this.onFixClick(e);
38
+ });
39
+ }
40
+ async onFixClick(e) {
41
+ if (!this.graphResults || !this.graph) {
42
+ this.updateUi("⛔ Fix button click without results.");
43
+ return;
44
+ }
45
+ let graphFinalResults = fixBadLinks(this.graph, true);
46
+ graphFinalResults = fixBadLinks(graphFinalResults.graph, true);
47
+ if (graphFinalResults.patched || graphFinalResults.deleted) {
48
+ graphFinalResults = fixBadLinks(graphFinalResults.graph, true);
49
+ }
50
+ this.graphFinalResults = graphFinalResults;
51
+ await this.saveFixedWorkflow();
52
+ if (graphFinalResults.hasBadLinks) {
53
+ this.updateUi("⛔ Hmm... Still detecting bad links. Can you file an issue at https://github.com/rgthree/rgthree-comfy/issues with your image/workflow.");
54
+ }
55
+ else {
56
+ this.updateUi("✅ Workflow fixed.<br><br><small>Please load new saved workflow json and double check linking and execution.</small>");
57
+ }
58
+ }
59
+ async onDrop(event) {
60
+ var _a, _b, _c, _d;
61
+ if (!event.dataTransfer) {
62
+ return;
63
+ }
64
+ this.reset();
65
+ event.preventDefault();
66
+ event.stopPropagation();
67
+ if (event.dataTransfer.files.length && ((_b = (_a = event.dataTransfer.files) === null || _a === void 0 ? void 0 : _a[0]) === null || _b === void 0 ? void 0 : _b.type) !== "image/bmp") {
68
+ await this.handleFile(event.dataTransfer.files[0]);
69
+ return;
70
+ }
71
+ const validTypes = ["text/uri-list", "text/x-moz-url"];
72
+ const match = [...event.dataTransfer.types].find((t) => validTypes.find((v) => t === v));
73
+ if (match) {
74
+ const uri = (_d = (_c = event.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0];
75
+ if (uri) {
76
+ await this.handleFile(await (await fetch(uri)).blob());
77
+ }
78
+ }
79
+ }
80
+ reset() {
81
+ this.file = undefined;
82
+ this.graph = undefined;
83
+ this.graphResults = undefined;
84
+ this.graphFinalResults = undefined;
85
+ this.updateUi();
86
+ }
87
+ updateUi(msg) {
88
+ this.outputeMessageEl.innerHTML = "";
89
+ if (this.file && !this.containerEl.classList.contains("-has-file")) {
90
+ this.containerEl.classList.add("-has-file");
91
+ this.figcaptionEl.innerHTML = this.file.name || this.file.type;
92
+ if (this.file.type === "application/json") {
93
+ this.outputImageEl.src = "icon_file_json.png";
94
+ }
95
+ else {
96
+ const reader = new FileReader();
97
+ reader.onload = () => (this.outputImageEl.src = reader.result);
98
+ reader.readAsDataURL(this.file);
99
+ }
100
+ }
101
+ else if (!this.file && this.containerEl.classList.contains("-has-file")) {
102
+ this.containerEl.classList.remove("-has-file");
103
+ this.outputImageEl.src = "";
104
+ this.outputImageEl.removeAttribute("src");
105
+ }
106
+ if (this.graphResults) {
107
+ this.containerEl.classList.add("-has-results");
108
+ if (!this.graphResults.patched && !this.graphResults.deleted) {
109
+ this.outputeMessageEl.innerHTML = "✅ No bad links detected in the workflow.";
110
+ }
111
+ else {
112
+ this.containerEl.classList.add("-has-fixable-results");
113
+ this.outputeMessageEl.innerHTML = `⚠️ Found ${this.graphResults.patched} links to fix, and ${this.graphResults.deleted} to be removed.`;
114
+ }
115
+ }
116
+ else {
117
+ this.containerEl.classList.remove("-has-results");
118
+ this.containerEl.classList.remove("-has-fixable-results");
119
+ }
120
+ if (msg) {
121
+ this.outputeMessageEl.innerHTML = msg;
122
+ }
123
+ }
124
+ async handleFile(file) {
125
+ this.file = file;
126
+ this.updateUi();
127
+ let workflow = null;
128
+ if (file.type.startsWith("image/")) {
129
+ const pngInfo = await getPngMetadata(file);
130
+ workflow = pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow;
131
+ }
132
+ else if (file.type === "application/json" ||
133
+ (file instanceof File && file.name.endsWith(".json"))) {
134
+ workflow = await new Promise((resolve) => {
135
+ const reader = new FileReader();
136
+ reader.onload = () => {
137
+ resolve(reader.result);
138
+ };
139
+ reader.readAsText(file);
140
+ });
141
+ }
142
+ if (!workflow) {
143
+ this.updateUi("⛔ No workflow found in dropped item.");
144
+ }
145
+ else {
146
+ try {
147
+ this.graph = JSON.parse(workflow);
148
+ }
149
+ catch (e) {
150
+ this.graph = undefined;
151
+ }
152
+ if (!this.graph) {
153
+ this.updateUi("⛔ Invalid workflow found in dropped item.");
154
+ }
155
+ else {
156
+ this.loadGraphData(this.graph);
157
+ }
158
+ }
159
+ }
160
+ async loadGraphData(graphData) {
161
+ this.graphResults = await fixBadLinks(graphData);
162
+ this.updateUi();
163
+ }
164
+ async saveFixedWorkflow() {
165
+ if (!this.graphFinalResults) {
166
+ this.updateUi("⛔ Save w/o final graph patched.");
167
+ return false;
168
+ }
169
+ let filename = this.file.name || "workflow.json";
170
+ let filenames = filename.split(".");
171
+ filenames.pop();
172
+ filename = filenames.join(".");
173
+ filename += "_fixed.json";
174
+ filename = prompt("Save workflow as:", filename);
175
+ if (!filename)
176
+ return false;
177
+ if (!filename.toLowerCase().endsWith(".json")) {
178
+ filename += ".json";
179
+ }
180
+ const json = JSON.stringify(this.graphFinalResults.graph, null, 2);
181
+ const blob = new Blob([json], { type: "application/json" });
182
+ const url = URL.createObjectURL(blob);
183
+ const anchor = document.createElement("a");
184
+ anchor.download = filename;
185
+ anchor.href = url;
186
+ anchor.style.display = "none";
187
+ document.body.appendChild(anchor);
188
+ await wait();
189
+ anchor.click();
190
+ await wait();
191
+ anchor.remove();
192
+ window.URL.revokeObjectURL(url);
193
+ return true;
194
+ }
195
+ }
sd-dynamic-thresholding/.github/FUNDING.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ github: mcmonkey4eva
sd-dynamic-thresholding/.github/workflows/publish.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - master
7
+ paths:
8
+ - "pyproject.toml"
9
+
10
+ jobs:
11
+ publish-node:
12
+ name: Publish Custom Node to registry
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Check out code
16
+ uses: actions/checkout@v4
17
+ - name: Publish Custom Node
18
+ uses: Comfy-Org/publish-node-action@main
19
+ with:
20
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
21
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
sd-dynamic-thresholding/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (435 Bytes). View file
 
sd-dynamic-thresholding/__pycache__/dynthres_comfyui.cpython-312.pyc ADDED
Binary file (4.19 kB). View file
 
sd-dynamic-thresholding/__pycache__/dynthres_core.cpython-312.pyc ADDED
Binary file (9.12 kB). View file
 
sd-dynamic-thresholding/github/comfy_node.png ADDED
sd-dynamic-thresholding/github/ui.png ADDED
sd-dynamic-thresholding/javascript/active.js ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let dynthres_update_enabled = function() {
2
+ return Array.from(arguments);
3
+ };
4
+
5
+ (function(){
6
+ let accordions = {};
7
+ let enabled = {};
8
+ onUiUpdate(() => {
9
+ let accordion_id_prefix = "#dynthres_";
10
+ let extension_checkbox_class = ".dynthres-enabled";
11
+
12
+ dynthres_update_enabled = function() {
13
+ let res = Array.from(arguments);
14
+ let tabname = res[1] ? "img2img" : "txt2img";
15
+
16
+ let checkbox = accordions[tabname]?.querySelector(extension_checkbox_class + ' input');
17
+ checkbox?.dispatchEvent(new Event('change'));
18
+
19
+ return res;
20
+ };
21
+
22
+ function attachEnabledButtonListener(checkbox, accordion) {
23
+ let span = accordion.querySelector('.label-wrap span');
24
+ let badge = document.createElement('input');
25
+ badge.type = "checkbox";
26
+ badge.checked = checkbox.checked;
27
+ badge.addEventListener('click', (e) => {
28
+ checkbox.checked = !checkbox.checked;
29
+ badge.checked = checkbox.checked;
30
+ checkbox.dispatchEvent(new Event('change'));
31
+ e.stopPropagation();
32
+ });
33
+
34
+ badge.className = checkbox.className;
35
+ badge.classList.add('primary');
36
+ span.insertBefore(badge, span.firstChild);
37
+ let space = document.createElement('span');
38
+ space.innerHTML = "&nbsp;";
39
+ span.insertBefore(space, badge.nextSibling);
40
+
41
+ checkbox.addEventListener('change', () => {
42
+ let badge = accordion.querySelector('.label-wrap span input');
43
+ badge.checked = checkbox.checked;
44
+ });
45
+ checkbox.parentNode.style.display = "none";
46
+ }
47
+
48
+ if (Object.keys(accordions).length < 2) {
49
+ let accordion = gradioApp().querySelector(accordion_id_prefix + 'txt2img');
50
+ if (accordion) {
51
+ accordions.txt2img = accordion;
52
+ }
53
+ accordion = gradioApp().querySelector(accordion_id_prefix + 'img2img');
54
+ if (accordion) {
55
+ accordions.img2img = accordion;
56
+ }
57
+ }
58
+
59
+ if (Object.keys(accordions).length > 0 && accordions.txt2img && !enabled.txt2img) {
60
+ enabled.txt2img = accordions.txt2img.querySelector(extension_checkbox_class + ' input');
61
+ attachEnabledButtonListener(enabled.txt2img, accordions.txt2img);
62
+ }
63
+ if (Object.keys(accordions).length > 0 && accordions.img2img && !enabled.img2img) {
64
+ enabled.img2img = accordions.img2img.querySelector(extension_checkbox_class + ' input');
65
+ attachEnabledButtonListener(enabled.img2img, accordions.img2img);
66
+ }
67
+ });
68
+ })();
sd-dynamic-thresholding/scripts/dynamic_thresholding.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################
2
+ # Stable Diffusion Dynamic Thresholding (CFG Scale Fix)
3
+ #
4
+ # Author: Alex 'mcmonkey' Goodwin
5
+ # GitHub URL: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding
6
+ # Created: 2022/01/26
7
+ # Last updated: 2023/01/30
8
+ #
9
+ # For usage help, view the README.md file in the extension root, or via the GitHub page.
10
+ #
11
+ ##################
12
+
13
+ import gradio as gr
14
+ import torch, traceback
15
+ import dynthres_core
16
+ from modules import scripts, script_callbacks, sd_samplers, sd_samplers_compvis, sd_samplers_common
17
+ try:
18
+ import dynthres_unipc
19
+ except Exception as e:
20
+ print(f"\n\n======\nError! UniPC sampler support failed to load! Is your WebUI up to date?\n(Error: {e})\n======")
21
+ try:
22
+ from modules.sd_samplers_kdiffusion import CFGDenoiserKDiffusion as cfgdenoisekdiff
23
+ IS_AUTO_16 = True
24
+ except Exception as e:
25
+ print(f"\n\n======\nWarning! Using legacy KDiff version! Is your WebUI up to date?\n======")
26
+ from modules.sd_samplers_kdiffusion import CFGDenoiser as cfgdenoisekdiff
27
+ IS_AUTO_16 = False
28
+
29
+ DISABLE_VISIBILITY = True
30
+
31
+ ######################### Data values #########################
32
+ MODES_WITH_VALUE = ["Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
33
+
34
+ ######################### Script class entrypoint #########################
35
+ class Script(scripts.Script):
36
+
37
+ def title(self):
38
+ return "Dynamic Thresholding (CFG Scale Fix)"
39
+
40
+ def show(self, is_img2img):
41
+ return scripts.AlwaysVisible
42
+
43
+ def ui(self, is_img2img):
44
+ def vis_change(is_vis):
45
+ return {"visible": is_vis, "__type__": "update"}
46
+ # "Dynamic Thresholding (CFG Scale Fix)"
47
+ dtrue = gr.Checkbox(value=True, visible=False)
48
+ dfalse = gr.Checkbox(value=False, visible=False)
49
+ with gr.Accordion("Dynamic Thresholding (CFG Scale Fix)", open=False, elem_id="dynthres_" + ("img2img" if is_img2img else "txt2img")):
50
+ with gr.Row():
51
+ enabled = gr.Checkbox(value=False, label="Enable Dynamic Thresholding (CFG Scale Fix)", elem_classes=["dynthres-enabled"], elem_id='dynthres_enabled')
52
+ with gr.Group():
53
+ gr.HTML(value=f"View <a style=\"border-bottom: 1px #00ffff dotted;\" href=\"https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/wiki/Usage-Tips\">the wiki for usage tips.</a><br><br>", elem_id='dynthres_wiki_link')
54
+ mimic_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='Mimic CFG Scale', value=7.0, elem_id='dynthres_mimic_scale')
55
+ with gr.Accordion("Advanced Options", open=False, elem_id='dynthres_advanced_opts'):
56
+ with gr.Row():
57
+ threshold_percentile = gr.Slider(minimum=90.0, value=100.0, maximum=100.0, step=0.05, label='Top percentile of latents to clamp', elem_id='dynthres_threshold_percentile')
58
+ interpolate_phi = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Interpolate Phi", value=1.0, elem_id='dynthres_interpolate_phi')
59
+ with gr.Row():
60
+ mimic_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="Mimic Scale Scheduler", elem_id='dynthres_mimic_mode')
61
+ cfg_mode = gr.Dropdown(dynthres_core.DynThresh.Modes, value="Constant", label="CFG Scale Scheduler", elem_id='dynthres_cfg_mode')
62
+ mimic_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=DISABLE_VISIBILITY, label="Minimum value of the Mimic Scale Scheduler", elem_id='dynthres_mimic_scale_min')
63
+ cfg_scale_min = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, visible=DISABLE_VISIBILITY, label="Minimum value of the CFG Scale Scheduler", elem_id='dynthres_cfg_scale_min')
64
+ sched_val = gr.Slider(minimum=0.0, maximum=40.0, step=0.5, value=4.0, visible=DISABLE_VISIBILITY, label="Scheduler Value", info="Value unique to the scheduler mode - for Power Up/Down, this is the power. For Linear/Cosine Repeating, this is the number of repeats per image.", elem_id='dynthres_sched_val')
65
+ with gr.Row():
66
+ separate_feature_channels = gr.Checkbox(value=True, label="Separate Feature Channels", elem_id='dynthres_separate_feature_channels')
67
+ scaling_startpoint = gr.Radio(["ZERO", "MEAN"], value="MEAN", label="Scaling Startpoint")
68
+ variability_measure = gr.Radio(["STD", "AD"], value="AD", label="Variability Measure")
69
+ def should_show_scheduler_value(cfg_mode, mimic_mode):
70
+ sched_vis = cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE or DISABLE_VISIBILITY
71
+ return vis_change(sched_vis), vis_change(mimic_mode != "Constant" or DISABLE_VISIBILITY), vis_change(cfg_mode != "Constant" or DISABLE_VISIBILITY)
72
+ cfg_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
73
+ mimic_mode.change(should_show_scheduler_value, inputs=[cfg_mode, mimic_mode], outputs=[sched_val, mimic_scale_min, cfg_scale_min])
74
+ enabled.change(
75
+ _js="dynthres_update_enabled",
76
+ fn=None,
77
+ inputs=[enabled, dtrue if is_img2img else dfalse],
78
+ show_progress = False)
79
+ self.infotext_fields = (
80
+ (enabled, lambda d: gr.Checkbox.update(value="Dynamic thresholding enabled" in d)),
81
+ (mimic_scale, "Mimic scale"),
82
+ (separate_feature_channels, "Separate Feature Channels"),
83
+ (scaling_startpoint, lambda d: gr.Radio.update(value=d.get("Scaling Startpoint", "MEAN"))),
84
+ (variability_measure, lambda d: gr.Radio.update(value=d.get("Variability Measure", "AD"))),
85
+ (interpolate_phi, "Interpolate Phi"),
86
+ (threshold_percentile, "Threshold percentile"),
87
+ (mimic_scale_min, "Mimic scale minimum"),
88
+ (mimic_mode, lambda d: gr.Dropdown.update(value=d.get("Mimic mode", "Constant"))),
89
+ (cfg_mode, lambda d: gr.Dropdown.update(value=d.get("CFG mode", "Constant"))),
90
+ (cfg_scale_min, "CFG scale minimum"),
91
+ (sched_val, "Scheduler value"))
92
+ return [enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi]
93
+
94
+ last_id = 0
95
+
96
+ def process_batch(self, p, enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi, batch_number, prompts, seeds, subseeds):
97
+ enabled = getattr(p, 'dynthres_enabled', enabled)
98
+ if not enabled:
99
+ return
100
+ orig_sampler_name = p.sampler_name
101
+ orig_latent_sampler_name = getattr(p, 'latent_sampler', None)
102
+ if orig_sampler_name in ["DDIM", "PLMS"]:
103
+ raise RuntimeError(f"Cannot use sampler {orig_sampler_name} with Dynamic Thresholding")
104
+ if orig_latent_sampler_name in ["DDIM", "PLMS"]:
105
+ raise RuntimeError(f"Cannot use secondary sampler {orig_latent_sampler_name} with Dynamic Thresholding")
106
+ if 'UniPC' in (orig_sampler_name, orig_latent_sampler_name) and p.enable_hr:
107
+ raise RuntimeError(f"UniPC does not support Hires Fix. Auto WebUI silently swaps to DDIM for this, which DynThresh does not support. Please swap to a sampler capable of img2img processing for HR Fix to work.")
108
+ mimic_scale = getattr(p, 'dynthres_mimic_scale', mimic_scale)
109
+ separate_feature_channels = getattr(p, 'dynthres_separate_feature_channels', separate_feature_channels)
110
+ scaling_startpoint = getattr(p, 'dynthres_scaling_startpoint', scaling_startpoint)
111
+ variability_measure = getattr(p, 'dynthres_variability_measure', variability_measure)
112
+ interpolate_phi = getattr(p, 'dynthres_interpolate_phi', interpolate_phi)
113
+ threshold_percentile = getattr(p, 'dynthres_threshold_percentile', threshold_percentile)
114
+ mimic_mode = getattr(p, 'dynthres_mimic_mode', mimic_mode)
115
+ mimic_scale_min = getattr(p, 'dynthres_mimic_scale_min', mimic_scale_min)
116
+ cfg_mode = getattr(p, 'dynthres_cfg_mode', cfg_mode)
117
+ cfg_scale_min = getattr(p, 'dynthres_cfg_scale_min', cfg_scale_min)
118
+ experiment_mode = getattr(p, 'dynthres_experiment_mode', 0)
119
+ sched_val = getattr(p, 'dynthres_scheduler_val', sched_val)
120
+ p.extra_generation_params["Dynamic thresholding enabled"] = True
121
+ p.extra_generation_params["Mimic scale"] = mimic_scale
122
+ p.extra_generation_params["Separate Feature Channels"] = separate_feature_channels
123
+ p.extra_generation_params["Scaling Startpoint"] = scaling_startpoint
124
+ p.extra_generation_params["Variability Measure"] = variability_measure
125
+ p.extra_generation_params["Interpolate Phi"] = interpolate_phi
126
+ p.extra_generation_params["Threshold percentile"] = threshold_percentile
127
+ p.extra_generation_params["Sampler"] = orig_sampler_name
128
+ if mimic_mode != "Constant":
129
+ p.extra_generation_params["Mimic mode"] = mimic_mode
130
+ p.extra_generation_params["Mimic scale minimum"] = mimic_scale_min
131
+ if cfg_mode != "Constant":
132
+ p.extra_generation_params["CFG mode"] = cfg_mode
133
+ p.extra_generation_params["CFG scale minimum"] = cfg_scale_min
134
+ if cfg_mode in MODES_WITH_VALUE or mimic_mode in MODES_WITH_VALUE:
135
+ p.extra_generation_params["Scheduler value"] = sched_val
136
+ # Note: the ID number is to protect the edge case of multiple simultaneous runs with different settings
137
+ Script.last_id += 1
138
+ # Percentage to portion
139
+ threshold_percentile *= 0.01
140
+
141
+ def make_sampler(orig_sampler_name):
142
+ fixed_sampler_name = f"{orig_sampler_name}_dynthres{Script.last_id}"
143
+
144
+ # Make a placeholder sampler
145
+ sampler = sd_samplers.all_samplers_map[orig_sampler_name]
146
+ dt_data = dynthres_core.DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, p.steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi)
147
+ if orig_sampler_name == "UniPC":
148
+ def unipc_constructor(model):
149
+ return CustomVanillaSDSampler(dynthres_unipc.CustomUniPCSampler, model, dt_data)
150
+ new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, unipc_constructor, sampler.aliases, sampler.options)
151
+ else:
152
+ def new_constructor(model):
153
+ result = sampler.constructor(model)
154
+ cfg = CustomCFGDenoiser(result if IS_AUTO_16 else result.model_wrap_cfg.inner_model, dt_data)
155
+ result.model_wrap_cfg = cfg
156
+ return result
157
+ new_sampler = sd_samplers_common.SamplerData(fixed_sampler_name, new_constructor, sampler.aliases, sampler.options)
158
+ return fixed_sampler_name, new_sampler
159
+
160
+ # Apply for usage
161
+ p.orig_sampler_name = orig_sampler_name
162
+ p.orig_latent_sampler_name = orig_latent_sampler_name
163
+ p.fixed_samplers = []
164
+
165
+ if orig_latent_sampler_name:
166
+ latent_sampler_name, latent_sampler = make_sampler(orig_latent_sampler_name)
167
+ sd_samplers.all_samplers_map[latent_sampler_name] = latent_sampler
168
+ p.fixed_samplers.append(latent_sampler_name)
169
+ p.latent_sampler = latent_sampler_name
170
+
171
+ if orig_sampler_name != orig_latent_sampler_name:
172
+ p.sampler_name, new_sampler = make_sampler(orig_sampler_name)
173
+ sd_samplers.all_samplers_map[p.sampler_name] = new_sampler
174
+ p.fixed_samplers.append(p.sampler_name)
175
+ else:
176
+ p.sampler_name = p.latent_sampler
177
+
178
+ if p.sampler is not None:
179
+ p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
180
+
181
+ def postprocess_batch(self, p, enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi, batch_number, images):
182
+ if not enabled or not hasattr(p, 'orig_sampler_name'):
183
+ return
184
+ p.sampler_name = p.orig_sampler_name
185
+ if p.orig_latent_sampler_name:
186
+ p.latent_sampler = p.orig_latent_sampler_name
187
+ for added_sampler in p.fixed_samplers:
188
+ del sd_samplers.all_samplers_map[added_sampler]
189
+ del p.fixed_samplers
190
+ del p.orig_sampler_name
191
+ del p.orig_latent_sampler_name
192
+
193
+ ######################### CompVis Implementation logic #########################
194
+
195
+ class CustomVanillaSDSampler(sd_samplers_compvis.VanillaStableDiffusionSampler):
196
+ def __init__(self, constructor, sd_model, dt_data):
197
+ super().__init__(constructor, sd_model)
198
+ self.sampler.main_class = dt_data
199
+
200
+ ######################### K-Diffusion Implementation logic #########################
201
+
202
+ class CustomCFGDenoiser(cfgdenoisekdiff):
203
+ def __init__(self, model, dt_data):
204
+ super().__init__(model)
205
+ self.main_class = dt_data
206
+
207
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
208
+ if isinstance(uncond, dict) and 'crossattn' in uncond:
209
+ uncond = uncond['crossattn']
210
+ denoised_uncond = x_out[-uncond.shape[0]:]
211
+ # conds_list shape is (batch, cond, 2)
212
+ weights = torch.tensor(conds_list, device=uncond.device).select(2, 1)
213
+ weights = weights.reshape(*weights.shape, 1, 1, 1)
214
+ self.main_class.step = self.step
215
+ if hasattr(self, 'total_steps'):
216
+ self.main_class.max_steps = self.total_steps
217
+
218
+ if self.main_class.experiment_mode >= 4 and self.main_class.experiment_mode <= 5:
219
+ # https://arxiv.org/pdf/2305.08891.pdf "Rescale CFG". It's not good, but if you want to test it, just set experiment_mode = 4 + phi.
220
+ denoised = torch.clone(denoised_uncond)
221
+ fi = self.main_class.experiment_mode - 4.0
222
+ for i, conds in enumerate(conds_list):
223
+ for cond_index, weight in conds:
224
+ xcfg = (denoised_uncond[i] + (x_out[cond_index] - denoised_uncond[i]) * (cond_scale * weight))
225
+ xrescaled = xcfg * (torch.std(x_out[cond_index]) / torch.std(xcfg))
226
+ xfinal = fi * xrescaled + (1.0 - fi) * xcfg
227
+ denoised[i] = xfinal
228
+ return denoised
229
+
230
+ return self.main_class.dynthresh(x_out[:-uncond.shape[0]], denoised_uncond, cond_scale, weights)
231
+
232
+ ######################### XYZ Plot Script Support logic #########################
233
+
234
+ def make_axis_options():
235
+ xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
236
+ def apply_mimic_scale(p, x, xs):
237
+ if x != 0:
238
+ setattr(p, "dynthres_enabled", True)
239
+ setattr(p, "dynthres_mimic_scale", x)
240
+ else:
241
+ setattr(p, "dynthres_enabled", False)
242
+ def confirm_scheduler(p, xs):
243
+ for x in xs:
244
+ if x not in dynthres_core.DynThresh.Modes:
245
+ raise RuntimeError(f"Unknown Scheduler: {x}")
246
+ extra_axis_options = [
247
+ xyz_grid.AxisOption("[DynThres] Mimic Scale", float, apply_mimic_scale),
248
+ xyz_grid.AxisOption("[DynThres] Separate Feature Channels", int,
249
+ xyz_grid.apply_field("dynthres_separate_feature_channels")),
250
+ xyz_grid.AxisOption("[DynThres] Scaling Startpoint", str, xyz_grid.apply_field("dynthres_scaling_startpoint"), choices=lambda:['ZERO', 'MEAN']),
251
+ xyz_grid.AxisOption("[DynThres] Variability Measure", str, xyz_grid.apply_field("dynthres_variability_measure"), choices=lambda:['STD', 'AD']),
252
+ xyz_grid.AxisOption("[DynThres] Interpolate Phi", float, xyz_grid.apply_field("dynthres_interpolate_phi")),
253
+ xyz_grid.AxisOption("[DynThres] Threshold Percentile", float, xyz_grid.apply_field("dynthres_threshold_percentile")),
254
+ xyz_grid.AxisOption("[DynThres] Mimic Scheduler", str, xyz_grid.apply_field("dynthres_mimic_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
255
+ xyz_grid.AxisOption("[DynThres] Mimic minimum", float, xyz_grid.apply_field("dynthres_mimic_scale_min")),
256
+ xyz_grid.AxisOption("[DynThres] CFG Scheduler", str, xyz_grid.apply_field("dynthres_cfg_mode"), confirm=confirm_scheduler, choices=lambda: dynthres_core.DynThresh.Modes),
257
+ xyz_grid.AxisOption("[DynThres] CFG minimum", float, xyz_grid.apply_field("dynthres_cfg_scale_min")),
258
+ xyz_grid.AxisOption("[DynThres] Scheduler value", float, xyz_grid.apply_field("dynthres_scheduler_val"))
259
+ ]
260
+ if not any("[DynThres]" in x.label for x in xyz_grid.axis_options):
261
+ xyz_grid.axis_options.extend(extra_axis_options)
262
+
263
+ def callback_before_ui():
264
+ try:
265
+ make_axis_options()
266
+ except Exception as e:
267
+ traceback.print_exc()
268
+ print(f"Failed to add support for X/Y/Z Plot Script because: {e}")
269
+
270
+ script_callbacks.on_before_ui(callback_before_ui)
sigmas_tools_and_the_golden_scheduler/.github/workflows/publish.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "pyproject.toml"
9
+
10
+ jobs:
11
+ publish-node:
12
+ name: Publish Custom Node to registry
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Check out code
16
+ uses: actions/checkout@v4
17
+ - name: Publish Custom Node
18
+ uses: Comfy-Org/publish-node-action@main
19
+ with:
20
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
21
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
sigmas_tools_and_the_golden_scheduler/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (273 Bytes). View file
 
sigmas_tools_and_the_golden_scheduler/__pycache__/sigmas_merge.cpython-312.pyc ADDED
Binary file (20.2 kB). View file
 
stable-diffusion-temperature-settings/.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ patreon: extraltodeus
stable-diffusion-temperature-settings/.github/workflows/publish.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - main
7
+ - master
8
+ paths:
9
+ - "pyproject.toml"
10
+
11
+ jobs:
12
+ publish-node:
13
+ name: Publish Custom Node to registry
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - name: Check out code
17
+ uses: actions/checkout@v4
18
+ - name: Publish Custom Node
19
+ uses: Comfy-Org/publish-node-action@main
20
+ with:
21
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
22
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
stable-diffusion-temperature-settings/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (342 Bytes). View file
 
stable-diffusion-temperature-settings/__pycache__/nodes.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
stable-diffusion-temperature-settings/workflows/tinybottle.png ADDED
ultimate-upscale-for-automatic1111/scripts/ultimate-upscale.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import gradio as gr
3
+ from PIL import Image, ImageDraw, ImageOps
4
+ from modules import processing, shared, images, devices, scripts
5
+ from modules.processing import StableDiffusionProcessing
6
+ from modules.processing import Processed
7
+ from modules.shared import opts, state
8
+ from enum import Enum
9
+
10
+ elem_id_prefix = "ultimateupscale"
11
+
12
+ class USDUMode(Enum):
13
+ LINEAR = 0
14
+ CHESS = 1
15
+ NONE = 2
16
+
17
+ class USDUSFMode(Enum):
18
+ NONE = 0
19
+ BAND_PASS = 1
20
+ HALF_TILE = 2
21
+ HALF_TILE_PLUS_INTERSECTIONS = 3
22
+
23
+ class USDUpscaler():
24
+
25
+ def __init__(self, p, image, upscaler_index:int, save_redraw, save_seams_fix, tile_width, tile_height) -> None:
26
+ self.p:StableDiffusionProcessing = p
27
+ self.image:Image = image
28
+ self.scale_factor = math.ceil(max(p.width, p.height) / max(image.width, image.height))
29
+ self.upscaler = shared.sd_upscalers[upscaler_index]
30
+ self.redraw = USDURedraw()
31
+ self.redraw.save = save_redraw
32
+ self.redraw.tile_width = tile_width if tile_width > 0 else tile_height
33
+ self.redraw.tile_height = tile_height if tile_height > 0 else tile_width
34
+ self.seams_fix = USDUSeamsFix()
35
+ self.seams_fix.save = save_seams_fix
36
+ self.seams_fix.tile_width = tile_width if tile_width > 0 else tile_height
37
+ self.seams_fix.tile_height = tile_height if tile_height > 0 else tile_width
38
+ self.initial_info = None
39
+ self.rows = math.ceil(self.p.height / self.redraw.tile_height)
40
+ self.cols = math.ceil(self.p.width / self.redraw.tile_width)
41
+
42
+ def get_factor(self, num):
43
+ # Its just return, don't need elif
44
+ if num == 1:
45
+ return 2
46
+ if num % 4 == 0:
47
+ return 4
48
+ if num % 3 == 0:
49
+ return 3
50
+ if num % 2 == 0:
51
+ return 2
52
+ return 0
53
+
54
+ def get_factors(self):
55
+ scales = []
56
+ current_scale = 1
57
+ current_scale_factor = self.get_factor(self.scale_factor)
58
+ while current_scale_factor == 0:
59
+ self.scale_factor += 1
60
+ current_scale_factor = self.get_factor(self.scale_factor)
61
+ while current_scale < self.scale_factor:
62
+ current_scale_factor = self.get_factor(self.scale_factor // current_scale)
63
+ scales.append(current_scale_factor)
64
+ current_scale = current_scale * current_scale_factor
65
+ if current_scale_factor == 0:
66
+ break
67
+ self.scales = enumerate(scales)
68
+
69
+ def upscale(self):
70
+ # Log info
71
+ print(f"Canva size: {self.p.width}x{self.p.height}")
72
+ print(f"Image size: {self.image.width}x{self.image.height}")
73
+ print(f"Scale factor: {self.scale_factor}")
74
+ # Check upscaler is not empty
75
+ if self.upscaler.name == "None":
76
+ self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
77
+ return
78
+ # Get list with scale factors
79
+ self.get_factors()
80
+ # Upscaling image over all factors
81
+ for index, value in self.scales:
82
+ print(f"Upscaling iteration {index+1} with scale factor {value}")
83
+ self.image = self.upscaler.scaler.upscale(self.image, value, self.upscaler.data_path)
84
+ # Resize image to set values
85
+ self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
86
+
87
+ def setup_redraw(self, redraw_mode, padding, mask_blur):
88
+ self.redraw.mode = USDUMode(redraw_mode)
89
+ self.redraw.enabled = self.redraw.mode != USDUMode.NONE
90
+ self.redraw.padding = padding
91
+ self.p.mask_blur = mask_blur
92
+
93
+ def setup_seams_fix(self, padding, denoise, mask_blur, width, mode):
94
+ self.seams_fix.padding = padding
95
+ self.seams_fix.denoise = denoise
96
+ self.seams_fix.mask_blur = mask_blur
97
+ self.seams_fix.width = width
98
+ self.seams_fix.mode = USDUSFMode(mode)
99
+ self.seams_fix.enabled = self.seams_fix.mode != USDUSFMode.NONE
100
+
101
+ def save_image(self):
102
+ if type(self.p.prompt) != list:
103
+ images.save_image(self.image, self.p.outpath_samples, "", self.p.seed, self.p.prompt, opts.samples_format, info=self.initial_info, p=self.p)
104
+ else:
105
+ images.save_image(self.image, self.p.outpath_samples, "", self.p.seed, self.p.prompt[0], opts.samples_format, info=self.initial_info, p=self.p)
106
+
107
+ def calc_jobs_count(self):
108
+ redraw_job_count = (self.rows * self.cols) if self.redraw.enabled else 0
109
+ seams_job_count = 0
110
+ if self.seams_fix.mode == USDUSFMode.BAND_PASS:
111
+ seams_job_count = self.rows + self.cols - 2
112
+ elif self.seams_fix.mode == USDUSFMode.HALF_TILE:
113
+ seams_job_count = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols
114
+ elif self.seams_fix.mode == USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS:
115
+ seams_job_count = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols + (self.rows - 1) * (self.cols - 1)
116
+
117
+ state.job_count = redraw_job_count + seams_job_count
118
+
119
+ def print_info(self):
120
+ print(f"Tile size: {self.redraw.tile_width}x{self.redraw.tile_height}")
121
+ print(f"Tiles amount: {self.rows * self.cols}")
122
+ print(f"Grid: {self.rows}x{self.cols}")
123
+ print(f"Redraw enabled: {self.redraw.enabled}")
124
+ print(f"Seams fix mode: {self.seams_fix.mode.name}")
125
+
126
+ def add_extra_info(self):
127
+ self.p.extra_generation_params["Ultimate SD upscale upscaler"] = self.upscaler.name
128
+ self.p.extra_generation_params["Ultimate SD upscale tile_width"] = self.redraw.tile_width
129
+ self.p.extra_generation_params["Ultimate SD upscale tile_height"] = self.redraw.tile_height
130
+ self.p.extra_generation_params["Ultimate SD upscale mask_blur"] = self.p.mask_blur
131
+ self.p.extra_generation_params["Ultimate SD upscale padding"] = self.redraw.padding
132
+
133
+ def process(self):
134
+ state.begin()
135
+ self.calc_jobs_count()
136
+ self.result_images = []
137
+ if self.redraw.enabled:
138
+ self.image = self.redraw.start(self.p, self.image, self.rows, self.cols)
139
+ self.initial_info = self.redraw.initial_info
140
+ self.result_images.append(self.image)
141
+ if self.redraw.save:
142
+ self.save_image()
143
+
144
+ if self.seams_fix.enabled:
145
+ self.image = self.seams_fix.start(self.p, self.image, self.rows, self.cols)
146
+ self.initial_info = self.seams_fix.initial_info
147
+ self.result_images.append(self.image)
148
+ if self.seams_fix.save:
149
+ self.save_image()
150
+ state.end()
151
+
152
+ class USDURedraw():
153
+
154
+ def init_draw(self, p, width, height):
155
+ p.inpaint_full_res = True
156
+ p.inpaint_full_res_padding = self.padding
157
+ p.width = math.ceil((self.tile_width+self.padding) / 64) * 64
158
+ p.height = math.ceil((self.tile_height+self.padding) / 64) * 64
159
+ mask = Image.new("L", (width, height), "black")
160
+ draw = ImageDraw.Draw(mask)
161
+ return mask, draw
162
+
163
+ def calc_rectangle(self, xi, yi):
164
+ x1 = xi * self.tile_width
165
+ y1 = yi * self.tile_height
166
+ x2 = xi * self.tile_width + self.tile_width
167
+ y2 = yi * self.tile_height + self.tile_height
168
+
169
+ return x1, y1, x2, y2
170
+
171
+ def linear_process(self, p, image, rows, cols):
172
+ mask, draw = self.init_draw(p, image.width, image.height)
173
+ for yi in range(rows):
174
+ for xi in range(cols):
175
+ if state.interrupted:
176
+ break
177
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
178
+ p.init_images = [image]
179
+ p.image_mask = mask
180
+ processed = processing.process_images(p)
181
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
182
+ if (len(processed.images) > 0):
183
+ image = processed.images[0]
184
+
185
+ p.width = image.width
186
+ p.height = image.height
187
+ self.initial_info = processed.infotext(p, 0)
188
+
189
+ return image
190
+
191
+ def chess_process(self, p, image, rows, cols):
192
+ mask, draw = self.init_draw(p, image.width, image.height)
193
+ tiles = []
194
+ # calc tiles colors
195
+ for yi in range(rows):
196
+ for xi in range(cols):
197
+ if state.interrupted:
198
+ break
199
+ if xi == 0:
200
+ tiles.append([])
201
+ color = xi % 2 == 0
202
+ if yi > 0 and yi % 2 != 0:
203
+ color = not color
204
+ tiles[yi].append(color)
205
+
206
+ for yi in range(len(tiles)):
207
+ for xi in range(len(tiles[yi])):
208
+ if state.interrupted:
209
+ break
210
+ if not tiles[yi][xi]:
211
+ tiles[yi][xi] = not tiles[yi][xi]
212
+ continue
213
+ tiles[yi][xi] = not tiles[yi][xi]
214
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
215
+ p.init_images = [image]
216
+ p.image_mask = mask
217
+ processed = processing.process_images(p)
218
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
219
+ if (len(processed.images) > 0):
220
+ image = processed.images[0]
221
+
222
+ for yi in range(len(tiles)):
223
+ for xi in range(len(tiles[yi])):
224
+ if state.interrupted:
225
+ break
226
+ if not tiles[yi][xi]:
227
+ continue
228
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
229
+ p.init_images = [image]
230
+ p.image_mask = mask
231
+ processed = processing.process_images(p)
232
+ draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
233
+ if (len(processed.images) > 0):
234
+ image = processed.images[0]
235
+
236
+ p.width = image.width
237
+ p.height = image.height
238
+ self.initial_info = processed.infotext(p, 0)
239
+
240
+ return image
241
+
242
+ def start(self, p, image, rows, cols):
243
+ self.initial_info = None
244
+ if self.mode == USDUMode.LINEAR:
245
+ return self.linear_process(p, image, rows, cols)
246
+ if self.mode == USDUMode.CHESS:
247
+ return self.chess_process(p, image, rows, cols)
248
+
249
+ class USDUSeamsFix():
250
+
251
+ def init_draw(self, p):
252
+ self.initial_info = None
253
+ p.width = math.ceil((self.tile_width+self.padding) / 64) * 64
254
+ p.height = math.ceil((self.tile_height+self.padding) / 64) * 64
255
+
256
+ def half_tile_process(self, p, image, rows, cols):
257
+
258
+ self.init_draw(p)
259
+ processed = None
260
+
261
+ gradient = Image.linear_gradient("L")
262
+ row_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
263
+ row_gradient.paste(gradient.resize(
264
+ (self.tile_width, self.tile_height//2), resample=Image.BICUBIC), (0, 0))
265
+ row_gradient.paste(gradient.rotate(180).resize(
266
+ (self.tile_width, self.tile_height//2), resample=Image.BICUBIC),
267
+ (0, self.tile_height//2))
268
+ col_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
269
+ col_gradient.paste(gradient.rotate(90).resize(
270
+ (self.tile_width//2, self.tile_height), resample=Image.BICUBIC), (0, 0))
271
+ col_gradient.paste(gradient.rotate(270).resize(
272
+ (self.tile_width//2, self.tile_height), resample=Image.BICUBIC), (self.tile_width//2, 0))
273
+
274
+ p.denoising_strength = self.denoise
275
+ p.mask_blur = self.mask_blur
276
+
277
+ for yi in range(rows-1):
278
+ for xi in range(cols):
279
+ if state.interrupted:
280
+ break
281
+ p.width = self.tile_width
282
+ p.height = self.tile_height
283
+ p.inpaint_full_res = True
284
+ p.inpaint_full_res_padding = self.padding
285
+ mask = Image.new("L", (image.width, image.height), "black")
286
+ mask.paste(row_gradient, (xi*self.tile_width, yi*self.tile_height + self.tile_height//2))
287
+
288
+ p.init_images = [image]
289
+ p.image_mask = mask
290
+ processed = processing.process_images(p)
291
+ if (len(processed.images) > 0):
292
+ image = processed.images[0]
293
+
294
+ for yi in range(rows):
295
+ for xi in range(cols-1):
296
+ if state.interrupted:
297
+ break
298
+ p.width = self.tile_width
299
+ p.height = self.tile_height
300
+ p.inpaint_full_res = True
301
+ p.inpaint_full_res_padding = self.padding
302
+ mask = Image.new("L", (image.width, image.height), "black")
303
+ mask.paste(col_gradient, (xi*self.tile_width+self.tile_width//2, yi*self.tile_height))
304
+
305
+ p.init_images = [image]
306
+ p.image_mask = mask
307
+ processed = processing.process_images(p)
308
+ if (len(processed.images) > 0):
309
+ image = processed.images[0]
310
+
311
+ p.width = image.width
312
+ p.height = image.height
313
+ if processed is not None:
314
+ self.initial_info = processed.infotext(p, 0)
315
+
316
+ return image
317
+
318
+ def half_tile_process_corners(self, p, image, rows, cols):
319
+ fixed_image = self.half_tile_process(p, image, rows, cols)
320
+ processed = None
321
+ self.init_draw(p)
322
+ gradient = Image.radial_gradient("L").resize(
323
+ (self.tile_width, self.tile_height), resample=Image.BICUBIC)
324
+ gradient = ImageOps.invert(gradient)
325
+ p.denoising_strength = self.denoise
326
+ #p.mask_blur = 0
327
+ p.mask_blur = self.mask_blur
328
+
329
+ for yi in range(rows-1):
330
+ for xi in range(cols-1):
331
+ if state.interrupted:
332
+ break
333
+ p.width = self.tile_width
334
+ p.height = self.tile_height
335
+ p.inpaint_full_res = True
336
+ p.inpaint_full_res_padding = 0
337
+ mask = Image.new("L", (fixed_image.width, fixed_image.height), "black")
338
+ mask.paste(gradient, (xi*self.tile_width + self.tile_width//2,
339
+ yi*self.tile_height + self.tile_height//2))
340
+
341
+ p.init_images = [fixed_image]
342
+ p.image_mask = mask
343
+ processed = processing.process_images(p)
344
+ if (len(processed.images) > 0):
345
+ fixed_image = processed.images[0]
346
+
347
+ p.width = fixed_image.width
348
+ p.height = fixed_image.height
349
+ if processed is not None:
350
+ self.initial_info = processed.infotext(p, 0)
351
+
352
+ return fixed_image
353
+
354
+ def band_pass_process(self, p, image, cols, rows):
355
+
356
+ self.init_draw(p)
357
+ processed = None
358
+
359
+ p.denoising_strength = self.denoise
360
+ p.mask_blur = 0
361
+
362
+ gradient = Image.linear_gradient("L")
363
+ mirror_gradient = Image.new("L", (256, 256), "black")
364
+ mirror_gradient.paste(gradient.resize((256, 128), resample=Image.BICUBIC), (0, 0))
365
+ mirror_gradient.paste(gradient.rotate(180).resize((256, 128), resample=Image.BICUBIC), (0, 128))
366
+
367
+ row_gradient = mirror_gradient.resize((image.width, self.width), resample=Image.BICUBIC)
368
+ col_gradient = mirror_gradient.rotate(90).resize((self.width, image.height), resample=Image.BICUBIC)
369
+
370
+ for xi in range(1, rows):
371
+ if state.interrupted:
372
+ break
373
+ p.width = self.width + self.padding * 2
374
+ p.height = image.height
375
+ p.inpaint_full_res = True
376
+ p.inpaint_full_res_padding = self.padding
377
+ mask = Image.new("L", (image.width, image.height), "black")
378
+ mask.paste(col_gradient, (xi * self.tile_width - self.width // 2, 0))
379
+
380
+ p.init_images = [image]
381
+ p.image_mask = mask
382
+ processed = processing.process_images(p)
383
+ if (len(processed.images) > 0):
384
+ image = processed.images[0]
385
+ for yi in range(1, cols):
386
+ if state.interrupted:
387
+ break
388
+ p.width = image.width
389
+ p.height = self.width + self.padding * 2
390
+ p.inpaint_full_res = True
391
+ p.inpaint_full_res_padding = self.padding
392
+ mask = Image.new("L", (image.width, image.height), "black")
393
+ mask.paste(row_gradient, (0, yi * self.tile_height - self.width // 2))
394
+
395
+ p.init_images = [image]
396
+ p.image_mask = mask
397
+ processed = processing.process_images(p)
398
+ if (len(processed.images) > 0):
399
+ image = processed.images[0]
400
+
401
+ p.width = image.width
402
+ p.height = image.height
403
+ if processed is not None:
404
+ self.initial_info = processed.infotext(p, 0)
405
+
406
+ return image
407
+
408
+ def start(self, p, image, rows, cols):
409
+ if USDUSFMode(self.mode) == USDUSFMode.BAND_PASS:
410
+ return self.band_pass_process(p, image, rows, cols)
411
+ elif USDUSFMode(self.mode) == USDUSFMode.HALF_TILE:
412
+ return self.half_tile_process(p, image, rows, cols)
413
+ elif USDUSFMode(self.mode) == USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS:
414
+ return self.half_tile_process_corners(p, image, rows, cols)
415
+ else:
416
+ return image
417
+
418
+ class Script(scripts.Script):
419
+ def title(self):
420
+ return "Ultimate SD upscale"
421
+
422
+ def show(self, is_img2img):
423
+ return is_img2img
424
+
425
+ def ui(self, is_img2img):
426
+
427
+ target_size_types = [
428
+ "From img2img2 settings",
429
+ "Custom size",
430
+ "Scale from image size"
431
+ ]
432
+
433
+ seams_fix_types = [
434
+ "None",
435
+ "Band pass",
436
+ "Half tile offset pass",
437
+ "Half tile offset pass + intersections"
438
+ ]
439
+
440
+ redrow_modes = [
441
+ "Linear",
442
+ "Chess",
443
+ "None"
444
+ ]
445
+
446
+ info = gr.HTML(
447
+ "<p style=\"margin-bottom:0.75em\">Will upscale the image depending on the selected target size type</p>")
448
+
449
+ with gr.Row():
450
+ target_size_type = gr.Dropdown(label="Target size type", elem_id=f"{elem_id_prefix}_target_size_type", choices=[k for k in target_size_types], type="index",
451
+ value=next(iter(target_size_types)))
452
+
453
+ custom_width = gr.Slider(label='Custom width', elem_id=f"{elem_id_prefix}_custom_width", minimum=64, maximum=8192, step=64, value=2048, visible=False, interactive=True)
454
+ custom_height = gr.Slider(label='Custom height', elem_id=f"{elem_id_prefix}_custom_height", minimum=64, maximum=8192, step=64, value=2048, visible=False, interactive=True)
455
+ custom_scale = gr.Slider(label='Scale', elem_id=f"{elem_id_prefix}_custom_scale", minimum=1, maximum=16, step=0.01, value=2, visible=False, interactive=True)
456
+
457
+ gr.HTML("<p style=\"margin-bottom:0.75em\">Redraw options:</p>")
458
+ with gr.Row():
459
+ upscaler_index = gr.Radio(label='Upscaler', elem_id=f"{elem_id_prefix}_upscaler_index", choices=[x.name for x in shared.sd_upscalers],
460
+ value=shared.sd_upscalers[0].name, type="index")
461
+ with gr.Row():
462
+ redraw_mode = gr.Dropdown(label="Type", elem_id=f"{elem_id_prefix}_redraw_mode", choices=[k for k in redrow_modes], type="index", value=next(iter(redrow_modes)))
463
+ tile_width = gr.Slider(elem_id=f"{elem_id_prefix}_tile_width", minimum=0, maximum=2048, step=64, label='Tile width', value=512)
464
+ tile_height = gr.Slider(elem_id=f"{elem_id_prefix}_tile_height", minimum=0, maximum=2048, step=64, label='Tile height', value=0)
465
+ mask_blur = gr.Slider(elem_id=f"{elem_id_prefix}_mask_blur", label='Mask blur', minimum=0, maximum=64, step=1, value=8)
466
+ padding = gr.Slider(elem_id=f"{elem_id_prefix}_padding", label='Padding', minimum=0, maximum=512, step=1, value=32)
467
+ gr.HTML("<p style=\"margin-bottom:0.75em\">Seams fix:</p>")
468
+ with gr.Row():
469
+ seams_fix_type = gr.Dropdown(label="Type", elem_id=f"{elem_id_prefix}_seams_fix_type", choices=[k for k in seams_fix_types], type="index", value=next(iter(seams_fix_types)))
470
+ seams_fix_denoise = gr.Slider(label='Denoise', elem_id=f"{elem_id_prefix}_seams_fix_denoise", minimum=0, maximum=1, step=0.01, value=0.35, visible=False, interactive=True)
471
+ seams_fix_width = gr.Slider(label='Width', elem_id=f"{elem_id_prefix}_seams_fix_width", minimum=0, maximum=128, step=1, value=64, visible=False, interactive=True)
472
+ seams_fix_mask_blur = gr.Slider(label='Mask blur', elem_id=f"{elem_id_prefix}_seams_fix_mask_blur", minimum=0, maximum=64, step=1, value=4, visible=False, interactive=True)
473
+ seams_fix_padding = gr.Slider(label='Padding', elem_id=f"{elem_id_prefix}_seams_fix_padding", minimum=0, maximum=128, step=1, value=16, visible=False, interactive=True)
474
+ gr.HTML("<p style=\"margin-bottom:0.75em\">Save options:</p>")
475
+ with gr.Row():
476
+ save_upscaled_image = gr.Checkbox(label="Upscaled", elem_id=f"{elem_id_prefix}_save_upscaled_image", value=True)
477
+ save_seams_fix_image = gr.Checkbox(label="Seams fix", elem_id=f"{elem_id_prefix}_save_seams_fix_image", value=False)
478
+
479
+ def select_fix_type(fix_index):
480
+ all_visible = fix_index != 0
481
+ mask_blur_visible = fix_index == 2 or fix_index == 3
482
+ width_visible = fix_index == 1
483
+
484
+ return [gr.update(visible=all_visible),
485
+ gr.update(visible=width_visible),
486
+ gr.update(visible=mask_blur_visible),
487
+ gr.update(visible=all_visible)]
488
+
489
+ seams_fix_type.change(
490
+ fn=select_fix_type,
491
+ inputs=seams_fix_type,
492
+ outputs=[seams_fix_denoise, seams_fix_width, seams_fix_mask_blur, seams_fix_padding]
493
+ )
494
+
495
+ def select_scale_type(scale_index):
496
+ is_custom_size = scale_index == 1
497
+ is_custom_scale = scale_index == 2
498
+
499
+ return [gr.update(visible=is_custom_size),
500
+ gr.update(visible=is_custom_size),
501
+ gr.update(visible=is_custom_scale),
502
+ ]
503
+
504
+ target_size_type.change(
505
+ fn=select_scale_type,
506
+ inputs=target_size_type,
507
+ outputs=[custom_width, custom_height, custom_scale]
508
+ )
509
+
510
+ def init_field(scale_name):
511
+ try:
512
+ scale_index = target_size_types.index(scale_name)
513
+ custom_width.visible = custom_height.visible = scale_index == 1
514
+ custom_scale.visible = scale_index == 2
515
+ except:
516
+ pass
517
+
518
+ target_size_type.init_field = init_field
519
+
520
+ return [info, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding,
521
+ upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur,
522
+ seams_fix_type, target_size_type, custom_width, custom_height, custom_scale]
523
+
524
+ def run(self, p, _, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding,
525
+ upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur,
526
+ seams_fix_type, target_size_type, custom_width, custom_height, custom_scale):
527
+
528
+ # Init
529
+ processing.fix_seed(p)
530
+ devices.torch_gc()
531
+
532
+ p.do_not_save_grid = True
533
+ p.do_not_save_samples = True
534
+ p.inpaint_full_res = False
535
+
536
+ p.inpainting_fill = 1
537
+ p.n_iter = 1
538
+ p.batch_size = 1
539
+
540
+ seed = p.seed
541
+
542
+ # Init image
543
+ init_img = p.init_images[0]
544
+ if init_img == None:
545
+ return Processed(p, [], seed, "Empty image")
546
+ init_img = images.flatten(init_img, opts.img2img_background_color)
547
+
548
+ #override size
549
+ if target_size_type == 1:
550
+ p.width = custom_width
551
+ p.height = custom_height
552
+ if target_size_type == 2:
553
+ p.width = math.ceil((init_img.width * custom_scale) / 64) * 64
554
+ p.height = math.ceil((init_img.height * custom_scale) / 64) * 64
555
+
556
+ # Upscaling
557
+ upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height)
558
+ upscaler.upscale()
559
+
560
+ # Drawing
561
+ upscaler.setup_redraw(redraw_mode, padding, mask_blur)
562
+ upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type)
563
+ upscaler.print_info()
564
+ upscaler.add_extra_info()
565
+ upscaler.process()
566
+ result_images = upscaler.result_images
567
+
568
+ return Processed(p, result_images, seed, upscaler.initial_info if upscaler.initial_info is not None else "")
569
+
was-node-suite-comfyui/.github/workflows/publish_action.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "pyproject.toml"
9
+
10
+ jobs:
11
+ publish-node:
12
+ name: Publish Custom Node to registry
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Check out code
16
+ uses: actions/checkout@v4
17
+ - name: Publish Custom Node
18
+ uses: Comfy-Org/publish-node-action@main
19
+ with:
20
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
was-node-suite-comfyui/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (260 Bytes). View file
 
was-node-suite-comfyui/modules/BLIP/__init__.py ADDED
File without changes
was-node-suite-comfyui/modules/BLIP/blip_med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
was-node-suite-comfyui/modules/BLIP/blip_module.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ from .blip_vit import VisionTransformer, interpolate_pos_embed
12
+ from .blip_med import BertConfig, BertModel, BertLMHeadModel
13
+ from transformers import BertTokenizer
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+
19
+ import os
20
+ from urllib.parse import urlparse
21
+ from timm.models.hub import download_cached_file
22
+ import numpy as np
23
+
24
+ from pathlib import Path
25
+ LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
26
+
27
+ # BLIP
28
+
29
+ class BLIP_Base(nn.Module):
30
+ def __init__(self,
31
+ med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
32
+ image_size = 224,
33
+ vit = 'base',
34
+ vit_grad_ckpt = False,
35
+ vit_ckpt_layer = 0,
36
+ ):
37
+ """
38
+ Args:
39
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
40
+ image_size (int): input image size
41
+ vit (str): model size of vision transformer
42
+ """
43
+ super().__init__()
44
+
45
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
46
+ self.tokenizer = init_tokenizer()
47
+ med_config = BertConfig.from_json_file(med_config)
48
+ med_config.encoder_width = vision_width
49
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
50
+
51
+
52
+ def forward(self, image, caption, mode):
53
+
54
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
55
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
56
+
57
+ if mode=='image':
58
+ # return image features
59
+ image_embeds = self.visual_encoder(image)
60
+ return image_embeds
61
+
62
+ elif mode=='text':
63
+ # return text features
64
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
65
+ return_dict = True, mode = 'text')
66
+ return text_output.last_hidden_state
67
+
68
+ elif mode=='multimodal':
69
+ # return multimodel features
70
+ image_embeds = self.visual_encoder(image)
71
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
72
+
73
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
74
+ output = self.text_encoder(text.input_ids,
75
+ attention_mask = text.attention_mask,
76
+ encoder_hidden_states = image_embeds,
77
+ encoder_attention_mask = image_atts,
78
+ return_dict = True,
79
+ )
80
+ return output.last_hidden_state
81
+
82
+
83
+
84
+ class BLIP_Decoder(nn.Module):
85
+ def __init__(self,
86
+ med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
87
+ image_size = 384,
88
+ vit = 'base',
89
+ vit_grad_ckpt = False,
90
+ vit_ckpt_layer = 0,
91
+ prompt = 'a picture of ',
92
+ ):
93
+ """
94
+ Args:
95
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
96
+ image_size (int): input image size
97
+ vit (str): model size of vision transformer
98
+ """
99
+ super().__init__()
100
+
101
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
102
+ self.tokenizer = init_tokenizer()
103
+ med_config = BertConfig.from_json_file(med_config)
104
+ med_config.encoder_width = vision_width
105
+ self.text_decoder = BertLMHeadModel(config=med_config)
106
+
107
+ self.prompt = prompt
108
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
109
+
110
+
111
+ def forward(self, image, caption):
112
+
113
+ image_embeds = self.visual_encoder(image)
114
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
115
+
116
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
117
+
118
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
119
+
120
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
121
+ decoder_targets[:,:self.prompt_length] = -100
122
+
123
+ decoder_output = self.text_decoder(text.input_ids,
124
+ attention_mask = text.attention_mask,
125
+ encoder_hidden_states = image_embeds,
126
+ encoder_attention_mask = image_atts,
127
+ labels = decoder_targets,
128
+ return_dict = True,
129
+ )
130
+ loss_lm = decoder_output.loss
131
+
132
+ return loss_lm
133
+
134
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
135
+ image_embeds = self.visual_encoder(image)
136
+
137
+ if not sample:
138
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
139
+
140
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
141
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
142
+
143
+ prompt = [self.prompt] * image.size(0)
144
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
145
+ input_ids[:,0] = self.tokenizer.bos_token_id
146
+ input_ids = input_ids[:, :-1]
147
+
148
+ if sample:
149
+ #nucleus sampling
150
+ outputs = self.text_decoder.generate(input_ids=input_ids,
151
+ max_length=max_length,
152
+ min_length=min_length,
153
+ do_sample=True,
154
+ top_p=top_p,
155
+ num_return_sequences=1,
156
+ eos_token_id=self.tokenizer.sep_token_id,
157
+ pad_token_id=self.tokenizer.pad_token_id,
158
+ repetition_penalty=1.1,
159
+ **model_kwargs)
160
+ else:
161
+ #beam search
162
+ outputs = self.text_decoder.generate(input_ids=input_ids,
163
+ max_length=max_length,
164
+ min_length=min_length,
165
+ num_beams=num_beams,
166
+ eos_token_id=self.tokenizer.sep_token_id,
167
+ pad_token_id=self.tokenizer.pad_token_id,
168
+ repetition_penalty=repetition_penalty,
169
+ **model_kwargs)
170
+
171
+ captions = []
172
+ for output in outputs:
173
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
174
+ captions.append(caption[len(self.prompt):])
175
+ return captions
176
+
177
+
178
+ def blip_decoder(pretrained='',**kwargs):
179
+ model = BLIP_Decoder(**kwargs)
180
+ if pretrained:
181
+ model,msg = load_checkpoint(model,pretrained)
182
+ assert(len(msg.missing_keys)==0)
183
+ return model
184
+
185
+ def blip_feature_extractor(pretrained='',**kwargs):
186
+ model = BLIP_Base(**kwargs)
187
+ if pretrained:
188
+ model,msg = load_checkpoint(model,pretrained)
189
+ assert(len(msg.missing_keys)==0)
190
+ return model
191
+
192
+ def init_tokenizer():
193
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
194
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
195
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
196
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
197
+ return tokenizer
198
+
199
+
200
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
201
+
202
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
203
+ if vit=='base':
204
+ vision_width = 768
205
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
206
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207
+ drop_path_rate=0 or drop_path_rate
208
+ )
209
+ elif vit=='large':
210
+ vision_width = 1024
211
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
212
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
213
+ drop_path_rate=0.1 or drop_path_rate
214
+ )
215
+ return visual_encoder, vision_width
216
+
217
+ def is_url(url_or_filename):
218
+ parsed = urlparse(url_or_filename)
219
+ return parsed.scheme in ("http", "https")
220
+
221
+ def load_checkpoint(model,url_or_filename):
222
+ if is_url(url_or_filename):
223
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
224
+ checkpoint = torch.load(cached_file, map_location='cpu')
225
+ elif os.path.isfile(url_or_filename):
226
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
227
+ else:
228
+ raise RuntimeError('checkpoint url or path is invalid')
229
+
230
+ state_dict = checkpoint['model']
231
+
232
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
233
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
234
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
235
+ model.visual_encoder_m)
236
+ for key in model.state_dict().keys():
237
+ if key in state_dict.keys():
238
+ if state_dict[key].shape!=model.state_dict()[key].shape:
239
+ del state_dict[key]
240
+
241
+ msg = model.load_state_dict(state_dict,strict=False)
242
+ print('load checkpoint from %s'%url_or_filename)
243
+ return model,msg
244
+
245
+ # BLIP VQA
246
+
247
+ class BLIP_VQA(nn.Module):
248
+ def __init__(self,
249
+ med_config = Path(LOCAL_PATH, 'blip_configs/med_config.json'),
250
+ image_size = 480,
251
+ vit = 'base',
252
+ vit_grad_ckpt = False,
253
+ vit_ckpt_layer = 0,
254
+ ):
255
+ """
256
+ Args:
257
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
258
+ image_size (int): input image size
259
+ vit (str): model size of vision transformer
260
+ """
261
+ super().__init__()
262
+
263
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
264
+ self.tokenizer = init_tokenizer()
265
+
266
+ encoder_config = BertConfig.from_json_file(med_config)
267
+ encoder_config.encoder_width = vision_width
268
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
269
+
270
+ decoder_config = BertConfig.from_json_file(med_config)
271
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
272
+
273
+
274
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
275
+
276
+ image_embeds = self.visual_encoder(image)
277
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
278
+
279
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
280
+ return_tensors="pt").to(image.device)
281
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
282
+
283
+ if train:
284
+ '''
285
+ n: number of answers for each question
286
+ weights: weight for each answer
287
+ '''
288
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
289
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
290
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
291
+
292
+ question_output = self.text_encoder(question.input_ids,
293
+ attention_mask = question.attention_mask,
294
+ encoder_hidden_states = image_embeds,
295
+ encoder_attention_mask = image_atts,
296
+ return_dict = True)
297
+
298
+ question_states = []
299
+ question_atts = []
300
+ for b, n in enumerate(n):
301
+ question_states += [question_output.last_hidden_state[b]]*n
302
+ question_atts += [question.attention_mask[b]]*n
303
+ question_states = torch.stack(question_states,0)
304
+ question_atts = torch.stack(question_atts,0)
305
+
306
+ answer_output = self.text_decoder(answer.input_ids,
307
+ attention_mask = answer.attention_mask,
308
+ encoder_hidden_states = question_states,
309
+ encoder_attention_mask = question_atts,
310
+ labels = answer_targets,
311
+ return_dict = True,
312
+ reduction = 'none',
313
+ )
314
+
315
+ loss = weights * answer_output.loss
316
+ loss = loss.sum()/image.size(0)
317
+
318
+ return loss
319
+
320
+
321
+ else:
322
+ question_output = self.text_encoder(question.input_ids,
323
+ attention_mask = question.attention_mask,
324
+ encoder_hidden_states = image_embeds,
325
+ encoder_attention_mask = image_atts,
326
+ return_dict = True)
327
+
328
+ if inference=='generate':
329
+ num_beams = 3
330
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
331
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
332
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
333
+
334
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
335
+
336
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
337
+ max_length=10,
338
+ min_length=1,
339
+ num_beams=num_beams,
340
+ eos_token_id=self.tokenizer.sep_token_id,
341
+ pad_token_id=self.tokenizer.pad_token_id,
342
+ **model_kwargs)
343
+
344
+ answers = []
345
+ for output in outputs:
346
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
347
+ answers.append(answer)
348
+ return answers
349
+
350
+ elif inference=='rank':
351
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
352
+ answer.input_ids, answer.attention_mask, k_test)
353
+ return max_ids
354
+
355
+
356
+
357
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
358
+
359
+ num_ques = question_states.size(0)
360
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
361
+
362
+ start_output = self.text_decoder(start_ids,
363
+ encoder_hidden_states = question_states,
364
+ encoder_attention_mask = question_atts,
365
+ return_dict = True,
366
+ reduction = 'none')
367
+ logits = start_output.logits[:,0,:] # first token's logit
368
+
369
+ # topk_probs: top-k probability
370
+ # topk_ids: [num_question, k]
371
+ answer_first_token = answer_ids[:,1]
372
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
373
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
374
+
375
+ # answer input: [num_question*k, answer_len]
376
+ input_ids = []
377
+ input_atts = []
378
+ for b, topk_id in enumerate(topk_ids):
379
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
380
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
381
+ input_ids = torch.cat(input_ids,dim=0)
382
+ input_atts = torch.cat(input_atts,dim=0)
383
+
384
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
385
+
386
+ # repeat encoder's output for top-k answers
387
+ question_states = tile(question_states, 0, k)
388
+ question_atts = tile(question_atts, 0, k)
389
+
390
+ output = self.text_decoder(input_ids,
391
+ attention_mask = input_atts,
392
+ encoder_hidden_states = question_states,
393
+ encoder_attention_mask = question_atts,
394
+ labels = targets_ids,
395
+ return_dict = True,
396
+ reduction = 'none')
397
+
398
+ log_probs_sum = -output.loss
399
+ log_probs_sum = log_probs_sum.view(num_ques,k)
400
+
401
+ max_topk_ids = log_probs_sum.argmax(dim=1)
402
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
403
+
404
+ return max_ids
405
+
406
+
407
+ def blip_vqa(pretrained='',**kwargs):
408
+ model = BLIP_VQA(**kwargs)
409
+ if pretrained:
410
+ model,msg = load_checkpoint(model,pretrained)
411
+ # assert(len(msg.missing_keys)==0)
412
+ return model
413
+
414
+
415
+ def tile(x, dim, n_tile):
416
+ init_dim = x.size(dim)
417
+ repeat_idx = [1] * x.dim()
418
+ repeat_idx[dim] = n_tile
419
+ x = x.repeat(*(repeat_idx))
420
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
421
+ return torch.index_select(x, dim, order_index.to(x.device))
422
+
423
+
was-node-suite-comfyui/modules/BLIP/blip_module_license.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
was-node-suite-comfyui/modules/BLIP/blip_vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
was-node-suite-comfyui/modules/__init__.py ADDED
File without changes
was-node-suite-comfyui/repos/SAM/demo/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Segment Anything Simple Web demo
2
+
3
+ This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
4
+
5
+ <img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>
6
+
7
+ ## Run the app
8
+
9
+ Install Yarn
10
+
11
+ ```
12
+ npm install --g yarn
13
+ ```
14
+
15
+ Build and run:
16
+
17
+ ```
18
+ yarn && yarn start
19
+ ```
20
+
21
+ Navigate to [`http://localhost:8081/`](http://localhost:8081/)
22
+
23
+ Move your cursor around to see the mask prediction update in real time.
24
+
25
+ ## Export the image embedding
26
+
27
+ In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.
28
+
29
+ Initialize the predictor:
30
+
31
+ ```python
32
+ checkpoint = "sam_vit_h_4b8939.pth"
33
+ model_type = "vit_h"
34
+ sam = sam_model_registry[model_type](checkpoint=checkpoint)
35
+ sam.to(device='cuda')
36
+ predictor = SamPredictor(sam)
37
+ ```
38
+
39
+ Set the new image and export the embedding:
40
+
41
+ ```
42
+ image = cv2.imread('src/assets/dogs.jpg')
43
+ predictor.set_image(image)
44
+ image_embedding = predictor.get_image_embedding().cpu().numpy()
45
+ np.save("dogs_embedding.npy", image_embedding)
46
+ ```
47
+
48
+ Save the new image and embedding in `src/assets/data`.
49
+
50
+ ## Export the ONNX model
51
+
52
+ You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb).
53
+
54
+ Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`.
55
+
56
+ Here is a snippet of the export/quantization code:
57
+
58
+ ```
59
+ onnx_model_path = "sam_onnx_example.onnx"
60
+ onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
61
+ quantize_dynamic(
62
+ model_input=onnx_model_path,
63
+ model_output=onnx_model_quantized_path,
64
+ optimize_model=True,
65
+ per_channel=False,
66
+ reduce_range=False,
67
+ weight_type=QuantType.QUInt8,
68
+ )
69
+ ```
70
+
71
+ **NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**
72
+
73
+ ## Update the image, embedding, model in the app
74
+
75
+ Update the following file paths at the top of`App.tsx`:
76
+
77
+ ```py
78
+ const IMAGE_PATH = "/assets/data/dogs.jpg";
79
+ const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
80
+ const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
81
+ ```
82
+
83
+ ## ONNX multithreading with SharedArrayBuffer
84
+
85
+ To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)
86
+
87
+ The headers below are set in `configs/webpack/dev.js`:
88
+
89
+ ```js
90
+ headers: {
91
+ "Cross-Origin-Opener-Policy": "same-origin",
92
+ "Cross-Origin-Embedder-Policy": "credentialless",
93
+ }
94
+ ```
95
+
96
+ ## Structure of the app
97
+
98
+ **`App.tsx`**
99
+
100
+ - Initializes ONNX model
101
+ - Loads image embedding and image
102
+ - Runs the ONNX model based on input prompts
103
+
104
+ **`Stage.tsx`**
105
+
106
+ - Handles mouse move interaction to update the ONNX model prompt
107
+
108
+ **`Tool.tsx`**
109
+
110
+ - Renders the image and the mask prediction
111
+
112
+ **`helpers/maskUtils.tsx`**
113
+
114
+ - Conversion of ONNX model output from array to an HTMLImageElement
115
+
116
+ **`helpers/onnxModelAPI.tsx`**
117
+
118
+ - Formats the inputs for the ONNX model
119
+
120
+ **`helpers/scaleHelper.tsx`**
121
+
122
+ - Handles image scaling logic for SAM (longest size 1024)
123
+
124
+ **`hooks/`**
125
+
126
+ - Handle shared state for the app
was-node-suite-comfyui/repos/SAM/demo/package.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "segment-anything-mini-demo",
3
+ "version": "0.1.0",
4
+ "license": "MIT",
5
+ "scripts": {
6
+ "build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js",
7
+ "clean-dist": "rimraf dist/*",
8
+ "lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
9
+ "start": "yarn run start-dev",
10
+ "test": "yarn run start-model-test",
11
+ "start-dev": "webpack serve --config=configs/webpack/dev.js"
12
+ },
13
+ "devDependencies": {
14
+ "@babel/core": "^7.18.13",
15
+ "@babel/preset-env": "^7.18.10",
16
+ "@babel/preset-react": "^7.18.6",
17
+ "@babel/preset-typescript": "^7.18.6",
18
+ "@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
19
+ "@testing-library/react": "^13.3.0",
20
+ "@types/node": "^18.7.13",
21
+ "@types/react": "^18.0.17",
22
+ "@types/react-dom": "^18.0.6",
23
+ "@types/underscore": "^1.11.4",
24
+ "@typescript-eslint/eslint-plugin": "^5.35.1",
25
+ "@typescript-eslint/parser": "^5.35.1",
26
+ "babel-loader": "^8.2.5",
27
+ "copy-webpack-plugin": "^11.0.0",
28
+ "css-loader": "^6.7.1",
29
+ "dotenv": "^16.0.2",
30
+ "dotenv-webpack": "^8.0.1",
31
+ "eslint": "^8.22.0",
32
+ "eslint-plugin-react": "^7.31.0",
33
+ "file-loader": "^6.2.0",
34
+ "fork-ts-checker-webpack-plugin": "^7.2.13",
35
+ "friendly-errors-webpack-plugin": "^1.7.0",
36
+ "html-webpack-plugin": "^5.5.0",
37
+ "image-webpack-loader": "^8.1.0",
38
+ "postcss-loader": "^7.0.1",
39
+ "postcss-preset-env": "^7.8.0",
40
+ "process": "^0.11.10",
41
+ "rimraf": "^3.0.2",
42
+ "sass": "^1.54.5",
43
+ "sass-loader": "^13.0.2",
44
+ "style-loader": "^3.3.1",
45
+ "tailwindcss": "^3.1.8",
46
+ "ts-loader": "^9.3.1",
47
+ "typescript": "^4.8.2",
48
+ "webpack": "^5.74.0",
49
+ "webpack-cli": "^4.10.0",
50
+ "webpack-dev-server": "^4.10.0",
51
+ "webpack-dotenv-plugin": "^2.1.0",
52
+ "webpack-merge": "^5.8.0"
53
+ },
54
+ "dependencies": {
55
+ "npyjs": "^0.4.0",
56
+ "onnxruntime-web": "^1.14.0",
57
+ "react": "^18.2.0",
58
+ "react-dom": "^18.2.0",
59
+ "underscore": "^1.13.6",
60
+ "react-refresh": "^0.14.0"
61
+ }
62
+ }
was-node-suite-comfyui/repos/SAM/demo/postcss.config.js ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ const tailwindcss = require("tailwindcss");
8
+ module.exports = {
9
+ plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
10
+ };
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/Interfaces.tsx ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ import { Tensor } from "onnxruntime-web";
8
+
9
+ export interface modelScaleProps {
10
+ samScale: number;
11
+ height: number;
12
+ width: number;
13
+ }
14
+
15
+ export interface modelInputProps {
16
+ x: number;
17
+ y: number;
18
+ clickType: number;
19
+ }
20
+
21
+ export interface modeDataProps {
22
+ clicks?: Array<modelInputProps>;
23
+ tensor: Tensor;
24
+ modelScale: modelScaleProps;
25
+ }
26
+
27
+ export interface ToolProps {
28
+ handleMouseMove: (e: any) => void;
29
+ }
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/maskUtils.tsx ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // Convert the onnx model mask prediction to ImageData
8
+ function arrayToImageData(input: any, width: number, height: number) {
9
+ const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
10
+ const arr = new Uint8ClampedArray(4 * width * height).fill(0);
11
+ for (let i = 0; i < input.length; i++) {
12
+
13
+ // Threshold the onnx model mask prediction at 0.0
14
+ // This is equivalent to thresholding the mask using predictor.model.mask_threshold
15
+ // in python
16
+ if (input[i] > 0.0) {
17
+ arr[4 * i + 0] = r;
18
+ arr[4 * i + 1] = g;
19
+ arr[4 * i + 2] = b;
20
+ arr[4 * i + 3] = a;
21
+ }
22
+ }
23
+ return new ImageData(arr, height, width);
24
+ }
25
+
26
+ // Use a Canvas element to produce an image from ImageData
27
+ function imageDataToImage(imageData: ImageData) {
28
+ const canvas = imageDataToCanvas(imageData);
29
+ const image = new Image();
30
+ image.src = canvas.toDataURL();
31
+ return image;
32
+ }
33
+
34
+ // Canvas elements can be created from ImageData
35
+ function imageDataToCanvas(imageData: ImageData) {
36
+ const canvas = document.createElement("canvas");
37
+ const ctx = canvas.getContext("2d");
38
+ canvas.width = imageData.width;
39
+ canvas.height = imageData.height;
40
+ ctx?.putImageData(imageData, 0, 0);
41
+ return canvas;
42
+ }
43
+
44
+ // Convert the onnx model mask output to an HTMLImageElement
45
+ export function onnxMaskToImage(input: any, width: number, height: number) {
46
+ return imageDataToImage(arrayToImageData(input, width, height));
47
+ }
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/onnxModelAPI.tsx ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ import { Tensor } from "onnxruntime-web";
8
+ import { modeDataProps } from "./Interfaces";
9
+
10
+ const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
11
+ const imageEmbedding = tensor;
12
+ let pointCoords;
13
+ let pointLabels;
14
+ let pointCoordsTensor;
15
+ let pointLabelsTensor;
16
+
17
+ // Check there are input click prompts
18
+ if (clicks) {
19
+ let n = clicks.length;
20
+
21
+ // If there is no box input, a single padding point with
22
+ // label -1 and coordinates (0.0, 0.0) should be concatenated
23
+ // so initialize the array to support (n + 1) points.
24
+ pointCoords = new Float32Array(2 * (n + 1));
25
+ pointLabels = new Float32Array(n + 1);
26
+
27
+ // Add clicks and scale to what SAM expects
28
+ for (let i = 0; i < n; i++) {
29
+ pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
30
+ pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
31
+ pointLabels[i] = clicks[i].clickType;
32
+ }
33
+
34
+ // Add in the extra point/label when only clicks and no box
35
+ // The extra point is at (0, 0) with label -1
36
+ pointCoords[2 * n] = 0.0;
37
+ pointCoords[2 * n + 1] = 0.0;
38
+ pointLabels[n] = -1.0;
39
+
40
+ // Create the tensor
41
+ pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
42
+ pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
43
+ }
44
+ const imageSizeTensor = new Tensor("float32", [
45
+ modelScale.height,
46
+ modelScale.width,
47
+ ]);
48
+
49
+ if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
50
+ return;
51
+
52
+ // There is no previous mask, so default to an empty tensor
53
+ const maskInput = new Tensor(
54
+ "float32",
55
+ new Float32Array(256 * 256),
56
+ [1, 1, 256, 256]
57
+ );
58
+ // There is no previous mask, so default to 0
59
+ const hasMaskInput = new Tensor("float32", [0]);
60
+
61
+ return {
62
+ image_embeddings: imageEmbedding,
63
+ point_coords: pointCoordsTensor,
64
+ point_labels: pointLabelsTensor,
65
+ orig_im_size: imageSizeTensor,
66
+ mask_input: maskInput,
67
+ has_mask_input: hasMaskInput,
68
+ };
69
+ };
70
+
71
+ export { modelData };
was-node-suite-comfyui/repos/SAM/demo/src/components/helpers/scaleHelper.tsx ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ // Helper function for handling image scaling needed for SAM
9
+ const handleImageScale = (image: HTMLImageElement) => {
10
+ // Input images to SAM must be resized so the longest side is 1024
11
+ const LONG_SIDE_LENGTH = 1024;
12
+ let w = image.naturalWidth;
13
+ let h = image.naturalHeight;
14
+ const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
15
+ return { height: h, width: w, samScale };
16
+ };
17
+
18
+ export { handleImageScale };
was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/context.tsx ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ import React, { useState } from "react";
8
+ import { modelInputProps } from "../helpers/Interfaces";
9
+ import AppContext from "./createContext";
10
+
11
+ const AppContextProvider = (props: {
12
+ children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
13
+ }) => {
14
+ const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
15
+ const [image, setImage] = useState<HTMLImageElement | null>(null);
16
+ const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
17
+
18
+ return (
19
+ <AppContext.Provider
20
+ value={{
21
+ clicks: [clicks, setClicks],
22
+ image: [image, setImage],
23
+ maskImg: [maskImg, setMaskImg],
24
+ }}
25
+ >
26
+ {props.children}
27
+ </AppContext.Provider>
28
+ );
29
+ };
30
+
31
+ export default AppContextProvider;
was-node-suite-comfyui/repos/SAM/demo/src/components/hooks/createContext.tsx ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ import { createContext } from "react";
8
+ import { modelInputProps } from "../helpers/Interfaces";
9
+
10
+ interface contextProps {
11
+ clicks: [
12
+ clicks: modelInputProps[] | null,
13
+ setClicks: (e: modelInputProps[] | null) => void
14
+ ];
15
+ image: [
16
+ image: HTMLImageElement | null,
17
+ setImage: (e: HTMLImageElement | null) => void
18
+ ];
19
+ maskImg: [
20
+ maskImg: HTMLImageElement | null,
21
+ setMaskImg: (e: HTMLImageElement | null) => void
22
+ ];
23
+ }
24
+
25
+ const AppContext = createContext<contextProps | null>(null);
26
+
27
+ export default AppContext;