diff --git a/loss_params.pth b/loss_params.pth new file mode 100644 index 0000000000000000000000000000000000000000..9fc2cfc6dc7bb8d4e6eab51ba6ebd96c163dca11 --- /dev/null +++ b/loss_params.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e4c687fb455b7495e325d5f1761391d281323de6d2a493b153a3dac9536664e +size 3120 diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.6.norm3.weight b/onnx/up_blocks.0/attentions.1.transformer_blocks.6.norm3.weight new file mode 100644 index 0000000000000000000000000000000000000000..e54789179622edc0d02162a4780a5b59a58b96f5 --- /dev/null +++ b/onnx/up_blocks.0/attentions.1.transformer_blocks.6.norm3.weight @@ -0,0 +1,9 @@ +[:9:::Y:k:::~:U:$:$:M:R:K:Z:|:b:Z:*:y: ::b:g:<:<:=:E::m:2:R:9::':D::Q:_:9T::s:`::l:l:::Y:g:t:9:K:2:R:a::,:8: :.:=:z:u:c::T:h:;:,:n:9G: +::E:t:T:H::Q:O:|:,:G: :[:89j:b:K:>:9Z:t:::::$:=:Z::Z:U:4::':Z:m:m:Y:D: :t:K:_:$:G:{:4:d:L:9z::n:q:f:9 +:b:::G:X:Z:$:9 :Z:E:m:Z:b:d:$:z::d:j:Z:%:W::::b:K:t:E::9K:L:9: :}:98:b:$:*:E:}:T:::8:z:c:h:4::j:4:::m::k:8:M:M:T::K:9:3:E:,:I:9:9Q:2:h:I::::9s:}::n:Y:::8:L:+:K:4:h:D: ::>:=:,:$:d:b: :b:):Z::-:m:K:E:D:<:L:j:N:::z:*:L:>:D:N:Z:Z:|:[:n:%:j:J:E:i:9[:[:U:::X:::[:G:9t:a:z:::k:b:K:\:::,:j:Z:|:=:2:K:<::1:L:N:::\:b:9D:`:::}:j::]::39:\:.:k:E::T:E:R:[:^:F:R:}:L::]:$:3:b:n:!:::q:E:~:3:K:<::|: ::T::4:::R:2:j:*:9z:$::9N:9Z:b::I:9:R:,:T:{:E:b:2:2::n:m::c:9_:6:9[:9:E:,:z:m:\:b:b:n:R::f:E:,::F:t::P:T::f::::m:9):K:A:t:&::Z:\::b:3:8:2:>:1:':9d:t:E:Y:V:\:9N:f::%:N:*:>:t:99d:T:K:P:n:Z:`:y:99T:f:1:n:):L:j:u:d:: :g::::u::T:I::m:Z:=:}:F:b:n:::M:;:F:I:4:::l::N:L:K:9m:V:&:8:&::::t:(:|:9:L:K:Y::|:T:2:t:[:::t:[:l:<:z:::N:Y:*:>:$:<:L::T:9;:,:H:K:&:j::::L:v:.:T::.:,::d:a::E:[:2:m:|:K::T:9:9 :E:m:>:::+:&:K:::X:m:<:&:2:>:%:F:::H:G:>:=:~:T:::%:(::E:L:K:\:: :E::=:9*:u::m:N:F:&:>:{:(:>::T:::.:f:Z:9Z:T:,::&:=:T:::8::: :n:\:\::>:,:9:[: :%:h:0:F:u:T:2:,:Z:L:,:*::f:<:*:m::z:>:T:M:*:9.::>:k::e:!:::&:K:z:K:Z:>:::*:,:,:::F::b:b:&:::Y:4:g:1::z:x:M:V:E:v:$:9m:N:: :t:=:d::n::b:|:z:*::z:::t:>::9\:[:,::E:::8:9:L:F::S::>:-:j::Z::S:t:}:d:2::,:9Y:f:2:j:f:j:L:d:)::y::E:&:=:>:|:E:T:[:~:::[:9d:[:T::::':e:Z::h::E:H:Z:P:4::L:::f:,:u:Y:g::T:Z:9:2:2:::L:N:92t:2:=:4::9:K:Z::92:h:L:::::R::v:[::H:9,:::z::v:::::t: :[::::n::S:)::v:9m:2:=:S:f:T:m:t:m:|:9q::@:h:<:*::N:j:2:\:,::Z:v::*:j:=:K:S::=:E:N:Y::::9M:9:::m:m:>:2:::#:S:_:X:D::N:b:U:: +:v:F:9T:N:9Z: :|::::|::Z:,:j::[:>:M::I:/:::Z:&:1:\:z:4::b:w:m:&:L:F:::<:&::.:L:D:K:n::n:&:9:1:[:::9):N:Q: +:\:F::v:9::l:m:Y:O:8:3:::b::>:L:: +:::*:,::|:d:$:N::::F:m:&:E::m::=::: +:t:::::G:[:*:Z:::b:d:S:G::X:Y:$:K::b:K:D::::0:L:E:,: +:z:<::I:2:9=:m::v:T:M:R:%:D:/:E:z::,::E:::::[:v::t:M:~::9\:~:::8:N:2::<:J:N:S:::u:1: :T::m::|::\:3:i:K:9z:Z:<:k:X:N:x: +::9[:^::K:|:E:E:n:3:t:,::[:b:*::h:E:=:S:V:_:{:|:Y:2:Z::Z::9d:D:b:m:<::d:c:z:0:E:v:2:$:f::::h:9::,:99:\::,:Y:u:j::::::9[::>:8:E:_:T:Y:9::X:,::j:i:t:9T:=:|:m:v:::t:Z:N:[:::z: :X:Z:>:b:[::l:b:,::~:*:<: :Z:n: :[:U:Z:::8:p: :G:4:D:N:-:j:::U:9^:L:G:Q:T:m:9Y:3::9$::9::T:L:::9(:d:K:=:R: \ No newline at end of file diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.7.attn1.to_out.0.bias b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.attn1.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..c73319ebfbc2c76def06c85255d94ff913fab1de Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.attn1.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.bias b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.bias new file mode 100644 index 0000000000000000000000000000000000000000..f381f72c4c8e0ea867a219e4df2bf83c1d3e29ac Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.bias differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.weight b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.weight new file mode 100644 index 0000000000000000000000000000000000000000..24eb2a85be6cc626b289d1dd9097982ea80a6a62 Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm1.weight differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm3.weight b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm3.weight new file mode 100644 index 0000000000000000000000000000000000000000..9bbb0492a60af22de563c06888817cbaed6695a3 --- /dev/null +++ b/onnx/up_blocks.0/attentions.1.transformer_blocks.7.norm3.weight @@ -0,0 +1,3 @@ +~:s:b:::j:::K:h:::h:n::::}:::=:t:::R:T:=::s:l::::b:t:z::b::::w::Z:::::d:::::S:v:v:g:::Z:N:-::T:w::b:::[:z::::R:=:::[::n:E:b:z::n:h:\::"9.:::N:p:2:h:::j:G:::t:v:|:k::I:n:::::E::t::z:z::K::a:R::u:::A:-:h:t::k::t:4: :H:j:t::Z:::Z::F:::~:b::j:L:v:z::::%:n:z:9:F::4:j::m:t:.::`:Y:\:::T:b:k:t:e:D:::L:z:L:::}::z:A:H:L:z:s:x:T:: +::P:|:Z::s:9l:::::::h::R::s:k::::::T::t:m:2::$:|::b:::;:m:::}:::j:z:|::::::::z::k:m:::b::d:b:2:v::::X:3::s::T:K::{:_::>:b::,:::|::::::b:t:::*::m::::::Y::U9G::::~:Y:~:j:b:Z:::::::w:m:m:::z:L::T:_::T:::S::l:::|:v::k:*:::<::x::D:k:::;:[:::}:::::L:w:n:::::;:K::::z:):::j::::::::[::j::::i:N:::9::F:T:|: :z::m::::::Z:::E::P:8:n:S:::f::j:T:N:v:t::{::q:::&:\::::y:::k:|:E:%:h:u:g::::t:::u:Y::R:::::b:|:::v:h::::|:b:L:v::n:|::h:::f:I:h:.:|::z:::_:e::[::j:z:v::::|:t::::::::L::j::u:[:::Q:j::::9T::V::m::::N::::::|::{::T:::\:::j:j:::2::f::|::U:a:k:f:z:\::{:Z:N:[:\:~::d:m:::E::{:l:n:N:\:Z::_:z:m:N::8:k:=:p:::::b:H:t::Y:Z::::I:::::{:m:*:=::::::::z:|:{:::N:m:L::&:m::c::::t:x:|:::T:M::|:T::Z::t:v::b:H:d:G:r:::t:|:~:|:::~::::L:M:m:y:::::[:I::j::j:z:::G:l:h::::::y:z:m::v:n:Z:: ;::|:::::b:::|:::b:u:g:4::::m:N::z::m:::F:2::::T:K::::Z::b:<:n:::|:::Q:f:9::::::::::n:b::::*:::::T::::h:::9:|:d:::Z:*:>:::b:/::::X::u:(:Q:a:t:::r:4:N:2::::::::l:b::h::N:::8:{::|:Z::p:>:T::2::::U:f::$:::::|:::l:::,:::f:::|:::d::9::b::b:::z::x::|:&:z:::l::w:::h:u:|:t:b::^::9:9:<:::m:~:u:N::~:j:z::[::[:m:L::Z:9:Y:,::d:j::Y:::::::|::|:<:}:_:::d:>:f::~:\:\:h:::v:j::[:v:::c::n:i:R:_::9{::|::U:\::e::F::::::$:m::v:::[::E:m:::::E:m:b:b::::u:U::F::::T::::f:::|::::::u:n::T::::Y::::L:j::::d:;:j:h::z:b:::z:T:::D:E:::::y::n:m:v::::q::z:V:::::d::j:::::::T:b:v:S:t:f::::l:$::t:[::::b::N::|::: :::j:::v::8:T::j:::T:d:Z:b::s:::\:M:: :b:::=:::::k::[:}:::9|:Y:V::::::::+:m:::\:::z::9{:N:/:::b:D::\:::::{:4:Z::|::::::D:&:K::}:G:{::::9t:K::z:::f::j:z::K::$:k:n::::::y:d:\::8:M:[::m:S:e:m:~::::=:^:t::=:Y::: :n: +::m:t:l:c::8::|:j:4:h:4:9|::::l: :8::h:|::: \ No newline at end of file diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.8.ff.net.2.bias b/onnx/up_blocks.0/attentions.1.transformer_blocks.8.ff.net.2.bias new file mode 100644 index 0000000000000000000000000000000000000000..91d99d19912b30ab617dd3efa6f403c2efc262cf Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.8.ff.net.2.bias differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.8.norm1.weight b/onnx/up_blocks.0/attentions.1.transformer_blocks.8.norm1.weight new file mode 100644 index 0000000000000000000000000000000000000000..6c3b66d2a385eae6a08099767467bb8687522928 Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.8.norm1.weight differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.9.ff.net.0.proj.bias b/onnx/up_blocks.0/attentions.1.transformer_blocks.9.ff.net.0.proj.bias new file mode 100644 index 0000000000000000000000000000000000000000..98d954d04b22fd98ff8cb03ab974e27a8f3b31f4 Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.9.ff.net.0.proj.bias differ diff --git a/onnx/up_blocks.0/attentions.1.transformer_blocks.9.norm3.bias b/onnx/up_blocks.0/attentions.1.transformer_blocks.9.norm3.bias new file mode 100644 index 0000000000000000000000000000000000000000..11544b4e42a08a5c66cbcc4b5e614a93aeb531b2 Binary files /dev/null and b/onnx/up_blocks.0/attentions.1.transformer_blocks.9.norm3.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.4.norm2.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.4.norm2.weight new file mode 100644 index 0000000000000000000000000000000000000000..6f0b3723e5a669498e699362e4102dccc2c78ac0 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.4.norm2.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.attn1.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.attn1.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..b286c97b82f8090b735fd36d685ea5c9d4d74a7f Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.attn1.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.0.proj.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.0.proj.bias new file mode 100644 index 0000000000000000000000000000000000000000..33a30d89e3cac35133d9e99af9ea54d3a155c76b Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.0.proj.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.2.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.2.bias new file mode 100644 index 0000000000000000000000000000000000000000..0a1a063ad3d3fc621ccb0785ea435e9a465b58bf Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.ff.net.2.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.bias new file mode 100644 index 0000000000000000000000000000000000000000..50ea8aaa1de8053d43b8654c5d5ecaa3e9a6fba4 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.weight new file mode 100644 index 0000000000000000000000000000000000000000..d19b9096aa20f41e2951dbf0a67eaa4de6d387ae Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm1.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm2.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm2.weight new file mode 100644 index 0000000000000000000000000000000000000000..ac6e0a51862d15861c198a052089a76689a81461 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.6.norm2.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.7.attn1.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.attn1.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..6509ffc1ec970fec02ed4d5c0245e13c9eb8cba8 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.attn1.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.bias new file mode 100644 index 0000000000000000000000000000000000000000..8b0588d4ab92e4bf5c8bc16ee3c8c7d4e05717df Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.weight new file mode 100644 index 0000000000000000000000000000000000000000..146e7d3a80c6d2013b20af2872791c5109b0a191 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm1.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm2.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm2.weight new file mode 100644 index 0000000000000000000000000000000000000000..3bd4e5e3284531543183f48952d82032152f2653 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm2.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm3.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm3.weight new file mode 100644 index 0000000000000000000000000000000000000000..72d489b6d888b9acce7df55c90b08e5c75f11b60 --- /dev/null +++ b/onnx/up_blocks.0/attentions.2.transformer_blocks.7.norm3.weight @@ -0,0 +1,3 @@ +{:V: :F:+:d:~::Q:Z:]:P:n:a::e::\:K:9\:\:::9F:S:N:g:|:z:%::>:K:h:l:k:D:I:$8Z:m:l:^:a::5:T:k:E::>:z:z:<:t::~:\:%:4:p:Z::L:k::H:\:Y:8:j:F:n:l:8::l:\::}:t::2{:\:N:%:\:::9^:]:d:\:k:*:+:n:9]:b:+:<::\:9~:::l:t:.:K::]:\:F:L:d:|:|:P:d:F::>:Y:j::u:u:::d:p:m:Y:M:n:M:b::4:F:4:9F:@:5::S:::V:F:a::d::P:::1:0:T::u:J:\:T:Z:<:<:S::k::#::{:d:`:|::o:{:d:n:I:>:Z:::T:~:f::M:z::e:V:k:d:`:%:::z:.:L::p:M:o::i:Z:~:z:M:{:::j:Z:0:Q:::k:F:h:k:\::<:I::f:u:3:>::{:F:i:p::`:%:<:9f:Y:f:t:H:~:z:E:F::~:>:p:t:M:F::~:):~:k:l:^::R:a:k:{:m:8:d:E:f:|9I:I: :>:^:K:x:]:F:Z::z:F:>::~:F:n:Z::`::F::Y:P:<:<:f:f:T:4::f:9N:h:k:L:M:I::f:F:l::;:F:u:+:N:\::k:K:l:6:~::9::N::Y::_::\:F:u:L:L::L:V:=::\:|:9::5:t:d:m:t:[:`:F:K:p:b:5:v:&::Y:o:::{:\:S:F::]:d:F:K:j:L:::{:k:`::::D::N:F:t:k:`:|:9N:Y:Y:W:n::: +:P:M:V:;:S:P:K:a:F:|:n:m::}:{:S:\:4:::d:f:\::z:M:F::]:j:9::l::z::4:o:Z:6:t:u:.:>:M:u:\::b: ::h:\:P:::9y:]::k::+:>:v:V::L:{:Y::>::k:m:\:d:\:n:9:]:t:k:V:L:z:>:|:{:E:5:j::\::{:.:4:5:0:t:V:v:L:h:#:F::d:::<:R:^:K:l:b:Z:Z:K::E:2:\:Z:~: :M:<:R:5:K:b:p:i:Y:o:\:0:D:{:::~:Q:j:::p:>:%:+:Q:g::a::d:E::^:+:{::K:W:^::j:T::Z:\:M:9u:l::z:<::Z:k:x::l:f:~:n:]:P::+:j::V:>:i:j:\:d::E:P:F:o:E:P:F::P:i:Z:f::[:F:k:4:~:l:6:M:::\:P:Y:S:~:D:l:j:u::M:V: : :i:#:n:L:~:P:4:n:i:^:6:a:::*:L:z::\:T:b:z::~:{::d:z:|:J:M:S:F::F:t:8:]:`:%:l::m::>:[:n:b:F:::t:0::v:~::F:t:9:v:+:n:j:~: +:::I:m::d::K:h::n:S:::~:{:6:T::O::d:S:L::.:E:V:N:::V:Z::}:L:3:+:d:d:{:M:m:z:z::n:V:V:4:L:5::W::#:e:F:d::E:F:@:0::V:t:{::Y:<::M:W:p:M:f::::%:L:M:~::u:o:+:T:F:Z::99:L:g::K:_:^:l:E::K::@:V:l:\:<:M:Z:\:F:M:]:d::k:::m:#:0:%:W:F:%: :::F:K:<:`::k:9::.:M:D:k:d:`:V:/:i:1:u:.::5:L:F:t:Z:k:|:N::d:n:z:S:M:R:[:I:j:z:L:<:8:u:~:9u:B:{::S:4:I:g:t::<:::f:u:V:Y::8:F:N:u:`:M:L:\:J: :;:{:Y:::u:+:.:r:F::#:L:i:I:u:n:z:L:T:::F:R:S:E:P::<:\:P:C:P:4:~:b:n:::K::.::?:0::M:D::<:^:%::>:8\:V:<:{:9u:a:~:m:^:f:m:E:j:f:o::P:>:4:!:::.::Z: :F:Y::n:]:m:Z:V::f:<:?:Q:Y:t::4:+::,:z::k::1:Y:l:(:H:m:::l:~:I:o:+::\:^::h:F:F:`:%:I:9\:i:T:i:E::::e:_:S:t:`::>:u::u:d:z::L::}:K:]:4:Q:%:t:]:S:F:a:b: :*:{:f:^:^:P:d:~::F::f:o::F:%:V:8::z:x:2:Z:n:4:N::V:\:\::V:{:l::F:n:y::-i::::.:~:_:,:N:t:P:%:M:::l:L:M:t:`::F:{:>::V: :,:M:]:^:N:u:M:f:Y:5:F:,:\:::.:0::m:o::4::#:P:~:T:<:P::E:u:>:.:-:S::\:u:x:y:m::f:`:u::h:`::l:p:D:4::d:W::+:u:\:H:>:4:n:M:*:V:D::h:I::M::x:::6:\:d:4:f:L:n::w:F::9m:S::x:6:b:z:T::H:{::n::<:F::+:95:e:M: \ No newline at end of file diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn1.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn1.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..d758763870c3feb4227b8c614d187ab3e3a8507d Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn1.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn2.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn2.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..1ad5101c40947382da404c67eabe4bb7b8889de9 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.attn2.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm1.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm1.bias new file mode 100644 index 0000000000000000000000000000000000000000..2380df4857d665e3fe3506320945bf8e2b2654d1 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm1.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm2.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm2.weight new file mode 100644 index 0000000000000000000000000000000000000000..618187eab8484013fa3757bdedf6bb235a097cdd Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm2.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.bias new file mode 100644 index 0000000000000000000000000000000000000000..4a67aeb7e0a1b236f7495884ee125dfd3e7ebf76 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.weight new file mode 100644 index 0000000000000000000000000000000000000000..6fedf2b8373471feef3bde32b2e66a82a5d5922c --- /dev/null +++ b/onnx/up_blocks.0/attentions.2.transformer_blocks.8.norm3.weight @@ -0,0 +1,3 @@ +M:f:#:r::D:N::\:I:J:V:n:G:\:d:P:R:9: +:N:z:t:n:9V:V:D:Y:M:F: ::0:r:2:O:X:&:l:8^:K:<:Z:::<:?:p:X::F:^:V::x::l:X:Z:X:X:f::S:Q:Z:::L:I::Y:P:l:^::P:8:@::u:c:f:w3l:_:D::Z:>:::;:`:G:G:-: :t:`:9f:J: :m:f:;:1:Y:d::\:^:#:^:I::-:d:6:d:z:^:6:;:+::T:L:V:{:O:\:l:p:f:\:p:k:\:Z:L::K:E:X::+:I:L:K:(:w:<:\:<:L:m:f:d:M:n:f:A:-:j:l:n:F:6:J:`:(:h::v:x:f:(:0:p:S:f::1:>:\:f:N:O:*:6:2:u:d:^:e:|:M:p:l:+:Q:S:G:H:M:G:g:Z::U:t:t:6:<:Y:F:K::X:c::\::?:\: :d:l::d:n:o:I:N:V:I:T:V:\:::^:M:\:P:^:K:z:>:.:K:9:+:d:F:6:w:\:G:1:w:N::X:^:_:f:l::X:^:>:M:N:z:F:`:d:F:n:g::<:l:9n::9U:s:K::F: : :t:Z:>:P::>:6:t:T::\:n:z:%:O:P:%:6:>:|:F:K:n:9:97:?:F:?:Q:Z:M:\:Y:O:y:U:<:_:>:x:X:f:G::p::e:R::d:d: :r:\:g:N::N:0:K:Z:X::P:E:Y:V:p:n:+::t::t:P:F:F:<:w:u::L:u:D:N:L::p:>:>:f:K:L:J:>:h:l:`::+:E:H:-::^:^:d:z:0:a:N:z::F:S:<:f:t::I:N:(:K:x::\:0:V:`:?:F:d:F:Z:t:f:M:n:+::{:\:E:?::~:^:V:x:>::T:p:M:d:Z:R:(:o::Z:4:`::o:f:X:V::@:i:>:m:m:d:f: +::Z:P:f::S:Y:+:?:_:J:0:6:V:\:F:t::U:*::P::Y:}:0:E:^:f:9:F:R:V:_:d:M:<:m:Q:O:<:a:^:8:::>:8:L:(:M:Y:V:Z:R::L::Z:::-:-:X::V:^:I:1:E:f:J:&:n:X:e:5:\:>:Q:f:f:u::\:Y:9:F:1:;:f:z:\:E:6:f:j:c:Q:l:,: :0:^::K::\:o:z:\:.:::a:O:'::L:l::-:=:V::]:p:m:U:f:d:Q:Q:E:]:L:f:l:\:u:E::9O:z:M:(:f:l:V:]:P:Z:^:(:M:>:p:>:f:N:O:l:t::f:.:m:X::P:P:x::(:S:N:N:I::*:[::d:w:P:E::(:M::X:S:u:n:F:\::y:(:z::M:9I:^:l:X:F:u:f:{:M:F:u:\:6:u:P:i:>:u:u:V:U:@:\::d:G:e:6::>:^:^:/:E: :f:z:\::u::X:F:V::t:^:M:M:\:f:9d:n:<:V:Z:^:x:>:]:F:4:+:m:\:t:~:1:l:S:(:p:J:>:E:,:::K:T:u:n:<:=::u:N:<:V:Q:6:d:Z:r:i:N:j:f:K:l::R:9::6:z:(:U:E:E:N:M:<:1::m:^:N:n::.:(::j:::I:I:f:d:M::Q:Q::S:Y:n::>:*:_:g:9P:(:N:P:h:F:d:V:F::n::S:\:K:^:0:F:5:\:>:?:]:;:t:T::.:>::M:\:<:N:F::]::>:X:;:M::\:K:o:\:^:M:M:@:E:l:R:F:::4:m:l:D:<:R:+:::<:~:-::}:<:(:S:l:D:~:u:d:M:1:a:L:9Z:2:K:f:F:0:R:+:a:z:K:6:<:d:\:`:r:<:%:N:-:P:N:-:}:<:z::f:N:l::p::6:L:r:K:K:F::D:V:>:u:>:d::2:>:c:F:4:f:L:E:S:K:T:>:F::V:O::>:l:X:d:.:>::G:+:o:6:^:V::M:8>:K:Y:;:9j:P::]:f:<:^:F:t:g:p:|:<:V:j:+:n:F::`:;:`:M:V:M:m:0:t::+::d::0:=:N:\:m:<:G:V:e:%:X:E:F:>:n:F:^:Y:a:0::`::::<:+:z:m:U:Z:&:P:Z:P:L::<:^:<:P:0::L::\:O:\:v:>: :S:^:{:w:X:F:n:F:^:f:0:=:2:l:+:V:<:Z:f:f:u:+:>:T:K:K:O:b:^:L:&:I:}:m:n:i:n::^:8:}:~:g:F:n:?:D:<:|:^:0:F::X:f:I:M:6::m: :T0\:>:e::::x::+::r: :\:n::Q:>:;:Y:d::l:^:P::0:<:4:>:M::l:::f:6:<:2:P:<:0:1:Q::d:Y:e:F::P:^:`:d:Z:K:n:4:X:x:(:/:Q:Z:]:l:a::{:F:d:}:z:^:\:\::k:{:F:*:l:::N:F:6:z:Z:D:.::P::%:%:: :l:2:~:N::d:S:x:F:>:>:K:f:f:8:Z::0::9:X:S:Z:2:V:n:I:n::t:F:6:m:%:F::::E:=:+: \ No newline at end of file diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn1.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn1.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..42e1e6f3e68a2a59243f65867d52368800db58d3 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn1.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn2.to_out.0.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn2.to_out.0.bias new file mode 100644 index 0000000000000000000000000000000000000000..8135adc1337f782abf0d0801fbe44329195f79c1 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.attn2.to_out.0.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.ff.net.0.proj.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.ff.net.0.proj.bias new file mode 100644 index 0000000000000000000000000000000000000000..003322f9589e33635d865b0db8837a399aa135e6 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.ff.net.0.proj.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm1.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm1.bias new file mode 100644 index 0000000000000000000000000000000000000000..981c69ccd3af192ab782a6d33127693d9cee3ba5 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm1.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm2.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm2.weight new file mode 100644 index 0000000000000000000000000000000000000000..4cd76d79c962cb7e1ef3ac3163a833a520a3910c Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm2.weight differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.bias b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.bias new file mode 100644 index 0000000000000000000000000000000000000000..16334f71c03fa6aa24cbe79225bd3a377c428d2a Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.bias differ diff --git a/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.weight b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.weight new file mode 100644 index 0000000000000000000000000000000000000000..08949c3ee51a6a84106821f49cc3c4776ba9a2d2 Binary files /dev/null and b/onnx/up_blocks.0/attentions.2.transformer_blocks.9.norm3.weight differ diff --git a/onnx/up_blocks.0/onnx__Add_6908 b/onnx/up_blocks.0/onnx__Add_6908 new file mode 100644 index 0000000000000000000000000000000000000000..2936ef35c70dd3aa99f802621eaeb52c4dcedf97 Binary files /dev/null and b/onnx/up_blocks.0/onnx__Add_6908 differ diff --git a/onnx/up_blocks.0/onnx__Add_7154 b/onnx/up_blocks.0/onnx__Add_7154 new file mode 100644 index 0000000000000000000000000000000000000000..dd8f84e2b9b6e3cbd885b319ab6d7682b89206e9 Binary files /dev/null and b/onnx/up_blocks.0/onnx__Add_7154 differ diff --git a/onnx/up_blocks.0/onnx__Add_7400 b/onnx/up_blocks.0/onnx__Add_7400 new file mode 100644 index 0000000000000000000000000000000000000000..5c4c1425cc33c3124e2029c29ac77f5d3b135338 Binary files /dev/null and b/onnx/up_blocks.0/onnx__Add_7400 differ diff --git a/onnx/up_blocks.0/onnx__Add_7402 b/onnx/up_blocks.0/onnx__Add_7402 new file mode 100644 index 0000000000000000000000000000000000000000..8cff03f0857b754b6a0eb3b37edd0e4a1af9c448 Binary files /dev/null and b/onnx/up_blocks.0/onnx__Add_7402 differ diff --git a/onnx/up_blocks.0/onnx__Add_7404 b/onnx/up_blocks.0/onnx__Add_7404 new file mode 100644 index 0000000000000000000000000000000000000000..628234a840b23f23dc98bd689754c2485b76cd93 Binary files /dev/null and b/onnx/up_blocks.0/onnx__Add_7404 differ diff --git a/onnx/up_blocks.0/onnx__Mul_7155 b/onnx/up_blocks.0/onnx__Mul_7155 new file mode 100644 index 0000000000000000000000000000000000000000..819dae504c2f0e43434b17fd68c8ab3a2d65bf60 --- /dev/null +++ b/onnx/up_blocks.0/onnx__Mul_7155 @@ -0,0 +1,9 @@ +989/9839=9b9y9k9J98[9898_898:99#9C9b9<9y999 9 9 99$939|9998R9|99 99:99-988W9z8J9)939B99899D98B98L999969B998-9U99r9I9l9J99[9D99*9(9999*988C9l889r9889949889<999-9p9:9t9:9q98v99j99499 98989#9k98[99 98d9I998G9899d99+989 9/989(99Y9K99929988{9839989J989|8J9>9599"9888939*99^9999K98w9989[9998j989[98e98889e999 9D9>99w9U9h9=9>9:988Z9b988.9d99>9z999999:9889889Y9k988898"98878888888"98L988]898988888888888888W688U99=999498898 +989B9 98889l998b9"989B9 9988989~988*9989899994^9999%98)9889 9M99M98-9:988888988[9839888888988978U9%9J98898888898%9/988U8-9888m88!988 +9w988,9888 +9j99S9788889888989=9398@98888d989.988#9898b988#98888/9998 98)9889884988-9m9+9889~9\98L98898888L9N9/98K99}9#96988888888.9C9788888b98898B98988U9.988889#98988998988<9898998888U9998:929889B998888988j9.99}9[999-99V9i9I9U9~9^98L9<9U9f99E99V9969;9999:998E9:9~9E9994:969B9>99E99+9.9@988<9+9+99988+98889#99+9*9999^9995999U9s9099 9 +99J98]88(989d99:98888}9989 98/99y9U9V9#9)9[9$92:9>98V98K998E9U9995989V999z99E9994998^9+99:9t9V99999d9]9s9<989#99f9l9988 9 99888888K9898i88788$89 9-98888888m988889>9899939*93989398+99889 988988}99:99{9+989898999889j9 98889888k999888M9 989888889 98898988999)9e98,9&9898z896999 9/99889 9E989988898*99899888998999/9k99z9j99"9G9#9+988699+9398B9 9999]98L988869839886983998U98/99r9E98"9888z8898899888^9998988 98+9498998+9 98J9'99>98889:9J98888 988"9B9+988 9s98:9k98h98 98i99>9b98 +9899/999m999J9599888<98U988 99[98989 :;988&9j9G9 99 9:988.9898i8&9 99889 99699986988J989D998 9D98889U999899899,9N9w99988#98988889*9/9U99d99r9398898 9s98$9 +98-9{989899$9/9 988988 98999[9Y988:98898899:9898U98d998+98698H998888&98889 99-998-9 +99|99:999Z9889898r9398 988k99/998+98898m9 9Z99 +9k99m989898l98H98N98898 9U999b9%98D9K9[9 9b9>9499U99 98999B99889889*989,989*98G98B98K98Y888c88888 88868888889 \ No newline at end of file diff --git a/onnx/up_blocks.0/resnets.0.conv2.bias b/onnx/up_blocks.0/resnets.0.conv2.bias new file mode 100644 index 0000000000000000000000000000000000000000..689a2910b96d69f4a3b62c52e53e8b25cbf82b5e Binary files /dev/null and b/onnx/up_blocks.0/resnets.0.conv2.bias differ diff --git a/onnx/up_blocks.0/resnets.0.time_emb_proj.bias b/onnx/up_blocks.0/resnets.0.time_emb_proj.bias new file mode 100644 index 0000000000000000000000000000000000000000..ae25c67d95fa9a236c17a89393ece17fd6e47243 Binary files /dev/null and b/onnx/up_blocks.0/resnets.0.time_emb_proj.bias differ diff --git a/onnx/up_blocks.0/resnets.1.conv2.bias b/onnx/up_blocks.0/resnets.1.conv2.bias new file mode 100644 index 0000000000000000000000000000000000000000..041afb33112c03a524b87c9778bd8692e2b7833f Binary files /dev/null and b/onnx/up_blocks.0/resnets.1.conv2.bias differ diff --git a/onnx/up_blocks.0/resnets.1.conv_shortcut.bias b/onnx/up_blocks.0/resnets.1.conv_shortcut.bias new file mode 100644 index 0000000000000000000000000000000000000000..deab7520a0a1d99efd74bb9cbfce4e3eb9e88ddd Binary files /dev/null and b/onnx/up_blocks.0/resnets.1.conv_shortcut.bias differ diff --git a/onnx/up_blocks.0/resnets.2.conv1.bias b/onnx/up_blocks.0/resnets.2.conv1.bias new file mode 100644 index 0000000000000000000000000000000000000000..111cdfecdd3fc920216c01af061d4b0542ef1437 Binary files /dev/null and b/onnx/up_blocks.0/resnets.2.conv1.bias differ diff --git a/onnx/up_blocks.0/resnets.2.conv_shortcut.bias b/onnx/up_blocks.0/resnets.2.conv_shortcut.bias new file mode 100644 index 0000000000000000000000000000000000000000..38d115bc6354443376ebec603a67b060338ea4c7 Binary files /dev/null and b/onnx/up_blocks.0/resnets.2.conv_shortcut.bias differ diff --git a/onnx/up_blocks.0/resnets.2.time_emb_proj.bias b/onnx/up_blocks.0/resnets.2.time_emb_proj.bias new file mode 100644 index 0000000000000000000000000000000000000000..1b1643c8ec1a850b5626b87250d9ca553033852f Binary files /dev/null and b/onnx/up_blocks.0/resnets.2.time_emb_proj.bias differ diff --git a/onnx/up_blocks.0/upsamplers.0.conv.bias b/onnx/up_blocks.0/upsamplers.0.conv.bias new file mode 100644 index 0000000000000000000000000000000000000000..a11abe507831e52cb30788627683439c0013cdf5 Binary files /dev/null and b/onnx/up_blocks.0/upsamplers.0.conv.bias differ diff --git a/src/assets/sdxl_cache.png b/src/assets/sdxl_cache.png new file mode 100644 index 0000000000000000000000000000000000000000..cc40b7f0331215eff229d3c5f8eb16607f587bdf Binary files /dev/null and b/src/assets/sdxl_cache.png differ diff --git a/src/cache_diffusion/cachify.py b/src/cache_diffusion/cachify.py new file mode 100644 index 0000000000000000000000000000000000000000..df5b3efbb222c69984c78f75ea84a6a3fb04dd16 --- /dev/null +++ b/src/cache_diffusion/cachify.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import fnmatch +from contextlib import contextmanager + +from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UpBlock2D, +) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.models.unets.unet_3d_blocks import ( + CrossAttnDownBlockSpatioTemporal, + CrossAttnUpBlockSpatioTemporal, + DownBlockSpatioTemporal, + UNetMidBlockSpatioTemporal, + UpBlockSpatioTemporal, +) +from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + +from .module import CachedModule +from .utils import replace_module + +CACHED_PIPE = { + UNet2DConditionModel: ( + DownBlock2D, + CrossAttnDownBlock2D, + UNetMidBlock2DCrossAttn, + CrossAttnUpBlock2D, + UpBlock2D, + ), + PixArtTransformer2DModel: (BasicTransformerBlock), + UNetSpatioTemporalConditionModel: ( + CrossAttnDownBlockSpatioTemporal, + DownBlockSpatioTemporal, + UpBlockSpatioTemporal, + CrossAttnUpBlockSpatioTemporal, + UNetMidBlockSpatioTemporal, + ), + SD3Transformer2DModel: (JointTransformerBlock), +} + + +def _apply_to_modules(model, action, modules=None, config_list=None): + if hasattr(model, "use_trt_infer") and model.use_trt_infer: + for key, module in model.engines.items(): + if isinstance(module, CachedModule): + action(module) + elif config_list: + for config in config_list: + if _pass(key, config["wildcard_or_filter_func"]): + model.engines[key] = CachedModule(module, config["select_cache_step_func"]) + else: + for name, module in model.named_modules(): + if isinstance(module, CachedModule): + action(module) + elif modules and config_list: + for config in config_list: + if _pass(name, config["wildcard_or_filter_func"]) and isinstance( + module, modules + ): + replace_module( + model, + name, + CachedModule(module, config["select_cache_step_func"]), + ) + + +def cachify(model, config_list, modules): + def cache_action(module): + pass # No action needed, caching is handled in the loop itself + + _apply_to_modules(model, cache_action, modules, config_list) + + +def disable(pipe): + model = get_model(pipe) + _apply_to_modules(model, lambda module: module.disable_cache()) + + +def enable(pipe): + model = get_model(pipe) + _apply_to_modules(model, lambda module: module.enable_cache()) + + +def reset_status(pipe): + model = get_model(pipe) + _apply_to_modules(model, lambda module: setattr(module, "cur_step", 0)) + + +def _pass(name, wildcard_or_filter_func): + if isinstance(wildcard_or_filter_func, str): + return fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + return wildcard_or_filter_func(name) + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + + +def get_model(pipe): + if hasattr(pipe, "unet"): + return pipe.unet + elif hasattr(pipe, "transformer"): + return pipe.transformer + else: + raise KeyError + + +@contextmanager +def infer(pipe): + try: + yield pipe + finally: + reset_status(pipe) + + +def prepare(pipe, config_list): + model = get_model(pipe) + assert model.__class__ in CACHED_PIPE.keys(), f"{model.__class__} is not supported!" + cachify(model, config_list, CACHED_PIPE[model.__class__]) diff --git a/src/cache_diffusion/module.py b/src/cache_diffusion/module.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ed434c30fd1ab1feeeafc6addeb73dd04655c4 --- /dev/null +++ b/src/cache_diffusion/module.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from torch import nn + + +class CachedModule(nn.Module): + def __init__(self, block, select_cache_step_func) -> None: + super().__init__() + self.block = block + self.select_cache_step_func = select_cache_step_func + self.cur_step = 0 + self.cached_results = None + self.enabled = True + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.block, name) + + def if_cache(self): + return self.select_cache_step_func(self.cur_step) and self.enabled + + def enable_cache(self): + self.enabled = True + + def disable_cache(self): + self.enabled = False + self.cur_step = 0 + + def forward(self, *args, **kwargs): + if not self.if_cache(): + self.cached_results = self.block(*args, **kwargs) + if self.enabled: + self.cur_step += 1 + return self.cached_results diff --git a/src/cache_diffusion/utils.py b/src/cache_diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b2c27726367f7961e55f215ca6f064353dc8bd --- /dev/null +++ b/src/cache_diffusion/utils.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import re + +SDXL_DEFAULT_CONFIG = [ + { + "wildcard_or_filter_func": lambda name: "up_blocks.2" not in name, + "select_cache_step_func": lambda step: (step % 2) != 0, + } +] + +PIXART_DEFAULT_CONFIG = [ + { + "wildcard_or_filter_func": lambda name: not re.search( + r"transformer_blocks\.(2[1-7])\.", name + ), + "select_cache_step_func": lambda step: (step % 3) != 0, + } +] + +SVD_DEFAULT_CONFIG = [ + { + "wildcard_or_filter_func": lambda name: "up_blocks.3" not in name, + "select_cache_step_func": lambda step: (step % 2) != 0, + } +] + +SD3_DEFAULT_CONFIG = [ + { + "wildcard_or_filter_func": lambda name: re.search( + r"^((?!transformer_blocks\.(1[6-9]|2[0-3])).)*$", name + ), + "select_cache_step_func": lambda step: (step % 2) != 0, + } +] + + +def replace_module(parent, name_path, new_module): + path_parts = name_path.split(".") + for part in path_parts[:-1]: + parent = getattr(parent, part) + setattr(parent, path_parts[-1], new_module) diff --git a/src/loss.py b/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..24a18c6efc382f39f458948bdd16d2c7ea972713 --- /dev/null +++ b/src/loss.py @@ -0,0 +1,45 @@ +_A=None +import torch +from tqdm import tqdm +class LossSchedulerModel(torch.nn.Module): + def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we)) + def forward(A,t,xT,e_prev): + B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t] + for(D,E)in zip(B,A.we[t]):C+=D*E + return C.to(xT.dtype) +class LossScheduler: + def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1 + @staticmethod + def load(path):A,B,C=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C);return LossScheduler(A,D) + def save(A,path):B,C,D=A.timesteps,A.model.wx,A.model.we;torch.save((B,C,D),path) + def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B) + def scale_model_input(A,sample,*B,**C):return sample + @torch.no_grad() + def step(self,model_output,timestep,sample,*D,**E): + A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1 + if A.t_prev==-1:A.xT=sample + A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev) + if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1 + else:A.t_prev=B + return C, +class SchedulerWrapper: + def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path + def set_timesteps(A,num_inference_steps,**C): + D=num_inference_steps + if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B + else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B + def step(B,model_output,timestep,sample,**F): + D=sample;E=model_output;A=timestep + if B.loss_scheduler is _A: + C=B.scheduler.step(E,A,D,**F);A=A.tolist() + if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[] + B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C + else:C=B.loss_scheduler.step(E,A,D,**F);return C + def scale_model_input(A,sample,timestep):return sample + def add_noise(A,original_samples,noise,timesteps):B=A.scheduler.add_noise(original_samples,noise,timesteps);return B + def get_path(C): + A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[] + for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G) + H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D + def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model) + def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params() diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..95720d31c57fddcdf5318dff60cb0ea106c7037b --- /dev/null +++ b/src/main.py @@ -0,0 +1,59 @@ +import atexit +from io import BytesIO +from multiprocessing.connection import Listener +from os import chmod, remove +from os.path import abspath, exists +from pathlib import Path + +import torch + +from PIL.JpegImagePlugin import JpegImageFile +from pipelines.models import TextToImageRequest + +from pipeline import load_pipeline, infer + +SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") + + +def at_exit(): + torch.cuda.empty_cache() + + +def main(): + atexit.register(at_exit) + + print(f"Loading pipeline") + pipeline = load_pipeline() + + print(f"Pipeline loaded, creating socket at '{SOCKET}'") + + if exists(SOCKET): + remove(SOCKET) + + with Listener(SOCKET) as listener: + chmod(SOCKET, 0o777) + + print(f"Awaiting connections") + with listener.accept() as connection: + print(f"Connected") + + while True: + try: + request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) + except EOFError: + print(f"Inference socket exiting") + + return + + image = infer(request, pipeline) + + data = BytesIO() + image.save(data, format=JpegImageFile.format) + + packet = data.getvalue() + + connection.send_bytes(packet) + + +if __name__ == '__main__': + main() diff --git a/src/pipe/config.py b/src/pipe/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3be8b352611cccbf36fa426e5fafb33533d41edf --- /dev/null +++ b/src/pipe/config.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + +sd3_common_transformer_block_config = { + "dummy_input": { + "hidden_states": (2, 4096, 1536), + "encoder_hidden_states": (2, 333, 1536), + "temb": (2, 1536), + }, + "output_names": ["encoder_hidden_states_out", "hidden_states_out"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + }, +} + +ONNX_CONFIG = { + UNet2DConditionModel: { + "down_blocks.0": { + "dummy_input": { + "hidden_states": (2, 320, 128, 128), + "temb": (2, 1280), + }, + "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + }, + }, + "down_blocks.1": { + "dummy_input": { + "hidden_states": (2, 320, 64, 64), + "temb": (2, 1280), + "encoder_hidden_states": (2, 77, 2048), + }, + "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size"}, + }, + }, + "down_blocks.2": { + "dummy_input": { + "hidden_states": (2, 640, 32, 32), + "temb": (2, 1280), + "encoder_hidden_states": (2, 77, 2048), + }, + "output_names": ["sample", "res_samples_0", "res_samples_1"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size"}, + }, + }, + "mid_block": { + "dummy_input": { + "hidden_states": (2, 1280, 32, 32), + "temb": (2, 1280), + "encoder_hidden_states": (2, 77, 2048), + }, + "output_names": ["sample"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size"}, + }, + }, + "up_blocks.0": { + "dummy_input": { + "hidden_states": (2, 1280, 32, 32), + "res_hidden_states_0": (2, 640, 32, 32), + "res_hidden_states_1": (2, 1280, 32, 32), + "res_hidden_states_2": (2, 1280, 32, 32), + "temb": (2, 1280), + "encoder_hidden_states": (2, 77, 2048), + }, + "output_names": ["sample"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size"}, + "res_hidden_states_0": {0: "batch_size"}, + "res_hidden_states_1": {0: "batch_size"}, + "res_hidden_states_2": {0: "batch_size"}, + }, + }, + "up_blocks.1": { + "dummy_input": { + "hidden_states": (2, 1280, 64, 64), + "res_hidden_states_0": (2, 320, 64, 64), + "res_hidden_states_1": (2, 640, 64, 64), + "res_hidden_states_2": (2, 640, 64, 64), + "temb": (2, 1280), + "encoder_hidden_states": (2, 77, 2048), + }, + "output_names": ["sample"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size"}, + "res_hidden_states_0": {0: "batch_size"}, + "res_hidden_states_1": {0: "batch_size"}, + "res_hidden_states_2": {0: "batch_size"}, + }, + }, + "up_blocks.2": { + "dummy_input": { + "hidden_states": (2, 640, 128, 128), + "res_hidden_states_0": (2, 320, 128, 128), + "res_hidden_states_1": (2, 320, 128, 128), + "res_hidden_states_2": (2, 320, 128, 128), + "temb": (2, 1280), + }, + "output_names": ["sample"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + "res_hidden_states_0": {0: "batch_size"}, + "res_hidden_states_1": {0: "batch_size"}, + "res_hidden_states_2": {0: "batch_size"}, + }, + }, + }, + SD3Transformer2DModel: { + **{f"transformer_blocks.{i}": sd3_common_transformer_block_config for i in range(23)}, + "transformer_blocks.23": { + "dummy_input": { + "hidden_states": (2, 4096, 1536), + "encoder_hidden_states": (2, 333, 1536), + "temb": (2, 1536), + }, + "output_names": ["hidden_states_out"], + "dynamic_axes": { + "hidden_states": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "temb": {0: "steps"}, + }, + }, + }, +} diff --git a/src/pipe/deploy.py b/src/pipe/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d0b667dafa3e5a807beeb391efb6f7657dc616 --- /dev/null +++ b/src/pipe/deploy.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import types +from pathlib import Path + +import tensorrt as trt +import torch +from cache_diffusion.cachify import CACHED_PIPE, get_model +from cuda import cudart +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from pipe.config import ONNX_CONFIG +from pipe.models.sd3 import sd3_forward +from pipe.models.sdxl import ( + cachecrossattnupblock2d_forward, + cacheunet_forward, + cacheupblock2d_forward, +) +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from torch.onnx import export as onnx_export + +from .utils import Engine + + +def replace_new_forward(backbone): + if backbone.__class__ == UNet2DConditionModel: + backbone.forward = types.MethodType(cacheunet_forward, backbone) + for upsample_block in backbone.up_blocks: + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + upsample_block.forward = types.MethodType( + cachecrossattnupblock2d_forward, upsample_block + ) + else: + upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block) + elif backbone.__class__ == SD3Transformer2DModel: + backbone.forward = types.MethodType(sd3_forward, backbone) + + +def get_input_info(dummy_dict, info: str = None, batch_size: int = 1): + return_val = [] if info == "profile_shapes" or info == "input_names" else {} + + def collect_leaf_keys(d): + for key, value in d.items(): + if isinstance(value, dict): + collect_leaf_keys(value) + else: + value = (value[0] * batch_size,) + value[1:] + if info == "profile_shapes": + return_val.append((key, value)) # type: ignore + elif info == "profile_shapes_dict": + return_val[key] = value # type: ignore + elif info == "dummy_input": + return_val[key] = torch.ones(value).half().cuda() # type: ignore + elif info == "input_names": + return_val.append(key) # type: ignore + + collect_leaf_keys(dummy_dict) + return return_val + + +def complie2trt(cls, onnx_path: Path, engine_path: Path, batch_size: int = 1): + subdirs = [f for f in onnx_path.iterdir() if f.is_dir()] + for subdir in subdirs: + if subdir.name not in ONNX_CONFIG[cls].keys(): + continue + model_path = subdir / "model.onnx" + plan_path = engine_path / f"{subdir.name}.plan" + if not plan_path.exists(): + print(f"Building {str(model_path)}") + build_profile = Profile() + profile_shapes = get_input_info( + ONNX_CONFIG[cls][subdir.name]["dummy_input"], "profile_shapes", batch_size + ) + for input_name, input_shape in profile_shapes: + min_input_shape = (2,) + input_shape[1:] + build_profile.add(input_name, min_input_shape, input_shape, input_shape) + block_network = network_from_onnx_path( + str(model_path), flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM], strongly_typed=True + ) + build_config = CreateConfig( + builder_optimization_level=6, + tf32=True, + #bf16=True, + profiles=[build_profile], + ) + engine = engine_from_network( + block_network, + config=build_config, + ) + save_engine(engine, path=plan_path) + else: + print(f"{str(model_path)} already exists!") + + +def get_total_device_memory(backbone): + max_device_memory = 0 + for _, engine in backbone.engines.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + +def load_engines(backbone, engine_path: Path, batch_size: int = 1): + backbone.engines = {} + for f in engine_path.iterdir(): + if f.is_file(): + eng = Engine() + eng.load(str(f)) + backbone.engines[f"{f.stem}"] = eng + _, shared_device_memory = cudart.cudaMalloc(get_total_device_memory(backbone)) + for engine in backbone.engines.values(): + engine.activate(shared_device_memory) + backbone.cuda_stream = cudart.cudaStreamCreate()[1] + for block_name in backbone.engines.keys(): + backbone.engines[block_name].allocate_buffers( + shape_dict=get_input_info( + ONNX_CONFIG[backbone.__class__][block_name]["dummy_input"], + "profile_shapes_dict", + batch_size, + ), + device=backbone.device, + batch_size=batch_size, + ) + # TODO: Free and clean up the origin pytorch cuda memory + + +def export_onnx(backbone, onnx_path: Path): + for name, module in backbone.named_modules(): + if isinstance(module, CACHED_PIPE[backbone.__class__]): + _onnx_dir = onnx_path.joinpath(f"{name}") + _onnx_file = _onnx_dir.joinpath("model.onnx") + if not _onnx_file.exists(): + _onnx_dir.mkdir(parents=True, exist_ok=True) + dummy_input = get_input_info( + ONNX_CONFIG[backbone.__class__][f"{name}"]["dummy_input"], "dummy_input" + ) + input_names = get_input_info( + ONNX_CONFIG[backbone.__class__][f"{name}"]["dummy_input"], "input_names" + ) + output_names = ONNX_CONFIG[backbone.__class__][f"{name}"]["output_names"] + onnx_export( + module, + args=dummy_input, + f=_onnx_file.as_posix(), + input_names=input_names, + output_names=output_names, + dynamic_axes=ONNX_CONFIG[backbone.__class__][f"{name}"]["dynamic_axes"], + do_constant_folding=True, + opset_version=17, + ) + else: + print(f"{str(_onnx_file)} alread exists!") + + +def warm_up(backbone, batch_size: int = 1): + print("Warming-up TensorRT engines...") + for name, engine in backbone.engines.items(): + dummy_input = get_input_info( + ONNX_CONFIG[backbone.__class__][name]["dummy_input"], "dummy_input", batch_size + ) + _ = engine(dummy_input, backbone.cuda_stream) + + +def teardown(pipe): + backbone = get_model(pipe) + for engine in backbone.engines.values(): + del engine + + cudart.cudaStreamDestroy(backbone.cuda_stream) + del backbone.cuda_stream + + +def compile(pipe, onnx_path: Path, engine_path: Path, batch_size: int = 1): + backbone = get_model(pipe) + onnx_path.mkdir(parents=True, exist_ok=True) + engine_path.mkdir(parents=True, exist_ok=True) + + replace_new_forward(backbone) + export_onnx(backbone, onnx_path) + complie2trt(backbone.__class__, onnx_path, engine_path, batch_size) + load_engines(backbone, engine_path, batch_size) + warm_up(backbone, batch_size) + backbone.use_trt_infer = True diff --git a/src/pipe/models/sd3.py b/src/pipe/models/sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..33fb5f940b2382c099107af1a3e460c633ac3acd --- /dev/null +++ b/src/pipe/models/sd3.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Any, Dict, List, Optional, Union + +import torch +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + scale_lora_layers, + unscale_lora_layers, +) + + +def sd3_forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + block_controlnet_hidden_states: List = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "temb": temb, + } + _results = self.engines[f"transformer_blocks.{index_block}"]( + feed_dict, self.cuda_stream + ) + if index_block != 23: + encoder_hidden_states = _results["encoder_hidden_states_out"] + hidden_states = _results["hidden_states_out"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + ) + + # controlnet residual + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = ( + hidden_states + block_controlnet_hidden_states[index_block // interval_control] + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/pipe/models/sdxl.py b/src/pipe/models/sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..92aaece6ab8a3313be41825c48d4c2109f3d1df7 --- /dev/null +++ b/src/pipe/models/sdxl.py @@ -0,0 +1,275 @@ +# Adapted from +# https://github.com/huggingface/diffusers/blob/73acebb8cfbd1d2954cabe1af4185f9994e61917/src/diffusers/models/unets/unet_2d_condition.py#L1039-L1312 +# https://github.com/huggingface/diffusers/blob/73acebb8cfbd1d2954cabe1af4185f9994e61917/src/diffusers/models/unets/unet_2d_blocks.py#L2482-L2564 +# https://github.com/huggingface/diffusers/blob/73acebb8cfbd1d2954cabe1af4185f9994e61917/src/diffusers/models/unets/unet_2d_blocks.py#L2617-L2679 + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + + +def cachecrossattnupblock2d_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_0: torch.FloatTensor, + res_hidden_states_1: torch.FloatTensor, + res_hidden_states_2: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def cacheupblock2d_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_0: torch.FloatTensor, + res_hidden_states_1: torch.FloatTensor, + res_hidden_states_2: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, +) -> torch.FloatTensor: + res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def cacheunet_forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, +) -> Union[UNet2DConditionOutput, Tuple]: + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + aug_emb = self.get_aug_embed( + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + + emb = emb + aug_emb if aug_emb is not None else emb + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + if hasattr(self, "_export_precess_onnx") and self._export_precess_onnx: + return ( + sample, + encoder_hidden_states, + emb, + ) + + down_block_res_samples = (sample,) + for i, downsample_block in enumerate(self.down_blocks): + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = { + "hidden_states": sample, + "temb": emb, + "encoder_hidden_states": encoder_hidden_states, + } + down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) + sample = down_results["sample"] + res_samples_0 = down_results["res_samples_0"] + res_samples_1 = down_results["res_samples_1"] + if "res_samples_2" in down_results.keys(): + res_samples_2 = down_results["res_samples_2"] + else: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = {"hidden_states": sample, "temb": emb} + down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) + sample = down_results["sample"] + res_samples_0 = down_results["res_samples_0"] + res_samples_1 = down_results["res_samples_1"] + if "res_samples_2" in down_results.keys(): + res_samples_2 = down_results["res_samples_2"] + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + down_block_res_samples += ( + res_samples_0, + res_samples_1, + ) + if "res_samples_2" in down_results.keys(): + down_block_res_samples += (res_samples_2,) + else: + down_block_res_samples += res_samples + + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = { + "hidden_states": sample, + "temb": emb, + "encoder_hidden_states": encoder_hidden_states, + } + mid_results = self.engines["mid_block"](feed_dict, self.cuda_stream) + sample = mid_results["sample"] + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = { + "hidden_states": sample, + "res_hidden_states_0": res_samples[0], + "res_hidden_states_1": res_samples[1], + "res_hidden_states_2": res_samples[2], + "temb": emb, + "encoder_hidden_states": encoder_hidden_states, + } + up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) + sample = up_results["sample"] + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_0=res_samples[0], + res_hidden_states_1=res_samples[1], + res_hidden_states_2=res_samples[2], + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + if hasattr(self, "use_trt_infer") and self.use_trt_infer: + feed_dict = { + "hidden_states": sample, + "res_hidden_states_0": res_samples[0], + "res_hidden_states_1": res_samples[1], + "res_hidden_states_2": res_samples[2], + "temb": emb, + } + up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) + sample = up_results["sample"] + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_0=res_samples[0], + res_hidden_states_1=res_samples[1], + res_hidden_states_2=res_samples[2], + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/src/pipe/utils.py b/src/pipe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..340f2ea58f10e9bf11fa214ce8a55b589b5c0b70 --- /dev/null +++ b/src/pipe/utils.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from collections import OrderedDict + +import numpy as np +import tensorrt as trt +import torch +from cuda import cudart +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import engine_from_bytes + +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +class Engine: + def __init__( + self, + ): + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.has_cross_attention = False + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def load(self, engine_path): + self.engine = engine_from_bytes(bytes_from_path(engine_path)) + + def activate(self, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() # type: ignore + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() # type: ignore + + def allocate_buffers(self, shape_dict=None, device="cuda", batch_size=1): + for binding in range(self.engine.num_io_tensors): # type: ignore + name = self.engine.get_tensor_name(binding) # type: ignore + if shape_dict and name in shape_dict: + shape = shape_dict[name] + else: + shape = self.engine.get_tensor_shape(name) # type: ignore + shape = (batch_size * 2,) + shape[1:] + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) # type: ignore + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: # type: ignore + self.context.set_input_shape(name, shape) # type: ignore + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to( + device=device + ) + self.tensors[name] = tensor + + def __call__(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) # type: ignore + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) # type: ignore + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture( + stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal + ) + ) + self.context.execute_async_v3(stream) # type: ignore + self.graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(self.graph, 0)) + else: + noerror = self.context.execute_async_v3(stream) # type: ignore + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +def cuassert(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None