toto10 commited on
Commit
65f2719
·
1 Parent(s): 39cf5c7

aa926620ca6701673dec0f931c22897dc1c41bb94a3cec6e3b01f666d53a734a

Browse files
Files changed (50) hide show
  1. microsoftexcel-supermerger/elemental_ja.md +119 -0
  2. microsoftexcel-supermerger/install.py +7 -0
  3. microsoftexcel-supermerger/sample.txt +95 -0
  4. microsoftexcel-supermerger/scripts/__pycache__/supermerger.cpython-310.pyc +0 -0
  5. microsoftexcel-supermerger/scripts/mbwpresets.txt +39 -0
  6. microsoftexcel-supermerger/scripts/mbwpresets_master.txt +39 -0
  7. microsoftexcel-supermerger/scripts/mergers/__pycache__/mergers.cpython-310.pyc +0 -0
  8. microsoftexcel-supermerger/scripts/mergers/__pycache__/model_util.cpython-310.pyc +0 -0
  9. microsoftexcel-supermerger/scripts/mergers/__pycache__/pluslora.cpython-310.pyc +0 -0
  10. microsoftexcel-supermerger/scripts/mergers/__pycache__/xyplot.cpython-310.pyc +0 -0
  11. microsoftexcel-supermerger/scripts/mergers/mergers.py +699 -0
  12. microsoftexcel-supermerger/scripts/mergers/model_util.py +928 -0
  13. microsoftexcel-supermerger/scripts/mergers/pluslora.py +1298 -0
  14. microsoftexcel-supermerger/scripts/mergers/xyplot.py +513 -0
  15. microsoftexcel-supermerger/scripts/supermerger.py +552 -0
  16. microsoftexcel-tunnels/.gitignore +176 -0
  17. microsoftexcel-tunnels/.pre-commit-config.yaml +25 -0
  18. microsoftexcel-tunnels/LICENSE.md +22 -0
  19. microsoftexcel-tunnels/README.md +21 -0
  20. microsoftexcel-tunnels/__pycache__/preload.cpython-310.pyc +0 -0
  21. microsoftexcel-tunnels/install.py +4 -0
  22. microsoftexcel-tunnels/preload.py +21 -0
  23. microsoftexcel-tunnels/pyproject.toml +25 -0
  24. microsoftexcel-tunnels/scripts/__pycache__/ssh_tunnel.cpython-310.pyc +0 -0
  25. microsoftexcel-tunnels/scripts/__pycache__/try_cloudflare.cpython-310.pyc +0 -0
  26. microsoftexcel-tunnels/scripts/ssh_tunnel.py +81 -0
  27. microsoftexcel-tunnels/scripts/try_cloudflare.py +15 -0
  28. microsoftexcel-tunnels/ssh_tunnel.py +86 -0
  29. openpose-editor/.github/CODEOWNERS +2 -0
  30. openpose-editor/.github/ISSUE_TEMPLATE/bug_report.md +31 -0
  31. openpose-editor/.github/ISSUE_TEMPLATE/feature_request.md +17 -0
  32. openpose-editor/.github/workflows/typos.yml +21 -0
  33. openpose-editor/.gitignore +2 -0
  34. openpose-editor/.vscode/settings.json +5 -0
  35. openpose-editor/LICENSE +21 -0
  36. openpose-editor/README.en.md +40 -0
  37. openpose-editor/README.md +41 -0
  38. openpose-editor/README.zh-cn.md +41 -0
  39. openpose-editor/_typos.toml +11 -0
  40. openpose-editor/configs/.gitkeep +0 -0
  41. openpose-editor/images//343/202/271/343/202/257/343/203/252/343/203/274/343/203/263/343/202/267/343/203/247/343/203/203/343/203/210 2023-02-19 131430.png +0 -0
  42. openpose-editor/javascript/fabric.js +0 -0
  43. openpose-editor/javascript/main.js +691 -0
  44. openpose-editor/scripts/__pycache__/main.cpython-310.pyc +0 -0
  45. openpose-editor/scripts/main.py +143 -0
  46. openpose-editor/scripts/openpose/__pycache__/body.cpython-310.pyc +0 -0
  47. openpose-editor/scripts/openpose/__pycache__/model.cpython-310.pyc +0 -0
  48. openpose-editor/scripts/openpose/__pycache__/util.cpython-310.pyc +0 -0
  49. openpose-editor/scripts/openpose/body.py +220 -0
  50. openpose-editor/scripts/openpose/model.py +221 -0
microsoftexcel-supermerger/elemental_ja.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Elemental Merge
2
+ - 階層マージを越えた階層マージです
3
+
4
+ 階層マージでは25の階層ごとにマージ比率を変えることができますが、階層もまた複数の要素で構成されており、要素ごとに比率を変えることも原理的には可能です。可能ですが、要素の数は600以上にもなり人の手で扱えるのかは疑問でしたが実装してみました。いきなり要素ごとのマージは推奨されません。階層マージにおいて解決不可能な問題が生じたときに最終調節手段として使うことをおすすめします。
5
+ 次の画像はOUT05層の要素を変えた結果です。左端はマージ無し。2番目はOUT05層すべて(つまりは普通の階層マージ),以降が要素マージです。下表のとおり、attn2などの中にはさらに複数の要素が含まれます。
6
+ ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample1.jpg)
7
+
8
+ ## 使い方
9
+ 要素マージは通常マージ、階層マージ時どちらの場合でも有効で、最後に計算されるために、階層マージで指定した値は上書きされることに注意してください。
10
+
11
+ Elemental Mergeで設定します。ここにテキストが設定されていると自動的に適応されるので注意して下さい。各要素は下表のとおりですが、各要素のフルネームを入力する必要はありません。
12
+ ちゃんと効果が現れるかどうかはprint changeチェックを有効にすることで確認できます。このチェックを有効にするとマージ時にコマンドプロンプト画面に適用された要素が表示されます。
13
+ 部分一致で指定が可能です。
14
+ ### 書式
15
+ 階層:要素:比率,階層:要素:比率,...
16
+ または
17
+ 階層:要素:比率
18
+ 階層:要素:比率
19
+ 階層:要素:比率
20
+
21
+ カンマまたは改行で区切ることで複数の指定が可能です。カンマと改行は混在しても問題ありません。
22
+ 階層は大文字でBASE,IN00-M00-OUT11まで指定でます。空欄にするとすべての階層に適用されます。スペースで区切ることで複数の階層を指定できます。
23
+ 要素も同様でスペースで区切ることで複数の要素を指定できます。
24
+ 部分一致で判断するので、例えば「attn」と入力するとattn1,attn2両方が変化します。「attn2」の場合はattn2のみ。さらに細かく指定したい場合は「attn2.to_out」などと入力します。
25
+
26
+ OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5
27
+
28
+ と入力すると、OUT03,OUT04,OUT05層のattn2が含まれる要素及びattn1.to_outの比率が0.5になります。
29
+ 要素の欄を空欄にすると指定階層のすべての要素が変わり、階層マージと同じ効果になります。
30
+ 指定が重複する場合、後に入力された方が優先されます。
31
+
32
+ OUT06:attn:0.5,OUT06:attn2.to_k:0.2
33
+
34
+ と入力した場合、OUT06層のattn2.to_k以外のattnは0.5,attn2.to_kのみ0.2となります。
35
+
36
+ 最初にNOTと入力することで効果範囲を反転させることができます。
37
+ これは階層・要素別に設定できます。
38
+
39
+ NOT OUT04:attn:1
40
+
41
+ と入力するとOUT04層以外の層のattnに比率1が設定されます。
42
+
43
+ OUT05:NOT attn proj:0.2
44
+
45
+ とすると、OUT05層のattnとproj以外の層が0.2になります。
46
+
47
+ ## XY plot
48
+ elemental用のXY plotを複数用意しています。入力例はsample.txtにあります。
49
+ #### elemental
50
+ 複数の要素マージについてXY plotを作成します。要素同士は空行で区切ってください。
51
+ トップ画像はsample.txtのsample1を実行した結果です。
52
+
53
+ #### pinpoint element
54
+ 特定の要素について値を変えてXY plotを作成します。pinpoint Blocksと同じことを要素で行います。反対の軸にはalphaを指定してください。要素同士は改行またはカンマで区切ります。
55
+ 以下の画像はsample.txtのsample3を実行した結果です。
56
+ ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample3.jpg)
57
+
58
+ #### effective elenemtal checker
59
+ 各要素の影響度を差分として出力します。オプションでanime gif、csvファイルを出力できます。gif.csvファイルはoutputフォルダにModelAとModelBから作られるフォルダ下に作成されるdiffフォルダに作成されます。ファイル名が重複する場合名前を変えて保存しますが、増えてくるとややこしいのでdiffフォルダを適当な名前に変えることをおすすめします。
60
+ 改行またはカンマで区切ります。反対の軸はalphaを使用し、単一の値を入力してください。これは要素の効果を見るのにも有効ですが、要素を指定しないことで階層の効果を見ることも可能なので、そちらの使い方をする場合が多いかもしれません。
61
+ 以下��画像はsample.txtのsample5を実行した結果です。
62
+ ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-1.jpg)
63
+ ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-2.jpg)
64
+ ### 要素一覧
65
+ 基本的にはattnが顔や服装の情報を担っているようです。特にIN07,OUT03,OUT04,OUT05層の影響度が強いようです。階層によって影響度が異なることが多いので複数の層の同じ要素を同時に変化させることは意味が無いように思えます。
66
+ nullと書かれた場所には要素が存在しません。
67
+
68
+ ||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11
69
+ |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
70
+ op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
71
+ op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
72
+ emb_layers.1.bias|null|||null|||null|||null|null|||||||||||||||
73
+ emb_layers.1.weight|null|||null|||null|||null|null|||||||||||||||
74
+ in_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
75
+ in_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
76
+ in_layers.2.bias|null|||null|||null|||null|null|||||||||||||||
77
+ in_layers.2.weight|null|||null|||null|||null|null|||||||||||||||
78
+ out_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
79
+ out_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
80
+ out_layers.3.bias|null|||null|||null|||null|null|||||||||||||||
81
+ out_layers.3.weight|null|||null|||null|||null|null|||||||||||||||
82
+ skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
83
+ skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
84
+ norm.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
85
+ norm.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
86
+ proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
87
+ proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
88
+ proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
89
+ proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
90
+ transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
91
+ transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
92
+ transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
93
+ transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
94
+ transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
95
+ transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
96
+ transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
97
+ transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
98
+ transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
99
+ transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
100
+ transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
101
+ transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
102
+ transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
103
+ transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
104
+ transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
105
+ transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
106
+ transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
107
+ transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
108
+ transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
109
+ transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
110
+ conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
111
+ conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
112
+ 0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
113
+ 0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
114
+ 2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
115
+ 2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
116
+ time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
117
+ time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
118
+ time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
119
+ time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
microsoftexcel-supermerger/install.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import launch
2
+
3
+ if not launch.is_installed("sklearn"):
4
+ launch.run_pip("install scikit-learn", "scikit-learn")
5
+
6
+ if not launch.is_installed("diffusers"):
7
+ launch.run_pip("install diffusers", "diffusers")
microsoftexcel-supermerger/sample.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eamples of XY plot
2
+ ***************************************************************
3
+ for elemental
4
+ Each value is separated by a blank line/空行で区切ります
5
+ Element merge values are comma or newline delimited/要素マージの値はカンマまたは改行区切りです
6
+ Commas and newlines can also be mixed/カンマと改行の混在も可能です
7
+ You can insert no elemental by inserting two blank lines at the beginning/最初に空行をふたつ入れるとelementalのない状態を挿入できます
8
+
9
+ **sample1*******************************************************
10
+
11
+
12
+ OUT05::1
13
+
14
+ OUT05:layers:1
15
+
16
+ OUT05:attn1:1
17
+
18
+ OUT05:attn2:1
19
+
20
+ OUT05:ff.net:1
21
+ **sample2*******************************************************
22
+
23
+
24
+ IN07:attn1:0.5,IN07:attn2:0.2
25
+ OUT03:NOT attn1.to_k.weight:0.5
26
+ OUT03:attn1.to_k.weight:0.3
27
+
28
+ OUT04:NOT ff.net:0.5
29
+ OUT04:attn1.to_k.weight:0.3
30
+ OUT04:attn1.to_out.0.weight:0.3
31
+
32
+ OUT05:NOT ff.net:0.5
33
+ OUT05:attn1.to_k.weight:0.3
34
+ OUT05:attn1.to_out.0.weight:0.3
35
+ ***************************************************************
36
+ ***************************************************************
37
+ for pinpoint element
38
+ Each value is separated by comma or a blank line/カンマまたは改行で区切ります
39
+ do not enter ratio/ratioは入力しません
40
+ **sample3******************************************************
41
+ IN07:,IN07:layers,IN07:attn1,IN07:attn2,IN07:ff.net
42
+ **sample4******************************************************
43
+ OUT04:NOT attn2.to_q
44
+ OUT04:attn1.to_k.weight
45
+ OUT04:ff.net
46
+ OUT04:attn
47
+ ***************************************************************
48
+ ***************************************************************
49
+ for effective elemental cheker
50
+
51
+ Examine the effect of each block/各階層の影響度を調べる
52
+ Output results can be split by inserting "|"/途中「|」を挿入することで出力結果を分割できます。
53
+ **sample5*************************************************************
54
+ IN00:,IN01:,IN02:,IN03:,IN04:,IN05:,IN06:,IN07:,IN08:,IN09:,IN10:,IN11:,M00:|OUT00:,OUT01:,OUT02:,OUT03:,OUT04:,OUT05:,OUT06:,OUT07:,OUT08:,OUT09:,OUT10:,OUT11:
55
+
56
+ Examine the effect of all elements in the IN01/IN01層のすべての要素の影響度を調べる
57
+ Below that corresponds to IN02 or later/その下はIN02以降に対応
58
+ **sample6*************************************************************
59
+ IN01:emb_layers.1.bias,IN01:emb_layers.1.weight,IN01:in_layers.0.bias,IN01:in_layers.0.weight,IN01:in_layers.2.bias,IN01:in_layers.2.weight,IN01:out_layers.0.bias,IN01:out_layers.0.weight,IN01:out_layers.3.bias,IN01:out_layers.3.weight,IN01:skip_connection.bias,IN01:skip_connection.weight,IN01:norm.bias,IN01:norm.weight,IN01:proj_in.bias,IN01:proj_in.weight,IN01:proj_out.bias,IN01:proj_out.weight,IN01:transformer_blocks.0.attn1.to_k.weight,IN01:transformer_blocks.0.attn1.to_out.0.bias,IN01:transformer_blocks.0.attn1.to_out.0.weight,IN01:transformer_blocks.0.attn1.to_q.weight,IN01:transformer_blocks.0.attn1.to_v.weight,IN01:transformer_blocks.0.attn2.to_k.weight,IN01:transformer_blocks.0.attn2.to_out.0.bias,IN01:transformer_blocks.0.attn2.to_out.0.weight,IN01:transformer_blocks.0.attn2.to_q.weight,IN01:transformer_blocks.0.attn2.to_v.weight,IN01:transformer_blocks.0.ff.net.0.proj.bias,IN01:transformer_blocks.0.ff.net.0.proj.weight,IN01:transformer_blocks.0.ff.net.2.bias,IN01:transformer_blocks.0.ff.net.2.weight,IN01:transformer_blocks.0.norm1.bias,IN01:transformer_blocks.0.norm1.weight,IN01:transformer_blocks.0.norm2.bias,IN01:transformer_blocks.0.norm2.weight,IN01:transformer_blocks.0.norm3.bias,IN01:transformer_blocks.0.norm3.weight
60
+
61
+ IN02:emb_layers.1.bias,IN02:emb_layers.1.weight,IN02:in_layers.0.bias,IN02:in_layers.0.weight,IN02:in_layers.2.bias,IN02:in_layers.2.weight,IN02:out_layers.0.bias,IN02:out_layers.0.weight,IN02:out_layers.3.bias,IN02:out_layers.3.weight,IN02:skip_connection.bias,IN02:skip_connection.weight,IN02:norm.bias,IN02:norm.weight,IN02:proj_in.bias,IN02:proj_in.weight,IN02:proj_out.bias,IN02:proj_out.weight,IN02:transformer_blocks.0.attn1.to_k.weight,IN02:transformer_blocks.0.attn1.to_out.0.bias,IN02:transformer_blocks.0.attn1.to_out.0.weight,IN02:transformer_blocks.0.attn1.to_q.weight,IN02:transformer_blocks.0.attn1.to_v.weight,IN02:transformer_blocks.0.attn2.to_k.weight,IN02:transformer_blocks.0.attn2.to_out.0.bias,IN02:transformer_blocks.0.attn2.to_out.0.weight,IN02:transformer_blocks.0.attn2.to_q.weight,IN02:transformer_blocks.0.attn2.to_v.weight,IN02:transformer_blocks.0.ff.net.0.proj.bias,IN02:transformer_blocks.0.ff.net.0.proj.weight,IN02:transformer_blocks.0.ff.net.2.bias,IN02:transformer_blocks.0.ff.net.2.weight,IN02:transformer_blocks.0.norm1.bias,IN02:transformer_blocks.0.norm1.weight,IN02:transformer_blocks.0.norm2.bias,IN02:transformer_blocks.0.norm2.weight,IN02:transformer_blocks.0.norm3.bias,IN02:transformer_blocks.0.norm3.weight
62
+
63
+ IN00:bias,IN00:weight,IN03:op.bias,IN03:op.weight,IN06:op.bias,IN06:op.weight,IN09:op.bias,IN09:op.weight
64
+
65
+ IN04:emb_layers.1.bias,IN04:emb_layers.1.weight,IN04:in_layers.0.bias,IN04:in_layers.0.weight,IN04:in_layers.2.bias,IN04:in_layers.2.weight,IN04:out_layers.0.bias,IN04:out_layers.0.weight,IN04:out_layers.3.bias,IN04:out_layers.3.weight,IN04:skip_connection.bias,IN04:skip_connection.weight,IN04:norm.bias,IN04:norm.weight,IN04:proj_in.bias,IN04:proj_in.weight,IN04:proj_out.bias,IN04:proj_out.weight,IN04:transformer_blocks.0.attn1.to_k.weight,IN04:transformer_blocks.0.attn1.to_out.0.bias,IN04:transformer_blocks.0.attn1.to_out.0.weight,IN04:transformer_blocks.0.attn1.to_q.weight,IN04:transformer_blocks.0.attn1.to_v.weight,IN04:transformer_blocks.0.attn2.to_k.weight,IN04:transformer_blocks.0.attn2.to_out.0.bias,IN04:transformer_blocks.0.attn2.to_out.0.weight,IN04:transformer_blocks.0.attn2.to_q.weight,IN04:transformer_blocks.0.attn2.to_v.weight,IN04:transformer_blocks.0.ff.net.0.proj.bias,IN04:transformer_blocks.0.ff.net.0.proj.weight,IN04:transformer_blocks.0.ff.net.2.bias,IN04:transformer_blocks.0.ff.net.2.weight,IN04:transformer_blocks.0.norm1.bias,IN04:transformer_blocks.0.norm1.weight,IN04:transformer_blocks.0.norm2.bias,IN04:transformer_blocks.0.norm2.weight,IN04:transformer_blocks.0.norm3.bias,IN04:transformer_blocks.0.norm3.weight
66
+
67
+ IN05:emb_layers.1.bias,IN05:emb_layers.1.weight,IN05:in_layers.0.bias,IN05:in_layers.0.weight,IN05:in_layers.2.bias,IN05:in_layers.2.weight,IN05:out_layers.0.bias,IN05:out_layers.0.weight,IN05:out_layers.3.bias,IN05:out_layers.3.weight,IN05:skip_connection.bias,IN05:skip_connection.weight,IN05:norm.bias,IN05:norm.weight,IN05:proj_in.bias,IN05:proj_in.weight,IN05:proj_out.bias,IN05:proj_out.weight,IN05:transformer_blocks.0.attn1.to_k.weight,IN05:transformer_blocks.0.attn1.to_out.0.bias,IN05:transformer_blocks.0.attn1.to_out.0.weight,IN05:transformer_blocks.0.attn1.to_q.weight,IN05:transformer_blocks.0.attn1.to_v.weight,IN05:transformer_blocks.0.attn2.to_k.weight,IN05:transformer_blocks.0.attn2.to_out.0.bias,IN05:transformer_blocks.0.attn2.to_out.0.weight,IN05:transformer_blocks.0.attn2.to_q.weight,IN05:transformer_blocks.0.attn2.to_v.weight,IN05:transformer_blocks.0.ff.net.0.proj.bias,IN05:transformer_blocks.0.ff.net.0.proj.weight,IN05:transformer_blocks.0.ff.net.2.bias,IN05:transformer_blocks.0.ff.net.2.weight,IN05:transformer_blocks.0.norm1.bias,IN05:transformer_blocks.0.norm1.weight,IN05:transformer_blocks.0.norm2.bias,IN05:transformer_blocks.0.norm2.weight,IN05:transformer_blocks.0.norm3.bias,IN05:transformer_blocks.0.norm3.weight
68
+
69
+ IN07:emb_layers.1.bias,IN07:emb_layers.1.weight,IN07:in_layers.0.bias,IN07:in_layers.0.weight,IN07:in_layers.2.bias,IN07:in_layers.2.weight,IN07:out_layers.0.bias,IN07:out_layers.0.weight,IN07:out_layers.3.bias,IN07:out_layers.3.weight,IN07:skip_connection.bias,IN07:skip_connection.weight,IN07:norm.bias,IN07:norm.weight,IN07:proj_in.bias,IN07:proj_in.weight,IN07:proj_out.bias,IN07:proj_out.weight,IN07:transformer_blocks.0.attn1.to_k.weight,IN07:transformer_blocks.0.attn1.to_out.0.bias,IN07:transformer_blocks.0.attn1.to_out.0.weight,IN07:transformer_blocks.0.attn1.to_q.weight,IN07:transformer_blocks.0.attn1.to_v.weight,IN07:transformer_blocks.0.attn2.to_k.weight,IN07:transformer_blocks.0.attn2.to_out.0.bias,IN07:transformer_blocks.0.attn2.to_out.0.weight,IN07:transformer_blocks.0.attn2.to_q.weight,IN07:transformer_blocks.0.attn2.to_v.weight,IN07:transformer_blocks.0.ff.net.0.proj.bias,IN07:transformer_blocks.0.ff.net.0.proj.weight,IN07:transformer_blocks.0.ff.net.2.bias,IN07:transformer_blocks.0.ff.net.2.weight,IN07:transformer_blocks.0.norm1.bias,IN07:transformer_blocks.0.norm1.weight,IN07:transformer_blocks.0.norm2.bias,IN07:transformer_blocks.0.norm2.weight,IN07:transformer_blocks.0.norm3.bias,IN07:transformer_blocks.0.norm3.weight
70
+
71
+ IN08:emb_layers.1.bias,IN08:emb_layers.1.weight,IN08:in_layers.0.bias,IN08:in_layers.0.weight,IN08:in_layers.2.bias,IN08:in_layers.2.weight,IN08:out_layers.0.bias,IN08:out_layers.0.weight,IN08:out_layers.3.bias,IN08:out_layers.3.weight,IN08:skip_connection.bias,IN08:skip_connection.weight,IN08:norm.bias,IN08:norm.weight,IN08:proj_in.bias,IN08:proj_in.weight,IN08:proj_out.bias,IN08:proj_out.weight,IN08:transformer_blocks.0.attn1.to_k.weight,IN08:transformer_blocks.0.attn1.to_out.0.bias,IN08:transformer_blocks.0.attn1.to_out.0.weight,IN08:transformer_blocks.0.attn1.to_q.weight,IN08:transformer_blocks.0.attn1.to_v.weight,IN08:transformer_blocks.0.attn2.to_k.weight,IN08:transformer_blocks.0.attn2.to_out.0.bias,IN08:transformer_blocks.0.attn2.to_out.0.weight,IN08:transformer_blocks.0.attn2.to_q.weight,IN08:transformer_blocks.0.attn2.to_v.weight,IN08:transformer_blocks.0.ff.net.0.proj.bias,IN08:transformer_blocks.0.ff.net.0.proj.weight,IN08:transformer_blocks.0.ff.net.2.bias,IN08:transformer_blocks.0.ff.net.2.weight,IN08:transformer_blocks.0.norm1.bias,IN08:transformer_blocks.0.norm1.weight,IN08:transformer_blocks.0.norm2.bias,IN08:transformer_blocks.0.norm2.weight,IN08:transformer_blocks.0.norm3.bias,IN08:transformer_blocks.0.norm3.weight
72
+
73
+ IN10:emb_layers.1.bias,IN10:emb_layers.1.weight,IN10:in_layers.0.bias,IN10:in_layers.0.weight,IN10:in_layers.2.bias,IN10:in_layers.2.weight,IN10:out_layers.0.bias,IN10:out_layers.0.weight,IN10:out_layers.3.bias,IN10:out_layers.3.weight,IN11:emb_layers.1.bias,IN11:emb_layers.1.weight,IN11:in_layers.0.bias,IN11:in_layers.0.weight,IN11:in_layers.2.bias,IN11:in_layers.2.weight,IN11:out_layers.0.bias,IN11:out_layers.0.weight,IN11:out_layers.3.bias,IN11:out_layers.3.weight
74
+
75
+ M00:0.emb_layers.1.bias,M00:0.emb_layers.1.weight,M00:0.in_layers.0.bias,M00:0.in_layers.0.weight,M00:0.in_layers.2.bias,M00:0.in_layers.2.weight,M00:0.out_layers.0.bias,M00:0.out_layers.0.weight,M00:0.out_layers.3.bias,M00:0.out_layers.3.weight,M00:1.norm.bias,M00:1.norm.weight,M00:1.proj_in.bias,M00:1.proj_in.weight,M00:1.proj_out.bias,M00:1.proj_out.weight,M00:1.transformer_blocks.0.attn1.to_k.weight,M00:1.transformer_blocks.0.attn1.to_out.0.bias,M00:1.transformer_blocks.0.attn1.to_out.0.weight,M00:1.transformer_blocks.0.attn1.to_q.weight,M00:1.transformer_blocks.0.attn1.to_v.weight,M00:1.transformer_blocks.0.attn2.to_k.weight,M00:1.transformer_blocks.0.attn2.to_out.0.bias,M00:1.transformer_blocks.0.attn2.to_out.0.weight,M00:1.transformer_blocks.0.attn2.to_q.weight,M00:1.transformer_blocks.0.attn2.to_v.weight,M00:1.transformer_blocks.0.ff.net.0.proj.bias,M00:1.transformer_blocks.0.ff.net.0.proj.weight,M00:1.transformer_blocks.0.ff.net.2.bias,M00:1.transformer_blocks.0.ff.net.2.weight,M00:1.transformer_blocks.0.norm1.bias,M00:1.transformer_blocks.0.norm1.weight,M00:1.transformer_blocks.0.norm2.bias,M00:1.transformer_blocks.0.norm2.weight,M00:1.transformer_blocks.0.norm3.bias,M00:1.transformer_blocks.0.norm3.weight,M00:2.emb_layers.1.bias,M00:2.emb_layers.1.weight,M00:2.in_layers.0.bias,M00:2.in_layers.0.weight,M00:2.in_layers.2.bias,M00:2.in_layers.2.weight,M00:2.out_layers.0.bias,M00:2.out_layers.0.weight,M00:2.out_layers.3.bias,M00:2.out_layers.3.weight
76
+
77
+ OUT00:emb_layers.1.bias,OUT00:emb_layers.1.weight,OUT00:in_layers.0.bias,OUT00:in_layers.0.weight,OUT00:in_layers.2.bias,OUT00:in_layers.2.weight,OUT00:out_layers.0.bias,OUT00:out_layers.0.weight,OUT00:out_layers.3.bias,OUT00:out_layers.3.weight,OUT00:skip_connection.bias,OUT00:skip_connection.weight,OUT01:emb_layers.1.bias,OUT01:emb_layers.1.weight,OUT01:in_layers.0.bias,OUT01:in_layers.0.weight,OUT01:in_layers.2.bias,OUT01:in_layers.2.weight,OUT01:out_layers.0.bias,OUT01:out_layers.0.weight,OUT01:out_layers.3.bias,OUT01:out_layers.3.weight,OUT01:skip_connection.bias,OUT01:skip_connection.weight,OUT02:emb_layers.1.bias,OUT02:emb_layers.1.weight,OUT02:in_layers.0.bias,OUT02:in_layers.0.weight,OUT02:in_layers.2.bias,OUT02:in_layers.2.weight,OUT02:out_layers.0.bias,OUT02:out_layers.0.weight,OUT02:out_layers.3.bias,OUT02:out_layers.3.weight,OUT02:skip_connection.bias,OUT02:skip_connection.weight,OUT02:conv.bias,OUT02:conv.weight
78
+
79
+ OUT03:emb_layers.1.bias,OUT03:emb_layers.1.weight,OUT03:in_layers.0.bias,OUT03:in_layers.0.weight,OUT03:in_layers.2.bias,OUT03:in_layers.2.weight,OUT03:out_layers.0.bias,OUT03:out_layers.0.weight,OUT03:out_layers.3.bias,OUT03:out_layers.3.weight,OUT03:skip_connection.bias,OUT03:skip_connection.weight,OUT03:norm.bias,OUT03:norm.weight,OUT03:proj_in.bias,OUT03:proj_in.weight,OUT03:proj_out.bias,OUT03:proj_out.weight,OUT03:transformer_blocks.0.attn1.to_k.weight,OUT03:transformer_blocks.0.attn1.to_out.0.bias,OUT03:transformer_blocks.0.attn1.to_out.0.weight,OUT03:transformer_blocks.0.attn1.to_q.weight,OUT03:transformer_blocks.0.attn1.to_v.weight,OUT03:transformer_blocks.0.attn2.to_k.weight,OUT03:transformer_blocks.0.attn2.to_out.0.bias,OUT03:transformer_blocks.0.attn2.to_out.0.weight,OUT03:transformer_blocks.0.attn2.to_q.weight,OUT03:transformer_blocks.0.attn2.to_v.weight,OUT03:transformer_blocks.0.ff.net.0.proj.bias,OUT03:transformer_blocks.0.ff.net.0.proj.weight,OUT03:transformer_blocks.0.ff.net.2.bias,OUT03:transformer_blocks.0.ff.net.2.weight,OUT03:transformer_blocks.0.norm1.bias,OUT03:transformer_blocks.0.norm1.weight,OUT03:transformer_blocks.0.norm2.bias,OUT03:transformer_blocks.0.norm2.weight,OUT03:transformer_blocks.0.norm3.bias,OUT03:transformer_blocks.0.norm3.weight
80
+
81
+ OUT04:emb_layers.1.bias,OUT04:emb_layers.1.weight,OUT04:in_layers.0.bias,OUT04:in_layers.0.weight,OUT04:in_layers.2.bias,OUT04:in_layers.2.weight,OUT04:out_layers.0.bias,OUT04:out_layers.0.weight,OUT04:out_layers.3.bias,OUT04:out_layers.3.weight,OUT04:skip_connection.bias,OUT04:skip_connection.weight,OUT04:norm.bias,OUT04:norm.weight,OUT04:proj_in.bias,OUT04:proj_in.weight,OUT04:proj_out.bias,OUT04:proj_out.weight,OUT04:transformer_blocks.0.attn1.to_k.weight,OUT04:transformer_blocks.0.attn1.to_out.0.bias,OUT04:transformer_blocks.0.attn1.to_out.0.weight,OUT04:transformer_blocks.0.attn1.to_q.weight,OUT04:transformer_blocks.0.attn1.to_v.weight,OUT04:transformer_blocks.0.attn2.to_k.weight,OUT04:transformer_blocks.0.attn2.to_out.0.bias,OUT04:transformer_blocks.0.attn2.to_out.0.weight,OUT04:transformer_blocks.0.attn2.to_q.weight,OUT04:transformer_blocks.0.attn2.to_v.weight,OUT04:transformer_blocks.0.ff.net.0.proj.bias,OUT04:transformer_blocks.0.ff.net.0.proj.weight,OUT04:transformer_blocks.0.ff.net.2.bias,OUT04:transformer_blocks.0.ff.net.2.weight,OUT04:transformer_blocks.0.norm1.bias,OUT04:transformer_blocks.0.norm1.weight,OUT04:transformer_blocks.0.norm2.bias,OUT04:transformer_blocks.0.norm2.weight,OUT04:transformer_blocks.0.norm3.bias,OUT04:transformer_blocks.0.norm3.weight
82
+
83
+ OUT05:emb_layers.1.bias,OUT05:emb_layers.1.weight,OUT05:in_layers.0.bias,OUT05:in_layers.0.weight,OUT05:in_layers.2.bias,OUT05:in_layers.2.weight,OUT05:out_layers.0.bias,OUT05:out_layers.0.weight,OUT05:out_layers.3.bias,OUT05:out_layers.3.weight,OUT05:skip_connection.bias,OUT05:skip_connection.weight,OUT05:norm.bias,OUT05:norm.weight,OUT05:proj_in.bias,OUT05:proj_in.weight,OUT05:proj_out.bias,OUT05:proj_out.weight,OUT05:transformer_blocks.0.attn1.to_k.weight,OUT05:transformer_blocks.0.attn1.to_out.0.bias,OUT05:transformer_blocks.0.attn1.to_out.0.weight,OUT05:transformer_blocks.0.attn1.to_q.weight,OUT05:transformer_blocks.0.attn1.to_v.weight,OUT05:transformer_blocks.0.attn2.to_k.weight,OUT05:transformer_blocks.0.attn2.to_out.0.bias,OUT05:transformer_blocks.0.attn2.to_out.0.weight,OUT05:transformer_blocks.0.attn2.to_q.weight,OUT05:transformer_blocks.0.attn2.to_v.weight,OUT05:transformer_blocks.0.ff.net.0.proj.bias,OUT05:transformer_blocks.0.ff.net.0.proj.weight,OUT05:transformer_blocks.0.ff.net.2.bias,OUT05:transformer_blocks.0.ff.net.2.weight,OUT05:transformer_blocks.0.norm1.bias,OUT05:transformer_blocks.0.norm1.weight,OUT05:transformer_blocks.0.norm2.bias,OUT05:transformer_blocks.0.norm2.weight,OUT05:transformer_blocks.0.norm3.bias,OUT05:transformer_blocks.0.norm3.weight,OUT05:conv.bias,OUT05:conv.weight
84
+
85
+ ,OUT06:emb_layers.1.bias,OUT06:emb_layers.1.weight,OUT06:in_layers.0.bias,OUT06:in_layers.0.weight,OUT06:in_layers.2.bias,OUT06:in_layers.2.weight,OUT06:out_layers.0.bias,OUT06:out_layers.0.weight,OUT06:out_layers.3.bias,OUT06:out_layers.3.weight,OUT06:skip_connection.bias,OUT06:skip_connection.weight,OUT06:norm.bias,OUT06:norm.weight,OUT06:proj_in.bias,OUT06:proj_in.weight,OUT06:proj_out.bias,OUT06:proj_out.weight,OUT06:transformer_blocks.0.attn1.to_k.weight,OUT06:transformer_blocks.0.attn1.to_out.0.bias,OUT06:transformer_blocks.0.attn1.to_out.0.weight,OUT06:transformer_blocks.0.attn1.to_q.weight,OUT06:transformer_blocks.0.attn1.to_v.weight,OUT06:transformer_blocks.0.attn2.to_k.weight,OUT06:transformer_blocks.0.attn2.to_out.0.bias,OUT06:transformer_blocks.0.attn2.to_out.0.weight,OUT06:transformer_blocks.0.attn2.to_q.weight,OUT06:transformer_blocks.0.attn2.to_v.weight,OUT06:transformer_blocks.0.ff.net.0.proj.bias,OUT06:transformer_blocks.0.ff.net.0.proj.weight,OUT06:transformer_blocks.0.ff.net.2.bias,OUT06:transformer_blocks.0.ff.net.2.weight,OUT06:transformer_blocks.0.norm1.bias,OUT06:transformer_blocks.0.norm1.weight,OUT06:transformer_blocks.0.norm2.bias,OUT06:transformer_blocks.0.norm2.weight,OUT06:transformer_blocks.0.norm3.bias,OUT06:transformer_blocks.0.norm3.weight
86
+
87
+ OUT07:emb_layers.1.bias,OUT07:emb_layers.1.weight,OUT07:in_layers.0.bias,OUT07:in_layers.0.weight,OUT07:in_layers.2.bias,OUT07:in_layers.2.weight,OUT07:out_layers.0.bias,OUT07:out_layers.0.weight,OUT07:out_layers.3.bias,OUT07:out_layers.3.weight,OUT07:skip_connection.bias,OUT07:skip_connection.weight,OUT07:norm.bias,OUT07:norm.weight,OUT07:proj_in.bias,OUT07:proj_in.weight,OUT07:proj_out.bias,OUT07:proj_out.weight,OUT07:transformer_blocks.0.attn1.to_k.weight,OUT07:transformer_blocks.0.attn1.to_out.0.bias,OUT07:transformer_blocks.0.attn1.to_out.0.weight,OUT07:transformer_blocks.0.attn1.to_q.weight,OUT07:transformer_blocks.0.attn1.to_v.weight,OUT07:transformer_blocks.0.attn2.to_k.weight,OUT07:transformer_blocks.0.attn2.to_out.0.bias,OUT07:transformer_blocks.0.attn2.to_out.0.weight,OUT07:transformer_blocks.0.attn2.to_q.weight,OUT07:transformer_blocks.0.attn2.to_v.weight,OUT07:transformer_blocks.0.ff.net.0.proj.bias,OUT07:transformer_blocks.0.ff.net.0.proj.weight,OUT07:transformer_blocks.0.ff.net.2.bias,OUT07:transformer_blocks.0.ff.net.2.weight,OUT07:transformer_blocks.0.norm1.bias,OUT07:transformer_blocks.0.norm1.weight,OUT07:transformer_blocks.0.norm2.bias,OUT07:transformer_blocks.0.norm2.weight,OUT07:transformer_blocks.0.norm3.bias,OUT07:transformer_blocks.0.norm3.weight
88
+
89
+ ,OUT08:emb_layers.1.bias,OUT08:emb_layers.1.weight,OUT08:in_layers.0.bias,OUT08:in_layers.0.weight,OUT08:in_layers.2.bias,OUT08:in_layers.2.weight,OUT08:out_layers.0.bias,OUT08:out_layers.0.weight,OUT08:out_layers.3.bias,OUT08:out_layers.3.weight,OUT08:skip_connection.bias,OUT08:skip_connection.weight,OUT08:norm.bias,OUT08:norm.weight,OUT08:proj_in.bias,OUT08:proj_in.weight,OUT08:proj_out.bias,OUT08:proj_out.weight,OUT08:transformer_blocks.0.attn1.to_k.weight,OUT08:transformer_blocks.0.attn1.to_out.0.bias,OUT08:transformer_blocks.0.attn1.to_out.0.weight,OUT08:transformer_blocks.0.attn1.to_q.weight,OUT08:transformer_blocks.0.attn1.to_v.weight,OUT08:transformer_blocks.0.attn2.to_k.weight,OUT08:transformer_blocks.0.attn2.to_out.0.bias,OUT08:transformer_blocks.0.attn2.to_out.0.weight,OUT08:transformer_blocks.0.attn2.to_q.weight,OUT08:transformer_blocks.0.attn2.to_v.weight,OUT08:transformer_blocks.0.ff.net.0.proj.bias,OUT08:transformer_blocks.0.ff.net.0.proj.weight,OUT08:transformer_blocks.0.ff.net.2.bias,OUT08:transformer_blocks.0.ff.net.2.weight,OUT08:transformer_blocks.0.norm1.bias,OUT08:transformer_blocks.0.norm1.weight,OUT08:transformer_blocks.0.norm2.bias,OUT08:transformer_blocks.0.norm2.weight,OUT08:transformer_blocks.0.norm3.bias,OUT08:transformer_blocks.0.norm3.weight,OUT08:conv.bias,OUT08:conv.weight
90
+
91
+ OUT09:emb_layers.1.bias,OUT09:emb_layers.1.weight,OUT09:in_layers.0.bias,OUT09:in_layers.0.weight,OUT09:in_layers.2.bias,OUT09:in_layers.2.weight,OUT09:out_layers.0.bias,OUT09:out_layers.0.weight,OUT09:out_layers.3.bias,OUT09:out_layers.3.weight,OUT09:skip_connection.bias,OUT09:skip_connection.weight,OUT09:norm.bias,OUT09:norm.weight,OUT09:proj_in.bias,OUT09:proj_in.weight,OUT09:proj_out.bias,OUT09:proj_out.weight,OUT09:transformer_blocks.0.attn1.to_k.weight,OUT09:transformer_blocks.0.attn1.to_out.0.bias,OUT09:transformer_blocks.0.attn1.to_out.0.weight,OUT09:transformer_blocks.0.attn1.to_q.weight,OUT09:transformer_blocks.0.attn1.to_v.weight,OUT09:transformer_blocks.0.attn2.to_k.weight,OUT09:transformer_blocks.0.attn2.to_out.0.bias,OUT09:transformer_blocks.0.attn2.to_out.0.weight,OUT09:transformer_blocks.0.attn2.to_q.weight,OUT09:transformer_blocks.0.attn2.to_v.weight,OUT09:transformer_blocks.0.ff.net.0.proj.bias,OUT09:transformer_blocks.0.ff.net.0.proj.weight,OUT09:transformer_blocks.0.ff.net.2.bias,OUT09:transformer_blocks.0.ff.net.2.weight,OUT09:transformer_blocks.0.norm1.bias,OUT09:transformer_blocks.0.norm1.weight,OUT09:transformer_blocks.0.norm2.bias,OUT09:transformer_blocks.0.norm2.weight,OUT09:transformer_blocks.0.norm3.bias,OUT09:transformer_blocks.0.norm3.weight
92
+
93
+ OUT10:emb_layers.1.bias,OUT10:emb_layers.1.weight,OUT10:in_layers.0.bias,OUT10:in_layers.0.weight,OUT10:in_layers.2.bias,OUT10:in_layers.2.weight,OUT10:out_layers.0.bias,OUT10:out_layers.0.weight,OUT10:out_layers.3.bias,OUT10:out_layers.3.weight,OUT10:skip_connection.bias,OUT10:skip_connection.weight,OUT10:norm.bias,OUT10:norm.weight,OUT10:proj_in.bias,OUT10:proj_in.weight,OUT10:proj_out.bias,OUT10:proj_out.weight,OUT10:transformer_blocks.0.attn1.to_k.weight,OUT10:transformer_blocks.0.attn1.to_out.0.bias,OUT10:transformer_blocks.0.attn1.to_out.0.weight,OUT10:transformer_blocks.0.attn1.to_q.weight,OUT10:transformer_blocks.0.attn1.to_v.weight,OUT10:transformer_blocks.0.attn2.to_k.weight,OUT10:transformer_blocks.0.attn2.to_out.0.bias,OUT10:transformer_blocks.0.attn2.to_out.0.weight,OUT10:transformer_blocks.0.attn2.to_q.weight,OUT10:transformer_blocks.0.attn2.to_v.weight,OUT10:transformer_blocks.0.ff.net.0.proj.bias,OUT10:transformer_blocks.0.ff.net.0.proj.weight,OUT10:transformer_blocks.0.ff.net.2.bias,OUT10:transformer_blocks.0.ff.net.2.weight,OUT10:transformer_blocks.0.norm1.bias,OUT10:transformer_blocks.0.norm1.weight,OUT10:transformer_blocks.0.norm2.bias,OUT10:transformer_blocks.0.norm2.weight,OUT10:transformer_blocks.0.norm3.bias,OUT10:transformer_blocks.0.norm3.weight
94
+
95
+ OUT11:emb_layers.1.bias,OUT11:emb_layers.1.weight,OUT11:in_layers.0.bias,OUT11:in_layers.0.weight,OUT11:in_layers.2.bias,OUT11:in_layers.2.weight,OUT11:out_layers.0.bias,OUT11:out_layers.0.weight,OUT11:out_layers.3.bias,OUT11:out_layers.3.weight,OUT11:skip_connection.bias,OUT11:skip_connection.weight,OUT11:norm.bias,OUT11:norm.weight,OUT11:proj_in.bias,OUT11:proj_in.weight,OUT11:proj_out.bias,OUT11:proj_out.weight,OUT11:transformer_blocks.0.attn1.to_k.weight,OUT11:transformer_blocks.0.attn1.to_out.0.bias,OUT11:transformer_blocks.0.attn1.to_out.0.weight,OUT11:transformer_blocks.0.attn1.to_q.weight,OUT11:transformer_blocks.0.attn1.to_v.weight,OUT11:transformer_blocks.0.attn2.to_k.weight,OUT11:transformer_blocks.0.attn2.to_out.0.bias,OUT11:transformer_blocks.0.attn2.to_out.0.weight,OUT11:transformer_blocks.0.attn2.to_q.weight,OUT11:transformer_blocks.0.attn2.to_v.weight,OUT11:transformer_blocks.0.ff.net.0.proj.bias,OUT11:transformer_blocks.0.ff.net.0.proj.weight,OUT11:transformer_blocks.0.ff.net.2.bias,OUT11:transformer_blocks.0.ff.net.2.weight,OUT11:transformer_blocks.0.norm1.bias,OUT11:transformer_blocks.0.norm1.weight,OUT11:transformer_blocks.0.norm2.bias,OUT11:transformer_blocks.0.norm2.weight,OUT11:transformer_blocks.0.norm3.bias,OUT11:transformer_blocks.0.norm3.weight,OUT11:0.bias,OUT11:0.weight,OUT11:2.bias,OUT11:2.weight
microsoftexcel-supermerger/scripts/__pycache__/supermerger.cpython-310.pyc ADDED
Binary file (22.2 kB). View file
 
microsoftexcel-supermerger/scripts/mbwpresets.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preset_name preset_weights
2
+ GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
3
+ GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
4
+ FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
5
+ FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
6
+ WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
7
+ WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
8
+ WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
9
+ WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
10
+ MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
11
+ OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
12
+ OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
13
+ OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
14
+ RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
15
+ RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
16
+ RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
17
+ RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
18
+ SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1
19
+ REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0
20
+ SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0
21
+ R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1
22
+ SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1
23
+ R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0
24
+ SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0
25
+ R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1
26
+ SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0
27
+ R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1
28
+ SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1
29
+ R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0
30
+ SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0
31
+ R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1
32
+ COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0
33
+ REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1
34
+ TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1
35
+ TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0
36
+ FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1
37
+ FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0
38
+ ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
39
+ ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
microsoftexcel-supermerger/scripts/mbwpresets_master.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preset_name preset_weights
2
+ GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
3
+ GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
4
+ FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
5
+ FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
6
+ WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
7
+ WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
8
+ WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
9
+ WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
10
+ MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
11
+ OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
12
+ OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
13
+ OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
14
+ RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
15
+ RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
16
+ RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
17
+ RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
18
+ SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1
19
+ REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0
20
+ SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0
21
+ R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1
22
+ SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1
23
+ R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0
24
+ SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0
25
+ R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1
26
+ SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0
27
+ R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1
28
+ SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1
29
+ R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0
30
+ SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0
31
+ R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1
32
+ COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0
33
+ REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1
34
+ TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1
35
+ TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0
36
+ FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1
37
+ FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0
38
+ ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
39
+ ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
microsoftexcel-supermerger/scripts/mergers/__pycache__/mergers.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
microsoftexcel-supermerger/scripts/mergers/__pycache__/model_util.cpython-310.pyc ADDED
Binary file (24.8 kB). View file
 
microsoftexcel-supermerger/scripts/mergers/__pycache__/pluslora.cpython-310.pyc ADDED
Binary file (34.5 kB). View file
 
microsoftexcel-supermerger/scripts/mergers/__pycache__/xyplot.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
microsoftexcel-supermerger/scripts/mergers/mergers.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from linecache import clearcache
2
+
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import os.path
7
+ import re
8
+ import torch
9
+ import tqdm
10
+ import datetime
11
+ import csv
12
+ import json
13
+ import torch.nn as nn
14
+ import scipy.ndimage
15
+ from scipy.ndimage.filters import median_filter as filter
16
+ from PIL import Image, ImageFont, ImageDraw
17
+ from tqdm import tqdm
18
+ from modules import shared, processing, sd_models, images, sd_samplers,scripts
19
+ from modules.ui import plaintext_to_html
20
+ from modules.shared import opts
21
+ from modules.processing import create_infotext,Processed
22
+ from modules.sd_models import load_model,checkpoints_loaded
23
+ from scripts.mergers.model_util import usemodelgen,filenamecutter,savemodel
24
+
25
+ from inspect import currentframe
26
+
27
+ stopmerge = False
28
+
29
+ def freezemtime():
30
+ global stopmerge
31
+ stopmerge = True
32
+
33
+ mergedmodel=[]
34
+ TYPESEG = ["none","alpha","beta (if Triple or Twice is not selected,Twice automatically enable)","alpha and beta","seed", "mbw alpha","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks (alpha or beta must be selected for another axis)","elemental","pinpoint element","effective elemental checker","tensors","calcmode","prompt"]
35
+ TYPES = ["none","alpha","beta","alpha and beta","seed", "mbw alpha ","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks","elemental","pinpoint element","effective","tensor","calcmode","prompt"]
36
+ MODES=["Weight" ,"Add" ,"Triple","Twice"]
37
+ SAVEMODES=["save model", "overwrite"]
38
+ #type[0:aplha,1:beta,2:seed,3:mbw,4:model_A,5:model_B,6:model_C]
39
+ #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
40
+ #id sets "image", "PNG info","XY grid"
41
+
42
+ hear = False
43
+ hearm = False
44
+ non4 = [None]*4
45
+
46
+ def caster(news,hear):
47
+ if hear: print(news)
48
+
49
+ def casterr(*args,hear=hear):
50
+ if hear:
51
+ names = {id(v): k for k, v in currentframe().f_back.f_locals.items()}
52
+ print('\n'.join([names.get(id(arg), '???') + ' = ' + repr(arg) for arg in args]))
53
+
54
+ #msettings=[weights_a,weights_b,model_a,model_b,model_c,device,base_alpha,base_beta,mode,loranames,useblocks,custom_name,save_sets,id_sets,wpresets,deep]
55
+ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,
56
+ calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
57
+ esettings,
58
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
59
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,
60
+ currentmodel,imggen):
61
+
62
+ deepprint = True if "print change" in esettings else False
63
+
64
+ result,currentmodel,modelid,theta_0,metadata = smerge(
65
+ weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
66
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,deepprint=deepprint
67
+ )
68
+
69
+ if "ERROR" in result or "STOPPED" in result:
70
+ return result,"not loaded",*non4
71
+
72
+ usemodelgen(theta_0,model_a,currentmodel)
73
+
74
+ save = True if SAVEMODES[0] in save_sets else False
75
+
76
+ result = savemodel(theta_0,currentmodel,custom_name,save_sets,model_a,metadata) if save else "Merged model loaded:"+currentmodel
77
+ del theta_0
78
+ gc.collect()
79
+
80
+ if imggen :
81
+ images = simggen(prompt,nprompt,steps,sampler,cfg,seed,w,h,hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,currentmodel,id_sets,modelid)
82
+ return result,currentmodel,*images[:4]
83
+ else:
84
+ return result,currentmodel
85
+
86
+ NUM_INPUT_BLOCKS = 12
87
+ NUM_MID_BLOCK = 1
88
+ NUM_OUTPUT_BLOCKS = 12
89
+ NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
90
+ blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
91
+
92
+ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
93
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,deepprint = False):
94
+ caster("merge start",hearm)
95
+ global hear,mergedmodel,stopmerge
96
+ stopmerge = False
97
+
98
+ gc.collect()
99
+
100
+ # for from file
101
+ if type(useblocks) is str:
102
+ useblocks = True if useblocks =="True" else False
103
+ if type(base_alpha) == str:base_alpha = float(base_alpha)
104
+ if type(base_beta) == str:base_beta = float(base_beta)
105
+
106
+ weights_a_orig = weights_a
107
+ weights_b_orig = weights_b
108
+
109
+ # preset to weights
110
+ if wpresets != False and useblocks:
111
+ weights_a = wpreseter(weights_a,wpresets)
112
+ weights_b = wpreseter(weights_b,wpresets)
113
+
114
+ # mode select booleans
115
+ save = True if SAVEMODES[0] in save_sets else False
116
+ usebeta = MODES[2] in mode or MODES[3] in mode or calcmode == "tensor"
117
+ save_metadata = "save metadata" in save_sets
118
+ metadata = {"format": "pt"}
119
+
120
+ if not useblocks:
121
+ weights_a = weights_b = ""
122
+ #for save log and save current model
123
+ mergedmodel =[weights_a,weights_b,
124
+ hashfromname(model_a),hashfromname(model_b),hashfromname(model_c),
125
+ base_alpha,base_beta,mode,useblocks,custom_name,save_sets,id_sets,deep,calcmode,tensor].copy()
126
+
127
+ model_a = namefromhash(model_a)
128
+ model_b = namefromhash(model_b)
129
+ model_c = namefromhash(model_c)
130
+ theta_2 = {}
131
+
132
+ caster(mergedmodel,False)
133
+
134
+ if len(deep) > 0:
135
+ deep = deep.replace("\n",",")
136
+ deep = deep.split(",")
137
+
138
+ #format check
139
+ if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
140
+ return "ERROR: Necessary model is not selected",*non4
141
+
142
+ #for MBW text to list
143
+ if useblocks:
144
+ weights_a_t=weights_a.split(',',1)
145
+ weights_b_t=weights_b.split(',',1)
146
+ base_alpha = float(weights_a_t[0])
147
+ weights_a = [float(w) for w in weights_a_t[1].split(',')]
148
+ caster(f"from {weights_a_t}, alpha = {base_alpha},weights_a ={weights_a}",hearm)
149
+ if len(weights_a) != 25:return f"ERROR: weights alpha value must be {26}.",*non4
150
+ if usebeta:
151
+ base_beta = float(weights_b_t[0])
152
+ weights_b = [float(w) for w in weights_b_t[1].split(',')]
153
+ caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
154
+ if len(weights_b) != 25: return f"ERROR: weights beta value must be {26}.",*non4
155
+
156
+ caster("model load start",hearm)
157
+
158
+ print(f" model A \t: {model_a}")
159
+ print(f" model B \t: {model_b}")
160
+ print(f" model C \t: {model_c}")
161
+ print(f" alpha,beta\t: {base_alpha,base_beta}")
162
+ print(f" weights_alpha\t: {weights_a}")
163
+ print(f" weights_beta\t: {weights_b}")
164
+ print(f" mode\t\t: {mode}")
165
+ print(f" MBW \t\t: {useblocks}")
166
+ print(f" CalcMode \t: {calcmode}")
167
+ print(f" Elemental \t: {deep}")
168
+ print(f" Tensors \t: {tensor}")
169
+
170
+ theta_1=load_model_weights_m(model_b,False,True,save).copy()
171
+
172
+ if MODES[1] in mode:#Add
173
+ if stopmerge: return "STOPPED", *non4
174
+ theta_2 = load_model_weights_m(model_c,False,False,save).copy()
175
+ for key in tqdm(theta_1.keys()):
176
+ if 'model' in key:
177
+ if key in theta_2:
178
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
179
+ theta_1[key] = theta_1[key]- t2
180
+ else:
181
+ theta_1[key] = torch.zeros_like(theta_1[key])
182
+ del theta_2
183
+
184
+ if stopmerge: return "STOPPED", *non4
185
+
186
+ if calcmode == "tensor":
187
+ theta_t = load_model_weights_m(model_a,True,False,save).copy()
188
+ theta_0 ={}
189
+ for key in theta_t:
190
+ theta_0[key] = theta_t[key].clone()
191
+ del theta_t
192
+ else:
193
+ theta_0=load_model_weights_m(model_a,True,False,save).copy()
194
+
195
+
196
+ if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
197
+ theta_2 = load_model_weights_m(model_c,False,False,save).copy()
198
+
199
+ alpha = base_alpha
200
+ beta = base_beta
201
+
202
+ re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
203
+ re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
204
+ re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
205
+
206
+ chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
207
+ count_target_of_basealpha = 0
208
+
209
+ if calcmode =="cosineA": #favors modelA's structure with details from B
210
+ if stopmerge: return "STOPPED", *non4
211
+ sim = torch.nn.CosineSimilarity(dim=0)
212
+ sims = np.array([], dtype=np.float64)
213
+ for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
214
+ # skip VAE model parameters to get better results
215
+ if "first_stage_model" in key: continue
216
+ if "model" in key and key in theta_1:
217
+ theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
218
+ theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
219
+ simab = sim(theta_0_norm, theta_1_norm)
220
+ sims = np.append(sims,simab.numpy())
221
+ sims = sims[~np.isnan(sims)]
222
+ sims = np.delete(sims, np.where(sims<np.percentile(sims, 1 ,method = 'midpoint')))
223
+ sims = np.delete(sims, np.where(sims>np.percentile(sims, 99 ,method = 'midpoint')))
224
+
225
+ if calcmode =="cosineB": #favors modelB's structure with details from A
226
+ if stopmerge: return "STOPPED", *non4
227
+ sim = torch.nn.CosineSimilarity(dim=0)
228
+ sims = np.array([], dtype=np.float64)
229
+ for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
230
+ # skip VAE model parameters to get better results
231
+ if "first_stage_model" in key: continue
232
+ if "model" in key and key in theta_1:
233
+ simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
234
+ dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
235
+ magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
236
+ combined_similarity = (simab + magnitude_similarity) / 2.0
237
+ sims = np.append(sims, combined_similarity.numpy())
238
+ sims = sims[~np.isnan(sims)]
239
+ sims = np.delete(sims, np.where(sims < np.percentile(sims, 1, method='midpoint')))
240
+ sims = np.delete(sims, np.where(sims > np.percentile(sims, 99, method='midpoint')))
241
+
242
+ for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
243
+ if stopmerge: return "STOPPED", *non4
244
+ if "model" in key and key in theta_1:
245
+ if usebeta and (not key in theta_2) and (not theta_2 == {}) :
246
+ continue
247
+
248
+ weight_index = -1
249
+ current_alpha = alpha
250
+ current_beta = beta
251
+
252
+ if key in chckpoint_dict_skip_on_merge:
253
+ continue
254
+
255
+ # check weighted and U-Net or not
256
+ if weights_a is not None and 'model.diffusion_model.' in key:
257
+ # check block index
258
+ weight_index = -1
259
+
260
+ if 'time_embed' in key:
261
+ weight_index = 0 # before input blocks
262
+ elif '.out.' in key:
263
+ weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
264
+ else:
265
+ m = re_inp.search(key)
266
+ if m:
267
+ inp_idx = int(m.groups()[0])
268
+ weight_index = inp_idx
269
+ else:
270
+ m = re_mid.search(key)
271
+ if m:
272
+ weight_index = NUM_INPUT_BLOCKS
273
+ else:
274
+ m = re_out.search(key)
275
+ if m:
276
+ out_idx = int(m.groups()[0])
277
+ weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
278
+
279
+ if weight_index >= NUM_TOTAL_BLOCKS:
280
+ print(f"ERROR: illegal block index: {key}")
281
+ return f"ERROR: illegal block index: {key}",*non4
282
+
283
+ if weight_index >= 0 and useblocks:
284
+ current_alpha = weights_a[weight_index]
285
+ if usebeta: current_beta = weights_b[weight_index]
286
+ else:
287
+ count_target_of_basealpha = count_target_of_basealpha + 1
288
+
289
+ if len(deep) > 0:
290
+ skey = key + blockid[weight_index+1]
291
+ for d in deep:
292
+ if d.count(":") != 2 :continue
293
+ dbs,dws,dr = d.split(":")[0],d.split(":")[1],d.split(":")[2]
294
+ dbs,dws = dbs.split(" "), dws.split(" ")
295
+ dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
296
+ dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
297
+ flag = dbn
298
+ for db in dbs:
299
+ if db in skey:
300
+ flag = not dbn
301
+ if flag:flag = dwn
302
+ else:continue
303
+ for dw in dws:
304
+ if dw in skey:
305
+ flag = not dwn
306
+ if flag:
307
+ dr = float(dr)
308
+ if deepprint :print(dbs,dws,key,dr)
309
+ current_alpha = dr
310
+
311
+ if calcmode == "normal":
312
+ if MODES[1] in mode:#Add
313
+ caster(f"model A[{key}] + {current_alpha} + * (model B - model C)[{key}]",hear)
314
+ theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
315
+ elif MODES[2] in mode:#Triple
316
+ caster(f"model A[{key}] + {1-current_alpha-current_beta} + model B[{key}]*{current_alpha} + model C[{key}]*{current_beta}",hear)
317
+ theta_0[key] = (1 - current_alpha-current_beta) * theta_0[key] + current_alpha * theta_1[key]+current_beta * theta_2[key]
318
+ elif MODES[3] in mode:#Twice
319
+ caster(f"model A[{key}] + {1-current_alpha} + * model B[{key}]*{alpha}",hear)
320
+ caster(f"model A+B[{key}] + {1-current_beta} + * model C[{key}]*{beta}",hear)
321
+ theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
322
+ theta_0[key] = (1 - current_beta) * theta_0[key] + current_beta * theta_2[key]
323
+ else:#Weight
324
+ if current_alpha == 1:
325
+ caster(f"alpha = 0,model A[{key}=model B[{key}",hear)
326
+ theta_0[key] = theta_1[key]
327
+ elif current_alpha !=0:
328
+ caster(f"model A[{key}] + {1-current_alpha} + * (model B)[{key}]*{alpha}",hear)
329
+ theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
330
+
331
+ elif calcmode == "cosineA": #favors modelA's structure with details from B
332
+ # skip VAE model parameters to get better results
333
+ if "first_stage_model" in key: continue
334
+ if "model" in key and key in theta_0:
335
+ # Normalize the vectors before merging
336
+ theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
337
+ theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
338
+ simab = sim(theta_0_norm, theta_1_norm)
339
+ dot_product = torch.dot(theta_0_norm.view(-1), theta_1_norm.view(-1))
340
+ magnitude_similarity = dot_product / (torch.norm(theta_0_norm) * torch.norm(theta_1_norm))
341
+ combined_similarity = (simab + magnitude_similarity) / 2.0
342
+ k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
343
+ k = k - current_alpha
344
+ k = k.clip(min=.0,max=1.)
345
+ caster(f"model A[{key}] + {1-k} + * (model B)[{key}]*{k}",hear)
346
+ theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
347
+
348
+ elif calcmode == "cosineB": #favors modelB's structure with details from A
349
+ # skip VAE model parameters to get better results
350
+ if "first_stage_model" in key: continue
351
+ if "model" in key and key in theta_0:
352
+ simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
353
+ dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
354
+ magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
355
+ combined_similarity = (simab + magnitude_similarity) / 2.0
356
+ k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
357
+ k = k - current_alpha
358
+ k = k.clip(min=.0,max=1.)
359
+ caster(f"model A[{key}] + {1-k} + * (model B)[{key}]*{k}",hear)
360
+ theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
361
+
362
+ elif calcmode == "smoothAdd":
363
+ caster(f"model A[{key}] + {current_alpha} + * (model B - model C)[{key}]", hear)
364
+ # Apply median filter to the weight differences
365
+ filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
366
+ # Apply Gaussian filter to the filtered differences
367
+ filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
368
+ theta_1[key] = torch.tensor(filtered_diff)
369
+ # Add the filtered differences to the original weights
370
+ theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
371
+
372
+ elif calcmode == "tensor":
373
+ dim = theta_0[key].dim()
374
+ if dim == 0 : continue
375
+ if current_alpha+current_beta <= 1 :
376
+ talphas = int(theta_0[key].shape[0]*(current_beta))
377
+ talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
378
+ if dim == 1:
379
+ theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
380
+
381
+ elif dim == 2:
382
+ theta_0[key][talphas:talphae,:] = theta_1[key][talphas:talphae,:].clone()
383
+
384
+ elif dim == 3:
385
+ theta_0[key][talphas:talphae,:,:] = theta_1[key][talphas:talphae,:,:].clone()
386
+
387
+ elif dim == 4:
388
+ theta_0[key][talphas:talphae,:,:,:] = theta_1[key][talphas:talphae,:,:,:].clone()
389
+
390
+ else:
391
+ talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
392
+ talphae = int(theta_0[key].shape[0]*(current_beta))
393
+ theta_t = theta_1[key].clone()
394
+ if dim == 1:
395
+ theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
396
+
397
+ elif dim == 2:
398
+ theta_t[talphas:talphae,:] = theta_0[key][talphas:talphae,:].clone()
399
+
400
+ elif dim == 3:
401
+ theta_t[talphas:talphae,:,:] = theta_0[key][talphas:talphae,:,:].clone()
402
+
403
+ elif dim == 4:
404
+ theta_t[talphas:talphae,:,:,:] = theta_0[key][talphas:talphae,:,:,:].clone()
405
+ theta_0[key] = theta_t
406
+
407
+ currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode)
408
+
409
+ for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
410
+ if key in chckpoint_dict_skip_on_merge:
411
+ continue
412
+ if "model" in key and key not in theta_0:
413
+ theta_0.update({key:theta_1[key]})
414
+
415
+ del theta_1
416
+
417
+ modelid = rwmergelog(currentmodel,mergedmodel)
418
+
419
+ caster(mergedmodel,False)
420
+
421
+ if save_metadata:
422
+ merge_recipe = {
423
+ "type": "sd-webui-supermerger",
424
+ "weights_alpha": weights_a if useblocks else None,
425
+ "weights_beta": weights_b if useblocks else None,
426
+ "weights_alpha_orig": weights_a_orig if useblocks else None,
427
+ "weights_beta_orig": weights_b_orig if useblocks else None,
428
+ "model_a": longhashfromname(model_a),
429
+ "model_b": longhashfromname(model_b),
430
+ "model_c": longhashfromname(model_c),
431
+ "base_alpha": base_alpha,
432
+ "base_beta": base_beta,
433
+ "mode": mode,
434
+ "mbw": useblocks,
435
+ "elemental_merge": deep,
436
+ "calcmode" : calcmode
437
+ }
438
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
439
+ metadata["sd_merge_models"] = {}
440
+
441
+ def add_model_metadata(checkpoint_name):
442
+ checkpoint_info = sd_models.get_closet_checkpoint_match(checkpoint_name)
443
+ checkpoint_info.calculate_shorthash()
444
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
445
+ "name": checkpoint_name,
446
+ "legacy_hash": checkpoint_info.hash
447
+ }
448
+
449
+ #metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
450
+
451
+ if model_a:
452
+ add_model_metadata(model_a)
453
+ if model_b:
454
+ add_model_metadata(model_b)
455
+ if model_c:
456
+ add_model_metadata(model_c)
457
+
458
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
459
+
460
+ return "",currentmodel,modelid,theta_0,metadata
461
+ def forkforker(filename):
462
+ try:
463
+ return sd_models.read_state_dict(filename,"cuda")
464
+ except:
465
+ return sd_models.read_state_dict(filename)
466
+
467
+ def load_model_weights_m(model,model_a,model_b,save):
468
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model)
469
+ sd_model_name = checkpoint_info.model_name
470
+
471
+ cachenum = shared.opts.sd_checkpoint_cache
472
+
473
+ if save:
474
+ if model_a:
475
+ load_model(checkpoint_info)
476
+ print(f"Loading weights [{sd_model_name}] from file")
477
+ return forkforker(checkpoint_info.filename)
478
+
479
+ if checkpoint_info in checkpoints_loaded:
480
+ print(f"Loading weights [{sd_model_name}] from cache")
481
+ return checkpoints_loaded[checkpoint_info]
482
+ elif cachenum>0 and model_a:
483
+ load_model(checkpoint_info)
484
+ print(f"Loading weights [{sd_model_name}] from cache")
485
+ return checkpoints_loaded[checkpoint_info]
486
+ elif cachenum>1 and model_b:
487
+ load_model(checkpoint_info)
488
+ print(f"Loading weights [{sd_model_name}] from cache")
489
+ return checkpoints_loaded[checkpoint_info]
490
+ elif cachenum>2:
491
+ load_model(checkpoint_info)
492
+ print(f"Loading weights [{sd_model_name}] from cache")
493
+ return checkpoints_loaded[checkpoint_info]
494
+ else:
495
+ if model_a:
496
+ load_model(checkpoint_info)
497
+ print(f"Loading weights [{sd_model_name}] from file")
498
+ return forkforker(checkpoint_info.filename)
499
+
500
+ def makemodelname(weights_a,weights_b,model_a, model_b,model_c, alpha,beta,useblocks,mode):
501
+ model_a=filenamecutter(model_a)
502
+ model_b=filenamecutter(model_b)
503
+ model_c=filenamecutter(model_c)
504
+
505
+ if type(alpha) == str:alpha = float(alpha)
506
+ if type(beta)== str:beta = float(beta)
507
+
508
+ if useblocks:
509
+ if MODES[1] in mode:#add
510
+ currentmodel =f"{model_a} + ({model_b} - {model_c}) x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)}"
511
+ elif MODES[2] in mode:#triple
512
+ currentmodel =f"{model_a} x (1-alpha-beta) + {model_b} x alpha + {model_c} x beta (alpha = {str(round(alpha,3))},{','.join(str(s) for s in weights_a)},beta = {beta},{','.join(str(s) for s in weights_b)})"
513
+ elif MODES[3] in mode:#twice
514
+ currentmodel =f"({model_a} x (1-alpha) + {model_b} x alpha)x(1-beta)+ {model_c} x beta ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})_({str(round(beta,3))},{','.join(str(s) for s in weights_b)})"
515
+ else:
516
+ currentmodel =f"{model_a} x (1-alpha) + {model_b} x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
517
+ else:
518
+ if MODES[1] in mode:#add
519
+ currentmodel =f"{model_a} + ({model_b} - {model_c}) x {str(round(alpha,3))}"
520
+ elif MODES[2] in mode:#triple
521
+ currentmodel =f"{model_a} x {str(round(1-alpha-beta,3))} + {model_b} x {str(round(alpha,3))} + {model_c} x {str(round(beta,3))}"
522
+ elif MODES[3] in mode:#twice
523
+ currentmodel =f"({model_a} x {str(round(1-alpha,3))} +{model_b} x {str(round(alpha,3))}) x {str(round(1-beta,3))} + {model_c} x {str(round(beta,3))}"
524
+ else:
525
+ currentmodel =f"{model_a} x {str(round(1-alpha,3))} + {model_b} x {str(round(alpha,3))}"
526
+ return currentmodel
527
+
528
+ path_root = scripts.basedir()
529
+
530
+ def rwmergelog(mergedname = "",settings= [],id = 0):
531
+ setting = settings.copy()
532
+ filepath = os.path.join(path_root, "mergehistory.csv")
533
+ is_file = os.path.isfile(filepath)
534
+ if not is_file:
535
+ with open(filepath, 'a') as f:
536
+ #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets, 12 deep 13 calcmode]
537
+ f.writelines('"ID","time","name","weights alpha","weights beta","model A","model B","model C","alpha","beta","mode","use MBW","plus lora","custum name","save setting","use ID"\n')
538
+ with open(filepath, 'r+') as f:
539
+ reader = csv.reader(f)
540
+ mlist = [raw for raw in reader]
541
+ if mergedname != "":
542
+ mergeid = len(mlist)
543
+ setting.insert(0,mergedname)
544
+ for i,x in enumerate(setting):
545
+ if "," in str(x):setting[i] = f'"{str(setting[i])}"'
546
+ text = ",".join(map(str, setting))
547
+ text=str(mergeid)+","+datetime.datetime.now().strftime('%Y.%m.%d %H.%M.%S.%f')[:-7]+"," + text + "\n"
548
+ f.writelines(text)
549
+ return mergeid
550
+ try:
551
+ out = mlist[int(id)]
552
+ except:
553
+ out = "ERROR: OUT of ID index"
554
+ return out
555
+
556
+ def draw_origin(grid, text,width,height,width_one):
557
+ grid_d= Image.new("RGB", (grid.width,grid.height), "white")
558
+ grid_d.paste(grid,(0,0))
559
+ def get_font(fontsize):
560
+ try:
561
+ from fonts.ttf import Roboto
562
+ try:
563
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
564
+ except Exception:
565
+ return ImageFont.truetype(Roboto, fontsize)
566
+ except Exception:
567
+ try:
568
+ return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
569
+ except Exception:
570
+ return ImageFont.truetype('javascript/roboto.ttf', fontsize)
571
+
572
+ d= ImageDraw.Draw(grid_d)
573
+ color_active = (0, 0, 0)
574
+ fontsize = (width+height)//25
575
+ fnt = get_font(fontsize)
576
+
577
+ if grid.width != width_one:
578
+ while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
579
+ fontsize -=1
580
+ fnt = get_font(fontsize)
581
+ d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
582
+ return grid_d
583
+
584
+ def wpreseter(w,presets):
585
+ if "," not in w and w != "":
586
+ presets=presets.splitlines()
587
+ wdict={}
588
+ for l in presets:
589
+ if ":" in l :
590
+ key = l.split(":",1)[0]
591
+ wdict[key.strip()]=l.split(":",1)[1]
592
+ if "\t" in l:
593
+ key = l.split("\t",1)[0]
594
+ wdict[key.strip()]=l.split("\t",1)[1]
595
+ if w.strip() in wdict:
596
+ name = w
597
+ w = wdict[w.strip()]
598
+ print(f"weights {name} imported from presets : {w}")
599
+ return w
600
+
601
+ def fullpathfromname(name):
602
+ if hash == "" or hash ==[]: return ""
603
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
604
+ return checkpoint_info.filename
605
+
606
+ def namefromhash(hash):
607
+ if hash == "" or hash ==[]: return ""
608
+ checkpoint_info = sd_models.get_closet_checkpoint_match(hash)
609
+ return checkpoint_info.model_name
610
+
611
+ def hashfromname(name):
612
+ from modules import sd_models
613
+ if name == "" or name ==[]: return ""
614
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
615
+ if checkpoint_info.shorthash is not None:
616
+ return checkpoint_info.shorthash
617
+ return checkpoint_info.calculate_shorthash()
618
+
619
+ def longhashfromname(name):
620
+ from modules import sd_models
621
+ if name == "" or name ==[]: return ""
622
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
623
+ if checkpoint_info.sha256 is not None:
624
+ return checkpoint_info.sha256
625
+ checkpoint_info.calculate_shorthash()
626
+ return checkpoint_info.sha256
627
+
628
+ def simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,genoptions,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,mergeinfo="",id_sets=[],modelid = "no id"):
629
+ shared.state.begin()
630
+ p = processing.StableDiffusionProcessingTxt2Img(
631
+ sd_model=shared.sd_model,
632
+ do_not_save_grid=True,
633
+ do_not_save_samples=True,
634
+ do_not_reload_embeddings=True,
635
+ )
636
+ p.batch_size = int(batch_size)
637
+ p.prompt = prompt
638
+ p.negative_prompt = nprompt
639
+ p.steps = steps
640
+ p.sampler_name = sd_samplers.samplers[sampler].name
641
+ p.cfg_scale = cfg
642
+ p.seed = seed
643
+ p.width = w
644
+ p.height = h
645
+ p.seed_resize_from_w=0
646
+ p.seed_resize_from_h=0
647
+ p.denoising_strength=None
648
+
649
+ #"Restore faces", "Tiling", "Hires. fix"
650
+
651
+ if "Hires. fix" in genoptions:
652
+ p.enable_hr = True
653
+ p.denoising_strength = denoise_str
654
+ p.hr_upscaler = hrupscaler
655
+ p.hr_second_pass_steps = hr2ndsteps
656
+ p.hr_scale = hr_scale
657
+
658
+ if "Tiling" in genoptions:
659
+ p.tiling = True
660
+
661
+ if "Restore faces" in genoptions:
662
+ p.restore_faces = True
663
+
664
+ if type(p.prompt) == list:
665
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
666
+ else:
667
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
668
+
669
+ if type(p.negative_prompt) == list:
670
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
671
+ else:
672
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
673
+
674
+ processed:Processed = processing.process_images(p)
675
+ if "image" in id_sets:
676
+ for i, image in enumerate(processed.images):
677
+ processed.images[i] = draw_origin(image, str(modelid),w,h,w)
678
+
679
+ if "PNG info" in id_sets:mergeinfo = mergeinfo + " ID " + str(modelid)
680
+
681
+ infotext = create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds)
682
+ if infotext.count("Steps: ")>1:
683
+ infotext = infotext[:infotext.rindex("Steps")]
684
+
685
+ infotexts = infotext.split(",")
686
+ for i,x in enumerate(infotexts):
687
+ if "Model:"in x:
688
+ infotexts[i] = " Model: "+mergeinfo.replace(","," ")
689
+ infotext= ",".join(infotexts)
690
+
691
+ for i, image in enumerate(processed.images):
692
+ images.save_image(image, opts.outdir_txt2img_samples, "",p.seed, p.prompt,shared.opts.samples_format, p=p,info=infotext)
693
+
694
+ if batch_size > 1:
695
+ grid = images.image_grid(processed.images, p.batch_size)
696
+ processed.images.insert(0, grid)
697
+ images.save_image(grid, opts.outdir_txt2img_grids, "grid", p.seed, p.prompt, opts.grid_format, info=infotext, short_filename=not opts.grid_extended_filename, p=p, grid=True)
698
+ shared.state.end()
699
+ return processed.images,infotext,plaintext_to_html(processed.info), plaintext_to_html(processed.comments),p
microsoftexcel-supermerger/scripts/mergers/model_util.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPTextModel, CLIPTextConfig
4
+ from safetensors.torch import load_file
5
+ import safetensors.torch
6
+ from modules.sd_models import read_state_dict
7
+
8
+ # DiffUsers版StableDiffusionのモデルパラメータ
9
+ NUM_TRAIN_TIMESTEPS = 1000
10
+ BETA_START = 0.00085
11
+ BETA_END = 0.0120
12
+
13
+ UNET_PARAMS_MODEL_CHANNELS = 320
14
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
15
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
16
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
17
+ UNET_PARAMS_IN_CHANNELS = 4
18
+ UNET_PARAMS_OUT_CHANNELS = 4
19
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
20
+ UNET_PARAMS_CONTEXT_DIM = 768
21
+ UNET_PARAMS_NUM_HEADS = 8
22
+
23
+ VAE_PARAMS_Z_CHANNELS = 4
24
+ VAE_PARAMS_RESOLUTION = 256
25
+ VAE_PARAMS_IN_CHANNELS = 3
26
+ VAE_PARAMS_OUT_CH = 3
27
+ VAE_PARAMS_CH = 128
28
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
29
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
30
+
31
+ # V2
32
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
33
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
34
+
35
+ # Diffusersの設定を読み込むための参照モデル
36
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
37
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
38
+
39
+
40
+ # region StableDiffusion->Diffusersの変換コード
41
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
42
+
43
+
44
+ def shave_segments(path, n_shave_prefix_segments=1):
45
+ """
46
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
47
+ """
48
+ if n_shave_prefix_segments >= 0:
49
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
50
+ else:
51
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
52
+
53
+
54
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
55
+ """
56
+ Updates paths inside resnets to the new naming scheme (local renaming)
57
+ """
58
+ mapping = []
59
+ for old_item in old_list:
60
+ new_item = old_item.replace("in_layers.0", "norm1")
61
+ new_item = new_item.replace("in_layers.2", "conv1")
62
+
63
+ new_item = new_item.replace("out_layers.0", "norm2")
64
+ new_item = new_item.replace("out_layers.3", "conv2")
65
+
66
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
67
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
68
+
69
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
70
+
71
+ mapping.append({"old": old_item, "new": new_item})
72
+
73
+ return mapping
74
+
75
+
76
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
77
+ """
78
+ Updates paths inside resnets to the new naming scheme (local renaming)
79
+ """
80
+ mapping = []
81
+ for old_item in old_list:
82
+ new_item = old_item
83
+
84
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
85
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
86
+
87
+ mapping.append({"old": old_item, "new": new_item})
88
+
89
+ return mapping
90
+
91
+
92
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
93
+ """
94
+ Updates paths inside attentions to the new naming scheme (local renaming)
95
+ """
96
+ mapping = []
97
+ for old_item in old_list:
98
+ new_item = old_item
99
+
100
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
101
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
102
+
103
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
104
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
105
+
106
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
107
+
108
+ mapping.append({"old": old_item, "new": new_item})
109
+
110
+ return mapping
111
+
112
+
113
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
114
+ """
115
+ Updates paths inside attentions to the new naming scheme (local renaming)
116
+ """
117
+ mapping = []
118
+ for old_item in old_list:
119
+ new_item = old_item
120
+
121
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
122
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
123
+
124
+ new_item = new_item.replace("q.weight", "query.weight")
125
+ new_item = new_item.replace("q.bias", "query.bias")
126
+
127
+ new_item = new_item.replace("k.weight", "key.weight")
128
+ new_item = new_item.replace("k.bias", "key.bias")
129
+
130
+ new_item = new_item.replace("v.weight", "value.weight")
131
+ new_item = new_item.replace("v.bias", "value.bias")
132
+
133
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
134
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
135
+
136
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
137
+
138
+ mapping.append({"old": old_item, "new": new_item})
139
+
140
+ return mapping
141
+
142
+
143
+ def assign_to_checkpoint(
144
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
145
+ ):
146
+ """
147
+ This does the final conversion step: take locally converted weights and apply a global renaming
148
+ to them. It splits attention layers, and takes into account additional replacements
149
+ that may arise.
150
+
151
+ Assigns the weights to the new checkpoint.
152
+ """
153
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
154
+
155
+ # Splits the attention layers into three variables.
156
+ if attention_paths_to_split is not None:
157
+ for path, path_map in attention_paths_to_split.items():
158
+ old_tensor = old_checkpoint[path]
159
+ channels = old_tensor.shape[0] // 3
160
+
161
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
162
+
163
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
164
+
165
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
166
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
167
+
168
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
169
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
170
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
171
+
172
+ for path in paths:
173
+ new_path = path["new"]
174
+
175
+ # These have already been assigned
176
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
177
+ continue
178
+
179
+ # Global renaming happens here
180
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
181
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
182
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
183
+
184
+ if additional_replacements is not None:
185
+ for replacement in additional_replacements:
186
+ new_path = new_path.replace(replacement["old"], replacement["new"])
187
+
188
+ # proj_attn.weight has to be converted from conv 1D to linear
189
+ if "proj_attn.weight" in new_path:
190
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
191
+ else:
192
+ checkpoint[new_path] = old_checkpoint[path["old"]]
193
+
194
+
195
+ def conv_attn_to_linear(checkpoint):
196
+ keys = list(checkpoint.keys())
197
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
198
+ for key in keys:
199
+ if ".".join(key.split(".")[-2:]) in attn_keys:
200
+ if checkpoint[key].ndim > 2:
201
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
202
+ elif "proj_attn.weight" in key:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0]
205
+
206
+
207
+ def linear_transformer_to_conv(checkpoint):
208
+ keys = list(checkpoint.keys())
209
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
210
+ for key in keys:
211
+ if ".".join(key.split(".")[-2:]) in tf_keys:
212
+ if checkpoint[key].ndim == 2:
213
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
214
+
215
+
216
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
217
+ """
218
+ Takes a state dict and a config, and returns a converted checkpoint.
219
+ """
220
+
221
+ # extract state_dict for UNet
222
+ unet_state_dict = {}
223
+ unet_key = "model.diffusion_model."
224
+ keys = list(checkpoint.keys())
225
+ for key in keys:
226
+ if key.startswith(unet_key):
227
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
228
+
229
+ new_checkpoint = {}
230
+
231
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
232
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
233
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
234
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
235
+
236
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
237
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
238
+
239
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
240
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
241
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
242
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
243
+
244
+ # Retrieves the keys for the input blocks only
245
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
246
+ input_blocks = {
247
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
248
+ for layer_id in range(num_input_blocks)
249
+ }
250
+
251
+ # Retrieves the keys for the middle blocks only
252
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
253
+ middle_blocks = {
254
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
255
+ for layer_id in range(num_middle_blocks)
256
+ }
257
+
258
+ # Retrieves the keys for the output blocks only
259
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
260
+ output_blocks = {
261
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
262
+ for layer_id in range(num_output_blocks)
263
+ }
264
+
265
+ for i in range(1, num_input_blocks):
266
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
267
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
268
+
269
+ resnets = [
270
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
271
+ ]
272
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
273
+
274
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
275
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
276
+ f"input_blocks.{i}.0.op.weight"
277
+ )
278
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
279
+ f"input_blocks.{i}.0.op.bias"
280
+ )
281
+
282
+ paths = renew_resnet_paths(resnets)
283
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
284
+ assign_to_checkpoint(
285
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
286
+ )
287
+
288
+ if len(attentions):
289
+ paths = renew_attention_paths(attentions)
290
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
291
+ assign_to_checkpoint(
292
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
293
+ )
294
+
295
+ resnet_0 = middle_blocks[0]
296
+ attentions = middle_blocks[1]
297
+ resnet_1 = middle_blocks[2]
298
+
299
+ resnet_0_paths = renew_resnet_paths(resnet_0)
300
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
301
+
302
+ resnet_1_paths = renew_resnet_paths(resnet_1)
303
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
304
+
305
+ attentions_paths = renew_attention_paths(attentions)
306
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
307
+ assign_to_checkpoint(
308
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
309
+ )
310
+
311
+ for i in range(num_output_blocks):
312
+ block_id = i // (config["layers_per_block"] + 1)
313
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
314
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
315
+ output_block_list = {}
316
+
317
+ for layer in output_block_layers:
318
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
319
+ if layer_id in output_block_list:
320
+ output_block_list[layer_id].append(layer_name)
321
+ else:
322
+ output_block_list[layer_id] = [layer_name]
323
+
324
+ if len(output_block_list) > 1:
325
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
326
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
327
+
328
+ resnet_0_paths = renew_resnet_paths(resnets)
329
+ paths = renew_resnet_paths(resnets)
330
+
331
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
332
+ assign_to_checkpoint(
333
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
334
+ )
335
+
336
+ # オリジナル:
337
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
338
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
339
+
340
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
341
+ for l in output_block_list.values():
342
+ l.sort()
343
+
344
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
345
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
346
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
347
+ f"output_blocks.{i}.{index}.conv.bias"
348
+ ]
349
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
350
+ f"output_blocks.{i}.{index}.conv.weight"
351
+ ]
352
+
353
+ # Clear attentions as they have been attributed above.
354
+ if len(attentions) == 2:
355
+ attentions = []
356
+
357
+ if len(attentions):
358
+ paths = renew_attention_paths(attentions)
359
+ meta_path = {
360
+ "old": f"output_blocks.{i}.1",
361
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
362
+ }
363
+ assign_to_checkpoint(
364
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
365
+ )
366
+ else:
367
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
368
+ for path in resnet_0_paths:
369
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
370
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
371
+
372
+ new_checkpoint[new_path] = unet_state_dict[old_path]
373
+
374
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
375
+ if v2:
376
+ linear_transformer_to_conv(new_checkpoint)
377
+
378
+ return new_checkpoint
379
+
380
+
381
+ def convert_ldm_vae_checkpoint(checkpoint, config):
382
+ # extract state dict for VAE
383
+ vae_state_dict = {}
384
+ vae_key = "first_stage_model."
385
+ keys = list(checkpoint.keys())
386
+ for key in keys:
387
+ if key.startswith(vae_key):
388
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
389
+ # if len(vae_state_dict) == 0:
390
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
391
+ # vae_state_dict = checkpoint
392
+
393
+ new_checkpoint = {}
394
+
395
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
396
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
397
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
398
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
399
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
400
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
401
+
402
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
403
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
404
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
405
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
406
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
407
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
408
+
409
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
410
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
411
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
412
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
413
+
414
+ # Retrieves the keys for the encoder down blocks only
415
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
416
+ down_blocks = {
417
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
418
+ }
419
+
420
+ # Retrieves the keys for the decoder up blocks only
421
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
422
+ up_blocks = {
423
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
424
+ }
425
+
426
+ for i in range(num_down_blocks):
427
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
428
+
429
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
430
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
431
+ f"encoder.down.{i}.downsample.conv.weight"
432
+ )
433
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
434
+ f"encoder.down.{i}.downsample.conv.bias"
435
+ )
436
+
437
+ paths = renew_vae_resnet_paths(resnets)
438
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
439
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
440
+
441
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
442
+ num_mid_res_blocks = 2
443
+ for i in range(1, num_mid_res_blocks + 1):
444
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
445
+
446
+ paths = renew_vae_resnet_paths(resnets)
447
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
448
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
449
+
450
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
451
+ paths = renew_vae_attention_paths(mid_attentions)
452
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
453
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
454
+ conv_attn_to_linear(new_checkpoint)
455
+
456
+ for i in range(num_up_blocks):
457
+ block_id = num_up_blocks - 1 - i
458
+ resnets = [
459
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
460
+ ]
461
+
462
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
463
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
464
+ f"decoder.up.{block_id}.upsample.conv.weight"
465
+ ]
466
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
467
+ f"decoder.up.{block_id}.upsample.conv.bias"
468
+ ]
469
+
470
+ paths = renew_vae_resnet_paths(resnets)
471
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
472
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
473
+
474
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
475
+ num_mid_res_blocks = 2
476
+ for i in range(1, num_mid_res_blocks + 1):
477
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
478
+
479
+ paths = renew_vae_resnet_paths(resnets)
480
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
481
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
482
+
483
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
484
+ paths = renew_vae_attention_paths(mid_attentions)
485
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
486
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
487
+ conv_attn_to_linear(new_checkpoint)
488
+ return new_checkpoint
489
+
490
+
491
+ def create_unet_diffusers_config(v2):
492
+ """
493
+ Creates a config for the diffusers based on the config of the LDM model.
494
+ """
495
+ # unet_params = original_config.model.params.unet_config.params
496
+
497
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
498
+
499
+ down_block_types = []
500
+ resolution = 1
501
+ for i in range(len(block_out_channels)):
502
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
503
+ down_block_types.append(block_type)
504
+ if i != len(block_out_channels) - 1:
505
+ resolution *= 2
506
+
507
+ up_block_types = []
508
+ for i in range(len(block_out_channels)):
509
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
510
+ up_block_types.append(block_type)
511
+ resolution //= 2
512
+
513
+ config = dict(
514
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
515
+ in_channels=UNET_PARAMS_IN_CHANNELS,
516
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
517
+ down_block_types=tuple(down_block_types),
518
+ up_block_types=tuple(up_block_types),
519
+ block_out_channels=tuple(block_out_channels),
520
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
521
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
522
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
523
+ )
524
+
525
+ return config
526
+
527
+
528
+ def create_vae_diffusers_config():
529
+ """
530
+ Creates a config for the diffusers based on the config of the LDM model.
531
+ """
532
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
533
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
534
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
535
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
536
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
537
+
538
+ config = dict(
539
+ sample_size=VAE_PARAMS_RESOLUTION,
540
+ in_channels=VAE_PARAMS_IN_CHANNELS,
541
+ out_channels=VAE_PARAMS_OUT_CH,
542
+ down_block_types=tuple(down_block_types),
543
+ up_block_types=tuple(up_block_types),
544
+ block_out_channels=tuple(block_out_channels),
545
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
546
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
547
+ )
548
+ return config
549
+
550
+
551
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
552
+ keys = list(checkpoint.keys())
553
+ text_model_dict = {}
554
+ for key in keys:
555
+ if key.startswith("cond_stage_model.transformer"):
556
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
557
+ return text_model_dict
558
+
559
+
560
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
561
+ # 嫌になるくらい違うぞ!
562
+ def convert_key(key):
563
+ if not key.startswith("cond_stage_model"):
564
+ return None
565
+
566
+ # common conversion
567
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
568
+ key = key.replace("cond_stage_model.model.", "text_model.")
569
+
570
+ if "resblocks" in key:
571
+ # resblocks conversion
572
+ key = key.replace(".resblocks.", ".layers.")
573
+ if ".ln_" in key:
574
+ key = key.replace(".ln_", ".layer_norm")
575
+ elif ".mlp." in key:
576
+ key = key.replace(".c_fc.", ".fc1.")
577
+ key = key.replace(".c_proj.", ".fc2.")
578
+ elif '.attn.out_proj' in key:
579
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
580
+ elif '.attn.in_proj' in key:
581
+ key = None # 特殊なので後で処理する
582
+ else:
583
+ raise ValueError(f"unexpected key in SD: {key}")
584
+ elif '.positional_embedding' in key:
585
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
586
+ elif '.text_projection' in key:
587
+ key = None # 使われない???
588
+ elif '.logit_scale' in key:
589
+ key = None # 使われない???
590
+ elif '.token_embedding' in key:
591
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
592
+ elif '.ln_final' in key:
593
+ key = key.replace(".ln_final", ".final_layer_norm")
594
+ return key
595
+
596
+ keys = list(checkpoint.keys())
597
+ new_sd = {}
598
+ for key in keys:
599
+ # remove resblocks 23
600
+ if '.resblocks.23.' in key:
601
+ continue
602
+ new_key = convert_key(key)
603
+ if new_key is None:
604
+ continue
605
+ new_sd[new_key] = checkpoint[key]
606
+
607
+ # attnの変換
608
+ for key in keys:
609
+ if '.resblocks.23.' in key:
610
+ continue
611
+ if '.resblocks' in key and '.attn.in_proj_' in key:
612
+ # 三つに分割
613
+ values = torch.chunk(checkpoint[key], 3)
614
+
615
+ key_suffix = ".weight" if "weight" in key else ".bias"
616
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
617
+ key_pfx = key_pfx.replace("_weight", "")
618
+ key_pfx = key_pfx.replace("_bias", "")
619
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
620
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
621
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
622
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
623
+
624
+ # rename or add position_ids
625
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
626
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
627
+ # waifu diffusion v1.4
628
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
629
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
630
+ else:
631
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
632
+
633
+ new_sd["text_model.embeddings.position_ids"] = position_ids
634
+ return new_sd
635
+
636
+ def is_safetensors(path):
637
+ return os.path.splitext(path)[1].lower() == '.safetensors'
638
+
639
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
640
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
641
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
642
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
643
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
644
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
645
+ ]
646
+
647
+ state_dict = read_state_dict(ckpt_path)
648
+
649
+ key_reps = []
650
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
651
+ for key in state_dict.keys():
652
+ if key.startswith(rep_from):
653
+ new_key = rep_to + key[len(rep_from):]
654
+ key_reps.append((key, new_key))
655
+
656
+ for key, new_key in key_reps:
657
+ state_dict[new_key] = state_dict[key]
658
+ del state_dict[key]
659
+
660
+ return state_dict
661
+
662
+ def to_half(sd):
663
+ for key in sd.keys():
664
+ if 'model' in key and sd[key].dtype == torch.float:
665
+ sd[key] = sd[key].half()
666
+ return sd
667
+
668
+ def savemodel(state_dict,currentmodel,fname,savesets,model_a,metadata={}):
669
+ from modules import sd_models,shared
670
+ if "fp16" in savesets:
671
+ state_dict = to_half(state_dict)
672
+ pre = "fp16"
673
+ else:pre = ""
674
+ ext = ".safetensors" if "safetensors" in savesets else ".ckpt"
675
+
676
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
677
+ model_a_path= checkpoint_info.filename
678
+ modeldir = os.path.split(model_a_path)[0]
679
+
680
+ if not fname or fname == "":
681
+ fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext
682
+ if fname[0]=="_":fname = fname[1:]
683
+ else:
684
+ fname = fname if ext in fname else fname +pre+ext
685
+
686
+ fname = os.path.join(modeldir, fname)
687
+
688
+ if len(fname) > 255:
689
+ fname.replace(ext,"")
690
+ fname=fname[:240]+ext
691
+
692
+ # check if output file already exists
693
+ if os.path.isfile(fname) and not "overwrite" in savesets:
694
+ _err_msg = f"Output file ({fname}) existed and was not saved]"
695
+ print(_err_msg)
696
+ return _err_msg
697
+
698
+ print("Saving...")
699
+ if ext == ".safetensors":
700
+ safetensors.torch.save_file(state_dict, fname, metadata=metadata)
701
+ else:
702
+ torch.save(state_dict, fname)
703
+ print("Done!")
704
+ return "Merged model saved in "+fname
705
+
706
+ def filenamecutter(name,model_a = False):
707
+ from modules import sd_models
708
+ if name =="" or name ==[]: return
709
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
710
+ name= os.path.splitext(checkpoint_info.filename)[0]
711
+
712
+ if not model_a:
713
+ name = os.path.basename(name)
714
+ return name
715
+
716
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
717
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
718
+ import diffusers
719
+ print("diffusers version : ",diffusers.__version__)
720
+ state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
721
+ if dtype is not None:
722
+ for k, v in state_dict.items():
723
+ if type(v) is torch.Tensor:
724
+ state_dict[k] = v.to(dtype)
725
+
726
+ # Convert the UNet2DConditionModel model.
727
+ unet_config = create_unet_diffusers_config(v2)
728
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
729
+
730
+ unet = diffusers.UNet2DConditionModel(**unet_config)
731
+ info = unet.load_state_dict(converted_unet_checkpoint)
732
+ print("loading u-net:", info)
733
+
734
+ # Convert the VAE model.
735
+ vae_config = create_vae_diffusers_config()
736
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
737
+
738
+ vae = diffusers.AutoencoderKL(**vae_config)
739
+ info = vae.load_state_dict(converted_vae_checkpoint)
740
+ print("loading vae:", info)
741
+
742
+ # convert text_model
743
+ if v2:
744
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
745
+ cfg = CLIPTextConfig(
746
+ vocab_size=49408,
747
+ hidden_size=1024,
748
+ intermediate_size=4096,
749
+ num_hidden_layers=23,
750
+ num_attention_heads=16,
751
+ max_position_embeddings=77,
752
+ hidden_act="gelu",
753
+ layer_norm_eps=1e-05,
754
+ dropout=0.0,
755
+ attention_dropout=0.0,
756
+ initializer_range=0.02,
757
+ initializer_factor=1.0,
758
+ pad_token_id=1,
759
+ bos_token_id=0,
760
+ eos_token_id=2,
761
+ model_type="clip_text_model",
762
+ projection_dim=512,
763
+ torch_dtype="float32",
764
+ transformers_version="4.25.0.dev0",
765
+ )
766
+ text_model = CLIPTextModel._from_config(cfg)
767
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
768
+ else:
769
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
770
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
771
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
772
+ print("loading text encoder:", info)
773
+
774
+ return text_model, vae, unet
775
+
776
+ def usemodelgen(theta_0,model_a,model_name):
777
+ from modules import lowvram, devices, sd_hijack,shared, sd_vae
778
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
779
+
780
+ model = shared.sd_model
781
+ model.load_state_dict(theta_0, strict=False)
782
+ del theta_0
783
+ if shared.cmd_opts.opt_channelslast:
784
+ model.to(memory_format=torch.channels_last)
785
+
786
+ if not shared.cmd_opts.no_half:
787
+ vae = model.first_stage_model
788
+
789
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
790
+ if shared.cmd_opts.no_half_vae:
791
+ model.first_stage_model = None
792
+
793
+ model.half()
794
+ model.first_stage_model = vae
795
+
796
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
797
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
798
+ devices.dtype_unet = model.model.diffusion_model.dtype
799
+
800
+ if hasattr(shared.cmd_opts,"upcast_sampling"):
801
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
802
+ else:
803
+ devices.unet_needs_upcast = devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
804
+
805
+ model.first_stage_model.to(devices.dtype_vae)
806
+ sd_hijack.model_hijack.hijack(model)
807
+
808
+ model.logvar = shared.sd_model.logvar.to(devices.device)
809
+
810
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
811
+ setup_for_low_vram_s(model, shared.cmd_opts.medvram)
812
+ else:
813
+ model.to(shared.device)
814
+
815
+ model.eval()
816
+
817
+ shared.sd_model = model
818
+ try:
819
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
820
+ except:
821
+ pass
822
+ #shared.sd_model.sd_checkpoint_info.model_name = model_name
823
+
824
+ def _setvae():
825
+ sd_vae.delete_base_vae()
826
+ sd_vae.clear_loaded_vae()
827
+ vae_file, vae_source = sd_vae.resolve_vae(model_a)
828
+ sd_vae.load_vae(shared.sd_model, vae_file, vae_source)
829
+
830
+ try:
831
+ _setvae()
832
+ except:
833
+ print("ERROR:setting VAE skipped")
834
+
835
+
836
+ import torch
837
+ from modules import devices
838
+
839
+ module_in_gpu = None
840
+ cpu = torch.device("cpu")
841
+
842
+
843
+ def send_everything_to_cpu():
844
+ global module_in_gpu
845
+
846
+ if module_in_gpu is not None:
847
+ module_in_gpu.to(cpu)
848
+
849
+ module_in_gpu = None
850
+
851
+ def setup_for_low_vram_s(sd_model, use_medvram):
852
+ parents = {}
853
+
854
+ def send_me_to_gpu(module, _):
855
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
856
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
857
+ be in CPU
858
+ """
859
+ global module_in_gpu
860
+
861
+ module = parents.get(module, module)
862
+
863
+ if module_in_gpu == module:
864
+ return
865
+
866
+ if module_in_gpu is not None:
867
+ module_in_gpu.to(cpu)
868
+
869
+ module.to(devices.device)
870
+ module_in_gpu = module
871
+
872
+ # see below for register_forward_pre_hook;
873
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
874
+ # useless here, and we just replace those methods
875
+
876
+ first_stage_model = sd_model.first_stage_model
877
+ first_stage_model_encode = sd_model.first_stage_model.encode
878
+ first_stage_model_decode = sd_model.first_stage_model.decode
879
+
880
+ def first_stage_model_encode_wrap(x):
881
+ send_me_to_gpu(first_stage_model, None)
882
+ return first_stage_model_encode(x)
883
+
884
+ def first_stage_model_decode_wrap(z):
885
+ send_me_to_gpu(first_stage_model, None)
886
+ return first_stage_model_decode(z)
887
+
888
+ # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
889
+ if hasattr(sd_model.cond_stage_model, 'model'):
890
+ sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
891
+
892
+ # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
893
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
894
+ stored = sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
895
+ sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None
896
+ sd_model.to(devices.device)
897
+ sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
898
+
899
+ # register hooks for those the first three models
900
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
901
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
902
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
903
+ if sd_model.depth_model:
904
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
905
+
906
+ if hasattr(sd_model.cond_stage_model, 'model'):
907
+ sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
908
+ del sd_model.cond_stage_model.transformer
909
+
910
+ if use_medvram:
911
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
912
+ else:
913
+ diff_model = sd_model.model.diffusion_model
914
+
915
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
916
+ # so that only one of them is in GPU at a time
917
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
918
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
919
+ sd_model.model.to(devices.device)
920
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
921
+
922
+ # install hooks for bits of third model
923
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
924
+ for block in diff_model.input_blocks:
925
+ block.register_forward_pre_hook(send_me_to_gpu)
926
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
927
+ for block in diff_model.output_blocks:
928
+ block.register_forward_pre_hook(send_me_to_gpu)
microsoftexcel-supermerger/scripts/mergers/pluslora.py ADDED
@@ -0,0 +1,1298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from sklearn.linear_model import PassiveAggressiveClassifier
3
+ import torch
4
+ import math
5
+ import os
6
+ import gc
7
+ import gradio as gr
8
+ from torchmetrics import Precision
9
+ import modules.shared as shared
10
+ import gc
11
+ from safetensors.torch import load_file, save_file
12
+ from typing import List
13
+ from tqdm import tqdm
14
+ from modules import sd_models,scripts
15
+ from scripts.mergers.model_util import load_models_from_stable_diffusion_checkpoint,filenamecutter,savemodel
16
+ from modules.ui import create_refresh_button
17
+
18
+ def on_ui_tabs():
19
+ import lora
20
+ sml_path_root = scripts.basedir()
21
+ LWEIGHTSPRESETS="\
22
+ NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
23
+ ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
24
+ INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
25
+ IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
26
+ INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
27
+ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
28
+ OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
29
+ OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
30
+ OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
31
+ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
32
+ sml_filepath = os.path.join(sml_path_root,"scripts", "lbwpresets.txt")
33
+ sml_lbwpresets=""
34
+ try:
35
+ with open(sml_filepath,encoding="utf-8") as f:
36
+ sml_lbwpresets = f.read()
37
+ except OSError as e:
38
+ sml_lbwpresets=LWEIGHTSPRESETS
39
+
40
+ with gr.Blocks(analytics_enabled=False) :
41
+ sml_submit_result = gr.Textbox(label="Message")
42
+ with gr.Row().style(equal_height=False):
43
+ sml_cpmerge = gr.Button(elem_id="model_merger_merge", value="Merge to Checkpoint",variant='primary')
44
+ sml_makelora = gr.Button(elem_id="model_merger_merge", value="Make LoRA (alpha * A - beta * B)",variant='primary')
45
+ sml_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True)
46
+ create_refresh_button(sml_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
47
+ sml_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True)
48
+ create_refresh_button(sml_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
49
+ with gr.Row().style(equal_height=False):
50
+ sml_merge = gr.Button(elem_id="model_merger_merge", value="Merge LoRAs",variant='primary')
51
+ alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=1)
52
+ beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=1)
53
+ with gr.Row().style(equal_height=False):
54
+ sml_settings = gr.CheckboxGroup(["same to Strength", "overwrite"], label="settings")
55
+ precision = gr.Radio(label = "save precision",choices=["float","fp16","bf16"],value = "fp16",type="value")
56
+ with gr.Row().style(equal_height=False):
57
+ sml_dim = gr.Radio(label = "remake dimension",choices = ["no","auto",*[2**(x+2) for x in range(9)]],value = "no",type = "value")
58
+ sml_filename = gr.Textbox(label="filename(option)",lines=1,visible =True,interactive = True)
59
+ sml_loranames = gr.Textbox(label='LoRAname1:ratio1:Blocks1,LoRAname2:ratio2:Blocks2,...(":blocks" is option, not necessary)',lines=1,value="",visible =True)
60
+ sml_dims = gr.CheckboxGroup(label = "limit dimension",choices=[],value = [],type="value",interactive=True,visible = False)
61
+ with gr.Row().style(equal_height=False):
62
+ sml_calcdim = gr.Button(elem_id="calcloras", value="calculate dimension of LoRAs(It may take a few minutes if there are many LoRAs)",variant='primary')
63
+ sml_update = gr.Button(elem_id="calcloras", value="update list",variant='primary')
64
+ sml_loras = gr.CheckboxGroup(label = "Lora",choices=[x[0] for x in lora.available_loras.items()],type="value",interactive=True,visible = True)
65
+ sml_loraratios = gr.TextArea(label="",value=sml_lbwpresets,visible =True,interactive = True)
66
+
67
+ sml_merge.click(
68
+ fn=lmerge,
69
+ inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_dim,precision],
70
+ outputs=[sml_submit_result]
71
+ )
72
+
73
+ sml_makelora.click(
74
+ fn=makelora,
75
+ inputs=[sml_model_a,sml_model_b,sml_dim,sml_filename,sml_settings,alpha,beta,precision],
76
+ outputs=[sml_submit_result]
77
+ )
78
+
79
+ sml_cpmerge.click(
80
+ fn=pluslora,
81
+ inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_model_a,precision],
82
+ outputs=[sml_submit_result]
83
+ )
84
+ llist ={}
85
+ dlist =[]
86
+ dn = []
87
+
88
+ def updateloras():
89
+ lora.list_available_loras()
90
+ for n in lora.available_loras.items():
91
+ if n[0] not in llist:llist[n[0]] = ""
92
+ return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()])
93
+
94
+ sml_update.click(fn = updateloras,outputs = [sml_loras])
95
+
96
+ def calculatedim():
97
+ print("listing dimensions...")
98
+ for n in tqdm(lora.available_loras.items()):
99
+ if n[0] in llist:
100
+ if llist[n[0]] !="": continue
101
+ c_lora = lora.available_loras.get(n[0], None)
102
+ d,t = dimgetter(c_lora.filename)
103
+ if t == "LoCon" : d = f"{d}:{t}"
104
+ if d not in dlist:
105
+ if type(d) == int :dlist.append(d)
106
+ elif d not in dn: dn.append(d)
107
+ llist[n[0]] = d
108
+ dlist.sort()
109
+ return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()],value =[]),gr.update(visible =True,choices = [x for x in (dlist+dn)])
110
+
111
+ sml_calcdim.click(
112
+ fn=calculatedim,
113
+ inputs=[],
114
+ outputs=[sml_loras,sml_dims]
115
+ )
116
+
117
+ def dimselector(dims):
118
+ if dims ==[]:return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()])
119
+ rl=[]
120
+ for d in dims:
121
+ for i in llist.items():
122
+ if d == i[1]:rl.append(f"{i[0]}({i[1]})")
123
+ return gr.update(choices = [l for l in rl],value =[])
124
+
125
+ def llister(names):
126
+ if names ==[] : return ""
127
+ else:
128
+ for i,n in enumerate(names):
129
+ if "(" in n:names[i] = n[:n.rfind("(")]
130
+ return ":1.0,".join(names)+":1.0"
131
+ sml_loras.change(fn=llister,inputs=[sml_loras],outputs=[sml_loranames])
132
+ sml_dims.change(fn=dimselector,inputs=[sml_dims],outputs=[sml_loras])
133
+
134
+ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,precision):
135
+ print("make LoRA start")
136
+ if model_a == "" or model_b =="":
137
+ return "ERROR: No model Selected"
138
+ gc.collect()
139
+
140
+ if saveto =="" : saveto = makeloraname(model_a,model_b)
141
+ if not ".safetensors" in saveto :saveto += ".safetensors"
142
+ saveto = os.path.join(shared.cmd_opts.lora_dir,saveto)
143
+
144
+ dim = 128 if type(dim) != int else int(dim)
145
+ if os.path.isfile(saveto ) and not "overwrite" in settings:
146
+ _err_msg = f"Output file ({saveto}) existed and was not saved"
147
+ print(_err_msg)
148
+ return _err_msg
149
+
150
+ svd(fullpathfromname(model_a),fullpathfromname(model_b),False,dim,precision,saveto,alpha,beta)
151
+ return f"saved to {saveto}"
152
+
153
+ def lmerge(loranames,loraratioss,settings,filename,dim,precision):
154
+ import lora
155
+ loras_on_disk = [lora.available_loras.get(name, None) for name in loranames]
156
+ if any([x is None for x in loras_on_disk]):
157
+ lora.list_available_loras()
158
+
159
+ loras_on_disk = [lora.available_loras.get(name, None) for name in loranames]
160
+
161
+ lnames = [loranames] if "," not in loranames else loranames.split(",")
162
+
163
+ for i, n in enumerate(lnames):
164
+ lnames[i] = n.split(":")
165
+
166
+ loraratios=loraratioss.splitlines()
167
+ ldict ={}
168
+
169
+ for i,l in enumerate(loraratios):
170
+ if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
171
+ ldict[l.split(":")[0]]=l.split(":")[1]
172
+
173
+ ln = []
174
+ lr = []
175
+ ld = []
176
+ lt = []
177
+ dmax = 1
178
+
179
+ for i,n in enumerate(lnames):
180
+ if len(n) ==3:
181
+ if n[2].strip() in ldict:
182
+ ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")]
183
+ else:ratio = [float(n[1])]*17
184
+ else:ratio = [float(n[1])]*17
185
+ c_lora = lora.available_loras.get(n[0], None)
186
+ ln.append(c_lora.filename)
187
+ lr.append(ratio)
188
+ d,t = dimgetter(c_lora.filename)
189
+ lt.append(t)
190
+ ld.append(d)
191
+ if d != "LyCORIS":
192
+ if d > dmax : dmax = d
193
+
194
+ if filename =="":filename =loranames.replace(",","+").replace(":","_")
195
+ if not ".safetensors" in filename:filename += ".safetensors"
196
+ filename = os.path.join(shared.cmd_opts.lora_dir,filename)
197
+
198
+ dim = int(dim) if dim != "no" and dim != "auto" else 0
199
+
200
+ if "LyCORIS" in ld or "LoCon" in lt:
201
+ if len(ld) !=1:
202
+ return "multiple merge of LyCORIS is not supported"
203
+ sd = lycomerge(ln[0],lr[0])
204
+ elif dim > 0:
205
+ print("change demension to ", dim)
206
+ sd = merge_lora_models_dim(ln, lr, dim,settings)
207
+ elif "auto" in settings and ld.count(ld[0]) != len(ld):
208
+ print("change demension to ",dmax)
209
+ sd = merge_lora_models_dim(ln, lr, dmax,settings)
210
+ else:
211
+ sd = merge_lora_models(ln, lr,settings)
212
+
213
+ if os.path.isfile(filename) and not "overwrite" in settings:
214
+ _err_msg = f"Output file ({filename}) existed and was not saved"
215
+ print(_err_msg)
216
+ return _err_msg
217
+
218
+ save_to_file(filename,sd,sd, str_to_dtype(precision))
219
+ return "saved : "+filename
220
+
221
+ def pluslora(lnames,loraratios,settings,output,model,precision):
222
+ if model == []:
223
+ return "ERROR: No model Selected"
224
+ if lnames == "":
225
+ return "ERROR: No LoRA Selected"
226
+
227
+ print("plus LoRA start")
228
+ import lora
229
+ lnames = [lnames] if "," not in lnames else lnames.split(",")
230
+
231
+ for i, n in enumerate(lnames):
232
+ lnames[i] = n.split(":")
233
+
234
+ loraratios=loraratios.splitlines()
235
+ ldict ={}
236
+
237
+ for i,l in enumerate(loraratios):
238
+ if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
239
+ ldict[l.split(":")[0].strip()]=l.split(":")[1]
240
+
241
+ names=[]
242
+ filenames=[]
243
+ loratypes=[]
244
+ lweis=[]
245
+
246
+ for n in lnames:
247
+ if len(n) ==3:
248
+ if n[2].strip() in ldict:
249
+ ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")]
250
+ else:ratio = [float(n[1])]*17
251
+ else:ratio = [float(n[1])]*17
252
+ c_lora = lora.available_loras.get(n[0], None)
253
+ names.append(n[0])
254
+ filenames.append(c_lora.filename)
255
+ _,t = dimgetter(c_lora.filename)
256
+ if "LyCORIS" in t: return "LyCORIS merge is not supported"
257
+ lweis.append(ratio)
258
+
259
+ modeln=filenamecutter(model,True)
260
+ dname = modeln
261
+ for n in names:
262
+ dname = dname + "+"+n
263
+
264
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model)
265
+ print(f"Loading {model}")
266
+ theta_0 = sd_models.read_state_dict(checkpoint_info.filename,"cpu")
267
+
268
+ keychanger = {}
269
+ for key in theta_0.keys():
270
+ if "model" in key:
271
+ skey = key.replace(".","_").replace("_weight","")
272
+ keychanger[skey.split("model_",1)[1]] = key
273
+
274
+ for name,filename, lwei in zip(names,filenames, lweis):
275
+ print(f"loading: {name}")
276
+ lora_sd = load_state_dict(filename, torch.float)
277
+
278
+ print(f"merging..." ,lwei)
279
+ for key in lora_sd.keys():
280
+ ratio = 1
281
+
282
+ fullkey = convert_diffusers_name_to_compvis(key)
283
+
284
+ for i,block in enumerate(LORABLOCKS):
285
+ if block in fullkey:
286
+ ratio = lwei[i]
287
+
288
+ msd_key, lora_key = fullkey.split(".", 1)
289
+
290
+ if "lora_down" in key:
291
+ up_key = key.replace("lora_down", "lora_up")
292
+ alpha_key = key[:key.index("lora_down")] + 'alpha'
293
+
294
+ # print(f"apply {key} to {module}")
295
+
296
+ down_weight = lora_sd[key].to(device="cpu")
297
+ up_weight = lora_sd[up_key].to(device="cpu")
298
+
299
+ dim = down_weight.size()[0]
300
+ alpha = lora_sd.get(alpha_key, dim)
301
+ scale = alpha / dim
302
+ # W <- W + U * D
303
+ weight = theta_0[keychanger[msd_key]].to(device="cpu")
304
+
305
+ if not len(down_weight.size()) == 4:
306
+ # linear
307
+ weight = weight + ratio * (up_weight @ down_weight) * scale
308
+ else:
309
+ # conv2d
310
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
311
+ ).unsqueeze(2).unsqueeze(3) * scale
312
+ theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight)
313
+ #usemodelgen(theta_0,model)
314
+ settings.append(precision)
315
+ result = savemodel(theta_0,dname,output,settings,model)
316
+ del theta_0
317
+ gc.collect()
318
+ return result
319
+
320
+ def save_to_file(file_name, model, state_dict, dtype):
321
+ if dtype is not None:
322
+ for key in list(state_dict.keys()):
323
+ if type(state_dict[key]) == torch.Tensor:
324
+ state_dict[key] = state_dict[key].to(dtype)
325
+
326
+ if os.path.splitext(file_name)[1] == '.safetensors':
327
+ save_file(model, file_name)
328
+ else:
329
+ torch.save(model, file_name)
330
+
331
+ re_digits = re.compile(r"\d+")
332
+
333
+ re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
334
+ re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
335
+ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
336
+
337
+ re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
338
+ re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
339
+ re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
340
+
341
+ re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
342
+ re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
343
+
344
+ re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
345
+
346
+
347
+ def convert_diffusers_name_to_compvis(key):
348
+ def match(match_list, regex):
349
+ r = re.match(regex, key)
350
+ if not r:
351
+ return False
352
+
353
+ match_list.clear()
354
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
355
+ return True
356
+
357
+ m = []
358
+
359
+ if match(m, re_unet_down_blocks):
360
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
361
+
362
+ if match(m, re_unet_mid_blocks):
363
+ return f"diffusion_model_middle_block_1_{m[1]}"
364
+
365
+ if match(m, re_unet_up_blocks):
366
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
367
+
368
+ if match(m, re_unet_down_blocks_res):
369
+ block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
370
+ if m[2].startswith('conv1'):
371
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
372
+ elif m[2].startswith('conv2'):
373
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
374
+ elif m[2].startswith('time_emb_proj'):
375
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
376
+ elif m[2].startswith('conv_shortcut'):
377
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
378
+
379
+ if match(m, re_unet_mid_blocks_res):
380
+ block = f"diffusion_model_middle_block_{m[0]*2}_"
381
+ if m[1].startswith('conv1'):
382
+ return f"{block}in_layers_2{m[1][len('conv1'):]}"
383
+ elif m[1].startswith('conv2'):
384
+ return f"{block}out_layers_3{m[1][len('conv2'):]}"
385
+ elif m[1].startswith('time_emb_proj'):
386
+ return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
387
+ elif m[1].startswith('conv_shortcut'):
388
+ return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
389
+
390
+ if match(m, re_unet_up_blocks_res):
391
+ block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
392
+ if m[2].startswith('conv1'):
393
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
394
+ elif m[2].startswith('conv2'):
395
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
396
+ elif m[2].startswith('time_emb_proj'):
397
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
398
+ elif m[2].startswith('conv_shortcut'):
399
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
400
+
401
+ if match(m, re_unet_downsample):
402
+ return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
403
+
404
+ if match(m, re_unet_upsample):
405
+ return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
406
+
407
+ if match(m, re_text_block):
408
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
409
+
410
+ return key
411
+
412
+ CLAMP_QUANTILE = 0.99
413
+ MIN_DIFF = 1e-6
414
+
415
+ def str_to_dtype(p):
416
+ if p == 'float':
417
+ return torch.float
418
+ if p == 'fp16':
419
+ return torch.float16
420
+ if p == 'bf16':
421
+ return torch.bfloat16
422
+ return None
423
+
424
+ def svd(model_a,model_b,v2,dim,save_precision,save_to,alpha,beta):
425
+ save_dtype = str_to_dtype(save_precision)
426
+
427
+ if model_a == model_b:
428
+ text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a)
429
+ text_encoder_o, _, unet_o = text_encoder_t, _, unet_t
430
+ else:
431
+ print(f"loading SD model : {model_b}")
432
+ text_encoder_o, _, unet_o = load_models_from_stable_diffusion_checkpoint(v2, model_b)
433
+
434
+ print(f"loading SD model : {model_a}")
435
+ text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a)
436
+
437
+ # create LoRA network to extract weights: Use dim (rank) as alpha
438
+ lora_network_o = create_network(1.0, dim, dim, None, text_encoder_o, unet_o)
439
+ lora_network_t = create_network(1.0, dim, dim, None, text_encoder_t, unet_t)
440
+ assert len(lora_network_o.text_encoder_loras) == len(
441
+ lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
442
+ # get diffs
443
+ diffs = {}
444
+ text_encoder_different = False
445
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
446
+ lora_name = lora_o.lora_name
447
+ module_o = lora_o.org_module
448
+ module_t = lora_t.org_module
449
+ diff = alpha*module_t.weight - beta*module_o.weight
450
+
451
+ # Text Encoder might be same
452
+ if torch.max(torch.abs(diff)) > MIN_DIFF:
453
+ text_encoder_different = True
454
+
455
+ diff = diff.float()
456
+ diffs[lora_name] = diff
457
+
458
+ if not text_encoder_different:
459
+ print("Text encoder is same. Extract U-Net only.")
460
+ lora_network_o.text_encoder_loras = []
461
+ diffs = {}
462
+
463
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
464
+ lora_name = lora_o.lora_name
465
+ module_o = lora_o.org_module
466
+ module_t = lora_t.org_module
467
+ diff = alpha*module_t.weight - beta*module_o.weight
468
+ diff = diff.float()
469
+
470
+ diffs[lora_name] = diff
471
+
472
+ # make LoRA with svd
473
+ print("calculating by svd")
474
+ rank = dim
475
+ lora_weights = {}
476
+ with torch.no_grad():
477
+ for lora_name, mat in tqdm(list(diffs.items())):
478
+ conv2d = (len(mat.size()) == 4)
479
+ if conv2d:
480
+ mat = mat.squeeze()
481
+
482
+ U, S, Vh = torch.linalg.svd(mat)
483
+
484
+ U = U[:, :rank]
485
+ S = S[:rank]
486
+ U = U @ torch.diag(S)
487
+
488
+ Vh = Vh[:rank, :]
489
+
490
+ dist = torch.cat([U.flatten(), Vh.flatten()])
491
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
492
+ low_val = -hi_val
493
+
494
+ U = U.clamp(low_val, hi_val)
495
+ Vh = Vh.clamp(low_val, hi_val)
496
+
497
+ lora_weights[lora_name] = (U, Vh)
498
+
499
+ # make state dict for LoRA
500
+ lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
501
+ lora_sd = lora_network_o.state_dict()
502
+ print(f"LoRA has {len(lora_sd)} weights.")
503
+
504
+ for key in list(lora_sd.keys()):
505
+ if "alpha" in key:
506
+ continue
507
+
508
+ lora_name = key.split('.')[0]
509
+ i = 0 if "lora_up" in key else 1
510
+
511
+ weights = lora_weights[lora_name][i]
512
+ # print(key, i, weights.size(), lora_sd[key].size())
513
+ if len(lora_sd[key].size()) == 4:
514
+ weights = weights.unsqueeze(2).unsqueeze(3)
515
+
516
+ assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
517
+ lora_sd[key] = weights
518
+
519
+ # load state dict to LoRA and save it
520
+ info = lora_network_o.load_state_dict(lora_sd)
521
+ print(f"Loading extracted LoRA weights: {info}")
522
+
523
+ dir_name = os.path.dirname(save_to)
524
+ if dir_name and not os.path.exists(dir_name):
525
+ os.makedirs(dir_name, exist_ok=True)
526
+
527
+ # minimum metadata
528
+ metadata = {"ss_network_dim": str(dim), "ss_network_alpha": str(dim)}
529
+
530
+ lora_network_o.save_weights(save_to, save_dtype, metadata)
531
+ print(f"LoRA weights are saved to: {save_to}")
532
+ return save_to
533
+
534
+ def load_state_dict(file_name, dtype):
535
+ if os.path.splitext(file_name)[1] == '.safetensors':
536
+ sd = load_file(file_name)
537
+ else:
538
+ sd = torch.load(file_name, map_location='cpu')
539
+ for key in list(sd.keys()):
540
+ if type(sd[key]) == torch.Tensor:
541
+ sd[key] = sd[key].to(dtype)
542
+ return sd
543
+
544
+ def dimgetter(filename):
545
+ lora_sd = load_state_dict(filename, torch.float)
546
+ alpha = None
547
+ dim = None
548
+ type = None
549
+
550
+ if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys():
551
+ type = "LoCon"
552
+
553
+ for key, value in lora_sd.items():
554
+
555
+ if alpha is None and 'alpha' in key:
556
+ alpha = value
557
+ if dim is None and 'lora_down' in key and len(value.size()) == 2:
558
+ dim = value.size()[0]
559
+ if "hada_" in key:
560
+ dim,type = "LyCORIS","LyCORIS"
561
+ if alpha is not None and dim is not None:
562
+ break
563
+ if alpha is None:
564
+ alpha = dim
565
+ if type == None:type = "LoRA"
566
+ if dim :
567
+ return dim,type
568
+ else:
569
+ return "unknown","unknown"
570
+
571
+ def blockfromkey(key):
572
+ fullkey = convert_diffusers_name_to_compvis(key)
573
+ for i,n in enumerate(LORABLOCKS):
574
+ if n in fullkey: return i
575
+ return 0
576
+
577
+ def merge_lora_models_dim(models, ratios, new_rank,sets):
578
+ merged_sd = {}
579
+ fugou = 1
580
+ for model, ratios in zip(models, ratios):
581
+ merge_dtype = torch.float
582
+ lora_sd = load_state_dict(model, merge_dtype)
583
+
584
+ # merge
585
+ print(f"merging {model}: {ratios}")
586
+ for key in tqdm(list(lora_sd.keys())):
587
+ if 'lora_down' not in key:
588
+ continue
589
+ lora_module_name = key[:key.rfind(".lora_down")]
590
+
591
+ down_weight = lora_sd[key]
592
+ network_dim = down_weight.size()[0]
593
+
594
+ up_weight = lora_sd[lora_module_name + '.lora_up.weight']
595
+ alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
596
+
597
+ in_dim = down_weight.size()[1]
598
+ out_dim = up_weight.size()[0]
599
+ conv2d = len(down_weight.size()) == 4
600
+ # print(lora_module_name, network_dim, alpha, in_dim, out_dim)
601
+
602
+ # make original weight if not exist
603
+ if lora_module_name not in merged_sd:
604
+ weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
605
+ else:
606
+ weight = merged_sd[lora_module_name]
607
+
608
+ ratio = ratios[blockfromkey(key)]
609
+ if "same to Strength" in sets:
610
+ ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1)
611
+ #print(lora_module_name, ratio)
612
+ # W <- W + U * D
613
+ scale = (alpha / network_dim)
614
+ if not conv2d: # linear
615
+ weight = weight + ratio * (up_weight @ down_weight) * scale * fugou
616
+ else:
617
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
618
+ ).unsqueeze(2).unsqueeze(3) * scale * fugou
619
+
620
+ merged_sd[lora_module_name] = weight
621
+
622
+ # extract from merged weights
623
+ print("extract new lora...")
624
+ merged_lora_sd = {}
625
+ with torch.no_grad():
626
+ for lora_module_name, mat in tqdm(list(merged_sd.items())):
627
+ conv2d = (len(mat.size()) == 4)
628
+ if conv2d:
629
+ mat = mat.squeeze()
630
+
631
+ U, S, Vh = torch.linalg.svd(mat)
632
+
633
+ U = U[:, :new_rank]
634
+ S = S[:new_rank]
635
+ U = U @ torch.diag(S)
636
+
637
+ Vh = Vh[:new_rank, :]
638
+
639
+ dist = torch.cat([U.flatten(), Vh.flatten()])
640
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
641
+ low_val = -hi_val
642
+
643
+ U = U.clamp(low_val, hi_val)
644
+ Vh = Vh.clamp(low_val, hi_val)
645
+
646
+ up_weight = U
647
+ down_weight = Vh
648
+
649
+ if conv2d:
650
+ up_weight = up_weight.unsqueeze(2).unsqueeze(3)
651
+ down_weight = down_weight.unsqueeze(2).unsqueeze(3)
652
+
653
+ merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
654
+ merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
655
+ merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
656
+
657
+ return merged_lora_sd
658
+
659
+ def merge_lora_models(models, ratios,sets):
660
+ base_alphas = {} # alpha for merged model
661
+ base_dims = {}
662
+ merge_dtype = torch.float
663
+ merged_sd = {}
664
+ fugou = 1
665
+ for model, ratios in zip(models, ratios):
666
+ print(f"merging {model}: {ratios}")
667
+ lora_sd = load_state_dict(model, merge_dtype)
668
+
669
+ # get alpha and dim
670
+ alphas = {} # alpha for current model
671
+ dims = {} # dims for current model
672
+ for key in lora_sd.keys():
673
+ if 'alpha' in key:
674
+ lora_module_name = key[:key.rfind(".alpha")]
675
+ alpha = float(lora_sd[key].detach().numpy())
676
+ alphas[lora_module_name] = alpha
677
+ if lora_module_name not in base_alphas:
678
+ base_alphas[lora_module_name] = alpha
679
+ elif "lora_down" in key:
680
+ lora_module_name = key[:key.rfind(".lora_down")]
681
+ dim = lora_sd[key].size()[0]
682
+ dims[lora_module_name] = dim
683
+ if lora_module_name not in base_dims:
684
+ base_dims[lora_module_name] = dim
685
+
686
+ for lora_module_name in dims.keys():
687
+ if lora_module_name not in alphas:
688
+ alpha = dims[lora_module_name]
689
+ alphas[lora_module_name] = alpha
690
+ if lora_module_name not in base_alphas:
691
+ base_alphas[lora_module_name] = alpha
692
+
693
+ print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
694
+
695
+ # merge
696
+ print(f"merging...")
697
+ for key in lora_sd.keys():
698
+ if 'alpha' in key:
699
+ continue
700
+ if "lora_down" in key: dwon = True
701
+ lora_module_name = key[:key.rfind(".lora_")]
702
+
703
+ base_alpha = base_alphas[lora_module_name]
704
+ alpha = alphas[lora_module_name]
705
+
706
+ ratio = ratios[blockfromkey(key)]
707
+ if "same to Strength" in sets:
708
+ ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1)
709
+
710
+ if "lora_down" in key:
711
+ ratio = ratio * fugou
712
+
713
+ scale = math.sqrt(alpha / base_alpha) * ratio
714
+
715
+ if key in merged_sd:
716
+ assert merged_sd[key].size() == lora_sd[key].size(
717
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
718
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
719
+ else:
720
+ merged_sd[key] = lora_sd[key] * scale
721
+
722
+ # set alpha to sd
723
+ for lora_module_name, alpha in base_alphas.items():
724
+ key = lora_module_name + ".alpha"
725
+ merged_sd[key] = torch.tensor(alpha)
726
+
727
+ print("merged model")
728
+ print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
729
+
730
+ return merged_sd
731
+
732
+ def fullpathfromname(name):
733
+ if hash == "" or hash ==[]: return ""
734
+ checkpoint_info = sd_models.get_closet_checkpoint_match(name)
735
+ return checkpoint_info.filename
736
+
737
+ def makeloraname(model_a,model_b):
738
+ model_a=filenamecutter(model_a)
739
+ model_b=filenamecutter(model_b)
740
+ return "lora_"+model_a+"-"+model_b
741
+
742
+ def lycomerge(filename,ratios):
743
+ sd = load_state_dict(filename, torch.float)
744
+
745
+ if len(ratios) == 17:
746
+ r0 = 1
747
+ ratios = [ratios[0]] + [r0] + ratios[1:3]+ [r0] + ratios[3:5]+[r0] + ratios[5:7]+[r0,r0,r0] + [ratios[7]] + [r0,r0,r0] + ratios[8:]
748
+
749
+ print("LyCORIS: " , ratios)
750
+
751
+ keys_failed_to_match = []
752
+
753
+ for lkey, weight in sd.items():
754
+ ratio = 1
755
+ picked = False
756
+ if 'alpha' in lkey:
757
+ continue
758
+
759
+ fullkey = convert_diffusers_name_to_compvis(lkey)
760
+ key, lora_key = fullkey.split(".", 1)
761
+
762
+ for i,block in enumerate(LYCOBLOCKS):
763
+ if block in key:
764
+ ratio = ratios[i]
765
+ picked = True
766
+ if not picked: keys_failed_to_match.append(key)
767
+
768
+ sd[lkey] = weight * math.sqrt(abs(float(ratio)))
769
+
770
+ if "down" in lkey and ratio < 0:
771
+ sd[key] = sd[key] * -1
772
+
773
+ if len(keys_failed_to_match) > 0:
774
+ print(keys_failed_to_match)
775
+
776
+ return sd
777
+
778
+ LORABLOCKS=["encoder",
779
+ "diffusion_model_input_blocks_1_",
780
+ "diffusion_model_input_blocks_2_",
781
+ "diffusion_model_input_blocks_4_",
782
+ "diffusion_model_input_blocks_5_",
783
+ "diffusion_model_input_blocks_7_",
784
+ "diffusion_model_input_blocks_8_",
785
+ "diffusion_model_middle_block_",
786
+ "diffusion_model_output_blocks_3_",
787
+ "diffusion_model_output_blocks_4_",
788
+ "diffusion_model_output_blocks_5_",
789
+ "diffusion_model_output_blocks_6_",
790
+ "diffusion_model_output_blocks_7_",
791
+ "diffusion_model_output_blocks_8_",
792
+ "diffusion_model_output_blocks_9_",
793
+ "diffusion_model_output_blocks_10_",
794
+ "diffusion_model_output_blocks_11_"]
795
+
796
+ LYCOBLOCKS=["encoder",
797
+ "diffusion_model_input_blocks_0_",
798
+ "diffusion_model_input_blocks_1_",
799
+ "diffusion_model_input_blocks_2_",
800
+ "diffusion_model_input_blocks_3_",
801
+ "diffusion_model_input_blocks_4_",
802
+ "diffusion_model_input_blocks_5_",
803
+ "diffusion_model_input_blocks_6_",
804
+ "diffusion_model_input_blocks_7_",
805
+ "diffusion_model_input_blocks_8_",
806
+ "diffusion_model_input_blocks_9_",
807
+ "diffusion_model_input_blocks_10_",
808
+ "diffusion_model_input_blocks_11_",
809
+ "diffusion_model_middle_block_",
810
+ "diffusion_model_output_blocks_0_",
811
+ "diffusion_model_output_blocks_1_",
812
+ "diffusion_model_output_blocks_2_",
813
+ "diffusion_model_output_blocks_3_",
814
+ "diffusion_model_output_blocks_4_",
815
+ "diffusion_model_output_blocks_5_",
816
+ "diffusion_model_output_blocks_6_",
817
+ "diffusion_model_output_blocks_7_",
818
+ "diffusion_model_output_blocks_8_",
819
+ "diffusion_model_output_blocks_9_",
820
+ "diffusion_model_output_blocks_10_",
821
+ "diffusion_model_output_blocks_11_"]
822
+
823
+ class LoRAModule(torch.nn.Module):
824
+ """
825
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
826
+ """
827
+
828
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
829
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
830
+ super().__init__()
831
+ self.lora_name = lora_name
832
+
833
+ if org_module.__class__.__name__ == "Conv2d":
834
+ in_dim = org_module.in_channels
835
+ out_dim = org_module.out_channels
836
+ else:
837
+ in_dim = org_module.in_features
838
+ out_dim = org_module.out_features
839
+
840
+ # if limit_rank:
841
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
842
+ # if self.lora_dim != lora_dim:
843
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
844
+ # else:
845
+ self.lora_dim = lora_dim
846
+
847
+ if org_module.__class__.__name__ == "Conv2d":
848
+ kernel_size = org_module.kernel_size
849
+ stride = org_module.stride
850
+ padding = org_module.padding
851
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
852
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
853
+ else:
854
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
855
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
856
+
857
+ if type(alpha) == torch.Tensor:
858
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
859
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
860
+ self.scale = alpha / self.lora_dim
861
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
862
+
863
+ # same as microsoft's
864
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
865
+ torch.nn.init.zeros_(self.lora_up.weight)
866
+
867
+ self.multiplier = multiplier
868
+ self.org_module = org_module # remove in applying
869
+ self.region = None
870
+ self.region_mask = None
871
+
872
+ def apply_to(self):
873
+ self.org_forward = self.org_module.forward
874
+ self.org_module.forward = self.forward
875
+ del self.org_module
876
+
877
+ def merge_to(self, sd, dtype, device):
878
+ # get up/down weight
879
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
880
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
881
+
882
+ # extract weight from org_module
883
+ org_sd = self.org_module.state_dict()
884
+ weight = org_sd["weight"].to(torch.float)
885
+
886
+ # merge weight
887
+ if len(weight.size()) == 2:
888
+ # linear
889
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
890
+ elif down_weight.size()[2:4] == (1, 1):
891
+ # conv2d 1x1
892
+ weight = (
893
+ weight
894
+ + self.multiplier
895
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
896
+ * self.scale
897
+ )
898
+ else:
899
+ # conv2d 3x3
900
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
901
+ # print(conved.size(), weight.size(), module.stride, module.padding)
902
+ weight = weight + self.multiplier * conved * self.scale
903
+
904
+ # set weight to org_module
905
+ org_sd["weight"] = weight.to(dtype)
906
+ self.org_module.load_state_dict(org_sd)
907
+
908
+ def set_region(self, region):
909
+ self.region = region
910
+ self.region_mask = None
911
+
912
+ def forward(self, x):
913
+ if self.region is None:
914
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
915
+
916
+ # regional LoRA FIXME same as additional-network extension
917
+ if x.size()[1] % 77 == 0:
918
+ # print(f"LoRA for context: {self.lora_name}")
919
+ self.region = None
920
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
921
+
922
+ # calculate region mask first time
923
+ if self.region_mask is None:
924
+ if len(x.size()) == 4:
925
+ h, w = x.size()[2:4]
926
+ else:
927
+ seq_len = x.size()[1]
928
+ ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
929
+ h = int(self.region.size()[0] / ratio + 0.5)
930
+ w = seq_len // h
931
+
932
+ r = self.region.to(x.device)
933
+ if r.dtype == torch.bfloat16:
934
+ r = r.to(torch.float)
935
+ r = r.unsqueeze(0).unsqueeze(1)
936
+ # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
937
+ r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
938
+ r = r.to(x.dtype)
939
+
940
+ if len(x.size()) == 3:
941
+ r = torch.reshape(r, (1, x.size()[1], -1))
942
+
943
+ self.region_mask = r
944
+
945
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
946
+
947
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
948
+ if network_dim is None:
949
+ network_dim = 4 # default
950
+
951
+ # extract dim/alpha for conv2d, and block dim
952
+ conv_dim = kwargs.get("conv_dim", None)
953
+ conv_alpha = kwargs.get("conv_alpha", None)
954
+ if conv_dim is not None:
955
+ conv_dim = int(conv_dim)
956
+ if conv_alpha is None:
957
+ conv_alpha = 1.0
958
+ else:
959
+ conv_alpha = float(conv_alpha)
960
+
961
+ """
962
+ block_dims = kwargs.get("block_dims")
963
+ block_alphas = None
964
+ if block_dims is not None:
965
+ block_dims = [int(d) for d in block_dims.split(',')]
966
+ assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
967
+ block_alphas = kwargs.get("block_alphas")
968
+ if block_alphas is None:
969
+ block_alphas = [1] * len(block_dims)
970
+ else:
971
+ block_alphas = [int(a) for a in block_alphas(',')]
972
+ assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
973
+ conv_block_dims = kwargs.get("conv_block_dims")
974
+ conv_block_alphas = None
975
+ if conv_block_dims is not None:
976
+ conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
977
+ assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
978
+ conv_block_alphas = kwargs.get("conv_block_alphas")
979
+ if conv_block_alphas is None:
980
+ conv_block_alphas = [1] * len(conv_block_dims)
981
+ else:
982
+ conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
983
+ assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
984
+ """
985
+
986
+ network = LoRANetwork(
987
+ text_encoder,
988
+ unet,
989
+ multiplier=multiplier,
990
+ lora_dim=network_dim,
991
+ alpha=network_alpha,
992
+ conv_lora_dim=conv_dim,
993
+ conv_alpha=conv_alpha,
994
+ )
995
+ return network
996
+
997
+
998
+
999
+ class LoRANetwork(torch.nn.Module):
1000
+ # is it possible to apply conv_in and conv_out?
1001
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
1002
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
1003
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
1004
+ LORA_PREFIX_UNET = "lora_unet"
1005
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
1006
+
1007
+ def __init__(
1008
+ self,
1009
+ text_encoder,
1010
+ unet,
1011
+ multiplier=1.0,
1012
+ lora_dim=4,
1013
+ alpha=1,
1014
+ conv_lora_dim=None,
1015
+ conv_alpha=None,
1016
+ modules_dim=None,
1017
+ modules_alpha=None,
1018
+ ) -> None:
1019
+ super().__init__()
1020
+ self.multiplier = multiplier
1021
+
1022
+ self.lora_dim = lora_dim
1023
+ self.alpha = alpha
1024
+ self.conv_lora_dim = conv_lora_dim
1025
+ self.conv_alpha = conv_alpha
1026
+
1027
+ if modules_dim is not None:
1028
+ print(f"create LoRA network from weights")
1029
+ else:
1030
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
1031
+
1032
+ self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
1033
+ if self.apply_to_conv2d_3x3:
1034
+ if self.conv_alpha is None:
1035
+ self.conv_alpha = self.alpha
1036
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
1037
+
1038
+ # create module instances
1039
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
1040
+ loras = []
1041
+ for name, module in root_module.named_modules():
1042
+ if module.__class__.__name__ in target_replace_modules:
1043
+ # TODO get block index here
1044
+ for child_name, child_module in module.named_modules():
1045
+ is_linear = child_module.__class__.__name__ == "Linear"
1046
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
1047
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
1048
+ if is_linear or is_conv2d:
1049
+ lora_name = prefix + "." + name + "." + child_name
1050
+ lora_name = lora_name.replace(".", "_")
1051
+
1052
+ if modules_dim is not None:
1053
+ if lora_name not in modules_dim:
1054
+ continue # no LoRA module in this weights file
1055
+ dim = modules_dim[lora_name]
1056
+ alpha = modules_alpha[lora_name]
1057
+ else:
1058
+ if is_linear or is_conv2d_1x1:
1059
+ dim = self.lora_dim
1060
+ alpha = self.alpha
1061
+ elif self.apply_to_conv2d_3x3:
1062
+ dim = self.conv_lora_dim
1063
+ alpha = self.conv_alpha
1064
+ else:
1065
+ continue
1066
+
1067
+ lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
1068
+ loras.append(lora)
1069
+ return loras
1070
+
1071
+ self.text_encoder_loras = create_modules(
1072
+ LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
1073
+ )
1074
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
1075
+
1076
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
1077
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
1078
+ if modules_dim is not None or self.conv_lora_dim is not None:
1079
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
1080
+
1081
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
1082
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
1083
+
1084
+ self.weights_sd = None
1085
+
1086
+ # assertion
1087
+ names = set()
1088
+ for lora in self.text_encoder_loras + self.unet_loras:
1089
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
1090
+ names.add(lora.lora_name)
1091
+
1092
+ def set_multiplier(self, multiplier):
1093
+ self.multiplier = multiplier
1094
+ for lora in self.text_encoder_loras + self.unet_loras:
1095
+ lora.multiplier = self.multiplier
1096
+
1097
+ def load_weights(self, file):
1098
+ if os.path.splitext(file)[1] == ".safetensors":
1099
+ from safetensors.torch import load_file, safe_open
1100
+
1101
+ self.weights_sd = load_file(file)
1102
+ else:
1103
+ self.weights_sd = torch.load(file, map_location="cpu")
1104
+
1105
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
1106
+ if self.weights_sd:
1107
+ weights_has_text_encoder = weights_has_unet = False
1108
+ for key in self.weights_sd.keys():
1109
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
1110
+ weights_has_text_encoder = True
1111
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
1112
+ weights_has_unet = True
1113
+
1114
+ if apply_text_encoder is None:
1115
+ apply_text_encoder = weights_has_text_encoder
1116
+ else:
1117
+ assert (
1118
+ apply_text_encoder == weights_has_text_encoder
1119
+ ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
1120
+
1121
+ if apply_unet is None:
1122
+ apply_unet = weights_has_unet
1123
+ else:
1124
+ assert (
1125
+ apply_unet == weights_has_unet
1126
+ ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
1127
+ else:
1128
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
1129
+
1130
+ if apply_text_encoder:
1131
+ print("enable LoRA for text encoder")
1132
+ else:
1133
+ self.text_encoder_loras = []
1134
+
1135
+ if apply_unet:
1136
+ print("enable LoRA for U-Net")
1137
+ else:
1138
+ self.unet_loras = []
1139
+
1140
+ for lora in self.text_encoder_loras + self.unet_loras:
1141
+ lora.apply_to()
1142
+ self.add_module(lora.lora_name, lora)
1143
+
1144
+ if self.weights_sd:
1145
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
1146
+ info = self.load_state_dict(self.weights_sd, False)
1147
+ print(f"weights are loaded: {info}")
1148
+
1149
+ # TODO refactor to common function with apply_to
1150
+ def merge_to(self, text_encoder, unet, dtype, device):
1151
+ assert self.weights_sd is not None, "weights are not loaded"
1152
+
1153
+ apply_text_encoder = apply_unet = False
1154
+ for key in self.weights_sd.keys():
1155
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
1156
+ apply_text_encoder = True
1157
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
1158
+ apply_unet = True
1159
+
1160
+ if apply_text_encoder:
1161
+ print("enable LoRA for text encoder")
1162
+ else:
1163
+ self.text_encoder_loras = []
1164
+
1165
+ if apply_unet:
1166
+ print("enable LoRA for U-Net")
1167
+ else:
1168
+ self.unet_loras = []
1169
+
1170
+ for lora in self.text_encoder_loras + self.unet_loras:
1171
+ sd_for_lora = {}
1172
+ for key in self.weights_sd.keys():
1173
+ if key.startswith(lora.lora_name):
1174
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
1175
+ lora.merge_to(sd_for_lora, dtype, device)
1176
+ print(f"weights are merged")
1177
+
1178
+ def enable_gradient_checkpointing(self):
1179
+ # not supported
1180
+ pass
1181
+
1182
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
1183
+ def enumerate_params(loras):
1184
+ params = []
1185
+ for lora in loras:
1186
+ params.extend(lora.parameters())
1187
+ return params
1188
+
1189
+ self.requires_grad_(True)
1190
+ all_params = []
1191
+
1192
+ if self.text_encoder_loras:
1193
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
1194
+ if text_encoder_lr is not None:
1195
+ param_data["lr"] = text_encoder_lr
1196
+ all_params.append(param_data)
1197
+
1198
+ if self.unet_loras:
1199
+ param_data = {"params": enumerate_params(self.unet_loras)}
1200
+ if unet_lr is not None:
1201
+ param_data["lr"] = unet_lr
1202
+ all_params.append(param_data)
1203
+
1204
+ return all_params
1205
+
1206
+ def prepare_grad_etc(self, text_encoder, unet):
1207
+ self.requires_grad_(True)
1208
+
1209
+ def on_epoch_start(self, text_encoder, unet):
1210
+ self.train()
1211
+
1212
+ def get_trainable_params(self):
1213
+ return self.parameters()
1214
+
1215
+ def save_weights(self, file, dtype, metadata):
1216
+ if metadata is not None and len(metadata) == 0:
1217
+ metadata = None
1218
+
1219
+ state_dict = self.state_dict()
1220
+
1221
+ if dtype is not None:
1222
+ for key in list(state_dict.keys()):
1223
+ v = state_dict[key]
1224
+ v = v.detach().clone().to("cpu").to(dtype)
1225
+ state_dict[key] = v
1226
+
1227
+ if os.path.splitext(file)[1] == ".safetensors":
1228
+ from safetensors.torch import save_file
1229
+
1230
+ # Precalculate model hashes to save time on indexing
1231
+ if metadata is None:
1232
+ metadata = {}
1233
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
1234
+ metadata["sshs_model_hash"] = model_hash
1235
+ metadata["sshs_legacy_hash"] = legacy_hash
1236
+
1237
+ save_file(state_dict, file, metadata)
1238
+ else:
1239
+ torch.save(state_dict, file)
1240
+
1241
+ @staticmethod
1242
+ def set_regions(networks, image):
1243
+ image = image.astype(np.float32) / 255.0
1244
+ for i, network in enumerate(networks[:3]):
1245
+ # NOTE: consider averaging overwrapping area
1246
+ region = image[:, :, i]
1247
+ if region.max() == 0:
1248
+ continue
1249
+ region = torch.tensor(region)
1250
+ network.set_region(region)
1251
+
1252
+ def set_region(self, region):
1253
+ for lora in self.unet_loras:
1254
+ lora.set_region(region)
1255
+
1256
+ from io import BytesIO
1257
+ import safetensors.torch
1258
+ import hashlib
1259
+
1260
+ def precalculate_safetensors_hashes(tensors, metadata):
1261
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
1262
+ save time on indexing the model later."""
1263
+
1264
+ # Because writing user metadata to the file can change the result of
1265
+ # sd_models.model_hash(), only retain the training metadata for purposes of
1266
+ # calculating the hash, as they are meant to be immutable
1267
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1268
+
1269
+ bytes = safetensors.torch.save(tensors, metadata)
1270
+ b = BytesIO(bytes)
1271
+
1272
+ model_hash = addnet_hash_safetensors(b)
1273
+ legacy_hash = addnet_hash_legacy(b)
1274
+ return model_hash, legacy_hash
1275
+
1276
+ def addnet_hash_safetensors(b):
1277
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1278
+ hash_sha256 = hashlib.sha256()
1279
+ blksize = 1024 * 1024
1280
+
1281
+ b.seek(0)
1282
+ header = b.read(8)
1283
+ n = int.from_bytes(header, "little")
1284
+
1285
+ offset = n + 8
1286
+ b.seek(offset)
1287
+ for chunk in iter(lambda: b.read(blksize), b""):
1288
+ hash_sha256.update(chunk)
1289
+
1290
+ return hash_sha256.hexdigest()
1291
+
1292
+ def addnet_hash_legacy(b):
1293
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1294
+ m = hashlib.sha256()
1295
+
1296
+ b.seek(0x100000)
1297
+ m.update(b.read(0x10000))
1298
+ return m.hexdigest()[0:8]
microsoftexcel-supermerger/scripts/mergers/xyplot.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import copy
6
+ import csv
7
+ from PIL import Image
8
+ from modules import images
9
+ from modules.shared import opts
10
+ from scripts.mergers.mergers import TYPES,smerge,simggen,filenamecutter,draw_origin,wpreseter
11
+ from scripts.mergers.model_util import usemodelgen
12
+
13
+ hear = True
14
+ hearm = False
15
+
16
+ state_mergen = False
17
+
18
+ numadepth = []
19
+
20
+ def freezetime():
21
+ global state_mergen
22
+ state_mergen = True
23
+
24
+ def numanager(normalstart,xtype,xmen,ytype,ymen,esettings,
25
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
26
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
27
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
28
+ global numadepth
29
+ grids = []
30
+ sep = "|"
31
+
32
+ if sep in xmen:
33
+ xmens = xmen.split(sep)
34
+ xmen = xmens[0]
35
+ if seed =="-1": seed = str(random.randrange(4294967294))
36
+ for men in xmens[1:]:
37
+ numaker(xtype,men,ytype,ymen,esettings,
38
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
39
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
40
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
41
+ elif sep in ymen:
42
+ ymens = ymen.split(sep)
43
+ ymen = ymens[0]
44
+ if seed =="-1": seed = str(random.randrange(4294967294))
45
+ for men in ymens[1:]:
46
+ numaker(xtype,xmen,ytype,men,esettings,
47
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
48
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
49
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
50
+
51
+ if normalstart:
52
+ result,currentmodel,xyimage,a,b,c= sgenxyplot(xtype,xmen,ytype,ymen,esettings,
53
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
54
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
55
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
56
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
57
+ if xyimage is not None:grids =[xyimage[0]]
58
+ else:print(result)
59
+ else:
60
+ if numadepth ==[]:
61
+ return "no reservation",*[None]*5
62
+ result=currentmodel=xyimage=a=b=c = None
63
+
64
+ while True:
65
+ for i,row in enumerate(numadepth):
66
+ if row[1] =="waiting":
67
+ numadepth[i][1] = "Operating"
68
+ try:
69
+ result,currentmodel,xyimage,a,b,c = sgenxyplot(*row[2:])
70
+ except Exception as e:
71
+ print(e)
72
+ numadepth[i][1] = "Error"
73
+ else:
74
+ if xyimage is not None:
75
+ grids.append(xyimage[0])
76
+ numadepth[i][1] = "Finished"
77
+ else:
78
+ print(result)
79
+ numadepth[i][1] = "Error"
80
+ wcounter = 0
81
+ for row in numadepth:
82
+ if row[1] != "waiting":
83
+ wcounter += 1
84
+ if wcounter == len(numadepth):
85
+ break
86
+
87
+ return result,currentmodel,grids,a,b,c
88
+
89
+ def numaker(xtype,xmen,ytype,ymen,esettings,
90
+ #msettings=[weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets]
91
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
92
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
93
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
94
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
95
+ global numadepth
96
+ numadepth.append([len(numadepth)+1,"waiting",xtype,xmen,ytype,ymen,esettings,
97
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
98
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
99
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
100
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size])
101
+ return numalistmaker(copy.deepcopy(numadepth))
102
+
103
+ def nulister(redel):
104
+ global numadepth
105
+ if redel == False:
106
+ return numalistmaker(copy.deepcopy(numadepth))
107
+ if redel ==-1:
108
+ numadepth = []
109
+ else:
110
+ try:del numadepth[int(redel-1)]
111
+ except Exception as e:print(e)
112
+ return numalistmaker(copy.deepcopy(numadepth))
113
+
114
+ def numalistmaker(numa):
115
+ if numa ==[]: return [["no data","",""],]
116
+ for i,r in enumerate(numa):
117
+ r[2] = TYPES[int(r[2])]
118
+ r[4] = TYPES[int(r[4])]
119
+ numa[i] = r[0:6]+r[8:11]+r[12:16]+r[6:8]
120
+ return numa
121
+
122
+ def caster(news,hear):
123
+ if hear: print(news)
124
+
125
+ def sgenxyplot(xtype,xmen,ytype,ymen,esettings,
126
+ weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
127
+ useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
128
+ prompt,nprompt,steps,sampler,cfg,seed,w,h,
129
+ hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
130
+ global hear
131
+ esettings = " ".join(esettings)
132
+ #type[0:none,1:aplha,2:beta,3:seed,4:mbw,5:model_A,6:model_B,7:model_C,8:pinpoint 9:deep]
133
+ xtype = TYPES[xtype]
134
+ ytype = TYPES[ytype]
135
+ if ytype == "none": ymen = ""
136
+
137
+ modes=["Weight" ,"Add" ,"Triple","Twice"]
138
+ xs=ys=0
139
+ weights_a_in=weights_b_in="0"
140
+
141
+ deepprint = True if "print change" in esettings else False
142
+
143
+ def castall(hear):
144
+ if hear :print(f"xmen:{xmen}, ymen:{ymen}, xtype:{xtype}, ytype:{ytype}, weights_a:{weights_a_in}, weights_b:{weights_b_in}, model_A:{model_a},model_B :{model_b}, model_C:{model_c}, alpha:{alpha},\
145
+ beta :{beta}, mode:{mode}, blocks:{useblocks}")
146
+
147
+ pinpoint = "pinpoint blocks" in xtype or "pinpoint blocks" in ytype
148
+ usebeta = modes[2] in mode or modes[3] in mode
149
+
150
+ #check and adjust format
151
+ print(f"XY plot start, mode:{mode}, X: {xtype}, Y: {ytype}, MBW: {useblocks}")
152
+ castall(hear)
153
+ None5 = [None,None,None,None,None]
154
+ if xmen =="": return "ERROR: parameter X is empty",*None5
155
+ if ymen =="" and not ytype=="none": return "ERROR: parameter Y is empty",*None5
156
+ if model_a ==[] and not ("model_A" in xtype or "model_A" in ytype):return f"ERROR: model_A is not selected",*None5
157
+ if model_b ==[] and not ("model_B" in xtype or "model_B" in ytype):return f"ERROR: model_B is not selected",*None5
158
+ if model_c ==[] and usebeta and not ("model_C" in xtype or "model_C" in ytype):return "ERROR: model_C is not selected",*None5
159
+ if xtype == ytype: return "ERROR: same type selected for X,Y",*None5
160
+
161
+ if useblocks:
162
+ weights_a_in=wpreseter(weights_a,wpresets)
163
+ weights_b_in=wpreseter(weights_b,wpresets)
164
+
165
+ #for X only plot, use same seed
166
+ if seed == -1: seed = int(random.randrange(4294967294))
167
+
168
+ #for XY plot, use same seed
169
+ def dicedealer(zs):
170
+ for i,z in enumerate(zs):
171
+ if z =="-1": zs[i] = str(random.randrange(4294967294))
172
+ print(f"the die was thrown : {zs}")
173
+
174
+ #adjust parameters, alpha,beta,models,seed: list of single parameters, mbw(no beta):list of text,mbw(usebeta); list of pair text
175
+ def adjuster(zmen,ztype,aztype):
176
+ if "mbw" in ztype or "prompt" in ztype:#men separated by newline
177
+ zs = zmen.splitlines()
178
+ caster(zs,hear)
179
+ if "mbw alpha and beta" in ztype:
180
+ zs = [zs[i:i+2] for i in range(0,len(zs),2)]
181
+ caster(zs,hear)
182
+ elif "elemental" in ztype:
183
+ zs = zmen.split("\n\n")
184
+ else:
185
+ if "pinpoint element" in ztype:
186
+ zmen = zmen.replace("\n",",")
187
+ if "effective" in ztype:
188
+ zmen = ","+zmen
189
+ zmen = zmen.replace("\n",",")
190
+ zs = [z.strip() for z in zmen.split(',')]
191
+ caster(zs,hear)
192
+ if "alpha" in ztype and "effective" in aztype:
193
+ zs = [zs[0]]
194
+ if "seed" in ztype:dicedealer(zs)
195
+ if "alpha" == ztype or "beta" == ztype:
196
+ oz = []
197
+ for z in zs:
198
+ try:
199
+ float(z)
200
+ oz.append(z)
201
+ except:
202
+ pass
203
+ zs = oz
204
+ return zs
205
+
206
+ xs = adjuster(xmen,xtype,ytype)
207
+ ys = adjuster(ymen,ytype,xtype)
208
+
209
+ #in case beta selected but mode is Weight sum or Add or Diff
210
+ if ("beta" in xtype or "beta" in ytype) and (not usebeta and "tensor" not in calcmode):
211
+ mode = modes[3]
212
+ print(f"{modes[3]} mode automatically selected)")
213
+
214
+ #in case mbw or pinpoint selected but useblocks not chekced
215
+ if ("mbw" in xtype or "pinpoint blocks" in xtype) and not useblocks:
216
+ useblocks = True
217
+ print(f"MBW mode enabled")
218
+
219
+ if ("mbw" in ytype or "pinpoint blocks" in ytype) and not useblocks:
220
+ useblocks = True
221
+ print(f"MBW mode enabled")
222
+
223
+ xyimage=[]
224
+ xcount =ycount=0
225
+ allcount = len(xs)*len(ys)
226
+
227
+ #for STOP XY bottun
228
+ flag = False
229
+ global state_mergen
230
+ state_mergen = False
231
+
232
+ #type[0:none,1:aplha,2:beta,3:seed,4:mbw,5:model_A,6:model_B,7:model_C,8:pinpoint ]
233
+ blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
234
+ #format ,IN00 IN03,IN04-IN09,OUT4,OUT05
235
+ def weightsdealer(x,xtype,y,weights):
236
+ caster(f"weights from : {weights}",hear)
237
+ zz = x if "pinpoint blocks" in xtype else y
238
+ za = y if "pinpoint blocks" in xtype else x
239
+ zz = [z.strip() for z in zz.split(' ')]
240
+ weights_t = [w.strip() for w in weights.split(',')]
241
+ if zz[0]!="NOT":
242
+ flagger=[False]*26
243
+ changer = True
244
+ else:
245
+ flagger=[True]*26
246
+ changer = False
247
+ for z in zz:
248
+ if z =="NOT":continue
249
+ if "-" in z:
250
+ zt = [zt.strip() for zt in z.split('-')]
251
+ if blockid.index(zt[1]) > blockid.index(zt[0]):
252
+ flagger[blockid.index(zt[0]):blockid.index(zt[1])+1] = [changer]*(blockid.index(zt[1])-blockid.index(zt[0])+1)
253
+ else:
254
+ flagger[blockid.index(zt[1]):blockid.index(zt[0])+1] = [changer]*(blockid.index(zt[0])-blockid.index(zt[1])+1)
255
+ else:
256
+ flagger[blockid.index(z)] =changer
257
+ for i,f in enumerate(flagger):
258
+ if f:weights_t[i]=za
259
+ outext = ",".join(weights_t)
260
+ caster(f"weights changed: {outext}",hear)
261
+ return outext
262
+
263
+ def abdealer(z):
264
+ if " " in z:return z.split(" ")[0],z.split(" ")[1]
265
+ return z,z
266
+
267
+ def xydealer(z,zt,azt):
268
+ nonlocal alpha,beta,seed,weights_a_in,weights_b_in,model_a,model_b,model_c,deep,calcmode,prompt
269
+ if pinpoint or "pinpoint element" in zt or "effective" in zt:return
270
+ if "mbw" in zt:
271
+ def weightser(z):return z, z.split(',',1)[0]
272
+ if "mbw alpha and beta" in zt:
273
+ weights_a_in,alpha = weightser(wpreseter(z[0],wpresets))
274
+ weights_b_in,beta = weightser(wpreseter(z[1],wpresets))
275
+ return
276
+ elif "alpha" in zt:
277
+ weights_a_in,alpha = weightser(wpreseter(z,wpresets))
278
+ return
279
+ else:
280
+ weights_b_in,beta = weightser(wpreseter(z,wpresets))
281
+ return
282
+ if "and" in zt:
283
+ alpha,beta = abdealer(z)
284
+ return
285
+ if "alpha" in zt and not "pinpoint element" in azt:alpha = z
286
+ if "beta" in zt: beta = z
287
+ if "seed" in zt:seed = int(z)
288
+ if "model_A" in zt:model_a = z
289
+ if "model_B" in zt:model_b = z
290
+ if "model_C" in zt:model_c = z
291
+ if "elemental" in zt:deep = z
292
+ if "calcmode" in zt:calcmode = z
293
+ if "prompt" in zt:prompt = z
294
+
295
+ # plot start
296
+ for y in ys:
297
+ xydealer(y,ytype,xtype)
298
+ xcount = 0
299
+ for x in xs:
300
+ xydealer(x,xtype,ytype)
301
+ if ("alpha" in xtype or "alpha" in ytype) and pinpoint:
302
+ weights_a_in = weightsdealer(x,xtype,y,weights_a)
303
+ weights_b_in = weights_b
304
+ if ("beta" in xtype or "beta" in ytype) and pinpoint:
305
+ weights_b_in = weightsdealer(x,xtype,y,weights_b)
306
+ weights_a_in =weights_a
307
+ if "pinpoint element" in xtype or "effective" in xtype:
308
+ deep_in = deep +","+ str(x)+":"+ str(y)
309
+ elif "pinpoint element" in ytype or "effective" in ytype:
310
+ deep_in = deep +","+ str(y)+":"+ str(x)
311
+ else:
312
+ deep_in = deep
313
+
314
+ print(f"XY plot: X: {xtype}, {str(x)}, Y: {ytype}, {str(y)} ({xcount+ycount*len(xs)+1}/{allcount})")
315
+ if not (xtype=="seed" and xcount > 0):
316
+ _ , currentmodel,modelid,theta_0,_=smerge(weights_a_in,weights_b_in, model_a,model_b,model_c, float(alpha),float(beta),mode,calcmode,
317
+ useblocks,"","",id_sets,False,deep_in,tensor,deepprint = deepprint)
318
+ usemodelgen(theta_0,model_a,currentmodel)
319
+ # simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,mergeinfo="",id_sets=[],modelid = "no id"):
320
+ image_temp = simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,currentmodel,id_sets,modelid)
321
+ xyimage.append(image_temp[0][0])
322
+ xcount+=1
323
+ if state_mergen:
324
+ flag = True
325
+ break
326
+ ycount+=1
327
+ if flag:break
328
+
329
+ if flag and ycount ==1:
330
+ xs = xs[:xcount]
331
+ ys = [ys[0],]
332
+ print(f"stopped at x={xcount},y={ycount}")
333
+ elif flag:
334
+ ys=ys[:ycount]
335
+ print(f"stopped at x={xcount},y={ycount}")
336
+
337
+ if "mbw alpha and beta" in xtype: xs = [f"alpha:({x[0]}),beta({x[1]})" for x in xs ]
338
+ if "mbw alpha and beta" in ytype: ys = [f"alpha:({y[0]}),beta({y[1]})" for y in ys ]
339
+
340
+ xs[0]=xtype+" = "+xs[0] #draw X label
341
+ if ytype!=TYPES[0] or "model" in ytype:ys[0]=ytype+" = "+ys[0] #draw Y label
342
+
343
+ if ys==[""]:ys = [" "]
344
+
345
+ if "effective" in xtype or "effective" in ytype:
346
+ xyimage,xs,ys = effectivechecker(xyimage,xs,ys,model_a,model_b,esettings)
347
+
348
+ if not "grid" in esettings:
349
+ gridmodel= makegridmodelname(model_a, model_b,model_c, useblocks,mode,xtype,ytype,alpha,beta,weights_a,weights_b,usebeta)
350
+ grid = smakegrid(xyimage,xs,ys,gridmodel,image_temp[4])
351
+ xyimage.insert(0,grid)
352
+
353
+ state_mergen = False
354
+ return "Finished",currentmodel,xyimage,*image_temp[1:4]
355
+
356
+ def smakegrid(imgs,xs,ys,currentmodel,p):
357
+ ver_texts = [[images.GridAnnotation(y)] for y in ys]
358
+ hor_texts = [[images.GridAnnotation(x)] for x in xs]
359
+
360
+ w, h = imgs[0].size
361
+ grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
362
+
363
+ for i, img in enumerate(imgs):
364
+ grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
365
+
366
+ grid = images.draw_grid_annotations(grid,w,h, hor_texts, ver_texts)
367
+ grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
368
+ if opts.grid_save:
369
+ images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
370
+
371
+ return grid
372
+
373
+ def makegridmodelname(model_a, model_b,model_c, useblocks,mode,xtype,ytype,alpha,beta,wa,wb,usebeta):
374
+ model_a=filenamecutter(model_a)
375
+ model_b=filenamecutter(model_b)
376
+ model_c=filenamecutter(model_c)
377
+
378
+ if not usebeta:beta,wb = "not used","not used"
379
+ vals = ""
380
+ modes=["Weight" ,"Add" ,"Triple","Twice"]
381
+
382
+ if "mbw" in xtype:
383
+ if "alpha" in xtype:wa = "X"
384
+ if usebeta or " beta" in xtype:wb = "X"
385
+
386
+ if "mbw" in ytype:
387
+ if "alpha" in ytype:wa = "Y"
388
+ if usebeta or " beta" in ytype:wb = "Y"
389
+
390
+ wa = "alpha = " + wa
391
+ wb = "beta = " + wb
392
+
393
+ x = 50
394
+ while len(wa) > x:
395
+ wa = wa[:x] + '\n' + wa[x:]
396
+ x = x + 50
397
+
398
+ x = 50
399
+ while len(wb) > x:
400
+ wb = wb[:x] + '\n' + wb[x:]
401
+ x = x + 50
402
+
403
+ if "model" in xtype:
404
+ if "A" in xtype:model_a = "model A"
405
+ elif "B" in xtype:model_b="model B"
406
+ elif "C" in xtype:model_c="model C"
407
+
408
+ if "model" in ytype:
409
+ if "A" in ytype:model_a = "model A"
410
+ elif "B" in ytype:model_b="model B"
411
+ elif "C" in ytype:model_c="model C"
412
+
413
+ if modes[1] in mode:
414
+ currentmodel =f"{model_a} \n {model_b} - {model_c})\n x alpha"
415
+ elif modes[2] in mode:
416
+ currentmodel =f"{model_a} x \n(1-alpha-beta) {model_b} x alpha \n+ {model_c} x beta"
417
+ elif modes[3] in mode:
418
+ currentmodel =f"({model_a} x(1-alpha) \n + {model_b} x alpha)*(1-beta)\n+ {model_c} x beta"
419
+ else:
420
+ currentmodel =f"{model_a} x (1-alpha) \n {model_b} x alpha"
421
+
422
+ if "alpha" in xtype:alpha = "X"
423
+ if "beta" in xtype:beta = "X"
424
+ if "alpha" in ytype:alpha = "Y"
425
+ if "beta" in ytype:beta = "Y"
426
+
427
+ if "mbw" in xtype:
428
+ if "alpha" in xtype: alpha = "X"
429
+ if "beta" in xtype or usebeta: beta = "X"
430
+
431
+ if "mbw" in ytype:
432
+ if "alpha" in ytype: alpha = "Y"
433
+ if "beta" in ytype or usebeta: beta = "Y"
434
+
435
+ vals = f"\nalpha = {alpha},beta = {beta}" if not useblocks else f"\n{wa}\n{wb}"
436
+
437
+ currentmodel = currentmodel+vals
438
+ return currentmodel
439
+
440
+ def effectivechecker(imgs,xs,ys,model_a,model_b,esettings):
441
+ diffs = []
442
+ outnum =[]
443
+ im1 = np.array(imgs[0])
444
+
445
+ model_a = filenamecutter(model_a)
446
+ model_b = filenamecutter(model_b)
447
+ dir = os.path.join(opts.outdir_txt2img_samples,f"{model_a+model_b}","difgif")
448
+
449
+ if "gif" in esettings:
450
+ try:
451
+ os.makedirs(dir)
452
+ except FileExistsError:
453
+ pass
454
+
455
+ ls,ss = (xs.copy(),ys.copy()) if len(xs) > len(ys) else (ys.copy(),xs.copy())
456
+
457
+ for i in range(len(imgs)-1):
458
+ im2 = np.array(imgs[i+1])
459
+
460
+ abs_diff = cv2.absdiff(im2 , im1)
461
+
462
+ abs_diff_t = cv2.threshold(abs_diff, 5, 255, cv2.THRESH_BINARY)[1]
463
+ res = abs_diff_t.astype(np.uint8)
464
+ percentage = (np.count_nonzero(res) * 100)/ res.size
465
+ abs_diff = cv2.bitwise_not(abs_diff)
466
+ outnum.append(percentage)
467
+
468
+ abs_diff = Image.fromarray(abs_diff)
469
+
470
+ diffs.append(abs_diff)
471
+
472
+ if "gif" in esettings:
473
+ gifpath = gifpath_t = os.path.join(dir,ls[i+1].replace(":","_")+".gif")
474
+
475
+ is_file = os.path.isfile(gifpath)
476
+ j = 0
477
+ while is_file:
478
+ gifpath = gifpath_t.replace(".gif",f"_{j}.gif")
479
+ print(gifpath)
480
+ is_file = os.path.isfile(gifpath)
481
+ j = j + 1
482
+
483
+ imgs[0].save(gifpath, save_all=True, append_images=[imgs[i+1]], optimize=False, duration=1000, loop=0)
484
+
485
+ nums = []
486
+ outs = []
487
+
488
+ ls = ls[1:]
489
+ for i in range(len(ls)):
490
+ nums.append([ls[i],outnum[i]])
491
+ ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
492
+
493
+ if "csv" in esettings:
494
+ try:
495
+ os.makedirs(dir)
496
+ except FileExistsError:
497
+ pass
498
+ filepath = os.path.join(dir, f"{model_a+model_b}.csv")
499
+ with open(filepath, "a", newline="") as f:
500
+ writer = csv.writer(f)
501
+ writer.writerows(nums)
502
+
503
+ if len(ys) > len (xs):
504
+ for diff,img in zip(diffs,imgs[1:]):
505
+ outs.append(diff)
506
+ outs.append(img)
507
+ outs.append(imgs[0])
508
+ ss = ["diff",ss[0],"source"]
509
+ return outs,ss,ls
510
+ else:
511
+ outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
512
+ ss = ["source",ss[0],"diff"]
513
+ return outs,ls,ss
microsoftexcel-supermerger/scripts/supermerger.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import os.path
5
+ import re
6
+ import shutil
7
+ from importlib import reload
8
+ from pprint import pprint
9
+ import gradio as gr
10
+ from modules import (devices, script_callbacks, scripts, sd_hijack, sd_models,sd_vae, shared)
11
+ from modules.scripts import basedir
12
+ from modules.sd_models import checkpoints_loaded
13
+ from modules.shared import opts
14
+ from modules.ui import create_output_panel, create_refresh_button
15
+ import scripts.mergers.mergers
16
+ import scripts.mergers.pluslora
17
+ import scripts.mergers.xyplot
18
+ reload(scripts.mergers.mergers) # update without restarting web-ui.bat
19
+ reload(scripts.mergers.xyplot)
20
+ reload(scripts.mergers.pluslora)
21
+ import csv
22
+ import scripts.mergers.pluslora as pluslora
23
+ from scripts.mergers.mergers import (TYPESEG, freezemtime, rwmergelog, simggen,smergegen)
24
+ from scripts.mergers.xyplot import freezetime, nulister, numaker, numanager
25
+
26
+ gensets=argparse.Namespace()
27
+
28
+ def on_ui_train_tabs(params):
29
+ txt2img_preview_params=params.txt2img_preview_params
30
+ gensets.txt2img_preview_params=txt2img_preview_params
31
+ return None
32
+
33
+ path_root = basedir()
34
+
35
+ def on_ui_tabs():
36
+ weights_presets=""
37
+ userfilepath = os.path.join(path_root, "scripts","mbwpresets.txt")
38
+ if os.path.isfile(userfilepath):
39
+ try:
40
+ with open(userfilepath) as f:
41
+ weights_presets = f.read()
42
+ filepath = userfilepath
43
+ except OSError as e:
44
+ pass
45
+ else:
46
+ filepath = os.path.join(path_root, "scripts","mbwpresets_master.txt")
47
+ try:
48
+ with open(filepath) as f:
49
+ weights_presets = f.read()
50
+ shutil.copyfile(filepath, userfilepath)
51
+ except OSError as e:
52
+ pass
53
+
54
+ with gr.Blocks() as supermergerui:
55
+ with gr.Tab("Merge"):
56
+ with gr.Row().style(equal_height=False):
57
+ with gr.Column(scale = 3):
58
+ gr.HTML(value="<p>Merge models and load it for generation</p>")
59
+
60
+ with gr.Row():
61
+ model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model A",interactive=True)
62
+ create_refresh_button(model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
63
+
64
+ model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model B",interactive=True)
65
+ create_refresh_button(model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
66
+
67
+ model_c = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model C",interactive=True)
68
+ create_refresh_button(model_c, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
69
+
70
+ mode = gr.Radio(label = "Merge Mode",choices = ["Weight sum:A*(1-alpha)+B*alpha", "Add difference:A+(B-C)*alpha",
71
+ "Triple sum:A*(1-alpha-beta)+B*alpha+C*beta",
72
+ "sum Twice:(A*(1-alpha)+B*alpha)*(1-beta)+C*beta",
73
+ ], value = "Weight sum:A*(1-alpha)+B*alpha")
74
+ calcmode = gr.Radio(label = "Calcutation Mode",choices = ["normal", "cosineA", "cosineB", "smoothAdd","tensor"], value = "normal")
75
+ with gr.Row():
76
+ useblocks = gr.Checkbox(label="use MBW")
77
+ base_alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=0.5)
78
+ base_beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=0.25)
79
+ #weights = gr.Textbox(label="weights,base alpha,IN00,IN02,...IN11,M00,OUT00,...,OUT11",lines=2,value="0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5")
80
+
81
+ with gr.Row():
82
+ merge = gr.Button(elem_id="model_merger_merge", value="Merge!",variant='primary')
83
+ mergeandgen = gr.Button(elem_id="model_merger_merge", value="Merge&Gen",variant='primary')
84
+ gen = gr.Button(elem_id="model_merger_merge", value="Gen",variant='primary')
85
+ stopmerge = gr.Button(elem_id="stopmerge", value="Stop",variant='primary')
86
+ with gr.Row():
87
+ with gr.Column(scale = 4):
88
+ save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata"], value=["safetensors"], label="save settings")
89
+ with gr.Column(scale = 2):
90
+ id_sets = gr.CheckboxGroup(["image", "PNG info"], label="write merged model ID to")
91
+ with gr.Row():
92
+ with gr.Column(min_width = 50, scale=2):
93
+ with gr.Row():
94
+ custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="model_converter_custom_name")
95
+ mergeid = gr.Textbox(label="merge from ID", elem_id="model_converter_custom_name",value = "-1")
96
+ with gr.Column(min_width = 50, scale=1):
97
+ with gr.Row():s_reverse= gr.Button(value="Set from ID(-1 for last)",variant='primary')
98
+
99
+ with gr.Accordion("Restore faces, Tiling, Hires. fix, Batch size",open = False):
100
+ batch_size = denois_str = gr.Slider(minimum=0, maximum=8, step=1, label='Batch size', value=1, elem_id="sm_txt2img_batch_size")
101
+ genoptions = gr.CheckboxGroup(label = "Gen Options",choices=["Restore faces", "Tiling", "Hires. fix"], visible = True,interactive=True,type="value")
102
+ with gr.Row(elem_id="txt2img_hires_fix_row1", variant="compact"):
103
+ hrupscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
104
+ hr2ndsteps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
105
+ denois_str = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
106
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
107
+
108
+ hiresfix = [genoptions,hrupscaler,hr2ndsteps,denois_str,hr_scale]
109
+
110
+ with gr.Accordion("Elemental Merge",open = False):
111
+ with gr.Row():
112
+ esettings1 = gr.CheckboxGroup(label = "settings",choices=["print change"],type="value",interactive=True)
113
+ with gr.Row():
114
+ deep = gr.Textbox(label="Blocks:Element:Ratio,Blocks:Element:Ratio,...",lines=2,value="")
115
+
116
+ with gr.Accordion("Tensor Merge",open = False,visible=False):
117
+ tensor = gr.Textbox(label="Blocks:Tensors",lines=2,value="")
118
+
119
+ with gr.Row():
120
+ x_type = gr.Dropdown(label="X type", choices=[x for x in TYPESEG], value="alpha", type="index")
121
+ x_randseednum = gr.Number(value=3, label="number of -1", interactive=True, visible = True)
122
+ xgrid = gr.Textbox(label="Sequential Merge Parameters",lines=3,value="0.25,0.5,0.75")
123
+ y_type = gr.Dropdown(label="Y type", choices=[y for y in TYPESEG], value="none", type="index")
124
+ ygrid = gr.Textbox(label="Y grid (Disabled if blank)",lines=3,value="",visible =False)
125
+ with gr.Row():
126
+ gengrid = gr.Button(elem_id="model_merger_merge", value="Sequential XY Merge and Generation",variant='primary')
127
+ stopgrid = gr.Button(elem_id="model_merger_merge", value="Stop XY",variant='primary')
128
+ s_reserve1 = gr.Button(value="Reserve XY Plot",variant='primary')
129
+ dtrue = gr.Checkbox(value = True, visible = False)
130
+ dfalse = gr.Checkbox(value = False,visible = False)
131
+ dummy_t = gr.Textbox(value = "",visible = False)
132
+ blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
133
+
134
+ with gr.Column(scale = 2):
135
+ currentmodel = gr.Textbox(label="Current Model",lines=1,value="")
136
+ submit_result = gr.Textbox(label="Message")
137
+ mgallery, mgeninfo, mhtmlinfo, mhtmllog = create_output_panel("txt2img", opts.outdir_txt2img_samples)
138
+ with gr.Row(visible = False) as row_inputers:
139
+ inputer = gr.Textbox(label="",lines=1,value="")
140
+ addtox = gr.Button(value="Add to Sequence X")
141
+ addtoy = gr.Button(value="Add to Sequence Y")
142
+ with gr.Row(visible = False) as row_blockids:
143
+ blockids = gr.CheckboxGroup(label = "block IDs",choices=[x for x in blockid],type="value",interactive=True)
144
+ with gr.Row(visible = False) as row_calcmode:
145
+ calcmodes = gr.CheckboxGroup(label = "calcmode",choices=["normal", "cosineA", "cosineB", "smoothAdd","tensor"],type="value",interactive=True)
146
+ with gr.Row(visible = False) as row_checkpoints:
147
+ checkpoints = gr.CheckboxGroup(label = "checkpoint",choices=[x.model_name for x in sd_models.checkpoints_list.values()],type="value",interactive=True)
148
+ with gr.Row(visible = False) as row_esets:
149
+ esettings = gr.CheckboxGroup(label = "effective chekcer settings",choices=["save csv","save anime gif","not save grid","print change"],type="value",interactive=True)
150
+
151
+ with gr.Tab("Weights Setting"):
152
+ with gr.Row():
153
+ setalpha = gr.Button(elem_id="copytogen", value="set to alpha",variant='primary')
154
+ readalpha = gr.Button(elem_id="copytogen", value="read from alpha",variant='primary')
155
+ setbeta = gr.Button(elem_id="copytogen", value="set to beta",variant='primary')
156
+ readbeta = gr.Button(elem_id="copytogen", value="read from beta",variant='primary')
157
+ setx = gr.Button(elem_id="copytogen", value="set to X",variant='primary')
158
+ with gr.Row():
159
+ weights_a = gr.Textbox(label="weights for alpha, base alpha,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5")
160
+ weights_b = gr.Textbox(label="weights,for beta, base beta,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2")
161
+ with gr.Row():
162
+ base= gr.Slider(label="Base", minimum=0, maximum=1, step =0.01, value=0.5)
163
+ in00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5)
164
+ in01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5)
165
+ in02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5)
166
+ in03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5)
167
+ with gr.Row():
168
+ in04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5)
169
+ in05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5)
170
+ in06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5)
171
+ in07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5)
172
+ in08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5)
173
+ in09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5)
174
+ with gr.Row():
175
+ in10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5)
176
+ in11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5)
177
+ mi00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5)
178
+ ou00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5)
179
+ ou01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5)
180
+ ou02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5)
181
+ with gr.Row():
182
+ ou03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5)
183
+ ou04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5)
184
+ ou05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5)
185
+ ou06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5)
186
+ ou07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5)
187
+ ou08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5)
188
+ with gr.Row():
189
+ ou09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5)
190
+ ou10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5)
191
+ ou11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5)
192
+ with gr.Tab("Weights Presets"):
193
+ with gr.Row():
194
+ s_reloadtext = gr.Button(value="Reload Presets",variant='primary')
195
+ s_reloadtags = gr.Button(value="Reload Tags",variant='primary')
196
+ s_savetext = gr.Button(value="Save Presets",variant='primary')
197
+ s_openeditor = gr.Button(value="Open TextEditor",variant='primary')
198
+ weightstags= gr.Textbox(label="available",lines = 2,value=tagdicter(weights_presets),visible =True,interactive =True)
199
+ wpresets= gr.TextArea(label="",value=weights_presets,visible =True,interactive = True)
200
+
201
+ with gr.Tab("Reservation"):
202
+ with gr.Row():
203
+ s_reserve = gr.Button(value="Reserve XY Plot",variant='primary')
204
+ s_reloadreserve = gr.Button(value="Reloat List",variant='primary')
205
+ s_startreserve = gr.Button(value="Start XY plot",variant='primary')
206
+ s_delreserve = gr.Button(value="Delete list(-1 for all)",variant='primary')
207
+ s_delnum = gr.Number(value=1, label="Delete num : ", interactive=True, visible = True,precision =0)
208
+ with gr.Row():
209
+ numaframe = gr.Dataframe(
210
+ headers=["No.","status","xtype","xmenber", "ytype","ymenber","model A","model B","model C","alpha","beta","mode","use MBW","weights alpha","weights beta"],
211
+ row_count=5,)
212
+ # with gr.Tab("manual"):
213
+ # with gr.Row():
214
+ # gr.HTML(value="<p> exampls: Change base alpha from 0.1 to 0.9 <br>0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9<br>If you want to display the original model as well for comparison<br>0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1</p>")
215
+ # gr.HTML(value="<p> For block-by-block merging <br>0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5<br>1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1<br>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1</p>")
216
+
217
+ with gr.Row():
218
+
219
+ currentcache = gr.Textbox(label="Current Cache")
220
+ loadcachelist = gr.Button(elem_id="model_merger_merge", value="Reload Cache List",variant='primary')
221
+ unloadmodel = gr.Button(value="unload model",variant='primary')
222
+
223
+
224
+ # main ui end
225
+
226
+ with gr.Tab("LoRA", elem_id="tab_lora"):
227
+ pluslora.on_ui_tabs()
228
+
229
+ with gr.Tab("History", elem_id="tab_history"):
230
+
231
+ with gr.Row():
232
+ load_history = gr.Button(value="load_history",variant='primary')
233
+ searchwrods = gr.Textbox(label="",lines=1,value="")
234
+ search = gr.Button(value="search")
235
+ searchmode = gr.Radio(label = "Search Mode",choices = ["or","and"], value = "or",type = "value")
236
+ with gr.Row():
237
+ history = gr.Dataframe(
238
+ headers=["ID","Time","Name","Weights alpha","Weights beta","Model A","Model B","Model C","alpha","beta","Mode","use MBW","custum name","save setting","use ID"],
239
+ )
240
+
241
+ with gr.Tab("Elements", elem_id="tab_deep"):
242
+ with gr.Row():
243
+ smd_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True)
244
+ create_refresh_button(smd_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
245
+ smd_loadkeys = gr.Button(value="load keys",variant='primary')
246
+ with gr.Row():
247
+ keys = gr.Dataframe(headers=["No.","block","key"],)
248
+
249
+ with gr.Tab("Metadeta", elem_id="tab_metadata"):
250
+ with gr.Row():
251
+ meta_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="read metadata",interactive=True)
252
+ create_refresh_button(meta_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
253
+ smd_loadmetadata = gr.Button(value="load keys",variant='primary')
254
+ with gr.Row():
255
+ metadata = gr.TextArea()
256
+
257
+ smd_loadmetadata.click(
258
+ fn=loadmetadata,
259
+ inputs=[meta_model_a],
260
+ outputs=[metadata]
261
+ )
262
+
263
+ smd_loadkeys.click(
264
+ fn=loadkeys,
265
+ inputs=[smd_model_a],
266
+ outputs=[keys]
267
+ )
268
+
269
+ def unload():
270
+ if shared.sd_model == None: return "already unloaded"
271
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
272
+ shared.sd_model = None
273
+ gc.collect()
274
+ devices.torch_gc()
275
+ return "model unloaded"
276
+
277
+ unloadmodel.click(fn=unload,outputs=[submit_result])
278
+
279
+ load_history.click(fn=load_historyf,outputs=[history ])
280
+
281
+ msettings=[weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor]
282
+ imagegal = [mgallery,mgeninfo,mhtmlinfo,mhtmllog]
283
+ xysettings=[x_type,xgrid,y_type,ygrid,esettings]
284
+
285
+ s_reverse.click(fn = reversparams,
286
+ inputs =mergeid,
287
+ outputs = [submit_result,*msettings[0:8],*msettings[9:13],deep,calcmode]
288
+ )
289
+
290
+ merge.click(
291
+ fn=smergegen,
292
+ inputs=[*msettings,esettings1,*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,dfalse],
293
+ outputs=[submit_result,currentmodel]
294
+ )
295
+
296
+ mergeandgen.click(
297
+ fn=smergegen,
298
+ inputs=[*msettings,esettings1,*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,dtrue],
299
+ outputs=[submit_result,currentmodel,*imagegal]
300
+ )
301
+
302
+ gen.click(
303
+ fn=simggen,
304
+ inputs=[*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,id_sets],
305
+ outputs=[*imagegal],
306
+ )
307
+
308
+ s_reserve.click(
309
+ fn=numaker,
310
+ inputs=[*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
311
+ outputs=[numaframe]
312
+ )
313
+
314
+ s_reserve1.click(
315
+ fn=numaker,
316
+ inputs=[*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
317
+ outputs=[numaframe]
318
+ )
319
+
320
+ gengrid.click(
321
+ fn=numanager,
322
+ inputs=[dtrue,*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
323
+ outputs=[submit_result,currentmodel,*imagegal],
324
+ )
325
+
326
+ s_startreserve.click(
327
+ fn=numanager,
328
+ inputs=[dfalse,*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
329
+ outputs=[submit_result,currentmodel,*imagegal],
330
+ )
331
+
332
+ search.click(fn = searchhistory,inputs=[searchwrods,searchmode],outputs=[history])
333
+
334
+ s_reloadreserve.click(fn=nulister,inputs=[dfalse],outputs=[numaframe])
335
+ s_delreserve.click(fn=nulister,inputs=[s_delnum],outputs=[numaframe])
336
+ loadcachelist.click(fn=load_cachelist,inputs=[],outputs=[currentcache])
337
+ addtox.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[xgrid])
338
+ addtoy.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[ygrid])
339
+
340
+ stopgrid.click(fn=freezetime)
341
+ stopmerge.click(fn=freezemtime)
342
+
343
+ checkpoints.change(fn=lambda x:",".join(x),inputs=[checkpoints],outputs=[inputer])
344
+ blockids.change(fn=lambda x:" ".join(x),inputs=[blockids],outputs=[inputer])
345
+ calcmodes.change(fn=lambda x:",".join(x),inputs=[calcmodes],outputs=[inputer])
346
+
347
+ menbers = [base,in00,in01,in02,in03,in04,in05,in06,in07,in08,in09,in10,in11,mi00,ou00,ou01,ou02,ou03,ou04,ou05,ou06,ou07,ou08,ou09,ou10,ou11]
348
+
349
+ setalpha.click(fn=slider2text,inputs=menbers,outputs=[weights_a])
350
+ setbeta.click(fn=slider2text,inputs=menbers,outputs=[weights_b])
351
+ setx.click(fn=add_to_seq,inputs=[xgrid,weights_a],outputs=[xgrid])
352
+
353
+ readalpha.click(fn=text2slider,inputs=weights_a,outputs=menbers)
354
+ readbeta.click(fn=text2slider,inputs=weights_b,outputs=menbers)
355
+
356
+ x_type.change(fn=showxy,inputs=[x_type,y_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,row_esets,row_calcmode])
357
+ y_type.change(fn=showxy,inputs=[x_type,y_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,row_esets,row_calcmode])
358
+ x_randseednum.change(fn=makerand,inputs=[x_randseednum],outputs=[xgrid])
359
+
360
+ import subprocess
361
+ def openeditors():
362
+ subprocess.Popen(['start', filepath], shell=True)
363
+
364
+ def reloadpresets():
365
+ try:
366
+ with open(filepath) as f:
367
+ return f.read()
368
+ except OSError as e:
369
+ pass
370
+
371
+ def savepresets(text):
372
+ with open(filepath,mode = 'w') as f:
373
+ f.write(text)
374
+
375
+ s_reloadtext.click(fn=reloadpresets,inputs=[],outputs=[wpresets])
376
+ s_reloadtags.click(fn=tagdicter,inputs=[wpresets],outputs=[weightstags])
377
+ s_savetext.click(fn=savepresets,inputs=[wpresets],outputs=[])
378
+ s_openeditor.click(fn=openeditors,inputs=[],outputs=[])
379
+
380
+ return (supermergerui, "SuperMerger", "supermerger"),
381
+
382
+ msearch = []
383
+ mlist=[]
384
+
385
+ def loadmetadata(model):
386
+ import json
387
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model)
388
+ if ".safetensors" not in checkpoint_info.filename: return "no metadata(not safetensors)"
389
+ sdict = sd_models.read_metadata_from_safetensors(checkpoint_info.filename)
390
+ if sdict == {}: return "no metadata"
391
+ return json.dumps(sdict,indent=4)
392
+
393
+ def load_historyf():
394
+ filepath = os.path.join(path_root,"mergehistory.csv")
395
+ global mlist,msearch
396
+ msearch = []
397
+ mlist=[]
398
+ try:
399
+ with open(filepath, 'r') as f:
400
+ reader = csv.reader(f)
401
+ mlist = [raw for raw in reader]
402
+ mlist = mlist[1:]
403
+ for m in mlist:
404
+ msearch.append(" ".join(m))
405
+ maxlen = len(mlist[-1][0])
406
+ for i,m in enumerate(mlist):
407
+ mlist[i][0] = mlist[i][0].zfill(maxlen)
408
+ return mlist
409
+ except:
410
+ return [["no data","",""],]
411
+
412
+ def searchhistory(words,searchmode):
413
+ outs =[]
414
+ ando = "and" in searchmode
415
+ words = words.split(" ") if " " in words else [words]
416
+ for i, m in enumerate(msearch):
417
+ hit = ando
418
+ for w in words:
419
+ if ando:
420
+ if w not in m:hit = False
421
+ else:
422
+ if w in m:hit = True
423
+ print(i,len(mlist))
424
+ if hit :outs.append(mlist[i])
425
+
426
+ if outs == []:return [["no result","",""],]
427
+ return outs
428
+
429
+ #msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
430
+
431
+ def reversparams(id):
432
+ def selectfromhash(hash):
433
+ for model in sd_models.checkpoint_tiles():
434
+ if hash in model:
435
+ return model
436
+ return ""
437
+ try:
438
+ idsets = rwmergelog(id = id)
439
+ except:
440
+ return [gr.update(value = "ERROR: history file could not open"),*[gr.update() for x in range(14)]]
441
+ if type(idsets) == str:
442
+ print("ERROR")
443
+ return [gr.update(value = idsets),*[gr.update() for x in range(14)]]
444
+ if idsets[0] == "ID":return [gr.update(value ="ERROR: no history"),*[gr.update() for x in range(14)]]
445
+ mgs = idsets[3:]
446
+ if mgs[0] == "":mgs[0] = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
447
+ if mgs[1] == "":mgs[1] = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2"
448
+ mgs[2] = selectfromhash(mgs[2]) if len(mgs[2]) > 5 else ""
449
+ mgs[3] = selectfromhash(mgs[3]) if len(mgs[3]) > 5 else ""
450
+ mgs[4] = selectfromhash(mgs[4]) if len(mgs[4]) > 5 else ""
451
+ mgs[8] = True if mgs[8] =="True" else False
452
+ mgs[10] = mgs[10].replace("[","").replace("]","").replace("'", "")
453
+ mgs[10] = [x.strip() for x in mgs[10].split(",")]
454
+ mgs[11] = mgs[11].replace("[","").replace("]","").replace("'", "")
455
+ mgs[11] = [x.strip() for x in mgs[11].split(",")]
456
+ while len(mgs) < 14:
457
+ mgs.append("")
458
+ mgs[13] = "normal" if mgs[13] == "" else mgs[13]
459
+ return [gr.update(value = "setting loaded") ,*[gr.update(value = x) for x in mgs[0:14]]]
460
+
461
+ def add_to_seq(seq,maker):
462
+ return gr.Textbox.update(value = maker if seq=="" else seq+"\r\n"+maker)
463
+
464
+ def load_cachelist():
465
+ text = ""
466
+ for x in checkpoints_loaded.keys():
467
+ text = text +"\r\n"+ x.model_name
468
+ return text.replace("\r\n","",1)
469
+
470
+ def makerand(num):
471
+ text = ""
472
+ for x in range(int(num)):
473
+ text = text +"-1,"
474
+ text = text[:-1]
475
+ return text
476
+
477
+ #row_blockids,row_checkpoints,row_inputers,ygrid
478
+ def showxy(x,y):
479
+ flags =[False]*6
480
+ t = TYPESEG
481
+ txy = t[x] + t[y]
482
+ if "model" in txy : flags[1] = flags[2] = True
483
+ if "pinpoint" in txy : flags[0] = flags[2] = True
484
+ if "effective" in txy or "element" in txy : flags[4] = True
485
+ if "calcmode" in txy : flags[5] = True
486
+ if not "none" in t[y] : flags[3] = flags[2] = True
487
+ return [gr.update(visible = x) for x in flags]
488
+
489
+ def text2slider(text):
490
+ vals = [t.strip() for t in text.split(",")]
491
+ return [gr.update(value = float(v)) for v in vals]
492
+
493
+ def slider2text(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z):
494
+ numbers = [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z]
495
+ numbers = [str(x) for x in numbers]
496
+ return gr.update(value = ",".join(numbers) )
497
+
498
+ def tagdicter(presets):
499
+ presets=presets.splitlines()
500
+ wdict={}
501
+ for l in presets:
502
+ w=[]
503
+ if ":" in l :
504
+ key = l.split(":",1)[0]
505
+ w = l.split(":",1)[1]
506
+ if "\t" in l:
507
+ key = l.split("\t",1)[0]
508
+ w = l.split("\t",1)[1]
509
+ if len([w for w in w.split(",")]) == 26:
510
+ wdict[key.strip()]=w
511
+ return ",".join(list(wdict.keys()))
512
+
513
+ def loadkeys(model_a):
514
+ checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
515
+ sd = sd_models.read_state_dict(checkpoint_info.filename,"cpu")
516
+ keys = []
517
+ for i, key in enumerate(sd.keys()):
518
+ re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
519
+ re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
520
+ re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
521
+
522
+ weight_index = -1
523
+ blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11","Not Merge"]
524
+
525
+ NUM_INPUT_BLOCKS = 12
526
+ NUM_MID_BLOCK = 1
527
+ NUM_OUTPUT_BLOCKS = 12
528
+ NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
529
+
530
+ if 'time_embed' in key:
531
+ weight_index = -2 # before input blocks
532
+ elif '.out.' in key:
533
+ weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
534
+ else:
535
+ m = re_inp.search(key)
536
+ if m:
537
+ inp_idx = int(m.groups()[0])
538
+ weight_index = inp_idx
539
+ else:
540
+ m = re_mid.search(key)
541
+ if m:
542
+ weight_index = NUM_INPUT_BLOCKS
543
+ else:
544
+ m = re_out.search(key)
545
+ if m:
546
+ out_idx = int(m.groups()[0])
547
+ weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
548
+ keys.append([i,blockid[weight_index+1],key])
549
+ return keys
550
+
551
+ script_callbacks.on_ui_tabs(on_ui_tabs)
552
+ script_callbacks.on_ui_train_tabs(on_ui_train_tabs)
microsoftexcel-tunnels/.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # End of https://www.toptal.com/developers/gitignore/api/python
174
+
175
+ id_rsa
176
+ id_rsa.pub
microsoftexcel-tunnels/.pre-commit-config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ args: [--markdown-linebreak-ext=md]
7
+ - id: end-of-file-fixer
8
+
9
+ - repo: https://github.com/asottile/pyupgrade
10
+ rev: v3.3.1
11
+ hooks:
12
+ - id: pyupgrade
13
+ args: [--py310-plus]
14
+
15
+ - repo: https://github.com/psf/black
16
+ rev: 23.1.0
17
+ hooks:
18
+ - id: black
19
+
20
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
21
+ # Ruff version.
22
+ rev: "v0.0.244"
23
+ hooks:
24
+ - id: ruff
25
+ args: [--fix]
microsoftexcel-tunnels/LICENSE.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ The MIT License (MIT)
3
+
4
+ Copyright (c) 2023 Bingsu
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
microsoftexcel-tunnels/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sd-webui-tunnels
2
+
3
+ Tunneling extension for [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
4
+
5
+ ## Usage
6
+
7
+ ### [cloudflared](https://try.cloudflare.com/)
8
+
9
+ add `--cloudflared` to commandline options.
10
+
11
+ ### [localhost.run](https://localhost.run/)
12
+
13
+ add `--localhostrun` to commandline options.
14
+
15
+ ### [remote.moe](https://github.com/fasmide/remotemoe)
16
+
17
+ add `--remotemoe` to commandline options.
18
+
19
+ The feature of `remote.moe` is that as long as the same ssh key is used, the same url is generated.
20
+
21
+ The ssh keys for `localhost.run` and `remote.moe` are created with the name `id_rsa` in the script's root folder. However, if there is a problem with the write permission, it is created in a temporary folder instead, so a different url is created each time.
microsoftexcel-tunnels/__pycache__/preload.cpython-310.pyc ADDED
Binary file (623 Bytes). View file
 
microsoftexcel-tunnels/install.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import launch
2
+
3
+ if not launch.is_installed("pycloudflared"):
4
+ launch.run_pip("install pycloudflared", "pycloudflared")
microsoftexcel-tunnels/preload.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def preload(parser: argparse.ArgumentParser):
5
+ parser.add_argument(
6
+ "--cloudflared",
7
+ action="store_true",
8
+ help="use trycloudflare, alternative to gradio --share",
9
+ )
10
+
11
+ parser.add_argument(
12
+ "--localhostrun",
13
+ action="store_true",
14
+ help="use localhost.run, alternative to gradio --share",
15
+ )
16
+
17
+ parser.add_argument(
18
+ "--remotemoe",
19
+ action="store_true",
20
+ help="use remote.moe, alternative to gradio --share",
21
+ )
microsoftexcel-tunnels/pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sd-webui-tunnels"
3
+ version = "23.2.1"
4
+ description = "Tunneling extension for automatic1111 sd-webui"
5
+ authors = [
6
+ {name = "dowon", email = "ks2515@naver.com"},
7
+ ]
8
+ requires-python = ">=3.8"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+
12
+ [project.urls]
13
+ repository = "https://github.com/Bing-su/sd-webui-tunnels"
14
+
15
+ [tool.isort]
16
+ profile = "black"
17
+ known_first_party = ["modules", "launch"]
18
+
19
+ [tool.ruff]
20
+ select = ["A", "B", "C4", "E", "F", "I001", "N", "PT", "UP", "W"]
21
+ ignore = ["B008", "B905", "E501"]
22
+ unfixable = ["F401"]
23
+
24
+ [tool.ruff.isort]
25
+ known-first-party = ["modules", "launch"]
microsoftexcel-tunnels/scripts/__pycache__/ssh_tunnel.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
microsoftexcel-tunnels/scripts/__pycache__/try_cloudflare.cpython-310.pyc ADDED
Binary file (597 Bytes). View file
 
microsoftexcel-tunnels/scripts/ssh_tunnel.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import re
3
+ import shlex
4
+ import subprocess
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+ from typing import Union
8
+ from gradio import strings
9
+ import os
10
+
11
+ from modules.shared import cmd_opts
12
+
13
+ LOCALHOST_RUN = "localhost.run"
14
+ REMOTE_MOE = "remote.moe"
15
+ localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)")
16
+ remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)")
17
+
18
+
19
+ def gen_key(path: Union[str, Path]) -> None:
20
+ path = Path(path)
21
+ arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
22
+ args = shlex.split(arg_string)
23
+ subprocess.run(args, check=True)
24
+ path.chmod(0o600)
25
+
26
+
27
+ def ssh_tunnel(host: str = LOCALHOST_RUN) -> None:
28
+ ssh_name = "id_rsa"
29
+ ssh_path = Path(__file__).parent.parent / ssh_name
30
+
31
+ tmp = None
32
+ if not ssh_path.exists():
33
+ try:
34
+ gen_key(ssh_path)
35
+ # write permission error or etc
36
+ except subprocess.CalledProcessError:
37
+ tmp = TemporaryDirectory()
38
+ ssh_path = Path(tmp.name) / ssh_name
39
+ gen_key(ssh_path)
40
+
41
+ port = cmd_opts.port if cmd_opts.port else 7860
42
+
43
+ arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}"
44
+ args = shlex.split(arg_string)
45
+
46
+ tunnel = subprocess.Popen(
47
+ args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
48
+ )
49
+
50
+ atexit.register(tunnel.terminate)
51
+ if tmp is not None:
52
+ atexit.register(tmp.cleanup)
53
+
54
+ tunnel_url = ""
55
+ lines = 27 if host == LOCALHOST_RUN else 5
56
+ pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern
57
+
58
+ for _ in range(lines):
59
+ line = tunnel.stdout.readline()
60
+ if line.startswith("Warning"):
61
+ print(line, end="")
62
+
63
+ url_match = pattern.search(line)
64
+ if url_match:
65
+ tunnel_url = url_match.group("url")
66
+ break
67
+ else:
68
+ raise RuntimeError(f"Failed to run {host}")
69
+
70
+ # print(f" * Running on {tunnel_url}")
71
+ os.environ['webui_url'] = tunnel_url
72
+ colab_url = os.getenv('colab_url')
73
+ strings.en["SHARE_LINK_MESSAGE"] = f"Running on public URL (recommended): {tunnel_url}"
74
+
75
+ if cmd_opts.localhostrun:
76
+ print("localhost.run detected, trying to connect...")
77
+ ssh_tunnel(LOCALHOST_RUN)
78
+
79
+ if cmd_opts.remotemoe:
80
+ print("remote.moe detected, trying to connect...")
81
+ ssh_tunnel(REMOTE_MOE)
microsoftexcel-tunnels/scripts/try_cloudflare.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # credit to camenduru senpai
2
+ from pycloudflared import try_cloudflare
3
+
4
+ from modules.shared import cmd_opts
5
+
6
+ from gradio import strings
7
+
8
+ import os
9
+
10
+ if cmd_opts.cloudflared:
11
+ print("cloudflared detected, trying to connect...")
12
+ port = cmd_opts.port if cmd_opts.port else 7860
13
+ tunnel_url = try_cloudflare(port=port, verbose=False)
14
+ os.environ['webui_url'] = tunnel_url.tunnel
15
+ strings.en["PUBLIC_SHARE_TRUE"] = f"Running on public URL: {tunnel_url.tunnel}"
microsoftexcel-tunnels/ssh_tunnel.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import re
3
+ import shlex
4
+ import subprocess
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+ from typing import Union
8
+ from gradio import strings
9
+ import os
10
+
11
+ from modules.shared import cmd_opts
12
+
13
+ LOCALHOST_RUN = "localhost.run"
14
+ REMOTE_MOE = "remote.moe"
15
+ localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)")
16
+ remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)")
17
+
18
+
19
+ def gen_key(path: Union[str, Path]) -> None:
20
+ path = Path(path)
21
+ arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
22
+ args = shlex.split(arg_string)
23
+ subprocess.run(args, check=True)
24
+ path.chmod(0o600)
25
+
26
+
27
+ def ssh_tunnel(host: str = LOCALHOST_RUN) -> None:
28
+ ssh_name = "id_rsa"
29
+ ssh_path = Path(__file__).parent.parent / ssh_name
30
+
31
+ tmp = None
32
+ if not ssh_path.exists():
33
+ try:
34
+ gen_key(ssh_path)
35
+ # write permission error or etc
36
+ except subprocess.CalledProcessError:
37
+ tmp = TemporaryDirectory()
38
+ ssh_path = Path(tmp.name) / ssh_name
39
+ gen_key(ssh_path)
40
+
41
+ port = cmd_opts.port if cmd_opts.port else 7860
42
+
43
+ arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}"
44
+ args = shlex.split(arg_string)
45
+
46
+ tunnel = subprocess.Popen(
47
+ args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
48
+ )
49
+
50
+ atexit.register(tunnel.terminate)
51
+ if tmp is not None:
52
+ atexit.register(tmp.cleanup)
53
+
54
+ tunnel_url = ""
55
+ lines = 27 if host == LOCALHOST_RUN else 5
56
+ pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern
57
+
58
+ for _ in range(lines):
59
+ line = tunnel.stdout.readline()
60
+ if line.startswith("Warning"):
61
+ print(line, end="")
62
+
63
+ url_match = pattern.search(line)
64
+ if url_match:
65
+ tunnel_url = url_match.group("url")
66
+ break
67
+ else:
68
+ raise RuntimeError(f"Failed to run {host}")
69
+
70
+ # print(f" * Running on {tunnel_url}")
71
+ os.environ['webui_url'] = tunnel_url
72
+ colab_url = os.getenv('colab_url')
73
+ strings.en["SHARE_LINK_MESSAGE"] = f"Public WebUI Colab URL: {tunnel_url}"
74
+
75
+
76
+ def googleusercontent_tunnel():
77
+ colab_url = os.getenv('colab_url')
78
+ strings.en["SHARE_LINK_MESSAGE"] = f"WebUI Colab URL: {colab_url}"
79
+
80
+ if cmd_opts.localhostrun:
81
+ print("localhost.run detected, trying to connect...")
82
+ ssh_tunnel(LOCALHOST_RUN)
83
+
84
+ if cmd_opts.remotemoe:
85
+ print("remote.moe detected, trying to connect...")
86
+ ssh_tunnel(REMOTE_MOE)
openpose-editor/.github/CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # default reviewer
2
+ * @fkunn1326
openpose-editor/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Environment**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+
30
+ **Additional context**
31
+ Add any other context about the problem here.
openpose-editor/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: enhancement
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Additional context**
17
+ Add any other context or screenshots about the feature request here.
openpose-editor/.github/workflows/typos.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # yamllint disable rule:line-length
3
+ name: Typos
4
+
5
+ on: # yamllint disable-line rule:truthy
6
+ push:
7
+ pull_request:
8
+ types:
9
+ - opened
10
+ - synchronize
11
+ - reopened
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v3
19
+
20
+ - name: typos-action
21
+ uses: crate-ci/typos@v1.13.10
openpose-editor/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ models/
openpose-editor/.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "nuxt.isNuxtApp": false,
3
+ "editor.tabCompletion": "on",
4
+ "diffEditor.codeLens": true
5
+ }
openpose-editor/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Fkunn1326
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
openpose-editor/README.en.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Openpose Editor
2
+
3
+ [日本語](README.md) | English | [简体中文](README.zh-cn.md)
4
+
5
+ ![image](https://user-images.githubusercontent.com/92153597/219921945-468b2e4f-a3a0-4d44-a923-13ceb0258ddc.png)
6
+
7
+ Openpose Editor for Automatic1111/stable-diffusion-webui
8
+
9
+ - Pose editing
10
+ - Pose detection
11
+
12
+ This can:
13
+
14
+ - Add a new person
15
+ - Detect pose from an image
16
+ - Add background image
17
+
18
+ - Save as a PNG
19
+ - Send to ControlNet extension
20
+
21
+ ## Installation
22
+
23
+ 1. Open the "Extension" tab
24
+ 2. Click on "Install from URL"
25
+ 3. In "URL for extension's git repository" enter this extension, https://github.com/fkunn1326/openpose-editor.git
26
+ 4. Click "Install"
27
+ 5. Restart WebUI
28
+
29
+ ## Attention
30
+
31
+ Do not select anything for the Preprocessor in ControlNet.
32
+
33
+
34
+ ## Fix Error
35
+ > urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
36
+
37
+ Run
38
+ ```
39
+ /Applications/Python\ $version /Install\ Certificates.command
40
+ ```
openpose-editor/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Openpose Editor
2
+
3
+ 日本語 | [English](README.en.md) | [简体中文](README.zh-cn.md)
4
+
5
+ ![image](https://user-images.githubusercontent.com/92153597/219921945-468b2e4f-a3a0-4d44-a923-13ceb0258ddc.png)
6
+
7
+ Automatic1111/stable-diffusion-webui用のOpenpose Editor
8
+
9
+ - ポーズの編集
10
+ - ポーズの検出
11
+
12
+ ができます
13
+
14
+ - 「Add」: 人を追加する
15
+ - 「Detect from image」: 画像からポーズを検出する
16
+ - 「Add Background image」: 背景を追加する
17
+
18
+ - 「Save PNG」: PNGで保存する
19
+ - 「Send to ControlNet」: Controlnet拡張機能がインストールされている場合、画像をそこに送る
20
+
21
+ ## インストール方法
22
+
23
+ 1. "Extension" タブを開く
24
+ 2. "Install from URL" タブを開く
25
+ 3. "URL for extension's git repository" 欄にこのリポジトリの URL (https://github.com/fkunn1326/openpose-editor.git) を入れます。
26
+ 4. "Install" ボタンを押す
27
+ 5. WebUIを再起動する
28
+
29
+ ## 注意
30
+
31
+ ControlNetの "Preprocessor" には、何も指定しないようにしてください。
32
+
33
+ ## エラーの対策
34
+
35
+ > urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
36
+
37
+
38
+ 以下のファイルを開いてださい
39
+ ```
40
+ /Applications/Python\ $version /Install\ Certificates.command
41
+ ```
openpose-editor/README.zh-cn.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Openpose Editor
2
+
3
+ [日本語](README.md) | [English](README.en.md)|中文
4
+
5
+ ![image](https://user-images.githubusercontent.com/92153597/219921945-468b2e4f-a3a0-4d44-a923-13ceb0258ddc.png)
6
+
7
+ 适用于Automatic1111/stable-diffusion-webui 的Openpose Editor 插件。
8
+
9
+ 功能:
10
+ - 直接编辑骨骼动作
11
+ - 从图像识别姿势
12
+
13
+ 本插件实现以下操作:
14
+
15
+ - 「Add」:添加一个新骨骼
16
+ - 「Detect from image」: 从图片中识别姿势
17
+ - 「Add Background image」: 添加背景图片
18
+ - 「Load JSON」:载入JSON文件
19
+
20
+ - 「Save PNG」: 保存为PNG格式图片
21
+ - 「Send to ControlNet」:将骨骼姿势发送到 ControlNet
22
+ - 「Save JSON」:将骨骼保存为JSON
23
+ ## 安装方法
24
+
25
+ 1. 打开扩展(Extension)标签。
26
+ 2. 点击从网址安装(Install from URL)
27
+ 3. 在扩展的 git 仓库网址(URL for extension's git repository)处输入 https://github.com/fkunn1326/openpose-editor.git
28
+ 4. 点击安装(Install)
29
+ 5. 重启 WebUI
30
+ ## 注意
31
+
32
+ 不要给ConrtolNet 的 "Preprocessor" 选项指定任何值,请保持在none状态
33
+
34
+ ## 常见问题
35
+ Mac OS可能会出现:
36
+ > urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
37
+
38
+ 请执行文件
39
+ ```
40
+ /Applications/Python\ $version /Install\ Certificates.command
41
+ ```
openpose-editor/_typos.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Files for typos
2
+ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3
+
4
+ [default.extend-identifiers]
5
+
6
+ [default.extend-words]
7
+
8
+
9
+
10
+ [files]
11
+ extend-exclude = ["_typos.toml", "fabric.js"]
openpose-editor/configs/.gitkeep ADDED
File without changes
openpose-editor/images//343/202/271/343/202/257/343/203/252/343/203/274/343/203/263/343/202/267/343/203/247/343/203/203/343/203/210 2023-02-19 131430.png ADDED
openpose-editor/javascript/fabric.js ADDED
The diff for this file is too large to render. See raw diff
 
openpose-editor/javascript/main.js ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fabric.Object.prototype.transparentCorners = false;
2
+ fabric.Object.prototype.cornerColor = '#108ce6';
3
+ fabric.Object.prototype.borderColor = '#108ce6';
4
+ fabric.Object.prototype.cornerSize = 10;
5
+
6
+ let count = 0;
7
+ let executed_openpose_editor = false;
8
+
9
+ let lockMode = false;
10
+ const undo_history = [];
11
+ const redo_history = [];
12
+
13
+ coco_body_keypoints = [
14
+ "nose",
15
+ "neck",
16
+ "right_shoulder",
17
+ "right_elbow",
18
+ "right_wrist",
19
+ "left_shoulder",
20
+ "left_elbow",
21
+ "left_wrist",
22
+ "right_hip",
23
+ "right_knee",
24
+ "right_ankle",
25
+ "left_hip",
26
+ "left_knee",
27
+ "left_ankle",
28
+ "right_eye",
29
+ "left_eye",
30
+ "right_ear",
31
+ "left_ear",
32
+ ]
33
+
34
+ let connect_keypoints = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13], [14, 0], [14, 16], [15, 0], [15, 17]]
35
+
36
+ let connect_color = [[0, 0, 255], [255, 0, 0], [255, 170, 0], [255, 255, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
37
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [85, 0, 255],
38
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
39
+
40
+ let openpose_obj = {
41
+ // width, height
42
+ resolution: [512, 512],
43
+ // fps...?
44
+ fps: 1,
45
+ // frames
46
+ frames: [
47
+ {
48
+ frame_current: 1,
49
+ // armatures
50
+ armatures: {
51
+ },
52
+ }
53
+ ]
54
+ }
55
+
56
+ let visibleEyes = true;
57
+ let flipped = false;
58
+ const default_keypoints = [[241,77],[241,120],[191,118],[177,183],[163,252],[298,118],[317,182],[332,245],[225,241],[213,359],[215,454],[270,240],[282,360],[286,456],[232,59],[253,60],[225,70],[260,72]]
59
+
60
+ async function fileToDataUrl(file) {
61
+ if (file.data) {
62
+ // Gradio version < 3.23
63
+ return file.data
64
+ }
65
+ return await new Promise(r => {let a=new FileReader(); a.onload=r; a.readAsDataURL(file.blob)}).then(e => e.target.result)
66
+ }
67
+
68
+ function calcResolution(width, height){
69
+ const viewportWidth = window.innerWidth / 2.25;
70
+ const viewportHeight = window.innerHeight * 0.75;
71
+ const ratio = Math.min(viewportWidth / width, viewportHeight / height);
72
+ return {width: width * ratio, height: height * ratio}
73
+ }
74
+
75
+ function resizeCanvas(width, height){
76
+ const elem = openpose_editor_elem;
77
+ const canvas = openpose_editor_canvas;
78
+
79
+ let resolution = calcResolution(width, height)
80
+
81
+ canvas.setWidth(width);
82
+ canvas.setHeight(height);
83
+ elem.style.width = resolution["width"] + "px"
84
+ elem.style.height = resolution["height"] + "px"
85
+ elem.nextElementSibling.style.width = resolution["width"] + "px"
86
+ elem.nextElementSibling.style.height = resolution["height"] + "px"
87
+ elem.parentElement.style.width = resolution["width"] + "px"
88
+ elem.parentElement.style.height = resolution["height"] + "px"
89
+ }
90
+
91
+ function undo() {
92
+ const canvas = openpose_editor_canvas;
93
+ if (undo_history.length > 0) {
94
+ lockMode = true;
95
+ if (undo_history.length > 1) redo_history.push(undo_history.pop());
96
+ const content = undo_history[undo_history.length - 1];
97
+ canvas.loadFromJSON(content, function () {
98
+ canvas.renderAll();
99
+ lockMode = false;
100
+ });
101
+ }
102
+ }
103
+
104
+ function redo() {
105
+ const canvas = openpose_editor_canvas;
106
+ if (redo_history.length > 0) {
107
+ lockMode = true;
108
+ const content = redo_history.pop();
109
+ undo_history.push(content);
110
+ canvas.loadFromJSON(content, function () {
111
+ canvas.renderAll();
112
+ lockMode = false;
113
+ });
114
+ }
115
+ }
116
+
117
+ function setPose(keypoints){
118
+ const canvas = openpose_editor_canvas;
119
+
120
+ canvas.clear()
121
+
122
+ canvas.backgroundColor = "#000"
123
+
124
+ const res = [];
125
+ for (let i = 0; i < keypoints.length; i += 18) {
126
+ const chunk = keypoints.slice(i, i + 18);
127
+ res.push(chunk);
128
+ }
129
+
130
+ for (item of res){
131
+ addPose(item)
132
+ openpose_editor_canvas.discardActiveObject();
133
+ }
134
+ }
135
+
136
+ function addPose(keypoints=undefined){
137
+ if (keypoints === undefined){
138
+ keypoints = default_keypoints;
139
+ }
140
+
141
+ const canvas = openpose_editor_canvas;
142
+ const group = new fabric.Group()
143
+
144
+ function makeCircle(color, left, top, line1, line2, line3, line4, line5) {
145
+ var c = new fabric.Circle({
146
+ left: left,
147
+ top: top,
148
+ strokeWidth: 1,
149
+ radius: 5,
150
+ fill: color,
151
+ stroke: color,
152
+ originX: 'center',
153
+ originY: 'center',
154
+ });
155
+ c.hasControls = c.hasBorders = false;
156
+
157
+ c.line1 = line1;
158
+ c.line2 = line2;
159
+ c.line3 = line3;
160
+ c.line4 = line4;
161
+ c.line5 = line5;
162
+
163
+ return c;
164
+ }
165
+
166
+ function makeLine(coords, color) {
167
+ return new fabric.Line(coords, {
168
+ fill: color,
169
+ stroke: color,
170
+ strokeWidth: 10,
171
+ selectable: false,
172
+ evented: false,
173
+ originX: 'center',
174
+ originY: 'center',
175
+ });
176
+ }
177
+
178
+ const lines = []
179
+ const circles = []
180
+
181
+ for (i = 0; i < connect_keypoints.length; i++){
182
+ // 接続されるidxを指定 [0, 1]なら0と1つなぐ
183
+ const item = connect_keypoints[i]
184
+ const line = makeLine(keypoints[item[0]].concat(keypoints[item[1]]), `rgba(${connect_color[i].join(", ")}, 0.7)`)
185
+ lines.push(line)
186
+ canvas.add(line)
187
+ line['id'] = item[0];
188
+ }
189
+
190
+ for (i = 0; i < keypoints.length; i++){
191
+ list = []
192
+ connect_keypoints.filter((item, idx) => {
193
+ if(item.includes(i)){
194
+ list.push(lines[idx])
195
+ return idx
196
+ }
197
+ })
198
+ circle = makeCircle(`rgb(${connect_color[i].join(", ")})`, keypoints[i][0], keypoints[i][1], ...list)
199
+ circle["id"] = i
200
+ circles.push(circle)
201
+ // canvas.add(circle)
202
+ group.addWithUpdate(circle);
203
+ }
204
+
205
+ canvas.discardActiveObject();
206
+ canvas.setActiveObject(group);
207
+ canvas.add(group);
208
+ group.toActiveSelection();
209
+ canvas.requestRenderAll();
210
+ }
211
+
212
+ function initCanvas(elem){
213
+ const canvas = window.openpose_editor_canvas = new fabric.Canvas(elem, {
214
+ backgroundColor: '#000',
215
+ // selection: false,
216
+ preserveObjectStacking: true
217
+ });
218
+
219
+ window.openpose_editor_elem = elem
220
+
221
+ function updateLines(target) {
222
+ if ("_objects" in target) {
223
+ const flipX = target.flipX ? -1 : 1;
224
+ const flipY = target.flipY ? -1 : 1;
225
+ flipped = flipX * flipY === -1;
226
+ const showEyes = flipped ? !visibleEyes : visibleEyes;
227
+
228
+ if (target.angle === 0) {
229
+ const rtop = target.top
230
+ const rleft = target.left
231
+ for (const item of target._objects){
232
+ let p = item;
233
+ p.scaleX = 1;
234
+ p.scaleY = 1;
235
+ const top = rtop + p.top * target.scaleY * flipY + target.height * target.scaleY / 2;
236
+ const left = rleft + p.left * target.scaleX * flipX + (target.width * target.scaleX / 2);
237
+ p['_top'] = top;
238
+ p['_left'] = left;
239
+ if (p["id"] === 0) {
240
+ p.line1 && p.line1.set({ 'x1': left, 'y1': top });
241
+ }else{
242
+ p.line1 && p.line1.set({ 'x2': left, 'y2': top });
243
+ }
244
+ if (p['id'] === 14 || p['id'] === 15) {
245
+ p.radius = showEyes ? 5 : 0;
246
+ if (p.line1) p.line1.strokeWidth = showEyes ? 10 : 0;
247
+ if (p.line2) p.line2.strokeWidth = showEyes ? 10 : 0;
248
+ }
249
+ p.line2 && p.line2.set({ 'x1': left, 'y1': top });
250
+ p.line3 && p.line3.set({ 'x1': left, 'y1': top });
251
+ p.line4 && p.line4.set({ 'x1': left, 'y1': top });
252
+ p.line5 && p.line5.set({ 'x1': left, 'y1': top });
253
+
254
+ }
255
+ } else {
256
+ const aCoords = target.aCoords;
257
+ const center = {'x': (aCoords.tl.x + aCoords.br.x)/2, 'y': (aCoords.tl.y + aCoords.br.y)/2};
258
+ const rad = target.angle * Math.PI / 180;
259
+ const sin = Math.sin(rad);
260
+ const cos = Math.cos(rad);
261
+
262
+ for (const item of target._objects){
263
+ let p = item;
264
+ const p_top = p.top * target.scaleY * flipY;
265
+ const p_left = p.left * target.scaleX * flipX;
266
+ const left = center.x + p_left * cos - p_top * sin;
267
+ const top = center.y + p_left * sin + p_top * cos;
268
+ p['_top'] = top;
269
+ p['_left'] = left;
270
+ if (p["id"] === 0) {
271
+ p.line1 && p.line1.set({ 'x1': left, 'y1': top });
272
+ }else{
273
+ p.line1 && p.line1.set({ 'x2': left, 'y2': top });
274
+ }
275
+ if (p['id'] === 14 || p['id'] === 15) {
276
+ p.radius = showEyes ? 5 : 0.3;
277
+ if (p.line1) p.line1.strokeWidth = showEyes ? 10 : 0;
278
+ if (p.line2) p.line2.strokeWidth = showEyes ? 10 : 0;
279
+ }
280
+ p.line2 && p.line2.set({ 'x1': left, 'y1': top });
281
+ p.line3 && p.line3.set({ 'x1': left, 'y1': top });
282
+ p.line4 && p.line4.set({ 'x1': left, 'y1': top });
283
+ p.line5 && p.line5.set({ 'x1': left, 'y1': top });
284
+ }
285
+ }
286
+ } else {
287
+ var p = target;
288
+ if (p["id"] === 0) {
289
+ p.line1 && p.line1.set({ 'x1': p.left, 'y1': p.top });
290
+ }else{
291
+ p.line1 && p.line1.set({ 'x2': p.left, 'y2': p.top });
292
+ }
293
+ p.line2 && p.line2.set({ 'x1': p.left, 'y1': p.top });
294
+ p.line3 && p.line3.set({ 'x1': p.left, 'y1': p.top });
295
+ p.line4 && p.line4.set({ 'x1': p.left, 'y1': p.top });
296
+ p.line5 && p.line5.set({ 'x1': p.left, 'y1': p.top });
297
+ }
298
+ canvas.renderAll();
299
+ }
300
+
301
+ canvas.on('object:moving', function(e) {
302
+ updateLines(e.target);
303
+ });
304
+
305
+ canvas.on('object:scaling', function(e) {
306
+ updateLines(e.target);
307
+ canvas.renderAll();
308
+ });
309
+
310
+ canvas.on('object:rotating', function(e) {
311
+ updateLines(e.target);
312
+ canvas.renderAll();
313
+ });
314
+
315
+ canvas.on("object:modified", function () {
316
+ if (lockMode) return;
317
+ undo_history.push(JSON.stringify(canvas));
318
+ redo_history.length = 0;
319
+ });
320
+
321
+ resizeCanvas(...openpose_obj.resolution)
322
+
323
+ setPose(default_keypoints)
324
+
325
+ undo_history.push(JSON.stringify(canvas));
326
+
327
+ const json_observer = new MutationObserver((m) => {
328
+ if(gradioApp().querySelector('#tab_openpose_editor').style.display!=='block') return;
329
+ try {
330
+ const raw = gradioApp().querySelector("#jsonbox").querySelector("textarea").value
331
+ if(raw.length!==0) detectImage(raw);
332
+ } catch(e){console.log(e)}
333
+ })
334
+ json_observer.observe(gradioApp().querySelector("#jsonbox"), { "attributes": true })
335
+
336
+ // document.addEventListener('keydown', function(e) {
337
+ // if (e.key !== undefined) {
338
+ // if((e.key == "z" && (e.metaKey || e.ctrlKey || e.altKey))) undo()
339
+ // if((e.key == "y" && (e.metaKey || e.ctrlKey || e.altKey))) redo()
340
+ // }
341
+ // })
342
+ }
343
+
344
+ function resetCanvas(){
345
+ const canvas = openpose_editor_canvas;
346
+ canvas.clear()
347
+ canvas.backgroundColor = "#000"
348
+ }
349
+
350
+ function savePNG(){
351
+ openpose_editor_canvas.getObjects("image").forEach((img) => {
352
+ img.set({
353
+ opacity: 0
354
+ });
355
+ })
356
+ if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0
357
+ openpose_editor_canvas.discardActiveObject();
358
+ openpose_editor_canvas.renderAll()
359
+ openpose_editor_elem.toBlob((blob) => {
360
+ const a = document.createElement("a");
361
+ a.href = URL.createObjectURL(blob);
362
+ a.download = "pose.png";
363
+ a.click();
364
+ URL.revokeObjectURL(a.href);
365
+ });
366
+ openpose_editor_canvas.getObjects("image").forEach((img) => {
367
+ img.set({
368
+ opacity: 1
369
+ });
370
+ })
371
+ if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0.5
372
+ openpose_editor_canvas.renderAll()
373
+ return openpose_editor_canvas
374
+ }
375
+
376
+ function serializeJSON(){
377
+ const json = JSON.stringify({
378
+ "width": openpose_editor_canvas.width,
379
+ "height": openpose_editor_canvas.height,
380
+ "keypoints": openpose_editor_canvas.getObjects().filter((item) => {
381
+ if (item.type === "circle") return item
382
+ }).map((item) => {
383
+ return [Math.round(item.left), Math.round(item.top)]
384
+ })
385
+ }, null, 4)
386
+ return json;
387
+ }
388
+
389
+ function saveJSON(){
390
+ const json = serializeJSON()
391
+ const blob = new Blob([json], {
392
+ type: "application/json"
393
+ });
394
+ const filename = "pose-" + Date.now().toString() + ".json"
395
+ const a = document.createElement("a");
396
+ a.href = URL.createObjectURL(blob);
397
+ a.download = filename;
398
+ a.click();
399
+ URL.revokeObjectURL(a.href);
400
+ }
401
+
402
+ async function loadJSON(file){
403
+ const url = await fileToDataUrl(file)
404
+ const response = await fetch(url)
405
+ const json = await response.json()
406
+ if (json["width"] && json["height"]) {
407
+ resizeCanvas(json["width"], json["height"])
408
+ }else{
409
+ throw new Error('width, height is invalid');
410
+ }
411
+ if (json["keypoints"].length % 18 === 0) {
412
+ setPose(json["keypoints"])
413
+ }else{
414
+ throw new Error('keypoints is invalid')
415
+ }
416
+ return [json["width"], json["height"]]
417
+ }
418
+
419
+ function savePreset(){
420
+ var name = prompt("Preset Name")
421
+ const json = serializeJSON()
422
+ return [name, json]
423
+ }
424
+
425
+ function loadPreset(json){
426
+ try {
427
+ json = JSON.parse(json)
428
+ if (json["width"] && json["height"]) {
429
+ resizeCanvas(json["width"], json["height"])
430
+ }else{
431
+ throw new Error('width, height is invalid');
432
+ }
433
+ if (json["keypoints"].length % 18 === 0) {
434
+ setPose(json["keypoints"])
435
+ }else{
436
+ throw new Error('keypoints is invalid')
437
+ }
438
+ return [json["width"], json["height"]]
439
+ }catch(e){
440
+ console.error(e)
441
+ alert("Invalid JSON")
442
+ }
443
+ }
444
+
445
+ async function addBackground(file){
446
+ const url = await fileToDataUrl(file)
447
+ openpose_editor_canvas.setBackgroundImage(url, openpose_editor_canvas.renderAll.bind(openpose_editor_canvas), {
448
+ opacity: 0.5
449
+ });
450
+ const img = new Image();
451
+ await (img.src = url);
452
+ resizeCanvas(img.width, img.height)
453
+ return [img.width, img.height]
454
+ }
455
+
456
+ function detectImage(raw){
457
+ const json = JSON.parse(raw)
458
+
459
+ let candidate = json["candidate"]
460
+ let subset = json["subset"]
461
+ const li = []
462
+ subset = subset.splice(0, 18)
463
+ for (i=0; subset.length > i; i++){
464
+ if (Number.isInteger(subset[i]) && subset[i] >= 0){
465
+ li.push(candidate[subset[i]])
466
+ }else{
467
+ li.push([-1,-1])
468
+ }
469
+ }
470
+ if(li.length === 0){
471
+ const bgimage = openpose_editor_canvas.backgroundImage
472
+ setPose(li);
473
+ openpose_editor_canvas.backgroundImage = bgimage
474
+ return;
475
+ }
476
+ if(li.every(([x,y])=>x===-1&&y===-1)){
477
+ const ra_width = Math.floor(Math.random() * openpose_editor_canvas.width)
478
+ const ra_height = Math.floor(Math.random() * openpose_editor_canvas.height)
479
+ li[0] = [ra_width, ra_height]
480
+ }
481
+ default_relative_keypoints = []
482
+ for(i=0;i<default_keypoints.length;i++){
483
+ default_relative_keypoints.push([])
484
+ for(j=0;j<default_keypoints.length;j++){
485
+ x = default_keypoints[j][0] - default_keypoints[i][0];
486
+ y = default_keypoints[j][1] - default_keypoints[i][1];
487
+ default_relative_keypoints[i].push([x,y])
488
+ }
489
+ }
490
+ kp_connect = (i,j)=>{
491
+ for(idx=0;idx<connect_keypoints.length;idx++){
492
+ cp = connect_keypoints[idx];
493
+ if(((cp[0]===i)&&(cp[1]===j)) || ((cp[0]===j)&&(cp[1]===i))){
494
+ return true;
495
+ }
496
+ }
497
+ return false;
498
+ }
499
+ // bfs propagate
500
+ while(li.some(([x,y])=>x===-1&&y===-1)){
501
+ for(i=0;i<li.length;i++){
502
+ if(li[i][0]===-1){
503
+ continue;
504
+ }
505
+ for(j=0;j<li.length;j++){
506
+ if(li[j][0]===-1 && kp_connect(i,j)){
507
+ x = li[i][0] + default_relative_keypoints[i][j][0]
508
+ y = li[i][1] + default_relative_keypoints[i][j][1]
509
+ x = Math.min(Math.max(x, 0), openpose_editor_canvas.width);
510
+ y = Math.min(Math.max(y, 0), openpose_editor_canvas.height);
511
+ li[j] = [x,y];
512
+ }
513
+ }
514
+ }
515
+ }
516
+ const bgimage = openpose_editor_canvas.backgroundImage
517
+ setPose(li);
518
+ openpose_editor_canvas.backgroundImage = bgimage
519
+ }
520
+
521
+ function sendImage(type, index){
522
+ openpose_editor_canvas.getObjects("image").forEach((img) => {
523
+ img.set({
524
+ opacity: 0
525
+ });
526
+ })
527
+ if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0
528
+ openpose_editor_canvas.discardActiveObject();
529
+ openpose_editor_canvas.renderAll()
530
+ openpose_editor_elem.toBlob((blob) => {
531
+ const file = new File(([blob]), "pose.png")
532
+ const dt = new DataTransfer();
533
+ dt.items.add(file);
534
+ const list = dt.files
535
+ const selector = type === "txt2img" ? "#txt2img_script_container" : "#img2img_script_container"
536
+ if (type === "txt2img"){
537
+ switch_to_txt2img()
538
+ }else if(type === "img2img"){
539
+ switch_to_img2img()
540
+ }
541
+
542
+ const isNew = window.gradio_config.version.replace("\n", "") >= "3.23.0"
543
+ const accordion_selector = isNew ? "#controlnet > .label-wrap > .icon" : "#controlnet .transition"
544
+ const accordion = gradioApp().querySelector(selector).querySelector(accordion_selector)
545
+ if (isNew ? accordion.style.transform == "rotate(90deg)" : accordion.classList.contains("rotate-90")) {
546
+ accordion.click()
547
+ }
548
+
549
+ let input = gradioApp().querySelector(selector).querySelector("#controlnet").querySelector("input[type='file']");
550
+
551
+ const input_image = (input) =>{
552
+ try {
553
+ if(input.previousElementSibling
554
+ && input.previousElementSibling.previousElementSibling
555
+ && input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']")) {
556
+ input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']").click()
557
+ }
558
+ } catch (e) {
559
+ console.error(e)
560
+ }
561
+ input.value = "";
562
+ input.files = list;
563
+ const event = new Event('change', { 'bubbles': true, "composed": true });
564
+ input.dispatchEvent(event);
565
+ }
566
+
567
+ if (input == null){
568
+ const callback = (observer) => {
569
+ input = gradioApp().querySelector(selector).querySelector("#controlnet").querySelector("input[type='file']");
570
+ if (input == null) {
571
+ console.error('input[type=file] NOT exists')
572
+ return
573
+ }else{
574
+ input_image(input)
575
+ observer.disconnect()
576
+ }
577
+ }
578
+ const observer = new MutationObserver(callback);
579
+ observer.observe(gradioApp().querySelector(selector).querySelector("#controlnet"), { childList: true });
580
+ }else{
581
+ input_image(input)
582
+ }
583
+ });
584
+ openpose_editor_canvas.getObjects("image").forEach((img) => {
585
+ img.set({
586
+ opacity: 1
587
+ });
588
+ })
589
+ if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0.5
590
+ openpose_editor_canvas.renderAll()
591
+ }
592
+
593
+ function canvas_onDragOver(event) {
594
+ canvas_drag_overlay = gradioApp().querySelector("#canvas_drag_overlay");
595
+
596
+ if (event.dataTransfer.items[0].type.startsWith("image/")) {
597
+ event.preventDefault();
598
+ canvas_drag_overlay.textContent = "Add Background";
599
+ canvas_drag_overlay.style.visibility = "visible";
600
+ } else if (event.dataTransfer.items[0].type == "application/json") {
601
+ event.preventDefault();
602
+ canvas_drag_overlay.textContent = "Load JSON";
603
+ canvas_drag_overlay.style.visibility = "visible";
604
+ }
605
+ }
606
+
607
+ function canvas_onDrop(event) {
608
+ canvas_drag_overlay = gradioApp().querySelector("#canvas_drag_overlay");
609
+
610
+ if (event.dataTransfer.items[0].type.startsWith("image/")) {
611
+ event.preventDefault();
612
+ input = gradioApp().querySelector("#openpose_bg_button").previousElementSibling;
613
+ input.files = event.dataTransfer.files;
614
+ const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
615
+ input.dispatchEvent(changeEvent);
616
+ canvas_drag_overlay.style.visibility = "hidden";
617
+ } else if (event.dataTransfer.items[0].type == "application/json") {
618
+ event.preventDefault();
619
+ input = gradioApp().querySelector("#openpose_json_button").previousElementSibling;
620
+ input.files = event.dataTransfer.files;
621
+ const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
622
+ input.dispatchEvent(changeEvent);
623
+ canvas_drag_overlay.style.visibility = "hidden";
624
+ }
625
+ }
626
+
627
+ function button_onDragOver(event) {
628
+ if (((event.target.id == "openpose_detect_button" || event.target.id == "openpose_bg_button") && event.dataTransfer.items[0].type.startsWith("image/")) ||
629
+ (event.target.id == "openpose_json_button" && event.dataTransfer.items[0].type == "application/json")) {
630
+ event.preventDefault();
631
+ event.target.classList.remove("gr-button-secondary");
632
+ }
633
+ }
634
+
635
+ function button_onDragLeave(event) {
636
+ event.target.classList.add("gr-button-secondary");
637
+ }
638
+
639
+ function detect_onDrop(event) {
640
+ if (event.dataTransfer.items[0].type.startsWith("image/")) {
641
+ event.preventDefault();
642
+ input = event.target.previousElementSibling;
643
+ input.files = event.dataTransfer.files;
644
+ const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
645
+ input.dispatchEvent(changeEvent);
646
+ }
647
+ }
648
+
649
+ function json_onDrop(event) {
650
+ if (event.dataTransfer.items[0].type == "application/json") {
651
+ event.preventDefault();
652
+ input = event.target.previousElementSibling;
653
+ input.files = event.dataTransfer.files;
654
+ const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
655
+ input.dispatchEvent(changeEvent);
656
+ }
657
+ }
658
+
659
+ onUiLoaded(function() {
660
+ initCanvas(gradioApp().querySelector('#openpose_editor_canvas'))
661
+
662
+ var canvas_drag_overlay = document.createElement("div");
663
+ canvas_drag_overlay.id = "canvas_drag_overlay"
664
+ canvas_drag_overlay.style = "pointer-events: none; visibility: hidden; display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; color: white; font-size: 2.5em; font-family: inherit; font-weight: 600; line-height: 100%; background: rgba(0,0,0,0.5); margin: 0.25rem; border-radius: 0.25rem; border: 0.5px solid; position: absolute;"
665
+
666
+ var canvas = gradioApp().querySelector("#tab_openpose_editor .canvas-container")
667
+ canvas.appendChild(canvas_drag_overlay)
668
+ canvas.addEventListener("dragover", canvas_onDragOver);
669
+ canvas.addEventListener("dragleave", () => gradioApp().querySelector("#canvas_drag_overlay").style.visibility = "hidden");
670
+ canvas.addEventListener("drop", canvas_onDrop);
671
+
672
+ var bg_button = gradioApp().querySelector("#openpose_bg_button")
673
+ bg_button.addEventListener("dragover", button_onDragOver);
674
+ bg_button.addEventListener("dragleave", button_onDragLeave);
675
+ bg_button.addEventListener("drop", canvas_onDrop);
676
+ bg_button.addEventListener("drop", event => event.target.classList.add("gr-button-secondary"));
677
+ bg_button.classList.add("gr-button-secondary");
678
+
679
+ var detect_button = gradioApp().querySelector("#openpose_detect_button")
680
+ detect_button.addEventListener("dragover", button_onDragOver);
681
+ detect_button.addEventListener("dragleave", button_onDragLeave);
682
+ detect_button.addEventListener("drop", detect_onDrop);
683
+ detect_button.classList.add("gr-button-secondary");
684
+
685
+ var json_button = gradioApp().querySelector("#openpose_json_button")
686
+ json_button.addEventListener("dragover", button_onDragOver);
687
+ json_button.addEventListener("dragleave", button_onDragLeave);
688
+ json_button.addEventListener("drop", json_onDrop);
689
+ json_button.classList.add("gr-button-secondary");
690
+
691
+ })
openpose-editor/scripts/__pycache__/main.cpython-310.pyc ADDED
Binary file (5.83 kB). View file
 
openpose-editor/scripts/main.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import numpy as np
5
+ import cv2
6
+
7
+ import gradio as gr
8
+
9
+ import modules.scripts as scripts
10
+ from modules import script_callbacks
11
+ from modules.shared import opts
12
+ from modules.paths import models_path
13
+
14
+ from basicsr.utils.download_util import load_file_from_url
15
+
16
+ from scripts.openpose.body import Body
17
+
18
+ from PIL import Image
19
+
20
+ body_estimation = None
21
+ presets_file = os.path.join(scripts.basedir(), "presets.json")
22
+ presets = {}
23
+
24
+ try:
25
+ with open(presets_file) as file:
26
+ presets = json.load(file)
27
+ except FileNotFoundError:
28
+ pass
29
+
30
+ def pil2cv(in_image):
31
+ out_image = np.array(in_image, dtype=np.uint8)
32
+
33
+ if out_image.shape[2] == 3:
34
+ out_image = cv2.cvtColor(out_image, cv2.COLOR_RGB2BGR)
35
+ return out_image
36
+
37
+ def candidate2li(li):
38
+ res = []
39
+ for x, y, *_ in li:
40
+ res.append([x, y])
41
+ return res
42
+
43
+ def subset2li(li):
44
+ res = []
45
+ for r in li:
46
+ for c in r:
47
+ res.append(c)
48
+ return res
49
+
50
+ class Script(scripts.Script):
51
+ def __init__(self) -> None:
52
+ super().__init__()
53
+
54
+ def title(self):
55
+ return "OpenPose Editor"
56
+
57
+ def show(self, is_img2img):
58
+ return scripts.AlwaysVisible
59
+
60
+ def ui(self, is_img2img):
61
+ return ()
62
+
63
+ def on_ui_tabs():
64
+ with gr.Blocks(analytics_enabled=False) as openpose_editor:
65
+ with gr.Row():
66
+ with gr.Column():
67
+ width = gr.Slider(label="width", minimum=64, maximum=2048, value=512, step=64, interactive=True)
68
+ height = gr.Slider(label="height", minimum=64, maximum=2048, value=512, step=64, interactive=True)
69
+ with gr.Row():
70
+ add = gr.Button(value="Add", variant="primary")
71
+ # delete = gr.Button(value="Delete")
72
+ with gr.Row():
73
+ reset_btn = gr.Button(value="Reset")
74
+ json_input = gr.UploadButton(label="Load from JSON", file_types=[".json"], elem_id="openpose_json_button")
75
+ png_input = gr.UploadButton(label="Detect from Image", file_types=["image"], type="bytes", elem_id="openpose_detect_button")
76
+ bg_input = gr.UploadButton(label="Add Background Image", file_types=["image"], elem_id="openpose_bg_button")
77
+ with gr.Row():
78
+ preset_list = gr.Dropdown(label="Presets", choices=sorted(presets.keys()), interactive=True)
79
+ preset_load = gr.Button(value="Load Preset")
80
+ preset_save = gr.Button(value="Save Preset")
81
+
82
+ with gr.Column():
83
+ # gradioooooo...
84
+ canvas = gr.HTML('<canvas id="openpose_editor_canvas" width="512" height="512" style="margin: 0.25rem; border-radius: 0.25rem; border: 0.5px solid"></canvas>')
85
+ jsonbox = gr.Text(label="json", elem_id="jsonbox", visible=False)
86
+ with gr.Row():
87
+ json_output = gr.Button(value="Save JSON")
88
+ png_output = gr.Button(value="Save PNG")
89
+ send_t2t = gr.Button(value="Send to txt2img")
90
+ send_i2i = gr.Button(value="Send to img2img")
91
+ control_net_max_models_num = getattr(opts, 'control_net_max_models_num', 0)
92
+ select_target_index = gr.Dropdown([str(i) for i in range(control_net_max_models_num)], label="Send to", value="0", interactive=True, visible=(control_net_max_models_num > 1))
93
+
94
+ def estimate(file):
95
+ global body_estimation
96
+
97
+ if body_estimation is None:
98
+ model_path = os.path.join(models_path, "openpose", "body_pose_model.pth")
99
+ if not os.path.isfile(model_path):
100
+ body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
101
+ load_file_from_url(body_model_path, model_dir=os.path.join(models_path, "openpose"))
102
+ body_estimation = Body(model_path)
103
+
104
+ stream = io.BytesIO(file)
105
+ img = Image.open(stream)
106
+ candidate, subset = body_estimation(pil2cv(img))
107
+
108
+ result = {
109
+ "candidate": candidate2li(candidate),
110
+ "subset": subset2li(subset),
111
+ }
112
+
113
+ return str(result).replace("'", '"')
114
+
115
+ def savePreset(name, data):
116
+ if name:
117
+ presets[name] = json.loads(data)
118
+ with open(presets_file, "w") as file:
119
+ json.dump(presets, file)
120
+ return gr.update(choices=sorted(presets.keys()), value=name), json.dumps(data)
121
+ return gr.update(), gr.update()
122
+
123
+ dummy_component = gr.Label(visible=False)
124
+ preset = gr.Text(visible=False)
125
+ width.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
126
+ height.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
127
+ png_output.click(None, [], None, _js="savePNG")
128
+ bg_input.upload(None, [bg_input], [width, height], _js="addBackground")
129
+ png_input.upload(estimate, png_input, jsonbox)
130
+ png_input.upload(None, png_input, [width, height], _js="addBackground")
131
+ add.click(None, [], None, _js="addPose")
132
+ send_t2t.click(None, select_target_index, None, _js="(i) => {sendImage('txt2img', i)}")
133
+ send_i2i.click(None, select_target_index, None, _js="(i) => {sendImage('img2img', i)}")
134
+ reset_btn.click(None, [], None, _js="resetCanvas")
135
+ json_input.upload(None, json_input, [width, height], _js="loadJSON")
136
+ json_output.click(None, None, None, _js="saveJSON")
137
+ preset_save.click(savePreset, [dummy_component, dummy_component], [preset_list, preset], _js="savePreset")
138
+ preset_load.click(None, preset, [width, height], _js="loadPreset")
139
+ preset_list.change(lambda selected: json.dumps(presets[selected]), preset_list, preset)
140
+
141
+ return [(openpose_editor, "OpenPose Editor", "openpose_editor")]
142
+
143
+ script_callbacks.on_ui_tabs(on_ui_tabs)
openpose-editor/scripts/openpose/__pycache__/body.cpython-310.pyc ADDED
Binary file (7.35 kB). View file
 
openpose-editor/scripts/openpose/__pycache__/model.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
openpose-editor/scripts/openpose/__pycache__/util.cpython-310.pyc ADDED
Binary file (5.08 kB). View file
 
openpose-editor/scripts/openpose/body.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code from https://github.com/lllyasviel/ControlNet
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import math
6
+ import time
7
+ from scipy.ndimage.filters import gaussian_filter
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib
10
+ import torch
11
+ from torchvision import transforms
12
+
13
+ from . import util
14
+ from .model import bodypose_model
15
+
16
+ class Body(object):
17
+ def __init__(self, model_path):
18
+ self.model = bodypose_model()
19
+ if torch.cuda.is_available():
20
+ self.model = self.model.cuda()
21
+ model_dict = util.transfer(self.model, torch.load(model_path))
22
+ self.model.load_state_dict(model_dict)
23
+ self.model.eval()
24
+
25
+ def __call__(self, oriImg):
26
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
27
+ scale_search = [0.5]
28
+ boxsize = 368
29
+ stride = 8
30
+ padValue = 128
31
+ threshold1 = 0.1
32
+ threshold2 = 0.05
33
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
34
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
35
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
36
+
37
+ for m in range(len(multiplier)):
38
+ scale = multiplier[m]
39
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
40
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
41
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
42
+ im = np.ascontiguousarray(im)
43
+
44
+ data = torch.from_numpy(im).float()
45
+ if torch.cuda.is_available():
46
+ data = data.cuda()
47
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
48
+ with torch.no_grad():
49
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
50
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
51
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
52
+
53
+ # extract outputs, resize, and remove padding
54
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
55
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
56
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
57
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
58
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
59
+
60
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
61
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
62
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
63
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
64
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
65
+
66
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
67
+ paf_avg += + paf / len(multiplier)
68
+
69
+ all_peaks = []
70
+ peak_counter = 0
71
+
72
+ for part in range(18):
73
+ map_ori = heatmap_avg[:, :, part]
74
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
75
+
76
+ map_left = np.zeros(one_heatmap.shape)
77
+ map_left[1:, :] = one_heatmap[:-1, :]
78
+ map_right = np.zeros(one_heatmap.shape)
79
+ map_right[:-1, :] = one_heatmap[1:, :]
80
+ map_up = np.zeros(one_heatmap.shape)
81
+ map_up[:, 1:] = one_heatmap[:, :-1]
82
+ map_down = np.zeros(one_heatmap.shape)
83
+ map_down[:, :-1] = one_heatmap[:, 1:]
84
+
85
+ peaks_binary = np.logical_and.reduce(
86
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > threshold1))
87
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
88
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
89
+ peak_id = range(peak_counter, peak_counter + len(peaks))
90
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
91
+
92
+ all_peaks.append(peaks_with_score_and_id)
93
+ peak_counter += len(peaks)
94
+
95
+ # find connection in the specified sequence, center 29 is in the position 15
96
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
97
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
98
+ [1, 16], [16, 18], [3, 17], [6, 18]]
99
+ # the middle joints heatmap correpondence
100
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
101
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
102
+ [55, 56], [37, 38], [45, 46]]
103
+
104
+ connection_all = []
105
+ special_k = []
106
+ mid_num = 10
107
+
108
+ for k in range(len(mapIdx)):
109
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
110
+ candA = all_peaks[limbSeq[k][0] - 1]
111
+ candB = all_peaks[limbSeq[k][1] - 1]
112
+ nA = len(candA)
113
+ nB = len(candB)
114
+ indexA, indexB = limbSeq[k]
115
+ if (nA != 0 and nB != 0):
116
+ connection_candidate = []
117
+ for i in range(nA):
118
+ for j in range(nB):
119
+ vec = np.subtract(candB[j][:2], candA[i][:2])
120
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
121
+ norm = max(0.001, norm)
122
+ vec = np.divide(vec, norm)
123
+
124
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
125
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
126
+
127
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
128
+ for I in range(len(startend))])
129
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
130
+ for I in range(len(startend))])
131
+
132
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
133
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
134
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
135
+ criterion1 = len(np.nonzero(score_midpts > threshold2)[0]) > 0.8 * len(score_midpts)
136
+ criterion2 = score_with_dist_prior > 0
137
+ if criterion1 and criterion2:
138
+ connection_candidate.append(
139
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
140
+
141
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
142
+ connection = np.zeros((0, 5))
143
+ for c in range(len(connection_candidate)):
144
+ i, j, s = connection_candidate[c][0:3]
145
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
146
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
147
+ if (len(connection) >= min(nA, nB)):
148
+ break
149
+
150
+ connection_all.append(connection)
151
+ else:
152
+ special_k.append(k)
153
+ connection_all.append([])
154
+
155
+ # last number in each row is the total parts number of that person
156
+ # the second last number in each row is the score of the overall configuration
157
+ subset = -1 * np.ones((0, 20))
158
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
159
+
160
+ for k in range(len(mapIdx)):
161
+ if k not in special_k:
162
+ partAs = connection_all[k][:, 0]
163
+ partBs = connection_all[k][:, 1]
164
+ indexA, indexB = np.array(limbSeq[k]) - 1
165
+
166
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
167
+ found = 0
168
+ subset_idx = [-1, -1]
169
+ for j in range(len(subset)): # 1:size(subset,1):
170
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
171
+ subset_idx[found] = j
172
+ found += 1
173
+
174
+ if found == 1:
175
+ j = subset_idx[0]
176
+ if subset[j][indexB] != partBs[i]:
177
+ subset[j][indexB] = partBs[i]
178
+ subset[j][-1] += 1
179
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
180
+ elif found == 2: # if found 2 and disjoint, merge them
181
+ j1, j2 = subset_idx
182
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
183
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
184
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
185
+ subset[j1][-2:] += subset[j2][-2:]
186
+ subset[j1][-2] += connection_all[k][i][2]
187
+ subset = np.delete(subset, j2, 0)
188
+ else: # as like found == 1
189
+ subset[j1][indexB] = partBs[i]
190
+ subset[j1][-1] += 1
191
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
192
+
193
+ # if find no partA in the subset, create a new subset
194
+ elif not found and k < 17:
195
+ row = -1 * np.ones(20)
196
+ row[indexA] = partAs[i]
197
+ row[indexB] = partBs[i]
198
+ row[-1] = 2
199
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
200
+ subset = np.vstack([subset, row])
201
+ # delete some rows of subset which has few parts occur
202
+ deleteIdx = []
203
+ for i in range(len(subset)):
204
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
205
+ deleteIdx.append(i)
206
+ subset = np.delete(subset, deleteIdx, axis=0)
207
+
208
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
209
+ # candidate: x, y, score, id
210
+ return candidate, subset
211
+
212
+ if __name__ == "__main__":
213
+ body_estimation = Body('../model/body_pose_model.pth')
214
+
215
+ test_image = '../images/ski.jpg'
216
+ oriImg = cv2.imread(test_image) # B,G,R order
217
+ candidate, subset = body_estimation(oriImg)
218
+ canvas = util.draw_bodypose(oriImg, candidate, subset)
219
+ plt.imshow(canvas[:, :, [2, 1, 0]])
220
+ plt.show()
openpose-editor/scripts/openpose/model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code from https://github.com/lllyasviel/ControlNet
2
+
3
+ import torch
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ def make_layers(block, no_relu_layers):
10
+ layers = []
11
+ for layer_name, v in block.items():
12
+ if 'pool' in layer_name:
13
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
14
+ padding=v[2])
15
+ layers.append((layer_name, layer))
16
+ else:
17
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
18
+ kernel_size=v[2], stride=v[3],
19
+ padding=v[4])
20
+ layers.append((layer_name, conv2d))
21
+ if layer_name not in no_relu_layers:
22
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
23
+
24
+ return nn.Sequential(OrderedDict(layers))
25
+
26
+ class bodypose_model(nn.Module):
27
+ def __init__(self):
28
+ super(bodypose_model, self).__init__()
29
+
30
+ # these layers have no relu layer
31
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
32
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
33
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
34
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
35
+ blocks = {}
36
+ block0 = OrderedDict([
37
+ ('conv1_1', [3, 64, 3, 1, 1]),
38
+ ('conv1_2', [64, 64, 3, 1, 1]),
39
+ ('pool1_stage1', [2, 2, 0]),
40
+ ('conv2_1', [64, 128, 3, 1, 1]),
41
+ ('conv2_2', [128, 128, 3, 1, 1]),
42
+ ('pool2_stage1', [2, 2, 0]),
43
+ ('conv3_1', [128, 256, 3, 1, 1]),
44
+ ('conv3_2', [256, 256, 3, 1, 1]),
45
+ ('conv3_3', [256, 256, 3, 1, 1]),
46
+ ('conv3_4', [256, 256, 3, 1, 1]),
47
+ ('pool3_stage1', [2, 2, 0]),
48
+ ('conv4_1', [256, 512, 3, 1, 1]),
49
+ ('conv4_2', [512, 512, 3, 1, 1]),
50
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
51
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
52
+ ])
53
+
54
+
55
+ # Stage 1
56
+ block1_1 = OrderedDict([
57
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
58
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
59
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
60
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
61
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
62
+ ])
63
+
64
+ block1_2 = OrderedDict([
65
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
66
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
67
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
68
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
69
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
70
+ ])
71
+ blocks['block1_1'] = block1_1
72
+ blocks['block1_2'] = block1_2
73
+
74
+ self.model0 = make_layers(block0, no_relu_layers)
75
+
76
+ # Stages 2 - 6
77
+ for i in range(2, 7):
78
+ blocks['block%d_1' % i] = OrderedDict([
79
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
80
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
81
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
82
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
83
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
84
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
85
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
86
+ ])
87
+
88
+ blocks['block%d_2' % i] = OrderedDict([
89
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
90
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
91
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
92
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
93
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
94
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
95
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
96
+ ])
97
+
98
+ for k in blocks.keys():
99
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
100
+
101
+ self.model1_1 = blocks['block1_1']
102
+ self.model2_1 = blocks['block2_1']
103
+ self.model3_1 = blocks['block3_1']
104
+ self.model4_1 = blocks['block4_1']
105
+ self.model5_1 = blocks['block5_1']
106
+ self.model6_1 = blocks['block6_1']
107
+
108
+ self.model1_2 = blocks['block1_2']
109
+ self.model2_2 = blocks['block2_2']
110
+ self.model3_2 = blocks['block3_2']
111
+ self.model4_2 = blocks['block4_2']
112
+ self.model5_2 = blocks['block5_2']
113
+ self.model6_2 = blocks['block6_2']
114
+
115
+
116
+ def forward(self, x):
117
+
118
+ out1 = self.model0(x)
119
+
120
+ out1_1 = self.model1_1(out1)
121
+ out1_2 = self.model1_2(out1)
122
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
123
+
124
+ out2_1 = self.model2_1(out2)
125
+ out2_2 = self.model2_2(out2)
126
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
127
+
128
+ out3_1 = self.model3_1(out3)
129
+ out3_2 = self.model3_2(out3)
130
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
131
+
132
+ out4_1 = self.model4_1(out4)
133
+ out4_2 = self.model4_2(out4)
134
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
135
+
136
+ out5_1 = self.model5_1(out5)
137
+ out5_2 = self.model5_2(out5)
138
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
139
+
140
+ out6_1 = self.model6_1(out6)
141
+ out6_2 = self.model6_2(out6)
142
+
143
+ return out6_1, out6_2
144
+
145
+ class handpose_model(nn.Module):
146
+ def __init__(self):
147
+ super(handpose_model, self).__init__()
148
+
149
+ # these layers have no relu layer
150
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
151
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
152
+ # stage 1
153
+ block1_0 = OrderedDict([
154
+ ('conv1_1', [3, 64, 3, 1, 1]),
155
+ ('conv1_2', [64, 64, 3, 1, 1]),
156
+ ('pool1_stage1', [2, 2, 0]),
157
+ ('conv2_1', [64, 128, 3, 1, 1]),
158
+ ('conv2_2', [128, 128, 3, 1, 1]),
159
+ ('pool2_stage1', [2, 2, 0]),
160
+ ('conv3_1', [128, 256, 3, 1, 1]),
161
+ ('conv3_2', [256, 256, 3, 1, 1]),
162
+ ('conv3_3', [256, 256, 3, 1, 1]),
163
+ ('conv3_4', [256, 256, 3, 1, 1]),
164
+ ('pool3_stage1', [2, 2, 0]),
165
+ ('conv4_1', [256, 512, 3, 1, 1]),
166
+ ('conv4_2', [512, 512, 3, 1, 1]),
167
+ ('conv4_3', [512, 512, 3, 1, 1]),
168
+ ('conv4_4', [512, 512, 3, 1, 1]),
169
+ ('conv5_1', [512, 512, 3, 1, 1]),
170
+ ('conv5_2', [512, 512, 3, 1, 1]),
171
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
172
+ ])
173
+
174
+ block1_1 = OrderedDict([
175
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
176
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
177
+ ])
178
+
179
+ blocks = {}
180
+ blocks['block1_0'] = block1_0
181
+ blocks['block1_1'] = block1_1
182
+
183
+ # stage 2-6
184
+ for i in range(2, 7):
185
+ blocks['block%d' % i] = OrderedDict([
186
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
187
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
188
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
189
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
190
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
191
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
192
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
193
+ ])
194
+
195
+ for k in blocks.keys():
196
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
197
+
198
+ self.model1_0 = blocks['block1_0']
199
+ self.model1_1 = blocks['block1_1']
200
+ self.model2 = blocks['block2']
201
+ self.model3 = blocks['block3']
202
+ self.model4 = blocks['block4']
203
+ self.model5 = blocks['block5']
204
+ self.model6 = blocks['block6']
205
+
206
+ def forward(self, x):
207
+ out1_0 = self.model1_0(x)
208
+ out1_1 = self.model1_1(out1_0)
209
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
210
+ out_stage2 = self.model2(concat_stage2)
211
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
212
+ out_stage3 = self.model3(concat_stage3)
213
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
214
+ out_stage4 = self.model4(concat_stage4)
215
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
216
+ out_stage5 = self.model5(concat_stage5)
217
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
218
+ out_stage6 = self.model6(concat_stage6)
219
+ return out_stage6
220
+
221
+