aa926620ca6701673dec0f931c22897dc1c41bb94a3cec6e3b01f666d53a734a
Browse files- microsoftexcel-supermerger/elemental_ja.md +119 -0
- microsoftexcel-supermerger/install.py +7 -0
- microsoftexcel-supermerger/sample.txt +95 -0
- microsoftexcel-supermerger/scripts/__pycache__/supermerger.cpython-310.pyc +0 -0
- microsoftexcel-supermerger/scripts/mbwpresets.txt +39 -0
- microsoftexcel-supermerger/scripts/mbwpresets_master.txt +39 -0
- microsoftexcel-supermerger/scripts/mergers/__pycache__/mergers.cpython-310.pyc +0 -0
- microsoftexcel-supermerger/scripts/mergers/__pycache__/model_util.cpython-310.pyc +0 -0
- microsoftexcel-supermerger/scripts/mergers/__pycache__/pluslora.cpython-310.pyc +0 -0
- microsoftexcel-supermerger/scripts/mergers/__pycache__/xyplot.cpython-310.pyc +0 -0
- microsoftexcel-supermerger/scripts/mergers/mergers.py +699 -0
- microsoftexcel-supermerger/scripts/mergers/model_util.py +928 -0
- microsoftexcel-supermerger/scripts/mergers/pluslora.py +1298 -0
- microsoftexcel-supermerger/scripts/mergers/xyplot.py +513 -0
- microsoftexcel-supermerger/scripts/supermerger.py +552 -0
- microsoftexcel-tunnels/.gitignore +176 -0
- microsoftexcel-tunnels/.pre-commit-config.yaml +25 -0
- microsoftexcel-tunnels/LICENSE.md +22 -0
- microsoftexcel-tunnels/README.md +21 -0
- microsoftexcel-tunnels/__pycache__/preload.cpython-310.pyc +0 -0
- microsoftexcel-tunnels/install.py +4 -0
- microsoftexcel-tunnels/preload.py +21 -0
- microsoftexcel-tunnels/pyproject.toml +25 -0
- microsoftexcel-tunnels/scripts/__pycache__/ssh_tunnel.cpython-310.pyc +0 -0
- microsoftexcel-tunnels/scripts/__pycache__/try_cloudflare.cpython-310.pyc +0 -0
- microsoftexcel-tunnels/scripts/ssh_tunnel.py +81 -0
- microsoftexcel-tunnels/scripts/try_cloudflare.py +15 -0
- microsoftexcel-tunnels/ssh_tunnel.py +86 -0
- openpose-editor/.github/CODEOWNERS +2 -0
- openpose-editor/.github/ISSUE_TEMPLATE/bug_report.md +31 -0
- openpose-editor/.github/ISSUE_TEMPLATE/feature_request.md +17 -0
- openpose-editor/.github/workflows/typos.yml +21 -0
- openpose-editor/.gitignore +2 -0
- openpose-editor/.vscode/settings.json +5 -0
- openpose-editor/LICENSE +21 -0
- openpose-editor/README.en.md +40 -0
- openpose-editor/README.md +41 -0
- openpose-editor/README.zh-cn.md +41 -0
- openpose-editor/_typos.toml +11 -0
- openpose-editor/configs/.gitkeep +0 -0
- openpose-editor/images//343/202/271/343/202/257/343/203/252/343/203/274/343/203/263/343/202/267/343/203/247/343/203/203/343/203/210 2023-02-19 131430.png +0 -0
- openpose-editor/javascript/fabric.js +0 -0
- openpose-editor/javascript/main.js +691 -0
- openpose-editor/scripts/__pycache__/main.cpython-310.pyc +0 -0
- openpose-editor/scripts/main.py +143 -0
- openpose-editor/scripts/openpose/__pycache__/body.cpython-310.pyc +0 -0
- openpose-editor/scripts/openpose/__pycache__/model.cpython-310.pyc +0 -0
- openpose-editor/scripts/openpose/__pycache__/util.cpython-310.pyc +0 -0
- openpose-editor/scripts/openpose/body.py +220 -0
- openpose-editor/scripts/openpose/model.py +221 -0
microsoftexcel-supermerger/elemental_ja.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Elemental Merge
|
| 2 |
+
- 階層マージを越えた階層マージです
|
| 3 |
+
|
| 4 |
+
階層マージでは25の階層ごとにマージ比率を変えることができますが、階層もまた複数の要素で構成されており、要素ごとに比率を変えることも原理的には可能です。可能ですが、要素の数は600以上にもなり人の手で扱えるのかは疑問でしたが実装してみました。いきなり要素ごとのマージは推奨されません。階層マージにおいて解決不可能な問題が生じたときに最終調節手段として使うことをおすすめします。
|
| 5 |
+
次の画像はOUT05層の要素を変えた結果です。左端はマージ無し。2番目はOUT05層すべて(つまりは普通の階層マージ),以降が要素マージです。下表のとおり、attn2などの中にはさらに複数の要素が含まれます。
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
## 使い方
|
| 9 |
+
要素マージは通常マージ、階層マージ時どちらの場合でも有効で、最後に計算されるために、階層マージで指定した値は上書きされることに注意してください。
|
| 10 |
+
|
| 11 |
+
Elemental Mergeで設定します。ここにテキストが設定されていると自動的に適応されるので注意して下さい。各要素は下表のとおりですが、各要素のフルネームを入力する必要はありません。
|
| 12 |
+
ちゃんと効果が現れるかどうかはprint changeチェックを有効にすることで確認できます。このチェックを有効にするとマージ時にコマンドプロンプト画面に適用された要素が表示されます。
|
| 13 |
+
部分一致で指定が可能です。
|
| 14 |
+
### 書式
|
| 15 |
+
階層:要素:比率,階層:要素:比率,...
|
| 16 |
+
または
|
| 17 |
+
階層:要素:比率
|
| 18 |
+
階層:要素:比率
|
| 19 |
+
階層:要素:比率
|
| 20 |
+
|
| 21 |
+
カンマまたは改行で区切ることで複数の指定が可能です。カンマと改行は混在しても問題ありません。
|
| 22 |
+
階層は大文字でBASE,IN00-M00-OUT11まで指定でます。空欄にするとすべての階層に適用されます。スペースで区切ることで複数の階層を指定できます。
|
| 23 |
+
要素も同様でスペースで区切ることで複数の要素を指定できます。
|
| 24 |
+
部分一致で判断するので、例えば「attn」と入力するとattn1,attn2両方が変化します。「attn2」の場合はattn2のみ。さらに細かく指定したい場合は「attn2.to_out」などと入力します。
|
| 25 |
+
|
| 26 |
+
OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5
|
| 27 |
+
|
| 28 |
+
と入力すると、OUT03,OUT04,OUT05層のattn2が含まれる要素及びattn1.to_outの比率が0.5になります。
|
| 29 |
+
要素の欄を空欄にすると指定階層のすべての要素が変わり、階層マージと同じ効果になります。
|
| 30 |
+
指定が重複する場合、後に入力された方が優先されます。
|
| 31 |
+
|
| 32 |
+
OUT06:attn:0.5,OUT06:attn2.to_k:0.2
|
| 33 |
+
|
| 34 |
+
と入力した場合、OUT06層のattn2.to_k以外のattnは0.5,attn2.to_kのみ0.2となります。
|
| 35 |
+
|
| 36 |
+
最初にNOTと入力することで効果範囲を反転させることができます。
|
| 37 |
+
これは階層・要素別に設定できます。
|
| 38 |
+
|
| 39 |
+
NOT OUT04:attn:1
|
| 40 |
+
|
| 41 |
+
と入力するとOUT04層以外の層のattnに比率1が設定されます。
|
| 42 |
+
|
| 43 |
+
OUT05:NOT attn proj:0.2
|
| 44 |
+
|
| 45 |
+
とすると、OUT05層のattnとproj以外の層が0.2になります。
|
| 46 |
+
|
| 47 |
+
## XY plot
|
| 48 |
+
elemental用のXY plotを複数用意しています。入力例はsample.txtにあります。
|
| 49 |
+
#### elemental
|
| 50 |
+
複数の要素マージについてXY plotを作成します。要素同士は空行で区切ってください。
|
| 51 |
+
トップ画像はsample.txtのsample1を実行した結果です。
|
| 52 |
+
|
| 53 |
+
#### pinpoint element
|
| 54 |
+
特定の要素について値を変えてXY plotを作成します。pinpoint Blocksと同じことを要素で行います。反対の軸にはalphaを指定してください。要素同士は改行またはカンマで区切ります。
|
| 55 |
+
以下の画像はsample.txtのsample3を実行した結果です。
|
| 56 |
+

|
| 57 |
+
|
| 58 |
+
#### effective elenemtal checker
|
| 59 |
+
各要素の影響度を差分として出力します。オプションでanime gif、csvファイルを出力できます。gif.csvファイルはoutputフォルダにModelAとModelBから作られるフォルダ下に作成されるdiffフォルダに作成されます。ファイル名が重複する場合名前を変えて保存しますが、増えてくるとややこしいのでdiffフォルダを適当な名前に変えることをおすすめします。
|
| 60 |
+
改行またはカンマで区切ります。反対の軸はalphaを使用し、単一の値を入力してください。これは要素の効果を見るのにも有効ですが、要素を指定しないことで階層の効果を見ることも可能なので、そちらの使い方をする場合が多いかもしれません。
|
| 61 |
+
以下��画像はsample.txtのsample5を実行した結果です。
|
| 62 |
+

|
| 63 |
+

|
| 64 |
+
### 要素一覧
|
| 65 |
+
基本的にはattnが顔や服装の情報を担っているようです。特にIN07,OUT03,OUT04,OUT05層の影響度が強いようです。階層によって影響度が異なることが多いので複数の層の同じ要素を同時に変化させることは意味が無いように思えます。
|
| 66 |
+
nullと書かれた場所には要素が存在しません。
|
| 67 |
+
|
| 68 |
+
||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11
|
| 69 |
+
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
|
| 70 |
+
op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
|
| 71 |
+
op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
|
| 72 |
+
emb_layers.1.bias|null|||null|||null|||null|null|||||||||||||||
|
| 73 |
+
emb_layers.1.weight|null|||null|||null|||null|null|||||||||||||||
|
| 74 |
+
in_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
|
| 75 |
+
in_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
|
| 76 |
+
in_layers.2.bias|null|||null|||null|||null|null|||||||||||||||
|
| 77 |
+
in_layers.2.weight|null|||null|||null|||null|null|||||||||||||||
|
| 78 |
+
out_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
|
| 79 |
+
out_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
|
| 80 |
+
out_layers.3.bias|null|||null|||null|||null|null|||||||||||||||
|
| 81 |
+
out_layers.3.weight|null|||null|||null|||null|null|||||||||||||||
|
| 82 |
+
skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
|
| 83 |
+
skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
|
| 84 |
+
norm.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 85 |
+
norm.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 86 |
+
proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 87 |
+
proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 88 |
+
proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 89 |
+
proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 90 |
+
transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 91 |
+
transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 92 |
+
transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 93 |
+
transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 94 |
+
transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 95 |
+
transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 96 |
+
transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 97 |
+
transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 98 |
+
transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 99 |
+
transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 100 |
+
transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 101 |
+
transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 102 |
+
transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 103 |
+
transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 104 |
+
transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 105 |
+
transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 106 |
+
transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 107 |
+
transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 108 |
+
transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 109 |
+
transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
|
| 110 |
+
conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
|
| 111 |
+
conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
|
| 112 |
+
0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 113 |
+
0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 114 |
+
2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 115 |
+
2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 116 |
+
time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 117 |
+
time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 118 |
+
time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
| 119 |
+
time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
|
microsoftexcel-supermerger/install.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import launch
|
| 2 |
+
|
| 3 |
+
if not launch.is_installed("sklearn"):
|
| 4 |
+
launch.run_pip("install scikit-learn", "scikit-learn")
|
| 5 |
+
|
| 6 |
+
if not launch.is_installed("diffusers"):
|
| 7 |
+
launch.run_pip("install diffusers", "diffusers")
|
microsoftexcel-supermerger/sample.txt
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
eamples of XY plot
|
| 2 |
+
***************************************************************
|
| 3 |
+
for elemental
|
| 4 |
+
Each value is separated by a blank line/空行で区切ります
|
| 5 |
+
Element merge values are comma or newline delimited/要素マージの値はカンマまたは改行区切りです
|
| 6 |
+
Commas and newlines can also be mixed/カンマと改行の混在も可能です
|
| 7 |
+
You can insert no elemental by inserting two blank lines at the beginning/最初に空行をふたつ入れるとelementalのない状態を挿入できます
|
| 8 |
+
|
| 9 |
+
**sample1*******************************************************
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
OUT05::1
|
| 13 |
+
|
| 14 |
+
OUT05:layers:1
|
| 15 |
+
|
| 16 |
+
OUT05:attn1:1
|
| 17 |
+
|
| 18 |
+
OUT05:attn2:1
|
| 19 |
+
|
| 20 |
+
OUT05:ff.net:1
|
| 21 |
+
**sample2*******************************************************
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
IN07:attn1:0.5,IN07:attn2:0.2
|
| 25 |
+
OUT03:NOT attn1.to_k.weight:0.5
|
| 26 |
+
OUT03:attn1.to_k.weight:0.3
|
| 27 |
+
|
| 28 |
+
OUT04:NOT ff.net:0.5
|
| 29 |
+
OUT04:attn1.to_k.weight:0.3
|
| 30 |
+
OUT04:attn1.to_out.0.weight:0.3
|
| 31 |
+
|
| 32 |
+
OUT05:NOT ff.net:0.5
|
| 33 |
+
OUT05:attn1.to_k.weight:0.3
|
| 34 |
+
OUT05:attn1.to_out.0.weight:0.3
|
| 35 |
+
***************************************************************
|
| 36 |
+
***************************************************************
|
| 37 |
+
for pinpoint element
|
| 38 |
+
Each value is separated by comma or a blank line/カンマまたは改行で区切ります
|
| 39 |
+
do not enter ratio/ratioは入力しません
|
| 40 |
+
**sample3******************************************************
|
| 41 |
+
IN07:,IN07:layers,IN07:attn1,IN07:attn2,IN07:ff.net
|
| 42 |
+
**sample4******************************************************
|
| 43 |
+
OUT04:NOT attn2.to_q
|
| 44 |
+
OUT04:attn1.to_k.weight
|
| 45 |
+
OUT04:ff.net
|
| 46 |
+
OUT04:attn
|
| 47 |
+
***************************************************************
|
| 48 |
+
***************************************************************
|
| 49 |
+
for effective elemental cheker
|
| 50 |
+
|
| 51 |
+
Examine the effect of each block/各階層の影響度を調べる
|
| 52 |
+
Output results can be split by inserting "|"/途中「|」を挿入することで出力結果を分割できます。
|
| 53 |
+
**sample5*************************************************************
|
| 54 |
+
IN00:,IN01:,IN02:,IN03:,IN04:,IN05:,IN06:,IN07:,IN08:,IN09:,IN10:,IN11:,M00:|OUT00:,OUT01:,OUT02:,OUT03:,OUT04:,OUT05:,OUT06:,OUT07:,OUT08:,OUT09:,OUT10:,OUT11:
|
| 55 |
+
|
| 56 |
+
Examine the effect of all elements in the IN01/IN01層のすべての要素の影響度を調べる
|
| 57 |
+
Below that corresponds to IN02 or later/その下はIN02以降に対応
|
| 58 |
+
**sample6*************************************************************
|
| 59 |
+
IN01:emb_layers.1.bias,IN01:emb_layers.1.weight,IN01:in_layers.0.bias,IN01:in_layers.0.weight,IN01:in_layers.2.bias,IN01:in_layers.2.weight,IN01:out_layers.0.bias,IN01:out_layers.0.weight,IN01:out_layers.3.bias,IN01:out_layers.3.weight,IN01:skip_connection.bias,IN01:skip_connection.weight,IN01:norm.bias,IN01:norm.weight,IN01:proj_in.bias,IN01:proj_in.weight,IN01:proj_out.bias,IN01:proj_out.weight,IN01:transformer_blocks.0.attn1.to_k.weight,IN01:transformer_blocks.0.attn1.to_out.0.bias,IN01:transformer_blocks.0.attn1.to_out.0.weight,IN01:transformer_blocks.0.attn1.to_q.weight,IN01:transformer_blocks.0.attn1.to_v.weight,IN01:transformer_blocks.0.attn2.to_k.weight,IN01:transformer_blocks.0.attn2.to_out.0.bias,IN01:transformer_blocks.0.attn2.to_out.0.weight,IN01:transformer_blocks.0.attn2.to_q.weight,IN01:transformer_blocks.0.attn2.to_v.weight,IN01:transformer_blocks.0.ff.net.0.proj.bias,IN01:transformer_blocks.0.ff.net.0.proj.weight,IN01:transformer_blocks.0.ff.net.2.bias,IN01:transformer_blocks.0.ff.net.2.weight,IN01:transformer_blocks.0.norm1.bias,IN01:transformer_blocks.0.norm1.weight,IN01:transformer_blocks.0.norm2.bias,IN01:transformer_blocks.0.norm2.weight,IN01:transformer_blocks.0.norm3.bias,IN01:transformer_blocks.0.norm3.weight
|
| 60 |
+
|
| 61 |
+
IN02:emb_layers.1.bias,IN02:emb_layers.1.weight,IN02:in_layers.0.bias,IN02:in_layers.0.weight,IN02:in_layers.2.bias,IN02:in_layers.2.weight,IN02:out_layers.0.bias,IN02:out_layers.0.weight,IN02:out_layers.3.bias,IN02:out_layers.3.weight,IN02:skip_connection.bias,IN02:skip_connection.weight,IN02:norm.bias,IN02:norm.weight,IN02:proj_in.bias,IN02:proj_in.weight,IN02:proj_out.bias,IN02:proj_out.weight,IN02:transformer_blocks.0.attn1.to_k.weight,IN02:transformer_blocks.0.attn1.to_out.0.bias,IN02:transformer_blocks.0.attn1.to_out.0.weight,IN02:transformer_blocks.0.attn1.to_q.weight,IN02:transformer_blocks.0.attn1.to_v.weight,IN02:transformer_blocks.0.attn2.to_k.weight,IN02:transformer_blocks.0.attn2.to_out.0.bias,IN02:transformer_blocks.0.attn2.to_out.0.weight,IN02:transformer_blocks.0.attn2.to_q.weight,IN02:transformer_blocks.0.attn2.to_v.weight,IN02:transformer_blocks.0.ff.net.0.proj.bias,IN02:transformer_blocks.0.ff.net.0.proj.weight,IN02:transformer_blocks.0.ff.net.2.bias,IN02:transformer_blocks.0.ff.net.2.weight,IN02:transformer_blocks.0.norm1.bias,IN02:transformer_blocks.0.norm1.weight,IN02:transformer_blocks.0.norm2.bias,IN02:transformer_blocks.0.norm2.weight,IN02:transformer_blocks.0.norm3.bias,IN02:transformer_blocks.0.norm3.weight
|
| 62 |
+
|
| 63 |
+
IN00:bias,IN00:weight,IN03:op.bias,IN03:op.weight,IN06:op.bias,IN06:op.weight,IN09:op.bias,IN09:op.weight
|
| 64 |
+
|
| 65 |
+
IN04:emb_layers.1.bias,IN04:emb_layers.1.weight,IN04:in_layers.0.bias,IN04:in_layers.0.weight,IN04:in_layers.2.bias,IN04:in_layers.2.weight,IN04:out_layers.0.bias,IN04:out_layers.0.weight,IN04:out_layers.3.bias,IN04:out_layers.3.weight,IN04:skip_connection.bias,IN04:skip_connection.weight,IN04:norm.bias,IN04:norm.weight,IN04:proj_in.bias,IN04:proj_in.weight,IN04:proj_out.bias,IN04:proj_out.weight,IN04:transformer_blocks.0.attn1.to_k.weight,IN04:transformer_blocks.0.attn1.to_out.0.bias,IN04:transformer_blocks.0.attn1.to_out.0.weight,IN04:transformer_blocks.0.attn1.to_q.weight,IN04:transformer_blocks.0.attn1.to_v.weight,IN04:transformer_blocks.0.attn2.to_k.weight,IN04:transformer_blocks.0.attn2.to_out.0.bias,IN04:transformer_blocks.0.attn2.to_out.0.weight,IN04:transformer_blocks.0.attn2.to_q.weight,IN04:transformer_blocks.0.attn2.to_v.weight,IN04:transformer_blocks.0.ff.net.0.proj.bias,IN04:transformer_blocks.0.ff.net.0.proj.weight,IN04:transformer_blocks.0.ff.net.2.bias,IN04:transformer_blocks.0.ff.net.2.weight,IN04:transformer_blocks.0.norm1.bias,IN04:transformer_blocks.0.norm1.weight,IN04:transformer_blocks.0.norm2.bias,IN04:transformer_blocks.0.norm2.weight,IN04:transformer_blocks.0.norm3.bias,IN04:transformer_blocks.0.norm3.weight
|
| 66 |
+
|
| 67 |
+
IN05:emb_layers.1.bias,IN05:emb_layers.1.weight,IN05:in_layers.0.bias,IN05:in_layers.0.weight,IN05:in_layers.2.bias,IN05:in_layers.2.weight,IN05:out_layers.0.bias,IN05:out_layers.0.weight,IN05:out_layers.3.bias,IN05:out_layers.3.weight,IN05:skip_connection.bias,IN05:skip_connection.weight,IN05:norm.bias,IN05:norm.weight,IN05:proj_in.bias,IN05:proj_in.weight,IN05:proj_out.bias,IN05:proj_out.weight,IN05:transformer_blocks.0.attn1.to_k.weight,IN05:transformer_blocks.0.attn1.to_out.0.bias,IN05:transformer_blocks.0.attn1.to_out.0.weight,IN05:transformer_blocks.0.attn1.to_q.weight,IN05:transformer_blocks.0.attn1.to_v.weight,IN05:transformer_blocks.0.attn2.to_k.weight,IN05:transformer_blocks.0.attn2.to_out.0.bias,IN05:transformer_blocks.0.attn2.to_out.0.weight,IN05:transformer_blocks.0.attn2.to_q.weight,IN05:transformer_blocks.0.attn2.to_v.weight,IN05:transformer_blocks.0.ff.net.0.proj.bias,IN05:transformer_blocks.0.ff.net.0.proj.weight,IN05:transformer_blocks.0.ff.net.2.bias,IN05:transformer_blocks.0.ff.net.2.weight,IN05:transformer_blocks.0.norm1.bias,IN05:transformer_blocks.0.norm1.weight,IN05:transformer_blocks.0.norm2.bias,IN05:transformer_blocks.0.norm2.weight,IN05:transformer_blocks.0.norm3.bias,IN05:transformer_blocks.0.norm3.weight
|
| 68 |
+
|
| 69 |
+
IN07:emb_layers.1.bias,IN07:emb_layers.1.weight,IN07:in_layers.0.bias,IN07:in_layers.0.weight,IN07:in_layers.2.bias,IN07:in_layers.2.weight,IN07:out_layers.0.bias,IN07:out_layers.0.weight,IN07:out_layers.3.bias,IN07:out_layers.3.weight,IN07:skip_connection.bias,IN07:skip_connection.weight,IN07:norm.bias,IN07:norm.weight,IN07:proj_in.bias,IN07:proj_in.weight,IN07:proj_out.bias,IN07:proj_out.weight,IN07:transformer_blocks.0.attn1.to_k.weight,IN07:transformer_blocks.0.attn1.to_out.0.bias,IN07:transformer_blocks.0.attn1.to_out.0.weight,IN07:transformer_blocks.0.attn1.to_q.weight,IN07:transformer_blocks.0.attn1.to_v.weight,IN07:transformer_blocks.0.attn2.to_k.weight,IN07:transformer_blocks.0.attn2.to_out.0.bias,IN07:transformer_blocks.0.attn2.to_out.0.weight,IN07:transformer_blocks.0.attn2.to_q.weight,IN07:transformer_blocks.0.attn2.to_v.weight,IN07:transformer_blocks.0.ff.net.0.proj.bias,IN07:transformer_blocks.0.ff.net.0.proj.weight,IN07:transformer_blocks.0.ff.net.2.bias,IN07:transformer_blocks.0.ff.net.2.weight,IN07:transformer_blocks.0.norm1.bias,IN07:transformer_blocks.0.norm1.weight,IN07:transformer_blocks.0.norm2.bias,IN07:transformer_blocks.0.norm2.weight,IN07:transformer_blocks.0.norm3.bias,IN07:transformer_blocks.0.norm3.weight
|
| 70 |
+
|
| 71 |
+
IN08:emb_layers.1.bias,IN08:emb_layers.1.weight,IN08:in_layers.0.bias,IN08:in_layers.0.weight,IN08:in_layers.2.bias,IN08:in_layers.2.weight,IN08:out_layers.0.bias,IN08:out_layers.0.weight,IN08:out_layers.3.bias,IN08:out_layers.3.weight,IN08:skip_connection.bias,IN08:skip_connection.weight,IN08:norm.bias,IN08:norm.weight,IN08:proj_in.bias,IN08:proj_in.weight,IN08:proj_out.bias,IN08:proj_out.weight,IN08:transformer_blocks.0.attn1.to_k.weight,IN08:transformer_blocks.0.attn1.to_out.0.bias,IN08:transformer_blocks.0.attn1.to_out.0.weight,IN08:transformer_blocks.0.attn1.to_q.weight,IN08:transformer_blocks.0.attn1.to_v.weight,IN08:transformer_blocks.0.attn2.to_k.weight,IN08:transformer_blocks.0.attn2.to_out.0.bias,IN08:transformer_blocks.0.attn2.to_out.0.weight,IN08:transformer_blocks.0.attn2.to_q.weight,IN08:transformer_blocks.0.attn2.to_v.weight,IN08:transformer_blocks.0.ff.net.0.proj.bias,IN08:transformer_blocks.0.ff.net.0.proj.weight,IN08:transformer_blocks.0.ff.net.2.bias,IN08:transformer_blocks.0.ff.net.2.weight,IN08:transformer_blocks.0.norm1.bias,IN08:transformer_blocks.0.norm1.weight,IN08:transformer_blocks.0.norm2.bias,IN08:transformer_blocks.0.norm2.weight,IN08:transformer_blocks.0.norm3.bias,IN08:transformer_blocks.0.norm3.weight
|
| 72 |
+
|
| 73 |
+
IN10:emb_layers.1.bias,IN10:emb_layers.1.weight,IN10:in_layers.0.bias,IN10:in_layers.0.weight,IN10:in_layers.2.bias,IN10:in_layers.2.weight,IN10:out_layers.0.bias,IN10:out_layers.0.weight,IN10:out_layers.3.bias,IN10:out_layers.3.weight,IN11:emb_layers.1.bias,IN11:emb_layers.1.weight,IN11:in_layers.0.bias,IN11:in_layers.0.weight,IN11:in_layers.2.bias,IN11:in_layers.2.weight,IN11:out_layers.0.bias,IN11:out_layers.0.weight,IN11:out_layers.3.bias,IN11:out_layers.3.weight
|
| 74 |
+
|
| 75 |
+
M00:0.emb_layers.1.bias,M00:0.emb_layers.1.weight,M00:0.in_layers.0.bias,M00:0.in_layers.0.weight,M00:0.in_layers.2.bias,M00:0.in_layers.2.weight,M00:0.out_layers.0.bias,M00:0.out_layers.0.weight,M00:0.out_layers.3.bias,M00:0.out_layers.3.weight,M00:1.norm.bias,M00:1.norm.weight,M00:1.proj_in.bias,M00:1.proj_in.weight,M00:1.proj_out.bias,M00:1.proj_out.weight,M00:1.transformer_blocks.0.attn1.to_k.weight,M00:1.transformer_blocks.0.attn1.to_out.0.bias,M00:1.transformer_blocks.0.attn1.to_out.0.weight,M00:1.transformer_blocks.0.attn1.to_q.weight,M00:1.transformer_blocks.0.attn1.to_v.weight,M00:1.transformer_blocks.0.attn2.to_k.weight,M00:1.transformer_blocks.0.attn2.to_out.0.bias,M00:1.transformer_blocks.0.attn2.to_out.0.weight,M00:1.transformer_blocks.0.attn2.to_q.weight,M00:1.transformer_blocks.0.attn2.to_v.weight,M00:1.transformer_blocks.0.ff.net.0.proj.bias,M00:1.transformer_blocks.0.ff.net.0.proj.weight,M00:1.transformer_blocks.0.ff.net.2.bias,M00:1.transformer_blocks.0.ff.net.2.weight,M00:1.transformer_blocks.0.norm1.bias,M00:1.transformer_blocks.0.norm1.weight,M00:1.transformer_blocks.0.norm2.bias,M00:1.transformer_blocks.0.norm2.weight,M00:1.transformer_blocks.0.norm3.bias,M00:1.transformer_blocks.0.norm3.weight,M00:2.emb_layers.1.bias,M00:2.emb_layers.1.weight,M00:2.in_layers.0.bias,M00:2.in_layers.0.weight,M00:2.in_layers.2.bias,M00:2.in_layers.2.weight,M00:2.out_layers.0.bias,M00:2.out_layers.0.weight,M00:2.out_layers.3.bias,M00:2.out_layers.3.weight
|
| 76 |
+
|
| 77 |
+
OUT00:emb_layers.1.bias,OUT00:emb_layers.1.weight,OUT00:in_layers.0.bias,OUT00:in_layers.0.weight,OUT00:in_layers.2.bias,OUT00:in_layers.2.weight,OUT00:out_layers.0.bias,OUT00:out_layers.0.weight,OUT00:out_layers.3.bias,OUT00:out_layers.3.weight,OUT00:skip_connection.bias,OUT00:skip_connection.weight,OUT01:emb_layers.1.bias,OUT01:emb_layers.1.weight,OUT01:in_layers.0.bias,OUT01:in_layers.0.weight,OUT01:in_layers.2.bias,OUT01:in_layers.2.weight,OUT01:out_layers.0.bias,OUT01:out_layers.0.weight,OUT01:out_layers.3.bias,OUT01:out_layers.3.weight,OUT01:skip_connection.bias,OUT01:skip_connection.weight,OUT02:emb_layers.1.bias,OUT02:emb_layers.1.weight,OUT02:in_layers.0.bias,OUT02:in_layers.0.weight,OUT02:in_layers.2.bias,OUT02:in_layers.2.weight,OUT02:out_layers.0.bias,OUT02:out_layers.0.weight,OUT02:out_layers.3.bias,OUT02:out_layers.3.weight,OUT02:skip_connection.bias,OUT02:skip_connection.weight,OUT02:conv.bias,OUT02:conv.weight
|
| 78 |
+
|
| 79 |
+
OUT03:emb_layers.1.bias,OUT03:emb_layers.1.weight,OUT03:in_layers.0.bias,OUT03:in_layers.0.weight,OUT03:in_layers.2.bias,OUT03:in_layers.2.weight,OUT03:out_layers.0.bias,OUT03:out_layers.0.weight,OUT03:out_layers.3.bias,OUT03:out_layers.3.weight,OUT03:skip_connection.bias,OUT03:skip_connection.weight,OUT03:norm.bias,OUT03:norm.weight,OUT03:proj_in.bias,OUT03:proj_in.weight,OUT03:proj_out.bias,OUT03:proj_out.weight,OUT03:transformer_blocks.0.attn1.to_k.weight,OUT03:transformer_blocks.0.attn1.to_out.0.bias,OUT03:transformer_blocks.0.attn1.to_out.0.weight,OUT03:transformer_blocks.0.attn1.to_q.weight,OUT03:transformer_blocks.0.attn1.to_v.weight,OUT03:transformer_blocks.0.attn2.to_k.weight,OUT03:transformer_blocks.0.attn2.to_out.0.bias,OUT03:transformer_blocks.0.attn2.to_out.0.weight,OUT03:transformer_blocks.0.attn2.to_q.weight,OUT03:transformer_blocks.0.attn2.to_v.weight,OUT03:transformer_blocks.0.ff.net.0.proj.bias,OUT03:transformer_blocks.0.ff.net.0.proj.weight,OUT03:transformer_blocks.0.ff.net.2.bias,OUT03:transformer_blocks.0.ff.net.2.weight,OUT03:transformer_blocks.0.norm1.bias,OUT03:transformer_blocks.0.norm1.weight,OUT03:transformer_blocks.0.norm2.bias,OUT03:transformer_blocks.0.norm2.weight,OUT03:transformer_blocks.0.norm3.bias,OUT03:transformer_blocks.0.norm3.weight
|
| 80 |
+
|
| 81 |
+
OUT04:emb_layers.1.bias,OUT04:emb_layers.1.weight,OUT04:in_layers.0.bias,OUT04:in_layers.0.weight,OUT04:in_layers.2.bias,OUT04:in_layers.2.weight,OUT04:out_layers.0.bias,OUT04:out_layers.0.weight,OUT04:out_layers.3.bias,OUT04:out_layers.3.weight,OUT04:skip_connection.bias,OUT04:skip_connection.weight,OUT04:norm.bias,OUT04:norm.weight,OUT04:proj_in.bias,OUT04:proj_in.weight,OUT04:proj_out.bias,OUT04:proj_out.weight,OUT04:transformer_blocks.0.attn1.to_k.weight,OUT04:transformer_blocks.0.attn1.to_out.0.bias,OUT04:transformer_blocks.0.attn1.to_out.0.weight,OUT04:transformer_blocks.0.attn1.to_q.weight,OUT04:transformer_blocks.0.attn1.to_v.weight,OUT04:transformer_blocks.0.attn2.to_k.weight,OUT04:transformer_blocks.0.attn2.to_out.0.bias,OUT04:transformer_blocks.0.attn2.to_out.0.weight,OUT04:transformer_blocks.0.attn2.to_q.weight,OUT04:transformer_blocks.0.attn2.to_v.weight,OUT04:transformer_blocks.0.ff.net.0.proj.bias,OUT04:transformer_blocks.0.ff.net.0.proj.weight,OUT04:transformer_blocks.0.ff.net.2.bias,OUT04:transformer_blocks.0.ff.net.2.weight,OUT04:transformer_blocks.0.norm1.bias,OUT04:transformer_blocks.0.norm1.weight,OUT04:transformer_blocks.0.norm2.bias,OUT04:transformer_blocks.0.norm2.weight,OUT04:transformer_blocks.0.norm3.bias,OUT04:transformer_blocks.0.norm3.weight
|
| 82 |
+
|
| 83 |
+
OUT05:emb_layers.1.bias,OUT05:emb_layers.1.weight,OUT05:in_layers.0.bias,OUT05:in_layers.0.weight,OUT05:in_layers.2.bias,OUT05:in_layers.2.weight,OUT05:out_layers.0.bias,OUT05:out_layers.0.weight,OUT05:out_layers.3.bias,OUT05:out_layers.3.weight,OUT05:skip_connection.bias,OUT05:skip_connection.weight,OUT05:norm.bias,OUT05:norm.weight,OUT05:proj_in.bias,OUT05:proj_in.weight,OUT05:proj_out.bias,OUT05:proj_out.weight,OUT05:transformer_blocks.0.attn1.to_k.weight,OUT05:transformer_blocks.0.attn1.to_out.0.bias,OUT05:transformer_blocks.0.attn1.to_out.0.weight,OUT05:transformer_blocks.0.attn1.to_q.weight,OUT05:transformer_blocks.0.attn1.to_v.weight,OUT05:transformer_blocks.0.attn2.to_k.weight,OUT05:transformer_blocks.0.attn2.to_out.0.bias,OUT05:transformer_blocks.0.attn2.to_out.0.weight,OUT05:transformer_blocks.0.attn2.to_q.weight,OUT05:transformer_blocks.0.attn2.to_v.weight,OUT05:transformer_blocks.0.ff.net.0.proj.bias,OUT05:transformer_blocks.0.ff.net.0.proj.weight,OUT05:transformer_blocks.0.ff.net.2.bias,OUT05:transformer_blocks.0.ff.net.2.weight,OUT05:transformer_blocks.0.norm1.bias,OUT05:transformer_blocks.0.norm1.weight,OUT05:transformer_blocks.0.norm2.bias,OUT05:transformer_blocks.0.norm2.weight,OUT05:transformer_blocks.0.norm3.bias,OUT05:transformer_blocks.0.norm3.weight,OUT05:conv.bias,OUT05:conv.weight
|
| 84 |
+
|
| 85 |
+
,OUT06:emb_layers.1.bias,OUT06:emb_layers.1.weight,OUT06:in_layers.0.bias,OUT06:in_layers.0.weight,OUT06:in_layers.2.bias,OUT06:in_layers.2.weight,OUT06:out_layers.0.bias,OUT06:out_layers.0.weight,OUT06:out_layers.3.bias,OUT06:out_layers.3.weight,OUT06:skip_connection.bias,OUT06:skip_connection.weight,OUT06:norm.bias,OUT06:norm.weight,OUT06:proj_in.bias,OUT06:proj_in.weight,OUT06:proj_out.bias,OUT06:proj_out.weight,OUT06:transformer_blocks.0.attn1.to_k.weight,OUT06:transformer_blocks.0.attn1.to_out.0.bias,OUT06:transformer_blocks.0.attn1.to_out.0.weight,OUT06:transformer_blocks.0.attn1.to_q.weight,OUT06:transformer_blocks.0.attn1.to_v.weight,OUT06:transformer_blocks.0.attn2.to_k.weight,OUT06:transformer_blocks.0.attn2.to_out.0.bias,OUT06:transformer_blocks.0.attn2.to_out.0.weight,OUT06:transformer_blocks.0.attn2.to_q.weight,OUT06:transformer_blocks.0.attn2.to_v.weight,OUT06:transformer_blocks.0.ff.net.0.proj.bias,OUT06:transformer_blocks.0.ff.net.0.proj.weight,OUT06:transformer_blocks.0.ff.net.2.bias,OUT06:transformer_blocks.0.ff.net.2.weight,OUT06:transformer_blocks.0.norm1.bias,OUT06:transformer_blocks.0.norm1.weight,OUT06:transformer_blocks.0.norm2.bias,OUT06:transformer_blocks.0.norm2.weight,OUT06:transformer_blocks.0.norm3.bias,OUT06:transformer_blocks.0.norm3.weight
|
| 86 |
+
|
| 87 |
+
OUT07:emb_layers.1.bias,OUT07:emb_layers.1.weight,OUT07:in_layers.0.bias,OUT07:in_layers.0.weight,OUT07:in_layers.2.bias,OUT07:in_layers.2.weight,OUT07:out_layers.0.bias,OUT07:out_layers.0.weight,OUT07:out_layers.3.bias,OUT07:out_layers.3.weight,OUT07:skip_connection.bias,OUT07:skip_connection.weight,OUT07:norm.bias,OUT07:norm.weight,OUT07:proj_in.bias,OUT07:proj_in.weight,OUT07:proj_out.bias,OUT07:proj_out.weight,OUT07:transformer_blocks.0.attn1.to_k.weight,OUT07:transformer_blocks.0.attn1.to_out.0.bias,OUT07:transformer_blocks.0.attn1.to_out.0.weight,OUT07:transformer_blocks.0.attn1.to_q.weight,OUT07:transformer_blocks.0.attn1.to_v.weight,OUT07:transformer_blocks.0.attn2.to_k.weight,OUT07:transformer_blocks.0.attn2.to_out.0.bias,OUT07:transformer_blocks.0.attn2.to_out.0.weight,OUT07:transformer_blocks.0.attn2.to_q.weight,OUT07:transformer_blocks.0.attn2.to_v.weight,OUT07:transformer_blocks.0.ff.net.0.proj.bias,OUT07:transformer_blocks.0.ff.net.0.proj.weight,OUT07:transformer_blocks.0.ff.net.2.bias,OUT07:transformer_blocks.0.ff.net.2.weight,OUT07:transformer_blocks.0.norm1.bias,OUT07:transformer_blocks.0.norm1.weight,OUT07:transformer_blocks.0.norm2.bias,OUT07:transformer_blocks.0.norm2.weight,OUT07:transformer_blocks.0.norm3.bias,OUT07:transformer_blocks.0.norm3.weight
|
| 88 |
+
|
| 89 |
+
,OUT08:emb_layers.1.bias,OUT08:emb_layers.1.weight,OUT08:in_layers.0.bias,OUT08:in_layers.0.weight,OUT08:in_layers.2.bias,OUT08:in_layers.2.weight,OUT08:out_layers.0.bias,OUT08:out_layers.0.weight,OUT08:out_layers.3.bias,OUT08:out_layers.3.weight,OUT08:skip_connection.bias,OUT08:skip_connection.weight,OUT08:norm.bias,OUT08:norm.weight,OUT08:proj_in.bias,OUT08:proj_in.weight,OUT08:proj_out.bias,OUT08:proj_out.weight,OUT08:transformer_blocks.0.attn1.to_k.weight,OUT08:transformer_blocks.0.attn1.to_out.0.bias,OUT08:transformer_blocks.0.attn1.to_out.0.weight,OUT08:transformer_blocks.0.attn1.to_q.weight,OUT08:transformer_blocks.0.attn1.to_v.weight,OUT08:transformer_blocks.0.attn2.to_k.weight,OUT08:transformer_blocks.0.attn2.to_out.0.bias,OUT08:transformer_blocks.0.attn2.to_out.0.weight,OUT08:transformer_blocks.0.attn2.to_q.weight,OUT08:transformer_blocks.0.attn2.to_v.weight,OUT08:transformer_blocks.0.ff.net.0.proj.bias,OUT08:transformer_blocks.0.ff.net.0.proj.weight,OUT08:transformer_blocks.0.ff.net.2.bias,OUT08:transformer_blocks.0.ff.net.2.weight,OUT08:transformer_blocks.0.norm1.bias,OUT08:transformer_blocks.0.norm1.weight,OUT08:transformer_blocks.0.norm2.bias,OUT08:transformer_blocks.0.norm2.weight,OUT08:transformer_blocks.0.norm3.bias,OUT08:transformer_blocks.0.norm3.weight,OUT08:conv.bias,OUT08:conv.weight
|
| 90 |
+
|
| 91 |
+
OUT09:emb_layers.1.bias,OUT09:emb_layers.1.weight,OUT09:in_layers.0.bias,OUT09:in_layers.0.weight,OUT09:in_layers.2.bias,OUT09:in_layers.2.weight,OUT09:out_layers.0.bias,OUT09:out_layers.0.weight,OUT09:out_layers.3.bias,OUT09:out_layers.3.weight,OUT09:skip_connection.bias,OUT09:skip_connection.weight,OUT09:norm.bias,OUT09:norm.weight,OUT09:proj_in.bias,OUT09:proj_in.weight,OUT09:proj_out.bias,OUT09:proj_out.weight,OUT09:transformer_blocks.0.attn1.to_k.weight,OUT09:transformer_blocks.0.attn1.to_out.0.bias,OUT09:transformer_blocks.0.attn1.to_out.0.weight,OUT09:transformer_blocks.0.attn1.to_q.weight,OUT09:transformer_blocks.0.attn1.to_v.weight,OUT09:transformer_blocks.0.attn2.to_k.weight,OUT09:transformer_blocks.0.attn2.to_out.0.bias,OUT09:transformer_blocks.0.attn2.to_out.0.weight,OUT09:transformer_blocks.0.attn2.to_q.weight,OUT09:transformer_blocks.0.attn2.to_v.weight,OUT09:transformer_blocks.0.ff.net.0.proj.bias,OUT09:transformer_blocks.0.ff.net.0.proj.weight,OUT09:transformer_blocks.0.ff.net.2.bias,OUT09:transformer_blocks.0.ff.net.2.weight,OUT09:transformer_blocks.0.norm1.bias,OUT09:transformer_blocks.0.norm1.weight,OUT09:transformer_blocks.0.norm2.bias,OUT09:transformer_blocks.0.norm2.weight,OUT09:transformer_blocks.0.norm3.bias,OUT09:transformer_blocks.0.norm3.weight
|
| 92 |
+
|
| 93 |
+
OUT10:emb_layers.1.bias,OUT10:emb_layers.1.weight,OUT10:in_layers.0.bias,OUT10:in_layers.0.weight,OUT10:in_layers.2.bias,OUT10:in_layers.2.weight,OUT10:out_layers.0.bias,OUT10:out_layers.0.weight,OUT10:out_layers.3.bias,OUT10:out_layers.3.weight,OUT10:skip_connection.bias,OUT10:skip_connection.weight,OUT10:norm.bias,OUT10:norm.weight,OUT10:proj_in.bias,OUT10:proj_in.weight,OUT10:proj_out.bias,OUT10:proj_out.weight,OUT10:transformer_blocks.0.attn1.to_k.weight,OUT10:transformer_blocks.0.attn1.to_out.0.bias,OUT10:transformer_blocks.0.attn1.to_out.0.weight,OUT10:transformer_blocks.0.attn1.to_q.weight,OUT10:transformer_blocks.0.attn1.to_v.weight,OUT10:transformer_blocks.0.attn2.to_k.weight,OUT10:transformer_blocks.0.attn2.to_out.0.bias,OUT10:transformer_blocks.0.attn2.to_out.0.weight,OUT10:transformer_blocks.0.attn2.to_q.weight,OUT10:transformer_blocks.0.attn2.to_v.weight,OUT10:transformer_blocks.0.ff.net.0.proj.bias,OUT10:transformer_blocks.0.ff.net.0.proj.weight,OUT10:transformer_blocks.0.ff.net.2.bias,OUT10:transformer_blocks.0.ff.net.2.weight,OUT10:transformer_blocks.0.norm1.bias,OUT10:transformer_blocks.0.norm1.weight,OUT10:transformer_blocks.0.norm2.bias,OUT10:transformer_blocks.0.norm2.weight,OUT10:transformer_blocks.0.norm3.bias,OUT10:transformer_blocks.0.norm3.weight
|
| 94 |
+
|
| 95 |
+
OUT11:emb_layers.1.bias,OUT11:emb_layers.1.weight,OUT11:in_layers.0.bias,OUT11:in_layers.0.weight,OUT11:in_layers.2.bias,OUT11:in_layers.2.weight,OUT11:out_layers.0.bias,OUT11:out_layers.0.weight,OUT11:out_layers.3.bias,OUT11:out_layers.3.weight,OUT11:skip_connection.bias,OUT11:skip_connection.weight,OUT11:norm.bias,OUT11:norm.weight,OUT11:proj_in.bias,OUT11:proj_in.weight,OUT11:proj_out.bias,OUT11:proj_out.weight,OUT11:transformer_blocks.0.attn1.to_k.weight,OUT11:transformer_blocks.0.attn1.to_out.0.bias,OUT11:transformer_blocks.0.attn1.to_out.0.weight,OUT11:transformer_blocks.0.attn1.to_q.weight,OUT11:transformer_blocks.0.attn1.to_v.weight,OUT11:transformer_blocks.0.attn2.to_k.weight,OUT11:transformer_blocks.0.attn2.to_out.0.bias,OUT11:transformer_blocks.0.attn2.to_out.0.weight,OUT11:transformer_blocks.0.attn2.to_q.weight,OUT11:transformer_blocks.0.attn2.to_v.weight,OUT11:transformer_blocks.0.ff.net.0.proj.bias,OUT11:transformer_blocks.0.ff.net.0.proj.weight,OUT11:transformer_blocks.0.ff.net.2.bias,OUT11:transformer_blocks.0.ff.net.2.weight,OUT11:transformer_blocks.0.norm1.bias,OUT11:transformer_blocks.0.norm1.weight,OUT11:transformer_blocks.0.norm2.bias,OUT11:transformer_blocks.0.norm2.weight,OUT11:transformer_blocks.0.norm3.bias,OUT11:transformer_blocks.0.norm3.weight,OUT11:0.bias,OUT11:0.weight,OUT11:2.bias,OUT11:2.weight
|
microsoftexcel-supermerger/scripts/__pycache__/supermerger.cpython-310.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
microsoftexcel-supermerger/scripts/mbwpresets.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preset_name preset_weights
|
| 2 |
+
GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
|
| 3 |
+
GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
|
| 4 |
+
FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
|
| 5 |
+
FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
|
| 6 |
+
WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
|
| 7 |
+
WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
|
| 8 |
+
WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
| 9 |
+
WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
|
| 10 |
+
MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
|
| 11 |
+
OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
| 12 |
+
OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
|
| 13 |
+
OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
|
| 14 |
+
RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
|
| 15 |
+
RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
|
| 16 |
+
RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
|
| 17 |
+
RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
|
| 18 |
+
SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1
|
| 19 |
+
REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0
|
| 20 |
+
SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0
|
| 21 |
+
R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1
|
| 22 |
+
SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1
|
| 23 |
+
R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0
|
| 24 |
+
SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0
|
| 25 |
+
R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1
|
| 26 |
+
SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0
|
| 27 |
+
R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1
|
| 28 |
+
SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1
|
| 29 |
+
R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0
|
| 30 |
+
SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0
|
| 31 |
+
R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1
|
| 32 |
+
COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0
|
| 33 |
+
REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1
|
| 34 |
+
TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1
|
| 35 |
+
TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0
|
| 36 |
+
FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1
|
| 37 |
+
FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0
|
| 38 |
+
ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
| 39 |
+
ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
|
microsoftexcel-supermerger/scripts/mbwpresets_master.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preset_name preset_weights
|
| 2 |
+
GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
|
| 3 |
+
GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
|
| 4 |
+
FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
|
| 5 |
+
FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
|
| 6 |
+
WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
|
| 7 |
+
WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
|
| 8 |
+
WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
| 9 |
+
WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
|
| 10 |
+
MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
|
| 11 |
+
OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
|
| 12 |
+
OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
|
| 13 |
+
OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
|
| 14 |
+
RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
|
| 15 |
+
RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
|
| 16 |
+
RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
|
| 17 |
+
RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
|
| 18 |
+
SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1
|
| 19 |
+
REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0
|
| 20 |
+
SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0
|
| 21 |
+
R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1
|
| 22 |
+
SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1
|
| 23 |
+
R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0
|
| 24 |
+
SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0
|
| 25 |
+
R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1
|
| 26 |
+
SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0
|
| 27 |
+
R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1
|
| 28 |
+
SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1
|
| 29 |
+
R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0
|
| 30 |
+
SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0
|
| 31 |
+
R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1
|
| 32 |
+
COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0
|
| 33 |
+
REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1
|
| 34 |
+
TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1
|
| 35 |
+
TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0
|
| 36 |
+
FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1
|
| 37 |
+
FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0
|
| 38 |
+
ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
| 39 |
+
ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
|
microsoftexcel-supermerger/scripts/mergers/__pycache__/mergers.cpython-310.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
microsoftexcel-supermerger/scripts/mergers/__pycache__/model_util.cpython-310.pyc
ADDED
|
Binary file (24.8 kB). View file
|
|
|
microsoftexcel-supermerger/scripts/mergers/__pycache__/pluslora.cpython-310.pyc
ADDED
|
Binary file (34.5 kB). View file
|
|
|
microsoftexcel-supermerger/scripts/mergers/__pycache__/xyplot.cpython-310.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
microsoftexcel-supermerger/scripts/mergers/mergers.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from linecache import clearcache
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os.path
|
| 7 |
+
import re
|
| 8 |
+
import torch
|
| 9 |
+
import tqdm
|
| 10 |
+
import datetime
|
| 11 |
+
import csv
|
| 12 |
+
import json
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import scipy.ndimage
|
| 15 |
+
from scipy.ndimage.filters import median_filter as filter
|
| 16 |
+
from PIL import Image, ImageFont, ImageDraw
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from modules import shared, processing, sd_models, images, sd_samplers,scripts
|
| 19 |
+
from modules.ui import plaintext_to_html
|
| 20 |
+
from modules.shared import opts
|
| 21 |
+
from modules.processing import create_infotext,Processed
|
| 22 |
+
from modules.sd_models import load_model,checkpoints_loaded
|
| 23 |
+
from scripts.mergers.model_util import usemodelgen,filenamecutter,savemodel
|
| 24 |
+
|
| 25 |
+
from inspect import currentframe
|
| 26 |
+
|
| 27 |
+
stopmerge = False
|
| 28 |
+
|
| 29 |
+
def freezemtime():
|
| 30 |
+
global stopmerge
|
| 31 |
+
stopmerge = True
|
| 32 |
+
|
| 33 |
+
mergedmodel=[]
|
| 34 |
+
TYPESEG = ["none","alpha","beta (if Triple or Twice is not selected,Twice automatically enable)","alpha and beta","seed", "mbw alpha","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks (alpha or beta must be selected for another axis)","elemental","pinpoint element","effective elemental checker","tensors","calcmode","prompt"]
|
| 35 |
+
TYPES = ["none","alpha","beta","alpha and beta","seed", "mbw alpha ","mbw beta","mbw alpha and beta", "model_A","model_B","model_C","pinpoint blocks","elemental","pinpoint element","effective","tensor","calcmode","prompt"]
|
| 36 |
+
MODES=["Weight" ,"Add" ,"Triple","Twice"]
|
| 37 |
+
SAVEMODES=["save model", "overwrite"]
|
| 38 |
+
#type[0:aplha,1:beta,2:seed,3:mbw,4:model_A,5:model_B,6:model_C]
|
| 39 |
+
#msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
|
| 40 |
+
#id sets "image", "PNG info","XY grid"
|
| 41 |
+
|
| 42 |
+
hear = False
|
| 43 |
+
hearm = False
|
| 44 |
+
non4 = [None]*4
|
| 45 |
+
|
| 46 |
+
def caster(news,hear):
|
| 47 |
+
if hear: print(news)
|
| 48 |
+
|
| 49 |
+
def casterr(*args,hear=hear):
|
| 50 |
+
if hear:
|
| 51 |
+
names = {id(v): k for k, v in currentframe().f_back.f_locals.items()}
|
| 52 |
+
print('\n'.join([names.get(id(arg), '???') + ' = ' + repr(arg) for arg in args]))
|
| 53 |
+
|
| 54 |
+
#msettings=[weights_a,weights_b,model_a,model_b,model_c,device,base_alpha,base_beta,mode,loranames,useblocks,custom_name,save_sets,id_sets,wpresets,deep]
|
| 55 |
+
def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,
|
| 56 |
+
calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 57 |
+
esettings,
|
| 58 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 59 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,
|
| 60 |
+
currentmodel,imggen):
|
| 61 |
+
|
| 62 |
+
deepprint = True if "print change" in esettings else False
|
| 63 |
+
|
| 64 |
+
result,currentmodel,modelid,theta_0,metadata = smerge(
|
| 65 |
+
weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
|
| 66 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,deepprint=deepprint
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if "ERROR" in result or "STOPPED" in result:
|
| 70 |
+
return result,"not loaded",*non4
|
| 71 |
+
|
| 72 |
+
usemodelgen(theta_0,model_a,currentmodel)
|
| 73 |
+
|
| 74 |
+
save = True if SAVEMODES[0] in save_sets else False
|
| 75 |
+
|
| 76 |
+
result = savemodel(theta_0,currentmodel,custom_name,save_sets,model_a,metadata) if save else "Merged model loaded:"+currentmodel
|
| 77 |
+
del theta_0
|
| 78 |
+
gc.collect()
|
| 79 |
+
|
| 80 |
+
if imggen :
|
| 81 |
+
images = simggen(prompt,nprompt,steps,sampler,cfg,seed,w,h,hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,currentmodel,id_sets,modelid)
|
| 82 |
+
return result,currentmodel,*images[:4]
|
| 83 |
+
else:
|
| 84 |
+
return result,currentmodel
|
| 85 |
+
|
| 86 |
+
NUM_INPUT_BLOCKS = 12
|
| 87 |
+
NUM_MID_BLOCK = 1
|
| 88 |
+
NUM_OUTPUT_BLOCKS = 12
|
| 89 |
+
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
|
| 90 |
+
blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
| 91 |
+
|
| 92 |
+
def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
|
| 93 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,deepprint = False):
|
| 94 |
+
caster("merge start",hearm)
|
| 95 |
+
global hear,mergedmodel,stopmerge
|
| 96 |
+
stopmerge = False
|
| 97 |
+
|
| 98 |
+
gc.collect()
|
| 99 |
+
|
| 100 |
+
# for from file
|
| 101 |
+
if type(useblocks) is str:
|
| 102 |
+
useblocks = True if useblocks =="True" else False
|
| 103 |
+
if type(base_alpha) == str:base_alpha = float(base_alpha)
|
| 104 |
+
if type(base_beta) == str:base_beta = float(base_beta)
|
| 105 |
+
|
| 106 |
+
weights_a_orig = weights_a
|
| 107 |
+
weights_b_orig = weights_b
|
| 108 |
+
|
| 109 |
+
# preset to weights
|
| 110 |
+
if wpresets != False and useblocks:
|
| 111 |
+
weights_a = wpreseter(weights_a,wpresets)
|
| 112 |
+
weights_b = wpreseter(weights_b,wpresets)
|
| 113 |
+
|
| 114 |
+
# mode select booleans
|
| 115 |
+
save = True if SAVEMODES[0] in save_sets else False
|
| 116 |
+
usebeta = MODES[2] in mode or MODES[3] in mode or calcmode == "tensor"
|
| 117 |
+
save_metadata = "save metadata" in save_sets
|
| 118 |
+
metadata = {"format": "pt"}
|
| 119 |
+
|
| 120 |
+
if not useblocks:
|
| 121 |
+
weights_a = weights_b = ""
|
| 122 |
+
#for save log and save current model
|
| 123 |
+
mergedmodel =[weights_a,weights_b,
|
| 124 |
+
hashfromname(model_a),hashfromname(model_b),hashfromname(model_c),
|
| 125 |
+
base_alpha,base_beta,mode,useblocks,custom_name,save_sets,id_sets,deep,calcmode,tensor].copy()
|
| 126 |
+
|
| 127 |
+
model_a = namefromhash(model_a)
|
| 128 |
+
model_b = namefromhash(model_b)
|
| 129 |
+
model_c = namefromhash(model_c)
|
| 130 |
+
theta_2 = {}
|
| 131 |
+
|
| 132 |
+
caster(mergedmodel,False)
|
| 133 |
+
|
| 134 |
+
if len(deep) > 0:
|
| 135 |
+
deep = deep.replace("\n",",")
|
| 136 |
+
deep = deep.split(",")
|
| 137 |
+
|
| 138 |
+
#format check
|
| 139 |
+
if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
|
| 140 |
+
return "ERROR: Necessary model is not selected",*non4
|
| 141 |
+
|
| 142 |
+
#for MBW text to list
|
| 143 |
+
if useblocks:
|
| 144 |
+
weights_a_t=weights_a.split(',',1)
|
| 145 |
+
weights_b_t=weights_b.split(',',1)
|
| 146 |
+
base_alpha = float(weights_a_t[0])
|
| 147 |
+
weights_a = [float(w) for w in weights_a_t[1].split(',')]
|
| 148 |
+
caster(f"from {weights_a_t}, alpha = {base_alpha},weights_a ={weights_a}",hearm)
|
| 149 |
+
if len(weights_a) != 25:return f"ERROR: weights alpha value must be {26}.",*non4
|
| 150 |
+
if usebeta:
|
| 151 |
+
base_beta = float(weights_b_t[0])
|
| 152 |
+
weights_b = [float(w) for w in weights_b_t[1].split(',')]
|
| 153 |
+
caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
|
| 154 |
+
if len(weights_b) != 25: return f"ERROR: weights beta value must be {26}.",*non4
|
| 155 |
+
|
| 156 |
+
caster("model load start",hearm)
|
| 157 |
+
|
| 158 |
+
print(f" model A \t: {model_a}")
|
| 159 |
+
print(f" model B \t: {model_b}")
|
| 160 |
+
print(f" model C \t: {model_c}")
|
| 161 |
+
print(f" alpha,beta\t: {base_alpha,base_beta}")
|
| 162 |
+
print(f" weights_alpha\t: {weights_a}")
|
| 163 |
+
print(f" weights_beta\t: {weights_b}")
|
| 164 |
+
print(f" mode\t\t: {mode}")
|
| 165 |
+
print(f" MBW \t\t: {useblocks}")
|
| 166 |
+
print(f" CalcMode \t: {calcmode}")
|
| 167 |
+
print(f" Elemental \t: {deep}")
|
| 168 |
+
print(f" Tensors \t: {tensor}")
|
| 169 |
+
|
| 170 |
+
theta_1=load_model_weights_m(model_b,False,True,save).copy()
|
| 171 |
+
|
| 172 |
+
if MODES[1] in mode:#Add
|
| 173 |
+
if stopmerge: return "STOPPED", *non4
|
| 174 |
+
theta_2 = load_model_weights_m(model_c,False,False,save).copy()
|
| 175 |
+
for key in tqdm(theta_1.keys()):
|
| 176 |
+
if 'model' in key:
|
| 177 |
+
if key in theta_2:
|
| 178 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
| 179 |
+
theta_1[key] = theta_1[key]- t2
|
| 180 |
+
else:
|
| 181 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
| 182 |
+
del theta_2
|
| 183 |
+
|
| 184 |
+
if stopmerge: return "STOPPED", *non4
|
| 185 |
+
|
| 186 |
+
if calcmode == "tensor":
|
| 187 |
+
theta_t = load_model_weights_m(model_a,True,False,save).copy()
|
| 188 |
+
theta_0 ={}
|
| 189 |
+
for key in theta_t:
|
| 190 |
+
theta_0[key] = theta_t[key].clone()
|
| 191 |
+
del theta_t
|
| 192 |
+
else:
|
| 193 |
+
theta_0=load_model_weights_m(model_a,True,False,save).copy()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
|
| 197 |
+
theta_2 = load_model_weights_m(model_c,False,False,save).copy()
|
| 198 |
+
|
| 199 |
+
alpha = base_alpha
|
| 200 |
+
beta = base_beta
|
| 201 |
+
|
| 202 |
+
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
|
| 203 |
+
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
|
| 204 |
+
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
| 205 |
+
|
| 206 |
+
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
| 207 |
+
count_target_of_basealpha = 0
|
| 208 |
+
|
| 209 |
+
if calcmode =="cosineA": #favors modelA's structure with details from B
|
| 210 |
+
if stopmerge: return "STOPPED", *non4
|
| 211 |
+
sim = torch.nn.CosineSimilarity(dim=0)
|
| 212 |
+
sims = np.array([], dtype=np.float64)
|
| 213 |
+
for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
|
| 214 |
+
# skip VAE model parameters to get better results
|
| 215 |
+
if "first_stage_model" in key: continue
|
| 216 |
+
if "model" in key and key in theta_1:
|
| 217 |
+
theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
|
| 218 |
+
theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
|
| 219 |
+
simab = sim(theta_0_norm, theta_1_norm)
|
| 220 |
+
sims = np.append(sims,simab.numpy())
|
| 221 |
+
sims = sims[~np.isnan(sims)]
|
| 222 |
+
sims = np.delete(sims, np.where(sims<np.percentile(sims, 1 ,method = 'midpoint')))
|
| 223 |
+
sims = np.delete(sims, np.where(sims>np.percentile(sims, 99 ,method = 'midpoint')))
|
| 224 |
+
|
| 225 |
+
if calcmode =="cosineB": #favors modelB's structure with details from A
|
| 226 |
+
if stopmerge: return "STOPPED", *non4
|
| 227 |
+
sim = torch.nn.CosineSimilarity(dim=0)
|
| 228 |
+
sims = np.array([], dtype=np.float64)
|
| 229 |
+
for key in (tqdm(theta_0.keys(), desc="Stage 0/2")):
|
| 230 |
+
# skip VAE model parameters to get better results
|
| 231 |
+
if "first_stage_model" in key: continue
|
| 232 |
+
if "model" in key and key in theta_1:
|
| 233 |
+
simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
|
| 234 |
+
dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
|
| 235 |
+
magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
|
| 236 |
+
combined_similarity = (simab + magnitude_similarity) / 2.0
|
| 237 |
+
sims = np.append(sims, combined_similarity.numpy())
|
| 238 |
+
sims = sims[~np.isnan(sims)]
|
| 239 |
+
sims = np.delete(sims, np.where(sims < np.percentile(sims, 1, method='midpoint')))
|
| 240 |
+
sims = np.delete(sims, np.where(sims > np.percentile(sims, 99, method='midpoint')))
|
| 241 |
+
|
| 242 |
+
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
|
| 243 |
+
if stopmerge: return "STOPPED", *non4
|
| 244 |
+
if "model" in key and key in theta_1:
|
| 245 |
+
if usebeta and (not key in theta_2) and (not theta_2 == {}) :
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
weight_index = -1
|
| 249 |
+
current_alpha = alpha
|
| 250 |
+
current_beta = beta
|
| 251 |
+
|
| 252 |
+
if key in chckpoint_dict_skip_on_merge:
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
# check weighted and U-Net or not
|
| 256 |
+
if weights_a is not None and 'model.diffusion_model.' in key:
|
| 257 |
+
# check block index
|
| 258 |
+
weight_index = -1
|
| 259 |
+
|
| 260 |
+
if 'time_embed' in key:
|
| 261 |
+
weight_index = 0 # before input blocks
|
| 262 |
+
elif '.out.' in key:
|
| 263 |
+
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
|
| 264 |
+
else:
|
| 265 |
+
m = re_inp.search(key)
|
| 266 |
+
if m:
|
| 267 |
+
inp_idx = int(m.groups()[0])
|
| 268 |
+
weight_index = inp_idx
|
| 269 |
+
else:
|
| 270 |
+
m = re_mid.search(key)
|
| 271 |
+
if m:
|
| 272 |
+
weight_index = NUM_INPUT_BLOCKS
|
| 273 |
+
else:
|
| 274 |
+
m = re_out.search(key)
|
| 275 |
+
if m:
|
| 276 |
+
out_idx = int(m.groups()[0])
|
| 277 |
+
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
|
| 278 |
+
|
| 279 |
+
if weight_index >= NUM_TOTAL_BLOCKS:
|
| 280 |
+
print(f"ERROR: illegal block index: {key}")
|
| 281 |
+
return f"ERROR: illegal block index: {key}",*non4
|
| 282 |
+
|
| 283 |
+
if weight_index >= 0 and useblocks:
|
| 284 |
+
current_alpha = weights_a[weight_index]
|
| 285 |
+
if usebeta: current_beta = weights_b[weight_index]
|
| 286 |
+
else:
|
| 287 |
+
count_target_of_basealpha = count_target_of_basealpha + 1
|
| 288 |
+
|
| 289 |
+
if len(deep) > 0:
|
| 290 |
+
skey = key + blockid[weight_index+1]
|
| 291 |
+
for d in deep:
|
| 292 |
+
if d.count(":") != 2 :continue
|
| 293 |
+
dbs,dws,dr = d.split(":")[0],d.split(":")[1],d.split(":")[2]
|
| 294 |
+
dbs,dws = dbs.split(" "), dws.split(" ")
|
| 295 |
+
dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
|
| 296 |
+
dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
|
| 297 |
+
flag = dbn
|
| 298 |
+
for db in dbs:
|
| 299 |
+
if db in skey:
|
| 300 |
+
flag = not dbn
|
| 301 |
+
if flag:flag = dwn
|
| 302 |
+
else:continue
|
| 303 |
+
for dw in dws:
|
| 304 |
+
if dw in skey:
|
| 305 |
+
flag = not dwn
|
| 306 |
+
if flag:
|
| 307 |
+
dr = float(dr)
|
| 308 |
+
if deepprint :print(dbs,dws,key,dr)
|
| 309 |
+
current_alpha = dr
|
| 310 |
+
|
| 311 |
+
if calcmode == "normal":
|
| 312 |
+
if MODES[1] in mode:#Add
|
| 313 |
+
caster(f"model A[{key}] + {current_alpha} + * (model B - model C)[{key}]",hear)
|
| 314 |
+
theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
|
| 315 |
+
elif MODES[2] in mode:#Triple
|
| 316 |
+
caster(f"model A[{key}] + {1-current_alpha-current_beta} + model B[{key}]*{current_alpha} + model C[{key}]*{current_beta}",hear)
|
| 317 |
+
theta_0[key] = (1 - current_alpha-current_beta) * theta_0[key] + current_alpha * theta_1[key]+current_beta * theta_2[key]
|
| 318 |
+
elif MODES[3] in mode:#Twice
|
| 319 |
+
caster(f"model A[{key}] + {1-current_alpha} + * model B[{key}]*{alpha}",hear)
|
| 320 |
+
caster(f"model A+B[{key}] + {1-current_beta} + * model C[{key}]*{beta}",hear)
|
| 321 |
+
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
|
| 322 |
+
theta_0[key] = (1 - current_beta) * theta_0[key] + current_beta * theta_2[key]
|
| 323 |
+
else:#Weight
|
| 324 |
+
if current_alpha == 1:
|
| 325 |
+
caster(f"alpha = 0,model A[{key}=model B[{key}",hear)
|
| 326 |
+
theta_0[key] = theta_1[key]
|
| 327 |
+
elif current_alpha !=0:
|
| 328 |
+
caster(f"model A[{key}] + {1-current_alpha} + * (model B)[{key}]*{alpha}",hear)
|
| 329 |
+
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
|
| 330 |
+
|
| 331 |
+
elif calcmode == "cosineA": #favors modelA's structure with details from B
|
| 332 |
+
# skip VAE model parameters to get better results
|
| 333 |
+
if "first_stage_model" in key: continue
|
| 334 |
+
if "model" in key and key in theta_0:
|
| 335 |
+
# Normalize the vectors before merging
|
| 336 |
+
theta_0_norm = nn.functional.normalize(theta_0[key].to(torch.float32), p=2, dim=0)
|
| 337 |
+
theta_1_norm = nn.functional.normalize(theta_1[key].to(torch.float32), p=2, dim=0)
|
| 338 |
+
simab = sim(theta_0_norm, theta_1_norm)
|
| 339 |
+
dot_product = torch.dot(theta_0_norm.view(-1), theta_1_norm.view(-1))
|
| 340 |
+
magnitude_similarity = dot_product / (torch.norm(theta_0_norm) * torch.norm(theta_1_norm))
|
| 341 |
+
combined_similarity = (simab + magnitude_similarity) / 2.0
|
| 342 |
+
k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
|
| 343 |
+
k = k - current_alpha
|
| 344 |
+
k = k.clip(min=.0,max=1.)
|
| 345 |
+
caster(f"model A[{key}] + {1-k} + * (model B)[{key}]*{k}",hear)
|
| 346 |
+
theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
|
| 347 |
+
|
| 348 |
+
elif calcmode == "cosineB": #favors modelB's structure with details from A
|
| 349 |
+
# skip VAE model parameters to get better results
|
| 350 |
+
if "first_stage_model" in key: continue
|
| 351 |
+
if "model" in key and key in theta_0:
|
| 352 |
+
simab = sim(theta_0[key].to(torch.float32), theta_1[key].to(torch.float32))
|
| 353 |
+
dot_product = torch.dot(theta_0[key].view(-1).to(torch.float32), theta_1[key].view(-1).to(torch.float32))
|
| 354 |
+
magnitude_similarity = dot_product / (torch.norm(theta_0[key].to(torch.float32)) * torch.norm(theta_1[key].to(torch.float32)))
|
| 355 |
+
combined_similarity = (simab + magnitude_similarity) / 2.0
|
| 356 |
+
k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
|
| 357 |
+
k = k - current_alpha
|
| 358 |
+
k = k.clip(min=.0,max=1.)
|
| 359 |
+
caster(f"model A[{key}] + {1-k} + * (model B)[{key}]*{k}",hear)
|
| 360 |
+
theta_0[key] = theta_1[key] * (1 - k) + theta_0[key] * k
|
| 361 |
+
|
| 362 |
+
elif calcmode == "smoothAdd":
|
| 363 |
+
caster(f"model A[{key}] + {current_alpha} + * (model B - model C)[{key}]", hear)
|
| 364 |
+
# Apply median filter to the weight differences
|
| 365 |
+
filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
|
| 366 |
+
# Apply Gaussian filter to the filtered differences
|
| 367 |
+
filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
|
| 368 |
+
theta_1[key] = torch.tensor(filtered_diff)
|
| 369 |
+
# Add the filtered differences to the original weights
|
| 370 |
+
theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
|
| 371 |
+
|
| 372 |
+
elif calcmode == "tensor":
|
| 373 |
+
dim = theta_0[key].dim()
|
| 374 |
+
if dim == 0 : continue
|
| 375 |
+
if current_alpha+current_beta <= 1 :
|
| 376 |
+
talphas = int(theta_0[key].shape[0]*(current_beta))
|
| 377 |
+
talphae = int(theta_0[key].shape[0]*(current_alpha+current_beta))
|
| 378 |
+
if dim == 1:
|
| 379 |
+
theta_0[key][talphas:talphae] = theta_1[key][talphas:talphae].clone()
|
| 380 |
+
|
| 381 |
+
elif dim == 2:
|
| 382 |
+
theta_0[key][talphas:talphae,:] = theta_1[key][talphas:talphae,:].clone()
|
| 383 |
+
|
| 384 |
+
elif dim == 3:
|
| 385 |
+
theta_0[key][talphas:talphae,:,:] = theta_1[key][talphas:talphae,:,:].clone()
|
| 386 |
+
|
| 387 |
+
elif dim == 4:
|
| 388 |
+
theta_0[key][talphas:talphae,:,:,:] = theta_1[key][talphas:talphae,:,:,:].clone()
|
| 389 |
+
|
| 390 |
+
else:
|
| 391 |
+
talphas = int(theta_0[key].shape[0]*(current_alpha+current_beta-1))
|
| 392 |
+
talphae = int(theta_0[key].shape[0]*(current_beta))
|
| 393 |
+
theta_t = theta_1[key].clone()
|
| 394 |
+
if dim == 1:
|
| 395 |
+
theta_t[talphas:talphae] = theta_0[key][talphas:talphae].clone()
|
| 396 |
+
|
| 397 |
+
elif dim == 2:
|
| 398 |
+
theta_t[talphas:talphae,:] = theta_0[key][talphas:talphae,:].clone()
|
| 399 |
+
|
| 400 |
+
elif dim == 3:
|
| 401 |
+
theta_t[talphas:talphae,:,:] = theta_0[key][talphas:talphae,:,:].clone()
|
| 402 |
+
|
| 403 |
+
elif dim == 4:
|
| 404 |
+
theta_t[talphas:talphae,:,:,:] = theta_0[key][talphas:talphae,:,:,:].clone()
|
| 405 |
+
theta_0[key] = theta_t
|
| 406 |
+
|
| 407 |
+
currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode)
|
| 408 |
+
|
| 409 |
+
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
|
| 410 |
+
if key in chckpoint_dict_skip_on_merge:
|
| 411 |
+
continue
|
| 412 |
+
if "model" in key and key not in theta_0:
|
| 413 |
+
theta_0.update({key:theta_1[key]})
|
| 414 |
+
|
| 415 |
+
del theta_1
|
| 416 |
+
|
| 417 |
+
modelid = rwmergelog(currentmodel,mergedmodel)
|
| 418 |
+
|
| 419 |
+
caster(mergedmodel,False)
|
| 420 |
+
|
| 421 |
+
if save_metadata:
|
| 422 |
+
merge_recipe = {
|
| 423 |
+
"type": "sd-webui-supermerger",
|
| 424 |
+
"weights_alpha": weights_a if useblocks else None,
|
| 425 |
+
"weights_beta": weights_b if useblocks else None,
|
| 426 |
+
"weights_alpha_orig": weights_a_orig if useblocks else None,
|
| 427 |
+
"weights_beta_orig": weights_b_orig if useblocks else None,
|
| 428 |
+
"model_a": longhashfromname(model_a),
|
| 429 |
+
"model_b": longhashfromname(model_b),
|
| 430 |
+
"model_c": longhashfromname(model_c),
|
| 431 |
+
"base_alpha": base_alpha,
|
| 432 |
+
"base_beta": base_beta,
|
| 433 |
+
"mode": mode,
|
| 434 |
+
"mbw": useblocks,
|
| 435 |
+
"elemental_merge": deep,
|
| 436 |
+
"calcmode" : calcmode
|
| 437 |
+
}
|
| 438 |
+
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
| 439 |
+
metadata["sd_merge_models"] = {}
|
| 440 |
+
|
| 441 |
+
def add_model_metadata(checkpoint_name):
|
| 442 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(checkpoint_name)
|
| 443 |
+
checkpoint_info.calculate_shorthash()
|
| 444 |
+
metadata["sd_merge_models"][checkpoint_info.sha256] = {
|
| 445 |
+
"name": checkpoint_name,
|
| 446 |
+
"legacy_hash": checkpoint_info.hash
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
#metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
| 450 |
+
|
| 451 |
+
if model_a:
|
| 452 |
+
add_model_metadata(model_a)
|
| 453 |
+
if model_b:
|
| 454 |
+
add_model_metadata(model_b)
|
| 455 |
+
if model_c:
|
| 456 |
+
add_model_metadata(model_c)
|
| 457 |
+
|
| 458 |
+
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
|
| 459 |
+
|
| 460 |
+
return "",currentmodel,modelid,theta_0,metadata
|
| 461 |
+
def forkforker(filename):
|
| 462 |
+
try:
|
| 463 |
+
return sd_models.read_state_dict(filename,"cuda")
|
| 464 |
+
except:
|
| 465 |
+
return sd_models.read_state_dict(filename)
|
| 466 |
+
|
| 467 |
+
def load_model_weights_m(model,model_a,model_b,save):
|
| 468 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(model)
|
| 469 |
+
sd_model_name = checkpoint_info.model_name
|
| 470 |
+
|
| 471 |
+
cachenum = shared.opts.sd_checkpoint_cache
|
| 472 |
+
|
| 473 |
+
if save:
|
| 474 |
+
if model_a:
|
| 475 |
+
load_model(checkpoint_info)
|
| 476 |
+
print(f"Loading weights [{sd_model_name}] from file")
|
| 477 |
+
return forkforker(checkpoint_info.filename)
|
| 478 |
+
|
| 479 |
+
if checkpoint_info in checkpoints_loaded:
|
| 480 |
+
print(f"Loading weights [{sd_model_name}] from cache")
|
| 481 |
+
return checkpoints_loaded[checkpoint_info]
|
| 482 |
+
elif cachenum>0 and model_a:
|
| 483 |
+
load_model(checkpoint_info)
|
| 484 |
+
print(f"Loading weights [{sd_model_name}] from cache")
|
| 485 |
+
return checkpoints_loaded[checkpoint_info]
|
| 486 |
+
elif cachenum>1 and model_b:
|
| 487 |
+
load_model(checkpoint_info)
|
| 488 |
+
print(f"Loading weights [{sd_model_name}] from cache")
|
| 489 |
+
return checkpoints_loaded[checkpoint_info]
|
| 490 |
+
elif cachenum>2:
|
| 491 |
+
load_model(checkpoint_info)
|
| 492 |
+
print(f"Loading weights [{sd_model_name}] from cache")
|
| 493 |
+
return checkpoints_loaded[checkpoint_info]
|
| 494 |
+
else:
|
| 495 |
+
if model_a:
|
| 496 |
+
load_model(checkpoint_info)
|
| 497 |
+
print(f"Loading weights [{sd_model_name}] from file")
|
| 498 |
+
return forkforker(checkpoint_info.filename)
|
| 499 |
+
|
| 500 |
+
def makemodelname(weights_a,weights_b,model_a, model_b,model_c, alpha,beta,useblocks,mode):
|
| 501 |
+
model_a=filenamecutter(model_a)
|
| 502 |
+
model_b=filenamecutter(model_b)
|
| 503 |
+
model_c=filenamecutter(model_c)
|
| 504 |
+
|
| 505 |
+
if type(alpha) == str:alpha = float(alpha)
|
| 506 |
+
if type(beta)== str:beta = float(beta)
|
| 507 |
+
|
| 508 |
+
if useblocks:
|
| 509 |
+
if MODES[1] in mode:#add
|
| 510 |
+
currentmodel =f"{model_a} + ({model_b} - {model_c}) x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)}"
|
| 511 |
+
elif MODES[2] in mode:#triple
|
| 512 |
+
currentmodel =f"{model_a} x (1-alpha-beta) + {model_b} x alpha + {model_c} x beta (alpha = {str(round(alpha,3))},{','.join(str(s) for s in weights_a)},beta = {beta},{','.join(str(s) for s in weights_b)})"
|
| 513 |
+
elif MODES[3] in mode:#twice
|
| 514 |
+
currentmodel =f"({model_a} x (1-alpha) + {model_b} x alpha)x(1-beta)+ {model_c} x beta ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})_({str(round(beta,3))},{','.join(str(s) for s in weights_b)})"
|
| 515 |
+
else:
|
| 516 |
+
currentmodel =f"{model_a} x (1-alpha) + {model_b} x alpha ({str(round(alpha,3))},{','.join(str(s) for s in weights_a)})"
|
| 517 |
+
else:
|
| 518 |
+
if MODES[1] in mode:#add
|
| 519 |
+
currentmodel =f"{model_a} + ({model_b} - {model_c}) x {str(round(alpha,3))}"
|
| 520 |
+
elif MODES[2] in mode:#triple
|
| 521 |
+
currentmodel =f"{model_a} x {str(round(1-alpha-beta,3))} + {model_b} x {str(round(alpha,3))} + {model_c} x {str(round(beta,3))}"
|
| 522 |
+
elif MODES[3] in mode:#twice
|
| 523 |
+
currentmodel =f"({model_a} x {str(round(1-alpha,3))} +{model_b} x {str(round(alpha,3))}) x {str(round(1-beta,3))} + {model_c} x {str(round(beta,3))}"
|
| 524 |
+
else:
|
| 525 |
+
currentmodel =f"{model_a} x {str(round(1-alpha,3))} + {model_b} x {str(round(alpha,3))}"
|
| 526 |
+
return currentmodel
|
| 527 |
+
|
| 528 |
+
path_root = scripts.basedir()
|
| 529 |
+
|
| 530 |
+
def rwmergelog(mergedname = "",settings= [],id = 0):
|
| 531 |
+
setting = settings.copy()
|
| 532 |
+
filepath = os.path.join(path_root, "mergehistory.csv")
|
| 533 |
+
is_file = os.path.isfile(filepath)
|
| 534 |
+
if not is_file:
|
| 535 |
+
with open(filepath, 'a') as f:
|
| 536 |
+
#msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets, 12 deep 13 calcmode]
|
| 537 |
+
f.writelines('"ID","time","name","weights alpha","weights beta","model A","model B","model C","alpha","beta","mode","use MBW","plus lora","custum name","save setting","use ID"\n')
|
| 538 |
+
with open(filepath, 'r+') as f:
|
| 539 |
+
reader = csv.reader(f)
|
| 540 |
+
mlist = [raw for raw in reader]
|
| 541 |
+
if mergedname != "":
|
| 542 |
+
mergeid = len(mlist)
|
| 543 |
+
setting.insert(0,mergedname)
|
| 544 |
+
for i,x in enumerate(setting):
|
| 545 |
+
if "," in str(x):setting[i] = f'"{str(setting[i])}"'
|
| 546 |
+
text = ",".join(map(str, setting))
|
| 547 |
+
text=str(mergeid)+","+datetime.datetime.now().strftime('%Y.%m.%d %H.%M.%S.%f')[:-7]+"," + text + "\n"
|
| 548 |
+
f.writelines(text)
|
| 549 |
+
return mergeid
|
| 550 |
+
try:
|
| 551 |
+
out = mlist[int(id)]
|
| 552 |
+
except:
|
| 553 |
+
out = "ERROR: OUT of ID index"
|
| 554 |
+
return out
|
| 555 |
+
|
| 556 |
+
def draw_origin(grid, text,width,height,width_one):
|
| 557 |
+
grid_d= Image.new("RGB", (grid.width,grid.height), "white")
|
| 558 |
+
grid_d.paste(grid,(0,0))
|
| 559 |
+
def get_font(fontsize):
|
| 560 |
+
try:
|
| 561 |
+
from fonts.ttf import Roboto
|
| 562 |
+
try:
|
| 563 |
+
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
| 564 |
+
except Exception:
|
| 565 |
+
return ImageFont.truetype(Roboto, fontsize)
|
| 566 |
+
except Exception:
|
| 567 |
+
try:
|
| 568 |
+
return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
|
| 569 |
+
except Exception:
|
| 570 |
+
return ImageFont.truetype('javascript/roboto.ttf', fontsize)
|
| 571 |
+
|
| 572 |
+
d= ImageDraw.Draw(grid_d)
|
| 573 |
+
color_active = (0, 0, 0)
|
| 574 |
+
fontsize = (width+height)//25
|
| 575 |
+
fnt = get_font(fontsize)
|
| 576 |
+
|
| 577 |
+
if grid.width != width_one:
|
| 578 |
+
while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
|
| 579 |
+
fontsize -=1
|
| 580 |
+
fnt = get_font(fontsize)
|
| 581 |
+
d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
|
| 582 |
+
return grid_d
|
| 583 |
+
|
| 584 |
+
def wpreseter(w,presets):
|
| 585 |
+
if "," not in w and w != "":
|
| 586 |
+
presets=presets.splitlines()
|
| 587 |
+
wdict={}
|
| 588 |
+
for l in presets:
|
| 589 |
+
if ":" in l :
|
| 590 |
+
key = l.split(":",1)[0]
|
| 591 |
+
wdict[key.strip()]=l.split(":",1)[1]
|
| 592 |
+
if "\t" in l:
|
| 593 |
+
key = l.split("\t",1)[0]
|
| 594 |
+
wdict[key.strip()]=l.split("\t",1)[1]
|
| 595 |
+
if w.strip() in wdict:
|
| 596 |
+
name = w
|
| 597 |
+
w = wdict[w.strip()]
|
| 598 |
+
print(f"weights {name} imported from presets : {w}")
|
| 599 |
+
return w
|
| 600 |
+
|
| 601 |
+
def fullpathfromname(name):
|
| 602 |
+
if hash == "" or hash ==[]: return ""
|
| 603 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
|
| 604 |
+
return checkpoint_info.filename
|
| 605 |
+
|
| 606 |
+
def namefromhash(hash):
|
| 607 |
+
if hash == "" or hash ==[]: return ""
|
| 608 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(hash)
|
| 609 |
+
return checkpoint_info.model_name
|
| 610 |
+
|
| 611 |
+
def hashfromname(name):
|
| 612 |
+
from modules import sd_models
|
| 613 |
+
if name == "" or name ==[]: return ""
|
| 614 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
|
| 615 |
+
if checkpoint_info.shorthash is not None:
|
| 616 |
+
return checkpoint_info.shorthash
|
| 617 |
+
return checkpoint_info.calculate_shorthash()
|
| 618 |
+
|
| 619 |
+
def longhashfromname(name):
|
| 620 |
+
from modules import sd_models
|
| 621 |
+
if name == "" or name ==[]: return ""
|
| 622 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
|
| 623 |
+
if checkpoint_info.sha256 is not None:
|
| 624 |
+
return checkpoint_info.sha256
|
| 625 |
+
checkpoint_info.calculate_shorthash()
|
| 626 |
+
return checkpoint_info.sha256
|
| 627 |
+
|
| 628 |
+
def simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,genoptions,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,mergeinfo="",id_sets=[],modelid = "no id"):
|
| 629 |
+
shared.state.begin()
|
| 630 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
| 631 |
+
sd_model=shared.sd_model,
|
| 632 |
+
do_not_save_grid=True,
|
| 633 |
+
do_not_save_samples=True,
|
| 634 |
+
do_not_reload_embeddings=True,
|
| 635 |
+
)
|
| 636 |
+
p.batch_size = int(batch_size)
|
| 637 |
+
p.prompt = prompt
|
| 638 |
+
p.negative_prompt = nprompt
|
| 639 |
+
p.steps = steps
|
| 640 |
+
p.sampler_name = sd_samplers.samplers[sampler].name
|
| 641 |
+
p.cfg_scale = cfg
|
| 642 |
+
p.seed = seed
|
| 643 |
+
p.width = w
|
| 644 |
+
p.height = h
|
| 645 |
+
p.seed_resize_from_w=0
|
| 646 |
+
p.seed_resize_from_h=0
|
| 647 |
+
p.denoising_strength=None
|
| 648 |
+
|
| 649 |
+
#"Restore faces", "Tiling", "Hires. fix"
|
| 650 |
+
|
| 651 |
+
if "Hires. fix" in genoptions:
|
| 652 |
+
p.enable_hr = True
|
| 653 |
+
p.denoising_strength = denoise_str
|
| 654 |
+
p.hr_upscaler = hrupscaler
|
| 655 |
+
p.hr_second_pass_steps = hr2ndsteps
|
| 656 |
+
p.hr_scale = hr_scale
|
| 657 |
+
|
| 658 |
+
if "Tiling" in genoptions:
|
| 659 |
+
p.tiling = True
|
| 660 |
+
|
| 661 |
+
if "Restore faces" in genoptions:
|
| 662 |
+
p.restore_faces = True
|
| 663 |
+
|
| 664 |
+
if type(p.prompt) == list:
|
| 665 |
+
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
| 666 |
+
else:
|
| 667 |
+
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
| 668 |
+
|
| 669 |
+
if type(p.negative_prompt) == list:
|
| 670 |
+
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
| 671 |
+
else:
|
| 672 |
+
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
| 673 |
+
|
| 674 |
+
processed:Processed = processing.process_images(p)
|
| 675 |
+
if "image" in id_sets:
|
| 676 |
+
for i, image in enumerate(processed.images):
|
| 677 |
+
processed.images[i] = draw_origin(image, str(modelid),w,h,w)
|
| 678 |
+
|
| 679 |
+
if "PNG info" in id_sets:mergeinfo = mergeinfo + " ID " + str(modelid)
|
| 680 |
+
|
| 681 |
+
infotext = create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds)
|
| 682 |
+
if infotext.count("Steps: ")>1:
|
| 683 |
+
infotext = infotext[:infotext.rindex("Steps")]
|
| 684 |
+
|
| 685 |
+
infotexts = infotext.split(",")
|
| 686 |
+
for i,x in enumerate(infotexts):
|
| 687 |
+
if "Model:"in x:
|
| 688 |
+
infotexts[i] = " Model: "+mergeinfo.replace(","," ")
|
| 689 |
+
infotext= ",".join(infotexts)
|
| 690 |
+
|
| 691 |
+
for i, image in enumerate(processed.images):
|
| 692 |
+
images.save_image(image, opts.outdir_txt2img_samples, "",p.seed, p.prompt,shared.opts.samples_format, p=p,info=infotext)
|
| 693 |
+
|
| 694 |
+
if batch_size > 1:
|
| 695 |
+
grid = images.image_grid(processed.images, p.batch_size)
|
| 696 |
+
processed.images.insert(0, grid)
|
| 697 |
+
images.save_image(grid, opts.outdir_txt2img_grids, "grid", p.seed, p.prompt, opts.grid_format, info=infotext, short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
| 698 |
+
shared.state.end()
|
| 699 |
+
return processed.images,infotext,plaintext_to_html(processed.info), plaintext_to_html(processed.comments),p
|
microsoftexcel-supermerger/scripts/mergers/model_util.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import CLIPTextModel, CLIPTextConfig
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import safetensors.torch
|
| 6 |
+
from modules.sd_models import read_state_dict
|
| 7 |
+
|
| 8 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
| 9 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
| 10 |
+
BETA_START = 0.00085
|
| 11 |
+
BETA_END = 0.0120
|
| 12 |
+
|
| 13 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
| 14 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
| 15 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
| 16 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
| 17 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
| 18 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
| 19 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
| 20 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
| 21 |
+
UNET_PARAMS_NUM_HEADS = 8
|
| 22 |
+
|
| 23 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
| 24 |
+
VAE_PARAMS_RESOLUTION = 256
|
| 25 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
| 26 |
+
VAE_PARAMS_OUT_CH = 3
|
| 27 |
+
VAE_PARAMS_CH = 128
|
| 28 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
| 29 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
| 30 |
+
|
| 31 |
+
# V2
|
| 32 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
| 33 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
| 34 |
+
|
| 35 |
+
# Diffusersの設定を読み込むための参照モデル
|
| 36 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
| 37 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# region StableDiffusion->Diffusersの変換コード
|
| 41 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 45 |
+
"""
|
| 46 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 47 |
+
"""
|
| 48 |
+
if n_shave_prefix_segments >= 0:
|
| 49 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 50 |
+
else:
|
| 51 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 55 |
+
"""
|
| 56 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 57 |
+
"""
|
| 58 |
+
mapping = []
|
| 59 |
+
for old_item in old_list:
|
| 60 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 61 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 62 |
+
|
| 63 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 64 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 65 |
+
|
| 66 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 67 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 68 |
+
|
| 69 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 70 |
+
|
| 71 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 72 |
+
|
| 73 |
+
return mapping
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 77 |
+
"""
|
| 78 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 79 |
+
"""
|
| 80 |
+
mapping = []
|
| 81 |
+
for old_item in old_list:
|
| 82 |
+
new_item = old_item
|
| 83 |
+
|
| 84 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 85 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 86 |
+
|
| 87 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 88 |
+
|
| 89 |
+
return mapping
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 93 |
+
"""
|
| 94 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 95 |
+
"""
|
| 96 |
+
mapping = []
|
| 97 |
+
for old_item in old_list:
|
| 98 |
+
new_item = old_item
|
| 99 |
+
|
| 100 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 101 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 102 |
+
|
| 103 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 104 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 105 |
+
|
| 106 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 107 |
+
|
| 108 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 109 |
+
|
| 110 |
+
return mapping
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 114 |
+
"""
|
| 115 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 116 |
+
"""
|
| 117 |
+
mapping = []
|
| 118 |
+
for old_item in old_list:
|
| 119 |
+
new_item = old_item
|
| 120 |
+
|
| 121 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 122 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 123 |
+
|
| 124 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
| 125 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
| 126 |
+
|
| 127 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
| 128 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
| 129 |
+
|
| 130 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
| 131 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
| 132 |
+
|
| 133 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 134 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 135 |
+
|
| 136 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 137 |
+
|
| 138 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 139 |
+
|
| 140 |
+
return mapping
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def assign_to_checkpoint(
|
| 144 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
| 148 |
+
to them. It splits attention layers, and takes into account additional replacements
|
| 149 |
+
that may arise.
|
| 150 |
+
|
| 151 |
+
Assigns the weights to the new checkpoint.
|
| 152 |
+
"""
|
| 153 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 154 |
+
|
| 155 |
+
# Splits the attention layers into three variables.
|
| 156 |
+
if attention_paths_to_split is not None:
|
| 157 |
+
for path, path_map in attention_paths_to_split.items():
|
| 158 |
+
old_tensor = old_checkpoint[path]
|
| 159 |
+
channels = old_tensor.shape[0] // 3
|
| 160 |
+
|
| 161 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 162 |
+
|
| 163 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 164 |
+
|
| 165 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 166 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 167 |
+
|
| 168 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 169 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 170 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 171 |
+
|
| 172 |
+
for path in paths:
|
| 173 |
+
new_path = path["new"]
|
| 174 |
+
|
| 175 |
+
# These have already been assigned
|
| 176 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Global renaming happens here
|
| 180 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 181 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 182 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 183 |
+
|
| 184 |
+
if additional_replacements is not None:
|
| 185 |
+
for replacement in additional_replacements:
|
| 186 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 187 |
+
|
| 188 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 189 |
+
if "proj_attn.weight" in new_path:
|
| 190 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 191 |
+
else:
|
| 192 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def conv_attn_to_linear(checkpoint):
|
| 196 |
+
keys = list(checkpoint.keys())
|
| 197 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 198 |
+
for key in keys:
|
| 199 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 200 |
+
if checkpoint[key].ndim > 2:
|
| 201 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 202 |
+
elif "proj_attn.weight" in key:
|
| 203 |
+
if checkpoint[key].ndim > 2:
|
| 204 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def linear_transformer_to_conv(checkpoint):
|
| 208 |
+
keys = list(checkpoint.keys())
|
| 209 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
| 210 |
+
for key in keys:
|
| 211 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
| 212 |
+
if checkpoint[key].ndim == 2:
|
| 213 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
| 217 |
+
"""
|
| 218 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
# extract state_dict for UNet
|
| 222 |
+
unet_state_dict = {}
|
| 223 |
+
unet_key = "model.diffusion_model."
|
| 224 |
+
keys = list(checkpoint.keys())
|
| 225 |
+
for key in keys:
|
| 226 |
+
if key.startswith(unet_key):
|
| 227 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 228 |
+
|
| 229 |
+
new_checkpoint = {}
|
| 230 |
+
|
| 231 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 232 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 233 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 234 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 235 |
+
|
| 236 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 237 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 238 |
+
|
| 239 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 240 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 241 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 242 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 243 |
+
|
| 244 |
+
# Retrieves the keys for the input blocks only
|
| 245 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 246 |
+
input_blocks = {
|
| 247 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
| 248 |
+
for layer_id in range(num_input_blocks)
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# Retrieves the keys for the middle blocks only
|
| 252 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 253 |
+
middle_blocks = {
|
| 254 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
| 255 |
+
for layer_id in range(num_middle_blocks)
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
# Retrieves the keys for the output blocks only
|
| 259 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 260 |
+
output_blocks = {
|
| 261 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
| 262 |
+
for layer_id in range(num_output_blocks)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
for i in range(1, num_input_blocks):
|
| 266 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 267 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 268 |
+
|
| 269 |
+
resnets = [
|
| 270 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 271 |
+
]
|
| 272 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 273 |
+
|
| 274 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 275 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 276 |
+
f"input_blocks.{i}.0.op.weight"
|
| 277 |
+
)
|
| 278 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 279 |
+
f"input_blocks.{i}.0.op.bias"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
paths = renew_resnet_paths(resnets)
|
| 283 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 284 |
+
assign_to_checkpoint(
|
| 285 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if len(attentions):
|
| 289 |
+
paths = renew_attention_paths(attentions)
|
| 290 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 291 |
+
assign_to_checkpoint(
|
| 292 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
resnet_0 = middle_blocks[0]
|
| 296 |
+
attentions = middle_blocks[1]
|
| 297 |
+
resnet_1 = middle_blocks[2]
|
| 298 |
+
|
| 299 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 300 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 301 |
+
|
| 302 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 303 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 304 |
+
|
| 305 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 306 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 307 |
+
assign_to_checkpoint(
|
| 308 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
for i in range(num_output_blocks):
|
| 312 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 313 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 314 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 315 |
+
output_block_list = {}
|
| 316 |
+
|
| 317 |
+
for layer in output_block_layers:
|
| 318 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 319 |
+
if layer_id in output_block_list:
|
| 320 |
+
output_block_list[layer_id].append(layer_name)
|
| 321 |
+
else:
|
| 322 |
+
output_block_list[layer_id] = [layer_name]
|
| 323 |
+
|
| 324 |
+
if len(output_block_list) > 1:
|
| 325 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 326 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 327 |
+
|
| 328 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 329 |
+
paths = renew_resnet_paths(resnets)
|
| 330 |
+
|
| 331 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 332 |
+
assign_to_checkpoint(
|
| 333 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# オリジナル:
|
| 337 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
| 338 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
| 339 |
+
|
| 340 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
| 341 |
+
for l in output_block_list.values():
|
| 342 |
+
l.sort()
|
| 343 |
+
|
| 344 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 345 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 346 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 347 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 348 |
+
]
|
| 349 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 350 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
# Clear attentions as they have been attributed above.
|
| 354 |
+
if len(attentions) == 2:
|
| 355 |
+
attentions = []
|
| 356 |
+
|
| 357 |
+
if len(attentions):
|
| 358 |
+
paths = renew_attention_paths(attentions)
|
| 359 |
+
meta_path = {
|
| 360 |
+
"old": f"output_blocks.{i}.1",
|
| 361 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 362 |
+
}
|
| 363 |
+
assign_to_checkpoint(
|
| 364 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 368 |
+
for path in resnet_0_paths:
|
| 369 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 370 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 371 |
+
|
| 372 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 373 |
+
|
| 374 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
| 375 |
+
if v2:
|
| 376 |
+
linear_transformer_to_conv(new_checkpoint)
|
| 377 |
+
|
| 378 |
+
return new_checkpoint
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 382 |
+
# extract state dict for VAE
|
| 383 |
+
vae_state_dict = {}
|
| 384 |
+
vae_key = "first_stage_model."
|
| 385 |
+
keys = list(checkpoint.keys())
|
| 386 |
+
for key in keys:
|
| 387 |
+
if key.startswith(vae_key):
|
| 388 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 389 |
+
# if len(vae_state_dict) == 0:
|
| 390 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
| 391 |
+
# vae_state_dict = checkpoint
|
| 392 |
+
|
| 393 |
+
new_checkpoint = {}
|
| 394 |
+
|
| 395 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 396 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 397 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 398 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 399 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 400 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 401 |
+
|
| 402 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 403 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 404 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 405 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 406 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 407 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 408 |
+
|
| 409 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 410 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 411 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 412 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 413 |
+
|
| 414 |
+
# Retrieves the keys for the encoder down blocks only
|
| 415 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 416 |
+
down_blocks = {
|
| 417 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
# Retrieves the keys for the decoder up blocks only
|
| 421 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 422 |
+
up_blocks = {
|
| 423 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
for i in range(num_down_blocks):
|
| 427 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 428 |
+
|
| 429 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 430 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 431 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 432 |
+
)
|
| 433 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 434 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 438 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 439 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 440 |
+
|
| 441 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 442 |
+
num_mid_res_blocks = 2
|
| 443 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 444 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 445 |
+
|
| 446 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 447 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 448 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 449 |
+
|
| 450 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 451 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 452 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 453 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 454 |
+
conv_attn_to_linear(new_checkpoint)
|
| 455 |
+
|
| 456 |
+
for i in range(num_up_blocks):
|
| 457 |
+
block_id = num_up_blocks - 1 - i
|
| 458 |
+
resnets = [
|
| 459 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 460 |
+
]
|
| 461 |
+
|
| 462 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 463 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 464 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 465 |
+
]
|
| 466 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 467 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 468 |
+
]
|
| 469 |
+
|
| 470 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 471 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 472 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 473 |
+
|
| 474 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 475 |
+
num_mid_res_blocks = 2
|
| 476 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 477 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 478 |
+
|
| 479 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 480 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 481 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 482 |
+
|
| 483 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 484 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 485 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 486 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 487 |
+
conv_attn_to_linear(new_checkpoint)
|
| 488 |
+
return new_checkpoint
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def create_unet_diffusers_config(v2):
|
| 492 |
+
"""
|
| 493 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 494 |
+
"""
|
| 495 |
+
# unet_params = original_config.model.params.unet_config.params
|
| 496 |
+
|
| 497 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
| 498 |
+
|
| 499 |
+
down_block_types = []
|
| 500 |
+
resolution = 1
|
| 501 |
+
for i in range(len(block_out_channels)):
|
| 502 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
| 503 |
+
down_block_types.append(block_type)
|
| 504 |
+
if i != len(block_out_channels) - 1:
|
| 505 |
+
resolution *= 2
|
| 506 |
+
|
| 507 |
+
up_block_types = []
|
| 508 |
+
for i in range(len(block_out_channels)):
|
| 509 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
| 510 |
+
up_block_types.append(block_type)
|
| 511 |
+
resolution //= 2
|
| 512 |
+
|
| 513 |
+
config = dict(
|
| 514 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
| 515 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
| 516 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
| 517 |
+
down_block_types=tuple(down_block_types),
|
| 518 |
+
up_block_types=tuple(up_block_types),
|
| 519 |
+
block_out_channels=tuple(block_out_channels),
|
| 520 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
| 521 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
| 522 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
return config
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def create_vae_diffusers_config():
|
| 529 |
+
"""
|
| 530 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 531 |
+
"""
|
| 532 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 533 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
| 534 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
| 535 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 536 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 537 |
+
|
| 538 |
+
config = dict(
|
| 539 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
| 540 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
| 541 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
| 542 |
+
down_block_types=tuple(down_block_types),
|
| 543 |
+
up_block_types=tuple(up_block_types),
|
| 544 |
+
block_out_channels=tuple(block_out_channels),
|
| 545 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
| 546 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
| 547 |
+
)
|
| 548 |
+
return config
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
| 552 |
+
keys = list(checkpoint.keys())
|
| 553 |
+
text_model_dict = {}
|
| 554 |
+
for key in keys:
|
| 555 |
+
if key.startswith("cond_stage_model.transformer"):
|
| 556 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
| 557 |
+
return text_model_dict
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
| 561 |
+
# 嫌になるくらい違うぞ!
|
| 562 |
+
def convert_key(key):
|
| 563 |
+
if not key.startswith("cond_stage_model"):
|
| 564 |
+
return None
|
| 565 |
+
|
| 566 |
+
# common conversion
|
| 567 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
| 568 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
| 569 |
+
|
| 570 |
+
if "resblocks" in key:
|
| 571 |
+
# resblocks conversion
|
| 572 |
+
key = key.replace(".resblocks.", ".layers.")
|
| 573 |
+
if ".ln_" in key:
|
| 574 |
+
key = key.replace(".ln_", ".layer_norm")
|
| 575 |
+
elif ".mlp." in key:
|
| 576 |
+
key = key.replace(".c_fc.", ".fc1.")
|
| 577 |
+
key = key.replace(".c_proj.", ".fc2.")
|
| 578 |
+
elif '.attn.out_proj' in key:
|
| 579 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
| 580 |
+
elif '.attn.in_proj' in key:
|
| 581 |
+
key = None # 特殊なので後で処理する
|
| 582 |
+
else:
|
| 583 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
| 584 |
+
elif '.positional_embedding' in key:
|
| 585 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
| 586 |
+
elif '.text_projection' in key:
|
| 587 |
+
key = None # 使われない???
|
| 588 |
+
elif '.logit_scale' in key:
|
| 589 |
+
key = None # 使われない???
|
| 590 |
+
elif '.token_embedding' in key:
|
| 591 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
| 592 |
+
elif '.ln_final' in key:
|
| 593 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
| 594 |
+
return key
|
| 595 |
+
|
| 596 |
+
keys = list(checkpoint.keys())
|
| 597 |
+
new_sd = {}
|
| 598 |
+
for key in keys:
|
| 599 |
+
# remove resblocks 23
|
| 600 |
+
if '.resblocks.23.' in key:
|
| 601 |
+
continue
|
| 602 |
+
new_key = convert_key(key)
|
| 603 |
+
if new_key is None:
|
| 604 |
+
continue
|
| 605 |
+
new_sd[new_key] = checkpoint[key]
|
| 606 |
+
|
| 607 |
+
# attnの変換
|
| 608 |
+
for key in keys:
|
| 609 |
+
if '.resblocks.23.' in key:
|
| 610 |
+
continue
|
| 611 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
| 612 |
+
# 三つに分割
|
| 613 |
+
values = torch.chunk(checkpoint[key], 3)
|
| 614 |
+
|
| 615 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
| 616 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
| 617 |
+
key_pfx = key_pfx.replace("_weight", "")
|
| 618 |
+
key_pfx = key_pfx.replace("_bias", "")
|
| 619 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
| 620 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
| 621 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
| 622 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
| 623 |
+
|
| 624 |
+
# rename or add position_ids
|
| 625 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
| 626 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
| 627 |
+
# waifu diffusion v1.4
|
| 628 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
| 629 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
| 630 |
+
else:
|
| 631 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
| 632 |
+
|
| 633 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
| 634 |
+
return new_sd
|
| 635 |
+
|
| 636 |
+
def is_safetensors(path):
|
| 637 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
| 638 |
+
|
| 639 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
| 640 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
| 641 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
| 642 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
| 643 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
| 644 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
| 645 |
+
]
|
| 646 |
+
|
| 647 |
+
state_dict = read_state_dict(ckpt_path)
|
| 648 |
+
|
| 649 |
+
key_reps = []
|
| 650 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
| 651 |
+
for key in state_dict.keys():
|
| 652 |
+
if key.startswith(rep_from):
|
| 653 |
+
new_key = rep_to + key[len(rep_from):]
|
| 654 |
+
key_reps.append((key, new_key))
|
| 655 |
+
|
| 656 |
+
for key, new_key in key_reps:
|
| 657 |
+
state_dict[new_key] = state_dict[key]
|
| 658 |
+
del state_dict[key]
|
| 659 |
+
|
| 660 |
+
return state_dict
|
| 661 |
+
|
| 662 |
+
def to_half(sd):
|
| 663 |
+
for key in sd.keys():
|
| 664 |
+
if 'model' in key and sd[key].dtype == torch.float:
|
| 665 |
+
sd[key] = sd[key].half()
|
| 666 |
+
return sd
|
| 667 |
+
|
| 668 |
+
def savemodel(state_dict,currentmodel,fname,savesets,model_a,metadata={}):
|
| 669 |
+
from modules import sd_models,shared
|
| 670 |
+
if "fp16" in savesets:
|
| 671 |
+
state_dict = to_half(state_dict)
|
| 672 |
+
pre = "fp16"
|
| 673 |
+
else:pre = ""
|
| 674 |
+
ext = ".safetensors" if "safetensors" in savesets else ".ckpt"
|
| 675 |
+
|
| 676 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
|
| 677 |
+
model_a_path= checkpoint_info.filename
|
| 678 |
+
modeldir = os.path.split(model_a_path)[0]
|
| 679 |
+
|
| 680 |
+
if not fname or fname == "":
|
| 681 |
+
fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext
|
| 682 |
+
if fname[0]=="_":fname = fname[1:]
|
| 683 |
+
else:
|
| 684 |
+
fname = fname if ext in fname else fname +pre+ext
|
| 685 |
+
|
| 686 |
+
fname = os.path.join(modeldir, fname)
|
| 687 |
+
|
| 688 |
+
if len(fname) > 255:
|
| 689 |
+
fname.replace(ext,"")
|
| 690 |
+
fname=fname[:240]+ext
|
| 691 |
+
|
| 692 |
+
# check if output file already exists
|
| 693 |
+
if os.path.isfile(fname) and not "overwrite" in savesets:
|
| 694 |
+
_err_msg = f"Output file ({fname}) existed and was not saved]"
|
| 695 |
+
print(_err_msg)
|
| 696 |
+
return _err_msg
|
| 697 |
+
|
| 698 |
+
print("Saving...")
|
| 699 |
+
if ext == ".safetensors":
|
| 700 |
+
safetensors.torch.save_file(state_dict, fname, metadata=metadata)
|
| 701 |
+
else:
|
| 702 |
+
torch.save(state_dict, fname)
|
| 703 |
+
print("Done!")
|
| 704 |
+
return "Merged model saved in "+fname
|
| 705 |
+
|
| 706 |
+
def filenamecutter(name,model_a = False):
|
| 707 |
+
from modules import sd_models
|
| 708 |
+
if name =="" or name ==[]: return
|
| 709 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
|
| 710 |
+
name= os.path.splitext(checkpoint_info.filename)[0]
|
| 711 |
+
|
| 712 |
+
if not model_a:
|
| 713 |
+
name = os.path.basename(name)
|
| 714 |
+
return name
|
| 715 |
+
|
| 716 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
| 717 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
| 718 |
+
import diffusers
|
| 719 |
+
print("diffusers version : ",diffusers.__version__)
|
| 720 |
+
state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
| 721 |
+
if dtype is not None:
|
| 722 |
+
for k, v in state_dict.items():
|
| 723 |
+
if type(v) is torch.Tensor:
|
| 724 |
+
state_dict[k] = v.to(dtype)
|
| 725 |
+
|
| 726 |
+
# Convert the UNet2DConditionModel model.
|
| 727 |
+
unet_config = create_unet_diffusers_config(v2)
|
| 728 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
| 729 |
+
|
| 730 |
+
unet = diffusers.UNet2DConditionModel(**unet_config)
|
| 731 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
| 732 |
+
print("loading u-net:", info)
|
| 733 |
+
|
| 734 |
+
# Convert the VAE model.
|
| 735 |
+
vae_config = create_vae_diffusers_config()
|
| 736 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
| 737 |
+
|
| 738 |
+
vae = diffusers.AutoencoderKL(**vae_config)
|
| 739 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
| 740 |
+
print("loading vae:", info)
|
| 741 |
+
|
| 742 |
+
# convert text_model
|
| 743 |
+
if v2:
|
| 744 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
| 745 |
+
cfg = CLIPTextConfig(
|
| 746 |
+
vocab_size=49408,
|
| 747 |
+
hidden_size=1024,
|
| 748 |
+
intermediate_size=4096,
|
| 749 |
+
num_hidden_layers=23,
|
| 750 |
+
num_attention_heads=16,
|
| 751 |
+
max_position_embeddings=77,
|
| 752 |
+
hidden_act="gelu",
|
| 753 |
+
layer_norm_eps=1e-05,
|
| 754 |
+
dropout=0.0,
|
| 755 |
+
attention_dropout=0.0,
|
| 756 |
+
initializer_range=0.02,
|
| 757 |
+
initializer_factor=1.0,
|
| 758 |
+
pad_token_id=1,
|
| 759 |
+
bos_token_id=0,
|
| 760 |
+
eos_token_id=2,
|
| 761 |
+
model_type="clip_text_model",
|
| 762 |
+
projection_dim=512,
|
| 763 |
+
torch_dtype="float32",
|
| 764 |
+
transformers_version="4.25.0.dev0",
|
| 765 |
+
)
|
| 766 |
+
text_model = CLIPTextModel._from_config(cfg)
|
| 767 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 768 |
+
else:
|
| 769 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
| 770 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 771 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 772 |
+
print("loading text encoder:", info)
|
| 773 |
+
|
| 774 |
+
return text_model, vae, unet
|
| 775 |
+
|
| 776 |
+
def usemodelgen(theta_0,model_a,model_name):
|
| 777 |
+
from modules import lowvram, devices, sd_hijack,shared, sd_vae
|
| 778 |
+
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
| 779 |
+
|
| 780 |
+
model = shared.sd_model
|
| 781 |
+
model.load_state_dict(theta_0, strict=False)
|
| 782 |
+
del theta_0
|
| 783 |
+
if shared.cmd_opts.opt_channelslast:
|
| 784 |
+
model.to(memory_format=torch.channels_last)
|
| 785 |
+
|
| 786 |
+
if not shared.cmd_opts.no_half:
|
| 787 |
+
vae = model.first_stage_model
|
| 788 |
+
|
| 789 |
+
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
| 790 |
+
if shared.cmd_opts.no_half_vae:
|
| 791 |
+
model.first_stage_model = None
|
| 792 |
+
|
| 793 |
+
model.half()
|
| 794 |
+
model.first_stage_model = vae
|
| 795 |
+
|
| 796 |
+
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
| 797 |
+
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
| 798 |
+
devices.dtype_unet = model.model.diffusion_model.dtype
|
| 799 |
+
|
| 800 |
+
if hasattr(shared.cmd_opts,"upcast_sampling"):
|
| 801 |
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
| 802 |
+
else:
|
| 803 |
+
devices.unet_needs_upcast = devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
| 804 |
+
|
| 805 |
+
model.first_stage_model.to(devices.dtype_vae)
|
| 806 |
+
sd_hijack.model_hijack.hijack(model)
|
| 807 |
+
|
| 808 |
+
model.logvar = shared.sd_model.logvar.to(devices.device)
|
| 809 |
+
|
| 810 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
| 811 |
+
setup_for_low_vram_s(model, shared.cmd_opts.medvram)
|
| 812 |
+
else:
|
| 813 |
+
model.to(shared.device)
|
| 814 |
+
|
| 815 |
+
model.eval()
|
| 816 |
+
|
| 817 |
+
shared.sd_model = model
|
| 818 |
+
try:
|
| 819 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
| 820 |
+
except:
|
| 821 |
+
pass
|
| 822 |
+
#shared.sd_model.sd_checkpoint_info.model_name = model_name
|
| 823 |
+
|
| 824 |
+
def _setvae():
|
| 825 |
+
sd_vae.delete_base_vae()
|
| 826 |
+
sd_vae.clear_loaded_vae()
|
| 827 |
+
vae_file, vae_source = sd_vae.resolve_vae(model_a)
|
| 828 |
+
sd_vae.load_vae(shared.sd_model, vae_file, vae_source)
|
| 829 |
+
|
| 830 |
+
try:
|
| 831 |
+
_setvae()
|
| 832 |
+
except:
|
| 833 |
+
print("ERROR:setting VAE skipped")
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
import torch
|
| 837 |
+
from modules import devices
|
| 838 |
+
|
| 839 |
+
module_in_gpu = None
|
| 840 |
+
cpu = torch.device("cpu")
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
def send_everything_to_cpu():
|
| 844 |
+
global module_in_gpu
|
| 845 |
+
|
| 846 |
+
if module_in_gpu is not None:
|
| 847 |
+
module_in_gpu.to(cpu)
|
| 848 |
+
|
| 849 |
+
module_in_gpu = None
|
| 850 |
+
|
| 851 |
+
def setup_for_low_vram_s(sd_model, use_medvram):
|
| 852 |
+
parents = {}
|
| 853 |
+
|
| 854 |
+
def send_me_to_gpu(module, _):
|
| 855 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
| 856 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
| 857 |
+
be in CPU
|
| 858 |
+
"""
|
| 859 |
+
global module_in_gpu
|
| 860 |
+
|
| 861 |
+
module = parents.get(module, module)
|
| 862 |
+
|
| 863 |
+
if module_in_gpu == module:
|
| 864 |
+
return
|
| 865 |
+
|
| 866 |
+
if module_in_gpu is not None:
|
| 867 |
+
module_in_gpu.to(cpu)
|
| 868 |
+
|
| 869 |
+
module.to(devices.device)
|
| 870 |
+
module_in_gpu = module
|
| 871 |
+
|
| 872 |
+
# see below for register_forward_pre_hook;
|
| 873 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
| 874 |
+
# useless here, and we just replace those methods
|
| 875 |
+
|
| 876 |
+
first_stage_model = sd_model.first_stage_model
|
| 877 |
+
first_stage_model_encode = sd_model.first_stage_model.encode
|
| 878 |
+
first_stage_model_decode = sd_model.first_stage_model.decode
|
| 879 |
+
|
| 880 |
+
def first_stage_model_encode_wrap(x):
|
| 881 |
+
send_me_to_gpu(first_stage_model, None)
|
| 882 |
+
return first_stage_model_encode(x)
|
| 883 |
+
|
| 884 |
+
def first_stage_model_decode_wrap(z):
|
| 885 |
+
send_me_to_gpu(first_stage_model, None)
|
| 886 |
+
return first_stage_model_decode(z)
|
| 887 |
+
|
| 888 |
+
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
| 889 |
+
if hasattr(sd_model.cond_stage_model, 'model'):
|
| 890 |
+
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
| 891 |
+
|
| 892 |
+
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
| 893 |
+
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
| 894 |
+
stored = sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
| 895 |
+
sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None
|
| 896 |
+
sd_model.to(devices.device)
|
| 897 |
+
sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
| 898 |
+
|
| 899 |
+
# register hooks for those the first three models
|
| 900 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
| 901 |
+
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
| 902 |
+
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
| 903 |
+
if sd_model.depth_model:
|
| 904 |
+
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
| 905 |
+
|
| 906 |
+
if hasattr(sd_model.cond_stage_model, 'model'):
|
| 907 |
+
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
| 908 |
+
del sd_model.cond_stage_model.transformer
|
| 909 |
+
|
| 910 |
+
if use_medvram:
|
| 911 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
| 912 |
+
else:
|
| 913 |
+
diff_model = sd_model.model.diffusion_model
|
| 914 |
+
|
| 915 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
| 916 |
+
# so that only one of them is in GPU at a time
|
| 917 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
| 918 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
| 919 |
+
sd_model.model.to(devices.device)
|
| 920 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
| 921 |
+
|
| 922 |
+
# install hooks for bits of third model
|
| 923 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
| 924 |
+
for block in diff_model.input_blocks:
|
| 925 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
| 926 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
| 927 |
+
for block in diff_model.output_blocks:
|
| 928 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
microsoftexcel-supermerger/scripts/mergers/pluslora.py
ADDED
|
@@ -0,0 +1,1298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from sklearn.linear_model import PassiveAggressiveClassifier
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import gc
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from torchmetrics import Precision
|
| 9 |
+
import modules.shared as shared
|
| 10 |
+
import gc
|
| 11 |
+
from safetensors.torch import load_file, save_file
|
| 12 |
+
from typing import List
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from modules import sd_models,scripts
|
| 15 |
+
from scripts.mergers.model_util import load_models_from_stable_diffusion_checkpoint,filenamecutter,savemodel
|
| 16 |
+
from modules.ui import create_refresh_button
|
| 17 |
+
|
| 18 |
+
def on_ui_tabs():
|
| 19 |
+
import lora
|
| 20 |
+
sml_path_root = scripts.basedir()
|
| 21 |
+
LWEIGHTSPRESETS="\
|
| 22 |
+
NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
|
| 23 |
+
ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
|
| 24 |
+
INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
|
| 25 |
+
IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
|
| 26 |
+
INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
|
| 27 |
+
MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
|
| 28 |
+
OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
|
| 29 |
+
OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
|
| 30 |
+
OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
|
| 31 |
+
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
| 32 |
+
sml_filepath = os.path.join(sml_path_root,"scripts", "lbwpresets.txt")
|
| 33 |
+
sml_lbwpresets=""
|
| 34 |
+
try:
|
| 35 |
+
with open(sml_filepath,encoding="utf-8") as f:
|
| 36 |
+
sml_lbwpresets = f.read()
|
| 37 |
+
except OSError as e:
|
| 38 |
+
sml_lbwpresets=LWEIGHTSPRESETS
|
| 39 |
+
|
| 40 |
+
with gr.Blocks(analytics_enabled=False) :
|
| 41 |
+
sml_submit_result = gr.Textbox(label="Message")
|
| 42 |
+
with gr.Row().style(equal_height=False):
|
| 43 |
+
sml_cpmerge = gr.Button(elem_id="model_merger_merge", value="Merge to Checkpoint",variant='primary')
|
| 44 |
+
sml_makelora = gr.Button(elem_id="model_merger_merge", value="Make LoRA (alpha * A - beta * B)",variant='primary')
|
| 45 |
+
sml_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True)
|
| 46 |
+
create_refresh_button(sml_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 47 |
+
sml_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True)
|
| 48 |
+
create_refresh_button(sml_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 49 |
+
with gr.Row().style(equal_height=False):
|
| 50 |
+
sml_merge = gr.Button(elem_id="model_merger_merge", value="Merge LoRAs",variant='primary')
|
| 51 |
+
alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=1)
|
| 52 |
+
beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=1)
|
| 53 |
+
with gr.Row().style(equal_height=False):
|
| 54 |
+
sml_settings = gr.CheckboxGroup(["same to Strength", "overwrite"], label="settings")
|
| 55 |
+
precision = gr.Radio(label = "save precision",choices=["float","fp16","bf16"],value = "fp16",type="value")
|
| 56 |
+
with gr.Row().style(equal_height=False):
|
| 57 |
+
sml_dim = gr.Radio(label = "remake dimension",choices = ["no","auto",*[2**(x+2) for x in range(9)]],value = "no",type = "value")
|
| 58 |
+
sml_filename = gr.Textbox(label="filename(option)",lines=1,visible =True,interactive = True)
|
| 59 |
+
sml_loranames = gr.Textbox(label='LoRAname1:ratio1:Blocks1,LoRAname2:ratio2:Blocks2,...(":blocks" is option, not necessary)',lines=1,value="",visible =True)
|
| 60 |
+
sml_dims = gr.CheckboxGroup(label = "limit dimension",choices=[],value = [],type="value",interactive=True,visible = False)
|
| 61 |
+
with gr.Row().style(equal_height=False):
|
| 62 |
+
sml_calcdim = gr.Button(elem_id="calcloras", value="calculate dimension of LoRAs(It may take a few minutes if there are many LoRAs)",variant='primary')
|
| 63 |
+
sml_update = gr.Button(elem_id="calcloras", value="update list",variant='primary')
|
| 64 |
+
sml_loras = gr.CheckboxGroup(label = "Lora",choices=[x[0] for x in lora.available_loras.items()],type="value",interactive=True,visible = True)
|
| 65 |
+
sml_loraratios = gr.TextArea(label="",value=sml_lbwpresets,visible =True,interactive = True)
|
| 66 |
+
|
| 67 |
+
sml_merge.click(
|
| 68 |
+
fn=lmerge,
|
| 69 |
+
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_dim,precision],
|
| 70 |
+
outputs=[sml_submit_result]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
sml_makelora.click(
|
| 74 |
+
fn=makelora,
|
| 75 |
+
inputs=[sml_model_a,sml_model_b,sml_dim,sml_filename,sml_settings,alpha,beta,precision],
|
| 76 |
+
outputs=[sml_submit_result]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
sml_cpmerge.click(
|
| 80 |
+
fn=pluslora,
|
| 81 |
+
inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_model_a,precision],
|
| 82 |
+
outputs=[sml_submit_result]
|
| 83 |
+
)
|
| 84 |
+
llist ={}
|
| 85 |
+
dlist =[]
|
| 86 |
+
dn = []
|
| 87 |
+
|
| 88 |
+
def updateloras():
|
| 89 |
+
lora.list_available_loras()
|
| 90 |
+
for n in lora.available_loras.items():
|
| 91 |
+
if n[0] not in llist:llist[n[0]] = ""
|
| 92 |
+
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()])
|
| 93 |
+
|
| 94 |
+
sml_update.click(fn = updateloras,outputs = [sml_loras])
|
| 95 |
+
|
| 96 |
+
def calculatedim():
|
| 97 |
+
print("listing dimensions...")
|
| 98 |
+
for n in tqdm(lora.available_loras.items()):
|
| 99 |
+
if n[0] in llist:
|
| 100 |
+
if llist[n[0]] !="": continue
|
| 101 |
+
c_lora = lora.available_loras.get(n[0], None)
|
| 102 |
+
d,t = dimgetter(c_lora.filename)
|
| 103 |
+
if t == "LoCon" : d = f"{d}:{t}"
|
| 104 |
+
if d not in dlist:
|
| 105 |
+
if type(d) == int :dlist.append(d)
|
| 106 |
+
elif d not in dn: dn.append(d)
|
| 107 |
+
llist[n[0]] = d
|
| 108 |
+
dlist.sort()
|
| 109 |
+
return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()],value =[]),gr.update(visible =True,choices = [x for x in (dlist+dn)])
|
| 110 |
+
|
| 111 |
+
sml_calcdim.click(
|
| 112 |
+
fn=calculatedim,
|
| 113 |
+
inputs=[],
|
| 114 |
+
outputs=[sml_loras,sml_dims]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def dimselector(dims):
|
| 118 |
+
if dims ==[]:return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()])
|
| 119 |
+
rl=[]
|
| 120 |
+
for d in dims:
|
| 121 |
+
for i in llist.items():
|
| 122 |
+
if d == i[1]:rl.append(f"{i[0]}({i[1]})")
|
| 123 |
+
return gr.update(choices = [l for l in rl],value =[])
|
| 124 |
+
|
| 125 |
+
def llister(names):
|
| 126 |
+
if names ==[] : return ""
|
| 127 |
+
else:
|
| 128 |
+
for i,n in enumerate(names):
|
| 129 |
+
if "(" in n:names[i] = n[:n.rfind("(")]
|
| 130 |
+
return ":1.0,".join(names)+":1.0"
|
| 131 |
+
sml_loras.change(fn=llister,inputs=[sml_loras],outputs=[sml_loranames])
|
| 132 |
+
sml_dims.change(fn=dimselector,inputs=[sml_dims],outputs=[sml_loras])
|
| 133 |
+
|
| 134 |
+
def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,precision):
|
| 135 |
+
print("make LoRA start")
|
| 136 |
+
if model_a == "" or model_b =="":
|
| 137 |
+
return "ERROR: No model Selected"
|
| 138 |
+
gc.collect()
|
| 139 |
+
|
| 140 |
+
if saveto =="" : saveto = makeloraname(model_a,model_b)
|
| 141 |
+
if not ".safetensors" in saveto :saveto += ".safetensors"
|
| 142 |
+
saveto = os.path.join(shared.cmd_opts.lora_dir,saveto)
|
| 143 |
+
|
| 144 |
+
dim = 128 if type(dim) != int else int(dim)
|
| 145 |
+
if os.path.isfile(saveto ) and not "overwrite" in settings:
|
| 146 |
+
_err_msg = f"Output file ({saveto}) existed and was not saved"
|
| 147 |
+
print(_err_msg)
|
| 148 |
+
return _err_msg
|
| 149 |
+
|
| 150 |
+
svd(fullpathfromname(model_a),fullpathfromname(model_b),False,dim,precision,saveto,alpha,beta)
|
| 151 |
+
return f"saved to {saveto}"
|
| 152 |
+
|
| 153 |
+
def lmerge(loranames,loraratioss,settings,filename,dim,precision):
|
| 154 |
+
import lora
|
| 155 |
+
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames]
|
| 156 |
+
if any([x is None for x in loras_on_disk]):
|
| 157 |
+
lora.list_available_loras()
|
| 158 |
+
|
| 159 |
+
loras_on_disk = [lora.available_loras.get(name, None) for name in loranames]
|
| 160 |
+
|
| 161 |
+
lnames = [loranames] if "," not in loranames else loranames.split(",")
|
| 162 |
+
|
| 163 |
+
for i, n in enumerate(lnames):
|
| 164 |
+
lnames[i] = n.split(":")
|
| 165 |
+
|
| 166 |
+
loraratios=loraratioss.splitlines()
|
| 167 |
+
ldict ={}
|
| 168 |
+
|
| 169 |
+
for i,l in enumerate(loraratios):
|
| 170 |
+
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
|
| 171 |
+
ldict[l.split(":")[0]]=l.split(":")[1]
|
| 172 |
+
|
| 173 |
+
ln = []
|
| 174 |
+
lr = []
|
| 175 |
+
ld = []
|
| 176 |
+
lt = []
|
| 177 |
+
dmax = 1
|
| 178 |
+
|
| 179 |
+
for i,n in enumerate(lnames):
|
| 180 |
+
if len(n) ==3:
|
| 181 |
+
if n[2].strip() in ldict:
|
| 182 |
+
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")]
|
| 183 |
+
else:ratio = [float(n[1])]*17
|
| 184 |
+
else:ratio = [float(n[1])]*17
|
| 185 |
+
c_lora = lora.available_loras.get(n[0], None)
|
| 186 |
+
ln.append(c_lora.filename)
|
| 187 |
+
lr.append(ratio)
|
| 188 |
+
d,t = dimgetter(c_lora.filename)
|
| 189 |
+
lt.append(t)
|
| 190 |
+
ld.append(d)
|
| 191 |
+
if d != "LyCORIS":
|
| 192 |
+
if d > dmax : dmax = d
|
| 193 |
+
|
| 194 |
+
if filename =="":filename =loranames.replace(",","+").replace(":","_")
|
| 195 |
+
if not ".safetensors" in filename:filename += ".safetensors"
|
| 196 |
+
filename = os.path.join(shared.cmd_opts.lora_dir,filename)
|
| 197 |
+
|
| 198 |
+
dim = int(dim) if dim != "no" and dim != "auto" else 0
|
| 199 |
+
|
| 200 |
+
if "LyCORIS" in ld or "LoCon" in lt:
|
| 201 |
+
if len(ld) !=1:
|
| 202 |
+
return "multiple merge of LyCORIS is not supported"
|
| 203 |
+
sd = lycomerge(ln[0],lr[0])
|
| 204 |
+
elif dim > 0:
|
| 205 |
+
print("change demension to ", dim)
|
| 206 |
+
sd = merge_lora_models_dim(ln, lr, dim,settings)
|
| 207 |
+
elif "auto" in settings and ld.count(ld[0]) != len(ld):
|
| 208 |
+
print("change demension to ",dmax)
|
| 209 |
+
sd = merge_lora_models_dim(ln, lr, dmax,settings)
|
| 210 |
+
else:
|
| 211 |
+
sd = merge_lora_models(ln, lr,settings)
|
| 212 |
+
|
| 213 |
+
if os.path.isfile(filename) and not "overwrite" in settings:
|
| 214 |
+
_err_msg = f"Output file ({filename}) existed and was not saved"
|
| 215 |
+
print(_err_msg)
|
| 216 |
+
return _err_msg
|
| 217 |
+
|
| 218 |
+
save_to_file(filename,sd,sd, str_to_dtype(precision))
|
| 219 |
+
return "saved : "+filename
|
| 220 |
+
|
| 221 |
+
def pluslora(lnames,loraratios,settings,output,model,precision):
|
| 222 |
+
if model == []:
|
| 223 |
+
return "ERROR: No model Selected"
|
| 224 |
+
if lnames == "":
|
| 225 |
+
return "ERROR: No LoRA Selected"
|
| 226 |
+
|
| 227 |
+
print("plus LoRA start")
|
| 228 |
+
import lora
|
| 229 |
+
lnames = [lnames] if "," not in lnames else lnames.split(",")
|
| 230 |
+
|
| 231 |
+
for i, n in enumerate(lnames):
|
| 232 |
+
lnames[i] = n.split(":")
|
| 233 |
+
|
| 234 |
+
loraratios=loraratios.splitlines()
|
| 235 |
+
ldict ={}
|
| 236 |
+
|
| 237 |
+
for i,l in enumerate(loraratios):
|
| 238 |
+
if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
|
| 239 |
+
ldict[l.split(":")[0].strip()]=l.split(":")[1]
|
| 240 |
+
|
| 241 |
+
names=[]
|
| 242 |
+
filenames=[]
|
| 243 |
+
loratypes=[]
|
| 244 |
+
lweis=[]
|
| 245 |
+
|
| 246 |
+
for n in lnames:
|
| 247 |
+
if len(n) ==3:
|
| 248 |
+
if n[2].strip() in ldict:
|
| 249 |
+
ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")]
|
| 250 |
+
else:ratio = [float(n[1])]*17
|
| 251 |
+
else:ratio = [float(n[1])]*17
|
| 252 |
+
c_lora = lora.available_loras.get(n[0], None)
|
| 253 |
+
names.append(n[0])
|
| 254 |
+
filenames.append(c_lora.filename)
|
| 255 |
+
_,t = dimgetter(c_lora.filename)
|
| 256 |
+
if "LyCORIS" in t: return "LyCORIS merge is not supported"
|
| 257 |
+
lweis.append(ratio)
|
| 258 |
+
|
| 259 |
+
modeln=filenamecutter(model,True)
|
| 260 |
+
dname = modeln
|
| 261 |
+
for n in names:
|
| 262 |
+
dname = dname + "+"+n
|
| 263 |
+
|
| 264 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(model)
|
| 265 |
+
print(f"Loading {model}")
|
| 266 |
+
theta_0 = sd_models.read_state_dict(checkpoint_info.filename,"cpu")
|
| 267 |
+
|
| 268 |
+
keychanger = {}
|
| 269 |
+
for key in theta_0.keys():
|
| 270 |
+
if "model" in key:
|
| 271 |
+
skey = key.replace(".","_").replace("_weight","")
|
| 272 |
+
keychanger[skey.split("model_",1)[1]] = key
|
| 273 |
+
|
| 274 |
+
for name,filename, lwei in zip(names,filenames, lweis):
|
| 275 |
+
print(f"loading: {name}")
|
| 276 |
+
lora_sd = load_state_dict(filename, torch.float)
|
| 277 |
+
|
| 278 |
+
print(f"merging..." ,lwei)
|
| 279 |
+
for key in lora_sd.keys():
|
| 280 |
+
ratio = 1
|
| 281 |
+
|
| 282 |
+
fullkey = convert_diffusers_name_to_compvis(key)
|
| 283 |
+
|
| 284 |
+
for i,block in enumerate(LORABLOCKS):
|
| 285 |
+
if block in fullkey:
|
| 286 |
+
ratio = lwei[i]
|
| 287 |
+
|
| 288 |
+
msd_key, lora_key = fullkey.split(".", 1)
|
| 289 |
+
|
| 290 |
+
if "lora_down" in key:
|
| 291 |
+
up_key = key.replace("lora_down", "lora_up")
|
| 292 |
+
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
| 293 |
+
|
| 294 |
+
# print(f"apply {key} to {module}")
|
| 295 |
+
|
| 296 |
+
down_weight = lora_sd[key].to(device="cpu")
|
| 297 |
+
up_weight = lora_sd[up_key].to(device="cpu")
|
| 298 |
+
|
| 299 |
+
dim = down_weight.size()[0]
|
| 300 |
+
alpha = lora_sd.get(alpha_key, dim)
|
| 301 |
+
scale = alpha / dim
|
| 302 |
+
# W <- W + U * D
|
| 303 |
+
weight = theta_0[keychanger[msd_key]].to(device="cpu")
|
| 304 |
+
|
| 305 |
+
if not len(down_weight.size()) == 4:
|
| 306 |
+
# linear
|
| 307 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 308 |
+
else:
|
| 309 |
+
# conv2d
|
| 310 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 311 |
+
).unsqueeze(2).unsqueeze(3) * scale
|
| 312 |
+
theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight)
|
| 313 |
+
#usemodelgen(theta_0,model)
|
| 314 |
+
settings.append(precision)
|
| 315 |
+
result = savemodel(theta_0,dname,output,settings,model)
|
| 316 |
+
del theta_0
|
| 317 |
+
gc.collect()
|
| 318 |
+
return result
|
| 319 |
+
|
| 320 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
| 321 |
+
if dtype is not None:
|
| 322 |
+
for key in list(state_dict.keys()):
|
| 323 |
+
if type(state_dict[key]) == torch.Tensor:
|
| 324 |
+
state_dict[key] = state_dict[key].to(dtype)
|
| 325 |
+
|
| 326 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
| 327 |
+
save_file(model, file_name)
|
| 328 |
+
else:
|
| 329 |
+
torch.save(model, file_name)
|
| 330 |
+
|
| 331 |
+
re_digits = re.compile(r"\d+")
|
| 332 |
+
|
| 333 |
+
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
| 334 |
+
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
| 335 |
+
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
| 336 |
+
|
| 337 |
+
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
|
| 338 |
+
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
|
| 339 |
+
re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
|
| 340 |
+
|
| 341 |
+
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
|
| 342 |
+
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
|
| 343 |
+
|
| 344 |
+
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def convert_diffusers_name_to_compvis(key):
|
| 348 |
+
def match(match_list, regex):
|
| 349 |
+
r = re.match(regex, key)
|
| 350 |
+
if not r:
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
match_list.clear()
|
| 354 |
+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
| 355 |
+
return True
|
| 356 |
+
|
| 357 |
+
m = []
|
| 358 |
+
|
| 359 |
+
if match(m, re_unet_down_blocks):
|
| 360 |
+
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
| 361 |
+
|
| 362 |
+
if match(m, re_unet_mid_blocks):
|
| 363 |
+
return f"diffusion_model_middle_block_1_{m[1]}"
|
| 364 |
+
|
| 365 |
+
if match(m, re_unet_up_blocks):
|
| 366 |
+
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
| 367 |
+
|
| 368 |
+
if match(m, re_unet_down_blocks_res):
|
| 369 |
+
block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
|
| 370 |
+
if m[2].startswith('conv1'):
|
| 371 |
+
return f"{block}in_layers_2{m[2][len('conv1'):]}"
|
| 372 |
+
elif m[2].startswith('conv2'):
|
| 373 |
+
return f"{block}out_layers_3{m[2][len('conv2'):]}"
|
| 374 |
+
elif m[2].startswith('time_emb_proj'):
|
| 375 |
+
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
|
| 376 |
+
elif m[2].startswith('conv_shortcut'):
|
| 377 |
+
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
|
| 378 |
+
|
| 379 |
+
if match(m, re_unet_mid_blocks_res):
|
| 380 |
+
block = f"diffusion_model_middle_block_{m[0]*2}_"
|
| 381 |
+
if m[1].startswith('conv1'):
|
| 382 |
+
return f"{block}in_layers_2{m[1][len('conv1'):]}"
|
| 383 |
+
elif m[1].startswith('conv2'):
|
| 384 |
+
return f"{block}out_layers_3{m[1][len('conv2'):]}"
|
| 385 |
+
elif m[1].startswith('time_emb_proj'):
|
| 386 |
+
return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
|
| 387 |
+
elif m[1].startswith('conv_shortcut'):
|
| 388 |
+
return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
|
| 389 |
+
|
| 390 |
+
if match(m, re_unet_up_blocks_res):
|
| 391 |
+
block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
|
| 392 |
+
if m[2].startswith('conv1'):
|
| 393 |
+
return f"{block}in_layers_2{m[2][len('conv1'):]}"
|
| 394 |
+
elif m[2].startswith('conv2'):
|
| 395 |
+
return f"{block}out_layers_3{m[2][len('conv2'):]}"
|
| 396 |
+
elif m[2].startswith('time_emb_proj'):
|
| 397 |
+
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
|
| 398 |
+
elif m[2].startswith('conv_shortcut'):
|
| 399 |
+
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
|
| 400 |
+
|
| 401 |
+
if match(m, re_unet_downsample):
|
| 402 |
+
return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
|
| 403 |
+
|
| 404 |
+
if match(m, re_unet_upsample):
|
| 405 |
+
return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
|
| 406 |
+
|
| 407 |
+
if match(m, re_text_block):
|
| 408 |
+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
| 409 |
+
|
| 410 |
+
return key
|
| 411 |
+
|
| 412 |
+
CLAMP_QUANTILE = 0.99
|
| 413 |
+
MIN_DIFF = 1e-6
|
| 414 |
+
|
| 415 |
+
def str_to_dtype(p):
|
| 416 |
+
if p == 'float':
|
| 417 |
+
return torch.float
|
| 418 |
+
if p == 'fp16':
|
| 419 |
+
return torch.float16
|
| 420 |
+
if p == 'bf16':
|
| 421 |
+
return torch.bfloat16
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
def svd(model_a,model_b,v2,dim,save_precision,save_to,alpha,beta):
|
| 425 |
+
save_dtype = str_to_dtype(save_precision)
|
| 426 |
+
|
| 427 |
+
if model_a == model_b:
|
| 428 |
+
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a)
|
| 429 |
+
text_encoder_o, _, unet_o = text_encoder_t, _, unet_t
|
| 430 |
+
else:
|
| 431 |
+
print(f"loading SD model : {model_b}")
|
| 432 |
+
text_encoder_o, _, unet_o = load_models_from_stable_diffusion_checkpoint(v2, model_b)
|
| 433 |
+
|
| 434 |
+
print(f"loading SD model : {model_a}")
|
| 435 |
+
text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a)
|
| 436 |
+
|
| 437 |
+
# create LoRA network to extract weights: Use dim (rank) as alpha
|
| 438 |
+
lora_network_o = create_network(1.0, dim, dim, None, text_encoder_o, unet_o)
|
| 439 |
+
lora_network_t = create_network(1.0, dim, dim, None, text_encoder_t, unet_t)
|
| 440 |
+
assert len(lora_network_o.text_encoder_loras) == len(
|
| 441 |
+
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
| 442 |
+
# get diffs
|
| 443 |
+
diffs = {}
|
| 444 |
+
text_encoder_different = False
|
| 445 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
| 446 |
+
lora_name = lora_o.lora_name
|
| 447 |
+
module_o = lora_o.org_module
|
| 448 |
+
module_t = lora_t.org_module
|
| 449 |
+
diff = alpha*module_t.weight - beta*module_o.weight
|
| 450 |
+
|
| 451 |
+
# Text Encoder might be same
|
| 452 |
+
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
| 453 |
+
text_encoder_different = True
|
| 454 |
+
|
| 455 |
+
diff = diff.float()
|
| 456 |
+
diffs[lora_name] = diff
|
| 457 |
+
|
| 458 |
+
if not text_encoder_different:
|
| 459 |
+
print("Text encoder is same. Extract U-Net only.")
|
| 460 |
+
lora_network_o.text_encoder_loras = []
|
| 461 |
+
diffs = {}
|
| 462 |
+
|
| 463 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
| 464 |
+
lora_name = lora_o.lora_name
|
| 465 |
+
module_o = lora_o.org_module
|
| 466 |
+
module_t = lora_t.org_module
|
| 467 |
+
diff = alpha*module_t.weight - beta*module_o.weight
|
| 468 |
+
diff = diff.float()
|
| 469 |
+
|
| 470 |
+
diffs[lora_name] = diff
|
| 471 |
+
|
| 472 |
+
# make LoRA with svd
|
| 473 |
+
print("calculating by svd")
|
| 474 |
+
rank = dim
|
| 475 |
+
lora_weights = {}
|
| 476 |
+
with torch.no_grad():
|
| 477 |
+
for lora_name, mat in tqdm(list(diffs.items())):
|
| 478 |
+
conv2d = (len(mat.size()) == 4)
|
| 479 |
+
if conv2d:
|
| 480 |
+
mat = mat.squeeze()
|
| 481 |
+
|
| 482 |
+
U, S, Vh = torch.linalg.svd(mat)
|
| 483 |
+
|
| 484 |
+
U = U[:, :rank]
|
| 485 |
+
S = S[:rank]
|
| 486 |
+
U = U @ torch.diag(S)
|
| 487 |
+
|
| 488 |
+
Vh = Vh[:rank, :]
|
| 489 |
+
|
| 490 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 491 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
| 492 |
+
low_val = -hi_val
|
| 493 |
+
|
| 494 |
+
U = U.clamp(low_val, hi_val)
|
| 495 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 496 |
+
|
| 497 |
+
lora_weights[lora_name] = (U, Vh)
|
| 498 |
+
|
| 499 |
+
# make state dict for LoRA
|
| 500 |
+
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
| 501 |
+
lora_sd = lora_network_o.state_dict()
|
| 502 |
+
print(f"LoRA has {len(lora_sd)} weights.")
|
| 503 |
+
|
| 504 |
+
for key in list(lora_sd.keys()):
|
| 505 |
+
if "alpha" in key:
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
+
lora_name = key.split('.')[0]
|
| 509 |
+
i = 0 if "lora_up" in key else 1
|
| 510 |
+
|
| 511 |
+
weights = lora_weights[lora_name][i]
|
| 512 |
+
# print(key, i, weights.size(), lora_sd[key].size())
|
| 513 |
+
if len(lora_sd[key].size()) == 4:
|
| 514 |
+
weights = weights.unsqueeze(2).unsqueeze(3)
|
| 515 |
+
|
| 516 |
+
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
| 517 |
+
lora_sd[key] = weights
|
| 518 |
+
|
| 519 |
+
# load state dict to LoRA and save it
|
| 520 |
+
info = lora_network_o.load_state_dict(lora_sd)
|
| 521 |
+
print(f"Loading extracted LoRA weights: {info}")
|
| 522 |
+
|
| 523 |
+
dir_name = os.path.dirname(save_to)
|
| 524 |
+
if dir_name and not os.path.exists(dir_name):
|
| 525 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 526 |
+
|
| 527 |
+
# minimum metadata
|
| 528 |
+
metadata = {"ss_network_dim": str(dim), "ss_network_alpha": str(dim)}
|
| 529 |
+
|
| 530 |
+
lora_network_o.save_weights(save_to, save_dtype, metadata)
|
| 531 |
+
print(f"LoRA weights are saved to: {save_to}")
|
| 532 |
+
return save_to
|
| 533 |
+
|
| 534 |
+
def load_state_dict(file_name, dtype):
|
| 535 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
| 536 |
+
sd = load_file(file_name)
|
| 537 |
+
else:
|
| 538 |
+
sd = torch.load(file_name, map_location='cpu')
|
| 539 |
+
for key in list(sd.keys()):
|
| 540 |
+
if type(sd[key]) == torch.Tensor:
|
| 541 |
+
sd[key] = sd[key].to(dtype)
|
| 542 |
+
return sd
|
| 543 |
+
|
| 544 |
+
def dimgetter(filename):
|
| 545 |
+
lora_sd = load_state_dict(filename, torch.float)
|
| 546 |
+
alpha = None
|
| 547 |
+
dim = None
|
| 548 |
+
type = None
|
| 549 |
+
|
| 550 |
+
if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys():
|
| 551 |
+
type = "LoCon"
|
| 552 |
+
|
| 553 |
+
for key, value in lora_sd.items():
|
| 554 |
+
|
| 555 |
+
if alpha is None and 'alpha' in key:
|
| 556 |
+
alpha = value
|
| 557 |
+
if dim is None and 'lora_down' in key and len(value.size()) == 2:
|
| 558 |
+
dim = value.size()[0]
|
| 559 |
+
if "hada_" in key:
|
| 560 |
+
dim,type = "LyCORIS","LyCORIS"
|
| 561 |
+
if alpha is not None and dim is not None:
|
| 562 |
+
break
|
| 563 |
+
if alpha is None:
|
| 564 |
+
alpha = dim
|
| 565 |
+
if type == None:type = "LoRA"
|
| 566 |
+
if dim :
|
| 567 |
+
return dim,type
|
| 568 |
+
else:
|
| 569 |
+
return "unknown","unknown"
|
| 570 |
+
|
| 571 |
+
def blockfromkey(key):
|
| 572 |
+
fullkey = convert_diffusers_name_to_compvis(key)
|
| 573 |
+
for i,n in enumerate(LORABLOCKS):
|
| 574 |
+
if n in fullkey: return i
|
| 575 |
+
return 0
|
| 576 |
+
|
| 577 |
+
def merge_lora_models_dim(models, ratios, new_rank,sets):
|
| 578 |
+
merged_sd = {}
|
| 579 |
+
fugou = 1
|
| 580 |
+
for model, ratios in zip(models, ratios):
|
| 581 |
+
merge_dtype = torch.float
|
| 582 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
| 583 |
+
|
| 584 |
+
# merge
|
| 585 |
+
print(f"merging {model}: {ratios}")
|
| 586 |
+
for key in tqdm(list(lora_sd.keys())):
|
| 587 |
+
if 'lora_down' not in key:
|
| 588 |
+
continue
|
| 589 |
+
lora_module_name = key[:key.rfind(".lora_down")]
|
| 590 |
+
|
| 591 |
+
down_weight = lora_sd[key]
|
| 592 |
+
network_dim = down_weight.size()[0]
|
| 593 |
+
|
| 594 |
+
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
| 595 |
+
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
| 596 |
+
|
| 597 |
+
in_dim = down_weight.size()[1]
|
| 598 |
+
out_dim = up_weight.size()[0]
|
| 599 |
+
conv2d = len(down_weight.size()) == 4
|
| 600 |
+
# print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
| 601 |
+
|
| 602 |
+
# make original weight if not exist
|
| 603 |
+
if lora_module_name not in merged_sd:
|
| 604 |
+
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
| 605 |
+
else:
|
| 606 |
+
weight = merged_sd[lora_module_name]
|
| 607 |
+
|
| 608 |
+
ratio = ratios[blockfromkey(key)]
|
| 609 |
+
if "same to Strength" in sets:
|
| 610 |
+
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1)
|
| 611 |
+
#print(lora_module_name, ratio)
|
| 612 |
+
# W <- W + U * D
|
| 613 |
+
scale = (alpha / network_dim)
|
| 614 |
+
if not conv2d: # linear
|
| 615 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale * fugou
|
| 616 |
+
else:
|
| 617 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
| 618 |
+
).unsqueeze(2).unsqueeze(3) * scale * fugou
|
| 619 |
+
|
| 620 |
+
merged_sd[lora_module_name] = weight
|
| 621 |
+
|
| 622 |
+
# extract from merged weights
|
| 623 |
+
print("extract new lora...")
|
| 624 |
+
merged_lora_sd = {}
|
| 625 |
+
with torch.no_grad():
|
| 626 |
+
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
| 627 |
+
conv2d = (len(mat.size()) == 4)
|
| 628 |
+
if conv2d:
|
| 629 |
+
mat = mat.squeeze()
|
| 630 |
+
|
| 631 |
+
U, S, Vh = torch.linalg.svd(mat)
|
| 632 |
+
|
| 633 |
+
U = U[:, :new_rank]
|
| 634 |
+
S = S[:new_rank]
|
| 635 |
+
U = U @ torch.diag(S)
|
| 636 |
+
|
| 637 |
+
Vh = Vh[:new_rank, :]
|
| 638 |
+
|
| 639 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 640 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
| 641 |
+
low_val = -hi_val
|
| 642 |
+
|
| 643 |
+
U = U.clamp(low_val, hi_val)
|
| 644 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 645 |
+
|
| 646 |
+
up_weight = U
|
| 647 |
+
down_weight = Vh
|
| 648 |
+
|
| 649 |
+
if conv2d:
|
| 650 |
+
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
| 651 |
+
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
| 652 |
+
|
| 653 |
+
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
| 654 |
+
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
| 655 |
+
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
| 656 |
+
|
| 657 |
+
return merged_lora_sd
|
| 658 |
+
|
| 659 |
+
def merge_lora_models(models, ratios,sets):
|
| 660 |
+
base_alphas = {} # alpha for merged model
|
| 661 |
+
base_dims = {}
|
| 662 |
+
merge_dtype = torch.float
|
| 663 |
+
merged_sd = {}
|
| 664 |
+
fugou = 1
|
| 665 |
+
for model, ratios in zip(models, ratios):
|
| 666 |
+
print(f"merging {model}: {ratios}")
|
| 667 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
| 668 |
+
|
| 669 |
+
# get alpha and dim
|
| 670 |
+
alphas = {} # alpha for current model
|
| 671 |
+
dims = {} # dims for current model
|
| 672 |
+
for key in lora_sd.keys():
|
| 673 |
+
if 'alpha' in key:
|
| 674 |
+
lora_module_name = key[:key.rfind(".alpha")]
|
| 675 |
+
alpha = float(lora_sd[key].detach().numpy())
|
| 676 |
+
alphas[lora_module_name] = alpha
|
| 677 |
+
if lora_module_name not in base_alphas:
|
| 678 |
+
base_alphas[lora_module_name] = alpha
|
| 679 |
+
elif "lora_down" in key:
|
| 680 |
+
lora_module_name = key[:key.rfind(".lora_down")]
|
| 681 |
+
dim = lora_sd[key].size()[0]
|
| 682 |
+
dims[lora_module_name] = dim
|
| 683 |
+
if lora_module_name not in base_dims:
|
| 684 |
+
base_dims[lora_module_name] = dim
|
| 685 |
+
|
| 686 |
+
for lora_module_name in dims.keys():
|
| 687 |
+
if lora_module_name not in alphas:
|
| 688 |
+
alpha = dims[lora_module_name]
|
| 689 |
+
alphas[lora_module_name] = alpha
|
| 690 |
+
if lora_module_name not in base_alphas:
|
| 691 |
+
base_alphas[lora_module_name] = alpha
|
| 692 |
+
|
| 693 |
+
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
| 694 |
+
|
| 695 |
+
# merge
|
| 696 |
+
print(f"merging...")
|
| 697 |
+
for key in lora_sd.keys():
|
| 698 |
+
if 'alpha' in key:
|
| 699 |
+
continue
|
| 700 |
+
if "lora_down" in key: dwon = True
|
| 701 |
+
lora_module_name = key[:key.rfind(".lora_")]
|
| 702 |
+
|
| 703 |
+
base_alpha = base_alphas[lora_module_name]
|
| 704 |
+
alpha = alphas[lora_module_name]
|
| 705 |
+
|
| 706 |
+
ratio = ratios[blockfromkey(key)]
|
| 707 |
+
if "same to Strength" in sets:
|
| 708 |
+
ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1)
|
| 709 |
+
|
| 710 |
+
if "lora_down" in key:
|
| 711 |
+
ratio = ratio * fugou
|
| 712 |
+
|
| 713 |
+
scale = math.sqrt(alpha / base_alpha) * ratio
|
| 714 |
+
|
| 715 |
+
if key in merged_sd:
|
| 716 |
+
assert merged_sd[key].size() == lora_sd[key].size(
|
| 717 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
| 718 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
| 719 |
+
else:
|
| 720 |
+
merged_sd[key] = lora_sd[key] * scale
|
| 721 |
+
|
| 722 |
+
# set alpha to sd
|
| 723 |
+
for lora_module_name, alpha in base_alphas.items():
|
| 724 |
+
key = lora_module_name + ".alpha"
|
| 725 |
+
merged_sd[key] = torch.tensor(alpha)
|
| 726 |
+
|
| 727 |
+
print("merged model")
|
| 728 |
+
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
| 729 |
+
|
| 730 |
+
return merged_sd
|
| 731 |
+
|
| 732 |
+
def fullpathfromname(name):
|
| 733 |
+
if hash == "" or hash ==[]: return ""
|
| 734 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(name)
|
| 735 |
+
return checkpoint_info.filename
|
| 736 |
+
|
| 737 |
+
def makeloraname(model_a,model_b):
|
| 738 |
+
model_a=filenamecutter(model_a)
|
| 739 |
+
model_b=filenamecutter(model_b)
|
| 740 |
+
return "lora_"+model_a+"-"+model_b
|
| 741 |
+
|
| 742 |
+
def lycomerge(filename,ratios):
|
| 743 |
+
sd = load_state_dict(filename, torch.float)
|
| 744 |
+
|
| 745 |
+
if len(ratios) == 17:
|
| 746 |
+
r0 = 1
|
| 747 |
+
ratios = [ratios[0]] + [r0] + ratios[1:3]+ [r0] + ratios[3:5]+[r0] + ratios[5:7]+[r0,r0,r0] + [ratios[7]] + [r0,r0,r0] + ratios[8:]
|
| 748 |
+
|
| 749 |
+
print("LyCORIS: " , ratios)
|
| 750 |
+
|
| 751 |
+
keys_failed_to_match = []
|
| 752 |
+
|
| 753 |
+
for lkey, weight in sd.items():
|
| 754 |
+
ratio = 1
|
| 755 |
+
picked = False
|
| 756 |
+
if 'alpha' in lkey:
|
| 757 |
+
continue
|
| 758 |
+
|
| 759 |
+
fullkey = convert_diffusers_name_to_compvis(lkey)
|
| 760 |
+
key, lora_key = fullkey.split(".", 1)
|
| 761 |
+
|
| 762 |
+
for i,block in enumerate(LYCOBLOCKS):
|
| 763 |
+
if block in key:
|
| 764 |
+
ratio = ratios[i]
|
| 765 |
+
picked = True
|
| 766 |
+
if not picked: keys_failed_to_match.append(key)
|
| 767 |
+
|
| 768 |
+
sd[lkey] = weight * math.sqrt(abs(float(ratio)))
|
| 769 |
+
|
| 770 |
+
if "down" in lkey and ratio < 0:
|
| 771 |
+
sd[key] = sd[key] * -1
|
| 772 |
+
|
| 773 |
+
if len(keys_failed_to_match) > 0:
|
| 774 |
+
print(keys_failed_to_match)
|
| 775 |
+
|
| 776 |
+
return sd
|
| 777 |
+
|
| 778 |
+
LORABLOCKS=["encoder",
|
| 779 |
+
"diffusion_model_input_blocks_1_",
|
| 780 |
+
"diffusion_model_input_blocks_2_",
|
| 781 |
+
"diffusion_model_input_blocks_4_",
|
| 782 |
+
"diffusion_model_input_blocks_5_",
|
| 783 |
+
"diffusion_model_input_blocks_7_",
|
| 784 |
+
"diffusion_model_input_blocks_8_",
|
| 785 |
+
"diffusion_model_middle_block_",
|
| 786 |
+
"diffusion_model_output_blocks_3_",
|
| 787 |
+
"diffusion_model_output_blocks_4_",
|
| 788 |
+
"diffusion_model_output_blocks_5_",
|
| 789 |
+
"diffusion_model_output_blocks_6_",
|
| 790 |
+
"diffusion_model_output_blocks_7_",
|
| 791 |
+
"diffusion_model_output_blocks_8_",
|
| 792 |
+
"diffusion_model_output_blocks_9_",
|
| 793 |
+
"diffusion_model_output_blocks_10_",
|
| 794 |
+
"diffusion_model_output_blocks_11_"]
|
| 795 |
+
|
| 796 |
+
LYCOBLOCKS=["encoder",
|
| 797 |
+
"diffusion_model_input_blocks_0_",
|
| 798 |
+
"diffusion_model_input_blocks_1_",
|
| 799 |
+
"diffusion_model_input_blocks_2_",
|
| 800 |
+
"diffusion_model_input_blocks_3_",
|
| 801 |
+
"diffusion_model_input_blocks_4_",
|
| 802 |
+
"diffusion_model_input_blocks_5_",
|
| 803 |
+
"diffusion_model_input_blocks_6_",
|
| 804 |
+
"diffusion_model_input_blocks_7_",
|
| 805 |
+
"diffusion_model_input_blocks_8_",
|
| 806 |
+
"diffusion_model_input_blocks_9_",
|
| 807 |
+
"diffusion_model_input_blocks_10_",
|
| 808 |
+
"diffusion_model_input_blocks_11_",
|
| 809 |
+
"diffusion_model_middle_block_",
|
| 810 |
+
"diffusion_model_output_blocks_0_",
|
| 811 |
+
"diffusion_model_output_blocks_1_",
|
| 812 |
+
"diffusion_model_output_blocks_2_",
|
| 813 |
+
"diffusion_model_output_blocks_3_",
|
| 814 |
+
"diffusion_model_output_blocks_4_",
|
| 815 |
+
"diffusion_model_output_blocks_5_",
|
| 816 |
+
"diffusion_model_output_blocks_6_",
|
| 817 |
+
"diffusion_model_output_blocks_7_",
|
| 818 |
+
"diffusion_model_output_blocks_8_",
|
| 819 |
+
"diffusion_model_output_blocks_9_",
|
| 820 |
+
"diffusion_model_output_blocks_10_",
|
| 821 |
+
"diffusion_model_output_blocks_11_"]
|
| 822 |
+
|
| 823 |
+
class LoRAModule(torch.nn.Module):
|
| 824 |
+
"""
|
| 825 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 826 |
+
"""
|
| 827 |
+
|
| 828 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
| 829 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 830 |
+
super().__init__()
|
| 831 |
+
self.lora_name = lora_name
|
| 832 |
+
|
| 833 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 834 |
+
in_dim = org_module.in_channels
|
| 835 |
+
out_dim = org_module.out_channels
|
| 836 |
+
else:
|
| 837 |
+
in_dim = org_module.in_features
|
| 838 |
+
out_dim = org_module.out_features
|
| 839 |
+
|
| 840 |
+
# if limit_rank:
|
| 841 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
| 842 |
+
# if self.lora_dim != lora_dim:
|
| 843 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
| 844 |
+
# else:
|
| 845 |
+
self.lora_dim = lora_dim
|
| 846 |
+
|
| 847 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 848 |
+
kernel_size = org_module.kernel_size
|
| 849 |
+
stride = org_module.stride
|
| 850 |
+
padding = org_module.padding
|
| 851 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 852 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 853 |
+
else:
|
| 854 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 855 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 856 |
+
|
| 857 |
+
if type(alpha) == torch.Tensor:
|
| 858 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 859 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 860 |
+
self.scale = alpha / self.lora_dim
|
| 861 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
| 862 |
+
|
| 863 |
+
# same as microsoft's
|
| 864 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 865 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 866 |
+
|
| 867 |
+
self.multiplier = multiplier
|
| 868 |
+
self.org_module = org_module # remove in applying
|
| 869 |
+
self.region = None
|
| 870 |
+
self.region_mask = None
|
| 871 |
+
|
| 872 |
+
def apply_to(self):
|
| 873 |
+
self.org_forward = self.org_module.forward
|
| 874 |
+
self.org_module.forward = self.forward
|
| 875 |
+
del self.org_module
|
| 876 |
+
|
| 877 |
+
def merge_to(self, sd, dtype, device):
|
| 878 |
+
# get up/down weight
|
| 879 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
| 880 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
| 881 |
+
|
| 882 |
+
# extract weight from org_module
|
| 883 |
+
org_sd = self.org_module.state_dict()
|
| 884 |
+
weight = org_sd["weight"].to(torch.float)
|
| 885 |
+
|
| 886 |
+
# merge weight
|
| 887 |
+
if len(weight.size()) == 2:
|
| 888 |
+
# linear
|
| 889 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
| 890 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 891 |
+
# conv2d 1x1
|
| 892 |
+
weight = (
|
| 893 |
+
weight
|
| 894 |
+
+ self.multiplier
|
| 895 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 896 |
+
* self.scale
|
| 897 |
+
)
|
| 898 |
+
else:
|
| 899 |
+
# conv2d 3x3
|
| 900 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 901 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
| 902 |
+
weight = weight + self.multiplier * conved * self.scale
|
| 903 |
+
|
| 904 |
+
# set weight to org_module
|
| 905 |
+
org_sd["weight"] = weight.to(dtype)
|
| 906 |
+
self.org_module.load_state_dict(org_sd)
|
| 907 |
+
|
| 908 |
+
def set_region(self, region):
|
| 909 |
+
self.region = region
|
| 910 |
+
self.region_mask = None
|
| 911 |
+
|
| 912 |
+
def forward(self, x):
|
| 913 |
+
if self.region is None:
|
| 914 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 915 |
+
|
| 916 |
+
# regional LoRA FIXME same as additional-network extension
|
| 917 |
+
if x.size()[1] % 77 == 0:
|
| 918 |
+
# print(f"LoRA for context: {self.lora_name}")
|
| 919 |
+
self.region = None
|
| 920 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 921 |
+
|
| 922 |
+
# calculate region mask first time
|
| 923 |
+
if self.region_mask is None:
|
| 924 |
+
if len(x.size()) == 4:
|
| 925 |
+
h, w = x.size()[2:4]
|
| 926 |
+
else:
|
| 927 |
+
seq_len = x.size()[1]
|
| 928 |
+
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
| 929 |
+
h = int(self.region.size()[0] / ratio + 0.5)
|
| 930 |
+
w = seq_len // h
|
| 931 |
+
|
| 932 |
+
r = self.region.to(x.device)
|
| 933 |
+
if r.dtype == torch.bfloat16:
|
| 934 |
+
r = r.to(torch.float)
|
| 935 |
+
r = r.unsqueeze(0).unsqueeze(1)
|
| 936 |
+
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
| 937 |
+
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
|
| 938 |
+
r = r.to(x.dtype)
|
| 939 |
+
|
| 940 |
+
if len(x.size()) == 3:
|
| 941 |
+
r = torch.reshape(r, (1, x.size()[1], -1))
|
| 942 |
+
|
| 943 |
+
self.region_mask = r
|
| 944 |
+
|
| 945 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
| 946 |
+
|
| 947 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
| 948 |
+
if network_dim is None:
|
| 949 |
+
network_dim = 4 # default
|
| 950 |
+
|
| 951 |
+
# extract dim/alpha for conv2d, and block dim
|
| 952 |
+
conv_dim = kwargs.get("conv_dim", None)
|
| 953 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
| 954 |
+
if conv_dim is not None:
|
| 955 |
+
conv_dim = int(conv_dim)
|
| 956 |
+
if conv_alpha is None:
|
| 957 |
+
conv_alpha = 1.0
|
| 958 |
+
else:
|
| 959 |
+
conv_alpha = float(conv_alpha)
|
| 960 |
+
|
| 961 |
+
"""
|
| 962 |
+
block_dims = kwargs.get("block_dims")
|
| 963 |
+
block_alphas = None
|
| 964 |
+
if block_dims is not None:
|
| 965 |
+
block_dims = [int(d) for d in block_dims.split(',')]
|
| 966 |
+
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
| 967 |
+
block_alphas = kwargs.get("block_alphas")
|
| 968 |
+
if block_alphas is None:
|
| 969 |
+
block_alphas = [1] * len(block_dims)
|
| 970 |
+
else:
|
| 971 |
+
block_alphas = [int(a) for a in block_alphas(',')]
|
| 972 |
+
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
| 973 |
+
conv_block_dims = kwargs.get("conv_block_dims")
|
| 974 |
+
conv_block_alphas = None
|
| 975 |
+
if conv_block_dims is not None:
|
| 976 |
+
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
| 977 |
+
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
| 978 |
+
conv_block_alphas = kwargs.get("conv_block_alphas")
|
| 979 |
+
if conv_block_alphas is None:
|
| 980 |
+
conv_block_alphas = [1] * len(conv_block_dims)
|
| 981 |
+
else:
|
| 982 |
+
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
| 983 |
+
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
network = LoRANetwork(
|
| 987 |
+
text_encoder,
|
| 988 |
+
unet,
|
| 989 |
+
multiplier=multiplier,
|
| 990 |
+
lora_dim=network_dim,
|
| 991 |
+
alpha=network_alpha,
|
| 992 |
+
conv_lora_dim=conv_dim,
|
| 993 |
+
conv_alpha=conv_alpha,
|
| 994 |
+
)
|
| 995 |
+
return network
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
class LoRANetwork(torch.nn.Module):
|
| 1000 |
+
# is it possible to apply conv_in and conv_out?
|
| 1001 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
| 1002 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 1003 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 1004 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 1005 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 1006 |
+
|
| 1007 |
+
def __init__(
|
| 1008 |
+
self,
|
| 1009 |
+
text_encoder,
|
| 1010 |
+
unet,
|
| 1011 |
+
multiplier=1.0,
|
| 1012 |
+
lora_dim=4,
|
| 1013 |
+
alpha=1,
|
| 1014 |
+
conv_lora_dim=None,
|
| 1015 |
+
conv_alpha=None,
|
| 1016 |
+
modules_dim=None,
|
| 1017 |
+
modules_alpha=None,
|
| 1018 |
+
) -> None:
|
| 1019 |
+
super().__init__()
|
| 1020 |
+
self.multiplier = multiplier
|
| 1021 |
+
|
| 1022 |
+
self.lora_dim = lora_dim
|
| 1023 |
+
self.alpha = alpha
|
| 1024 |
+
self.conv_lora_dim = conv_lora_dim
|
| 1025 |
+
self.conv_alpha = conv_alpha
|
| 1026 |
+
|
| 1027 |
+
if modules_dim is not None:
|
| 1028 |
+
print(f"create LoRA network from weights")
|
| 1029 |
+
else:
|
| 1030 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 1031 |
+
|
| 1032 |
+
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
| 1033 |
+
if self.apply_to_conv2d_3x3:
|
| 1034 |
+
if self.conv_alpha is None:
|
| 1035 |
+
self.conv_alpha = self.alpha
|
| 1036 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 1037 |
+
|
| 1038 |
+
# create module instances
|
| 1039 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
| 1040 |
+
loras = []
|
| 1041 |
+
for name, module in root_module.named_modules():
|
| 1042 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 1043 |
+
# TODO get block index here
|
| 1044 |
+
for child_name, child_module in module.named_modules():
|
| 1045 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 1046 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 1047 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 1048 |
+
if is_linear or is_conv2d:
|
| 1049 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 1050 |
+
lora_name = lora_name.replace(".", "_")
|
| 1051 |
+
|
| 1052 |
+
if modules_dim is not None:
|
| 1053 |
+
if lora_name not in modules_dim:
|
| 1054 |
+
continue # no LoRA module in this weights file
|
| 1055 |
+
dim = modules_dim[lora_name]
|
| 1056 |
+
alpha = modules_alpha[lora_name]
|
| 1057 |
+
else:
|
| 1058 |
+
if is_linear or is_conv2d_1x1:
|
| 1059 |
+
dim = self.lora_dim
|
| 1060 |
+
alpha = self.alpha
|
| 1061 |
+
elif self.apply_to_conv2d_3x3:
|
| 1062 |
+
dim = self.conv_lora_dim
|
| 1063 |
+
alpha = self.conv_alpha
|
| 1064 |
+
else:
|
| 1065 |
+
continue
|
| 1066 |
+
|
| 1067 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
| 1068 |
+
loras.append(lora)
|
| 1069 |
+
return loras
|
| 1070 |
+
|
| 1071 |
+
self.text_encoder_loras = create_modules(
|
| 1072 |
+
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
| 1073 |
+
)
|
| 1074 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 1075 |
+
|
| 1076 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 1077 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 1078 |
+
if modules_dim is not None or self.conv_lora_dim is not None:
|
| 1079 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 1080 |
+
|
| 1081 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
| 1082 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 1083 |
+
|
| 1084 |
+
self.weights_sd = None
|
| 1085 |
+
|
| 1086 |
+
# assertion
|
| 1087 |
+
names = set()
|
| 1088 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1089 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 1090 |
+
names.add(lora.lora_name)
|
| 1091 |
+
|
| 1092 |
+
def set_multiplier(self, multiplier):
|
| 1093 |
+
self.multiplier = multiplier
|
| 1094 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1095 |
+
lora.multiplier = self.multiplier
|
| 1096 |
+
|
| 1097 |
+
def load_weights(self, file):
|
| 1098 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1099 |
+
from safetensors.torch import load_file, safe_open
|
| 1100 |
+
|
| 1101 |
+
self.weights_sd = load_file(file)
|
| 1102 |
+
else:
|
| 1103 |
+
self.weights_sd = torch.load(file, map_location="cpu")
|
| 1104 |
+
|
| 1105 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
| 1106 |
+
if self.weights_sd:
|
| 1107 |
+
weights_has_text_encoder = weights_has_unet = False
|
| 1108 |
+
for key in self.weights_sd.keys():
|
| 1109 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 1110 |
+
weights_has_text_encoder = True
|
| 1111 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
| 1112 |
+
weights_has_unet = True
|
| 1113 |
+
|
| 1114 |
+
if apply_text_encoder is None:
|
| 1115 |
+
apply_text_encoder = weights_has_text_encoder
|
| 1116 |
+
else:
|
| 1117 |
+
assert (
|
| 1118 |
+
apply_text_encoder == weights_has_text_encoder
|
| 1119 |
+
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
| 1120 |
+
|
| 1121 |
+
if apply_unet is None:
|
| 1122 |
+
apply_unet = weights_has_unet
|
| 1123 |
+
else:
|
| 1124 |
+
assert (
|
| 1125 |
+
apply_unet == weights_has_unet
|
| 1126 |
+
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
| 1127 |
+
else:
|
| 1128 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
| 1129 |
+
|
| 1130 |
+
if apply_text_encoder:
|
| 1131 |
+
print("enable LoRA for text encoder")
|
| 1132 |
+
else:
|
| 1133 |
+
self.text_encoder_loras = []
|
| 1134 |
+
|
| 1135 |
+
if apply_unet:
|
| 1136 |
+
print("enable LoRA for U-Net")
|
| 1137 |
+
else:
|
| 1138 |
+
self.unet_loras = []
|
| 1139 |
+
|
| 1140 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1141 |
+
lora.apply_to()
|
| 1142 |
+
self.add_module(lora.lora_name, lora)
|
| 1143 |
+
|
| 1144 |
+
if self.weights_sd:
|
| 1145 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
| 1146 |
+
info = self.load_state_dict(self.weights_sd, False)
|
| 1147 |
+
print(f"weights are loaded: {info}")
|
| 1148 |
+
|
| 1149 |
+
# TODO refactor to common function with apply_to
|
| 1150 |
+
def merge_to(self, text_encoder, unet, dtype, device):
|
| 1151 |
+
assert self.weights_sd is not None, "weights are not loaded"
|
| 1152 |
+
|
| 1153 |
+
apply_text_encoder = apply_unet = False
|
| 1154 |
+
for key in self.weights_sd.keys():
|
| 1155 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 1156 |
+
apply_text_encoder = True
|
| 1157 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
| 1158 |
+
apply_unet = True
|
| 1159 |
+
|
| 1160 |
+
if apply_text_encoder:
|
| 1161 |
+
print("enable LoRA for text encoder")
|
| 1162 |
+
else:
|
| 1163 |
+
self.text_encoder_loras = []
|
| 1164 |
+
|
| 1165 |
+
if apply_unet:
|
| 1166 |
+
print("enable LoRA for U-Net")
|
| 1167 |
+
else:
|
| 1168 |
+
self.unet_loras = []
|
| 1169 |
+
|
| 1170 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1171 |
+
sd_for_lora = {}
|
| 1172 |
+
for key in self.weights_sd.keys():
|
| 1173 |
+
if key.startswith(lora.lora_name):
|
| 1174 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
| 1175 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
| 1176 |
+
print(f"weights are merged")
|
| 1177 |
+
|
| 1178 |
+
def enable_gradient_checkpointing(self):
|
| 1179 |
+
# not supported
|
| 1180 |
+
pass
|
| 1181 |
+
|
| 1182 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
| 1183 |
+
def enumerate_params(loras):
|
| 1184 |
+
params = []
|
| 1185 |
+
for lora in loras:
|
| 1186 |
+
params.extend(lora.parameters())
|
| 1187 |
+
return params
|
| 1188 |
+
|
| 1189 |
+
self.requires_grad_(True)
|
| 1190 |
+
all_params = []
|
| 1191 |
+
|
| 1192 |
+
if self.text_encoder_loras:
|
| 1193 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
| 1194 |
+
if text_encoder_lr is not None:
|
| 1195 |
+
param_data["lr"] = text_encoder_lr
|
| 1196 |
+
all_params.append(param_data)
|
| 1197 |
+
|
| 1198 |
+
if self.unet_loras:
|
| 1199 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
| 1200 |
+
if unet_lr is not None:
|
| 1201 |
+
param_data["lr"] = unet_lr
|
| 1202 |
+
all_params.append(param_data)
|
| 1203 |
+
|
| 1204 |
+
return all_params
|
| 1205 |
+
|
| 1206 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
| 1207 |
+
self.requires_grad_(True)
|
| 1208 |
+
|
| 1209 |
+
def on_epoch_start(self, text_encoder, unet):
|
| 1210 |
+
self.train()
|
| 1211 |
+
|
| 1212 |
+
def get_trainable_params(self):
|
| 1213 |
+
return self.parameters()
|
| 1214 |
+
|
| 1215 |
+
def save_weights(self, file, dtype, metadata):
|
| 1216 |
+
if metadata is not None and len(metadata) == 0:
|
| 1217 |
+
metadata = None
|
| 1218 |
+
|
| 1219 |
+
state_dict = self.state_dict()
|
| 1220 |
+
|
| 1221 |
+
if dtype is not None:
|
| 1222 |
+
for key in list(state_dict.keys()):
|
| 1223 |
+
v = state_dict[key]
|
| 1224 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 1225 |
+
state_dict[key] = v
|
| 1226 |
+
|
| 1227 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1228 |
+
from safetensors.torch import save_file
|
| 1229 |
+
|
| 1230 |
+
# Precalculate model hashes to save time on indexing
|
| 1231 |
+
if metadata is None:
|
| 1232 |
+
metadata = {}
|
| 1233 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
| 1234 |
+
metadata["sshs_model_hash"] = model_hash
|
| 1235 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 1236 |
+
|
| 1237 |
+
save_file(state_dict, file, metadata)
|
| 1238 |
+
else:
|
| 1239 |
+
torch.save(state_dict, file)
|
| 1240 |
+
|
| 1241 |
+
@staticmethod
|
| 1242 |
+
def set_regions(networks, image):
|
| 1243 |
+
image = image.astype(np.float32) / 255.0
|
| 1244 |
+
for i, network in enumerate(networks[:3]):
|
| 1245 |
+
# NOTE: consider averaging overwrapping area
|
| 1246 |
+
region = image[:, :, i]
|
| 1247 |
+
if region.max() == 0:
|
| 1248 |
+
continue
|
| 1249 |
+
region = torch.tensor(region)
|
| 1250 |
+
network.set_region(region)
|
| 1251 |
+
|
| 1252 |
+
def set_region(self, region):
|
| 1253 |
+
for lora in self.unet_loras:
|
| 1254 |
+
lora.set_region(region)
|
| 1255 |
+
|
| 1256 |
+
from io import BytesIO
|
| 1257 |
+
import safetensors.torch
|
| 1258 |
+
import hashlib
|
| 1259 |
+
|
| 1260 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
| 1261 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
| 1262 |
+
save time on indexing the model later."""
|
| 1263 |
+
|
| 1264 |
+
# Because writing user metadata to the file can change the result of
|
| 1265 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
| 1266 |
+
# calculating the hash, as they are meant to be immutable
|
| 1267 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
| 1268 |
+
|
| 1269 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
| 1270 |
+
b = BytesIO(bytes)
|
| 1271 |
+
|
| 1272 |
+
model_hash = addnet_hash_safetensors(b)
|
| 1273 |
+
legacy_hash = addnet_hash_legacy(b)
|
| 1274 |
+
return model_hash, legacy_hash
|
| 1275 |
+
|
| 1276 |
+
def addnet_hash_safetensors(b):
|
| 1277 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
| 1278 |
+
hash_sha256 = hashlib.sha256()
|
| 1279 |
+
blksize = 1024 * 1024
|
| 1280 |
+
|
| 1281 |
+
b.seek(0)
|
| 1282 |
+
header = b.read(8)
|
| 1283 |
+
n = int.from_bytes(header, "little")
|
| 1284 |
+
|
| 1285 |
+
offset = n + 8
|
| 1286 |
+
b.seek(offset)
|
| 1287 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
| 1288 |
+
hash_sha256.update(chunk)
|
| 1289 |
+
|
| 1290 |
+
return hash_sha256.hexdigest()
|
| 1291 |
+
|
| 1292 |
+
def addnet_hash_legacy(b):
|
| 1293 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
| 1294 |
+
m = hashlib.sha256()
|
| 1295 |
+
|
| 1296 |
+
b.seek(0x100000)
|
| 1297 |
+
m.update(b.read(0x10000))
|
| 1298 |
+
return m.hexdigest()[0:8]
|
microsoftexcel-supermerger/scripts/mergers/xyplot.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import copy
|
| 6 |
+
import csv
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from modules import images
|
| 9 |
+
from modules.shared import opts
|
| 10 |
+
from scripts.mergers.mergers import TYPES,smerge,simggen,filenamecutter,draw_origin,wpreseter
|
| 11 |
+
from scripts.mergers.model_util import usemodelgen
|
| 12 |
+
|
| 13 |
+
hear = True
|
| 14 |
+
hearm = False
|
| 15 |
+
|
| 16 |
+
state_mergen = False
|
| 17 |
+
|
| 18 |
+
numadepth = []
|
| 19 |
+
|
| 20 |
+
def freezetime():
|
| 21 |
+
global state_mergen
|
| 22 |
+
state_mergen = True
|
| 23 |
+
|
| 24 |
+
def numanager(normalstart,xtype,xmen,ytype,ymen,esettings,
|
| 25 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 26 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 27 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
|
| 28 |
+
global numadepth
|
| 29 |
+
grids = []
|
| 30 |
+
sep = "|"
|
| 31 |
+
|
| 32 |
+
if sep in xmen:
|
| 33 |
+
xmens = xmen.split(sep)
|
| 34 |
+
xmen = xmens[0]
|
| 35 |
+
if seed =="-1": seed = str(random.randrange(4294967294))
|
| 36 |
+
for men in xmens[1:]:
|
| 37 |
+
numaker(xtype,men,ytype,ymen,esettings,
|
| 38 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 39 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 40 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
|
| 41 |
+
elif sep in ymen:
|
| 42 |
+
ymens = ymen.split(sep)
|
| 43 |
+
ymen = ymens[0]
|
| 44 |
+
if seed =="-1": seed = str(random.randrange(4294967294))
|
| 45 |
+
for men in ymens[1:]:
|
| 46 |
+
numaker(xtype,xmen,ytype,men,esettings,
|
| 47 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 48 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 49 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
|
| 50 |
+
|
| 51 |
+
if normalstart:
|
| 52 |
+
result,currentmodel,xyimage,a,b,c= sgenxyplot(xtype,xmen,ytype,ymen,esettings,
|
| 53 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
|
| 54 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 55 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 56 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size)
|
| 57 |
+
if xyimage is not None:grids =[xyimage[0]]
|
| 58 |
+
else:print(result)
|
| 59 |
+
else:
|
| 60 |
+
if numadepth ==[]:
|
| 61 |
+
return "no reservation",*[None]*5
|
| 62 |
+
result=currentmodel=xyimage=a=b=c = None
|
| 63 |
+
|
| 64 |
+
while True:
|
| 65 |
+
for i,row in enumerate(numadepth):
|
| 66 |
+
if row[1] =="waiting":
|
| 67 |
+
numadepth[i][1] = "Operating"
|
| 68 |
+
try:
|
| 69 |
+
result,currentmodel,xyimage,a,b,c = sgenxyplot(*row[2:])
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(e)
|
| 72 |
+
numadepth[i][1] = "Error"
|
| 73 |
+
else:
|
| 74 |
+
if xyimage is not None:
|
| 75 |
+
grids.append(xyimage[0])
|
| 76 |
+
numadepth[i][1] = "Finished"
|
| 77 |
+
else:
|
| 78 |
+
print(result)
|
| 79 |
+
numadepth[i][1] = "Error"
|
| 80 |
+
wcounter = 0
|
| 81 |
+
for row in numadepth:
|
| 82 |
+
if row[1] != "waiting":
|
| 83 |
+
wcounter += 1
|
| 84 |
+
if wcounter == len(numadepth):
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
return result,currentmodel,grids,a,b,c
|
| 88 |
+
|
| 89 |
+
def numaker(xtype,xmen,ytype,ymen,esettings,
|
| 90 |
+
#msettings=[weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets]
|
| 91 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
|
| 92 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 93 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 94 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
|
| 95 |
+
global numadepth
|
| 96 |
+
numadepth.append([len(numadepth)+1,"waiting",xtype,xmen,ytype,ymen,esettings,
|
| 97 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
|
| 98 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 99 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 100 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size])
|
| 101 |
+
return numalistmaker(copy.deepcopy(numadepth))
|
| 102 |
+
|
| 103 |
+
def nulister(redel):
|
| 104 |
+
global numadepth
|
| 105 |
+
if redel == False:
|
| 106 |
+
return numalistmaker(copy.deepcopy(numadepth))
|
| 107 |
+
if redel ==-1:
|
| 108 |
+
numadepth = []
|
| 109 |
+
else:
|
| 110 |
+
try:del numadepth[int(redel-1)]
|
| 111 |
+
except Exception as e:print(e)
|
| 112 |
+
return numalistmaker(copy.deepcopy(numadepth))
|
| 113 |
+
|
| 114 |
+
def numalistmaker(numa):
|
| 115 |
+
if numa ==[]: return [["no data","",""],]
|
| 116 |
+
for i,r in enumerate(numa):
|
| 117 |
+
r[2] = TYPES[int(r[2])]
|
| 118 |
+
r[4] = TYPES[int(r[4])]
|
| 119 |
+
numa[i] = r[0:6]+r[8:11]+r[12:16]+r[6:8]
|
| 120 |
+
return numa
|
| 121 |
+
|
| 122 |
+
def caster(news,hear):
|
| 123 |
+
if hear: print(news)
|
| 124 |
+
|
| 125 |
+
def sgenxyplot(xtype,xmen,ytype,ymen,esettings,
|
| 126 |
+
weights_a,weights_b,model_a,model_b,model_c,alpha,beta,mode,calcmode,
|
| 127 |
+
useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor,
|
| 128 |
+
prompt,nprompt,steps,sampler,cfg,seed,w,h,
|
| 129 |
+
hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size):
|
| 130 |
+
global hear
|
| 131 |
+
esettings = " ".join(esettings)
|
| 132 |
+
#type[0:none,1:aplha,2:beta,3:seed,4:mbw,5:model_A,6:model_B,7:model_C,8:pinpoint 9:deep]
|
| 133 |
+
xtype = TYPES[xtype]
|
| 134 |
+
ytype = TYPES[ytype]
|
| 135 |
+
if ytype == "none": ymen = ""
|
| 136 |
+
|
| 137 |
+
modes=["Weight" ,"Add" ,"Triple","Twice"]
|
| 138 |
+
xs=ys=0
|
| 139 |
+
weights_a_in=weights_b_in="0"
|
| 140 |
+
|
| 141 |
+
deepprint = True if "print change" in esettings else False
|
| 142 |
+
|
| 143 |
+
def castall(hear):
|
| 144 |
+
if hear :print(f"xmen:{xmen}, ymen:{ymen}, xtype:{xtype}, ytype:{ytype}, weights_a:{weights_a_in}, weights_b:{weights_b_in}, model_A:{model_a},model_B :{model_b}, model_C:{model_c}, alpha:{alpha},\
|
| 145 |
+
beta :{beta}, mode:{mode}, blocks:{useblocks}")
|
| 146 |
+
|
| 147 |
+
pinpoint = "pinpoint blocks" in xtype or "pinpoint blocks" in ytype
|
| 148 |
+
usebeta = modes[2] in mode or modes[3] in mode
|
| 149 |
+
|
| 150 |
+
#check and adjust format
|
| 151 |
+
print(f"XY plot start, mode:{mode}, X: {xtype}, Y: {ytype}, MBW: {useblocks}")
|
| 152 |
+
castall(hear)
|
| 153 |
+
None5 = [None,None,None,None,None]
|
| 154 |
+
if xmen =="": return "ERROR: parameter X is empty",*None5
|
| 155 |
+
if ymen =="" and not ytype=="none": return "ERROR: parameter Y is empty",*None5
|
| 156 |
+
if model_a ==[] and not ("model_A" in xtype or "model_A" in ytype):return f"ERROR: model_A is not selected",*None5
|
| 157 |
+
if model_b ==[] and not ("model_B" in xtype or "model_B" in ytype):return f"ERROR: model_B is not selected",*None5
|
| 158 |
+
if model_c ==[] and usebeta and not ("model_C" in xtype or "model_C" in ytype):return "ERROR: model_C is not selected",*None5
|
| 159 |
+
if xtype == ytype: return "ERROR: same type selected for X,Y",*None5
|
| 160 |
+
|
| 161 |
+
if useblocks:
|
| 162 |
+
weights_a_in=wpreseter(weights_a,wpresets)
|
| 163 |
+
weights_b_in=wpreseter(weights_b,wpresets)
|
| 164 |
+
|
| 165 |
+
#for X only plot, use same seed
|
| 166 |
+
if seed == -1: seed = int(random.randrange(4294967294))
|
| 167 |
+
|
| 168 |
+
#for XY plot, use same seed
|
| 169 |
+
def dicedealer(zs):
|
| 170 |
+
for i,z in enumerate(zs):
|
| 171 |
+
if z =="-1": zs[i] = str(random.randrange(4294967294))
|
| 172 |
+
print(f"the die was thrown : {zs}")
|
| 173 |
+
|
| 174 |
+
#adjust parameters, alpha,beta,models,seed: list of single parameters, mbw(no beta):list of text,mbw(usebeta); list of pair text
|
| 175 |
+
def adjuster(zmen,ztype,aztype):
|
| 176 |
+
if "mbw" in ztype or "prompt" in ztype:#men separated by newline
|
| 177 |
+
zs = zmen.splitlines()
|
| 178 |
+
caster(zs,hear)
|
| 179 |
+
if "mbw alpha and beta" in ztype:
|
| 180 |
+
zs = [zs[i:i+2] for i in range(0,len(zs),2)]
|
| 181 |
+
caster(zs,hear)
|
| 182 |
+
elif "elemental" in ztype:
|
| 183 |
+
zs = zmen.split("\n\n")
|
| 184 |
+
else:
|
| 185 |
+
if "pinpoint element" in ztype:
|
| 186 |
+
zmen = zmen.replace("\n",",")
|
| 187 |
+
if "effective" in ztype:
|
| 188 |
+
zmen = ","+zmen
|
| 189 |
+
zmen = zmen.replace("\n",",")
|
| 190 |
+
zs = [z.strip() for z in zmen.split(',')]
|
| 191 |
+
caster(zs,hear)
|
| 192 |
+
if "alpha" in ztype and "effective" in aztype:
|
| 193 |
+
zs = [zs[0]]
|
| 194 |
+
if "seed" in ztype:dicedealer(zs)
|
| 195 |
+
if "alpha" == ztype or "beta" == ztype:
|
| 196 |
+
oz = []
|
| 197 |
+
for z in zs:
|
| 198 |
+
try:
|
| 199 |
+
float(z)
|
| 200 |
+
oz.append(z)
|
| 201 |
+
except:
|
| 202 |
+
pass
|
| 203 |
+
zs = oz
|
| 204 |
+
return zs
|
| 205 |
+
|
| 206 |
+
xs = adjuster(xmen,xtype,ytype)
|
| 207 |
+
ys = adjuster(ymen,ytype,xtype)
|
| 208 |
+
|
| 209 |
+
#in case beta selected but mode is Weight sum or Add or Diff
|
| 210 |
+
if ("beta" in xtype or "beta" in ytype) and (not usebeta and "tensor" not in calcmode):
|
| 211 |
+
mode = modes[3]
|
| 212 |
+
print(f"{modes[3]} mode automatically selected)")
|
| 213 |
+
|
| 214 |
+
#in case mbw or pinpoint selected but useblocks not chekced
|
| 215 |
+
if ("mbw" in xtype or "pinpoint blocks" in xtype) and not useblocks:
|
| 216 |
+
useblocks = True
|
| 217 |
+
print(f"MBW mode enabled")
|
| 218 |
+
|
| 219 |
+
if ("mbw" in ytype or "pinpoint blocks" in ytype) and not useblocks:
|
| 220 |
+
useblocks = True
|
| 221 |
+
print(f"MBW mode enabled")
|
| 222 |
+
|
| 223 |
+
xyimage=[]
|
| 224 |
+
xcount =ycount=0
|
| 225 |
+
allcount = len(xs)*len(ys)
|
| 226 |
+
|
| 227 |
+
#for STOP XY bottun
|
| 228 |
+
flag = False
|
| 229 |
+
global state_mergen
|
| 230 |
+
state_mergen = False
|
| 231 |
+
|
| 232 |
+
#type[0:none,1:aplha,2:beta,3:seed,4:mbw,5:model_A,6:model_B,7:model_C,8:pinpoint ]
|
| 233 |
+
blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
| 234 |
+
#format ,IN00 IN03,IN04-IN09,OUT4,OUT05
|
| 235 |
+
def weightsdealer(x,xtype,y,weights):
|
| 236 |
+
caster(f"weights from : {weights}",hear)
|
| 237 |
+
zz = x if "pinpoint blocks" in xtype else y
|
| 238 |
+
za = y if "pinpoint blocks" in xtype else x
|
| 239 |
+
zz = [z.strip() for z in zz.split(' ')]
|
| 240 |
+
weights_t = [w.strip() for w in weights.split(',')]
|
| 241 |
+
if zz[0]!="NOT":
|
| 242 |
+
flagger=[False]*26
|
| 243 |
+
changer = True
|
| 244 |
+
else:
|
| 245 |
+
flagger=[True]*26
|
| 246 |
+
changer = False
|
| 247 |
+
for z in zz:
|
| 248 |
+
if z =="NOT":continue
|
| 249 |
+
if "-" in z:
|
| 250 |
+
zt = [zt.strip() for zt in z.split('-')]
|
| 251 |
+
if blockid.index(zt[1]) > blockid.index(zt[0]):
|
| 252 |
+
flagger[blockid.index(zt[0]):blockid.index(zt[1])+1] = [changer]*(blockid.index(zt[1])-blockid.index(zt[0])+1)
|
| 253 |
+
else:
|
| 254 |
+
flagger[blockid.index(zt[1]):blockid.index(zt[0])+1] = [changer]*(blockid.index(zt[0])-blockid.index(zt[1])+1)
|
| 255 |
+
else:
|
| 256 |
+
flagger[blockid.index(z)] =changer
|
| 257 |
+
for i,f in enumerate(flagger):
|
| 258 |
+
if f:weights_t[i]=za
|
| 259 |
+
outext = ",".join(weights_t)
|
| 260 |
+
caster(f"weights changed: {outext}",hear)
|
| 261 |
+
return outext
|
| 262 |
+
|
| 263 |
+
def abdealer(z):
|
| 264 |
+
if " " in z:return z.split(" ")[0],z.split(" ")[1]
|
| 265 |
+
return z,z
|
| 266 |
+
|
| 267 |
+
def xydealer(z,zt,azt):
|
| 268 |
+
nonlocal alpha,beta,seed,weights_a_in,weights_b_in,model_a,model_b,model_c,deep,calcmode,prompt
|
| 269 |
+
if pinpoint or "pinpoint element" in zt or "effective" in zt:return
|
| 270 |
+
if "mbw" in zt:
|
| 271 |
+
def weightser(z):return z, z.split(',',1)[0]
|
| 272 |
+
if "mbw alpha and beta" in zt:
|
| 273 |
+
weights_a_in,alpha = weightser(wpreseter(z[0],wpresets))
|
| 274 |
+
weights_b_in,beta = weightser(wpreseter(z[1],wpresets))
|
| 275 |
+
return
|
| 276 |
+
elif "alpha" in zt:
|
| 277 |
+
weights_a_in,alpha = weightser(wpreseter(z,wpresets))
|
| 278 |
+
return
|
| 279 |
+
else:
|
| 280 |
+
weights_b_in,beta = weightser(wpreseter(z,wpresets))
|
| 281 |
+
return
|
| 282 |
+
if "and" in zt:
|
| 283 |
+
alpha,beta = abdealer(z)
|
| 284 |
+
return
|
| 285 |
+
if "alpha" in zt and not "pinpoint element" in azt:alpha = z
|
| 286 |
+
if "beta" in zt: beta = z
|
| 287 |
+
if "seed" in zt:seed = int(z)
|
| 288 |
+
if "model_A" in zt:model_a = z
|
| 289 |
+
if "model_B" in zt:model_b = z
|
| 290 |
+
if "model_C" in zt:model_c = z
|
| 291 |
+
if "elemental" in zt:deep = z
|
| 292 |
+
if "calcmode" in zt:calcmode = z
|
| 293 |
+
if "prompt" in zt:prompt = z
|
| 294 |
+
|
| 295 |
+
# plot start
|
| 296 |
+
for y in ys:
|
| 297 |
+
xydealer(y,ytype,xtype)
|
| 298 |
+
xcount = 0
|
| 299 |
+
for x in xs:
|
| 300 |
+
xydealer(x,xtype,ytype)
|
| 301 |
+
if ("alpha" in xtype or "alpha" in ytype) and pinpoint:
|
| 302 |
+
weights_a_in = weightsdealer(x,xtype,y,weights_a)
|
| 303 |
+
weights_b_in = weights_b
|
| 304 |
+
if ("beta" in xtype or "beta" in ytype) and pinpoint:
|
| 305 |
+
weights_b_in = weightsdealer(x,xtype,y,weights_b)
|
| 306 |
+
weights_a_in =weights_a
|
| 307 |
+
if "pinpoint element" in xtype or "effective" in xtype:
|
| 308 |
+
deep_in = deep +","+ str(x)+":"+ str(y)
|
| 309 |
+
elif "pinpoint element" in ytype or "effective" in ytype:
|
| 310 |
+
deep_in = deep +","+ str(y)+":"+ str(x)
|
| 311 |
+
else:
|
| 312 |
+
deep_in = deep
|
| 313 |
+
|
| 314 |
+
print(f"XY plot: X: {xtype}, {str(x)}, Y: {ytype}, {str(y)} ({xcount+ycount*len(xs)+1}/{allcount})")
|
| 315 |
+
if not (xtype=="seed" and xcount > 0):
|
| 316 |
+
_ , currentmodel,modelid,theta_0,_=smerge(weights_a_in,weights_b_in, model_a,model_b,model_c, float(alpha),float(beta),mode,calcmode,
|
| 317 |
+
useblocks,"","",id_sets,False,deep_in,tensor,deepprint = deepprint)
|
| 318 |
+
usemodelgen(theta_0,model_a,currentmodel)
|
| 319 |
+
# simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,mergeinfo="",id_sets=[],modelid = "no id"):
|
| 320 |
+
image_temp = simggen(prompt, nprompt, steps, sampler, cfg, seed, w, h,hireson,hrupscaler,hr2ndsteps,denoise_str,hr_scale,batch_size,currentmodel,id_sets,modelid)
|
| 321 |
+
xyimage.append(image_temp[0][0])
|
| 322 |
+
xcount+=1
|
| 323 |
+
if state_mergen:
|
| 324 |
+
flag = True
|
| 325 |
+
break
|
| 326 |
+
ycount+=1
|
| 327 |
+
if flag:break
|
| 328 |
+
|
| 329 |
+
if flag and ycount ==1:
|
| 330 |
+
xs = xs[:xcount]
|
| 331 |
+
ys = [ys[0],]
|
| 332 |
+
print(f"stopped at x={xcount},y={ycount}")
|
| 333 |
+
elif flag:
|
| 334 |
+
ys=ys[:ycount]
|
| 335 |
+
print(f"stopped at x={xcount},y={ycount}")
|
| 336 |
+
|
| 337 |
+
if "mbw alpha and beta" in xtype: xs = [f"alpha:({x[0]}),beta({x[1]})" for x in xs ]
|
| 338 |
+
if "mbw alpha and beta" in ytype: ys = [f"alpha:({y[0]}),beta({y[1]})" for y in ys ]
|
| 339 |
+
|
| 340 |
+
xs[0]=xtype+" = "+xs[0] #draw X label
|
| 341 |
+
if ytype!=TYPES[0] or "model" in ytype:ys[0]=ytype+" = "+ys[0] #draw Y label
|
| 342 |
+
|
| 343 |
+
if ys==[""]:ys = [" "]
|
| 344 |
+
|
| 345 |
+
if "effective" in xtype or "effective" in ytype:
|
| 346 |
+
xyimage,xs,ys = effectivechecker(xyimage,xs,ys,model_a,model_b,esettings)
|
| 347 |
+
|
| 348 |
+
if not "grid" in esettings:
|
| 349 |
+
gridmodel= makegridmodelname(model_a, model_b,model_c, useblocks,mode,xtype,ytype,alpha,beta,weights_a,weights_b,usebeta)
|
| 350 |
+
grid = smakegrid(xyimage,xs,ys,gridmodel,image_temp[4])
|
| 351 |
+
xyimage.insert(0,grid)
|
| 352 |
+
|
| 353 |
+
state_mergen = False
|
| 354 |
+
return "Finished",currentmodel,xyimage,*image_temp[1:4]
|
| 355 |
+
|
| 356 |
+
def smakegrid(imgs,xs,ys,currentmodel,p):
|
| 357 |
+
ver_texts = [[images.GridAnnotation(y)] for y in ys]
|
| 358 |
+
hor_texts = [[images.GridAnnotation(x)] for x in xs]
|
| 359 |
+
|
| 360 |
+
w, h = imgs[0].size
|
| 361 |
+
grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
|
| 362 |
+
|
| 363 |
+
for i, img in enumerate(imgs):
|
| 364 |
+
grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
|
| 365 |
+
|
| 366 |
+
grid = images.draw_grid_annotations(grid,w,h, hor_texts, ver_texts)
|
| 367 |
+
grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
|
| 368 |
+
if opts.grid_save:
|
| 369 |
+
images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
|
| 370 |
+
|
| 371 |
+
return grid
|
| 372 |
+
|
| 373 |
+
def makegridmodelname(model_a, model_b,model_c, useblocks,mode,xtype,ytype,alpha,beta,wa,wb,usebeta):
|
| 374 |
+
model_a=filenamecutter(model_a)
|
| 375 |
+
model_b=filenamecutter(model_b)
|
| 376 |
+
model_c=filenamecutter(model_c)
|
| 377 |
+
|
| 378 |
+
if not usebeta:beta,wb = "not used","not used"
|
| 379 |
+
vals = ""
|
| 380 |
+
modes=["Weight" ,"Add" ,"Triple","Twice"]
|
| 381 |
+
|
| 382 |
+
if "mbw" in xtype:
|
| 383 |
+
if "alpha" in xtype:wa = "X"
|
| 384 |
+
if usebeta or " beta" in xtype:wb = "X"
|
| 385 |
+
|
| 386 |
+
if "mbw" in ytype:
|
| 387 |
+
if "alpha" in ytype:wa = "Y"
|
| 388 |
+
if usebeta or " beta" in ytype:wb = "Y"
|
| 389 |
+
|
| 390 |
+
wa = "alpha = " + wa
|
| 391 |
+
wb = "beta = " + wb
|
| 392 |
+
|
| 393 |
+
x = 50
|
| 394 |
+
while len(wa) > x:
|
| 395 |
+
wa = wa[:x] + '\n' + wa[x:]
|
| 396 |
+
x = x + 50
|
| 397 |
+
|
| 398 |
+
x = 50
|
| 399 |
+
while len(wb) > x:
|
| 400 |
+
wb = wb[:x] + '\n' + wb[x:]
|
| 401 |
+
x = x + 50
|
| 402 |
+
|
| 403 |
+
if "model" in xtype:
|
| 404 |
+
if "A" in xtype:model_a = "model A"
|
| 405 |
+
elif "B" in xtype:model_b="model B"
|
| 406 |
+
elif "C" in xtype:model_c="model C"
|
| 407 |
+
|
| 408 |
+
if "model" in ytype:
|
| 409 |
+
if "A" in ytype:model_a = "model A"
|
| 410 |
+
elif "B" in ytype:model_b="model B"
|
| 411 |
+
elif "C" in ytype:model_c="model C"
|
| 412 |
+
|
| 413 |
+
if modes[1] in mode:
|
| 414 |
+
currentmodel =f"{model_a} \n {model_b} - {model_c})\n x alpha"
|
| 415 |
+
elif modes[2] in mode:
|
| 416 |
+
currentmodel =f"{model_a} x \n(1-alpha-beta) {model_b} x alpha \n+ {model_c} x beta"
|
| 417 |
+
elif modes[3] in mode:
|
| 418 |
+
currentmodel =f"({model_a} x(1-alpha) \n + {model_b} x alpha)*(1-beta)\n+ {model_c} x beta"
|
| 419 |
+
else:
|
| 420 |
+
currentmodel =f"{model_a} x (1-alpha) \n {model_b} x alpha"
|
| 421 |
+
|
| 422 |
+
if "alpha" in xtype:alpha = "X"
|
| 423 |
+
if "beta" in xtype:beta = "X"
|
| 424 |
+
if "alpha" in ytype:alpha = "Y"
|
| 425 |
+
if "beta" in ytype:beta = "Y"
|
| 426 |
+
|
| 427 |
+
if "mbw" in xtype:
|
| 428 |
+
if "alpha" in xtype: alpha = "X"
|
| 429 |
+
if "beta" in xtype or usebeta: beta = "X"
|
| 430 |
+
|
| 431 |
+
if "mbw" in ytype:
|
| 432 |
+
if "alpha" in ytype: alpha = "Y"
|
| 433 |
+
if "beta" in ytype or usebeta: beta = "Y"
|
| 434 |
+
|
| 435 |
+
vals = f"\nalpha = {alpha},beta = {beta}" if not useblocks else f"\n{wa}\n{wb}"
|
| 436 |
+
|
| 437 |
+
currentmodel = currentmodel+vals
|
| 438 |
+
return currentmodel
|
| 439 |
+
|
| 440 |
+
def effectivechecker(imgs,xs,ys,model_a,model_b,esettings):
|
| 441 |
+
diffs = []
|
| 442 |
+
outnum =[]
|
| 443 |
+
im1 = np.array(imgs[0])
|
| 444 |
+
|
| 445 |
+
model_a = filenamecutter(model_a)
|
| 446 |
+
model_b = filenamecutter(model_b)
|
| 447 |
+
dir = os.path.join(opts.outdir_txt2img_samples,f"{model_a+model_b}","difgif")
|
| 448 |
+
|
| 449 |
+
if "gif" in esettings:
|
| 450 |
+
try:
|
| 451 |
+
os.makedirs(dir)
|
| 452 |
+
except FileExistsError:
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
+
ls,ss = (xs.copy(),ys.copy()) if len(xs) > len(ys) else (ys.copy(),xs.copy())
|
| 456 |
+
|
| 457 |
+
for i in range(len(imgs)-1):
|
| 458 |
+
im2 = np.array(imgs[i+1])
|
| 459 |
+
|
| 460 |
+
abs_diff = cv2.absdiff(im2 , im1)
|
| 461 |
+
|
| 462 |
+
abs_diff_t = cv2.threshold(abs_diff, 5, 255, cv2.THRESH_BINARY)[1]
|
| 463 |
+
res = abs_diff_t.astype(np.uint8)
|
| 464 |
+
percentage = (np.count_nonzero(res) * 100)/ res.size
|
| 465 |
+
abs_diff = cv2.bitwise_not(abs_diff)
|
| 466 |
+
outnum.append(percentage)
|
| 467 |
+
|
| 468 |
+
abs_diff = Image.fromarray(abs_diff)
|
| 469 |
+
|
| 470 |
+
diffs.append(abs_diff)
|
| 471 |
+
|
| 472 |
+
if "gif" in esettings:
|
| 473 |
+
gifpath = gifpath_t = os.path.join(dir,ls[i+1].replace(":","_")+".gif")
|
| 474 |
+
|
| 475 |
+
is_file = os.path.isfile(gifpath)
|
| 476 |
+
j = 0
|
| 477 |
+
while is_file:
|
| 478 |
+
gifpath = gifpath_t.replace(".gif",f"_{j}.gif")
|
| 479 |
+
print(gifpath)
|
| 480 |
+
is_file = os.path.isfile(gifpath)
|
| 481 |
+
j = j + 1
|
| 482 |
+
|
| 483 |
+
imgs[0].save(gifpath, save_all=True, append_images=[imgs[i+1]], optimize=False, duration=1000, loop=0)
|
| 484 |
+
|
| 485 |
+
nums = []
|
| 486 |
+
outs = []
|
| 487 |
+
|
| 488 |
+
ls = ls[1:]
|
| 489 |
+
for i in range(len(ls)):
|
| 490 |
+
nums.append([ls[i],outnum[i]])
|
| 491 |
+
ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
|
| 492 |
+
|
| 493 |
+
if "csv" in esettings:
|
| 494 |
+
try:
|
| 495 |
+
os.makedirs(dir)
|
| 496 |
+
except FileExistsError:
|
| 497 |
+
pass
|
| 498 |
+
filepath = os.path.join(dir, f"{model_a+model_b}.csv")
|
| 499 |
+
with open(filepath, "a", newline="") as f:
|
| 500 |
+
writer = csv.writer(f)
|
| 501 |
+
writer.writerows(nums)
|
| 502 |
+
|
| 503 |
+
if len(ys) > len (xs):
|
| 504 |
+
for diff,img in zip(diffs,imgs[1:]):
|
| 505 |
+
outs.append(diff)
|
| 506 |
+
outs.append(img)
|
| 507 |
+
outs.append(imgs[0])
|
| 508 |
+
ss = ["diff",ss[0],"source"]
|
| 509 |
+
return outs,ss,ls
|
| 510 |
+
else:
|
| 511 |
+
outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
|
| 512 |
+
ss = ["source",ss[0],"diff"]
|
| 513 |
+
return outs,ls,ss
|
microsoftexcel-supermerger/scripts/supermerger.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import os.path
|
| 5 |
+
import re
|
| 6 |
+
import shutil
|
| 7 |
+
from importlib import reload
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from modules import (devices, script_callbacks, scripts, sd_hijack, sd_models,sd_vae, shared)
|
| 11 |
+
from modules.scripts import basedir
|
| 12 |
+
from modules.sd_models import checkpoints_loaded
|
| 13 |
+
from modules.shared import opts
|
| 14 |
+
from modules.ui import create_output_panel, create_refresh_button
|
| 15 |
+
import scripts.mergers.mergers
|
| 16 |
+
import scripts.mergers.pluslora
|
| 17 |
+
import scripts.mergers.xyplot
|
| 18 |
+
reload(scripts.mergers.mergers) # update without restarting web-ui.bat
|
| 19 |
+
reload(scripts.mergers.xyplot)
|
| 20 |
+
reload(scripts.mergers.pluslora)
|
| 21 |
+
import csv
|
| 22 |
+
import scripts.mergers.pluslora as pluslora
|
| 23 |
+
from scripts.mergers.mergers import (TYPESEG, freezemtime, rwmergelog, simggen,smergegen)
|
| 24 |
+
from scripts.mergers.xyplot import freezetime, nulister, numaker, numanager
|
| 25 |
+
|
| 26 |
+
gensets=argparse.Namespace()
|
| 27 |
+
|
| 28 |
+
def on_ui_train_tabs(params):
|
| 29 |
+
txt2img_preview_params=params.txt2img_preview_params
|
| 30 |
+
gensets.txt2img_preview_params=txt2img_preview_params
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
path_root = basedir()
|
| 34 |
+
|
| 35 |
+
def on_ui_tabs():
|
| 36 |
+
weights_presets=""
|
| 37 |
+
userfilepath = os.path.join(path_root, "scripts","mbwpresets.txt")
|
| 38 |
+
if os.path.isfile(userfilepath):
|
| 39 |
+
try:
|
| 40 |
+
with open(userfilepath) as f:
|
| 41 |
+
weights_presets = f.read()
|
| 42 |
+
filepath = userfilepath
|
| 43 |
+
except OSError as e:
|
| 44 |
+
pass
|
| 45 |
+
else:
|
| 46 |
+
filepath = os.path.join(path_root, "scripts","mbwpresets_master.txt")
|
| 47 |
+
try:
|
| 48 |
+
with open(filepath) as f:
|
| 49 |
+
weights_presets = f.read()
|
| 50 |
+
shutil.copyfile(filepath, userfilepath)
|
| 51 |
+
except OSError as e:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
with gr.Blocks() as supermergerui:
|
| 55 |
+
with gr.Tab("Merge"):
|
| 56 |
+
with gr.Row().style(equal_height=False):
|
| 57 |
+
with gr.Column(scale = 3):
|
| 58 |
+
gr.HTML(value="<p>Merge models and load it for generation</p>")
|
| 59 |
+
|
| 60 |
+
with gr.Row():
|
| 61 |
+
model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model A",interactive=True)
|
| 62 |
+
create_refresh_button(model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 63 |
+
|
| 64 |
+
model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model B",interactive=True)
|
| 65 |
+
create_refresh_button(model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 66 |
+
|
| 67 |
+
model_c = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model C",interactive=True)
|
| 68 |
+
create_refresh_button(model_c, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 69 |
+
|
| 70 |
+
mode = gr.Radio(label = "Merge Mode",choices = ["Weight sum:A*(1-alpha)+B*alpha", "Add difference:A+(B-C)*alpha",
|
| 71 |
+
"Triple sum:A*(1-alpha-beta)+B*alpha+C*beta",
|
| 72 |
+
"sum Twice:(A*(1-alpha)+B*alpha)*(1-beta)+C*beta",
|
| 73 |
+
], value = "Weight sum:A*(1-alpha)+B*alpha")
|
| 74 |
+
calcmode = gr.Radio(label = "Calcutation Mode",choices = ["normal", "cosineA", "cosineB", "smoothAdd","tensor"], value = "normal")
|
| 75 |
+
with gr.Row():
|
| 76 |
+
useblocks = gr.Checkbox(label="use MBW")
|
| 77 |
+
base_alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=0.5)
|
| 78 |
+
base_beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=0.25)
|
| 79 |
+
#weights = gr.Textbox(label="weights,base alpha,IN00,IN02,...IN11,M00,OUT00,...,OUT11",lines=2,value="0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5")
|
| 80 |
+
|
| 81 |
+
with gr.Row():
|
| 82 |
+
merge = gr.Button(elem_id="model_merger_merge", value="Merge!",variant='primary')
|
| 83 |
+
mergeandgen = gr.Button(elem_id="model_merger_merge", value="Merge&Gen",variant='primary')
|
| 84 |
+
gen = gr.Button(elem_id="model_merger_merge", value="Gen",variant='primary')
|
| 85 |
+
stopmerge = gr.Button(elem_id="stopmerge", value="Stop",variant='primary')
|
| 86 |
+
with gr.Row():
|
| 87 |
+
with gr.Column(scale = 4):
|
| 88 |
+
save_sets = gr.CheckboxGroup(["save model", "overwrite","safetensors","fp16","save metadata"], value=["safetensors"], label="save settings")
|
| 89 |
+
with gr.Column(scale = 2):
|
| 90 |
+
id_sets = gr.CheckboxGroup(["image", "PNG info"], label="write merged model ID to")
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(min_width = 50, scale=2):
|
| 93 |
+
with gr.Row():
|
| 94 |
+
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="model_converter_custom_name")
|
| 95 |
+
mergeid = gr.Textbox(label="merge from ID", elem_id="model_converter_custom_name",value = "-1")
|
| 96 |
+
with gr.Column(min_width = 50, scale=1):
|
| 97 |
+
with gr.Row():s_reverse= gr.Button(value="Set from ID(-1 for last)",variant='primary')
|
| 98 |
+
|
| 99 |
+
with gr.Accordion("Restore faces, Tiling, Hires. fix, Batch size",open = False):
|
| 100 |
+
batch_size = denois_str = gr.Slider(minimum=0, maximum=8, step=1, label='Batch size', value=1, elem_id="sm_txt2img_batch_size")
|
| 101 |
+
genoptions = gr.CheckboxGroup(label = "Gen Options",choices=["Restore faces", "Tiling", "Hires. fix"], visible = True,interactive=True,type="value")
|
| 102 |
+
with gr.Row(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
| 103 |
+
hrupscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
| 104 |
+
hr2ndsteps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
| 105 |
+
denois_str = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
| 106 |
+
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
| 107 |
+
|
| 108 |
+
hiresfix = [genoptions,hrupscaler,hr2ndsteps,denois_str,hr_scale]
|
| 109 |
+
|
| 110 |
+
with gr.Accordion("Elemental Merge",open = False):
|
| 111 |
+
with gr.Row():
|
| 112 |
+
esettings1 = gr.CheckboxGroup(label = "settings",choices=["print change"],type="value",interactive=True)
|
| 113 |
+
with gr.Row():
|
| 114 |
+
deep = gr.Textbox(label="Blocks:Element:Ratio,Blocks:Element:Ratio,...",lines=2,value="")
|
| 115 |
+
|
| 116 |
+
with gr.Accordion("Tensor Merge",open = False,visible=False):
|
| 117 |
+
tensor = gr.Textbox(label="Blocks:Tensors",lines=2,value="")
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
x_type = gr.Dropdown(label="X type", choices=[x for x in TYPESEG], value="alpha", type="index")
|
| 121 |
+
x_randseednum = gr.Number(value=3, label="number of -1", interactive=True, visible = True)
|
| 122 |
+
xgrid = gr.Textbox(label="Sequential Merge Parameters",lines=3,value="0.25,0.5,0.75")
|
| 123 |
+
y_type = gr.Dropdown(label="Y type", choices=[y for y in TYPESEG], value="none", type="index")
|
| 124 |
+
ygrid = gr.Textbox(label="Y grid (Disabled if blank)",lines=3,value="",visible =False)
|
| 125 |
+
with gr.Row():
|
| 126 |
+
gengrid = gr.Button(elem_id="model_merger_merge", value="Sequential XY Merge and Generation",variant='primary')
|
| 127 |
+
stopgrid = gr.Button(elem_id="model_merger_merge", value="Stop XY",variant='primary')
|
| 128 |
+
s_reserve1 = gr.Button(value="Reserve XY Plot",variant='primary')
|
| 129 |
+
dtrue = gr.Checkbox(value = True, visible = False)
|
| 130 |
+
dfalse = gr.Checkbox(value = False,visible = False)
|
| 131 |
+
dummy_t = gr.Textbox(value = "",visible = False)
|
| 132 |
+
blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
| 133 |
+
|
| 134 |
+
with gr.Column(scale = 2):
|
| 135 |
+
currentmodel = gr.Textbox(label="Current Model",lines=1,value="")
|
| 136 |
+
submit_result = gr.Textbox(label="Message")
|
| 137 |
+
mgallery, mgeninfo, mhtmlinfo, mhtmllog = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
| 138 |
+
with gr.Row(visible = False) as row_inputers:
|
| 139 |
+
inputer = gr.Textbox(label="",lines=1,value="")
|
| 140 |
+
addtox = gr.Button(value="Add to Sequence X")
|
| 141 |
+
addtoy = gr.Button(value="Add to Sequence Y")
|
| 142 |
+
with gr.Row(visible = False) as row_blockids:
|
| 143 |
+
blockids = gr.CheckboxGroup(label = "block IDs",choices=[x for x in blockid],type="value",interactive=True)
|
| 144 |
+
with gr.Row(visible = False) as row_calcmode:
|
| 145 |
+
calcmodes = gr.CheckboxGroup(label = "calcmode",choices=["normal", "cosineA", "cosineB", "smoothAdd","tensor"],type="value",interactive=True)
|
| 146 |
+
with gr.Row(visible = False) as row_checkpoints:
|
| 147 |
+
checkpoints = gr.CheckboxGroup(label = "checkpoint",choices=[x.model_name for x in sd_models.checkpoints_list.values()],type="value",interactive=True)
|
| 148 |
+
with gr.Row(visible = False) as row_esets:
|
| 149 |
+
esettings = gr.CheckboxGroup(label = "effective chekcer settings",choices=["save csv","save anime gif","not save grid","print change"],type="value",interactive=True)
|
| 150 |
+
|
| 151 |
+
with gr.Tab("Weights Setting"):
|
| 152 |
+
with gr.Row():
|
| 153 |
+
setalpha = gr.Button(elem_id="copytogen", value="set to alpha",variant='primary')
|
| 154 |
+
readalpha = gr.Button(elem_id="copytogen", value="read from alpha",variant='primary')
|
| 155 |
+
setbeta = gr.Button(elem_id="copytogen", value="set to beta",variant='primary')
|
| 156 |
+
readbeta = gr.Button(elem_id="copytogen", value="read from beta",variant='primary')
|
| 157 |
+
setx = gr.Button(elem_id="copytogen", value="set to X",variant='primary')
|
| 158 |
+
with gr.Row():
|
| 159 |
+
weights_a = gr.Textbox(label="weights for alpha, base alpha,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5")
|
| 160 |
+
weights_b = gr.Textbox(label="weights,for beta, base beta,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2")
|
| 161 |
+
with gr.Row():
|
| 162 |
+
base= gr.Slider(label="Base", minimum=0, maximum=1, step =0.01, value=0.5)
|
| 163 |
+
in00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 164 |
+
in01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 165 |
+
in02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 166 |
+
in03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 167 |
+
with gr.Row():
|
| 168 |
+
in04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 169 |
+
in05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 170 |
+
in06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 171 |
+
in07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 172 |
+
in08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 173 |
+
in09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 174 |
+
with gr.Row():
|
| 175 |
+
in10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 176 |
+
in11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 177 |
+
mi00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 178 |
+
ou00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 179 |
+
ou01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 180 |
+
ou02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 181 |
+
with gr.Row():
|
| 182 |
+
ou03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 183 |
+
ou04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 184 |
+
ou05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 185 |
+
ou06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 186 |
+
ou07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 187 |
+
ou08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 188 |
+
with gr.Row():
|
| 189 |
+
ou09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 190 |
+
ou10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 191 |
+
ou11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 192 |
+
with gr.Tab("Weights Presets"):
|
| 193 |
+
with gr.Row():
|
| 194 |
+
s_reloadtext = gr.Button(value="Reload Presets",variant='primary')
|
| 195 |
+
s_reloadtags = gr.Button(value="Reload Tags",variant='primary')
|
| 196 |
+
s_savetext = gr.Button(value="Save Presets",variant='primary')
|
| 197 |
+
s_openeditor = gr.Button(value="Open TextEditor",variant='primary')
|
| 198 |
+
weightstags= gr.Textbox(label="available",lines = 2,value=tagdicter(weights_presets),visible =True,interactive =True)
|
| 199 |
+
wpresets= gr.TextArea(label="",value=weights_presets,visible =True,interactive = True)
|
| 200 |
+
|
| 201 |
+
with gr.Tab("Reservation"):
|
| 202 |
+
with gr.Row():
|
| 203 |
+
s_reserve = gr.Button(value="Reserve XY Plot",variant='primary')
|
| 204 |
+
s_reloadreserve = gr.Button(value="Reloat List",variant='primary')
|
| 205 |
+
s_startreserve = gr.Button(value="Start XY plot",variant='primary')
|
| 206 |
+
s_delreserve = gr.Button(value="Delete list(-1 for all)",variant='primary')
|
| 207 |
+
s_delnum = gr.Number(value=1, label="Delete num : ", interactive=True, visible = True,precision =0)
|
| 208 |
+
with gr.Row():
|
| 209 |
+
numaframe = gr.Dataframe(
|
| 210 |
+
headers=["No.","status","xtype","xmenber", "ytype","ymenber","model A","model B","model C","alpha","beta","mode","use MBW","weights alpha","weights beta"],
|
| 211 |
+
row_count=5,)
|
| 212 |
+
# with gr.Tab("manual"):
|
| 213 |
+
# with gr.Row():
|
| 214 |
+
# gr.HTML(value="<p> exampls: Change base alpha from 0.1 to 0.9 <br>0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9<br>If you want to display the original model as well for comparison<br>0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1</p>")
|
| 215 |
+
# gr.HTML(value="<p> For block-by-block merging <br>0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5<br>1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1<br>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1</p>")
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
|
| 219 |
+
currentcache = gr.Textbox(label="Current Cache")
|
| 220 |
+
loadcachelist = gr.Button(elem_id="model_merger_merge", value="Reload Cache List",variant='primary')
|
| 221 |
+
unloadmodel = gr.Button(value="unload model",variant='primary')
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# main ui end
|
| 225 |
+
|
| 226 |
+
with gr.Tab("LoRA", elem_id="tab_lora"):
|
| 227 |
+
pluslora.on_ui_tabs()
|
| 228 |
+
|
| 229 |
+
with gr.Tab("History", elem_id="tab_history"):
|
| 230 |
+
|
| 231 |
+
with gr.Row():
|
| 232 |
+
load_history = gr.Button(value="load_history",variant='primary')
|
| 233 |
+
searchwrods = gr.Textbox(label="",lines=1,value="")
|
| 234 |
+
search = gr.Button(value="search")
|
| 235 |
+
searchmode = gr.Radio(label = "Search Mode",choices = ["or","and"], value = "or",type = "value")
|
| 236 |
+
with gr.Row():
|
| 237 |
+
history = gr.Dataframe(
|
| 238 |
+
headers=["ID","Time","Name","Weights alpha","Weights beta","Model A","Model B","Model C","alpha","beta","Mode","use MBW","custum name","save setting","use ID"],
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
with gr.Tab("Elements", elem_id="tab_deep"):
|
| 242 |
+
with gr.Row():
|
| 243 |
+
smd_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True)
|
| 244 |
+
create_refresh_button(smd_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 245 |
+
smd_loadkeys = gr.Button(value="load keys",variant='primary')
|
| 246 |
+
with gr.Row():
|
| 247 |
+
keys = gr.Dataframe(headers=["No.","block","key"],)
|
| 248 |
+
|
| 249 |
+
with gr.Tab("Metadeta", elem_id="tab_metadata"):
|
| 250 |
+
with gr.Row():
|
| 251 |
+
meta_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="read metadata",interactive=True)
|
| 252 |
+
create_refresh_button(meta_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z")
|
| 253 |
+
smd_loadmetadata = gr.Button(value="load keys",variant='primary')
|
| 254 |
+
with gr.Row():
|
| 255 |
+
metadata = gr.TextArea()
|
| 256 |
+
|
| 257 |
+
smd_loadmetadata.click(
|
| 258 |
+
fn=loadmetadata,
|
| 259 |
+
inputs=[meta_model_a],
|
| 260 |
+
outputs=[metadata]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
smd_loadkeys.click(
|
| 264 |
+
fn=loadkeys,
|
| 265 |
+
inputs=[smd_model_a],
|
| 266 |
+
outputs=[keys]
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def unload():
|
| 270 |
+
if shared.sd_model == None: return "already unloaded"
|
| 271 |
+
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
| 272 |
+
shared.sd_model = None
|
| 273 |
+
gc.collect()
|
| 274 |
+
devices.torch_gc()
|
| 275 |
+
return "model unloaded"
|
| 276 |
+
|
| 277 |
+
unloadmodel.click(fn=unload,outputs=[submit_result])
|
| 278 |
+
|
| 279 |
+
load_history.click(fn=load_historyf,outputs=[history ])
|
| 280 |
+
|
| 281 |
+
msettings=[weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,useblocks,custom_name,save_sets,id_sets,wpresets,deep,tensor]
|
| 282 |
+
imagegal = [mgallery,mgeninfo,mhtmlinfo,mhtmllog]
|
| 283 |
+
xysettings=[x_type,xgrid,y_type,ygrid,esettings]
|
| 284 |
+
|
| 285 |
+
s_reverse.click(fn = reversparams,
|
| 286 |
+
inputs =mergeid,
|
| 287 |
+
outputs = [submit_result,*msettings[0:8],*msettings[9:13],deep,calcmode]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
merge.click(
|
| 291 |
+
fn=smergegen,
|
| 292 |
+
inputs=[*msettings,esettings1,*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,dfalse],
|
| 293 |
+
outputs=[submit_result,currentmodel]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
mergeandgen.click(
|
| 297 |
+
fn=smergegen,
|
| 298 |
+
inputs=[*msettings,esettings1,*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,dtrue],
|
| 299 |
+
outputs=[submit_result,currentmodel,*imagegal]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
gen.click(
|
| 303 |
+
fn=simggen,
|
| 304 |
+
inputs=[*gensets.txt2img_preview_params,*hiresfix,batch_size,currentmodel,id_sets],
|
| 305 |
+
outputs=[*imagegal],
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
s_reserve.click(
|
| 309 |
+
fn=numaker,
|
| 310 |
+
inputs=[*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
|
| 311 |
+
outputs=[numaframe]
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
s_reserve1.click(
|
| 315 |
+
fn=numaker,
|
| 316 |
+
inputs=[*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
|
| 317 |
+
outputs=[numaframe]
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
gengrid.click(
|
| 321 |
+
fn=numanager,
|
| 322 |
+
inputs=[dtrue,*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
|
| 323 |
+
outputs=[submit_result,currentmodel,*imagegal],
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
s_startreserve.click(
|
| 327 |
+
fn=numanager,
|
| 328 |
+
inputs=[dfalse,*xysettings,*msettings,*gensets.txt2img_preview_params,*hiresfix,batch_size],
|
| 329 |
+
outputs=[submit_result,currentmodel,*imagegal],
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
search.click(fn = searchhistory,inputs=[searchwrods,searchmode],outputs=[history])
|
| 333 |
+
|
| 334 |
+
s_reloadreserve.click(fn=nulister,inputs=[dfalse],outputs=[numaframe])
|
| 335 |
+
s_delreserve.click(fn=nulister,inputs=[s_delnum],outputs=[numaframe])
|
| 336 |
+
loadcachelist.click(fn=load_cachelist,inputs=[],outputs=[currentcache])
|
| 337 |
+
addtox.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[xgrid])
|
| 338 |
+
addtoy.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[ygrid])
|
| 339 |
+
|
| 340 |
+
stopgrid.click(fn=freezetime)
|
| 341 |
+
stopmerge.click(fn=freezemtime)
|
| 342 |
+
|
| 343 |
+
checkpoints.change(fn=lambda x:",".join(x),inputs=[checkpoints],outputs=[inputer])
|
| 344 |
+
blockids.change(fn=lambda x:" ".join(x),inputs=[blockids],outputs=[inputer])
|
| 345 |
+
calcmodes.change(fn=lambda x:",".join(x),inputs=[calcmodes],outputs=[inputer])
|
| 346 |
+
|
| 347 |
+
menbers = [base,in00,in01,in02,in03,in04,in05,in06,in07,in08,in09,in10,in11,mi00,ou00,ou01,ou02,ou03,ou04,ou05,ou06,ou07,ou08,ou09,ou10,ou11]
|
| 348 |
+
|
| 349 |
+
setalpha.click(fn=slider2text,inputs=menbers,outputs=[weights_a])
|
| 350 |
+
setbeta.click(fn=slider2text,inputs=menbers,outputs=[weights_b])
|
| 351 |
+
setx.click(fn=add_to_seq,inputs=[xgrid,weights_a],outputs=[xgrid])
|
| 352 |
+
|
| 353 |
+
readalpha.click(fn=text2slider,inputs=weights_a,outputs=menbers)
|
| 354 |
+
readbeta.click(fn=text2slider,inputs=weights_b,outputs=menbers)
|
| 355 |
+
|
| 356 |
+
x_type.change(fn=showxy,inputs=[x_type,y_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,row_esets,row_calcmode])
|
| 357 |
+
y_type.change(fn=showxy,inputs=[x_type,y_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,row_esets,row_calcmode])
|
| 358 |
+
x_randseednum.change(fn=makerand,inputs=[x_randseednum],outputs=[xgrid])
|
| 359 |
+
|
| 360 |
+
import subprocess
|
| 361 |
+
def openeditors():
|
| 362 |
+
subprocess.Popen(['start', filepath], shell=True)
|
| 363 |
+
|
| 364 |
+
def reloadpresets():
|
| 365 |
+
try:
|
| 366 |
+
with open(filepath) as f:
|
| 367 |
+
return f.read()
|
| 368 |
+
except OSError as e:
|
| 369 |
+
pass
|
| 370 |
+
|
| 371 |
+
def savepresets(text):
|
| 372 |
+
with open(filepath,mode = 'w') as f:
|
| 373 |
+
f.write(text)
|
| 374 |
+
|
| 375 |
+
s_reloadtext.click(fn=reloadpresets,inputs=[],outputs=[wpresets])
|
| 376 |
+
s_reloadtags.click(fn=tagdicter,inputs=[wpresets],outputs=[weightstags])
|
| 377 |
+
s_savetext.click(fn=savepresets,inputs=[wpresets],outputs=[])
|
| 378 |
+
s_openeditor.click(fn=openeditors,inputs=[],outputs=[])
|
| 379 |
+
|
| 380 |
+
return (supermergerui, "SuperMerger", "supermerger"),
|
| 381 |
+
|
| 382 |
+
msearch = []
|
| 383 |
+
mlist=[]
|
| 384 |
+
|
| 385 |
+
def loadmetadata(model):
|
| 386 |
+
import json
|
| 387 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(model)
|
| 388 |
+
if ".safetensors" not in checkpoint_info.filename: return "no metadata(not safetensors)"
|
| 389 |
+
sdict = sd_models.read_metadata_from_safetensors(checkpoint_info.filename)
|
| 390 |
+
if sdict == {}: return "no metadata"
|
| 391 |
+
return json.dumps(sdict,indent=4)
|
| 392 |
+
|
| 393 |
+
def load_historyf():
|
| 394 |
+
filepath = os.path.join(path_root,"mergehistory.csv")
|
| 395 |
+
global mlist,msearch
|
| 396 |
+
msearch = []
|
| 397 |
+
mlist=[]
|
| 398 |
+
try:
|
| 399 |
+
with open(filepath, 'r') as f:
|
| 400 |
+
reader = csv.reader(f)
|
| 401 |
+
mlist = [raw for raw in reader]
|
| 402 |
+
mlist = mlist[1:]
|
| 403 |
+
for m in mlist:
|
| 404 |
+
msearch.append(" ".join(m))
|
| 405 |
+
maxlen = len(mlist[-1][0])
|
| 406 |
+
for i,m in enumerate(mlist):
|
| 407 |
+
mlist[i][0] = mlist[i][0].zfill(maxlen)
|
| 408 |
+
return mlist
|
| 409 |
+
except:
|
| 410 |
+
return [["no data","",""],]
|
| 411 |
+
|
| 412 |
+
def searchhistory(words,searchmode):
|
| 413 |
+
outs =[]
|
| 414 |
+
ando = "and" in searchmode
|
| 415 |
+
words = words.split(" ") if " " in words else [words]
|
| 416 |
+
for i, m in enumerate(msearch):
|
| 417 |
+
hit = ando
|
| 418 |
+
for w in words:
|
| 419 |
+
if ando:
|
| 420 |
+
if w not in m:hit = False
|
| 421 |
+
else:
|
| 422 |
+
if w in m:hit = True
|
| 423 |
+
print(i,len(mlist))
|
| 424 |
+
if hit :outs.append(mlist[i])
|
| 425 |
+
|
| 426 |
+
if outs == []:return [["no result","",""],]
|
| 427 |
+
return outs
|
| 428 |
+
|
| 429 |
+
#msettings=[0 weights_a,1 weights_b,2 model_a,3 model_b,4 model_c,5 base_alpha,6 base_beta,7 mode,8 useblocks,9 custom_name,10 save_sets,11 id_sets,12 wpresets]
|
| 430 |
+
|
| 431 |
+
def reversparams(id):
|
| 432 |
+
def selectfromhash(hash):
|
| 433 |
+
for model in sd_models.checkpoint_tiles():
|
| 434 |
+
if hash in model:
|
| 435 |
+
return model
|
| 436 |
+
return ""
|
| 437 |
+
try:
|
| 438 |
+
idsets = rwmergelog(id = id)
|
| 439 |
+
except:
|
| 440 |
+
return [gr.update(value = "ERROR: history file could not open"),*[gr.update() for x in range(14)]]
|
| 441 |
+
if type(idsets) == str:
|
| 442 |
+
print("ERROR")
|
| 443 |
+
return [gr.update(value = idsets),*[gr.update() for x in range(14)]]
|
| 444 |
+
if idsets[0] == "ID":return [gr.update(value ="ERROR: no history"),*[gr.update() for x in range(14)]]
|
| 445 |
+
mgs = idsets[3:]
|
| 446 |
+
if mgs[0] == "":mgs[0] = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
| 447 |
+
if mgs[1] == "":mgs[1] = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2"
|
| 448 |
+
mgs[2] = selectfromhash(mgs[2]) if len(mgs[2]) > 5 else ""
|
| 449 |
+
mgs[3] = selectfromhash(mgs[3]) if len(mgs[3]) > 5 else ""
|
| 450 |
+
mgs[4] = selectfromhash(mgs[4]) if len(mgs[4]) > 5 else ""
|
| 451 |
+
mgs[8] = True if mgs[8] =="True" else False
|
| 452 |
+
mgs[10] = mgs[10].replace("[","").replace("]","").replace("'", "")
|
| 453 |
+
mgs[10] = [x.strip() for x in mgs[10].split(",")]
|
| 454 |
+
mgs[11] = mgs[11].replace("[","").replace("]","").replace("'", "")
|
| 455 |
+
mgs[11] = [x.strip() for x in mgs[11].split(",")]
|
| 456 |
+
while len(mgs) < 14:
|
| 457 |
+
mgs.append("")
|
| 458 |
+
mgs[13] = "normal" if mgs[13] == "" else mgs[13]
|
| 459 |
+
return [gr.update(value = "setting loaded") ,*[gr.update(value = x) for x in mgs[0:14]]]
|
| 460 |
+
|
| 461 |
+
def add_to_seq(seq,maker):
|
| 462 |
+
return gr.Textbox.update(value = maker if seq=="" else seq+"\r\n"+maker)
|
| 463 |
+
|
| 464 |
+
def load_cachelist():
|
| 465 |
+
text = ""
|
| 466 |
+
for x in checkpoints_loaded.keys():
|
| 467 |
+
text = text +"\r\n"+ x.model_name
|
| 468 |
+
return text.replace("\r\n","",1)
|
| 469 |
+
|
| 470 |
+
def makerand(num):
|
| 471 |
+
text = ""
|
| 472 |
+
for x in range(int(num)):
|
| 473 |
+
text = text +"-1,"
|
| 474 |
+
text = text[:-1]
|
| 475 |
+
return text
|
| 476 |
+
|
| 477 |
+
#row_blockids,row_checkpoints,row_inputers,ygrid
|
| 478 |
+
def showxy(x,y):
|
| 479 |
+
flags =[False]*6
|
| 480 |
+
t = TYPESEG
|
| 481 |
+
txy = t[x] + t[y]
|
| 482 |
+
if "model" in txy : flags[1] = flags[2] = True
|
| 483 |
+
if "pinpoint" in txy : flags[0] = flags[2] = True
|
| 484 |
+
if "effective" in txy or "element" in txy : flags[4] = True
|
| 485 |
+
if "calcmode" in txy : flags[5] = True
|
| 486 |
+
if not "none" in t[y] : flags[3] = flags[2] = True
|
| 487 |
+
return [gr.update(visible = x) for x in flags]
|
| 488 |
+
|
| 489 |
+
def text2slider(text):
|
| 490 |
+
vals = [t.strip() for t in text.split(",")]
|
| 491 |
+
return [gr.update(value = float(v)) for v in vals]
|
| 492 |
+
|
| 493 |
+
def slider2text(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z):
|
| 494 |
+
numbers = [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z]
|
| 495 |
+
numbers = [str(x) for x in numbers]
|
| 496 |
+
return gr.update(value = ",".join(numbers) )
|
| 497 |
+
|
| 498 |
+
def tagdicter(presets):
|
| 499 |
+
presets=presets.splitlines()
|
| 500 |
+
wdict={}
|
| 501 |
+
for l in presets:
|
| 502 |
+
w=[]
|
| 503 |
+
if ":" in l :
|
| 504 |
+
key = l.split(":",1)[0]
|
| 505 |
+
w = l.split(":",1)[1]
|
| 506 |
+
if "\t" in l:
|
| 507 |
+
key = l.split("\t",1)[0]
|
| 508 |
+
w = l.split("\t",1)[1]
|
| 509 |
+
if len([w for w in w.split(",")]) == 26:
|
| 510 |
+
wdict[key.strip()]=w
|
| 511 |
+
return ",".join(list(wdict.keys()))
|
| 512 |
+
|
| 513 |
+
def loadkeys(model_a):
|
| 514 |
+
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
|
| 515 |
+
sd = sd_models.read_state_dict(checkpoint_info.filename,"cpu")
|
| 516 |
+
keys = []
|
| 517 |
+
for i, key in enumerate(sd.keys()):
|
| 518 |
+
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
|
| 519 |
+
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
|
| 520 |
+
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
| 521 |
+
|
| 522 |
+
weight_index = -1
|
| 523 |
+
blockid=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11","Not Merge"]
|
| 524 |
+
|
| 525 |
+
NUM_INPUT_BLOCKS = 12
|
| 526 |
+
NUM_MID_BLOCK = 1
|
| 527 |
+
NUM_OUTPUT_BLOCKS = 12
|
| 528 |
+
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
|
| 529 |
+
|
| 530 |
+
if 'time_embed' in key:
|
| 531 |
+
weight_index = -2 # before input blocks
|
| 532 |
+
elif '.out.' in key:
|
| 533 |
+
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
|
| 534 |
+
else:
|
| 535 |
+
m = re_inp.search(key)
|
| 536 |
+
if m:
|
| 537 |
+
inp_idx = int(m.groups()[0])
|
| 538 |
+
weight_index = inp_idx
|
| 539 |
+
else:
|
| 540 |
+
m = re_mid.search(key)
|
| 541 |
+
if m:
|
| 542 |
+
weight_index = NUM_INPUT_BLOCKS
|
| 543 |
+
else:
|
| 544 |
+
m = re_out.search(key)
|
| 545 |
+
if m:
|
| 546 |
+
out_idx = int(m.groups()[0])
|
| 547 |
+
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
|
| 548 |
+
keys.append([i,blockid[weight_index+1],key])
|
| 549 |
+
return keys
|
| 550 |
+
|
| 551 |
+
script_callbacks.on_ui_tabs(on_ui_tabs)
|
| 552 |
+
script_callbacks.on_ui_train_tabs(on_ui_train_tabs)
|
microsoftexcel-tunnels/.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
| 2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
| 3 |
+
|
| 4 |
+
### Python ###
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
build/
|
| 16 |
+
develop-eggs/
|
| 17 |
+
dist/
|
| 18 |
+
downloads/
|
| 19 |
+
eggs/
|
| 20 |
+
.eggs/
|
| 21 |
+
lib/
|
| 22 |
+
lib64/
|
| 23 |
+
parts/
|
| 24 |
+
sdist/
|
| 25 |
+
var/
|
| 26 |
+
wheels/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
cover/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
db.sqlite3
|
| 66 |
+
db.sqlite3-journal
|
| 67 |
+
|
| 68 |
+
# Flask stuff:
|
| 69 |
+
instance/
|
| 70 |
+
.webassets-cache
|
| 71 |
+
|
| 72 |
+
# Scrapy stuff:
|
| 73 |
+
.scrapy
|
| 74 |
+
|
| 75 |
+
# Sphinx documentation
|
| 76 |
+
docs/_build/
|
| 77 |
+
|
| 78 |
+
# PyBuilder
|
| 79 |
+
.pybuilder/
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 92 |
+
# .python-version
|
| 93 |
+
|
| 94 |
+
# pipenv
|
| 95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 98 |
+
# install all needed dependencies.
|
| 99 |
+
#Pipfile.lock
|
| 100 |
+
|
| 101 |
+
# poetry
|
| 102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 104 |
+
# commonly ignored for libraries.
|
| 105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 106 |
+
#poetry.lock
|
| 107 |
+
|
| 108 |
+
# pdm
|
| 109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 110 |
+
#pdm.lock
|
| 111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 112 |
+
# in version control.
|
| 113 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 114 |
+
.pdm.toml
|
| 115 |
+
|
| 116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 117 |
+
__pypackages__/
|
| 118 |
+
|
| 119 |
+
# Celery stuff
|
| 120 |
+
celerybeat-schedule
|
| 121 |
+
celerybeat.pid
|
| 122 |
+
|
| 123 |
+
# SageMath parsed files
|
| 124 |
+
*.sage.py
|
| 125 |
+
|
| 126 |
+
# Environments
|
| 127 |
+
.env
|
| 128 |
+
.venv
|
| 129 |
+
env/
|
| 130 |
+
venv/
|
| 131 |
+
ENV/
|
| 132 |
+
env.bak/
|
| 133 |
+
venv.bak/
|
| 134 |
+
|
| 135 |
+
# Spyder project settings
|
| 136 |
+
.spyderproject
|
| 137 |
+
.spyproject
|
| 138 |
+
|
| 139 |
+
# Rope project settings
|
| 140 |
+
.ropeproject
|
| 141 |
+
|
| 142 |
+
# mkdocs documentation
|
| 143 |
+
/site
|
| 144 |
+
|
| 145 |
+
# mypy
|
| 146 |
+
.mypy_cache/
|
| 147 |
+
.dmypy.json
|
| 148 |
+
dmypy.json
|
| 149 |
+
|
| 150 |
+
# Pyre type checker
|
| 151 |
+
.pyre/
|
| 152 |
+
|
| 153 |
+
# pytype static type analyzer
|
| 154 |
+
.pytype/
|
| 155 |
+
|
| 156 |
+
# Cython debug symbols
|
| 157 |
+
cython_debug/
|
| 158 |
+
|
| 159 |
+
# PyCharm
|
| 160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 164 |
+
#.idea/
|
| 165 |
+
|
| 166 |
+
### Python Patch ###
|
| 167 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 168 |
+
poetry.toml
|
| 169 |
+
|
| 170 |
+
# ruff
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
| 174 |
+
|
| 175 |
+
id_rsa
|
| 176 |
+
id_rsa.pub
|
microsoftexcel-tunnels/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.4.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
args: [--markdown-linebreak-ext=md]
|
| 7 |
+
- id: end-of-file-fixer
|
| 8 |
+
|
| 9 |
+
- repo: https://github.com/asottile/pyupgrade
|
| 10 |
+
rev: v3.3.1
|
| 11 |
+
hooks:
|
| 12 |
+
- id: pyupgrade
|
| 13 |
+
args: [--py310-plus]
|
| 14 |
+
|
| 15 |
+
- repo: https://github.com/psf/black
|
| 16 |
+
rev: 23.1.0
|
| 17 |
+
hooks:
|
| 18 |
+
- id: black
|
| 19 |
+
|
| 20 |
+
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
| 21 |
+
# Ruff version.
|
| 22 |
+
rev: "v0.0.244"
|
| 23 |
+
hooks:
|
| 24 |
+
- id: ruff
|
| 25 |
+
args: [--fix]
|
microsoftexcel-tunnels/LICENSE.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
The MIT License (MIT)
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2023 Bingsu
|
| 5 |
+
|
| 6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 7 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 8 |
+
in the Software without restriction, including without limitation the rights
|
| 9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 10 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 11 |
+
furnished to do so, subject to the following conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
copies or substantial portions of the Software.
|
| 15 |
+
|
| 16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 22 |
+
SOFTWARE.
|
microsoftexcel-tunnels/README.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sd-webui-tunnels
|
| 2 |
+
|
| 3 |
+
Tunneling extension for [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
|
| 7 |
+
### [cloudflared](https://try.cloudflare.com/)
|
| 8 |
+
|
| 9 |
+
add `--cloudflared` to commandline options.
|
| 10 |
+
|
| 11 |
+
### [localhost.run](https://localhost.run/)
|
| 12 |
+
|
| 13 |
+
add `--localhostrun` to commandline options.
|
| 14 |
+
|
| 15 |
+
### [remote.moe](https://github.com/fasmide/remotemoe)
|
| 16 |
+
|
| 17 |
+
add `--remotemoe` to commandline options.
|
| 18 |
+
|
| 19 |
+
The feature of `remote.moe` is that as long as the same ssh key is used, the same url is generated.
|
| 20 |
+
|
| 21 |
+
The ssh keys for `localhost.run` and `remote.moe` are created with the name `id_rsa` in the script's root folder. However, if there is a problem with the write permission, it is created in a temporary folder instead, so a different url is created each time.
|
microsoftexcel-tunnels/__pycache__/preload.cpython-310.pyc
ADDED
|
Binary file (623 Bytes). View file
|
|
|
microsoftexcel-tunnels/install.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import launch
|
| 2 |
+
|
| 3 |
+
if not launch.is_installed("pycloudflared"):
|
| 4 |
+
launch.run_pip("install pycloudflared", "pycloudflared")
|
microsoftexcel-tunnels/preload.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def preload(parser: argparse.ArgumentParser):
|
| 5 |
+
parser.add_argument(
|
| 6 |
+
"--cloudflared",
|
| 7 |
+
action="store_true",
|
| 8 |
+
help="use trycloudflare, alternative to gradio --share",
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"--localhostrun",
|
| 13 |
+
action="store_true",
|
| 14 |
+
help="use localhost.run, alternative to gradio --share",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--remotemoe",
|
| 19 |
+
action="store_true",
|
| 20 |
+
help="use remote.moe, alternative to gradio --share",
|
| 21 |
+
)
|
microsoftexcel-tunnels/pyproject.toml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sd-webui-tunnels"
|
| 3 |
+
version = "23.2.1"
|
| 4 |
+
description = "Tunneling extension for automatic1111 sd-webui"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "dowon", email = "ks2515@naver.com"},
|
| 7 |
+
]
|
| 8 |
+
requires-python = ">=3.8"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
|
| 12 |
+
[project.urls]
|
| 13 |
+
repository = "https://github.com/Bing-su/sd-webui-tunnels"
|
| 14 |
+
|
| 15 |
+
[tool.isort]
|
| 16 |
+
profile = "black"
|
| 17 |
+
known_first_party = ["modules", "launch"]
|
| 18 |
+
|
| 19 |
+
[tool.ruff]
|
| 20 |
+
select = ["A", "B", "C4", "E", "F", "I001", "N", "PT", "UP", "W"]
|
| 21 |
+
ignore = ["B008", "B905", "E501"]
|
| 22 |
+
unfixable = ["F401"]
|
| 23 |
+
|
| 24 |
+
[tool.ruff.isort]
|
| 25 |
+
known-first-party = ["modules", "launch"]
|
microsoftexcel-tunnels/scripts/__pycache__/ssh_tunnel.cpython-310.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
microsoftexcel-tunnels/scripts/__pycache__/try_cloudflare.cpython-310.pyc
ADDED
|
Binary file (597 Bytes). View file
|
|
|
microsoftexcel-tunnels/scripts/ssh_tunnel.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import re
|
| 3 |
+
import shlex
|
| 4 |
+
import subprocess
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
from typing import Union
|
| 8 |
+
from gradio import strings
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from modules.shared import cmd_opts
|
| 12 |
+
|
| 13 |
+
LOCALHOST_RUN = "localhost.run"
|
| 14 |
+
REMOTE_MOE = "remote.moe"
|
| 15 |
+
localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)")
|
| 16 |
+
remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gen_key(path: Union[str, Path]) -> None:
|
| 20 |
+
path = Path(path)
|
| 21 |
+
arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
|
| 22 |
+
args = shlex.split(arg_string)
|
| 23 |
+
subprocess.run(args, check=True)
|
| 24 |
+
path.chmod(0o600)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def ssh_tunnel(host: str = LOCALHOST_RUN) -> None:
|
| 28 |
+
ssh_name = "id_rsa"
|
| 29 |
+
ssh_path = Path(__file__).parent.parent / ssh_name
|
| 30 |
+
|
| 31 |
+
tmp = None
|
| 32 |
+
if not ssh_path.exists():
|
| 33 |
+
try:
|
| 34 |
+
gen_key(ssh_path)
|
| 35 |
+
# write permission error or etc
|
| 36 |
+
except subprocess.CalledProcessError:
|
| 37 |
+
tmp = TemporaryDirectory()
|
| 38 |
+
ssh_path = Path(tmp.name) / ssh_name
|
| 39 |
+
gen_key(ssh_path)
|
| 40 |
+
|
| 41 |
+
port = cmd_opts.port if cmd_opts.port else 7860
|
| 42 |
+
|
| 43 |
+
arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}"
|
| 44 |
+
args = shlex.split(arg_string)
|
| 45 |
+
|
| 46 |
+
tunnel = subprocess.Popen(
|
| 47 |
+
args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
atexit.register(tunnel.terminate)
|
| 51 |
+
if tmp is not None:
|
| 52 |
+
atexit.register(tmp.cleanup)
|
| 53 |
+
|
| 54 |
+
tunnel_url = ""
|
| 55 |
+
lines = 27 if host == LOCALHOST_RUN else 5
|
| 56 |
+
pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern
|
| 57 |
+
|
| 58 |
+
for _ in range(lines):
|
| 59 |
+
line = tunnel.stdout.readline()
|
| 60 |
+
if line.startswith("Warning"):
|
| 61 |
+
print(line, end="")
|
| 62 |
+
|
| 63 |
+
url_match = pattern.search(line)
|
| 64 |
+
if url_match:
|
| 65 |
+
tunnel_url = url_match.group("url")
|
| 66 |
+
break
|
| 67 |
+
else:
|
| 68 |
+
raise RuntimeError(f"Failed to run {host}")
|
| 69 |
+
|
| 70 |
+
# print(f" * Running on {tunnel_url}")
|
| 71 |
+
os.environ['webui_url'] = tunnel_url
|
| 72 |
+
colab_url = os.getenv('colab_url')
|
| 73 |
+
strings.en["SHARE_LINK_MESSAGE"] = f"Running on public URL (recommended): {tunnel_url}"
|
| 74 |
+
|
| 75 |
+
if cmd_opts.localhostrun:
|
| 76 |
+
print("localhost.run detected, trying to connect...")
|
| 77 |
+
ssh_tunnel(LOCALHOST_RUN)
|
| 78 |
+
|
| 79 |
+
if cmd_opts.remotemoe:
|
| 80 |
+
print("remote.moe detected, trying to connect...")
|
| 81 |
+
ssh_tunnel(REMOTE_MOE)
|
microsoftexcel-tunnels/scripts/try_cloudflare.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# credit to camenduru senpai
|
| 2 |
+
from pycloudflared import try_cloudflare
|
| 3 |
+
|
| 4 |
+
from modules.shared import cmd_opts
|
| 5 |
+
|
| 6 |
+
from gradio import strings
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
if cmd_opts.cloudflared:
|
| 11 |
+
print("cloudflared detected, trying to connect...")
|
| 12 |
+
port = cmd_opts.port if cmd_opts.port else 7860
|
| 13 |
+
tunnel_url = try_cloudflare(port=port, verbose=False)
|
| 14 |
+
os.environ['webui_url'] = tunnel_url.tunnel
|
| 15 |
+
strings.en["PUBLIC_SHARE_TRUE"] = f"Running on public URL: {tunnel_url.tunnel}"
|
microsoftexcel-tunnels/ssh_tunnel.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import re
|
| 3 |
+
import shlex
|
| 4 |
+
import subprocess
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
from typing import Union
|
| 8 |
+
from gradio import strings
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from modules.shared import cmd_opts
|
| 12 |
+
|
| 13 |
+
LOCALHOST_RUN = "localhost.run"
|
| 14 |
+
REMOTE_MOE = "remote.moe"
|
| 15 |
+
localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)")
|
| 16 |
+
remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gen_key(path: Union[str, Path]) -> None:
|
| 20 |
+
path = Path(path)
|
| 21 |
+
arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
|
| 22 |
+
args = shlex.split(arg_string)
|
| 23 |
+
subprocess.run(args, check=True)
|
| 24 |
+
path.chmod(0o600)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def ssh_tunnel(host: str = LOCALHOST_RUN) -> None:
|
| 28 |
+
ssh_name = "id_rsa"
|
| 29 |
+
ssh_path = Path(__file__).parent.parent / ssh_name
|
| 30 |
+
|
| 31 |
+
tmp = None
|
| 32 |
+
if not ssh_path.exists():
|
| 33 |
+
try:
|
| 34 |
+
gen_key(ssh_path)
|
| 35 |
+
# write permission error or etc
|
| 36 |
+
except subprocess.CalledProcessError:
|
| 37 |
+
tmp = TemporaryDirectory()
|
| 38 |
+
ssh_path = Path(tmp.name) / ssh_name
|
| 39 |
+
gen_key(ssh_path)
|
| 40 |
+
|
| 41 |
+
port = cmd_opts.port if cmd_opts.port else 7860
|
| 42 |
+
|
| 43 |
+
arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}"
|
| 44 |
+
args = shlex.split(arg_string)
|
| 45 |
+
|
| 46 |
+
tunnel = subprocess.Popen(
|
| 47 |
+
args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
atexit.register(tunnel.terminate)
|
| 51 |
+
if tmp is not None:
|
| 52 |
+
atexit.register(tmp.cleanup)
|
| 53 |
+
|
| 54 |
+
tunnel_url = ""
|
| 55 |
+
lines = 27 if host == LOCALHOST_RUN else 5
|
| 56 |
+
pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern
|
| 57 |
+
|
| 58 |
+
for _ in range(lines):
|
| 59 |
+
line = tunnel.stdout.readline()
|
| 60 |
+
if line.startswith("Warning"):
|
| 61 |
+
print(line, end="")
|
| 62 |
+
|
| 63 |
+
url_match = pattern.search(line)
|
| 64 |
+
if url_match:
|
| 65 |
+
tunnel_url = url_match.group("url")
|
| 66 |
+
break
|
| 67 |
+
else:
|
| 68 |
+
raise RuntimeError(f"Failed to run {host}")
|
| 69 |
+
|
| 70 |
+
# print(f" * Running on {tunnel_url}")
|
| 71 |
+
os.environ['webui_url'] = tunnel_url
|
| 72 |
+
colab_url = os.getenv('colab_url')
|
| 73 |
+
strings.en["SHARE_LINK_MESSAGE"] = f"Public WebUI Colab URL: {tunnel_url}"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def googleusercontent_tunnel():
|
| 77 |
+
colab_url = os.getenv('colab_url')
|
| 78 |
+
strings.en["SHARE_LINK_MESSAGE"] = f"WebUI Colab URL: {colab_url}"
|
| 79 |
+
|
| 80 |
+
if cmd_opts.localhostrun:
|
| 81 |
+
print("localhost.run detected, trying to connect...")
|
| 82 |
+
ssh_tunnel(LOCALHOST_RUN)
|
| 83 |
+
|
| 84 |
+
if cmd_opts.remotemoe:
|
| 85 |
+
print("remote.moe detected, trying to connect...")
|
| 86 |
+
ssh_tunnel(REMOTE_MOE)
|
openpose-editor/.github/CODEOWNERS
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# default reviewer
|
| 2 |
+
* @fkunn1326
|
openpose-editor/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Bug report
|
| 3 |
+
about: Create a report to help us improve
|
| 4 |
+
title: ''
|
| 5 |
+
labels: bug
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the bug**
|
| 11 |
+
A clear and concise description of what the bug is.
|
| 12 |
+
|
| 13 |
+
**To Reproduce**
|
| 14 |
+
Steps to reproduce the behavior:
|
| 15 |
+
1. Go to '...'
|
| 16 |
+
2. Click on '....'
|
| 17 |
+
3. Scroll down to '....'
|
| 18 |
+
4. See error
|
| 19 |
+
|
| 20 |
+
**Expected behavior**
|
| 21 |
+
A clear and concise description of what you expected to happen.
|
| 22 |
+
|
| 23 |
+
**Screenshots**
|
| 24 |
+
If applicable, add screenshots to help explain your problem.
|
| 25 |
+
|
| 26 |
+
**Environment**
|
| 27 |
+
- OS: [e.g. iOS]
|
| 28 |
+
- Browser [e.g. chrome, safari]
|
| 29 |
+
|
| 30 |
+
**Additional context**
|
| 31 |
+
Add any other context about the problem here.
|
openpose-editor/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Feature request
|
| 3 |
+
about: Suggest an idea for this project
|
| 4 |
+
title: ''
|
| 5 |
+
labels: enhancement
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Is your feature request related to a problem? Please describe.**
|
| 11 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
| 12 |
+
|
| 13 |
+
**Describe the solution you'd like**
|
| 14 |
+
A clear and concise description of what you want to happen.
|
| 15 |
+
|
| 16 |
+
**Additional context**
|
| 17 |
+
Add any other context or screenshots about the feature request here.
|
openpose-editor/.github/workflows/typos.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# yamllint disable rule:line-length
|
| 3 |
+
name: Typos
|
| 4 |
+
|
| 5 |
+
on: # yamllint disable-line rule:truthy
|
| 6 |
+
push:
|
| 7 |
+
pull_request:
|
| 8 |
+
types:
|
| 9 |
+
- opened
|
| 10 |
+
- synchronize
|
| 11 |
+
- reopened
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
build:
|
| 15 |
+
runs-on: ubuntu-latest
|
| 16 |
+
|
| 17 |
+
steps:
|
| 18 |
+
- uses: actions/checkout@v3
|
| 19 |
+
|
| 20 |
+
- name: typos-action
|
| 21 |
+
uses: crate-ci/typos@v1.13.10
|
openpose-editor/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
models/
|
openpose-editor/.vscode/settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nuxt.isNuxtApp": false,
|
| 3 |
+
"editor.tabCompletion": "on",
|
| 4 |
+
"diffEditor.codeLens": true
|
| 5 |
+
}
|
openpose-editor/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Fkunn1326
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
openpose-editor/README.en.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Openpose Editor
|
| 2 |
+
|
| 3 |
+
[日本語](README.md) | English | [简体中文](README.zh-cn.md)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Openpose Editor for Automatic1111/stable-diffusion-webui
|
| 8 |
+
|
| 9 |
+
- Pose editing
|
| 10 |
+
- Pose detection
|
| 11 |
+
|
| 12 |
+
This can:
|
| 13 |
+
|
| 14 |
+
- Add a new person
|
| 15 |
+
- Detect pose from an image
|
| 16 |
+
- Add background image
|
| 17 |
+
|
| 18 |
+
- Save as a PNG
|
| 19 |
+
- Send to ControlNet extension
|
| 20 |
+
|
| 21 |
+
## Installation
|
| 22 |
+
|
| 23 |
+
1. Open the "Extension" tab
|
| 24 |
+
2. Click on "Install from URL"
|
| 25 |
+
3. In "URL for extension's git repository" enter this extension, https://github.com/fkunn1326/openpose-editor.git
|
| 26 |
+
4. Click "Install"
|
| 27 |
+
5. Restart WebUI
|
| 28 |
+
|
| 29 |
+
## Attention
|
| 30 |
+
|
| 31 |
+
Do not select anything for the Preprocessor in ControlNet.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Fix Error
|
| 35 |
+
> urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
|
| 36 |
+
|
| 37 |
+
Run
|
| 38 |
+
```
|
| 39 |
+
/Applications/Python\ $version /Install\ Certificates.command
|
| 40 |
+
```
|
openpose-editor/README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Openpose Editor
|
| 2 |
+
|
| 3 |
+
日本語 | [English](README.en.md) | [简体中文](README.zh-cn.md)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Automatic1111/stable-diffusion-webui用のOpenpose Editor
|
| 8 |
+
|
| 9 |
+
- ポーズの編集
|
| 10 |
+
- ポーズの検出
|
| 11 |
+
|
| 12 |
+
ができます
|
| 13 |
+
|
| 14 |
+
- 「Add」: 人を追加する
|
| 15 |
+
- 「Detect from image」: 画像からポーズを検出する
|
| 16 |
+
- 「Add Background image」: 背景を追加する
|
| 17 |
+
|
| 18 |
+
- 「Save PNG」: PNGで保存する
|
| 19 |
+
- 「Send to ControlNet」: Controlnet拡張機能がインストールされている場合、画像をそこに送る
|
| 20 |
+
|
| 21 |
+
## インストール方法
|
| 22 |
+
|
| 23 |
+
1. "Extension" タブを開く
|
| 24 |
+
2. "Install from URL" タブを開く
|
| 25 |
+
3. "URL for extension's git repository" 欄にこのリポジトリの URL (https://github.com/fkunn1326/openpose-editor.git) を入れます。
|
| 26 |
+
4. "Install" ボタンを押す
|
| 27 |
+
5. WebUIを再起動する
|
| 28 |
+
|
| 29 |
+
## 注意
|
| 30 |
+
|
| 31 |
+
ControlNetの "Preprocessor" には、何も指定しないようにしてください。
|
| 32 |
+
|
| 33 |
+
## エラーの対策
|
| 34 |
+
|
| 35 |
+
> urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
以下のファイルを開いてださい
|
| 39 |
+
```
|
| 40 |
+
/Applications/Python\ $version /Install\ Certificates.command
|
| 41 |
+
```
|
openpose-editor/README.zh-cn.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Openpose Editor
|
| 2 |
+
|
| 3 |
+
[日本語](README.md) | [English](README.en.md)|中文
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
适用于Automatic1111/stable-diffusion-webui 的Openpose Editor 插件。
|
| 8 |
+
|
| 9 |
+
功能:
|
| 10 |
+
- 直接编辑骨骼动作
|
| 11 |
+
- 从图像识别姿势
|
| 12 |
+
|
| 13 |
+
本插件实现以下操作:
|
| 14 |
+
|
| 15 |
+
- 「Add」:添加一个新骨骼
|
| 16 |
+
- 「Detect from image」: 从图片中识别姿势
|
| 17 |
+
- 「Add Background image」: 添加背景图片
|
| 18 |
+
- 「Load JSON」:载入JSON文件
|
| 19 |
+
|
| 20 |
+
- 「Save PNG」: 保存为PNG格式图片
|
| 21 |
+
- 「Send to ControlNet」:将骨骼姿势发送到 ControlNet
|
| 22 |
+
- 「Save JSON」:将骨骼保存为JSON
|
| 23 |
+
## 安装方法
|
| 24 |
+
|
| 25 |
+
1. 打开扩展(Extension)标签。
|
| 26 |
+
2. 点击从网址安装(Install from URL)
|
| 27 |
+
3. 在扩展的 git 仓库网址(URL for extension's git repository)处输入 https://github.com/fkunn1326/openpose-editor.git
|
| 28 |
+
4. 点击安装(Install)
|
| 29 |
+
5. 重启 WebUI
|
| 30 |
+
## 注意
|
| 31 |
+
|
| 32 |
+
不要给ConrtolNet 的 "Preprocessor" 选项指定任何值,请保持在none状态
|
| 33 |
+
|
| 34 |
+
## 常见问题
|
| 35 |
+
Mac OS可能会出现:
|
| 36 |
+
> urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>
|
| 37 |
+
|
| 38 |
+
请执行文件
|
| 39 |
+
```
|
| 40 |
+
/Applications/Python\ $version /Install\ Certificates.command
|
| 41 |
+
```
|
openpose-editor/_typos.toml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Files for typos
|
| 2 |
+
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
| 3 |
+
|
| 4 |
+
[default.extend-identifiers]
|
| 5 |
+
|
| 6 |
+
[default.extend-words]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
[files]
|
| 11 |
+
extend-exclude = ["_typos.toml", "fabric.js"]
|
openpose-editor/configs/.gitkeep
ADDED
|
File without changes
|
openpose-editor/images//343/202/271/343/202/257/343/203/252/343/203/274/343/203/263/343/202/267/343/203/247/343/203/203/343/203/210 2023-02-19 131430.png
ADDED
|
openpose-editor/javascript/fabric.js
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
openpose-editor/javascript/main.js
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fabric.Object.prototype.transparentCorners = false;
|
| 2 |
+
fabric.Object.prototype.cornerColor = '#108ce6';
|
| 3 |
+
fabric.Object.prototype.borderColor = '#108ce6';
|
| 4 |
+
fabric.Object.prototype.cornerSize = 10;
|
| 5 |
+
|
| 6 |
+
let count = 0;
|
| 7 |
+
let executed_openpose_editor = false;
|
| 8 |
+
|
| 9 |
+
let lockMode = false;
|
| 10 |
+
const undo_history = [];
|
| 11 |
+
const redo_history = [];
|
| 12 |
+
|
| 13 |
+
coco_body_keypoints = [
|
| 14 |
+
"nose",
|
| 15 |
+
"neck",
|
| 16 |
+
"right_shoulder",
|
| 17 |
+
"right_elbow",
|
| 18 |
+
"right_wrist",
|
| 19 |
+
"left_shoulder",
|
| 20 |
+
"left_elbow",
|
| 21 |
+
"left_wrist",
|
| 22 |
+
"right_hip",
|
| 23 |
+
"right_knee",
|
| 24 |
+
"right_ankle",
|
| 25 |
+
"left_hip",
|
| 26 |
+
"left_knee",
|
| 27 |
+
"left_ankle",
|
| 28 |
+
"right_eye",
|
| 29 |
+
"left_eye",
|
| 30 |
+
"right_ear",
|
| 31 |
+
"left_ear",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
let connect_keypoints = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13], [14, 0], [14, 16], [15, 0], [15, 17]]
|
| 35 |
+
|
| 36 |
+
let connect_color = [[0, 0, 255], [255, 0, 0], [255, 170, 0], [255, 255, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
|
| 37 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [85, 0, 255],
|
| 38 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
| 39 |
+
|
| 40 |
+
let openpose_obj = {
|
| 41 |
+
// width, height
|
| 42 |
+
resolution: [512, 512],
|
| 43 |
+
// fps...?
|
| 44 |
+
fps: 1,
|
| 45 |
+
// frames
|
| 46 |
+
frames: [
|
| 47 |
+
{
|
| 48 |
+
frame_current: 1,
|
| 49 |
+
// armatures
|
| 50 |
+
armatures: {
|
| 51 |
+
},
|
| 52 |
+
}
|
| 53 |
+
]
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
let visibleEyes = true;
|
| 57 |
+
let flipped = false;
|
| 58 |
+
const default_keypoints = [[241,77],[241,120],[191,118],[177,183],[163,252],[298,118],[317,182],[332,245],[225,241],[213,359],[215,454],[270,240],[282,360],[286,456],[232,59],[253,60],[225,70],[260,72]]
|
| 59 |
+
|
| 60 |
+
async function fileToDataUrl(file) {
|
| 61 |
+
if (file.data) {
|
| 62 |
+
// Gradio version < 3.23
|
| 63 |
+
return file.data
|
| 64 |
+
}
|
| 65 |
+
return await new Promise(r => {let a=new FileReader(); a.onload=r; a.readAsDataURL(file.blob)}).then(e => e.target.result)
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
function calcResolution(width, height){
|
| 69 |
+
const viewportWidth = window.innerWidth / 2.25;
|
| 70 |
+
const viewportHeight = window.innerHeight * 0.75;
|
| 71 |
+
const ratio = Math.min(viewportWidth / width, viewportHeight / height);
|
| 72 |
+
return {width: width * ratio, height: height * ratio}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
function resizeCanvas(width, height){
|
| 76 |
+
const elem = openpose_editor_elem;
|
| 77 |
+
const canvas = openpose_editor_canvas;
|
| 78 |
+
|
| 79 |
+
let resolution = calcResolution(width, height)
|
| 80 |
+
|
| 81 |
+
canvas.setWidth(width);
|
| 82 |
+
canvas.setHeight(height);
|
| 83 |
+
elem.style.width = resolution["width"] + "px"
|
| 84 |
+
elem.style.height = resolution["height"] + "px"
|
| 85 |
+
elem.nextElementSibling.style.width = resolution["width"] + "px"
|
| 86 |
+
elem.nextElementSibling.style.height = resolution["height"] + "px"
|
| 87 |
+
elem.parentElement.style.width = resolution["width"] + "px"
|
| 88 |
+
elem.parentElement.style.height = resolution["height"] + "px"
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
function undo() {
|
| 92 |
+
const canvas = openpose_editor_canvas;
|
| 93 |
+
if (undo_history.length > 0) {
|
| 94 |
+
lockMode = true;
|
| 95 |
+
if (undo_history.length > 1) redo_history.push(undo_history.pop());
|
| 96 |
+
const content = undo_history[undo_history.length - 1];
|
| 97 |
+
canvas.loadFromJSON(content, function () {
|
| 98 |
+
canvas.renderAll();
|
| 99 |
+
lockMode = false;
|
| 100 |
+
});
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
function redo() {
|
| 105 |
+
const canvas = openpose_editor_canvas;
|
| 106 |
+
if (redo_history.length > 0) {
|
| 107 |
+
lockMode = true;
|
| 108 |
+
const content = redo_history.pop();
|
| 109 |
+
undo_history.push(content);
|
| 110 |
+
canvas.loadFromJSON(content, function () {
|
| 111 |
+
canvas.renderAll();
|
| 112 |
+
lockMode = false;
|
| 113 |
+
});
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
function setPose(keypoints){
|
| 118 |
+
const canvas = openpose_editor_canvas;
|
| 119 |
+
|
| 120 |
+
canvas.clear()
|
| 121 |
+
|
| 122 |
+
canvas.backgroundColor = "#000"
|
| 123 |
+
|
| 124 |
+
const res = [];
|
| 125 |
+
for (let i = 0; i < keypoints.length; i += 18) {
|
| 126 |
+
const chunk = keypoints.slice(i, i + 18);
|
| 127 |
+
res.push(chunk);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
for (item of res){
|
| 131 |
+
addPose(item)
|
| 132 |
+
openpose_editor_canvas.discardActiveObject();
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
function addPose(keypoints=undefined){
|
| 137 |
+
if (keypoints === undefined){
|
| 138 |
+
keypoints = default_keypoints;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
const canvas = openpose_editor_canvas;
|
| 142 |
+
const group = new fabric.Group()
|
| 143 |
+
|
| 144 |
+
function makeCircle(color, left, top, line1, line2, line3, line4, line5) {
|
| 145 |
+
var c = new fabric.Circle({
|
| 146 |
+
left: left,
|
| 147 |
+
top: top,
|
| 148 |
+
strokeWidth: 1,
|
| 149 |
+
radius: 5,
|
| 150 |
+
fill: color,
|
| 151 |
+
stroke: color,
|
| 152 |
+
originX: 'center',
|
| 153 |
+
originY: 'center',
|
| 154 |
+
});
|
| 155 |
+
c.hasControls = c.hasBorders = false;
|
| 156 |
+
|
| 157 |
+
c.line1 = line1;
|
| 158 |
+
c.line2 = line2;
|
| 159 |
+
c.line3 = line3;
|
| 160 |
+
c.line4 = line4;
|
| 161 |
+
c.line5 = line5;
|
| 162 |
+
|
| 163 |
+
return c;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
function makeLine(coords, color) {
|
| 167 |
+
return new fabric.Line(coords, {
|
| 168 |
+
fill: color,
|
| 169 |
+
stroke: color,
|
| 170 |
+
strokeWidth: 10,
|
| 171 |
+
selectable: false,
|
| 172 |
+
evented: false,
|
| 173 |
+
originX: 'center',
|
| 174 |
+
originY: 'center',
|
| 175 |
+
});
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
const lines = []
|
| 179 |
+
const circles = []
|
| 180 |
+
|
| 181 |
+
for (i = 0; i < connect_keypoints.length; i++){
|
| 182 |
+
// 接続されるidxを指定 [0, 1]なら0と1つなぐ
|
| 183 |
+
const item = connect_keypoints[i]
|
| 184 |
+
const line = makeLine(keypoints[item[0]].concat(keypoints[item[1]]), `rgba(${connect_color[i].join(", ")}, 0.7)`)
|
| 185 |
+
lines.push(line)
|
| 186 |
+
canvas.add(line)
|
| 187 |
+
line['id'] = item[0];
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
for (i = 0; i < keypoints.length; i++){
|
| 191 |
+
list = []
|
| 192 |
+
connect_keypoints.filter((item, idx) => {
|
| 193 |
+
if(item.includes(i)){
|
| 194 |
+
list.push(lines[idx])
|
| 195 |
+
return idx
|
| 196 |
+
}
|
| 197 |
+
})
|
| 198 |
+
circle = makeCircle(`rgb(${connect_color[i].join(", ")})`, keypoints[i][0], keypoints[i][1], ...list)
|
| 199 |
+
circle["id"] = i
|
| 200 |
+
circles.push(circle)
|
| 201 |
+
// canvas.add(circle)
|
| 202 |
+
group.addWithUpdate(circle);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
canvas.discardActiveObject();
|
| 206 |
+
canvas.setActiveObject(group);
|
| 207 |
+
canvas.add(group);
|
| 208 |
+
group.toActiveSelection();
|
| 209 |
+
canvas.requestRenderAll();
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
function initCanvas(elem){
|
| 213 |
+
const canvas = window.openpose_editor_canvas = new fabric.Canvas(elem, {
|
| 214 |
+
backgroundColor: '#000',
|
| 215 |
+
// selection: false,
|
| 216 |
+
preserveObjectStacking: true
|
| 217 |
+
});
|
| 218 |
+
|
| 219 |
+
window.openpose_editor_elem = elem
|
| 220 |
+
|
| 221 |
+
function updateLines(target) {
|
| 222 |
+
if ("_objects" in target) {
|
| 223 |
+
const flipX = target.flipX ? -1 : 1;
|
| 224 |
+
const flipY = target.flipY ? -1 : 1;
|
| 225 |
+
flipped = flipX * flipY === -1;
|
| 226 |
+
const showEyes = flipped ? !visibleEyes : visibleEyes;
|
| 227 |
+
|
| 228 |
+
if (target.angle === 0) {
|
| 229 |
+
const rtop = target.top
|
| 230 |
+
const rleft = target.left
|
| 231 |
+
for (const item of target._objects){
|
| 232 |
+
let p = item;
|
| 233 |
+
p.scaleX = 1;
|
| 234 |
+
p.scaleY = 1;
|
| 235 |
+
const top = rtop + p.top * target.scaleY * flipY + target.height * target.scaleY / 2;
|
| 236 |
+
const left = rleft + p.left * target.scaleX * flipX + (target.width * target.scaleX / 2);
|
| 237 |
+
p['_top'] = top;
|
| 238 |
+
p['_left'] = left;
|
| 239 |
+
if (p["id"] === 0) {
|
| 240 |
+
p.line1 && p.line1.set({ 'x1': left, 'y1': top });
|
| 241 |
+
}else{
|
| 242 |
+
p.line1 && p.line1.set({ 'x2': left, 'y2': top });
|
| 243 |
+
}
|
| 244 |
+
if (p['id'] === 14 || p['id'] === 15) {
|
| 245 |
+
p.radius = showEyes ? 5 : 0;
|
| 246 |
+
if (p.line1) p.line1.strokeWidth = showEyes ? 10 : 0;
|
| 247 |
+
if (p.line2) p.line2.strokeWidth = showEyes ? 10 : 0;
|
| 248 |
+
}
|
| 249 |
+
p.line2 && p.line2.set({ 'x1': left, 'y1': top });
|
| 250 |
+
p.line3 && p.line3.set({ 'x1': left, 'y1': top });
|
| 251 |
+
p.line4 && p.line4.set({ 'x1': left, 'y1': top });
|
| 252 |
+
p.line5 && p.line5.set({ 'x1': left, 'y1': top });
|
| 253 |
+
|
| 254 |
+
}
|
| 255 |
+
} else {
|
| 256 |
+
const aCoords = target.aCoords;
|
| 257 |
+
const center = {'x': (aCoords.tl.x + aCoords.br.x)/2, 'y': (aCoords.tl.y + aCoords.br.y)/2};
|
| 258 |
+
const rad = target.angle * Math.PI / 180;
|
| 259 |
+
const sin = Math.sin(rad);
|
| 260 |
+
const cos = Math.cos(rad);
|
| 261 |
+
|
| 262 |
+
for (const item of target._objects){
|
| 263 |
+
let p = item;
|
| 264 |
+
const p_top = p.top * target.scaleY * flipY;
|
| 265 |
+
const p_left = p.left * target.scaleX * flipX;
|
| 266 |
+
const left = center.x + p_left * cos - p_top * sin;
|
| 267 |
+
const top = center.y + p_left * sin + p_top * cos;
|
| 268 |
+
p['_top'] = top;
|
| 269 |
+
p['_left'] = left;
|
| 270 |
+
if (p["id"] === 0) {
|
| 271 |
+
p.line1 && p.line1.set({ 'x1': left, 'y1': top });
|
| 272 |
+
}else{
|
| 273 |
+
p.line1 && p.line1.set({ 'x2': left, 'y2': top });
|
| 274 |
+
}
|
| 275 |
+
if (p['id'] === 14 || p['id'] === 15) {
|
| 276 |
+
p.radius = showEyes ? 5 : 0.3;
|
| 277 |
+
if (p.line1) p.line1.strokeWidth = showEyes ? 10 : 0;
|
| 278 |
+
if (p.line2) p.line2.strokeWidth = showEyes ? 10 : 0;
|
| 279 |
+
}
|
| 280 |
+
p.line2 && p.line2.set({ 'x1': left, 'y1': top });
|
| 281 |
+
p.line3 && p.line3.set({ 'x1': left, 'y1': top });
|
| 282 |
+
p.line4 && p.line4.set({ 'x1': left, 'y1': top });
|
| 283 |
+
p.line5 && p.line5.set({ 'x1': left, 'y1': top });
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
} else {
|
| 287 |
+
var p = target;
|
| 288 |
+
if (p["id"] === 0) {
|
| 289 |
+
p.line1 && p.line1.set({ 'x1': p.left, 'y1': p.top });
|
| 290 |
+
}else{
|
| 291 |
+
p.line1 && p.line1.set({ 'x2': p.left, 'y2': p.top });
|
| 292 |
+
}
|
| 293 |
+
p.line2 && p.line2.set({ 'x1': p.left, 'y1': p.top });
|
| 294 |
+
p.line3 && p.line3.set({ 'x1': p.left, 'y1': p.top });
|
| 295 |
+
p.line4 && p.line4.set({ 'x1': p.left, 'y1': p.top });
|
| 296 |
+
p.line5 && p.line5.set({ 'x1': p.left, 'y1': p.top });
|
| 297 |
+
}
|
| 298 |
+
canvas.renderAll();
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
canvas.on('object:moving', function(e) {
|
| 302 |
+
updateLines(e.target);
|
| 303 |
+
});
|
| 304 |
+
|
| 305 |
+
canvas.on('object:scaling', function(e) {
|
| 306 |
+
updateLines(e.target);
|
| 307 |
+
canvas.renderAll();
|
| 308 |
+
});
|
| 309 |
+
|
| 310 |
+
canvas.on('object:rotating', function(e) {
|
| 311 |
+
updateLines(e.target);
|
| 312 |
+
canvas.renderAll();
|
| 313 |
+
});
|
| 314 |
+
|
| 315 |
+
canvas.on("object:modified", function () {
|
| 316 |
+
if (lockMode) return;
|
| 317 |
+
undo_history.push(JSON.stringify(canvas));
|
| 318 |
+
redo_history.length = 0;
|
| 319 |
+
});
|
| 320 |
+
|
| 321 |
+
resizeCanvas(...openpose_obj.resolution)
|
| 322 |
+
|
| 323 |
+
setPose(default_keypoints)
|
| 324 |
+
|
| 325 |
+
undo_history.push(JSON.stringify(canvas));
|
| 326 |
+
|
| 327 |
+
const json_observer = new MutationObserver((m) => {
|
| 328 |
+
if(gradioApp().querySelector('#tab_openpose_editor').style.display!=='block') return;
|
| 329 |
+
try {
|
| 330 |
+
const raw = gradioApp().querySelector("#jsonbox").querySelector("textarea").value
|
| 331 |
+
if(raw.length!==0) detectImage(raw);
|
| 332 |
+
} catch(e){console.log(e)}
|
| 333 |
+
})
|
| 334 |
+
json_observer.observe(gradioApp().querySelector("#jsonbox"), { "attributes": true })
|
| 335 |
+
|
| 336 |
+
// document.addEventListener('keydown', function(e) {
|
| 337 |
+
// if (e.key !== undefined) {
|
| 338 |
+
// if((e.key == "z" && (e.metaKey || e.ctrlKey || e.altKey))) undo()
|
| 339 |
+
// if((e.key == "y" && (e.metaKey || e.ctrlKey || e.altKey))) redo()
|
| 340 |
+
// }
|
| 341 |
+
// })
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
function resetCanvas(){
|
| 345 |
+
const canvas = openpose_editor_canvas;
|
| 346 |
+
canvas.clear()
|
| 347 |
+
canvas.backgroundColor = "#000"
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
function savePNG(){
|
| 351 |
+
openpose_editor_canvas.getObjects("image").forEach((img) => {
|
| 352 |
+
img.set({
|
| 353 |
+
opacity: 0
|
| 354 |
+
});
|
| 355 |
+
})
|
| 356 |
+
if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0
|
| 357 |
+
openpose_editor_canvas.discardActiveObject();
|
| 358 |
+
openpose_editor_canvas.renderAll()
|
| 359 |
+
openpose_editor_elem.toBlob((blob) => {
|
| 360 |
+
const a = document.createElement("a");
|
| 361 |
+
a.href = URL.createObjectURL(blob);
|
| 362 |
+
a.download = "pose.png";
|
| 363 |
+
a.click();
|
| 364 |
+
URL.revokeObjectURL(a.href);
|
| 365 |
+
});
|
| 366 |
+
openpose_editor_canvas.getObjects("image").forEach((img) => {
|
| 367 |
+
img.set({
|
| 368 |
+
opacity: 1
|
| 369 |
+
});
|
| 370 |
+
})
|
| 371 |
+
if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0.5
|
| 372 |
+
openpose_editor_canvas.renderAll()
|
| 373 |
+
return openpose_editor_canvas
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
function serializeJSON(){
|
| 377 |
+
const json = JSON.stringify({
|
| 378 |
+
"width": openpose_editor_canvas.width,
|
| 379 |
+
"height": openpose_editor_canvas.height,
|
| 380 |
+
"keypoints": openpose_editor_canvas.getObjects().filter((item) => {
|
| 381 |
+
if (item.type === "circle") return item
|
| 382 |
+
}).map((item) => {
|
| 383 |
+
return [Math.round(item.left), Math.round(item.top)]
|
| 384 |
+
})
|
| 385 |
+
}, null, 4)
|
| 386 |
+
return json;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
function saveJSON(){
|
| 390 |
+
const json = serializeJSON()
|
| 391 |
+
const blob = new Blob([json], {
|
| 392 |
+
type: "application/json"
|
| 393 |
+
});
|
| 394 |
+
const filename = "pose-" + Date.now().toString() + ".json"
|
| 395 |
+
const a = document.createElement("a");
|
| 396 |
+
a.href = URL.createObjectURL(blob);
|
| 397 |
+
a.download = filename;
|
| 398 |
+
a.click();
|
| 399 |
+
URL.revokeObjectURL(a.href);
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
async function loadJSON(file){
|
| 403 |
+
const url = await fileToDataUrl(file)
|
| 404 |
+
const response = await fetch(url)
|
| 405 |
+
const json = await response.json()
|
| 406 |
+
if (json["width"] && json["height"]) {
|
| 407 |
+
resizeCanvas(json["width"], json["height"])
|
| 408 |
+
}else{
|
| 409 |
+
throw new Error('width, height is invalid');
|
| 410 |
+
}
|
| 411 |
+
if (json["keypoints"].length % 18 === 0) {
|
| 412 |
+
setPose(json["keypoints"])
|
| 413 |
+
}else{
|
| 414 |
+
throw new Error('keypoints is invalid')
|
| 415 |
+
}
|
| 416 |
+
return [json["width"], json["height"]]
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
function savePreset(){
|
| 420 |
+
var name = prompt("Preset Name")
|
| 421 |
+
const json = serializeJSON()
|
| 422 |
+
return [name, json]
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
function loadPreset(json){
|
| 426 |
+
try {
|
| 427 |
+
json = JSON.parse(json)
|
| 428 |
+
if (json["width"] && json["height"]) {
|
| 429 |
+
resizeCanvas(json["width"], json["height"])
|
| 430 |
+
}else{
|
| 431 |
+
throw new Error('width, height is invalid');
|
| 432 |
+
}
|
| 433 |
+
if (json["keypoints"].length % 18 === 0) {
|
| 434 |
+
setPose(json["keypoints"])
|
| 435 |
+
}else{
|
| 436 |
+
throw new Error('keypoints is invalid')
|
| 437 |
+
}
|
| 438 |
+
return [json["width"], json["height"]]
|
| 439 |
+
}catch(e){
|
| 440 |
+
console.error(e)
|
| 441 |
+
alert("Invalid JSON")
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
async function addBackground(file){
|
| 446 |
+
const url = await fileToDataUrl(file)
|
| 447 |
+
openpose_editor_canvas.setBackgroundImage(url, openpose_editor_canvas.renderAll.bind(openpose_editor_canvas), {
|
| 448 |
+
opacity: 0.5
|
| 449 |
+
});
|
| 450 |
+
const img = new Image();
|
| 451 |
+
await (img.src = url);
|
| 452 |
+
resizeCanvas(img.width, img.height)
|
| 453 |
+
return [img.width, img.height]
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
function detectImage(raw){
|
| 457 |
+
const json = JSON.parse(raw)
|
| 458 |
+
|
| 459 |
+
let candidate = json["candidate"]
|
| 460 |
+
let subset = json["subset"]
|
| 461 |
+
const li = []
|
| 462 |
+
subset = subset.splice(0, 18)
|
| 463 |
+
for (i=0; subset.length > i; i++){
|
| 464 |
+
if (Number.isInteger(subset[i]) && subset[i] >= 0){
|
| 465 |
+
li.push(candidate[subset[i]])
|
| 466 |
+
}else{
|
| 467 |
+
li.push([-1,-1])
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
if(li.length === 0){
|
| 471 |
+
const bgimage = openpose_editor_canvas.backgroundImage
|
| 472 |
+
setPose(li);
|
| 473 |
+
openpose_editor_canvas.backgroundImage = bgimage
|
| 474 |
+
return;
|
| 475 |
+
}
|
| 476 |
+
if(li.every(([x,y])=>x===-1&&y===-1)){
|
| 477 |
+
const ra_width = Math.floor(Math.random() * openpose_editor_canvas.width)
|
| 478 |
+
const ra_height = Math.floor(Math.random() * openpose_editor_canvas.height)
|
| 479 |
+
li[0] = [ra_width, ra_height]
|
| 480 |
+
}
|
| 481 |
+
default_relative_keypoints = []
|
| 482 |
+
for(i=0;i<default_keypoints.length;i++){
|
| 483 |
+
default_relative_keypoints.push([])
|
| 484 |
+
for(j=0;j<default_keypoints.length;j++){
|
| 485 |
+
x = default_keypoints[j][0] - default_keypoints[i][0];
|
| 486 |
+
y = default_keypoints[j][1] - default_keypoints[i][1];
|
| 487 |
+
default_relative_keypoints[i].push([x,y])
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
kp_connect = (i,j)=>{
|
| 491 |
+
for(idx=0;idx<connect_keypoints.length;idx++){
|
| 492 |
+
cp = connect_keypoints[idx];
|
| 493 |
+
if(((cp[0]===i)&&(cp[1]===j)) || ((cp[0]===j)&&(cp[1]===i))){
|
| 494 |
+
return true;
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
return false;
|
| 498 |
+
}
|
| 499 |
+
// bfs propagate
|
| 500 |
+
while(li.some(([x,y])=>x===-1&&y===-1)){
|
| 501 |
+
for(i=0;i<li.length;i++){
|
| 502 |
+
if(li[i][0]===-1){
|
| 503 |
+
continue;
|
| 504 |
+
}
|
| 505 |
+
for(j=0;j<li.length;j++){
|
| 506 |
+
if(li[j][0]===-1 && kp_connect(i,j)){
|
| 507 |
+
x = li[i][0] + default_relative_keypoints[i][j][0]
|
| 508 |
+
y = li[i][1] + default_relative_keypoints[i][j][1]
|
| 509 |
+
x = Math.min(Math.max(x, 0), openpose_editor_canvas.width);
|
| 510 |
+
y = Math.min(Math.max(y, 0), openpose_editor_canvas.height);
|
| 511 |
+
li[j] = [x,y];
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
const bgimage = openpose_editor_canvas.backgroundImage
|
| 517 |
+
setPose(li);
|
| 518 |
+
openpose_editor_canvas.backgroundImage = bgimage
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
function sendImage(type, index){
|
| 522 |
+
openpose_editor_canvas.getObjects("image").forEach((img) => {
|
| 523 |
+
img.set({
|
| 524 |
+
opacity: 0
|
| 525 |
+
});
|
| 526 |
+
})
|
| 527 |
+
if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0
|
| 528 |
+
openpose_editor_canvas.discardActiveObject();
|
| 529 |
+
openpose_editor_canvas.renderAll()
|
| 530 |
+
openpose_editor_elem.toBlob((blob) => {
|
| 531 |
+
const file = new File(([blob]), "pose.png")
|
| 532 |
+
const dt = new DataTransfer();
|
| 533 |
+
dt.items.add(file);
|
| 534 |
+
const list = dt.files
|
| 535 |
+
const selector = type === "txt2img" ? "#txt2img_script_container" : "#img2img_script_container"
|
| 536 |
+
if (type === "txt2img"){
|
| 537 |
+
switch_to_txt2img()
|
| 538 |
+
}else if(type === "img2img"){
|
| 539 |
+
switch_to_img2img()
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
const isNew = window.gradio_config.version.replace("\n", "") >= "3.23.0"
|
| 543 |
+
const accordion_selector = isNew ? "#controlnet > .label-wrap > .icon" : "#controlnet .transition"
|
| 544 |
+
const accordion = gradioApp().querySelector(selector).querySelector(accordion_selector)
|
| 545 |
+
if (isNew ? accordion.style.transform == "rotate(90deg)" : accordion.classList.contains("rotate-90")) {
|
| 546 |
+
accordion.click()
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
let input = gradioApp().querySelector(selector).querySelector("#controlnet").querySelector("input[type='file']");
|
| 550 |
+
|
| 551 |
+
const input_image = (input) =>{
|
| 552 |
+
try {
|
| 553 |
+
if(input.previousElementSibling
|
| 554 |
+
&& input.previousElementSibling.previousElementSibling
|
| 555 |
+
&& input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']")) {
|
| 556 |
+
input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']").click()
|
| 557 |
+
}
|
| 558 |
+
} catch (e) {
|
| 559 |
+
console.error(e)
|
| 560 |
+
}
|
| 561 |
+
input.value = "";
|
| 562 |
+
input.files = list;
|
| 563 |
+
const event = new Event('change', { 'bubbles': true, "composed": true });
|
| 564 |
+
input.dispatchEvent(event);
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
if (input == null){
|
| 568 |
+
const callback = (observer) => {
|
| 569 |
+
input = gradioApp().querySelector(selector).querySelector("#controlnet").querySelector("input[type='file']");
|
| 570 |
+
if (input == null) {
|
| 571 |
+
console.error('input[type=file] NOT exists')
|
| 572 |
+
return
|
| 573 |
+
}else{
|
| 574 |
+
input_image(input)
|
| 575 |
+
observer.disconnect()
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
const observer = new MutationObserver(callback);
|
| 579 |
+
observer.observe(gradioApp().querySelector(selector).querySelector("#controlnet"), { childList: true });
|
| 580 |
+
}else{
|
| 581 |
+
input_image(input)
|
| 582 |
+
}
|
| 583 |
+
});
|
| 584 |
+
openpose_editor_canvas.getObjects("image").forEach((img) => {
|
| 585 |
+
img.set({
|
| 586 |
+
opacity: 1
|
| 587 |
+
});
|
| 588 |
+
})
|
| 589 |
+
if (openpose_editor_canvas.backgroundImage) openpose_editor_canvas.backgroundImage.opacity = 0.5
|
| 590 |
+
openpose_editor_canvas.renderAll()
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
function canvas_onDragOver(event) {
|
| 594 |
+
canvas_drag_overlay = gradioApp().querySelector("#canvas_drag_overlay");
|
| 595 |
+
|
| 596 |
+
if (event.dataTransfer.items[0].type.startsWith("image/")) {
|
| 597 |
+
event.preventDefault();
|
| 598 |
+
canvas_drag_overlay.textContent = "Add Background";
|
| 599 |
+
canvas_drag_overlay.style.visibility = "visible";
|
| 600 |
+
} else if (event.dataTransfer.items[0].type == "application/json") {
|
| 601 |
+
event.preventDefault();
|
| 602 |
+
canvas_drag_overlay.textContent = "Load JSON";
|
| 603 |
+
canvas_drag_overlay.style.visibility = "visible";
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
function canvas_onDrop(event) {
|
| 608 |
+
canvas_drag_overlay = gradioApp().querySelector("#canvas_drag_overlay");
|
| 609 |
+
|
| 610 |
+
if (event.dataTransfer.items[0].type.startsWith("image/")) {
|
| 611 |
+
event.preventDefault();
|
| 612 |
+
input = gradioApp().querySelector("#openpose_bg_button").previousElementSibling;
|
| 613 |
+
input.files = event.dataTransfer.files;
|
| 614 |
+
const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
|
| 615 |
+
input.dispatchEvent(changeEvent);
|
| 616 |
+
canvas_drag_overlay.style.visibility = "hidden";
|
| 617 |
+
} else if (event.dataTransfer.items[0].type == "application/json") {
|
| 618 |
+
event.preventDefault();
|
| 619 |
+
input = gradioApp().querySelector("#openpose_json_button").previousElementSibling;
|
| 620 |
+
input.files = event.dataTransfer.files;
|
| 621 |
+
const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
|
| 622 |
+
input.dispatchEvent(changeEvent);
|
| 623 |
+
canvas_drag_overlay.style.visibility = "hidden";
|
| 624 |
+
}
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
function button_onDragOver(event) {
|
| 628 |
+
if (((event.target.id == "openpose_detect_button" || event.target.id == "openpose_bg_button") && event.dataTransfer.items[0].type.startsWith("image/")) ||
|
| 629 |
+
(event.target.id == "openpose_json_button" && event.dataTransfer.items[0].type == "application/json")) {
|
| 630 |
+
event.preventDefault();
|
| 631 |
+
event.target.classList.remove("gr-button-secondary");
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
function button_onDragLeave(event) {
|
| 636 |
+
event.target.classList.add("gr-button-secondary");
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
function detect_onDrop(event) {
|
| 640 |
+
if (event.dataTransfer.items[0].type.startsWith("image/")) {
|
| 641 |
+
event.preventDefault();
|
| 642 |
+
input = event.target.previousElementSibling;
|
| 643 |
+
input.files = event.dataTransfer.files;
|
| 644 |
+
const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
|
| 645 |
+
input.dispatchEvent(changeEvent);
|
| 646 |
+
}
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
function json_onDrop(event) {
|
| 650 |
+
if (event.dataTransfer.items[0].type == "application/json") {
|
| 651 |
+
event.preventDefault();
|
| 652 |
+
input = event.target.previousElementSibling;
|
| 653 |
+
input.files = event.dataTransfer.files;
|
| 654 |
+
const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
|
| 655 |
+
input.dispatchEvent(changeEvent);
|
| 656 |
+
}
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
onUiLoaded(function() {
|
| 660 |
+
initCanvas(gradioApp().querySelector('#openpose_editor_canvas'))
|
| 661 |
+
|
| 662 |
+
var canvas_drag_overlay = document.createElement("div");
|
| 663 |
+
canvas_drag_overlay.id = "canvas_drag_overlay"
|
| 664 |
+
canvas_drag_overlay.style = "pointer-events: none; visibility: hidden; display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; color: white; font-size: 2.5em; font-family: inherit; font-weight: 600; line-height: 100%; background: rgba(0,0,0,0.5); margin: 0.25rem; border-radius: 0.25rem; border: 0.5px solid; position: absolute;"
|
| 665 |
+
|
| 666 |
+
var canvas = gradioApp().querySelector("#tab_openpose_editor .canvas-container")
|
| 667 |
+
canvas.appendChild(canvas_drag_overlay)
|
| 668 |
+
canvas.addEventListener("dragover", canvas_onDragOver);
|
| 669 |
+
canvas.addEventListener("dragleave", () => gradioApp().querySelector("#canvas_drag_overlay").style.visibility = "hidden");
|
| 670 |
+
canvas.addEventListener("drop", canvas_onDrop);
|
| 671 |
+
|
| 672 |
+
var bg_button = gradioApp().querySelector("#openpose_bg_button")
|
| 673 |
+
bg_button.addEventListener("dragover", button_onDragOver);
|
| 674 |
+
bg_button.addEventListener("dragleave", button_onDragLeave);
|
| 675 |
+
bg_button.addEventListener("drop", canvas_onDrop);
|
| 676 |
+
bg_button.addEventListener("drop", event => event.target.classList.add("gr-button-secondary"));
|
| 677 |
+
bg_button.classList.add("gr-button-secondary");
|
| 678 |
+
|
| 679 |
+
var detect_button = gradioApp().querySelector("#openpose_detect_button")
|
| 680 |
+
detect_button.addEventListener("dragover", button_onDragOver);
|
| 681 |
+
detect_button.addEventListener("dragleave", button_onDragLeave);
|
| 682 |
+
detect_button.addEventListener("drop", detect_onDrop);
|
| 683 |
+
detect_button.classList.add("gr-button-secondary");
|
| 684 |
+
|
| 685 |
+
var json_button = gradioApp().querySelector("#openpose_json_button")
|
| 686 |
+
json_button.addEventListener("dragover", button_onDragOver);
|
| 687 |
+
json_button.addEventListener("dragleave", button_onDragLeave);
|
| 688 |
+
json_button.addEventListener("drop", json_onDrop);
|
| 689 |
+
json_button.classList.add("gr-button-secondary");
|
| 690 |
+
|
| 691 |
+
})
|
openpose-editor/scripts/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
openpose-editor/scripts/main.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
import modules.scripts as scripts
|
| 10 |
+
from modules import script_callbacks
|
| 11 |
+
from modules.shared import opts
|
| 12 |
+
from modules.paths import models_path
|
| 13 |
+
|
| 14 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 15 |
+
|
| 16 |
+
from scripts.openpose.body import Body
|
| 17 |
+
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
body_estimation = None
|
| 21 |
+
presets_file = os.path.join(scripts.basedir(), "presets.json")
|
| 22 |
+
presets = {}
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
with open(presets_file) as file:
|
| 26 |
+
presets = json.load(file)
|
| 27 |
+
except FileNotFoundError:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def pil2cv(in_image):
|
| 31 |
+
out_image = np.array(in_image, dtype=np.uint8)
|
| 32 |
+
|
| 33 |
+
if out_image.shape[2] == 3:
|
| 34 |
+
out_image = cv2.cvtColor(out_image, cv2.COLOR_RGB2BGR)
|
| 35 |
+
return out_image
|
| 36 |
+
|
| 37 |
+
def candidate2li(li):
|
| 38 |
+
res = []
|
| 39 |
+
for x, y, *_ in li:
|
| 40 |
+
res.append([x, y])
|
| 41 |
+
return res
|
| 42 |
+
|
| 43 |
+
def subset2li(li):
|
| 44 |
+
res = []
|
| 45 |
+
for r in li:
|
| 46 |
+
for c in r:
|
| 47 |
+
res.append(c)
|
| 48 |
+
return res
|
| 49 |
+
|
| 50 |
+
class Script(scripts.Script):
|
| 51 |
+
def __init__(self) -> None:
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
def title(self):
|
| 55 |
+
return "OpenPose Editor"
|
| 56 |
+
|
| 57 |
+
def show(self, is_img2img):
|
| 58 |
+
return scripts.AlwaysVisible
|
| 59 |
+
|
| 60 |
+
def ui(self, is_img2img):
|
| 61 |
+
return ()
|
| 62 |
+
|
| 63 |
+
def on_ui_tabs():
|
| 64 |
+
with gr.Blocks(analytics_enabled=False) as openpose_editor:
|
| 65 |
+
with gr.Row():
|
| 66 |
+
with gr.Column():
|
| 67 |
+
width = gr.Slider(label="width", minimum=64, maximum=2048, value=512, step=64, interactive=True)
|
| 68 |
+
height = gr.Slider(label="height", minimum=64, maximum=2048, value=512, step=64, interactive=True)
|
| 69 |
+
with gr.Row():
|
| 70 |
+
add = gr.Button(value="Add", variant="primary")
|
| 71 |
+
# delete = gr.Button(value="Delete")
|
| 72 |
+
with gr.Row():
|
| 73 |
+
reset_btn = gr.Button(value="Reset")
|
| 74 |
+
json_input = gr.UploadButton(label="Load from JSON", file_types=[".json"], elem_id="openpose_json_button")
|
| 75 |
+
png_input = gr.UploadButton(label="Detect from Image", file_types=["image"], type="bytes", elem_id="openpose_detect_button")
|
| 76 |
+
bg_input = gr.UploadButton(label="Add Background Image", file_types=["image"], elem_id="openpose_bg_button")
|
| 77 |
+
with gr.Row():
|
| 78 |
+
preset_list = gr.Dropdown(label="Presets", choices=sorted(presets.keys()), interactive=True)
|
| 79 |
+
preset_load = gr.Button(value="Load Preset")
|
| 80 |
+
preset_save = gr.Button(value="Save Preset")
|
| 81 |
+
|
| 82 |
+
with gr.Column():
|
| 83 |
+
# gradioooooo...
|
| 84 |
+
canvas = gr.HTML('<canvas id="openpose_editor_canvas" width="512" height="512" style="margin: 0.25rem; border-radius: 0.25rem; border: 0.5px solid"></canvas>')
|
| 85 |
+
jsonbox = gr.Text(label="json", elem_id="jsonbox", visible=False)
|
| 86 |
+
with gr.Row():
|
| 87 |
+
json_output = gr.Button(value="Save JSON")
|
| 88 |
+
png_output = gr.Button(value="Save PNG")
|
| 89 |
+
send_t2t = gr.Button(value="Send to txt2img")
|
| 90 |
+
send_i2i = gr.Button(value="Send to img2img")
|
| 91 |
+
control_net_max_models_num = getattr(opts, 'control_net_max_models_num', 0)
|
| 92 |
+
select_target_index = gr.Dropdown([str(i) for i in range(control_net_max_models_num)], label="Send to", value="0", interactive=True, visible=(control_net_max_models_num > 1))
|
| 93 |
+
|
| 94 |
+
def estimate(file):
|
| 95 |
+
global body_estimation
|
| 96 |
+
|
| 97 |
+
if body_estimation is None:
|
| 98 |
+
model_path = os.path.join(models_path, "openpose", "body_pose_model.pth")
|
| 99 |
+
if not os.path.isfile(model_path):
|
| 100 |
+
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
|
| 101 |
+
load_file_from_url(body_model_path, model_dir=os.path.join(models_path, "openpose"))
|
| 102 |
+
body_estimation = Body(model_path)
|
| 103 |
+
|
| 104 |
+
stream = io.BytesIO(file)
|
| 105 |
+
img = Image.open(stream)
|
| 106 |
+
candidate, subset = body_estimation(pil2cv(img))
|
| 107 |
+
|
| 108 |
+
result = {
|
| 109 |
+
"candidate": candidate2li(candidate),
|
| 110 |
+
"subset": subset2li(subset),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
return str(result).replace("'", '"')
|
| 114 |
+
|
| 115 |
+
def savePreset(name, data):
|
| 116 |
+
if name:
|
| 117 |
+
presets[name] = json.loads(data)
|
| 118 |
+
with open(presets_file, "w") as file:
|
| 119 |
+
json.dump(presets, file)
|
| 120 |
+
return gr.update(choices=sorted(presets.keys()), value=name), json.dumps(data)
|
| 121 |
+
return gr.update(), gr.update()
|
| 122 |
+
|
| 123 |
+
dummy_component = gr.Label(visible=False)
|
| 124 |
+
preset = gr.Text(visible=False)
|
| 125 |
+
width.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
|
| 126 |
+
height.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
|
| 127 |
+
png_output.click(None, [], None, _js="savePNG")
|
| 128 |
+
bg_input.upload(None, [bg_input], [width, height], _js="addBackground")
|
| 129 |
+
png_input.upload(estimate, png_input, jsonbox)
|
| 130 |
+
png_input.upload(None, png_input, [width, height], _js="addBackground")
|
| 131 |
+
add.click(None, [], None, _js="addPose")
|
| 132 |
+
send_t2t.click(None, select_target_index, None, _js="(i) => {sendImage('txt2img', i)}")
|
| 133 |
+
send_i2i.click(None, select_target_index, None, _js="(i) => {sendImage('img2img', i)}")
|
| 134 |
+
reset_btn.click(None, [], None, _js="resetCanvas")
|
| 135 |
+
json_input.upload(None, json_input, [width, height], _js="loadJSON")
|
| 136 |
+
json_output.click(None, None, None, _js="saveJSON")
|
| 137 |
+
preset_save.click(savePreset, [dummy_component, dummy_component], [preset_list, preset], _js="savePreset")
|
| 138 |
+
preset_load.click(None, preset, [width, height], _js="loadPreset")
|
| 139 |
+
preset_list.change(lambda selected: json.dumps(presets[selected]), preset_list, preset)
|
| 140 |
+
|
| 141 |
+
return [(openpose_editor, "OpenPose Editor", "openpose_editor")]
|
| 142 |
+
|
| 143 |
+
script_callbacks.on_ui_tabs(on_ui_tabs)
|
openpose-editor/scripts/openpose/__pycache__/body.cpython-310.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
openpose-editor/scripts/openpose/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
openpose-editor/scripts/openpose/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
openpose-editor/scripts/openpose/body.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code from https://github.com/lllyasviel/ControlNet
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
from . import util
|
| 14 |
+
from .model import bodypose_model
|
| 15 |
+
|
| 16 |
+
class Body(object):
|
| 17 |
+
def __init__(self, model_path):
|
| 18 |
+
self.model = bodypose_model()
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
self.model = self.model.cuda()
|
| 21 |
+
model_dict = util.transfer(self.model, torch.load(model_path))
|
| 22 |
+
self.model.load_state_dict(model_dict)
|
| 23 |
+
self.model.eval()
|
| 24 |
+
|
| 25 |
+
def __call__(self, oriImg):
|
| 26 |
+
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
| 27 |
+
scale_search = [0.5]
|
| 28 |
+
boxsize = 368
|
| 29 |
+
stride = 8
|
| 30 |
+
padValue = 128
|
| 31 |
+
threshold1 = 0.1
|
| 32 |
+
threshold2 = 0.05
|
| 33 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
| 34 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
| 35 |
+
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
| 36 |
+
|
| 37 |
+
for m in range(len(multiplier)):
|
| 38 |
+
scale = multiplier[m]
|
| 39 |
+
imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
| 40 |
+
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
| 41 |
+
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
| 42 |
+
im = np.ascontiguousarray(im)
|
| 43 |
+
|
| 44 |
+
data = torch.from_numpy(im).float()
|
| 45 |
+
if torch.cuda.is_available():
|
| 46 |
+
data = data.cuda()
|
| 47 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
| 50 |
+
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
| 51 |
+
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
| 52 |
+
|
| 53 |
+
# extract outputs, resize, and remove padding
|
| 54 |
+
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
| 55 |
+
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
|
| 56 |
+
heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
|
| 57 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
| 58 |
+
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
|
| 59 |
+
|
| 60 |
+
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
| 61 |
+
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
| 62 |
+
paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
|
| 63 |
+
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
| 64 |
+
paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
|
| 65 |
+
|
| 66 |
+
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
| 67 |
+
paf_avg += + paf / len(multiplier)
|
| 68 |
+
|
| 69 |
+
all_peaks = []
|
| 70 |
+
peak_counter = 0
|
| 71 |
+
|
| 72 |
+
for part in range(18):
|
| 73 |
+
map_ori = heatmap_avg[:, :, part]
|
| 74 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
| 75 |
+
|
| 76 |
+
map_left = np.zeros(one_heatmap.shape)
|
| 77 |
+
map_left[1:, :] = one_heatmap[:-1, :]
|
| 78 |
+
map_right = np.zeros(one_heatmap.shape)
|
| 79 |
+
map_right[:-1, :] = one_heatmap[1:, :]
|
| 80 |
+
map_up = np.zeros(one_heatmap.shape)
|
| 81 |
+
map_up[:, 1:] = one_heatmap[:, :-1]
|
| 82 |
+
map_down = np.zeros(one_heatmap.shape)
|
| 83 |
+
map_down[:, :-1] = one_heatmap[:, 1:]
|
| 84 |
+
|
| 85 |
+
peaks_binary = np.logical_and.reduce(
|
| 86 |
+
(one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > threshold1))
|
| 87 |
+
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
|
| 88 |
+
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
| 89 |
+
peak_id = range(peak_counter, peak_counter + len(peaks))
|
| 90 |
+
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
|
| 91 |
+
|
| 92 |
+
all_peaks.append(peaks_with_score_and_id)
|
| 93 |
+
peak_counter += len(peaks)
|
| 94 |
+
|
| 95 |
+
# find connection in the specified sequence, center 29 is in the position 15
|
| 96 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
| 97 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
| 98 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
| 99 |
+
# the middle joints heatmap correpondence
|
| 100 |
+
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
|
| 101 |
+
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
|
| 102 |
+
[55, 56], [37, 38], [45, 46]]
|
| 103 |
+
|
| 104 |
+
connection_all = []
|
| 105 |
+
special_k = []
|
| 106 |
+
mid_num = 10
|
| 107 |
+
|
| 108 |
+
for k in range(len(mapIdx)):
|
| 109 |
+
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
| 110 |
+
candA = all_peaks[limbSeq[k][0] - 1]
|
| 111 |
+
candB = all_peaks[limbSeq[k][1] - 1]
|
| 112 |
+
nA = len(candA)
|
| 113 |
+
nB = len(candB)
|
| 114 |
+
indexA, indexB = limbSeq[k]
|
| 115 |
+
if (nA != 0 and nB != 0):
|
| 116 |
+
connection_candidate = []
|
| 117 |
+
for i in range(nA):
|
| 118 |
+
for j in range(nB):
|
| 119 |
+
vec = np.subtract(candB[j][:2], candA[i][:2])
|
| 120 |
+
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
| 121 |
+
norm = max(0.001, norm)
|
| 122 |
+
vec = np.divide(vec, norm)
|
| 123 |
+
|
| 124 |
+
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
|
| 125 |
+
np.linspace(candA[i][1], candB[j][1], num=mid_num)))
|
| 126 |
+
|
| 127 |
+
vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
|
| 128 |
+
for I in range(len(startend))])
|
| 129 |
+
vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
|
| 130 |
+
for I in range(len(startend))])
|
| 131 |
+
|
| 132 |
+
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
| 133 |
+
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
| 134 |
+
0.5 * oriImg.shape[0] / norm - 1, 0)
|
| 135 |
+
criterion1 = len(np.nonzero(score_midpts > threshold2)[0]) > 0.8 * len(score_midpts)
|
| 136 |
+
criterion2 = score_with_dist_prior > 0
|
| 137 |
+
if criterion1 and criterion2:
|
| 138 |
+
connection_candidate.append(
|
| 139 |
+
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
|
| 140 |
+
|
| 141 |
+
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
|
| 142 |
+
connection = np.zeros((0, 5))
|
| 143 |
+
for c in range(len(connection_candidate)):
|
| 144 |
+
i, j, s = connection_candidate[c][0:3]
|
| 145 |
+
if (i not in connection[:, 3] and j not in connection[:, 4]):
|
| 146 |
+
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
|
| 147 |
+
if (len(connection) >= min(nA, nB)):
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
connection_all.append(connection)
|
| 151 |
+
else:
|
| 152 |
+
special_k.append(k)
|
| 153 |
+
connection_all.append([])
|
| 154 |
+
|
| 155 |
+
# last number in each row is the total parts number of that person
|
| 156 |
+
# the second last number in each row is the score of the overall configuration
|
| 157 |
+
subset = -1 * np.ones((0, 20))
|
| 158 |
+
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
| 159 |
+
|
| 160 |
+
for k in range(len(mapIdx)):
|
| 161 |
+
if k not in special_k:
|
| 162 |
+
partAs = connection_all[k][:, 0]
|
| 163 |
+
partBs = connection_all[k][:, 1]
|
| 164 |
+
indexA, indexB = np.array(limbSeq[k]) - 1
|
| 165 |
+
|
| 166 |
+
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
| 167 |
+
found = 0
|
| 168 |
+
subset_idx = [-1, -1]
|
| 169 |
+
for j in range(len(subset)): # 1:size(subset,1):
|
| 170 |
+
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
|
| 171 |
+
subset_idx[found] = j
|
| 172 |
+
found += 1
|
| 173 |
+
|
| 174 |
+
if found == 1:
|
| 175 |
+
j = subset_idx[0]
|
| 176 |
+
if subset[j][indexB] != partBs[i]:
|
| 177 |
+
subset[j][indexB] = partBs[i]
|
| 178 |
+
subset[j][-1] += 1
|
| 179 |
+
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
| 180 |
+
elif found == 2: # if found 2 and disjoint, merge them
|
| 181 |
+
j1, j2 = subset_idx
|
| 182 |
+
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
|
| 183 |
+
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
| 184 |
+
subset[j1][:-2] += (subset[j2][:-2] + 1)
|
| 185 |
+
subset[j1][-2:] += subset[j2][-2:]
|
| 186 |
+
subset[j1][-2] += connection_all[k][i][2]
|
| 187 |
+
subset = np.delete(subset, j2, 0)
|
| 188 |
+
else: # as like found == 1
|
| 189 |
+
subset[j1][indexB] = partBs[i]
|
| 190 |
+
subset[j1][-1] += 1
|
| 191 |
+
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
| 192 |
+
|
| 193 |
+
# if find no partA in the subset, create a new subset
|
| 194 |
+
elif not found and k < 17:
|
| 195 |
+
row = -1 * np.ones(20)
|
| 196 |
+
row[indexA] = partAs[i]
|
| 197 |
+
row[indexB] = partBs[i]
|
| 198 |
+
row[-1] = 2
|
| 199 |
+
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
|
| 200 |
+
subset = np.vstack([subset, row])
|
| 201 |
+
# delete some rows of subset which has few parts occur
|
| 202 |
+
deleteIdx = []
|
| 203 |
+
for i in range(len(subset)):
|
| 204 |
+
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
| 205 |
+
deleteIdx.append(i)
|
| 206 |
+
subset = np.delete(subset, deleteIdx, axis=0)
|
| 207 |
+
|
| 208 |
+
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
| 209 |
+
# candidate: x, y, score, id
|
| 210 |
+
return candidate, subset
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
body_estimation = Body('../model/body_pose_model.pth')
|
| 214 |
+
|
| 215 |
+
test_image = '../images/ski.jpg'
|
| 216 |
+
oriImg = cv2.imread(test_image) # B,G,R order
|
| 217 |
+
candidate, subset = body_estimation(oriImg)
|
| 218 |
+
canvas = util.draw_bodypose(oriImg, candidate, subset)
|
| 219 |
+
plt.imshow(canvas[:, :, [2, 1, 0]])
|
| 220 |
+
plt.show()
|
openpose-editor/scripts/openpose/model.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code from https://github.com/lllyasviel/ControlNet
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
def make_layers(block, no_relu_layers):
|
| 10 |
+
layers = []
|
| 11 |
+
for layer_name, v in block.items():
|
| 12 |
+
if 'pool' in layer_name:
|
| 13 |
+
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
|
| 14 |
+
padding=v[2])
|
| 15 |
+
layers.append((layer_name, layer))
|
| 16 |
+
else:
|
| 17 |
+
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
|
| 18 |
+
kernel_size=v[2], stride=v[3],
|
| 19 |
+
padding=v[4])
|
| 20 |
+
layers.append((layer_name, conv2d))
|
| 21 |
+
if layer_name not in no_relu_layers:
|
| 22 |
+
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
|
| 23 |
+
|
| 24 |
+
return nn.Sequential(OrderedDict(layers))
|
| 25 |
+
|
| 26 |
+
class bodypose_model(nn.Module):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
super(bodypose_model, self).__init__()
|
| 29 |
+
|
| 30 |
+
# these layers have no relu layer
|
| 31 |
+
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
|
| 32 |
+
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
|
| 33 |
+
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
|
| 34 |
+
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
|
| 35 |
+
blocks = {}
|
| 36 |
+
block0 = OrderedDict([
|
| 37 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
| 38 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
| 39 |
+
('pool1_stage1', [2, 2, 0]),
|
| 40 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
| 41 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
| 42 |
+
('pool2_stage1', [2, 2, 0]),
|
| 43 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
| 44 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
| 45 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
| 46 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
| 47 |
+
('pool3_stage1', [2, 2, 0]),
|
| 48 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
| 49 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
| 50 |
+
('conv4_3_CPM', [512, 256, 3, 1, 1]),
|
| 51 |
+
('conv4_4_CPM', [256, 128, 3, 1, 1])
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Stage 1
|
| 56 |
+
block1_1 = OrderedDict([
|
| 57 |
+
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
|
| 58 |
+
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
|
| 59 |
+
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
|
| 60 |
+
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
|
| 61 |
+
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
block1_2 = OrderedDict([
|
| 65 |
+
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
|
| 66 |
+
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
|
| 67 |
+
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
|
| 68 |
+
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
|
| 69 |
+
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
|
| 70 |
+
])
|
| 71 |
+
blocks['block1_1'] = block1_1
|
| 72 |
+
blocks['block1_2'] = block1_2
|
| 73 |
+
|
| 74 |
+
self.model0 = make_layers(block0, no_relu_layers)
|
| 75 |
+
|
| 76 |
+
# Stages 2 - 6
|
| 77 |
+
for i in range(2, 7):
|
| 78 |
+
blocks['block%d_1' % i] = OrderedDict([
|
| 79 |
+
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
|
| 80 |
+
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
| 81 |
+
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
| 82 |
+
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
| 83 |
+
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
| 84 |
+
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
|
| 85 |
+
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
blocks['block%d_2' % i] = OrderedDict([
|
| 89 |
+
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
|
| 90 |
+
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
| 91 |
+
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
| 92 |
+
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
| 93 |
+
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
| 94 |
+
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
|
| 95 |
+
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
for k in blocks.keys():
|
| 99 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
| 100 |
+
|
| 101 |
+
self.model1_1 = blocks['block1_1']
|
| 102 |
+
self.model2_1 = blocks['block2_1']
|
| 103 |
+
self.model3_1 = blocks['block3_1']
|
| 104 |
+
self.model4_1 = blocks['block4_1']
|
| 105 |
+
self.model5_1 = blocks['block5_1']
|
| 106 |
+
self.model6_1 = blocks['block6_1']
|
| 107 |
+
|
| 108 |
+
self.model1_2 = blocks['block1_2']
|
| 109 |
+
self.model2_2 = blocks['block2_2']
|
| 110 |
+
self.model3_2 = blocks['block3_2']
|
| 111 |
+
self.model4_2 = blocks['block4_2']
|
| 112 |
+
self.model5_2 = blocks['block5_2']
|
| 113 |
+
self.model6_2 = blocks['block6_2']
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
|
| 118 |
+
out1 = self.model0(x)
|
| 119 |
+
|
| 120 |
+
out1_1 = self.model1_1(out1)
|
| 121 |
+
out1_2 = self.model1_2(out1)
|
| 122 |
+
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
| 123 |
+
|
| 124 |
+
out2_1 = self.model2_1(out2)
|
| 125 |
+
out2_2 = self.model2_2(out2)
|
| 126 |
+
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
| 127 |
+
|
| 128 |
+
out3_1 = self.model3_1(out3)
|
| 129 |
+
out3_2 = self.model3_2(out3)
|
| 130 |
+
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
| 131 |
+
|
| 132 |
+
out4_1 = self.model4_1(out4)
|
| 133 |
+
out4_2 = self.model4_2(out4)
|
| 134 |
+
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
| 135 |
+
|
| 136 |
+
out5_1 = self.model5_1(out5)
|
| 137 |
+
out5_2 = self.model5_2(out5)
|
| 138 |
+
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
| 139 |
+
|
| 140 |
+
out6_1 = self.model6_1(out6)
|
| 141 |
+
out6_2 = self.model6_2(out6)
|
| 142 |
+
|
| 143 |
+
return out6_1, out6_2
|
| 144 |
+
|
| 145 |
+
class handpose_model(nn.Module):
|
| 146 |
+
def __init__(self):
|
| 147 |
+
super(handpose_model, self).__init__()
|
| 148 |
+
|
| 149 |
+
# these layers have no relu layer
|
| 150 |
+
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
|
| 151 |
+
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
|
| 152 |
+
# stage 1
|
| 153 |
+
block1_0 = OrderedDict([
|
| 154 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
| 155 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
| 156 |
+
('pool1_stage1', [2, 2, 0]),
|
| 157 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
| 158 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
| 159 |
+
('pool2_stage1', [2, 2, 0]),
|
| 160 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
| 161 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
| 162 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
| 163 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
| 164 |
+
('pool3_stage1', [2, 2, 0]),
|
| 165 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
| 166 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
| 167 |
+
('conv4_3', [512, 512, 3, 1, 1]),
|
| 168 |
+
('conv4_4', [512, 512, 3, 1, 1]),
|
| 169 |
+
('conv5_1', [512, 512, 3, 1, 1]),
|
| 170 |
+
('conv5_2', [512, 512, 3, 1, 1]),
|
| 171 |
+
('conv5_3_CPM', [512, 128, 3, 1, 1])
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
block1_1 = OrderedDict([
|
| 175 |
+
('conv6_1_CPM', [128, 512, 1, 1, 0]),
|
| 176 |
+
('conv6_2_CPM', [512, 22, 1, 1, 0])
|
| 177 |
+
])
|
| 178 |
+
|
| 179 |
+
blocks = {}
|
| 180 |
+
blocks['block1_0'] = block1_0
|
| 181 |
+
blocks['block1_1'] = block1_1
|
| 182 |
+
|
| 183 |
+
# stage 2-6
|
| 184 |
+
for i in range(2, 7):
|
| 185 |
+
blocks['block%d' % i] = OrderedDict([
|
| 186 |
+
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
|
| 187 |
+
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
|
| 188 |
+
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
|
| 189 |
+
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
|
| 190 |
+
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
|
| 191 |
+
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
|
| 192 |
+
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
|
| 193 |
+
])
|
| 194 |
+
|
| 195 |
+
for k in blocks.keys():
|
| 196 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
| 197 |
+
|
| 198 |
+
self.model1_0 = blocks['block1_0']
|
| 199 |
+
self.model1_1 = blocks['block1_1']
|
| 200 |
+
self.model2 = blocks['block2']
|
| 201 |
+
self.model3 = blocks['block3']
|
| 202 |
+
self.model4 = blocks['block4']
|
| 203 |
+
self.model5 = blocks['block5']
|
| 204 |
+
self.model6 = blocks['block6']
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
out1_0 = self.model1_0(x)
|
| 208 |
+
out1_1 = self.model1_1(out1_0)
|
| 209 |
+
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
| 210 |
+
out_stage2 = self.model2(concat_stage2)
|
| 211 |
+
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
| 212 |
+
out_stage3 = self.model3(concat_stage3)
|
| 213 |
+
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
| 214 |
+
out_stage4 = self.model4(concat_stage4)
|
| 215 |
+
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
| 216 |
+
out_stage5 = self.model5(concat_stage5)
|
| 217 |
+
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
| 218 |
+
out_stage6 = self.model6(concat_stage6)
|
| 219 |
+
return out_stage6
|
| 220 |
+
|
| 221 |
+
|