Bradarr lnyan commited on
Commit
6988dbd
·
0 Parent(s):

Duplicate from lnyan/stablediffusion-infinity

Browse files

Co-authored-by: Lnyan <lnyan@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
PyPatchMatch/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /build/
2
+ /*.so
3
+ __pycache__
4
+ *.py[cod]
PyPatchMatch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jiayuan Mao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PyPatchMatch/Makefile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Makefile
3
+ # Jiayuan Mao, 2019-01-09 13:59
4
+ #
5
+
6
+ SRC_DIR = csrc
7
+ INC_DIR = csrc
8
+ OBJ_DIR = build/obj
9
+ TARGET = libpatchmatch.so
10
+
11
+ LIB_TARGET = $(TARGET)
12
+ INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
13
+
14
+ CXX = $(ENVIRONMENT_OPTIONS) g++
15
+ CXXFLAGS = -std=c++14
16
+ CXXFLAGS += -Ofast -ffast-math -w
17
+ # CXXFLAGS += -g
18
+ CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
19
+ CXXFLAGS += $(INCLUDE_DIR)
20
+ LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
21
+
22
+
23
+ CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
24
+ OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
25
+ DEPFILES = $(OBJS:.o=.d)
26
+
27
+ .PHONY: all clean rebuild test
28
+
29
+ all: $(LIB_TARGET)
30
+
31
+ $(OBJ_DIR)/%.o: %.cpp
32
+ @echo "[CC] $< ..."
33
+ @$(CXX) -c $< $(CXXFLAGS) -o $@
34
+
35
+ $(OBJ_DIR)/%.d: %.cpp
36
+ @mkdir -pv $(dir $@)
37
+ @echo "[dep] $< ..."
38
+ @$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
39
+
40
+ sinclude $(DEPFILES)
41
+
42
+ $(LIB_TARGET): $(OBJS)
43
+ @echo "[link] $(LIB_TARGET) ..."
44
+ @$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
45
+
46
+ clean:
47
+ rm -rf $(OBJ_DIR) $(LIB_TARGET)
48
+
49
+ rebuild:
50
+ +@make clean
51
+ +@make
52
+
53
+ # vim:ft=make
54
+ #
PyPatchMatch/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PatchMatch based Inpainting
2
+ =====================================
3
+ This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
4
+ This implementation is heavily based on the implementation by Younesse ANDAM:
5
+ (younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
6
+
7
+ Usage
8
+ -------------------------------------
9
+
10
+ You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
11
+ shared library `libpatchmatch.so`.
12
+
13
+ For Python users (example available at `examples/py_example.py`)
14
+
15
+ ```python
16
+ import patch_match
17
+
18
+ image = ... # either a numpy ndarray or a PIL Image object.
19
+ mask = ... # either a numpy ndarray or a PIL Image object.
20
+ result = patch_match.inpaint(image, mask, patch_size=5)
21
+ ```
22
+
23
+ For C++ users (examples available at `examples/cpp_example.cpp`)
24
+
25
+ ```cpp
26
+ #include "inpaint.h"
27
+
28
+ int main() {
29
+ cv::Mat image = ...
30
+ cv::Mat mask = ...
31
+
32
+ cv::Mat result = Inpainting(image, mask, 5).run();
33
+
34
+ return 0;
35
+ }
36
+ ```
37
+
38
+
39
+ README and COPYRIGHT by Younesse ANDAM
40
+ -------------------------------------
41
+ @Author: Younesse ANDAM
42
+
43
+ @Contact: younesse.andam@gmail.com
44
+
45
+ Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
46
+ The algorithm is presented in the following paper
47
+ PatchMatch A Randomized Correspondence Algorithm
48
+ for Structural Image Editing
49
+ by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
50
+ ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
51
+
52
+ For more information please refer to
53
+ http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
54
+
55
+ Copyright (c) 2010-2011
56
+
57
+
58
+ Requirements
59
+ -------------------------------------
60
+
61
+ To run the project you need to install Opencv library and link it to your project.
62
+ Opencv can be download it here
63
+ http://opencv.org/downloads.html
64
+
PyPatchMatch/csrc/inpaint.cpp ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <opencv2/imgcodecs.hpp>
4
+ #include <opencv2/imgproc.hpp>
5
+ #include <opencv2/highgui.hpp>
6
+
7
+ #include "inpaint.h"
8
+
9
+ namespace {
10
+ static std::vector<double> kDistance2Similarity;
11
+
12
+ void init_kDistance2Similarity() {
13
+ double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
14
+ int length = (PatchDistanceMetric::kDistanceScale + 1);
15
+ kDistance2Similarity.resize(length);
16
+ for (int i = 0; i < length; ++i) {
17
+ double t = (double) i / length;
18
+ int j = (int) (100 * t);
19
+ int k = j + 1;
20
+ double vj = (j < 11) ? base[j] : 0;
21
+ double vk = (k < 11) ? base[k] : 0;
22
+ kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
23
+ }
24
+ }
25
+
26
+
27
+ inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
28
+ if (source.is_masked(ys, xs)) return;
29
+ if (source.is_globally_masked(ys, xs)) return;
30
+
31
+ auto source_ptr = source.get_image(ys, xs);
32
+ auto target_ptr = target.ptr<double>(yt, xt);
33
+
34
+ #pragma unroll
35
+ for (int c = 0; c < 3; ++c)
36
+ target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
37
+ target_ptr[3] += weight;
38
+ }
39
+ }
40
+
41
+ /**
42
+ * This algorithme uses a version proposed by Xavier Philippeau.
43
+ */
44
+
45
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
46
+ : m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
47
+ _initialize_pyramid();
48
+ }
49
+
50
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
51
+ : m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
52
+ _initialize_pyramid();
53
+ }
54
+
55
+ void Inpainting::_initialize_pyramid() {
56
+ auto source = m_initial;
57
+ m_pyramid.push_back(source);
58
+ while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
59
+ source = source.downsample();
60
+ m_pyramid.push_back(source);
61
+ }
62
+
63
+ if (kDistance2Similarity.size() == 0) {
64
+ init_kDistance2Similarity();
65
+ }
66
+ }
67
+
68
+ cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
69
+ srand(random_seed);
70
+ const int nr_levels = m_pyramid.size();
71
+
72
+ MaskedImage source, target;
73
+ for (int level = nr_levels - 1; level >= 0; --level) {
74
+ if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
75
+
76
+ source = m_pyramid[level];
77
+
78
+ if (level == nr_levels - 1) {
79
+ target = source.clone();
80
+ target.clear_mask();
81
+ m_source2target = NearestNeighborField(source, target, m_distance_metric);
82
+ m_target2source = NearestNeighborField(target, source, m_distance_metric);
83
+ } else {
84
+ m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
85
+ m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
86
+ }
87
+
88
+ if (verbose) std::cerr << "Initialization done." << std::endl;
89
+
90
+ if (verbose_visualize) {
91
+ auto visualize_size = m_initial.size();
92
+ cv::Mat source_visualize(visualize_size, m_initial.image().type());
93
+ cv::resize(source.image(), source_visualize, visualize_size);
94
+ cv::imshow("Source", source_visualize);
95
+ cv::Mat target_visualize(visualize_size, m_initial.image().type());
96
+ cv::resize(target.image(), target_visualize, visualize_size);
97
+ cv::imshow("Target", target_visualize);
98
+ cv::waitKey(0);
99
+ }
100
+
101
+ target = _expectation_maximization(source, target, level, verbose);
102
+ }
103
+
104
+ return target.image();
105
+ }
106
+
107
+ // EM-Like algorithm (see "PatchMatch" - page 6).
108
+ // Returns a double sized target image (unless level = 0).
109
+ MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
110
+ const int nr_iters_em = 1 + 2 * level;
111
+ const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
112
+ const int patch_size = m_distance_metric->patch_size();
113
+
114
+ MaskedImage new_source, new_target;
115
+
116
+ for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
117
+ if (iter_em != 0) {
118
+ m_source2target.set_target(new_target);
119
+ m_target2source.set_source(new_target);
120
+ target = new_target;
121
+ }
122
+
123
+ if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
124
+
125
+ auto size = source.size();
126
+ for (int i = 0; i < size.height; ++i) {
127
+ for (int j = 0; j < size.width; ++j) {
128
+ if (!source.contains_mask(i, j, patch_size)) {
129
+ m_source2target.set_identity(i, j);
130
+ m_target2source.set_identity(i, j);
131
+ }
132
+ }
133
+ }
134
+ if (verbose) std::cerr << " NNF minimization started." << std::endl;
135
+ m_source2target.minimize(nr_iters_nnf);
136
+ m_target2source.minimize(nr_iters_nnf);
137
+ if (verbose) std::cerr << " NNF minimization finished." << std::endl;
138
+
139
+ // Instead of upsizing the final target, we build the last target from the next level source image.
140
+ // Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
141
+ bool upscaled = false;
142
+ if (level >= 1 && iter_em == nr_iters_em - 1) {
143
+ new_source = m_pyramid[level - 1];
144
+ new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
145
+ upscaled = true;
146
+ } else {
147
+ new_source = m_pyramid[level];
148
+ new_target = target.clone();
149
+ }
150
+
151
+ auto vote = cv::Mat(new_target.size(), CV_64FC4);
152
+ vote.setTo(cv::Scalar::all(0));
153
+
154
+ // Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
155
+ _expectation_step(m_source2target, 1, vote, new_source, upscaled);
156
+ if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
157
+ _expectation_step(m_target2source, 0, vote, new_source, upscaled);
158
+ if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
159
+
160
+ // Compile votes and update pixel values.
161
+ _maximization_step(new_target, vote);
162
+ if (verbose) std::cerr << " Minimization step finished." << std::endl;
163
+ }
164
+
165
+ return new_target;
166
+ }
167
+
168
+ // Expectation step: vote for best estimations of each pixel.
169
+ void Inpainting::_expectation_step(
170
+ const NearestNeighborField &nnf, bool source2target,
171
+ cv::Mat &vote, const MaskedImage &source, bool upscaled
172
+ ) {
173
+ auto source_size = nnf.source_size();
174
+ auto target_size = nnf.target_size();
175
+ const int patch_size = m_distance_metric->patch_size();
176
+
177
+ for (int i = 0; i < source_size.height; ++i) {
178
+ for (int j = 0; j < source_size.width; ++j) {
179
+ if (nnf.source().is_globally_masked(i, j)) continue;
180
+ int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
181
+ double w = kDistance2Similarity[dp];
182
+
183
+ for (int di = -patch_size; di <= patch_size; ++di) {
184
+ for (int dj = -patch_size; dj <= patch_size; ++dj) {
185
+ int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
186
+ if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
187
+ if (nnf.source().is_globally_masked(ys, xs)) continue;
188
+ if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
189
+ if (nnf.target().is_globally_masked(yt, xt)) continue;
190
+
191
+ if (!source2target) {
192
+ std::swap(ys, yt);
193
+ std::swap(xs, xt);
194
+ }
195
+
196
+ if (upscaled) {
197
+ for (int uy = 0; uy < 2; ++uy) {
198
+ for (int ux = 0; ux < 2; ++ux) {
199
+ _weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
200
+ }
201
+ }
202
+ } else {
203
+ _weighted_copy(source, ys, xs, vote, yt, xt, w);
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ // Maximization Step: maximum likelihood of target pixel.
212
+ void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
213
+ auto target_size = target.size();
214
+ for (int i = 0; i < target_size.height; ++i) {
215
+ for (int j = 0; j < target_size.width; ++j) {
216
+ const double *source_ptr = vote.ptr<double>(i, j);
217
+ unsigned char *target_ptr = target.get_mutable_image(i, j);
218
+
219
+ if (target.is_globally_masked(i, j)) {
220
+ continue;
221
+ }
222
+
223
+ if (source_ptr[3] > 0) {
224
+ unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
225
+ unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
226
+ unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
227
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
228
+ } else {
229
+ target.set_mask(i, j, 0);
230
+ }
231
+ }
232
+ }
233
+ }
234
+
PyPatchMatch/csrc/inpaint.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ class Inpainting {
9
+ public:
10
+ Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
11
+ Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
12
+ cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
13
+
14
+ private:
15
+ void _initialize_pyramid(void);
16
+ MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
17
+ void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
18
+ void _maximization_step(MaskedImage &target, const cv::Mat &vote);
19
+
20
+ MaskedImage m_initial;
21
+ std::vector<MaskedImage> m_pyramid;
22
+
23
+ NearestNeighborField m_source2target;
24
+ NearestNeighborField m_target2source;
25
+ const PatchDistanceMetric *m_distance_metric;
26
+ };
27
+
PyPatchMatch/csrc/masked_image.cpp ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "masked_image.h"
2
+ #include <algorithm>
3
+ #include <iostream>
4
+
5
+ const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
6
+ const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
7
+
8
+ bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
9
+ auto mask_size = size();
10
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
11
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
12
+ int yy = y + dy, xx = x + dx;
13
+ if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
14
+ if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
15
+ }
16
+ }
17
+ }
18
+ return false;
19
+ }
20
+
21
+ MaskedImage MaskedImage::downsample() const {
22
+ const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
23
+ const auto &kernel = MaskedImage::kDownsampleKernel;
24
+
25
+ const auto size = this->size();
26
+ const auto new_size = cv::Size(size.width / 2, size.height / 2);
27
+
28
+ auto ret = MaskedImage(new_size.width, new_size.height);
29
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
30
+ for (int y = 0; y < size.height - 1; y += 2) {
31
+ for (int x = 0; x < size.width - 1; x += 2) {
32
+ int r = 0, g = 0, b = 0, ksum = 0;
33
+ bool is_gmasked = true;
34
+
35
+ for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
36
+ for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
37
+ int yy = y + dy, xx = x + dx;
38
+ if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
39
+ if (!is_globally_masked(yy, xx)) {
40
+ is_gmasked = false;
41
+ }
42
+ if (!is_masked(yy, xx)) {
43
+ auto source_ptr = get_image(yy, xx);
44
+ int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
45
+ r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
46
+ ksum += k;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
53
+
54
+ if (!m_global_mask.empty()) {
55
+ ret.set_global_mask(y / 2, x / 2, is_gmasked);
56
+ }
57
+ if (ksum > 0) {
58
+ auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
59
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
60
+ ret.set_mask(y / 2, x / 2, 0);
61
+ } else {
62
+ ret.set_mask(y / 2, x / 2, 1);
63
+ }
64
+ }
65
+ }
66
+
67
+ return ret;
68
+ }
69
+
70
+ MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
71
+ const auto size = this->size();
72
+ auto ret = MaskedImage(new_w, new_h);
73
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
74
+ for (int y = 0; y < new_h; ++y) {
75
+ for (int x = 0; x < new_w; ++x) {
76
+ int yy = y * size.height / new_h;
77
+ int xx = x * size.width / new_w;
78
+
79
+ if (is_globally_masked(yy, xx)) {
80
+ ret.set_global_mask(y, x, 1);
81
+ ret.set_mask(y, x, 1);
82
+ } else {
83
+ if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
84
+
85
+ if (is_masked(yy, xx)) {
86
+ ret.set_mask(y, x, 1);
87
+ } else {
88
+ auto source_ptr = get_image(yy, xx);
89
+ auto target_ptr = ret.get_mutable_image(y, x);
90
+ for (int c = 0; c < 3; ++c)
91
+ target_ptr[c] = source_ptr[c];
92
+ ret.set_mask(y, x, 0);
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
102
+ auto ret = upsample(new_w, new_h);
103
+ ret.set_global_mask_mat(new_global_mask);
104
+ return ret;
105
+ }
106
+
107
+ void MaskedImage::compute_image_gradients() {
108
+ if (m_image_grad_computed) {
109
+ return;
110
+ }
111
+
112
+ const auto size = m_image.size();
113
+ m_image_grady = cv::Mat(size, CV_8UC3);
114
+ m_image_gradx = cv::Mat(size, CV_8UC3);
115
+ m_image_grady = cv::Scalar::all(0);
116
+ m_image_gradx = cv::Scalar::all(0);
117
+
118
+ for (int i = 1; i < size.height - 1; ++i) {
119
+ const auto *ptr = m_image.ptr<unsigned char>(i, 0);
120
+ const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
121
+ const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
122
+ const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
123
+ const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
124
+ auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
125
+ auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
126
+ for (int j = 3; j < size.width * 3 - 3; ++j) {
127
+ mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
128
+ mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
129
+ }
130
+ }
131
+
132
+ m_image_grad_computed = true;
133
+ }
134
+
135
+ void MaskedImage::compute_image_gradients() const {
136
+ const_cast<MaskedImage *>(this)->compute_image_gradients();
137
+ }
138
+
PyPatchMatch/csrc/masked_image.h ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+
5
+ class MaskedImage {
6
+ public:
7
+ MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
8
+ // pass
9
+ }
10
+ MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
11
+ // pass
12
+ }
13
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
14
+ // pass
15
+ }
16
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
17
+ m_image(image), m_mask(mask), m_global_mask(global_mask),
18
+ m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
19
+ // pass
20
+ }
21
+ MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
22
+ m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
23
+ m_image = cv::Scalar::all(0);
24
+
25
+ m_mask = cv::Mat(cv::Size(width, height), CV_8U);
26
+ m_mask = cv::Scalar::all(0);
27
+ }
28
+ inline MaskedImage clone() {
29
+ return MaskedImage(
30
+ m_image.clone(), m_mask.clone(), m_global_mask.clone(),
31
+ m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
32
+ );
33
+ }
34
+
35
+ inline cv::Size size() const {
36
+ return m_image.size();
37
+ }
38
+ inline const cv::Mat &image() const {
39
+ return m_image;
40
+ }
41
+ inline const cv::Mat &mask() const {
42
+ return m_mask;
43
+ }
44
+ inline const cv::Mat &global_mask() const {
45
+ return m_global_mask;
46
+ }
47
+ inline const cv::Mat &grady() const {
48
+ assert(m_image_grad_computed);
49
+ return m_image_grady;
50
+ }
51
+ inline const cv::Mat &gradx() const {
52
+ assert(m_image_grad_computed);
53
+ return m_image_gradx;
54
+ }
55
+
56
+ inline void init_global_mask_mat() {
57
+ m_global_mask = cv::Mat(m_mask.size(), CV_8U);
58
+ m_global_mask.setTo(cv::Scalar(0));
59
+ }
60
+ inline void set_global_mask_mat(const cv::Mat &other) {
61
+ m_global_mask = other;
62
+ }
63
+
64
+ inline bool is_masked(int y, int x) const {
65
+ return static_cast<bool>(m_mask.at<unsigned char>(y, x));
66
+ }
67
+ inline bool is_globally_masked(int y, int x) const {
68
+ return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
69
+ }
70
+ inline void set_mask(int y, int x, bool value) {
71
+ m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
72
+ }
73
+ inline void set_global_mask(int y, int x, bool value) {
74
+ m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
75
+ }
76
+ inline void clear_mask() {
77
+ m_mask.setTo(cv::Scalar(0));
78
+ }
79
+
80
+ inline const unsigned char *get_image(int y, int x) const {
81
+ return m_image.ptr<unsigned char>(y, x);
82
+ }
83
+ inline unsigned char *get_mutable_image(int y, int x) {
84
+ return m_image.ptr<unsigned char>(y, x);
85
+ }
86
+
87
+ inline unsigned char get_image(int y, int x, int c) const {
88
+ return m_image.ptr<unsigned char>(y, x)[c];
89
+ }
90
+ inline int get_image_int(int y, int x, int c) const {
91
+ return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
92
+ }
93
+
94
+ bool contains_mask(int y, int x, int patch_size) const;
95
+ MaskedImage downsample() const;
96
+ MaskedImage upsample(int new_w, int new_h) const;
97
+ MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
98
+ void compute_image_gradients();
99
+ void compute_image_gradients() const;
100
+
101
+ static const cv::Size kDownsampleKernelSize;
102
+ static const int kDownsampleKernel[6];
103
+
104
+ private:
105
+ cv::Mat m_image;
106
+ cv::Mat m_mask;
107
+ cv::Mat m_global_mask;
108
+ cv::Mat m_image_grady;
109
+ cv::Mat m_image_gradx;
110
+ bool m_image_grad_computed = false;
111
+ };
112
+
PyPatchMatch/csrc/nnf.cpp ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <cmath>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ /**
9
+ * Nearest-Neighbor Field (see PatchMatch algorithm).
10
+ * This algorithme uses a version proposed by Xavier Philippeau.
11
+ *
12
+ */
13
+
14
+ template <typename T>
15
+ T clamp(T value, T min_value, T max_value) {
16
+ return std::min(std::max(value, min_value), max_value);
17
+ }
18
+
19
+ void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
20
+ auto this_size = source_size();
21
+ for (int i = 0; i < this_size.height; ++i) {
22
+ for (int j = 0; j < this_size.width; ++j) {
23
+ if (m_source.is_globally_masked(i, j)) continue;
24
+
25
+ auto this_ptr = mutable_ptr(i, j);
26
+ int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
27
+ if (distance < PatchDistanceMetric::kDistanceScale) {
28
+ continue;
29
+ }
30
+
31
+ int i_target = 0, j_target = 0;
32
+ for (int t = 0; t < max_retry; ++t) {
33
+ i_target = rand() % this_size.height;
34
+ j_target = rand() % this_size.width;
35
+ if (m_target.is_globally_masked(i_target, j_target)) continue;
36
+
37
+ distance = _distance(i, j, i_target, j_target);
38
+ if (distance < PatchDistanceMetric::kDistanceScale)
39
+ break;
40
+ }
41
+
42
+ this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
43
+ }
44
+ }
45
+ }
46
+
47
+ void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
48
+ const auto &this_size = source_size();
49
+ const auto &other_size = other.source_size();
50
+ double fi = static_cast<double>(this_size.height) / other_size.height;
51
+ double fj = static_cast<double>(this_size.width) / other_size.width;
52
+
53
+ for (int i = 0; i < this_size.height; ++i) {
54
+ for (int j = 0; j < this_size.width; ++j) {
55
+ if (m_source.is_globally_masked(i, j)) continue;
56
+
57
+ int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
58
+ int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
59
+ auto this_value = mutable_ptr(i, j);
60
+ auto other_value = other.ptr(ilow, jlow);
61
+
62
+ this_value[0] = static_cast<int>(other_value[0] * fi);
63
+ this_value[1] = static_cast<int>(other_value[1] * fj);
64
+ this_value[2] = _distance(i, j, this_value[0], this_value[1]);
65
+ }
66
+ }
67
+
68
+ _randomize_field(max_retry, false);
69
+ }
70
+
71
+ void NearestNeighborField::minimize(int nr_pass) {
72
+ const auto &this_size = source_size();
73
+ while (nr_pass--) {
74
+ for (int i = 0; i < this_size.height; ++i)
75
+ for (int j = 0; j < this_size.width; ++j) {
76
+ if (m_source.is_globally_masked(i, j)) continue;
77
+ if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
78
+ }
79
+ for (int i = this_size.height - 1; i >= 0; --i)
80
+ for (int j = this_size.width - 1; j >= 0; --j) {
81
+ if (m_source.is_globally_masked(i, j)) continue;
82
+ if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
83
+ }
84
+ }
85
+ }
86
+
87
+ void NearestNeighborField::_minimize_link(int y, int x, int direction) {
88
+ const auto &this_size = source_size();
89
+ const auto &this_target_size = target_size();
90
+ auto this_ptr = mutable_ptr(y, x);
91
+
92
+ // propagation along the y direction.
93
+ if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
94
+ int yp = at(y - direction, x, 0) + direction;
95
+ int xp = at(y - direction, x, 1);
96
+ int dp = _distance(y, x, yp, xp);
97
+ if (dp < at(y, x, 2)) {
98
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
99
+ }
100
+ }
101
+
102
+ // propagation along the x direction.
103
+ if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
104
+ int yp = at(y, x - direction, 0);
105
+ int xp = at(y, x - direction, 1) + direction;
106
+ int dp = _distance(y, x, yp, xp);
107
+ if (dp < at(y, x, 2)) {
108
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
109
+ }
110
+ }
111
+
112
+ // random search with a progressive step size.
113
+ int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
114
+ while (random_scale > 0) {
115
+ int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
116
+ int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
117
+ yp = clamp(yp, 0, target_size().height - 1);
118
+ xp = clamp(xp, 0, target_size().width - 1);
119
+
120
+ if (m_target.is_globally_masked(yp, xp)) {
121
+ random_scale /= 2;
122
+ }
123
+
124
+ int dp = _distance(y, x, yp, xp);
125
+ if (dp < at(y, x, 2)) {
126
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
127
+ }
128
+ random_scale /= 2;
129
+ }
130
+ }
131
+
132
+ const int PatchDistanceMetric::kDistanceScale = 65535;
133
+ const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
134
+
135
+ namespace {
136
+
137
+ inline int pow2(int i) {
138
+ return i * i;
139
+ }
140
+
141
+ int distance_masked_images(
142
+ const MaskedImage &source, int ys, int xs,
143
+ const MaskedImage &target, int yt, int xt,
144
+ int patch_size
145
+ ) {
146
+ long double distance = 0;
147
+ long double wsum = 0;
148
+
149
+ source.compute_image_gradients();
150
+ target.compute_image_gradients();
151
+
152
+ auto source_size = source.size();
153
+ auto target_size = target.size();
154
+
155
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
156
+ const int yys = ys + dy, yyt = yt + dy;
157
+
158
+ if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
159
+ distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
160
+ wsum += 2 * patch_size + 1;
161
+ continue;
162
+ }
163
+
164
+ const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
165
+ const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
166
+ const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
167
+ const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
168
+
169
+ const unsigned char *p_sgm = nullptr;
170
+ const unsigned char *p_tgm = nullptr;
171
+ if (!source.global_mask().empty()) {
172
+ p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
173
+ p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
174
+ }
175
+
176
+ const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
177
+ const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
178
+ const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
179
+ const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
180
+
181
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
182
+ int xxs = xs + dx, xxt = xt + dx;
183
+ wsum += 1;
184
+
185
+ if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
186
+ distance += PatchSSDDistanceMetric::kSSDScale;
187
+ continue;
188
+ }
189
+
190
+ if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
191
+ distance += PatchSSDDistanceMetric::kSSDScale;
192
+ continue;
193
+ }
194
+
195
+ int ssd = 0;
196
+ for (int c = 0; c < 3; ++c) {
197
+ int s_value = p_si[xxs * 3 + c];
198
+ int t_value = p_ti[xxt * 3 + c];
199
+ int s_gy = p_sgy[xxs * 3 + c];
200
+ int t_gy = p_tgy[xxt * 3 + c];
201
+ int s_gx = p_sgx[xxs * 3 + c];
202
+ int t_gx = p_tgx[xxt * 3 + c];
203
+
204
+ ssd += pow2(static_cast<int>(s_value) - t_value);
205
+ ssd += pow2(static_cast<int>(s_gx) - t_gx);
206
+ ssd += pow2(static_cast<int>(s_gy) - t_gy);
207
+ }
208
+ distance += ssd;
209
+ }
210
+ }
211
+
212
+ distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
213
+
214
+ int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
215
+ if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
216
+ return res;
217
+ }
218
+
219
+ }
220
+
221
+ int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
222
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
223
+ }
224
+
225
+ int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
226
+ fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
227
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
228
+ }
229
+
230
+ int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
231
+ double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
232
+ double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
233
+
234
+ double score1 = sqrt(dx * dx + dy *dy) / m_scale;
235
+ if (score1 < 0 || score1 > 1) score1 = 1;
236
+ score1 *= PatchDistanceMetric::kDistanceScale;
237
+
238
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
239
+ double score = score1 * m_weight + score2 / (1 + m_weight);
240
+ return static_cast<int>(score / (1 + m_weight));
241
+ }
242
+
243
+ int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
244
+ if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
245
+ return PatchDistanceMetric::kDistanceScale;
246
+
247
+ int source_scale = m_ijmap.size().height / source.size().height;
248
+ int target_scale = m_ijmap.size().height / target.size().height;
249
+
250
+ // fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
251
+
252
+ double score1 = PatchDistanceMetric::kDistanceScale;
253
+ if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
254
+ auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
255
+ auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
256
+
257
+ float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
258
+ float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
259
+ score1 = sqrt(di * di + dj *dj) / 0.707;
260
+ if (score1 < 0 || score1 > 1) score1 = 1;
261
+ score1 *= PatchDistanceMetric::kDistanceScale;
262
+ }
263
+
264
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
265
+ double score = score1 * m_weight + score2;
266
+ return int(score / (1 + m_weight));
267
+ }
268
+
PyPatchMatch/csrc/nnf.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+ #include "masked_image.h"
5
+
6
+ class PatchDistanceMetric {
7
+ public:
8
+ PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
9
+ virtual ~PatchDistanceMetric() = default;
10
+
11
+ inline int patch_size() const { return m_patch_size; }
12
+ virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
13
+ static const int kDistanceScale;
14
+
15
+ protected:
16
+ int m_patch_size;
17
+ };
18
+
19
+ class NearestNeighborField {
20
+ public:
21
+ NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
22
+ // pass
23
+ }
24
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
25
+ : m_source(source), m_target(target), m_distance_metric(metric) {
26
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
27
+ _randomize_field(max_retry);
28
+ }
29
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
30
+ : m_source(source), m_target(target), m_distance_metric(metric) {
31
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
32
+ _initialize_field_from(other, max_retry);
33
+ }
34
+
35
+ const MaskedImage &source() const {
36
+ return m_source;
37
+ }
38
+ const MaskedImage &target() const {
39
+ return m_target;
40
+ }
41
+ inline cv::Size source_size() const {
42
+ return m_source.size();
43
+ }
44
+ inline cv::Size target_size() const {
45
+ return m_target.size();
46
+ }
47
+ inline void set_source(const MaskedImage &source) {
48
+ m_source = source;
49
+ }
50
+ inline void set_target(const MaskedImage &target) {
51
+ m_target = target;
52
+ }
53
+
54
+ inline int *mutable_ptr(int y, int x) {
55
+ return m_field.ptr<int>(y, x);
56
+ }
57
+ inline const int *ptr(int y, int x) const {
58
+ return m_field.ptr<int>(y, x);
59
+ }
60
+
61
+ inline int at(int y, int x, int c) const {
62
+ return m_field.ptr<int>(y, x)[c];
63
+ }
64
+ inline int &at(int y, int x, int c) {
65
+ return m_field.ptr<int>(y, x)[c];
66
+ }
67
+ inline void set_identity(int y, int x) {
68
+ auto ptr = mutable_ptr(y, x);
69
+ ptr[0] = y, ptr[1] = x, ptr[2] = 0;
70
+ }
71
+
72
+ void minimize(int nr_pass);
73
+
74
+ private:
75
+ inline int _distance(int source_y, int source_x, int target_y, int target_x) {
76
+ return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
77
+ }
78
+
79
+ void _randomize_field(int max_retry = 20, bool reset = true);
80
+ void _initialize_field_from(const NearestNeighborField &other, int max_retry);
81
+ void _minimize_link(int y, int x, int direction);
82
+
83
+ MaskedImage m_source;
84
+ MaskedImage m_target;
85
+ cv::Mat m_field; // { y_target, x_target, distance_scaled }
86
+ const PatchDistanceMetric *m_distance_metric;
87
+ };
88
+
89
+
90
+ class PatchSSDDistanceMetric : public PatchDistanceMetric {
91
+ public:
92
+ using PatchDistanceMetric::PatchDistanceMetric;
93
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
94
+ static const int kSSDScale;
95
+ };
96
+
97
+ class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
98
+ public:
99
+ DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
100
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
101
+ protected:
102
+ int m_width, m_height;
103
+ };
104
+
105
+ class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
106
+ public:
107
+ RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
108
+ : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
109
+
110
+ assert(m_dy1 == 0);
111
+ assert(m_dx2 == 0);
112
+ m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
113
+ }
114
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
115
+
116
+ protected:
117
+ double m_dx1, m_dy1, m_dx2, m_dy2;
118
+ double m_scale, m_weight;
119
+ };
120
+
121
+ class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
122
+ public:
123
+ RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
124
+ : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
125
+
126
+ }
127
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
128
+
129
+ protected:
130
+ cv::Mat m_ijmap;
131
+ double m_width, m_height, m_weight;
132
+ };
133
+
PyPatchMatch/csrc/pyinterface.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "pyinterface.h"
2
+ #include "inpaint.h"
3
+
4
+ static unsigned int PM_seed = 1212;
5
+ static bool PM_verbose = false;
6
+
7
+ int _dtype_py_to_cv(int dtype_py);
8
+ int _dtype_cv_to_py(int dtype_cv);
9
+ cv::Mat _py_to_cv2(PM_mat_t pymat);
10
+ PM_mat_t _cv2_to_py(cv::Mat cvmat);
11
+
12
+ void PM_set_random_seed(unsigned int seed) {
13
+ PM_seed = seed;
14
+ }
15
+
16
+ void PM_set_verbose(int value) {
17
+ PM_verbose = static_cast<bool>(value);
18
+ }
19
+
20
+ void PM_free_pymat(PM_mat_t pymat) {
21
+ free(pymat.data_ptr);
22
+ }
23
+
24
+ PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
25
+ cv::Mat source = _py_to_cv2(source_py);
26
+ cv::Mat mask = _py_to_cv2(mask_py);
27
+ auto metric = PatchSSDDistanceMetric(patch_size);
28
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
29
+ return _cv2_to_py(result);
30
+ }
31
+
32
+ PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
33
+ cv::Mat source = _py_to_cv2(source_py);
34
+ cv::Mat mask = _py_to_cv2(mask_py);
35
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
36
+
37
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
38
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
39
+ return _cv2_to_py(result);
40
+ }
41
+
42
+ PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
43
+ cv::Mat source = _py_to_cv2(source_py);
44
+ cv::Mat mask = _py_to_cv2(mask_py);
45
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
46
+
47
+ auto metric = PatchSSDDistanceMetric(patch_size);
48
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
49
+ return _cv2_to_py(result);
50
+ }
51
+
52
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
53
+ cv::Mat source = _py_to_cv2(source_py);
54
+ cv::Mat mask = _py_to_cv2(mask_py);
55
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
56
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
57
+
58
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
59
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
60
+ return _cv2_to_py(result);
61
+ }
62
+
63
+ int _dtype_py_to_cv(int dtype_py) {
64
+ switch (dtype_py) {
65
+ case PM_UINT8: return CV_8U;
66
+ case PM_INT8: return CV_8S;
67
+ case PM_UINT16: return CV_16U;
68
+ case PM_INT16: return CV_16S;
69
+ case PM_INT32: return CV_32S;
70
+ case PM_FLOAT32: return CV_32F;
71
+ case PM_FLOAT64: return CV_64F;
72
+ }
73
+
74
+ return CV_8U;
75
+ }
76
+
77
+ int _dtype_cv_to_py(int dtype_cv) {
78
+ switch (dtype_cv) {
79
+ case CV_8U: return PM_UINT8;
80
+ case CV_8S: return PM_INT8;
81
+ case CV_16U: return PM_UINT16;
82
+ case CV_16S: return PM_INT16;
83
+ case CV_32S: return PM_INT32;
84
+ case CV_32F: return PM_FLOAT32;
85
+ case CV_64F: return PM_FLOAT64;
86
+ }
87
+
88
+ return PM_UINT8;
89
+ }
90
+
91
+ cv::Mat _py_to_cv2(PM_mat_t pymat) {
92
+ int dtype = _dtype_py_to_cv(pymat.dtype);
93
+ dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
94
+ return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
95
+ }
96
+
97
+ PM_mat_t _cv2_to_py(cv::Mat cvmat) {
98
+ PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
99
+ int dtype = _dtype_cv_to_py(cvmat.depth());
100
+ size_t dsize = cvmat.total() * cvmat.elemSize();
101
+
102
+ void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
103
+ memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
104
+
105
+ return PM_mat_t {data_ptr, shape, dtype};
106
+ }
107
+
PyPatchMatch/csrc/pyinterface.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <opencv2/core.hpp>
2
+ #include <cstdlib>
3
+ #include <cstdio>
4
+ #include <cstring>
5
+
6
+ extern "C" {
7
+
8
+ struct PM_shape_t {
9
+ int width, height, channels;
10
+ };
11
+
12
+ enum PM_dtype_e {
13
+ PM_UINT8,
14
+ PM_INT8,
15
+ PM_UINT16,
16
+ PM_INT16,
17
+ PM_INT32,
18
+ PM_FLOAT32,
19
+ PM_FLOAT64,
20
+ };
21
+
22
+ struct PM_mat_t {
23
+ void *data_ptr;
24
+ PM_shape_t shape;
25
+ int dtype;
26
+ };
27
+
28
+ void PM_set_random_seed(unsigned int seed);
29
+ void PM_set_verbose(int value);
30
+
31
+ void PM_free_pymat(PM_mat_t pymat);
32
+ PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
33
+ PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
34
+ PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
35
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
36
+
37
+ } /* extern "C" */
38
+
PyPatchMatch/examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /cpp_example.exe
2
+ /images/*recovered.bmp
PyPatchMatch/examples/cpp_example.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <opencv2/imgcodecs.hpp>
3
+ #include <opencv2/highgui.hpp>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+ #include "inpaint.h"
8
+
9
+ int main() {
10
+ auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
11
+
12
+ auto mask = cv::Mat(source.size(), CV_8UC1);
13
+ mask = cv::Scalar::all(0);
14
+ for (int i = 0; i < source.size().height; ++i) {
15
+ for (int j = 0; j < source.size().width; ++j) {
16
+ auto source_ptr = source.ptr<unsigned char>(i, j);
17
+ if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
18
+ mask.at<unsigned char>(i, j) = 1;
19
+ }
20
+ }
21
+ }
22
+
23
+ auto metric = PatchSSDDistanceMetric(3);
24
+ auto result = Inpainting(source, mask, &metric).run(true, true);
25
+ // cv::imwrite("./images/forest_recovered.bmp", result);
26
+ // cv::imshow("Result", result);
27
+ // cv::waitKey();
28
+
29
+ return 0;
30
+ }
31
+
PyPatchMatch/examples/cpp_example_run.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # cpp_example_run.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ set -x
10
+
11
+ CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
12
+ LDFLAGS="$(pkg-config --libs opencv)"
13
+ g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
14
+
15
+ export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
16
+ export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
17
+ time ./cpp_example.exe
18
+
PyPatchMatch/examples/images/forest.bmp ADDED
PyPatchMatch/examples/images/forest_pruned.bmp ADDED
PyPatchMatch/examples/py_example.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ from PIL import Image
11
+
12
+ import sys
13
+ sys.path.insert(0, '../')
14
+ import patch_match
15
+
16
+
17
+ if __name__ == '__main__':
18
+ source = Image.open('./images/forest_pruned.bmp')
19
+ result = patch_match.inpaint(source, patch_size=3)
20
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
21
+
PyPatchMatch/examples/py_example_global_mask.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import sys
14
+ sys.path.insert(0, '../')
15
+ import patch_match
16
+
17
+
18
+ if __name__ == '__main__':
19
+ patch_match.set_verbose(True)
20
+ source = Image.open('./images/forest_pruned.bmp')
21
+ source = np.array(source)
22
+ source[:100, :100] = 255
23
+ global_mask = np.zeros_like(source[..., 0])
24
+ global_mask[:100, :100] = 1
25
+ result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
26
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
27
+
PyPatchMatch/patch_match.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+
18
+ import os
19
+ if os.name!="nt":
20
+ # Otherwise, fall back to the subprocess.
21
+ import subprocess
22
+ print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
23
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
24
+ subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
25
+
26
+
27
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
28
+
29
+
30
+ class CShapeT(ctypes.Structure):
31
+ _fields_ = [
32
+ ('width', ctypes.c_int),
33
+ ('height', ctypes.c_int),
34
+ ('channels', ctypes.c_int),
35
+ ]
36
+
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import tempfile
46
+ from urllib.request import urlopen, Request
47
+ import shutil
48
+ from pathlib import Path
49
+ from tqdm import tqdm
50
+
51
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
52
+ r"""Download object at the given URL to a local path.
53
+
54
+ Args:
55
+ url (string): URL of the object to download
56
+ dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
57
+ hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
58
+ Default: None
59
+ progress (bool, optional): whether or not to display a progress bar to stderr
60
+ Default: True
61
+ https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
62
+ """
63
+ file_size = None
64
+ req = Request(url)
65
+ u = urlopen(req)
66
+ meta = u.info()
67
+ if hasattr(meta, 'getheaders'):
68
+ content_length = meta.getheaders("Content-Length")
69
+ else:
70
+ content_length = meta.get_all("Content-Length")
71
+ if content_length is not None and len(content_length) > 0:
72
+ file_size = int(content_length[0])
73
+
74
+ # We deliberately save it in a temp file and move it after
75
+ # download is complete. This prevents a local working checkpoint
76
+ # being overridden by a broken download.
77
+ dst = os.path.expanduser(dst)
78
+ dst_dir = os.path.dirname(dst)
79
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
80
+
81
+ try:
82
+ with tqdm(total=file_size, disable=not progress,
83
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
84
+ while True:
85
+ buffer = u.read(8192)
86
+ if len(buffer) == 0:
87
+ break
88
+ f.write(buffer)
89
+ pbar.update(len(buffer))
90
+
91
+ f.close()
92
+ shutil.move(f.name, dst)
93
+ finally:
94
+ f.close()
95
+ if os.path.exists(f.name):
96
+ os.remove(f.name)
97
+
98
+ if os.name!="nt":
99
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
100
+ else:
101
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
102
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
103
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
104
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
105
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
106
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
107
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
108
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
109
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
110
+
111
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
112
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
113
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
114
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
115
+ PMLIB.PM_inpaint.restype = CMatT
116
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
117
+ PMLIB.PM_inpaint_regularity.restype = CMatT
118
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
119
+ PMLIB.PM_inpaint2.restype = CMatT
120
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
121
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
122
+
123
+
124
+ def set_random_seed(seed: int):
125
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
126
+
127
+
128
+ def set_verbose(verbose: bool):
129
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
130
+
131
+
132
+ def inpaint(
133
+ image: Union[np.ndarray, Image.Image],
134
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
135
+ *,
136
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
137
+ patch_size: int = 15
138
+ ) -> np.ndarray:
139
+ """
140
+ PatchMatch based inpainting proposed in:
141
+
142
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
143
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
144
+ SIGGRAPH 2009
145
+
146
+ Args:
147
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
148
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
149
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
150
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
151
+ patch_size (int): the patch size for the inpainting algorithm.
152
+
153
+ Return:
154
+ result (np.ndarray): the repaired image, of the same size as the input image.
155
+ """
156
+
157
+ if isinstance(image, Image.Image):
158
+ image = np.array(image)
159
+ image = np.ascontiguousarray(image)
160
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
161
+
162
+ if mask is None:
163
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
164
+ mask = np.ascontiguousarray(mask)
165
+ else:
166
+ mask = _canonize_mask_array(mask)
167
+
168
+ if global_mask is None:
169
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
170
+ else:
171
+ global_mask = _canonize_mask_array(global_mask)
172
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
173
+
174
+ ret_npmat = pymat_to_np(ret_pymat)
175
+ PMLIB.PM_free_pymat(ret_pymat)
176
+
177
+ return ret_npmat
178
+
179
+
180
+ def inpaint_regularity(
181
+ image: Union[np.ndarray, Image.Image],
182
+ mask: Optional[Union[np.ndarray, Image.Image]],
183
+ ijmap: np.ndarray,
184
+ *,
185
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
186
+ patch_size: int = 15, guide_weight: float = 0.25
187
+ ) -> np.ndarray:
188
+ if isinstance(image, Image.Image):
189
+ image = np.array(image)
190
+ image = np.ascontiguousarray(image)
191
+
192
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
193
+ ijmap = np.ascontiguousarray(ijmap)
194
+
195
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
196
+ if mask is None:
197
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
198
+ mask = np.ascontiguousarray(mask)
199
+ else:
200
+ mask = _canonize_mask_array(mask)
201
+
202
+
203
+ if global_mask is None:
204
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
205
+ else:
206
+ global_mask = _canonize_mask_array(global_mask)
207
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
208
+
209
+ ret_npmat = pymat_to_np(ret_pymat)
210
+ PMLIB.PM_free_pymat(ret_pymat)
211
+
212
+ return ret_npmat
213
+
214
+
215
+ def _canonize_mask_array(mask):
216
+ if isinstance(mask, Image.Image):
217
+ mask = np.array(mask)
218
+ if mask.ndim == 2 and mask.dtype == 'uint8':
219
+ mask = mask[..., np.newaxis]
220
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
221
+ return np.ascontiguousarray(mask)
222
+
223
+
224
+ dtype_pymat_to_ctypes = [
225
+ ctypes.c_uint8,
226
+ ctypes.c_int8,
227
+ ctypes.c_uint16,
228
+ ctypes.c_int16,
229
+ ctypes.c_int32,
230
+ ctypes.c_float,
231
+ ctypes.c_double,
232
+ ]
233
+
234
+
235
+ dtype_np_to_pymat = {
236
+ 'uint8': 0,
237
+ 'int8': 1,
238
+ 'uint16': 2,
239
+ 'int16': 3,
240
+ 'int32': 4,
241
+ 'float32': 5,
242
+ 'float64': 6,
243
+ }
244
+
245
+
246
+ def np_to_pymat(npmat):
247
+ assert npmat.ndim == 3
248
+ return CMatT(
249
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
250
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
251
+ dtype_np_to_pymat[str(npmat.dtype)]
252
+ )
253
+
254
+
255
+ def pymat_to_np(pymat):
256
+ npmat = np.ctypeslib.as_array(
257
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
258
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
259
+ )
260
+ ret = np.empty(npmat.shape, npmat.dtype)
261
+ ret[:] = npmat
262
+ return ret
263
+
PyPatchMatch/travis.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # travis.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ make clean && make
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stablediffusion Infinity
3
+ emoji: ♾️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.10.1
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ duplicated_from: lnyan/stablediffusion-infinity
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ # import os.path as osp
3
+ import pip
4
+ # pip.main(["install","-v","-U","git+https://github.com/facebookresearch/xformers.git@main#egg=xformers"])
5
+ # subprocess.check_call("pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", cwd=osp.dirname(__file__), shell=True)
6
+
7
+ import io
8
+ import base64
9
+ import os
10
+ import sys
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import autocast
15
+ import diffusers
16
+ from diffusers.configuration_utils import FrozenDict
17
+ from diffusers import (
18
+ StableDiffusionPipeline,
19
+ StableDiffusionInpaintPipeline,
20
+ StableDiffusionImg2ImgPipeline,
21
+ StableDiffusionInpaintPipelineLegacy,
22
+ DDIMScheduler,
23
+ LMSDiscreteScheduler,
24
+ StableDiffusionUpscalePipeline,
25
+ DPMSolverMultistepScheduler
26
+ )
27
+ from diffusers.models import AutoencoderKL
28
+ from PIL import Image
29
+ from PIL import ImageOps
30
+ import gradio as gr
31
+ import base64
32
+ import skimage
33
+ import skimage.measure
34
+ import yaml
35
+ import json
36
+ from enum import Enum
37
+
38
+ try:
39
+ abspath = os.path.abspath(__file__)
40
+ dirname = os.path.dirname(abspath)
41
+ os.chdir(dirname)
42
+ except:
43
+ pass
44
+
45
+ from utils import *
46
+
47
+ assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0"
48
+
49
+ USE_NEW_DIFFUSERS = True
50
+ RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
51
+
52
+
53
+ class ModelChoice(Enum):
54
+ INPAINTING = "stablediffusion-inpainting"
55
+ INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5"
56
+ MODEL_1_5 = "stablediffusion-v1.5"
57
+ MODEL_1_4 = "stablediffusion-v1.4"
58
+
59
+
60
+ try:
61
+ from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
62
+ except:
63
+ UnifiedPipeline = StableDiffusionInpaintPipeline
64
+
65
+ # sys.path.append("./glid_3_xl_stable")
66
+
67
+ USE_GLID = False
68
+ # try:
69
+ # from glid3xlmodel import GlidModel
70
+ # except:
71
+ # USE_GLID = False
72
+
73
+ try:
74
+ cuda_available = torch.cuda.is_available()
75
+ except:
76
+ cuda_available = False
77
+ finally:
78
+ if sys.platform == "darwin":
79
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
80
+ elif cuda_available:
81
+ device = "cuda"
82
+ else:
83
+ device = "cpu"
84
+
85
+ import contextlib
86
+
87
+ autocast = contextlib.nullcontext
88
+
89
+ with open("config.yaml", "r") as yaml_in:
90
+ yaml_object = yaml.safe_load(yaml_in)
91
+ config_json = json.dumps(yaml_object)
92
+
93
+
94
+ def load_html():
95
+ body, canvaspy = "", ""
96
+ with open("index.html", encoding="utf8") as f:
97
+ body = f.read()
98
+ with open("canvas.py", encoding="utf8") as f:
99
+ canvaspy = f.read()
100
+ body = body.replace("- paths:\n", "")
101
+ body = body.replace(" - ./canvas.py\n", "")
102
+ body = body.replace("from canvas import InfCanvas", canvaspy)
103
+ return body
104
+
105
+
106
+ def test(x):
107
+ x = load_html()
108
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
109
+ display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
110
+ allow-scripts allow-same-origin allow-popups
111
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
112
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
113
+
114
+
115
+ DEBUG_MODE = False
116
+
117
+ try:
118
+ SAMPLING_MODE = Image.Resampling.LANCZOS
119
+ except Exception as e:
120
+ SAMPLING_MODE = Image.LANCZOS
121
+
122
+ try:
123
+ contain_func = ImageOps.contain
124
+ except Exception as e:
125
+
126
+ def contain_func(image, size, method=SAMPLING_MODE):
127
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
128
+ im_ratio = image.width / image.height
129
+ dest_ratio = size[0] / size[1]
130
+ if im_ratio != dest_ratio:
131
+ if im_ratio > dest_ratio:
132
+ new_height = int(image.height / image.width * size[0])
133
+ if new_height != size[1]:
134
+ size = (size[0], new_height)
135
+ else:
136
+ new_width = int(image.width / image.height * size[1])
137
+ if new_width != size[0]:
138
+ size = (new_width, size[1])
139
+ return image.resize(size, resample=method)
140
+
141
+
142
+ import argparse
143
+
144
+ parser = argparse.ArgumentParser(description="stablediffusion-infinity")
145
+ parser.add_argument("--port", type=int, help="listen port", dest="server_port")
146
+ parser.add_argument("--host", type=str, help="host", dest="server_name")
147
+ parser.add_argument("--share", action="store_true", help="share this app?")
148
+ parser.add_argument("--debug", action="store_true", help="debug mode")
149
+ parser.add_argument("--fp32", action="store_true", help="using full precision")
150
+ parser.add_argument("--encrypt", action="store_true", help="using https?")
151
+ parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
152
+ parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
153
+ parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
154
+ parser.add_argument(
155
+ "--auth", nargs=2, metavar=("username", "password"), help="use username password"
156
+ )
157
+ parser.add_argument(
158
+ "--remote_model",
159
+ type=str,
160
+ help="use a model (e.g. dreambooth fined) from huggingface hub",
161
+ default="",
162
+ )
163
+ parser.add_argument(
164
+ "--local_model", type=str, help="use a model stored on your PC", default=""
165
+ )
166
+
167
+ if __name__ == "__main__" and not RUN_IN_SPACE:
168
+ args = parser.parse_args()
169
+ else:
170
+ args = parser.parse_args()
171
+ # args = parser.parse_args(["--debug"])
172
+ if args.auth is not None:
173
+ args.auth = tuple(args.auth)
174
+
175
+ model = {}
176
+
177
+
178
+ def get_token():
179
+ token = ""
180
+ if os.path.exists(".token"):
181
+ with open(".token", "r") as f:
182
+ token = f.read()
183
+ token = os.environ.get("hftoken", token)
184
+ return token
185
+
186
+
187
+ def save_token(token):
188
+ with open(".token", "w") as f:
189
+ f.write(token)
190
+
191
+
192
+ def prepare_scheduler(scheduler):
193
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
194
+ new_config = dict(scheduler.config)
195
+ new_config["steps_offset"] = 1
196
+ scheduler._internal_dict = FrozenDict(new_config)
197
+ return scheduler
198
+
199
+
200
+ def my_resize(width, height):
201
+ if width >= 512 and height >= 512:
202
+ return width, height
203
+ if width == height:
204
+ return 512, 512
205
+ smaller = min(width, height)
206
+ larger = max(width, height)
207
+ if larger >= 608:
208
+ return width, height
209
+ factor = 1
210
+ if smaller < 290:
211
+ factor = 2
212
+ elif smaller < 330:
213
+ factor = 1.75
214
+ elif smaller < 384:
215
+ factor = 1.375
216
+ elif smaller < 400:
217
+ factor = 1.25
218
+ elif smaller < 450:
219
+ factor = 1.125
220
+ return int(factor * width)//8*8, int(factor * height)//8*8
221
+
222
+
223
+ def load_learned_embed_in_clip(
224
+ learned_embeds_path, text_encoder, tokenizer, token=None
225
+ ):
226
+ # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
227
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
228
+
229
+ # separate token and the embeds
230
+ trained_token = list(loaded_learned_embeds.keys())[0]
231
+ embeds = loaded_learned_embeds[trained_token]
232
+
233
+ # cast to dtype of text_encoder
234
+ dtype = text_encoder.get_input_embeddings().weight.dtype
235
+ embeds.to(dtype)
236
+
237
+ # add the token in tokenizer
238
+ token = token if token is not None else trained_token
239
+ num_added_tokens = tokenizer.add_tokens(token)
240
+ if num_added_tokens == 0:
241
+ raise ValueError(
242
+ f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
243
+ )
244
+
245
+ # resize the token embeddings
246
+ text_encoder.resize_token_embeddings(len(tokenizer))
247
+
248
+ # get the id for the token and assign the embeds
249
+ token_id = tokenizer.convert_tokens_to_ids(token)
250
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
251
+
252
+
253
+ scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None}
254
+
255
+
256
+ class StableDiffusionInpaint:
257
+ def __init__(
258
+ self, token: str = "", model_name: str = "", model_path: str = "", **kwargs,
259
+ ):
260
+ self.token = token
261
+ original_checkpoint = False
262
+ if model_path and os.path.exists(model_path):
263
+ if model_path.endswith(".ckpt"):
264
+ original_checkpoint = True
265
+ elif model_path.endswith(".json"):
266
+ model_name = os.path.dirname(model_path)
267
+ else:
268
+ model_name = model_path
269
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
270
+ vae.to(torch.float16)
271
+ if original_checkpoint:
272
+ print(f"Converting & Loading {model_path}")
273
+ from convert_checkpoint import convert_checkpoint
274
+
275
+ pipe = convert_checkpoint(model_path, inpainting=True)
276
+ if device == "cuda":
277
+ pipe.to(torch.float16)
278
+ inpaint = StableDiffusionInpaintPipeline(
279
+ vae=vae,
280
+ text_encoder=pipe.text_encoder,
281
+ tokenizer=pipe.tokenizer,
282
+ unet=pipe.unet,
283
+ scheduler=pipe.scheduler,
284
+ safety_checker=pipe.safety_checker,
285
+ feature_extractor=pipe.feature_extractor,
286
+ )
287
+ else:
288
+ print(f"Loading {model_name}")
289
+ if device == "cuda":
290
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
291
+ model_name,
292
+ revision="fp16",
293
+ torch_dtype=torch.float16,
294
+ use_auth_token=token,
295
+ vae=vae
296
+ )
297
+ else:
298
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
299
+ model_name, use_auth_token=token,
300
+ )
301
+ if os.path.exists("./embeddings"):
302
+ print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
303
+ for item in os.listdir("./embeddings"):
304
+ if item.endswith(".bin"):
305
+ load_learned_embed_in_clip(
306
+ os.path.join("./embeddings", item),
307
+ inpaint.text_encoder,
308
+ inpaint.tokenizer,
309
+ )
310
+ inpaint.to(device)
311
+ # try:
312
+ # inpaint.vae=torch.compile(inpaint.vae, dynamic=True)
313
+ # inpaint.unet=torch.compile(inpaint.unet, dynamic=True)
314
+ # except Exception as e:
315
+ # print(e)
316
+ # inpaint.enable_xformers_memory_efficient_attention()
317
+ # if device == "mps":
318
+ # _ = text2img("", num_inference_steps=1)
319
+ scheduler_dict["PLMS"] = inpaint.scheduler
320
+ scheduler_dict["DDIM"] = prepare_scheduler(
321
+ DDIMScheduler(
322
+ beta_start=0.00085,
323
+ beta_end=0.012,
324
+ beta_schedule="scaled_linear",
325
+ clip_sample=False,
326
+ set_alpha_to_one=False,
327
+ )
328
+ )
329
+ scheduler_dict["K-LMS"] = prepare_scheduler(
330
+ LMSDiscreteScheduler(
331
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
332
+ )
333
+ )
334
+ scheduler_dict["DPM"] = prepare_scheduler(
335
+ DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config)
336
+ )
337
+ self.safety_checker = inpaint.safety_checker
338
+ save_token(token)
339
+ try:
340
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
341
+ 1024 ** 3
342
+ )
343
+ if total_memory <= 5:
344
+ inpaint.enable_attention_slicing()
345
+ except:
346
+ pass
347
+ self.inpaint = inpaint
348
+
349
+ def run(
350
+ self,
351
+ image_pil,
352
+ prompt="",
353
+ negative_prompt="",
354
+ guidance_scale=7.5,
355
+ resize_check=True,
356
+ enable_safety=True,
357
+ fill_mode="patchmatch",
358
+ strength=0.75,
359
+ step=50,
360
+ enable_img2img=False,
361
+ use_seed=False,
362
+ seed_val=-1,
363
+ generate_num=1,
364
+ scheduler="",
365
+ scheduler_eta=0.0,
366
+ **kwargs,
367
+ ):
368
+ inpaint = self.inpaint
369
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
370
+ for item in [inpaint]:
371
+ item.scheduler = selected_scheduler
372
+ if enable_safety:
373
+ item.safety_checker = self.safety_checker
374
+ else:
375
+ item.safety_checker = lambda images, **kwargs: (images, False)
376
+ width, height = image_pil.size
377
+ sel_buffer = np.array(image_pil)
378
+ img = sel_buffer[:, :, 0:3]
379
+ mask = sel_buffer[:, :, -1]
380
+ nmask = 255 - mask
381
+ process_width = width
382
+ process_height = height
383
+ if resize_check:
384
+ process_width, process_height = my_resize(width, height)
385
+ process_width=process_width*8//8
386
+ process_height=process_height*8//8
387
+ extra_kwargs = {
388
+ "num_inference_steps": step,
389
+ "guidance_scale": guidance_scale,
390
+ "eta": scheduler_eta,
391
+ }
392
+ if USE_NEW_DIFFUSERS:
393
+ extra_kwargs["negative_prompt"] = negative_prompt
394
+ extra_kwargs["num_images_per_prompt"] = generate_num
395
+ if use_seed:
396
+ generator = torch.Generator(inpaint.device).manual_seed(seed_val)
397
+ extra_kwargs["generator"] = generator
398
+ if True:
399
+ img, mask = functbl[fill_mode](img, mask)
400
+ mask = 255 - mask
401
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
402
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
403
+ extra_kwargs["strength"] = strength
404
+ inpaint_func = inpaint
405
+ init_image = Image.fromarray(img)
406
+ mask_image = Image.fromarray(mask)
407
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
408
+ if True:
409
+ images = inpaint_func(
410
+ prompt=prompt,
411
+ image=init_image.resize(
412
+ (process_width, process_height), resample=SAMPLING_MODE
413
+ ),
414
+ mask_image=mask_image.resize((process_width, process_height)),
415
+ width=process_width,
416
+ height=process_height,
417
+ **extra_kwargs,
418
+ )["images"]
419
+ return images
420
+
421
+
422
+ class StableDiffusion:
423
+ def __init__(
424
+ self,
425
+ token: str = "",
426
+ model_name: str = "runwayml/stable-diffusion-v1-5",
427
+ model_path: str = None,
428
+ inpainting_model: bool = False,
429
+ **kwargs,
430
+ ):
431
+ self.token = token
432
+ original_checkpoint = False
433
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
434
+ vae.to(torch.float16)
435
+ if model_path and os.path.exists(model_path):
436
+ if model_path.endswith(".ckpt"):
437
+ original_checkpoint = True
438
+ elif model_path.endswith(".json"):
439
+ model_name = os.path.dirname(model_path)
440
+ else:
441
+ model_name = model_path
442
+ if original_checkpoint:
443
+ print(f"Converting & Loading {model_path}")
444
+ from convert_checkpoint import convert_checkpoint
445
+
446
+ text2img = convert_checkpoint(model_path)
447
+ if device == "cuda" and not args.fp32:
448
+ text2img.to(torch.float16)
449
+ else:
450
+ print(f"Loading {model_name}")
451
+ if device == "cuda" and not args.fp32:
452
+ text2img = StableDiffusionPipeline.from_pretrained(
453
+ "runwayml/stable-diffusion-v1-5",
454
+ revision="fp16",
455
+ torch_dtype=torch.float16,
456
+ use_auth_token=token,
457
+ vae=vae
458
+ )
459
+ else:
460
+ text2img = StableDiffusionPipeline.from_pretrained(
461
+ model_name, use_auth_token=token,
462
+ )
463
+ if inpainting_model:
464
+ # can reduce vRAM by reusing models except unet
465
+ text2img_unet = text2img.unet
466
+ del text2img.vae
467
+ del text2img.text_encoder
468
+ del text2img.tokenizer
469
+ del text2img.scheduler
470
+ del text2img.safety_checker
471
+ del text2img.feature_extractor
472
+ import gc
473
+
474
+ gc.collect()
475
+ if device == "cuda":
476
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
477
+ "runwayml/stable-diffusion-inpainting",
478
+ revision="fp16",
479
+ torch_dtype=torch.float16,
480
+ use_auth_token=token,
481
+ vae=vae
482
+ ).to(device)
483
+ else:
484
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
485
+ "runwayml/stable-diffusion-inpainting", use_auth_token=token,
486
+ ).to(device)
487
+ text2img_unet.to(device)
488
+ del text2img
489
+ gc.collect()
490
+ text2img = StableDiffusionPipeline(
491
+ vae=inpaint.vae,
492
+ text_encoder=inpaint.text_encoder,
493
+ tokenizer=inpaint.tokenizer,
494
+ unet=text2img_unet,
495
+ scheduler=inpaint.scheduler,
496
+ safety_checker=inpaint.safety_checker,
497
+ feature_extractor=inpaint.feature_extractor,
498
+ )
499
+ else:
500
+ inpaint = StableDiffusionInpaintPipelineLegacy(
501
+ vae=text2img.vae,
502
+ text_encoder=text2img.text_encoder,
503
+ tokenizer=text2img.tokenizer,
504
+ unet=text2img.unet,
505
+ scheduler=text2img.scheduler,
506
+ safety_checker=text2img.safety_checker,
507
+ feature_extractor=text2img.feature_extractor,
508
+ ).to(device)
509
+ text_encoder = text2img.text_encoder
510
+ tokenizer = text2img.tokenizer
511
+ if os.path.exists("./embeddings"):
512
+ for item in os.listdir("./embeddings"):
513
+ if item.endswith(".bin"):
514
+ load_learned_embed_in_clip(
515
+ os.path.join("./embeddings", item),
516
+ text2img.text_encoder,
517
+ text2img.tokenizer,
518
+ )
519
+ text2img.to(device)
520
+ if device == "mps":
521
+ _ = text2img("", num_inference_steps=1)
522
+ scheduler_dict["PLMS"] = text2img.scheduler
523
+ scheduler_dict["DDIM"] = prepare_scheduler(
524
+ DDIMScheduler(
525
+ beta_start=0.00085,
526
+ beta_end=0.012,
527
+ beta_schedule="scaled_linear",
528
+ clip_sample=False,
529
+ set_alpha_to_one=False,
530
+ )
531
+ )
532
+ scheduler_dict["K-LMS"] = prepare_scheduler(
533
+ LMSDiscreteScheduler(
534
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
535
+ )
536
+ )
537
+ scheduler_dict["DPM"] = prepare_scheduler(
538
+ DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
539
+ )
540
+ self.safety_checker = text2img.safety_checker
541
+ img2img = StableDiffusionImg2ImgPipeline(
542
+ vae=text2img.vae,
543
+ text_encoder=text2img.text_encoder,
544
+ tokenizer=text2img.tokenizer,
545
+ unet=text2img.unet,
546
+ scheduler=text2img.scheduler,
547
+ safety_checker=text2img.safety_checker,
548
+ feature_extractor=text2img.feature_extractor,
549
+ ).to(device)
550
+ save_token(token)
551
+ try:
552
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
553
+ 1024 ** 3
554
+ )
555
+ if total_memory <= 5:
556
+ inpaint.enable_attention_slicing()
557
+ except:
558
+ pass
559
+ self.text2img = text2img
560
+ self.inpaint = inpaint
561
+ self.img2img = img2img
562
+ self.unified = UnifiedPipeline(
563
+ vae=text2img.vae,
564
+ text_encoder=text2img.text_encoder,
565
+ tokenizer=text2img.tokenizer,
566
+ unet=text2img.unet,
567
+ scheduler=text2img.scheduler,
568
+ safety_checker=text2img.safety_checker,
569
+ feature_extractor=text2img.feature_extractor,
570
+ ).to(device)
571
+ self.inpainting_model = inpainting_model
572
+
573
+ def run(
574
+ self,
575
+ image_pil,
576
+ prompt="",
577
+ negative_prompt="",
578
+ guidance_scale=7.5,
579
+ resize_check=True,
580
+ enable_safety=True,
581
+ fill_mode="patchmatch",
582
+ strength=0.75,
583
+ step=50,
584
+ enable_img2img=False,
585
+ use_seed=False,
586
+ seed_val=-1,
587
+ generate_num=1,
588
+ scheduler="",
589
+ scheduler_eta=0.0,
590
+ **kwargs,
591
+ ):
592
+ text2img, inpaint, img2img, unified = (
593
+ self.text2img,
594
+ self.inpaint,
595
+ self.img2img,
596
+ self.unified,
597
+ )
598
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
599
+ for item in [text2img, inpaint, img2img, unified]:
600
+ item.scheduler = selected_scheduler
601
+ if enable_safety:
602
+ item.safety_checker = self.safety_checker
603
+ else:
604
+ item.safety_checker = lambda images, **kwargs: (images, False)
605
+ if RUN_IN_SPACE:
606
+ step = max(150, step)
607
+ image_pil = contain_func(image_pil, (1024, 1024))
608
+ width, height = image_pil.size
609
+ sel_buffer = np.array(image_pil)
610
+ img = sel_buffer[:, :, 0:3]
611
+ mask = sel_buffer[:, :, -1]
612
+ nmask = 255 - mask
613
+ process_width = width
614
+ process_height = height
615
+ if resize_check:
616
+ process_width, process_height = my_resize(width, height)
617
+ extra_kwargs = {
618
+ "num_inference_steps": step,
619
+ "guidance_scale": guidance_scale,
620
+ "eta": scheduler_eta,
621
+ }
622
+ if RUN_IN_SPACE:
623
+ generate_num = max(
624
+ int(4 * 512 * 512 // process_width // process_height), generate_num
625
+ )
626
+ if USE_NEW_DIFFUSERS:
627
+ extra_kwargs["negative_prompt"] = negative_prompt
628
+ extra_kwargs["num_images_per_prompt"] = generate_num
629
+ if use_seed:
630
+ generator = torch.Generator(text2img.device).manual_seed(seed_val)
631
+ extra_kwargs["generator"] = generator
632
+ if nmask.sum() < 1 and enable_img2img:
633
+ init_image = Image.fromarray(img)
634
+ if True:
635
+ images = img2img(
636
+ prompt=prompt,
637
+ init_image=init_image.resize(
638
+ (process_width, process_height), resample=SAMPLING_MODE
639
+ ),
640
+ strength=strength,
641
+ **extra_kwargs,
642
+ )["images"]
643
+ elif mask.sum() > 0:
644
+ if fill_mode == "g_diffuser" and not self.inpainting_model:
645
+ mask = 255 - mask
646
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
647
+ img, mask, out_mask = functbl[fill_mode](img, mask)
648
+ extra_kwargs["strength"] = 1.0
649
+ extra_kwargs["out_mask"] = Image.fromarray(out_mask)
650
+ inpaint_func = unified
651
+ else:
652
+ img, mask = functbl[fill_mode](img, mask)
653
+ mask = 255 - mask
654
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
655
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
656
+ extra_kwargs["strength"] = strength
657
+ inpaint_func = inpaint
658
+ init_image = Image.fromarray(img)
659
+ mask_image = Image.fromarray(mask)
660
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
661
+ if True:
662
+ input_image = init_image.resize(
663
+ (process_width, process_height), resample=SAMPLING_MODE
664
+ )
665
+ images = inpaint_func(
666
+ prompt=prompt,
667
+ init_image=input_image,
668
+ image=input_image,
669
+ width=process_width,
670
+ height=process_height,
671
+ mask_image=mask_image.resize((process_width, process_height)),
672
+ **extra_kwargs,
673
+ )["images"]
674
+ else:
675
+ if True:
676
+ images = text2img(
677
+ prompt=prompt,
678
+ height=process_width,
679
+ width=process_height,
680
+ **extra_kwargs,
681
+ )["images"]
682
+ return images
683
+
684
+
685
+ def get_model(token="", model_choice="", model_path=""):
686
+ if "model" not in model:
687
+ model_name = ""
688
+ if model_choice == ModelChoice.INPAINTING.value:
689
+ if len(model_name) < 1:
690
+ model_name = "runwayml/stable-diffusion-inpainting"
691
+ print(f"Using [{model_name}] {model_path}")
692
+ tmp = StableDiffusionInpaint(
693
+ token=token, model_name=model_name, model_path=model_path
694
+ )
695
+ elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
696
+ print(
697
+ f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
698
+ )
699
+ tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True)
700
+ else:
701
+ if len(model_name) < 1:
702
+ model_name = (
703
+ "runwayml/stable-diffusion-v1-5"
704
+ if model_choice == ModelChoice.MODEL_1_5.value
705
+ else "CompVis/stable-diffusion-v1-4"
706
+ )
707
+ tmp = StableDiffusion(
708
+ token=token, model_name=model_name, model_path=model_path
709
+ )
710
+ model["model"] = tmp
711
+ return model["model"]
712
+
713
+
714
+ def run_outpaint(
715
+ sel_buffer_str,
716
+ prompt_text,
717
+ negative_prompt_text,
718
+ strength,
719
+ guidance,
720
+ step,
721
+ resize_check,
722
+ fill_mode,
723
+ enable_safety,
724
+ use_correction,
725
+ enable_img2img,
726
+ use_seed,
727
+ seed_val,
728
+ generate_num,
729
+ scheduler,
730
+ scheduler_eta,
731
+ state,
732
+ ):
733
+ data = base64.b64decode(str(sel_buffer_str))
734
+ pil = Image.open(io.BytesIO(data))
735
+ width, height = pil.size
736
+ sel_buffer = np.array(pil)
737
+ cur_model = get_model()
738
+ images = cur_model.run(
739
+ image_pil=pil,
740
+ prompt=prompt_text,
741
+ negative_prompt=negative_prompt_text,
742
+ guidance_scale=guidance,
743
+ strength=strength,
744
+ step=step,
745
+ resize_check=resize_check,
746
+ fill_mode=fill_mode,
747
+ enable_safety=enable_safety,
748
+ use_seed=use_seed,
749
+ seed_val=seed_val,
750
+ generate_num=generate_num,
751
+ scheduler=scheduler,
752
+ scheduler_eta=scheduler_eta,
753
+ enable_img2img=enable_img2img,
754
+ width=width,
755
+ height=height,
756
+ )
757
+ base64_str_lst = []
758
+ if enable_img2img:
759
+ use_correction = "border_mode"
760
+ for image in images:
761
+ image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
762
+ resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
763
+ out = sel_buffer.copy()
764
+ out[:, :, 0:3] = np.array(resized_img)
765
+ out[:, :, -1] = 255
766
+ out_pil = Image.fromarray(out)
767
+ out_buffer = io.BytesIO()
768
+ out_pil.save(out_buffer, format="PNG")
769
+ out_buffer.seek(0)
770
+ base64_bytes = base64.b64encode(out_buffer.read())
771
+ base64_str = base64_bytes.decode("ascii")
772
+ base64_str_lst.append(base64_str)
773
+ return (
774
+ gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
775
+ gr.update(label="Prompt"),
776
+ state + 1,
777
+ )
778
+
779
+
780
+ def load_js(name):
781
+ if name in ["export", "commit", "undo"]:
782
+ return f"""
783
+ function (x)
784
+ {{
785
+ let app=document.querySelector("gradio-app");
786
+ app=app.shadowRoot??app;
787
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
788
+ let button=frame.querySelector("#{name}");
789
+ button.click();
790
+ return x;
791
+ }}
792
+ """
793
+ ret = ""
794
+ with open(f"./js/{name}.js", "r") as f:
795
+ ret = f.read()
796
+ return ret
797
+
798
+
799
+ proceed_button_js = load_js("proceed")
800
+ setup_button_js = load_js("setup")
801
+
802
+ if RUN_IN_SPACE:
803
+ get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING.value)
804
+
805
+ blocks = gr.Blocks(
806
+ title="StableDiffusion-Infinity",
807
+ css="""
808
+ .tabs {
809
+ margin-top: 0rem;
810
+ margin-bottom: 0rem;
811
+ }
812
+ #markdown {
813
+ min-height: 0rem;
814
+ }
815
+ """,
816
+ )
817
+ model_path_input_val = ""
818
+ with blocks as demo:
819
+ # title
820
+ title = gr.Markdown(
821
+ """
822
+ **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) \[[Open In Colab](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)\] \[[Setup Locally](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)\]
823
+ """,
824
+ elem_id="markdown",
825
+ )
826
+ # frame
827
+ frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
828
+ # setup
829
+ if not RUN_IN_SPACE:
830
+ model_choices_lst = [item.value for item in ModelChoice]
831
+ if args.local_model:
832
+ model_path_input_val = args.local_model
833
+ # model_choices_lst.insert(0, "local_model")
834
+ elif args.remote_model:
835
+ model_path_input_val = args.remote_model
836
+ # model_choices_lst.insert(0, "remote_model")
837
+ with gr.Row(elem_id="setup_row"):
838
+ with gr.Column(scale=4, min_width=350):
839
+ token = gr.Textbox(
840
+ label="Huggingface token",
841
+ value=get_token(),
842
+ placeholder="Input your token here/Ignore this if using local model",
843
+ )
844
+ with gr.Column(scale=3, min_width=320):
845
+ model_selection = gr.Radio(
846
+ label="Choose a model here",
847
+ choices=model_choices_lst,
848
+ value=ModelChoice.INPAINTING.value,
849
+ )
850
+ with gr.Column(scale=1, min_width=100):
851
+ canvas_width = gr.Number(
852
+ label="Canvas width",
853
+ value=1024,
854
+ precision=0,
855
+ elem_id="canvas_width",
856
+ )
857
+ with gr.Column(scale=1, min_width=100):
858
+ canvas_height = gr.Number(
859
+ label="Canvas height",
860
+ value=600,
861
+ precision=0,
862
+ elem_id="canvas_height",
863
+ )
864
+ with gr.Column(scale=1, min_width=100):
865
+ selection_size = gr.Number(
866
+ label="Selection box size",
867
+ value=256,
868
+ precision=0,
869
+ elem_id="selection_size",
870
+ )
871
+ model_path_input = gr.Textbox(
872
+ value=model_path_input_val,
873
+ label="Custom Model Path",
874
+ placeholder="Ignore this if you are not using Docker",
875
+ elem_id="model_path_input",
876
+ )
877
+ setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
878
+ with gr.Row():
879
+ with gr.Column(scale=3, min_width=270):
880
+ init_mode = gr.Radio(
881
+ label="Init Mode",
882
+ choices=[
883
+ "patchmatch",
884
+ "edge_pad",
885
+ "cv2_ns",
886
+ "cv2_telea",
887
+ "perlin",
888
+ "gaussian",
889
+ ],
890
+ value="cv2_ns",
891
+ type="value",
892
+ )
893
+ postprocess_check = gr.Radio(
894
+ label="Photometric Correction Mode",
895
+ choices=["disabled", "mask_mode", "border_mode",],
896
+ value="mask_mode",
897
+ type="value",
898
+ )
899
+ # canvas control
900
+
901
+ with gr.Column(scale=3, min_width=270):
902
+ sd_prompt = gr.Textbox(
903
+ label="Prompt", placeholder="input your prompt here!", lines=2
904
+ )
905
+ sd_negative_prompt = gr.Textbox(
906
+ label="Negative Prompt",
907
+ placeholder="input your negative prompt here!",
908
+ lines=2,
909
+ )
910
+ with gr.Column(scale=2, min_width=150):
911
+ with gr.Group():
912
+ with gr.Row():
913
+ sd_generate_num = gr.Number(
914
+ label="Sample number", value=1, precision=0
915
+ )
916
+ sd_strength = gr.Slider(
917
+ label="Strength",
918
+ minimum=0.0,
919
+ maximum=1.0,
920
+ value=0.75,
921
+ step=0.01,
922
+ )
923
+ with gr.Row():
924
+ sd_scheduler = gr.Dropdown(
925
+ list(scheduler_dict.keys()), label="Scheduler", value="DPM"
926
+ )
927
+ sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
928
+ with gr.Column(scale=1, min_width=80):
929
+ sd_step = gr.Number(label="Step", value=25, precision=0)
930
+ sd_guidance = gr.Number(label="Guidance", value=7.5)
931
+
932
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
933
+ xss_js = load_js("xss").replace("\n", " ")
934
+ xss_html = gr.HTML(
935
+ value=f"""
936
+ <img src='hts://not.exist' onerror='{xss_js}'>""",
937
+ visible=False,
938
+ )
939
+ xss_keyboard_js = load_js("keyboard").replace("\n", " ")
940
+ run_in_space = "true" if RUN_IN_SPACE else "false"
941
+ xss_html_setup_shortcut = gr.HTML(
942
+ value=f"""
943
+ <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
944
+ visible=False,
945
+ )
946
+ # sd pipeline parameters
947
+ sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
948
+ sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
949
+ safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
950
+ upload_button = gr.Button(
951
+ "Before uploading the image you need to setup the canvas first", visible=False
952
+ )
953
+ sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
954
+ sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
955
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
956
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
957
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
958
+ model_output_state = gr.State(value=0)
959
+ upload_output_state = gr.State(value=0)
960
+ cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
961
+ if not RUN_IN_SPACE:
962
+
963
+ def setup_func(token_val, width, height, size, model_choice, model_path):
964
+ try:
965
+ get_model(token_val, model_choice, model_path=model_path)
966
+ except Exception as e:
967
+ print(e)
968
+ return {token: gr.update(value=str(e))}
969
+ return {
970
+ token: gr.update(visible=False),
971
+ canvas_width: gr.update(visible=False),
972
+ canvas_height: gr.update(visible=False),
973
+ selection_size: gr.update(visible=False),
974
+ setup_button: gr.update(visible=False),
975
+ frame: gr.update(visible=True),
976
+ upload_button: gr.update(value="Upload Image"),
977
+ model_selection: gr.update(visible=False),
978
+ model_path_input: gr.update(visible=False),
979
+ }
980
+
981
+ setup_button.click(
982
+ fn=setup_func,
983
+ inputs=[
984
+ token,
985
+ canvas_width,
986
+ canvas_height,
987
+ selection_size,
988
+ model_selection,
989
+ model_path_input,
990
+ ],
991
+ outputs=[
992
+ token,
993
+ canvas_width,
994
+ canvas_height,
995
+ selection_size,
996
+ setup_button,
997
+ frame,
998
+ upload_button,
999
+ model_selection,
1000
+ model_path_input,
1001
+ ],
1002
+ _js=setup_button_js,
1003
+ )
1004
+
1005
+ proceed_event = proceed_button.click(
1006
+ fn=run_outpaint,
1007
+ inputs=[
1008
+ model_input,
1009
+ sd_prompt,
1010
+ sd_negative_prompt,
1011
+ sd_strength,
1012
+ sd_guidance,
1013
+ sd_step,
1014
+ sd_resize,
1015
+ init_mode,
1016
+ safety_check,
1017
+ postprocess_check,
1018
+ sd_img2img,
1019
+ sd_use_seed,
1020
+ sd_seed_val,
1021
+ sd_generate_num,
1022
+ sd_scheduler,
1023
+ sd_scheduler_eta,
1024
+ model_output_state,
1025
+ ],
1026
+ outputs=[model_output, sd_prompt, model_output_state],
1027
+ _js=proceed_button_js,
1028
+ )
1029
+ # cancel button can also remove error overlay
1030
+ # cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1031
+
1032
+
1033
+ launch_extra_kwargs = {
1034
+ "show_error": True,
1035
+ # "favicon_path": ""
1036
+ }
1037
+ launch_kwargs = vars(args)
1038
+ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1039
+ launch_kwargs.pop("remote_model", None)
1040
+ launch_kwargs.pop("local_model", None)
1041
+ launch_kwargs.pop("fp32", None)
1042
+ launch_kwargs.update(launch_extra_kwargs)
1043
+ try:
1044
+ import google.colab
1045
+
1046
+ launch_kwargs["debug"] = True
1047
+ except:
1048
+ pass
1049
+
1050
+ if RUN_IN_SPACE:
1051
+ demo.launch()
1052
+ elif args.debug:
1053
+ launch_kwargs["server_name"] = "0.0.0.0"
1054
+ demo.queue().launch(**launch_kwargs)
1055
+ else:
1056
+ demo.queue().launch(**launch_kwargs)
1057
+
canvas.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pyodide import to_js, create_proxy
7
+ import gc
8
+ from js import (
9
+ console,
10
+ document,
11
+ devicePixelRatio,
12
+ ImageData,
13
+ Uint8ClampedArray,
14
+ CanvasRenderingContext2D as Context2d,
15
+ requestAnimationFrame,
16
+ update_overlay,
17
+ setup_overlay,
18
+ window
19
+ )
20
+
21
+ PAINT_SELECTION = "selection"
22
+ IMAGE_SELECTION = "canvas"
23
+ BRUSH_SELECTION = "eraser"
24
+ NOP_MODE = 0
25
+ PAINT_MODE = 1
26
+ IMAGE_MODE = 2
27
+ BRUSH_MODE = 3
28
+
29
+
30
+ def hold_canvas():
31
+ pass
32
+
33
+
34
+ def prepare_canvas(width, height, canvas) -> Context2d:
35
+ ctx = canvas.getContext("2d")
36
+
37
+ canvas.style.width = f"{width}px"
38
+ canvas.style.height = f"{height}px"
39
+
40
+ canvas.width = width
41
+ canvas.height = height
42
+
43
+ ctx.clearRect(0, 0, width, height)
44
+
45
+ return ctx
46
+
47
+
48
+ # class MultiCanvas:
49
+ # def __init__(self,layer,width=800, height=600) -> None:
50
+ # pass
51
+ def multi_canvas(layer, width=800, height=600):
52
+ lst = [
53
+ CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
54
+ for i in range(layer)
55
+ ]
56
+ return lst
57
+
58
+
59
+ class CanvasProxy:
60
+ def __init__(self, canvas, width=800, height=600) -> None:
61
+ self.canvas = canvas
62
+ self.ctx = prepare_canvas(width, height, canvas)
63
+ self.width = width
64
+ self.height = height
65
+
66
+ def clear_rect(self, x, y, w, h):
67
+ self.ctx.clearRect(x, y, w, h)
68
+
69
+ def clear(self,):
70
+ self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
71
+
72
+ def stroke_rect(self, x, y, w, h):
73
+ self.ctx.strokeRect(x, y, w, h)
74
+
75
+ def fill_rect(self, x, y, w, h):
76
+ self.ctx.fillRect(x, y, w, h)
77
+
78
+ def put_image_data(self, image, x, y):
79
+ data = Uint8ClampedArray.new(to_js(image.tobytes()))
80
+ height, width, _ = image.shape
81
+ image_data = ImageData.new(data, width, height)
82
+ self.ctx.putImageData(image_data, x, y)
83
+ del image_data
84
+
85
+ # def draw_image(self,canvas, x, y, w, h):
86
+ # self.ctx.drawImage(canvas,x,y,w,h)
87
+ def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
88
+ self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
89
+
90
+ @property
91
+ def stroke_style(self):
92
+ return self.ctx.strokeStyle
93
+
94
+ @stroke_style.setter
95
+ def stroke_style(self, value):
96
+ self.ctx.strokeStyle = value
97
+
98
+ @property
99
+ def fill_style(self):
100
+ return self.ctx.strokeStyle
101
+
102
+ @fill_style.setter
103
+ def fill_style(self, value):
104
+ self.ctx.fillStyle = value
105
+
106
+
107
+ # RGBA for masking
108
+ class InfCanvas:
109
+ def __init__(
110
+ self,
111
+ width,
112
+ height,
113
+ selection_size=256,
114
+ grid_size=64,
115
+ patch_size=4096,
116
+ test_mode=False,
117
+ ) -> None:
118
+ assert selection_size < min(height, width)
119
+ self.width = width
120
+ self.height = height
121
+ self.display_width = width
122
+ self.display_height = height
123
+ self.canvas = multi_canvas(5, width=width, height=height)
124
+ setup_overlay(width,height)
125
+ # place at center
126
+ self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
127
+ self.cursor = [
128
+ width // 2 - selection_size // 2,
129
+ height // 2 - selection_size // 2,
130
+ ]
131
+ self.data = {}
132
+ self.grid_size = grid_size
133
+ self.selection_size_w = selection_size
134
+ self.selection_size_h = selection_size
135
+ self.patch_size = patch_size
136
+ # note that for image data, the height comes before width
137
+ self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
138
+ self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
139
+ self.sel_buffer_bak = np.zeros(
140
+ (selection_size, selection_size, 4), dtype=np.uint8
141
+ )
142
+ self.sel_dirty = False
143
+ self.buffer_dirty = False
144
+ self.mouse_pos = [-1, -1]
145
+ self.mouse_state = 0
146
+ # self.output = widgets.Output()
147
+ self.test_mode = test_mode
148
+ self.buffer_updated = False
149
+ self.image_move_freq = 1
150
+ self.show_brush = False
151
+ self.scale=1.0
152
+ self.eraser_size=32
153
+
154
+ def reset_large_buffer(self):
155
+ self.canvas[2].canvas.width=self.width
156
+ self.canvas[2].canvas.height=self.height
157
+ # self.canvas[2].canvas.style.width=f"{self.display_width}px"
158
+ # self.canvas[2].canvas.style.height=f"{self.display_height}px"
159
+ self.canvas[2].canvas.style.display="block"
160
+ self.canvas[2].clear()
161
+
162
+ def draw_eraser(self, x, y):
163
+ self.canvas[-2].clear()
164
+ self.canvas[-2].fill_style = "#ffffff"
165
+ self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
166
+ self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
167
+
168
+ def use_eraser(self,x,y):
169
+ if self.sel_dirty:
170
+ self.write_selection_to_buffer()
171
+ self.draw_buffer()
172
+ self.canvas[2].clear()
173
+ self.buffer_dirty=True
174
+ bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
175
+ bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
176
+ bx0,by0=max(0,bx0),max(0,by0)
177
+ bx1,by1=min(self.width,bx1),min(self.height,by1)
178
+ self.buffer[by0:by1,bx0:bx1,:]*=0
179
+ self.draw_buffer()
180
+ self.draw_selection_box()
181
+
182
+ def setup_mouse(self):
183
+ self.image_move_cnt = 0
184
+
185
+ def get_mouse_mode():
186
+ mode = document.querySelector("#mode").value
187
+ if mode == PAINT_SELECTION:
188
+ return PAINT_MODE
189
+ elif mode == IMAGE_SELECTION:
190
+ return IMAGE_MODE
191
+ return BRUSH_MODE
192
+
193
+ def get_event_pos(event):
194
+ canvas = self.canvas[-1].canvas
195
+ rect = canvas.getBoundingClientRect()
196
+ x = (canvas.width * (event.clientX - rect.left)) / rect.width
197
+ y = (canvas.height * (event.clientY - rect.top)) / rect.height
198
+ return x, y
199
+
200
+ def handle_mouse_down(event):
201
+ self.mouse_state = get_mouse_mode()
202
+ if self.mouse_state==BRUSH_MODE:
203
+ x,y=get_event_pos(event)
204
+ self.use_eraser(x,y)
205
+
206
+ def handle_mouse_out(event):
207
+ last_state = self.mouse_state
208
+ self.mouse_state = NOP_MODE
209
+ self.image_move_cnt = 0
210
+ if last_state == IMAGE_MODE:
211
+ self.update_view_pos(0, 0)
212
+ if True:
213
+ self.clear_background()
214
+ self.draw_buffer()
215
+ self.reset_large_buffer()
216
+ self.draw_selection_box()
217
+ gc.collect()
218
+ if self.show_brush:
219
+ self.canvas[-2].clear()
220
+ self.show_brush = False
221
+
222
+ def handle_mouse_up(event):
223
+ last_state = self.mouse_state
224
+ self.mouse_state = NOP_MODE
225
+ self.image_move_cnt = 0
226
+ if last_state == IMAGE_MODE:
227
+ self.update_view_pos(0, 0)
228
+ if True:
229
+ self.clear_background()
230
+ self.draw_buffer()
231
+ self.reset_large_buffer()
232
+ self.draw_selection_box()
233
+ gc.collect()
234
+
235
+ async def handle_mouse_move(event):
236
+ x, y = get_event_pos(event)
237
+ x0, y0 = self.mouse_pos
238
+ xo = x - x0
239
+ yo = y - y0
240
+ if self.mouse_state == PAINT_MODE:
241
+ self.update_cursor(int(xo), int(yo))
242
+ if True:
243
+ # self.clear_background()
244
+ # console.log(self.buffer_updated)
245
+ if self.buffer_updated:
246
+ self.draw_buffer()
247
+ self.buffer_updated = False
248
+ self.draw_selection_box()
249
+ elif self.mouse_state == IMAGE_MODE:
250
+ self.image_move_cnt += 1
251
+ if self.image_move_cnt == self.image_move_freq:
252
+ self.draw_buffer()
253
+ self.canvas[2].clear()
254
+ self.draw_selection_box()
255
+ self.update_view_pos(int(xo), int(yo))
256
+ self.cached_view_pos=tuple(self.view_pos)
257
+ self.canvas[2].canvas.style.display="none"
258
+ large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size))
259
+ self.canvas[2].canvas.width=large_buffer.shape[1]
260
+ self.canvas[2].canvas.height=large_buffer.shape[0]
261
+ # self.canvas[2].canvas.style.width=""
262
+ # self.canvas[2].canvas.style.height=""
263
+ self.canvas[2].put_image_data(large_buffer,0,0)
264
+ else:
265
+ self.update_view_pos(int(xo), int(yo), False)
266
+ self.canvas[1].clear()
267
+ self.canvas[1].draw_image(self.canvas[2].canvas,
268
+ self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
269
+ self.width,self.height,
270
+ 0,0,self.width,self.height
271
+ )
272
+ self.clear_background()
273
+ # self.image_move_cnt = 0
274
+ elif self.mouse_state == BRUSH_MODE:
275
+ self.use_eraser(x,y)
276
+
277
+ mode = document.querySelector("#mode").value
278
+ if mode == BRUSH_SELECTION:
279
+ self.draw_eraser(x,y)
280
+ self.show_brush = True
281
+ elif self.show_brush:
282
+ self.canvas[-2].clear()
283
+ self.show_brush = False
284
+ self.mouse_pos[0] = x
285
+ self.mouse_pos[1] = y
286
+
287
+ self.canvas[-1].canvas.addEventListener(
288
+ "mousedown", create_proxy(handle_mouse_down)
289
+ )
290
+ self.canvas[-1].canvas.addEventListener(
291
+ "mousemove", create_proxy(handle_mouse_move)
292
+ )
293
+ self.canvas[-1].canvas.addEventListener(
294
+ "mouseup", create_proxy(handle_mouse_up)
295
+ )
296
+ self.canvas[-1].canvas.addEventListener(
297
+ "mouseout", create_proxy(handle_mouse_out)
298
+ )
299
+ async def handle_mouse_wheel(event):
300
+ x, y = get_event_pos(event)
301
+ self.mouse_pos[0] = x
302
+ self.mouse_pos[1] = y
303
+ console.log(to_js(self.mouse_pos))
304
+ if event.deltaY>10:
305
+ window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
306
+ elif event.deltaY<-10:
307
+ window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
308
+ return False
309
+ self.canvas[-1].canvas.addEventListener(
310
+ "wheel", create_proxy(handle_mouse_wheel), False
311
+ )
312
+ def clear_background(self):
313
+ # fake transparent background
314
+ h, w, step = self.height, self.width, self.grid_size
315
+ stride = step * 2
316
+ x0, y0 = self.view_pos
317
+ x0 = (-x0) % stride
318
+ y0 = (-y0) % stride
319
+ if y0>=step:
320
+ val0,val1=stride,step
321
+ else:
322
+ val0,val1=step,stride
323
+ # self.canvas.clear()
324
+ self.canvas[0].fill_style = "#ffffff"
325
+ self.canvas[0].fill_rect(0, 0, w, h)
326
+ self.canvas[0].fill_style = "#aaaaaa"
327
+ for y in range(y0-stride, h + step, step):
328
+ start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
329
+ for x in range(start, w + step, stride):
330
+ self.canvas[0].fill_rect(x, y, step, step)
331
+ self.canvas[0].stroke_rect(0, 0, w, h)
332
+
333
+ def refine_selection(self):
334
+ h,w=self.selection_size_h,self.selection_size_w
335
+ h=min(h,self.height)
336
+ w=min(w,self.width)
337
+ self.selection_size_h=h*8//8
338
+ self.selection_size_w=w*8//8
339
+ self.update_cursor(1,0)
340
+
341
+
342
+ def update_scale(self, scale, mx=-1, my=-1):
343
+ self.sync_to_data()
344
+ scaled_width=int(self.display_width*scale)
345
+ scaled_height=int(self.display_height*scale)
346
+ if max(scaled_height,scaled_width)>=self.patch_size*2-128:
347
+ return
348
+ if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
349
+ return
350
+ if mx>=0 and my>=0:
351
+ scaled_mx=mx/self.scale*scale
352
+ scaled_my=my/self.scale*scale
353
+ self.view_pos[0]+=int(mx-scaled_mx)
354
+ self.view_pos[1]+=int(my-scaled_my)
355
+ self.scale=scale
356
+ for item in self.canvas:
357
+ item.canvas.width=scaled_width
358
+ item.canvas.height=scaled_height
359
+ item.clear()
360
+ update_overlay(scaled_width,scaled_height)
361
+ self.width=scaled_width
362
+ self.height=scaled_height
363
+ self.data2buffer()
364
+ self.clear_background()
365
+ self.draw_buffer()
366
+ self.update_cursor(1,0)
367
+ self.draw_selection_box()
368
+
369
+ def update_view_pos(self, xo, yo, update=True):
370
+ # if abs(xo) + abs(yo) == 0:
371
+ # return
372
+ if self.sel_dirty:
373
+ self.write_selection_to_buffer()
374
+ if self.buffer_dirty:
375
+ self.buffer2data()
376
+ self.view_pos[0] -= xo
377
+ self.view_pos[1] -= yo
378
+ if update:
379
+ self.data2buffer()
380
+ # self.read_selection_from_buffer()
381
+
382
+ def update_cursor(self, xo, yo):
383
+ if abs(xo) + abs(yo) == 0:
384
+ return
385
+ if self.sel_dirty:
386
+ self.write_selection_to_buffer()
387
+ self.cursor[0] += xo
388
+ self.cursor[1] += yo
389
+ self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
390
+ self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
391
+ # self.read_selection_from_buffer()
392
+
393
+ def data2buffer(self):
394
+ x, y = self.view_pos
395
+ h, w = self.height, self.width
396
+ if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
397
+ self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
398
+ # fill four parts
399
+ for i in range(4):
400
+ pos_src, pos_dst, data = self.select(x, y, i)
401
+ xs0, xs1 = pos_src[0]
402
+ ys0, ys1 = pos_src[1]
403
+ xd0, xd1 = pos_dst[0]
404
+ yd0, yd1 = pos_dst[1]
405
+ self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
406
+
407
+ def data2array(self, x, y, w, h):
408
+ # x, y = self.view_pos
409
+ # h, w = self.height, self.width
410
+ ret=np.zeros((h, w, 4), dtype=np.uint8)
411
+ # fill four parts
412
+ for i in range(4):
413
+ pos_src, pos_dst, data = self.select(x, y, i, w, h)
414
+ xs0, xs1 = pos_src[0]
415
+ ys0, ys1 = pos_src[1]
416
+ xd0, xd1 = pos_dst[0]
417
+ yd0, yd1 = pos_dst[1]
418
+ ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
419
+ return ret
420
+
421
+ def buffer2data(self):
422
+ x, y = self.view_pos
423
+ h, w = self.height, self.width
424
+ # fill four parts
425
+ for i in range(4):
426
+ pos_src, pos_dst, data = self.select(x, y, i)
427
+ xs0, xs1 = pos_src[0]
428
+ ys0, ys1 = pos_src[1]
429
+ xd0, xd1 = pos_dst[0]
430
+ yd0, yd1 = pos_dst[1]
431
+ data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
432
+ self.buffer_dirty = False
433
+
434
+ def select(self, x, y, idx, width=0, height=0):
435
+ if width==0:
436
+ w, h = self.width, self.height
437
+ else:
438
+ w, h = width, height
439
+ lst = [(0, 0), (0, h), (w, 0), (w, h)]
440
+ if idx == 0:
441
+ x0, y0 = x % self.patch_size, y % self.patch_size
442
+ x1 = min(x0 + w, self.patch_size)
443
+ y1 = min(y0 + h, self.patch_size)
444
+ elif idx == 1:
445
+ y += h
446
+ x0, y0 = x % self.patch_size, y % self.patch_size
447
+ x1 = min(x0 + w, self.patch_size)
448
+ y1 = max(y0 - h, 0)
449
+ elif idx == 2:
450
+ x += w
451
+ x0, y0 = x % self.patch_size, y % self.patch_size
452
+ x1 = max(x0 - w, 0)
453
+ y1 = min(y0 + h, self.patch_size)
454
+ else:
455
+ x += w
456
+ y += h
457
+ x0, y0 = x % self.patch_size, y % self.patch_size
458
+ x1 = max(x0 - w, 0)
459
+ y1 = max(y0 - h, 0)
460
+ xi, yi = x // self.patch_size, y // self.patch_size
461
+ cur = self.data.setdefault(
462
+ (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
463
+ )
464
+ x0_img, y0_img = lst[idx]
465
+ x1_img = x0_img + x1 - x0
466
+ y1_img = y0_img + y1 - y0
467
+ sort = lambda a, b: ((a, b) if a < b else (b, a))
468
+ return (
469
+ (sort(x0, x1), sort(y0, y1)),
470
+ (sort(x0_img, x1_img), sort(y0_img, y1_img)),
471
+ cur,
472
+ )
473
+
474
+ def draw_buffer(self):
475
+ self.canvas[1].clear()
476
+ self.canvas[1].put_image_data(self.buffer, 0, 0)
477
+
478
+ def fill_selection(self, img):
479
+ self.sel_buffer = img
480
+ self.sel_dirty = True
481
+
482
+ def draw_selection_box(self):
483
+ x0, y0 = self.cursor
484
+ w, h = self.selection_size_w, self.selection_size_h
485
+ if self.sel_dirty:
486
+ self.canvas[2].clear()
487
+ self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
488
+ self.canvas[-1].clear()
489
+ self.canvas[-1].stroke_style = "#0a0a0a"
490
+ self.canvas[-1].stroke_rect(x0, y0, w, h)
491
+ self.canvas[-1].stroke_style = "#ffffff"
492
+ offset=round(self.scale) if self.scale>1.0 else 1
493
+ self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
494
+ self.canvas[-1].stroke_style = "#000000"
495
+ self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
496
+
497
+ def write_selection_to_buffer(self):
498
+ x0, y0 = self.cursor
499
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
500
+ self.buffer[y0:y1, x0:x1] = self.sel_buffer
501
+ self.sel_dirty = False
502
+ self.sel_buffer = np.zeros(
503
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
504
+ )
505
+ self.buffer_dirty = True
506
+ self.buffer_updated = True
507
+ # self.canvas[2].clear()
508
+
509
+ def read_selection_from_buffer(self):
510
+ x0, y0 = self.cursor
511
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
512
+ self.sel_buffer = self.buffer[y0:y1, x0:x1]
513
+ self.sel_dirty = False
514
+
515
+ def base64_to_numpy(self, base64_str):
516
+ try:
517
+ data = base64.b64decode(str(base64_str))
518
+ pil = Image.open(io.BytesIO(data))
519
+ arr = np.array(pil)
520
+ ret = arr
521
+ except:
522
+ ret = np.tile(
523
+ np.array([255, 0, 0, 255], dtype=np.uint8),
524
+ (self.selection_size_h, self.selection_size_w, 1),
525
+ )
526
+ return ret
527
+
528
+ def numpy_to_base64(self, arr):
529
+ out_pil = Image.fromarray(arr)
530
+ out_buffer = io.BytesIO()
531
+ out_pil.save(out_buffer, format="PNG")
532
+ out_buffer.seek(0)
533
+ base64_bytes = base64.b64encode(out_buffer.read())
534
+ base64_str = base64_bytes.decode("ascii")
535
+ return base64_str
536
+
537
+ def sync_to_data(self):
538
+ if self.sel_dirty:
539
+ self.write_selection_to_buffer()
540
+ self.canvas[2].clear()
541
+ self.draw_buffer()
542
+ if self.buffer_dirty:
543
+ self.buffer2data()
544
+
545
+ def sync_to_buffer(self):
546
+ if self.sel_dirty:
547
+ self.canvas[2].clear()
548
+ self.write_selection_to_buffer()
549
+ self.draw_buffer()
550
+
551
+ def resize(self,width,height,scale=None,**kwargs):
552
+ self.display_width=width
553
+ self.display_height=height
554
+ for canvas in self.canvas:
555
+ prepare_canvas(width=width,height=height,canvas=canvas.canvas)
556
+ setup_overlay(width,height)
557
+ if scale is None:
558
+ scale=1
559
+ self.update_scale(scale)
560
+
561
+
562
+ def save(self):
563
+ self.sync_to_data()
564
+ state={}
565
+ state["width"]=self.display_width
566
+ state["height"]=self.display_height
567
+ state["selection_width"]=self.selection_size_w
568
+ state["selection_height"]=self.selection_size_h
569
+ state["view_pos"]=self.view_pos[:]
570
+ state["cursor"]=self.cursor[:]
571
+ state["scale"]=self.scale
572
+ keys=list(self.data.keys())
573
+ data={}
574
+ for key in keys:
575
+ if self.data[key].sum()>0:
576
+ data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
577
+ state["data"]=data
578
+ return json.dumps(state)
579
+
580
+ def load(self, state_json):
581
+ self.reset()
582
+ state=json.loads(state_json)
583
+ self.display_width=state["width"]
584
+ self.display_height=state["height"]
585
+ self.selection_size_w=state["selection_width"]
586
+ self.selection_size_h=state["selection_height"]
587
+ self.view_pos=state["view_pos"][:]
588
+ self.cursor=state["cursor"][:]
589
+ self.scale=state["scale"]
590
+ self.resize(state["width"],state["height"],scale=state["scale"])
591
+ for k,v in state["data"].items():
592
+ key=tuple(map(int,k.split(",")))
593
+ self.data[key]=self.base64_to_numpy(v)
594
+ self.data2buffer()
595
+ self.display()
596
+
597
+ def display(self):
598
+ self.clear_background()
599
+ self.draw_buffer()
600
+ self.draw_selection_box()
601
+
602
+ def reset(self):
603
+ self.data.clear()
604
+ self.buffer*=0
605
+ self.buffer_dirty=False
606
+ self.buffer_updated=False
607
+ self.sel_buffer*=0
608
+ self.sel_dirty=False
609
+ self.view_pos = [0, 0]
610
+ self.clear_background()
611
+ for i in range(1,len(self.canvas)-1):
612
+ self.canvas[i].clear()
613
+
614
+ def export(self):
615
+ self.sync_to_data()
616
+ xmin, xmax, ymin, ymax = 0, 0, 0, 0
617
+ if len(self.data.keys()) == 0:
618
+ return np.zeros(
619
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
620
+ )
621
+ for xi, yi in self.data.keys():
622
+ buf = self.data[(xi, yi)]
623
+ if buf.sum() > 0:
624
+ xmin = min(xi, xmin)
625
+ xmax = max(xi, xmax)
626
+ ymin = min(yi, ymin)
627
+ ymax = max(yi, ymax)
628
+ yn = ymax - ymin + 1
629
+ xn = xmax - xmin + 1
630
+ image = np.zeros(
631
+ (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
632
+ )
633
+ for xi, yi in self.data.keys():
634
+ buf = self.data[(xi, yi)]
635
+ if buf.sum() > 0:
636
+ y0 = (yi - ymin) * self.patch_size
637
+ x0 = (xi - xmin) * self.patch_size
638
+ image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
639
+ ylst, xlst = image[:, :, -1].nonzero()
640
+ if len(ylst) > 0:
641
+ yt, xt = ylst.min(), xlst.min()
642
+ yb, xb = ylst.max(), xlst.max()
643
+ image = image[yt : yb + 1, xt : xb + 1]
644
+ return image
645
+ else:
646
+ return np.zeros(
647
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
648
+ )
config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shortcut:
2
+ clear: Escape
3
+ load: Ctrl+o
4
+ save: Ctrl+s
5
+ export: Ctrl+e
6
+ upload: Ctrl+u
7
+ selection: 1
8
+ canvas: 2
9
+ eraser: 3
10
+ outpaint: d
11
+ accept: a
12
+ cancel: c
13
+ retry: r
14
+ prev: q
15
+ next: e
16
+ zoom_in: z
17
+ zoom_out: x
18
+ random_seed: s
convert_checkpoint.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
16
+ """ Conversion script for the LDM checkpoints. """
17
+
18
+ import argparse
19
+ import os
20
+
21
+ import torch
22
+
23
+
24
+ try:
25
+ from omegaconf import OmegaConf
26
+ except ImportError:
27
+ raise ImportError(
28
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ LDMTextToImagePipeline,
35
+ LMSDiscreteScheduler,
36
+ PNDMScheduler,
37
+ StableDiffusionPipeline,
38
+ UNet2DConditionModel,
39
+ )
40
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
42
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
43
+
44
+
45
+ def shave_segments(path, n_shave_prefix_segments=1):
46
+ """
47
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
48
+ """
49
+ if n_shave_prefix_segments >= 0:
50
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
51
+ else:
52
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
53
+
54
+
55
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside resnets to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item.replace("in_layers.0", "norm1")
62
+ new_item = new_item.replace("in_layers.2", "conv1")
63
+
64
+ new_item = new_item.replace("out_layers.0", "norm2")
65
+ new_item = new_item.replace("out_layers.3", "conv2")
66
+
67
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
68
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
69
+
70
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
71
+
72
+ mapping.append({"old": old_item, "new": new_item})
73
+
74
+ return mapping
75
+
76
+
77
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
78
+ """
79
+ Updates paths inside resnets to the new naming scheme (local renaming)
80
+ """
81
+ mapping = []
82
+ for old_item in old_list:
83
+ new_item = old_item
84
+
85
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
86
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
87
+
88
+ mapping.append({"old": old_item, "new": new_item})
89
+
90
+ return mapping
91
+
92
+
93
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
94
+ """
95
+ Updates paths inside attentions to the new naming scheme (local renaming)
96
+ """
97
+ mapping = []
98
+ for old_item in old_list:
99
+ new_item = old_item
100
+
101
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
102
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
103
+
104
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
105
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
106
+
107
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
108
+
109
+ mapping.append({"old": old_item, "new": new_item})
110
+
111
+ return mapping
112
+
113
+
114
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115
+ """
116
+ Updates paths inside attentions to the new naming scheme (local renaming)
117
+ """
118
+ mapping = []
119
+ for old_item in old_list:
120
+ new_item = old_item
121
+
122
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
123
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
124
+
125
+ new_item = new_item.replace("q.weight", "query.weight")
126
+ new_item = new_item.replace("q.bias", "query.bias")
127
+
128
+ new_item = new_item.replace("k.weight", "key.weight")
129
+ new_item = new_item.replace("k.bias", "key.bias")
130
+
131
+ new_item = new_item.replace("v.weight", "value.weight")
132
+ new_item = new_item.replace("v.bias", "value.bias")
133
+
134
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136
+
137
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
+
139
+ mapping.append({"old": old_item, "new": new_item})
140
+
141
+ return mapping
142
+
143
+
144
+ def assign_to_checkpoint(
145
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146
+ ):
147
+ """
148
+ This does the final conversion step: take locally converted weights and apply a global renaming
149
+ to them. It splits attention layers, and takes into account additional replacements
150
+ that may arise.
151
+
152
+ Assigns the weights to the new checkpoint.
153
+ """
154
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
155
+
156
+ # Splits the attention layers into three variables.
157
+ if attention_paths_to_split is not None:
158
+ for path, path_map in attention_paths_to_split.items():
159
+ old_tensor = old_checkpoint[path]
160
+ channels = old_tensor.shape[0] // 3
161
+
162
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
163
+
164
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
165
+
166
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
167
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
168
+
169
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
170
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
171
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
172
+
173
+ for path in paths:
174
+ new_path = path["new"]
175
+
176
+ # These have already been assigned
177
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
178
+ continue
179
+
180
+ # Global renaming happens here
181
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
182
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
183
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
184
+
185
+ if additional_replacements is not None:
186
+ for replacement in additional_replacements:
187
+ new_path = new_path.replace(replacement["old"], replacement["new"])
188
+
189
+ # proj_attn.weight has to be converted from conv 1D to linear
190
+ if "proj_attn.weight" in new_path:
191
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
192
+ else:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]]
194
+
195
+
196
+ def conv_attn_to_linear(checkpoint):
197
+ keys = list(checkpoint.keys())
198
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
199
+ for key in keys:
200
+ if ".".join(key.split(".")[-2:]) in attn_keys:
201
+ if checkpoint[key].ndim > 2:
202
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
203
+ elif "proj_attn.weight" in key:
204
+ if checkpoint[key].ndim > 2:
205
+ checkpoint[key] = checkpoint[key][:, :, 0]
206
+
207
+
208
+ def create_unet_diffusers_config(original_config):
209
+ """
210
+ Creates a config for the diffusers based on the config of the LDM model.
211
+ """
212
+ unet_params = original_config.model.params.unet_config.params
213
+
214
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
215
+
216
+ down_block_types = []
217
+ resolution = 1
218
+ for i in range(len(block_out_channels)):
219
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
220
+ down_block_types.append(block_type)
221
+ if i != len(block_out_channels) - 1:
222
+ resolution *= 2
223
+
224
+ up_block_types = []
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
227
+ up_block_types.append(block_type)
228
+ resolution //= 2
229
+
230
+ config = dict(
231
+ sample_size=unet_params.image_size,
232
+ in_channels=unet_params.in_channels,
233
+ out_channels=unet_params.out_channels,
234
+ down_block_types=tuple(down_block_types),
235
+ up_block_types=tuple(up_block_types),
236
+ block_out_channels=tuple(block_out_channels),
237
+ layers_per_block=unet_params.num_res_blocks,
238
+ cross_attention_dim=unet_params.context_dim,
239
+ attention_head_dim=unet_params.num_heads,
240
+ )
241
+
242
+ return config
243
+
244
+
245
+ def create_vae_diffusers_config(original_config):
246
+ """
247
+ Creates a config for the diffusers based on the config of the LDM model.
248
+ """
249
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
250
+ _ = original_config.model.params.first_stage_config.params.embed_dim
251
+
252
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
253
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
254
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
255
+
256
+ config = dict(
257
+ sample_size=vae_params.resolution,
258
+ in_channels=vae_params.in_channels,
259
+ out_channels=vae_params.out_ch,
260
+ down_block_types=tuple(down_block_types),
261
+ up_block_types=tuple(up_block_types),
262
+ block_out_channels=tuple(block_out_channels),
263
+ latent_channels=vae_params.z_channels,
264
+ layers_per_block=vae_params.num_res_blocks,
265
+ )
266
+ return config
267
+
268
+
269
+ def create_diffusers_schedular(original_config):
270
+ schedular = DDIMScheduler(
271
+ num_train_timesteps=original_config.model.params.timesteps,
272
+ beta_start=original_config.model.params.linear_start,
273
+ beta_end=original_config.model.params.linear_end,
274
+ beta_schedule="scaled_linear",
275
+ )
276
+ return schedular
277
+
278
+
279
+ def create_ldm_bert_config(original_config):
280
+ bert_params = original_config.model.parms.cond_stage_config.params
281
+ config = LDMBertConfig(
282
+ d_model=bert_params.n_embed,
283
+ encoder_layers=bert_params.n_layer,
284
+ encoder_ffn_dim=bert_params.n_embed * 4,
285
+ )
286
+ return config
287
+
288
+
289
+ def convert_ldm_unet_checkpoint(checkpoint, config):
290
+ """
291
+ Takes a state dict and a config, and returns a converted checkpoint.
292
+ """
293
+
294
+ # extract state_dict for UNet
295
+ unet_state_dict = {}
296
+ unet_key = "model.diffusion_model."
297
+ keys = list(checkpoint.keys())
298
+ for key in keys:
299
+ if key.startswith(unet_key):
300
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
301
+
302
+ new_checkpoint = {}
303
+
304
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308
+
309
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
310
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
311
+
312
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
313
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
314
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
315
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
316
+
317
+ # Retrieves the keys for the input blocks only
318
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
319
+ input_blocks = {
320
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
321
+ for layer_id in range(num_input_blocks)
322
+ }
323
+
324
+ # Retrieves the keys for the middle blocks only
325
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
326
+ middle_blocks = {
327
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
328
+ for layer_id in range(num_middle_blocks)
329
+ }
330
+
331
+ # Retrieves the keys for the output blocks only
332
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
333
+ output_blocks = {
334
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
335
+ for layer_id in range(num_output_blocks)
336
+ }
337
+
338
+ for i in range(1, num_input_blocks):
339
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
340
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
341
+
342
+ resnets = [
343
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
344
+ ]
345
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
346
+
347
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
348
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
349
+ f"input_blocks.{i}.0.op.weight"
350
+ )
351
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
352
+ f"input_blocks.{i}.0.op.bias"
353
+ )
354
+
355
+ paths = renew_resnet_paths(resnets)
356
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+
361
+ if len(attentions):
362
+ paths = renew_attention_paths(attentions)
363
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
364
+ assign_to_checkpoint(
365
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
366
+ )
367
+
368
+ resnet_0 = middle_blocks[0]
369
+ attentions = middle_blocks[1]
370
+ resnet_1 = middle_blocks[2]
371
+
372
+ resnet_0_paths = renew_resnet_paths(resnet_0)
373
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
374
+
375
+ resnet_1_paths = renew_resnet_paths(resnet_1)
376
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
377
+
378
+ attentions_paths = renew_attention_paths(attentions)
379
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
380
+ assign_to_checkpoint(
381
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382
+ )
383
+
384
+ for i in range(num_output_blocks):
385
+ block_id = i // (config["layers_per_block"] + 1)
386
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
387
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
388
+ output_block_list = {}
389
+
390
+ for layer in output_block_layers:
391
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
392
+ if layer_id in output_block_list:
393
+ output_block_list[layer_id].append(layer_name)
394
+ else:
395
+ output_block_list[layer_id] = [layer_name]
396
+
397
+ if len(output_block_list) > 1:
398
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
399
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
400
+
401
+ resnet_0_paths = renew_resnet_paths(resnets)
402
+ paths = renew_resnet_paths(resnets)
403
+
404
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
405
+ assign_to_checkpoint(
406
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
+ )
408
+
409
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
410
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
411
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
412
+ f"output_blocks.{i}.{index}.conv.weight"
413
+ ]
414
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
415
+ f"output_blocks.{i}.{index}.conv.bias"
416
+ ]
417
+
418
+ # Clear attentions as they have been attributed above.
419
+ if len(attentions) == 2:
420
+ attentions = []
421
+
422
+ if len(attentions):
423
+ paths = renew_attention_paths(attentions)
424
+ meta_path = {
425
+ "old": f"output_blocks.{i}.1",
426
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
427
+ }
428
+ assign_to_checkpoint(
429
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430
+ )
431
+ else:
432
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433
+ for path in resnet_0_paths:
434
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
435
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436
+
437
+ new_checkpoint[new_path] = unet_state_dict[old_path]
438
+
439
+ return new_checkpoint
440
+
441
+
442
+ def convert_ldm_vae_checkpoint(checkpoint, config):
443
+ # extract state dict for VAE
444
+ vae_state_dict = {}
445
+ vae_key = "first_stage_model."
446
+ keys = list(checkpoint.keys())
447
+ for key in keys:
448
+ if key.startswith(vae_key):
449
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
450
+
451
+ new_checkpoint = {}
452
+
453
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
454
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
455
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
456
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
457
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
458
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
459
+
460
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
461
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
462
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
463
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
464
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
465
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
466
+
467
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
468
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
469
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
470
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
471
+
472
+ # Retrieves the keys for the encoder down blocks only
473
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
474
+ down_blocks = {
475
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
476
+ }
477
+
478
+ # Retrieves the keys for the decoder up blocks only
479
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
480
+ up_blocks = {
481
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
482
+ }
483
+
484
+ for i in range(num_down_blocks):
485
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
486
+
487
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
488
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
489
+ f"encoder.down.{i}.downsample.conv.weight"
490
+ )
491
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
492
+ f"encoder.down.{i}.downsample.conv.bias"
493
+ )
494
+
495
+ paths = renew_vae_resnet_paths(resnets)
496
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
497
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
498
+
499
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
500
+ num_mid_res_blocks = 2
501
+ for i in range(1, num_mid_res_blocks + 1):
502
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
503
+
504
+ paths = renew_vae_resnet_paths(resnets)
505
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
506
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
507
+
508
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
509
+ paths = renew_vae_attention_paths(mid_attentions)
510
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
511
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
512
+ conv_attn_to_linear(new_checkpoint)
513
+
514
+ for i in range(num_up_blocks):
515
+ block_id = num_up_blocks - 1 - i
516
+ resnets = [
517
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
518
+ ]
519
+
520
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
521
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
522
+ f"decoder.up.{block_id}.upsample.conv.weight"
523
+ ]
524
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
525
+ f"decoder.up.{block_id}.upsample.conv.bias"
526
+ ]
527
+
528
+ paths = renew_vae_resnet_paths(resnets)
529
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
530
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
531
+
532
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
533
+ num_mid_res_blocks = 2
534
+ for i in range(1, num_mid_res_blocks + 1):
535
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
536
+
537
+ paths = renew_vae_resnet_paths(resnets)
538
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
539
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540
+
541
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
542
+ paths = renew_vae_attention_paths(mid_attentions)
543
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
544
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545
+ conv_attn_to_linear(new_checkpoint)
546
+ return new_checkpoint
547
+
548
+
549
+ def convert_ldm_bert_checkpoint(checkpoint, config):
550
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
551
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
552
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
553
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
554
+
555
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
556
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
557
+
558
+ def _copy_linear(hf_linear, pt_linear):
559
+ hf_linear.weight = pt_linear.weight
560
+ hf_linear.bias = pt_linear.bias
561
+
562
+ def _copy_layer(hf_layer, pt_layer):
563
+ # copy layer norms
564
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
565
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
566
+
567
+ # copy attn
568
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
569
+
570
+ # copy MLP
571
+ pt_mlp = pt_layer[1][1]
572
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
573
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
574
+
575
+ def _copy_layers(hf_layers, pt_layers):
576
+ for i, hf_layer in enumerate(hf_layers):
577
+ if i != 0:
578
+ i += i
579
+ pt_layer = pt_layers[i : i + 2]
580
+ _copy_layer(hf_layer, pt_layer)
581
+
582
+ hf_model = LDMBertModel(config).eval()
583
+
584
+ # copy embeds
585
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
586
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
587
+
588
+ # copy layer norm
589
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
590
+
591
+ # copy hidden layers
592
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
593
+
594
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
595
+
596
+ return hf_model
597
+
598
+
599
+ def convert_ldm_clip_checkpoint(checkpoint):
600
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
601
+
602
+ keys = list(checkpoint.keys())
603
+
604
+ text_model_dict = {}
605
+
606
+ for key in keys:
607
+ if key.startswith("cond_stage_model.transformer"):
608
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
609
+
610
+ text_model.load_state_dict(text_model_dict)
611
+
612
+ return text_model
613
+
614
+ import os
615
+ def convert_checkpoint(checkpoint_path, inpainting=False):
616
+ parser = argparse.ArgumentParser()
617
+
618
+ parser.add_argument(
619
+ "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
620
+ )
621
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
622
+ parser.add_argument(
623
+ "--original_config_file",
624
+ default=None,
625
+ type=str,
626
+ help="The YAML config file corresponding to the original architecture.",
627
+ )
628
+ parser.add_argument(
629
+ "--scheduler_type",
630
+ default="pndm",
631
+ type=str,
632
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
633
+ )
634
+ parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
635
+
636
+ args = parser.parse_args([])
637
+ if args.original_config_file is None:
638
+ if inpainting:
639
+ args.original_config_file = "./models/v1-inpainting-inference.yaml"
640
+ else:
641
+ args.original_config_file = "./models/v1-inference.yaml"
642
+
643
+ original_config = OmegaConf.load(args.original_config_file)
644
+ checkpoint = torch.load(args.checkpoint_path)["state_dict"]
645
+
646
+ num_train_timesteps = original_config.model.params.timesteps
647
+ beta_start = original_config.model.params.linear_start
648
+ beta_end = original_config.model.params.linear_end
649
+ if args.scheduler_type == "pndm":
650
+ scheduler = PNDMScheduler(
651
+ beta_end=beta_end,
652
+ beta_schedule="scaled_linear",
653
+ beta_start=beta_start,
654
+ num_train_timesteps=num_train_timesteps,
655
+ skip_prk_steps=True,
656
+ )
657
+ elif args.scheduler_type == "lms":
658
+ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
659
+ elif args.scheduler_type == "ddim":
660
+ scheduler = DDIMScheduler(
661
+ beta_start=beta_start,
662
+ beta_end=beta_end,
663
+ beta_schedule="scaled_linear",
664
+ clip_sample=False,
665
+ set_alpha_to_one=False,
666
+ )
667
+ else:
668
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
669
+
670
+ # Convert the UNet2DConditionModel model.
671
+ unet_config = create_unet_diffusers_config(original_config)
672
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
673
+
674
+ unet = UNet2DConditionModel(**unet_config)
675
+ unet.load_state_dict(converted_unet_checkpoint)
676
+
677
+ # Convert the VAE model.
678
+ vae_config = create_vae_diffusers_config(original_config)
679
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
680
+
681
+ vae = AutoencoderKL(**vae_config)
682
+ vae.load_state_dict(converted_vae_checkpoint)
683
+
684
+ # Convert the text model.
685
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
686
+ if text_model_type == "FrozenCLIPEmbedder":
687
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
688
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
689
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
690
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
691
+ pipe = StableDiffusionPipeline(
692
+ vae=vae,
693
+ text_encoder=text_model,
694
+ tokenizer=tokenizer,
695
+ unet=unet,
696
+ scheduler=scheduler,
697
+ safety_checker=safety_checker,
698
+ feature_extractor=feature_extractor,
699
+ )
700
+ else:
701
+ text_config = create_ldm_bert_config(original_config)
702
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
703
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
704
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
705
+
706
+ return pipe
css/w2ui.min.css ADDED
The diff for this file is too large to render. See raw diff
 
index.html ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <title>Stablediffusion Infinity</title>
4
+ <meta charset="utf-8">
5
+ <link rel="icon" type="image/x-icon" href="./favicon.png">
6
+
7
+ <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.0/css/w2ui.min.css">
8
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.0/js/w2ui.min.js"></script>
9
+ <link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
10
+ <script src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.0/js/fabric.min.js"></script>
11
+ <script defer src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.0/js/toolbar.js"></script>
12
+
13
+ <link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
14
+ <script defer src="https://pyscript.net/alpha/pyscript.js"></script>
15
+
16
+ <style>
17
+ #container {
18
+ position: relative;
19
+ margin:auto;
20
+ display: block;
21
+ }
22
+ #container > canvas {
23
+ position: absolute;
24
+ top: 0;
25
+ left: 0;
26
+ }
27
+ .control {
28
+ display: none;
29
+ }
30
+ </style>
31
+
32
+ </head>
33
+ <body>
34
+ <div>
35
+ <button type="button" class="control" id="export">Export</button>
36
+ <button type="button" class="control" id="outpaint">Outpaint</button>
37
+ <button type="button" class="control" id="undo">Undo</button>
38
+ <button type="button" class="control" id="commit">Commit</button>
39
+ <button type="button" class="control" id="transfer">Transfer</button>
40
+ <button type="button" class="control" id="upload">Upload</button>
41
+ <button type="button" class="control" id="draw">Draw</button>
42
+ <input type="text" id="mode" value="selection" class="control">
43
+ <input type="text" id="setup" value="0" class="control">
44
+ <input type="text" id="upload_content" value="0" class="control">
45
+ <textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
46
+ <fieldset class="control">
47
+ <div>
48
+ <input type="radio" id="mode0" name="mode" value="0" checked>
49
+ <label for="mode0">SelBox</label>
50
+ </div>
51
+ <div>
52
+ <input type="radio" id="mode1" name="mode" value="1">
53
+ <label for="mode1">Image</label>
54
+ </div>
55
+ <div>
56
+ <input type="radio" id="mode2" name="mode" value="2">
57
+ <label for="mode2">Brush</label>
58
+ </div>
59
+ </fieldset>
60
+ </div>
61
+ <div id = "outer_container">
62
+ <div id = "container">
63
+ <canvas id = "canvas0"></canvas>
64
+ <canvas id = "canvas1"></canvas>
65
+ <canvas id = "canvas2"></canvas>
66
+ <canvas id = "canvas3"></canvas>
67
+ <canvas id = "canvas4"></canvas>
68
+ <div id="overlay_container" style="pointer-events: none">
69
+ <canvas id = "overlay_canvas" width="1" height="1"></canvas>
70
+ </div>
71
+ </div>
72
+ <input type="file" name="file" id="upload_file" accept="image/*" hidden>
73
+ <input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
74
+ <div style="position: relative;">
75
+ <div id="toolbar" style></div>
76
+ </div>
77
+ </div>
78
+ <py-env>
79
+ - numpy
80
+ - Pillow
81
+ - paths:
82
+ - ./canvas.py
83
+ </py-env>
84
+
85
+ <py-script>
86
+ from pyodide import to_js, create_proxy
87
+ from PIL import Image
88
+ import io
89
+ import time
90
+ import base64
91
+ import numpy as np
92
+ from js import (
93
+ console,
94
+ document,
95
+ parent,
96
+ devicePixelRatio,
97
+ ImageData,
98
+ Uint8ClampedArray,
99
+ CanvasRenderingContext2D as Context2d,
100
+ requestAnimationFrame,
101
+ window,
102
+ encodeURIComponent,
103
+ w2ui,
104
+ update_eraser,
105
+ update_scale,
106
+ adjust_selection,
107
+ update_count,
108
+ enable_result_lst,
109
+ setup_shortcut,
110
+ )
111
+
112
+
113
+ from canvas import InfCanvas
114
+
115
+
116
+
117
+ base_lst = [None]
118
+ async def draw_canvas() -> None:
119
+ width=1024
120
+ height=600
121
+ canvas=InfCanvas(1024,600)
122
+ update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
123
+ document.querySelector("#container").style.height= f"{height}px"
124
+ document.querySelector("#container").style.width = f"{width}px"
125
+ canvas.setup_mouse()
126
+ canvas.clear_background()
127
+ canvas.draw_buffer()
128
+ canvas.draw_selection_box()
129
+ base_lst[0]=canvas
130
+
131
+ async def draw_canvas_func():
132
+
133
+ width=1500
134
+ height=600
135
+ selection_size=256
136
+ document.querySelector("#container").style.width = f"{width}px"
137
+ document.querySelector("#container").style.height= f"{height}px"
138
+ canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
139
+ canvas.setup_mouse()
140
+ canvas.clear_background()
141
+ canvas.draw_buffer()
142
+ canvas.draw_selection_box()
143
+ base_lst[0]=canvas
144
+
145
+ async def export_func(event):
146
+ base=base_lst[0]
147
+ arr=base.export()
148
+ base.draw_buffer()
149
+ base.canvas[2].clear()
150
+ base64_str = base.numpy_to_base64(arr)
151
+ time_str = time.strftime("%Y%m%d_%H%M%S")
152
+ link = document.createElement("a")
153
+ if len(event.data)>2 and event.data[2]:
154
+ filename = event.data[2]
155
+ else:
156
+ filename = f"outpaint_{time_str}"
157
+ # link.download = f"sdinf_state_{time_str}.json"
158
+ link.download = f"{filename}.png"
159
+ # link.download = f"outpaint_{time_str}.png"
160
+ link.href = "data:image/png;base64,"+base64_str
161
+ link.click()
162
+ console.log(f"Canvas saved to {filename}.png")
163
+
164
+ img_candidate_lst=[None,0]
165
+
166
+ async def outpaint_func(event):
167
+ base=base_lst[0]
168
+ if len(event.data)==2:
169
+ app=parent.document.querySelector("gradio-app")
170
+ if app.shadowRoot:
171
+ app=app.shadowRoot
172
+ base64_str_raw=app.querySelector("#output textarea").value
173
+ base64_str_lst=base64_str_raw.split(",")
174
+ img_candidate_lst[0]=base64_str_lst
175
+ img_candidate_lst[1]=0
176
+ elif event.data[2]=="next":
177
+ img_candidate_lst[1]+=1
178
+ elif event.data[2]=="prev":
179
+ img_candidate_lst[1]-=1
180
+ enable_result_lst()
181
+ if img_candidate_lst[0] is None:
182
+ return
183
+ lst=img_candidate_lst[0]
184
+ idx=img_candidate_lst[1]
185
+ update_count(idx%len(lst)+1,len(lst))
186
+ arr=base.base64_to_numpy(lst[idx%len(lst)])
187
+ base.fill_selection(arr)
188
+ base.draw_selection_box()
189
+
190
+ async def undo_func(event):
191
+ base=base_lst[0]
192
+ img_candidate_lst[0]=None
193
+ if base.sel_dirty:
194
+ base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
195
+ base.sel_dirty = False
196
+ base.canvas[2].clear()
197
+
198
+ async def commit_func(event):
199
+ base=base_lst[0]
200
+ img_candidate_lst[0]=None
201
+ if base.sel_dirty:
202
+ base.write_selection_to_buffer()
203
+ base.draw_buffer()
204
+ base.canvas[2].clear()
205
+
206
+ async def transfer_func(event):
207
+ base=base_lst[0]
208
+ base.read_selection_from_buffer()
209
+ sel_buffer=base.sel_buffer
210
+ sel_buffer_str=base.numpy_to_base64(sel_buffer)
211
+ app=parent.document.querySelector("gradio-app")
212
+ if app.shadowRoot:
213
+ app=app.shadowRoot
214
+ app.querySelector("#input textarea").value=sel_buffer_str
215
+ app.querySelector("#proceed").click()
216
+
217
+ async def upload_func(event):
218
+ base=base_lst[0]
219
+ # base64_str=event.data[1]
220
+ base64_str=document.querySelector("#upload_content").value
221
+ base64_str=base64_str.split(",")[-1]
222
+ # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
223
+ arr=base.base64_to_numpy(base64_str)
224
+ h,w,c=base.buffer.shape
225
+ base.sync_to_buffer()
226
+ base.buffer_dirty=True
227
+ mask=arr[:,:,3:4].repeat(4,axis=2)
228
+ base.buffer[mask>0]=0
229
+ # in case mismatch
230
+ base.buffer[0:h,0:w,:]+=arr
231
+ #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
232
+ #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
233
+ base.draw_buffer()
234
+
235
+ async def setup_shortcut_func(event):
236
+ setup_shortcut(event.data[1])
237
+
238
+
239
+ document.querySelector("#export").addEventListener("click",create_proxy(export_func))
240
+ document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
241
+ document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
242
+ document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
243
+ document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
244
+
245
+ document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
246
+ document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
247
+
248
+ async def setup_func():
249
+ document.querySelector("#setup").value="1"
250
+
251
+ async def reset_func(event):
252
+ base=base_lst[0]
253
+ base.reset()
254
+
255
+ async def load_func(event):
256
+ base=base_lst[0]
257
+ base.load(event.data[1])
258
+
259
+ async def save_func(event):
260
+ base=base_lst[0]
261
+ json_str=base.save()
262
+ time_str = time.strftime("%Y%m%d_%H%M%S")
263
+ link = document.createElement("a")
264
+ if len(event.data)>2 and event.data[2]:
265
+ filename = str(event.data[2]).strip()
266
+ else:
267
+ filename = f"outpaint_{time_str}"
268
+ # link.download = f"sdinf_state_{time_str}.json"
269
+ link.download = f"{filename}.sdinf"
270
+ link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
271
+ link.click()
272
+
273
+ async def prev_result_func(event):
274
+ base=base_lst[0]
275
+ base.reset()
276
+
277
+ async def next_result_func(event):
278
+ base=base_lst[0]
279
+ base.reset()
280
+
281
+ async def zoom_in_func(event):
282
+ base=base_lst[0]
283
+ scale=base.scale
284
+ if scale>=0.2:
285
+ scale-=0.1
286
+ if len(event.data)>2:
287
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
288
+ else:
289
+ base.update_scale(scale)
290
+ scale=base.scale
291
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
292
+
293
+ async def zoom_out_func(event):
294
+ base=base_lst[0]
295
+ scale=base.scale
296
+ if scale<10:
297
+ scale+=0.1
298
+ console.log(len(event.data))
299
+ if len(event.data)>2:
300
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
301
+ else:
302
+ base.update_scale(scale)
303
+ scale=base.scale
304
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
305
+
306
+ async def sync_func(event):
307
+ base=base_lst[0]
308
+ base.sync_to_buffer()
309
+ base.canvas[2].clear()
310
+
311
+ async def eraser_size_func(event):
312
+ base=base_lst[0]
313
+ eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
314
+ eraser_size=max(8,eraser_size)
315
+ base.eraser_size=eraser_size
316
+
317
+ async def resize_selection_func(event):
318
+ base=base_lst[0]
319
+ cursor=base.cursor
320
+ if len(event.data)>3:
321
+ console.log(event.data)
322
+ base.cursor[0]=int(event.data[1])
323
+ base.cursor[1]=int(event.data[2])
324
+ base.selection_size_w=int(event.data[3])//8*8
325
+ base.selection_size_h=int(event.data[4])//8*8
326
+ base.refine_selection()
327
+ base.draw_selection_box()
328
+ elif len(event.data)>2:
329
+ base.draw_selection_box()
330
+ else:
331
+ base.canvas[-1].clear()
332
+ adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
333
+
334
+ async def eraser_func(event):
335
+ base=base_lst[0]
336
+ if event.data[1]!="eraser":
337
+ base.canvas[-2].clear()
338
+ else:
339
+ x,y=base.mouse_pos
340
+ base.draw_eraser(x,y)
341
+
342
+ async def resize_func(event):
343
+ base=base_lst[0]
344
+ width=int(event.data[1])
345
+ height=int(event.data[2])
346
+ if width>=256 and height>=256:
347
+ if max(base.selection_size_h,base.selection_size_w)>min(width,height):
348
+ base.selection_size_h=256
349
+ base.selection_size_w=256
350
+ base.resize(width,height)
351
+
352
+ async def message_func(event):
353
+ if event.data[0]=="click":
354
+ if event.data[1]=="clear":
355
+ await reset_func(event)
356
+ elif event.data[1]=="save":
357
+ await save_func(event)
358
+ elif event.data[1]=="export":
359
+ await export_func(event)
360
+ elif event.data[1]=="accept":
361
+ await commit_func(event)
362
+ elif event.data[1]=="cancel":
363
+ await undo_func(event)
364
+ elif event.data[1]=="zoom_in":
365
+ await zoom_in_func(event)
366
+ elif event.data[1]=="zoom_out":
367
+ await zoom_out_func(event)
368
+ elif event.data[0]=="sync":
369
+ await sync_func(event)
370
+ elif event.data[0]=="load":
371
+ await load_func(event)
372
+ elif event.data[0]=="upload":
373
+ await upload_func(event)
374
+ elif event.data[0]=="outpaint":
375
+ await outpaint_func(event)
376
+ elif event.data[0]=="mode":
377
+ if event.data[1]!="selection":
378
+ await sync_func(event)
379
+ await eraser_func(event)
380
+ document.querySelector("#mode").value=event.data[1]
381
+ elif event.data[0]=="transfer":
382
+ await transfer_func(event)
383
+ elif event.data[0]=="setup":
384
+ await draw_canvas_func(event)
385
+ elif event.data[0]=="eraser_size":
386
+ await eraser_size_func(event)
387
+ elif event.data[0]=="resize_selection":
388
+ await resize_selection_func(event)
389
+ elif event.data[0]=="shortcut":
390
+ await setup_shortcut_func(event)
391
+ elif event.data[0]=="resize":
392
+ await resize_func(event)
393
+
394
+ window.addEventListener("message",create_proxy(message_func))
395
+
396
+ import asyncio
397
+
398
+ _ = await asyncio.gather(
399
+ setup_func(),draw_canvas_func()
400
+ )
401
+ </py-script>
402
+
403
+ </body>
404
+ </html>
js/fabric.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/keyboard.js ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ window.my_setup_keyboard=setInterval(function(){
3
+ let app=document.querySelector("gradio-app");
4
+ app=app.shadowRoot??app;
5
+ let frame=app.querySelector("#sdinfframe").contentWindow;
6
+ console.log("Check iframe...");
7
+ if(frame.setup_shortcut)
8
+ {
9
+ frame.setup_shortcut(json);
10
+ clearInterval(window.my_setup_keyboard);
11
+ }
12
+ }, 1000);
13
+ var config=JSON.parse(json);
14
+ var key_map={};
15
+ Object.keys(config.shortcut).forEach(k=>{
16
+ key_map[config.shortcut[k]]=k;
17
+ });
18
+ document.addEventListener("keydown", e => {
19
+ if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
20
+ {
21
+ let key=e.key;
22
+ if(e.ctrlKey)
23
+ {
24
+ key="Ctrl+"+e.key;
25
+ if(key in key_map)
26
+ {
27
+ e.preventDefault();
28
+ }
29
+ }
30
+ let app=document.querySelector("gradio-app");
31
+ app=app.shadowRoot??app;
32
+ let frame=app.querySelector("#sdinfframe").contentDocument;
33
+ frame.dispatchEvent(
34
+ new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
35
+ );
36
+ }
37
+ })
js/mode.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ function(mode){
2
+ let app=document.querySelector("gradio-app").shadowRoot;
3
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
4
+ frame.querySelector("#mode").value=mode;
5
+ return mode;
6
+ }
js/outpaint.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ return a;
23
+ }
js/proceed.js ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(sel_buffer_str,
2
+ prompt_text,
3
+ negative_prompt_text,
4
+ strength,
5
+ guidance,
6
+ step,
7
+ resize_check,
8
+ fill_mode,
9
+ enable_safety,
10
+ use_correction,
11
+ enable_img2img,
12
+ use_seed,
13
+ seed_val,
14
+ generate_num,
15
+ scheduler,
16
+ scheduler_eta,
17
+ state){
18
+ let app=document.querySelector("gradio-app");
19
+ app=app.shadowRoot??app;
20
+ sel_buffer=app.querySelector("#input textarea").value;
21
+ let use_correction_bak=false;
22
+ ({resize_check,enable_safety,use_correction_bak,enable_img2img,use_seed,seed_val}=window.config_obj);
23
+ return [
24
+ sel_buffer,
25
+ prompt_text,
26
+ negative_prompt_text,
27
+ strength,
28
+ guidance,
29
+ step,
30
+ resize_check,
31
+ fill_mode,
32
+ enable_safety,
33
+ use_correction,
34
+ enable_img2img,
35
+ use_seed,
36
+ seed_val,
37
+ generate_num,
38
+ scheduler,
39
+ scheduler_eta,
40
+ state,
41
+ ]
42
+ }
js/setup.js ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(token_val, width, height, size, model_choice, model_path){
2
+ let app=document.querySelector("gradio-app");
3
+ app=app.shadowRoot??app;
4
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
5
+ // app.querySelector("#setup_row").style.display="none";
6
+ app.querySelector("#model_path_input").style.display="none";
7
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
8
+
9
+ if(frame.querySelector("#setup").value=="0")
10
+ {
11
+ window.my_setup=setInterval(function(){
12
+ let app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
15
+ console.log("Check PyScript...")
16
+ if(frame.querySelector("#setup").value=="1")
17
+ {
18
+ frame.querySelector("#draw").click();
19
+ clearInterval(window.my_setup);
20
+ }
21
+ }, 100)
22
+ }
23
+ else
24
+ {
25
+ frame.querySelector("#draw").click();
26
+ }
27
+ return [token_val, width, height, size, model_choice, model_path];
28
+ }
js/toolbar.js ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
2
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
3
+
4
+ // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
5
+ function getBase64(file) {
6
+ var reader = new FileReader();
7
+ reader.readAsDataURL(file);
8
+ reader.onload = function () {
9
+ add_image(reader.result);
10
+ // console.log(reader.result);
11
+ };
12
+ reader.onerror = function (error) {
13
+ console.log("Error: ", error);
14
+ };
15
+ }
16
+
17
+ function getText(file) {
18
+ var reader = new FileReader();
19
+ reader.readAsText(file);
20
+ reader.onload = function () {
21
+ window.postMessage(["load",reader.result],"*")
22
+ // console.log(reader.result);
23
+ };
24
+ reader.onerror = function (error) {
25
+ console.log("Error: ", error);
26
+ };
27
+ }
28
+
29
+ document.querySelector("#upload_file").addEventListener("change", (event)=>{
30
+ console.log(event);
31
+ let file = document.querySelector("#upload_file").files[0];
32
+ getBase64(file);
33
+ })
34
+
35
+ document.querySelector("#upload_state").addEventListener("change", (event)=>{
36
+ console.log(event);
37
+ let file = document.querySelector("#upload_state").files[0];
38
+ getText(file);
39
+ })
40
+
41
+ open_setting = function() {
42
+ if (!w2ui.foo) {
43
+ new w2form({
44
+ name: "foo",
45
+ style: "border: 0px; background-color: transparent;",
46
+ fields: [{
47
+ field: "canvas_width",
48
+ type: "int",
49
+ required: true,
50
+ html: {
51
+ label: "Canvas Width"
52
+ }
53
+ },
54
+ {
55
+ field: "canvas_height",
56
+ type: "int",
57
+ required: true,
58
+ html: {
59
+ label: "Canvas Height"
60
+ }
61
+ },
62
+ ],
63
+ record: {
64
+ canvas_width: 1200,
65
+ canvas_height: 600,
66
+ },
67
+ actions: {
68
+ Save() {
69
+ this.validate();
70
+ let record = this.getCleanRecord();
71
+ window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
72
+ w2popup.close();
73
+ },
74
+ custom: {
75
+ text: "Cancel",
76
+ style: "text-transform: uppercase",
77
+ onClick(event) {
78
+ w2popup.close();
79
+ }
80
+ }
81
+ }
82
+ });
83
+ }
84
+ w2popup.open({
85
+ title: "Form in a Popup",
86
+ body: "<div id='form' style='width: 100%; height: 100%;''></div>",
87
+ style: "padding: 15px 0px 0px 0px",
88
+ width: 500,
89
+ height: 280,
90
+ showMax: true,
91
+ async onToggle(event) {
92
+ await event.complete
93
+ w2ui.foo.resize();
94
+ }
95
+ })
96
+ .then((event) => {
97
+ w2ui.foo.render("#form")
98
+ });
99
+ }
100
+
101
+ var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
102
+ var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting"];
103
+ var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting"];
104
+ var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting"];
105
+ var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
106
+ var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
107
+
108
+ function check_button(id,text="",checked=true,tooltip="")
109
+ {
110
+ return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
111
+ }
112
+
113
+ var toolbar=new w2toolbar({
114
+ box: "#toolbar",
115
+ name: "toolbar",
116
+ tooltip: "top",
117
+ items: [
118
+ { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
119
+ { type: "break" },
120
+ { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
121
+ { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
122
+ { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
123
+ { type: "break" },
124
+ { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
125
+ { type: "break" },
126
+ { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
127
+ { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
128
+ { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
129
+ { type: "break" },
130
+ { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
131
+ { type: "break" },
132
+ { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disable:true,},
133
+ { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
134
+ { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disable:true,},
135
+ { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disable:true,},
136
+ { type: "html", id: "current", hidden: true, disable:true,
137
+ async onRefresh(event) {
138
+ await event.complete
139
+ let fragment = query.html(`
140
+ <div class="w2ui-tb-text">
141
+ <div class="w2ui-tb-count">
142
+ <span>${this.sel_value ?? "1/1"}</span>
143
+ </div> </div>`)
144
+ query(this.box).find("#tb_toolbar_item_current").append(fragment)
145
+ }
146
+ },
147
+ { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disable:true,},
148
+ { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disable:true,},
149
+ { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disable:true,},
150
+ { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disable:true,},
151
+ { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disable:true,},
152
+ { type: "break" },
153
+ { type: "spacer" },
154
+ { type: "break" },
155
+ { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
156
+ { type: "html", id: "eraser_size", hidden: true,
157
+ async onRefresh(event) {
158
+ await event.complete
159
+ // let fragment = query.html(`
160
+ // <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
161
+ // <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
162
+ let fragment = query.html(`
163
+ <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
164
+ `)
165
+ fragment.filter("input").on("change", event => {
166
+ this.eraser_size = event.target.value;
167
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
168
+ this.setCount("eraser_size_btn", event.target.value);
169
+ window.postMessage(["eraser_size", event.target.value],"*")
170
+ this.refresh();
171
+ })
172
+ query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
173
+ }
174
+ },
175
+ // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
176
+ { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
177
+ { type: "break" },
178
+ { type: "html", id: "scale",
179
+ async onRefresh(event) {
180
+ await event.complete
181
+ let fragment = query.html(`
182
+ <div class="">
183
+ <div style="padding: 4px; border: 1px solid silver">
184
+ <span>${this.scale_value ?? "100%"}</span>
185
+ </div></div>`)
186
+ query(this.box).find("#tb_toolbar_item_scale").append(fragment)
187
+ }
188
+ },
189
+ { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
190
+ { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
191
+ { type: "break" },
192
+ { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
193
+ { type: "new-line"},
194
+ { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
195
+ { type: "break" },
196
+ check_button("enable_img2img","Enable Img2Img",false),
197
+ // check_button("use_correction","Photometric Correction",false),
198
+ check_button("resize_check","Resize Small Input",true),
199
+ check_button("enable_safety","Enable Safety Checker",true),
200
+ check_button("square_selection","Square Selection Only",false),
201
+ {type: "break"},
202
+ check_button("use_seed","Use Seed:",false),
203
+ { type: "html", id: "seed_val",
204
+ async onRefresh(event) {
205
+ await event.complete
206
+ let fragment = query.html(`
207
+ <input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
208
+ fragment.filter("input").on("change", event => {
209
+ this.config_obj.seed_val = event.target.value;
210
+ parent.config_obj=this.config_obj;
211
+ this.refresh();
212
+ })
213
+ query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
214
+ }
215
+ },
216
+ { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
217
+ ],
218
+ onClick(event) {
219
+ switch(event.target){
220
+ case "setting":
221
+ open_setting();
222
+ break;
223
+ case "upload":
224
+ this.upload_mode=true
225
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
226
+ this.click("canvas");
227
+ this.click("selection");
228
+ this.show("confirm","cancel_overlay","add_image","delete_image");
229
+ this.enable("confirm","cancel_overlay","add_image","delete_image");
230
+ this.disable(...upload_button_lst);
231
+ query("#upload_file").click();
232
+ if(this.upload_tip)
233
+ {
234
+ this.upload_tip=false;
235
+ w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
236
+ }
237
+ break;
238
+ case "resize_selection":
239
+ this.resize_mode=true;
240
+ this.disable(...resize_button_lst);
241
+ this.enable("confirm","cancel_overlay");
242
+ this.show("confirm","cancel_overlay");
243
+ window.postMessage(["resize_selection",""],"*");
244
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
245
+ break;
246
+ case "confirm":
247
+ if(this.upload_mode)
248
+ {
249
+ export_image();
250
+ }
251
+ else
252
+ {
253
+ let sel_box=this.selection_box;
254
+ window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
255
+ }
256
+ case "cancel_overlay":
257
+ end_overlay();
258
+ this.hide("confirm","cancel_overlay","add_image","delete_image");
259
+ if(this.upload_mode){
260
+ this.enable(...upload_button_lst);
261
+ }
262
+ else
263
+ {
264
+ this.enable(...resize_button_lst);
265
+ window.postMessage(["resize_selection","",""],"*");
266
+ if(event.target=="cancel_overlay")
267
+ {
268
+ this.selection_box=this.selection_box_bak;
269
+ }
270
+ }
271
+ if(this.selection_box)
272
+ {
273
+ this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
274
+ }
275
+ this.disable("confirm","cancel_overlay","add_image","delete_image");
276
+ this.upload_mode=false;
277
+ this.resize_mode=false;
278
+ this.click("selection");
279
+ break;
280
+ case "add_image":
281
+ query("#upload_file").click();
282
+ break;
283
+ case "delete_image":
284
+ let active_obj = window.overlay.getActiveObject();
285
+ if(active_obj)
286
+ {
287
+ window.overlay.remove(active_obj);
288
+ window.overlay.renderAll();
289
+ }
290
+ else
291
+ {
292
+ w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
293
+ }
294
+ break;
295
+ case "load":
296
+ query("#upload_state").click();
297
+ this.selection_box=null;
298
+ this.setCount("resize_selection","");
299
+ break;
300
+ case "next":
301
+ case "prev":
302
+ window.postMessage(["outpaint", "", event.target], "*");
303
+ break;
304
+ case "outpaint":
305
+ this.click("selection");
306
+ this.disable(...outpaint_button_lst);
307
+ this.show(...outpaint_result_lst);
308
+ if(this.outpaint_tip)
309
+ {
310
+ this.outpaint_tip=false;
311
+ w2utils.notify("The canvas stays locked until you accept/cancel current outpainting",{timeout:10000,where:query("#container")})
312
+ }
313
+ document.querySelector("#container").style.pointerEvents="none";
314
+ case "retry":
315
+ this.disable(...outpaint_result_func_lst);
316
+ window.postMessage(["transfer",""],"*")
317
+ break;
318
+ case "accept":
319
+ case "cancel":
320
+ this.hide(...outpaint_result_lst);
321
+ this.disable(...outpaint_result_func_lst);
322
+ this.enable(...outpaint_button_lst);
323
+ document.querySelector("#container").style.pointerEvents="auto";
324
+ window.postMessage(["click", event.target],"*");
325
+ let app=parent.document.querySelector("gradio-app");
326
+ app=app.shadowRoot??app;
327
+ app.querySelector("#cancel").click();
328
+ break;
329
+ case "eraser":
330
+ case "selection":
331
+ case "canvas":
332
+ if(event.target=="eraser")
333
+ {
334
+ this.show("eraser_size","eraser_size_btn");
335
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
336
+ window.overlay.isDrawingMode = true;
337
+ }
338
+ else
339
+ {
340
+ this.hide("eraser_size","eraser_size_btn");
341
+ window.overlay.isDrawingMode = false;
342
+ }
343
+ if(this.upload_mode)
344
+ {
345
+ if(event.target=="canvas")
346
+ {
347
+ window.postMessage(["mode", event.target],"*")
348
+ document.querySelector("#overlay_container").style.pointerEvents="none";
349
+ document.querySelector("#overlay_container").style.opacity = 0.5;
350
+ }
351
+ else
352
+ {
353
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
354
+ document.querySelector("#overlay_container").style.opacity = 1.0;
355
+ }
356
+ }
357
+ else
358
+ {
359
+ window.postMessage(["mode", event.target],"*")
360
+ }
361
+ break;
362
+ case "help":
363
+ w2popup.open({
364
+ title: "Document",
365
+ body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
366
+ })
367
+ break;
368
+ case "clear":
369
+ w2confirm("Reset canvas?").yes(() => {
370
+ window.postMessage(["click", event.target],"*");
371
+ }).no(() => {})
372
+ break;
373
+ case "random_seed":
374
+ this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
375
+ parent.config_obj=this.config_obj;
376
+ this.refresh();
377
+ break;
378
+ case "enable_img2img":
379
+ case "use_correction":
380
+ case "resize_check":
381
+ case "enable_safety":
382
+ case "use_seed":
383
+ case "square_selection":
384
+ let target=this.get(event.target);
385
+ target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
386
+ this.config_obj[event.target]=!target.checked;
387
+ parent.config_obj=this.config_obj;
388
+ this.refresh();
389
+ break;
390
+ case "save":
391
+ case "export":
392
+ ask_filename(event.target);
393
+ break;
394
+ default:
395
+ // clear, save, export, outpaint, retry
396
+ // break, save, export, accept, retry, outpaint
397
+ window.postMessage(["click", event.target],"*")
398
+ }
399
+ console.log("Target: "+ event.target, event)
400
+ }
401
+ })
402
+ window.w2ui=w2ui;
403
+ w2ui.toolbar.config_obj={
404
+ resize_check: true,
405
+ enable_safety: true,
406
+ use_correction: false,
407
+ enable_img2img: false,
408
+ use_seed: false,
409
+ seed_val: 0,
410
+ square_selection: false,
411
+ };
412
+ w2ui.toolbar.outpaint_tip=true;
413
+ w2ui.toolbar.upload_tip=true;
414
+ window.update_count=function(cur,total){
415
+ w2ui.toolbar.sel_value=`${cur}/${total}`;
416
+ w2ui.toolbar.refresh();
417
+ }
418
+ window.update_eraser=function(val,max_val){
419
+ w2ui.toolbar.eraser_size=`${val}`;
420
+ w2ui.toolbar.eraser_max=`${max_val}`;
421
+ w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
422
+ w2ui.toolbar.refresh();
423
+ }
424
+ window.update_scale=function(val){
425
+ w2ui.toolbar.scale_value=`${val}`;
426
+ w2ui.toolbar.refresh();
427
+ }
428
+ window.enable_result_lst=function(){
429
+ w2ui.toolbar.enable(...outpaint_result_lst);
430
+ }
431
+ function onObjectScaled(e)
432
+ {
433
+ let object = e.target;
434
+ if(object.isType("rect"))
435
+ {
436
+ let width=object.getScaledWidth();
437
+ let height=object.getScaledHeight();
438
+ object.scale(1);
439
+ width=Math.max(Math.min(width,window.overlay.width-object.left),256);
440
+ height=Math.max(Math.min(height,window.overlay.height-object.top),256);
441
+ let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
442
+ let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
443
+ if(window.w2ui.toolbar.config_obj.square_selection)
444
+ {
445
+ let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
446
+ width=max_val;
447
+ height=max_val;
448
+ }
449
+ object.set({ width: width, height: height, left:l,top:t})
450
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
451
+ window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
452
+ window.w2ui.toolbar.refresh();
453
+ }
454
+ }
455
+ function onObjectMoved(e)
456
+ {
457
+ let object = e.target;
458
+ if(object.isType("rect"))
459
+ {
460
+ let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
461
+ let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
462
+ object.set({left:l,top:t});
463
+ window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
464
+ }
465
+ }
466
+ window.setup_overlay=function(width,height)
467
+ {
468
+ if(window.overlay)
469
+ {
470
+ window.overlay.setDimensions({width:width,height:height});
471
+ let app=parent.document.querySelector("gradio-app");
472
+ app=app.shadowRoot??app;
473
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
474
+ document.querySelector("#container").style.height= height+"px";
475
+ document.querySelector("#container").style.width = width+"px";
476
+ }
477
+ else
478
+ {
479
+ canvas=new fabric.Canvas("overlay_canvas");
480
+ canvas.setDimensions({width:width,height:height});
481
+ let app=parent.document.querySelector("gradio-app");
482
+ app=app.shadowRoot??app;
483
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
484
+ canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
485
+ canvas.on("object:scaling", onObjectScaled);
486
+ canvas.on("object:moving", onObjectMoved);
487
+ window.overlay=canvas;
488
+ }
489
+ document.querySelector("#overlay_container").style.pointerEvents="none";
490
+ }
491
+ window.update_overlay=function(width,height)
492
+ {
493
+ window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
494
+ // document.querySelector("#overlay_container").style.pointerEvents="none";
495
+ }
496
+ window.adjust_selection=function(x,y,width,height)
497
+ {
498
+ var rect = new fabric.Rect({
499
+ left: x,
500
+ top: y,
501
+ fill: "rgba(0,0,0,0)",
502
+ strokeWidth: 3,
503
+ stroke: "rgba(0,0,0,0.7)",
504
+ cornerColor: "red",
505
+ cornerStrokeColor: "red",
506
+ borderColor: "rgba(255, 0, 0, 1.0)",
507
+ width: width,
508
+ height: height,
509
+ lockRotation: true,
510
+ });
511
+ rect.setControlsVisibility({ mtr: false });
512
+ window.overlay.add(rect);
513
+ window.overlay.setActiveObject(window.overlay.item(0));
514
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
515
+ window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
516
+ }
517
+ function add_image(url)
518
+ {
519
+ fabric.Image.fromURL(url,function(img){
520
+ window.overlay.add(img);
521
+ window.overlay.setActiveObject(img);
522
+ },{left:100,top:100});
523
+ }
524
+ function export_image()
525
+ {
526
+ data=window.overlay.toDataURL();
527
+ document.querySelector("#upload_content").value=data;
528
+ window.postMessage(["upload",""],"*");
529
+ end_overlay();
530
+ }
531
+ function end_overlay()
532
+ {
533
+ window.overlay.clear();
534
+ document.querySelector("#overlay_container").style.opacity = 1.0;
535
+ document.querySelector("#overlay_container").style.pointerEvents="none";
536
+ }
537
+ function ask_filename(target)
538
+ {
539
+ w2prompt({
540
+ label: "Enter filename",
541
+ value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
542
+ })
543
+ .change((event) => {
544
+ console.log("change", event.detail.originalEvent.target.value);
545
+ })
546
+ .ok((event) => {
547
+ console.log("value=", event.detail.value);
548
+ window.postMessage(["click",target,event.detail.value],"*");
549
+ })
550
+ .cancel((event) => {
551
+ console.log("cancel");
552
+ });
553
+ }
554
+
555
+ document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
556
+ window.setup_shortcut=function(json)
557
+ {
558
+ var config=JSON.parse(json);
559
+ var key_map={};
560
+ Object.keys(config.shortcut).forEach(k=>{
561
+ key_map[config.shortcut[k]]=k;
562
+ })
563
+ document.addEventListener("keydown",(e)=>{
564
+ if(e.target.tagName!="INPUT")
565
+ {
566
+ let key=e.key;
567
+ if(e.ctrlKey)
568
+ {
569
+ key="Ctrl+"+e.key;
570
+ if(key in key_map)
571
+ {
572
+ e.preventDefault();
573
+ }
574
+ }
575
+ if(key in key_map)
576
+ {
577
+ w2ui.toolbar.click(key_map[key]);
578
+ }
579
+ }
580
+ })
581
+ }
js/upload.js ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a,b){
2
+ if(!window.my_observe_upload)
3
+ {
4
+ console.log("setup upload here");
5
+ window.my_observe_upload = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
8
+ frame.querySelector("#upload").click();
9
+ });
10
+ window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
11
+ window.my_observe_upload.observe(window.my_observe_upload_target, {
12
+ attributes: false,
13
+ subtree: true,
14
+ childList: true,
15
+ characterData: true
16
+ });
17
+ }
18
+ return [a,b];
19
+ }
js/w2ui.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/xss.js ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var setup_outpaint=function(){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ };
23
+ window.config_obj={
24
+ resize_check: true,
25
+ enable_safety: true,
26
+ use_correction: false,
27
+ enable_img2img: false,
28
+ use_seed: false,
29
+ seed_val: 0,
30
+ };
31
+ setup_outpaint();
models/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
models/v1-inpainting-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 7.5e-05
3
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid # important
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ finetune_keys: null
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ build-essential
2
+ python3-opencv
3
+ libopencv-dev
4
+ cmake
perlin2d.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ ##########
4
+ # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5
+ def perlin(x, y, seed=0):
6
+ # permutation table
7
+ np.random.seed(seed)
8
+ p = np.arange(256, dtype=int)
9
+ np.random.shuffle(p)
10
+ p = np.stack([p, p]).flatten()
11
+ # coordinates of the top-left
12
+ xi, yi = x.astype(int), y.astype(int)
13
+ # internal coordinates
14
+ xf, yf = x - xi, y - yi
15
+ # fade factors
16
+ u, v = fade(xf), fade(yf)
17
+ # noise components
18
+ n00 = gradient(p[p[xi] + yi], xf, yf)
19
+ n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20
+ n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21
+ n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22
+ # combine noises
23
+ x1 = lerp(n00, n10, u)
24
+ x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25
+ return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26
+
27
+
28
+ def lerp(a, b, x):
29
+ "linear interpolation"
30
+ return a + x * (b - a)
31
+
32
+
33
+ def fade(t):
34
+ "6t^5 - 15t^4 + 10t^3"
35
+ return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36
+
37
+
38
+ def gradient(h, x, y):
39
+ "grad converts h to the right gradient vector and return the dot product with (x,y)"
40
+ vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41
+ g = vectors[h % 4]
42
+ return g[:, :, 0] * x + g[:, :, 1] * y
43
+
44
+
45
+ ##########
postprocess.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import time
27
+ import argparse
28
+ import os
29
+ import fpie
30
+ from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
31
+ from fpie.io import read_images, write_image
32
+ from process import BaseProcessor, EquProcessor, GridProcessor
33
+
34
+ from PIL import Image
35
+ import numpy as np
36
+ import skimage
37
+ import skimage.measure
38
+ import scipy
39
+ import scipy.signal
40
+
41
+
42
+ class PhotometricCorrection:
43
+ def __init__(self,quite=False):
44
+ self.get_parser("cli")
45
+ args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
46
+ args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
47
+ self.backend=args.backend
48
+ self.args=args
49
+ self.quite=quite
50
+ proc: BaseProcessor
51
+ proc = GridProcessor(
52
+ args.gradient,
53
+ args.backend,
54
+ args.cpu,
55
+ args.mpi_sync_interval,
56
+ args.block_size,
57
+ args.grid_x,
58
+ args.grid_y,
59
+ )
60
+ print(
61
+ f"[PIE]Successfully initialize PIE {args.method} solver "
62
+ f"with {args.backend} backend"
63
+ )
64
+ self.proc=proc
65
+
66
+ def run(self, original_image, inpainted_image, mode="mask_mode"):
67
+ print(f"[PIE] start")
68
+ if mode=="disabled":
69
+ return inpainted_image
70
+ input_arr=np.array(original_image)
71
+ if input_arr[:,:,-1].sum()<1:
72
+ return inpainted_image
73
+ output_arr=np.array(inpainted_image)
74
+ mask=input_arr[:,:,-1]
75
+ mask=255-mask
76
+ if mask.sum()<1 and mode=="mask_mode":
77
+ mode=""
78
+ if mode=="mask_mode":
79
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
80
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
81
+ else:
82
+ mask[8:-9,8:-9]=255
83
+ mask = mask[:,:,np.newaxis].repeat(3,axis=2)
84
+ nmask=mask.copy()
85
+ output_arr2=output_arr[:,:,0:3].copy()
86
+ input_arr2=input_arr[:,:,0:3].copy()
87
+ output_arr2[nmask<128]=0
88
+ input_arr2[nmask>=128]=0
89
+ output_arr2+=input_arr2
90
+ src = output_arr2[:,:,0:3]
91
+ tgt = src.copy()
92
+ proc=self.proc
93
+ args=self.args
94
+ if proc.root:
95
+ n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
96
+ proc.sync()
97
+ if proc.root:
98
+ result = tgt
99
+ t = time.time()
100
+ if args.p == 0:
101
+ args.p = args.n
102
+
103
+ for i in range(0, args.n, args.p):
104
+ if proc.root:
105
+ result, err = proc.step(args.p) # type: ignore
106
+ print(f"[PIE] Iter {i + args.p}, abs_err {err}")
107
+ else:
108
+ proc.step(args.p)
109
+
110
+ if proc.root:
111
+ dt = time.time() - t
112
+ print(f"[PIE] Time elapsed: {dt:.4f}s")
113
+ # make sure consistent with dummy process
114
+ return Image.fromarray(result)
115
+
116
+
117
+ def get_parser(self,gen_type: str) -> argparse.Namespace:
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "-v", "--version", action="store_true", help="show the version and exit"
121
+ )
122
+ parser.add_argument(
123
+ "--check-backend", action="store_true", help="print all available backends"
124
+ )
125
+ if gen_type == "gui" and "mpi" in ALL_BACKEND:
126
+ # gui doesn't support MPI backend
127
+ ALL_BACKEND.remove("mpi")
128
+ parser.add_argument(
129
+ "-b",
130
+ "--backend",
131
+ type=str,
132
+ choices=ALL_BACKEND,
133
+ default=DEFAULT_BACKEND,
134
+ help="backend choice",
135
+ )
136
+ parser.add_argument(
137
+ "-c",
138
+ "--cpu",
139
+ type=int,
140
+ default=CPU_COUNT,
141
+ help="number of CPU used",
142
+ )
143
+ parser.add_argument(
144
+ "-z",
145
+ "--block-size",
146
+ type=int,
147
+ default=1024,
148
+ help="cuda block size (only for equ solver)",
149
+ )
150
+ parser.add_argument(
151
+ "--method",
152
+ type=str,
153
+ choices=["equ", "grid"],
154
+ default="equ",
155
+ help="how to parallelize computation",
156
+ )
157
+ parser.add_argument("-s", "--source", type=str, help="source image filename")
158
+ if gen_type == "cli":
159
+ parser.add_argument(
160
+ "-m",
161
+ "--mask",
162
+ type=str,
163
+ help="mask image filename (default is to use the whole source image)",
164
+ default="",
165
+ )
166
+ parser.add_argument("-t", "--target", type=str, help="target image filename")
167
+ parser.add_argument("-o", "--output", type=str, help="output image filename")
168
+ if gen_type == "cli":
169
+ parser.add_argument(
170
+ "-h0", type=int, help="mask position (height) on source image", default=0
171
+ )
172
+ parser.add_argument(
173
+ "-w0", type=int, help="mask position (width) on source image", default=0
174
+ )
175
+ parser.add_argument(
176
+ "-h1", type=int, help="mask position (height) on target image", default=0
177
+ )
178
+ parser.add_argument(
179
+ "-w1", type=int, help="mask position (width) on target image", default=0
180
+ )
181
+ parser.add_argument(
182
+ "-g",
183
+ "--gradient",
184
+ type=str,
185
+ choices=["max", "src", "avg"],
186
+ default="max",
187
+ help="how to calculate gradient for PIE",
188
+ )
189
+ parser.add_argument(
190
+ "-n",
191
+ type=int,
192
+ help="how many iteration would you perfer, the more the better",
193
+ default=5000,
194
+ )
195
+ if gen_type == "cli":
196
+ parser.add_argument(
197
+ "-p", type=int, help="output result every P iteration", default=0
198
+ )
199
+ if "mpi" in ALL_BACKEND:
200
+ parser.add_argument(
201
+ "--mpi-sync-interval",
202
+ type=int,
203
+ help="MPI sync iteration interval",
204
+ default=100,
205
+ )
206
+ parser.add_argument(
207
+ "--grid-x", type=int, help="x axis stride for grid solver", default=8
208
+ )
209
+ parser.add_argument(
210
+ "--grid-y", type=int, help="y axis stride for grid solver", default=8
211
+ )
212
+ self.parser=parser
213
+
214
+ if __name__ =="__main__":
215
+ import sys
216
+ import io
217
+ import base64
218
+ from PIL import Image
219
+ def base64_to_pil(base64_str):
220
+ data = base64.b64decode(str(base64_str))
221
+ pil = Image.open(io.BytesIO(data))
222
+ return pil
223
+
224
+ def pil_to_base64(out_pil):
225
+ out_buffer = io.BytesIO()
226
+ out_pil.save(out_buffer, format="PNG")
227
+ out_buffer.seek(0)
228
+ base64_bytes = base64.b64encode(out_buffer.read())
229
+ base64_str = base64_bytes.decode("ascii")
230
+ return base64_str
231
+ correction_func=PhotometricCorrection(quite=True)
232
+ while True:
233
+ buffer = sys.stdin.readline()
234
+ print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
235
+ if len(buffer)==0:
236
+ break
237
+ if isinstance(buffer,str):
238
+ lst=buffer.strip().split(",")
239
+ else:
240
+ lst=buffer.decode("ascii").strip().split(",")
241
+ img0=base64_to_pil(lst[0])
242
+ img1=base64_to_pil(lst[1])
243
+ ret=correction_func.run(img0,img1,mode=lst[2])
244
+ ret_base64=pil_to_base64(ret)
245
+ if isinstance(buffer,str):
246
+ sys.stdout.write(f"{ret_base64}\n")
247
+ else:
248
+ sys.stdout.write(f"{ret_base64}\n".encode())
249
+ sys.stdout.flush()
process.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ import os
26
+ from abc import ABC, abstractmethod
27
+ from typing import Any, Optional, Tuple
28
+
29
+ import numpy as np
30
+
31
+ from fpie import np_solver
32
+
33
+ import scipy
34
+ import scipy.signal
35
+
36
+ CPU_COUNT = os.cpu_count() or 1
37
+ DEFAULT_BACKEND = "numpy"
38
+ ALL_BACKEND = ["numpy"]
39
+
40
+ try:
41
+ from fpie import numba_solver
42
+ ALL_BACKEND += ["numba"]
43
+ DEFAULT_BACKEND = "numba"
44
+ except ImportError:
45
+ numba_solver = None # type: ignore
46
+
47
+ try:
48
+ from fpie import taichi_solver
49
+ ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
50
+ DEFAULT_BACKEND = "taichi-cpu"
51
+ except ImportError:
52
+ taichi_solver = None # type: ignore
53
+
54
+ # try:
55
+ # from fpie import core_gcc # type: ignore
56
+ # DEFAULT_BACKEND = "gcc"
57
+ # ALL_BACKEND.append("gcc")
58
+ # except ImportError:
59
+ # core_gcc = None
60
+
61
+ # try:
62
+ # from fpie import core_openmp # type: ignore
63
+ # DEFAULT_BACKEND = "openmp"
64
+ # ALL_BACKEND.append("openmp")
65
+ # except ImportError:
66
+ # core_openmp = None
67
+
68
+ # try:
69
+ # from mpi4py import MPI
70
+
71
+ # from fpie import core_mpi # type: ignore
72
+ # ALL_BACKEND.append("mpi")
73
+ # except ImportError:
74
+ # MPI = None # type: ignore
75
+ # core_mpi = None
76
+
77
+ try:
78
+ from fpie import core_cuda # type: ignore
79
+ DEFAULT_BACKEND = "cuda"
80
+ ALL_BACKEND.append("cuda")
81
+ except ImportError:
82
+ core_cuda = None
83
+
84
+
85
+ class BaseProcessor(ABC):
86
+ """API definition for processor class."""
87
+
88
+ def __init__(
89
+ self, gradient: str, rank: int, backend: str, core: Optional[Any]
90
+ ):
91
+ if core is None:
92
+ error_msg = {
93
+ "numpy":
94
+ "Please run `pip install numpy`.",
95
+ "numba":
96
+ "Please run `pip install numba`.",
97
+ "gcc":
98
+ "Please install cmake and gcc in your operating system.",
99
+ "openmp":
100
+ "Please make sure your gcc is compatible with `-fopenmp` option.",
101
+ "mpi":
102
+ "Please install MPI and run `pip install mpi4py`.",
103
+ "cuda":
104
+ "Please make sure nvcc and cuda-related libraries are available.",
105
+ "taichi":
106
+ "Please run `pip install taichi`.",
107
+ }
108
+ print(error_msg[backend.split("-")[0]])
109
+
110
+ raise AssertionError(f"Invalid backend {backend}.")
111
+
112
+ self.gradient = gradient
113
+ self.rank = rank
114
+ self.backend = backend
115
+ self.core = core
116
+ self.root = rank == 0
117
+
118
+ def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119
+ if self.gradient == "src":
120
+ return a
121
+ if self.gradient == "avg":
122
+ return (a + b) / 2
123
+ # mix gradient, see Equ. 12 in PIE paper
124
+ mask = np.abs(a) < np.abs(b)
125
+ a[mask] = b[mask]
126
+ return a
127
+
128
+ @abstractmethod
129
+ def reset(
130
+ self,
131
+ src: np.ndarray,
132
+ mask: np.ndarray,
133
+ tgt: np.ndarray,
134
+ mask_on_src: Tuple[int, int],
135
+ mask_on_tgt: Tuple[int, int],
136
+ ) -> int:
137
+ pass
138
+
139
+ def sync(self) -> None:
140
+ self.core.sync()
141
+
142
+ @abstractmethod
143
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
144
+ pass
145
+
146
+
147
+ class EquProcessor(BaseProcessor):
148
+ """PIE Jacobi equation processor."""
149
+
150
+ def __init__(
151
+ self,
152
+ gradient: str = "max",
153
+ backend: str = DEFAULT_BACKEND,
154
+ n_cpu: int = CPU_COUNT,
155
+ min_interval: int = 100,
156
+ block_size: int = 1024,
157
+ ):
158
+ core: Optional[Any] = None
159
+ rank = 0
160
+
161
+ if backend == "numpy":
162
+ core = np_solver.EquSolver()
163
+ elif backend == "numba" and numba_solver is not None:
164
+ core = numba_solver.EquSolver()
165
+ elif backend == "gcc":
166
+ core = core_gcc.EquSolver()
167
+ elif backend == "openmp" and core_openmp is not None:
168
+ core = core_openmp.EquSolver(n_cpu)
169
+ elif backend == "mpi" and core_mpi is not None:
170
+ core = core_mpi.EquSolver(min_interval)
171
+ rank = MPI.COMM_WORLD.Get_rank()
172
+ elif backend == "cuda" and core_cuda is not None:
173
+ core = core_cuda.EquSolver(block_size)
174
+ elif backend.startswith("taichi") and taichi_solver is not None:
175
+ core = taichi_solver.EquSolver(backend, n_cpu, block_size)
176
+
177
+ super().__init__(gradient, rank, backend, core)
178
+
179
+ def mask2index(
180
+ self, mask: np.ndarray
181
+ ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
182
+ x, y = np.nonzero(mask)
183
+ max_id = x.shape[0] + 1
184
+ index = np.zeros((max_id, 3))
185
+ ids = self.core.partition(mask)
186
+ ids[mask == 0] = 0 # reserve id=0 for constant
187
+ index = ids[x, y].argsort()
188
+ return ids, max_id, x[index], y[index]
189
+
190
+ def reset(
191
+ self,
192
+ src: np.ndarray,
193
+ mask: np.ndarray,
194
+ tgt: np.ndarray,
195
+ mask_on_src: Tuple[int, int],
196
+ mask_on_tgt: Tuple[int, int],
197
+ ) -> int:
198
+ assert self.root
199
+ # check validity
200
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
201
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
202
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
203
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
204
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
205
+
206
+ if len(mask.shape) == 3:
207
+ mask = mask.mean(-1)
208
+ mask = (mask >= 128).astype(np.int32)
209
+
210
+ # zero-out edge
211
+ mask[0] = 0
212
+ mask[-1] = 0
213
+ mask[:, 0] = 0
214
+ mask[:, -1] = 0
215
+
216
+ x, y = np.nonzero(mask)
217
+ x0, x1 = x.min() - 1, x.max() + 2
218
+ y0, y1 = y.min() - 1, y.max() + 2
219
+ mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
220
+ mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
221
+ mask = mask[x0:x1, y0:y1]
222
+ ids, max_id, index_x, index_y = self.mask2index(mask)
223
+
224
+ src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
225
+ tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
226
+
227
+ src_C = src[src_x, src_y].astype(np.float32)
228
+ src_U = src[src_x - 1, src_y].astype(np.float32)
229
+ src_D = src[src_x + 1, src_y].astype(np.float32)
230
+ src_L = src[src_x, src_y - 1].astype(np.float32)
231
+ src_R = src[src_x, src_y + 1].astype(np.float32)
232
+ tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
233
+ tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
234
+ tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
235
+ tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
236
+ tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
237
+
238
+ grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
239
+ + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
240
+ + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
241
+ + self.mixgrad(src_C - src_D, tgt_C - tgt_D)
242
+
243
+ A = np.zeros((max_id, 4), np.int32)
244
+ X = np.zeros((max_id, 3), np.float32)
245
+ B = np.zeros((max_id, 3), np.float32)
246
+
247
+ X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
248
+ # four-way
249
+ A[1:, 0] = ids[index_x - 1, index_y]
250
+ A[1:, 1] = ids[index_x + 1, index_y]
251
+ A[1:, 2] = ids[index_x, index_y - 1]
252
+ A[1:, 3] = ids[index_x, index_y + 1]
253
+ B[1:] = grad
254
+ m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
255
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
256
+ m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
257
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
258
+ m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
259
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
260
+ m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
261
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
262
+
263
+ self.tgt = tgt.copy()
264
+ self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
265
+ self.core.reset(max_id, A, X, B)
266
+ return max_id
267
+
268
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
269
+ result = self.core.step(iteration)
270
+ if self.root:
271
+ x, err = result
272
+ self.tgt[self.tgt_index] = x[1:]
273
+ return self.tgt, err
274
+ return None
275
+
276
+
277
+ class GridProcessor(BaseProcessor):
278
+ """PIE grid processor."""
279
+
280
+ def __init__(
281
+ self,
282
+ gradient: str = "max",
283
+ backend: str = DEFAULT_BACKEND,
284
+ n_cpu: int = CPU_COUNT,
285
+ min_interval: int = 100,
286
+ block_size: int = 1024,
287
+ grid_x: int = 8,
288
+ grid_y: int = 8,
289
+ ):
290
+ core: Optional[Any] = None
291
+ rank = 0
292
+
293
+ if backend == "numpy":
294
+ core = np_solver.GridSolver()
295
+ elif backend == "numba" and numba_solver is not None:
296
+ core = numba_solver.GridSolver()
297
+ elif backend == "gcc":
298
+ core = core_gcc.GridSolver(grid_x, grid_y)
299
+ elif backend == "openmp" and core_openmp is not None:
300
+ core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
301
+ elif backend == "mpi" and core_mpi is not None:
302
+ core = core_mpi.GridSolver(min_interval)
303
+ rank = MPI.COMM_WORLD.Get_rank()
304
+ elif backend == "cuda" and core_cuda is not None:
305
+ core = core_cuda.GridSolver(grid_x, grid_y)
306
+ elif backend.startswith("taichi") and taichi_solver is not None:
307
+ core = taichi_solver.GridSolver(
308
+ grid_x, grid_y, backend, n_cpu, block_size
309
+ )
310
+
311
+ super().__init__(gradient, rank, backend, core)
312
+
313
+ def reset(
314
+ self,
315
+ src: np.ndarray,
316
+ mask: np.ndarray,
317
+ tgt: np.ndarray,
318
+ mask_on_src: Tuple[int, int],
319
+ mask_on_tgt: Tuple[int, int],
320
+ ) -> int:
321
+ assert self.root
322
+ # check validity
323
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
324
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
325
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
326
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
327
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
328
+
329
+ if len(mask.shape) == 3:
330
+ mask = mask.mean(-1)
331
+ mask = (mask >= 128).astype(np.int32)
332
+
333
+ # zero-out edge
334
+ mask[0] = 0
335
+ mask[-1] = 0
336
+ mask[:, 0] = 0
337
+ mask[:, -1] = 0
338
+
339
+ x, y = np.nonzero(mask)
340
+ x0, x1 = x.min() - 1, x.max() + 2
341
+ y0, y1 = y.min() - 1, y.max() + 2
342
+ mask = mask[x0:x1, y0:y1]
343
+ max_id = np.prod(mask.shape)
344
+
345
+ src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
346
+ mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
347
+ tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
348
+ mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
349
+ grad = np.zeros([*mask.shape, 3], np.float32)
350
+ grad[1:] += self.mixgrad(
351
+ src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
352
+ )
353
+ grad[:-1] += self.mixgrad(
354
+ src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
355
+ )
356
+ grad[:, 1:] += self.mixgrad(
357
+ src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
358
+ )
359
+ grad[:, :-1] += self.mixgrad(
360
+ src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
361
+ )
362
+
363
+ grad[mask == 0] = 0
364
+ if True:
365
+ kernel = [[1] * 3 for _ in range(3)]
366
+ nmask = mask.copy()
367
+ nmask[nmask > 0] = 1
368
+ res = scipy.signal.convolve2d(
369
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
370
+ )
371
+ res[nmask < 1] = 0
372
+ res[res == 9] = 0
373
+ res[res > 0] = 1
374
+ grad[res>0]=0
375
+ # ylst, xlst = res.nonzero()
376
+ # for y, x in zip(ylst, xlst):
377
+ # grad[y,x]=0
378
+ # for yi in range(-1,2):
379
+ # for xi in range(-1,2):
380
+ # grad[y+yi,x+xi]=0
381
+ self.x0 = mask_on_tgt[0] + x0
382
+ self.x1 = mask_on_tgt[0] + x1
383
+ self.y0 = mask_on_tgt[1] + y0
384
+ self.y1 = mask_on_tgt[1] + y1
385
+ self.tgt = tgt.copy()
386
+ self.core.reset(max_id, mask, tgt_crop, grad)
387
+ return max_id
388
+
389
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
390
+ result = self.core.step(iteration)
391
+ if self.root:
392
+ tgt, err = result
393
+ self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
394
+ return self.tgt, err
395
+ return None
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/nightly/cu117
2
+ imageio==2.19.5
3
+ imageio-ffmpeg==0.4.7
4
+ numpy==1.22.4
5
+ opencv-python-headless==4.6.0.66
6
+ torch[dynamo]>=0.0.dev0
7
+ torchvision
8
+ Pillow
9
+ scipy
10
+ scikit-image
11
+ diffusers==0.9.0
12
+ transformers
13
+ ftfy
14
+ fpie
15
+ accelerate
16
+ ninja
17
+ setuptools==59.8.0
utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageFilter
3
+ import cv2
4
+ import numpy as np
5
+ import scipy
6
+ import scipy.signal
7
+ from scipy.spatial import cKDTree
8
+
9
+ import os
10
+ from perlin2d import *
11
+
12
+ patch_match_compiled = True
13
+
14
+ try:
15
+ from PyPatchMatch import patch_match
16
+ except Exception as e:
17
+ try:
18
+ import patch_match
19
+ except Exception as e:
20
+ patch_match_compiled = False
21
+
22
+ try:
23
+ patch_match
24
+ except NameError:
25
+ print("patch_match compiling failed, will fall back to edge_pad")
26
+ patch_match_compiled = False
27
+
28
+
29
+
30
+
31
+ def edge_pad(img, mask, mode=1):
32
+ if mode == 0:
33
+ nmask = mask.copy()
34
+ nmask[nmask > 0] = 1
35
+ res0 = 1 - nmask
36
+ res1 = nmask
37
+ p0 = np.stack(res0.nonzero(), axis=0).transpose()
38
+ p1 = np.stack(res1.nonzero(), axis=0).transpose()
39
+ min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
40
+ loc = p1[min_dist_idx]
41
+ for (a, b), (c, d) in zip(p0, loc):
42
+ img[a, b] = img[c, d]
43
+ elif mode == 1:
44
+ record = {}
45
+ kernel = [[1] * 3 for _ in range(3)]
46
+ nmask = mask.copy()
47
+ nmask[nmask > 0] = 1
48
+ res = scipy.signal.convolve2d(
49
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
50
+ )
51
+ res[nmask < 1] = 0
52
+ res[res == 9] = 0
53
+ res[res > 0] = 1
54
+ ylst, xlst = res.nonzero()
55
+ queue = [(y, x) for y, x in zip(ylst, xlst)]
56
+ # bfs here
57
+ cnt = res.astype(np.float32)
58
+ acc = img.astype(np.float32)
59
+ step = 1
60
+ h = acc.shape[0]
61
+ w = acc.shape[1]
62
+ offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
63
+ while queue:
64
+ target = []
65
+ for y, x in queue:
66
+ val = acc[y][x]
67
+ for yo, xo in offset:
68
+ yn = y + yo
69
+ xn = x + xo
70
+ if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
71
+ if record.get((yn, xn), step) == step:
72
+ acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
73
+ cnt[yn][xn] += 1
74
+ acc[yn][xn] /= cnt[yn][xn]
75
+ if (yn, xn) not in record:
76
+ record[(yn, xn)] = step
77
+ target.append((yn, xn))
78
+ step += 1
79
+ queue = target
80
+ img = acc.astype(np.uint8)
81
+ else:
82
+ nmask = mask.copy()
83
+ ylst, xlst = nmask.nonzero()
84
+ yt, xt = ylst.min(), xlst.min()
85
+ yb, xb = ylst.max(), xlst.max()
86
+ content = img[yt : yb + 1, xt : xb + 1]
87
+ img = np.pad(
88
+ content,
89
+ ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
90
+ mode="edge",
91
+ )
92
+ return img, mask
93
+
94
+
95
+ def perlin_noise(img, mask):
96
+ lin = np.linspace(0, 5, mask.shape[0], endpoint=False)
97
+ x, y = np.meshgrid(lin, lin)
98
+ avg = img.mean(axis=0).mean(axis=0)
99
+ # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
100
+ noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
101
+ noise = np.stack(noise, axis=-1)
102
+ # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
103
+ # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
104
+ # mask_image=Image.fromarray(mask)
105
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
106
+ # mask=np.array(mask_image)
107
+ nmask = mask.copy()
108
+ # nmask=nmask/255.0
109
+ nmask[mask > 0] = 1
110
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
111
+ # img=img.astype(np.uint8)
112
+ return img, mask
113
+
114
+
115
+ def gaussian_noise(img, mask):
116
+ noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
117
+ noise = (noise + 1) / 2 * 255
118
+ noise = noise.astype(np.uint8)
119
+ nmask = mask.copy()
120
+ nmask[mask > 0] = 1
121
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
122
+ return img, mask
123
+
124
+
125
+ def cv2_telea(img, mask):
126
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
127
+ return ret, mask
128
+
129
+
130
+ def cv2_ns(img, mask):
131
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
132
+ return ret, mask
133
+
134
+
135
+ def patch_match_func(img, mask):
136
+ ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
137
+ return ret, mask
138
+
139
+
140
+ def mean_fill(img, mask):
141
+ avg = img.mean(axis=0).mean(axis=0)
142
+ img[mask < 1] = avg
143
+ return img, mask
144
+
145
+ def g_diffuser(img,mask):
146
+ return img, mask
147
+
148
+ def dummy_fill(img,mask):
149
+ return img,mask
150
+ functbl = {
151
+ "gaussian": gaussian_noise,
152
+ "perlin": perlin_noise,
153
+ "edge_pad": edge_pad,
154
+ "patchmatch": patch_match_func if patch_match_compiled else edge_pad,
155
+ "cv2_ns": cv2_ns,
156
+ "cv2_telea": cv2_telea,
157
+ "g_diffuser": g_diffuser,
158
+ "g_diffuser_lib": dummy_fill,
159
+ }
160
+
161
+ try:
162
+ from postprocess import PhotometricCorrection
163
+ correction_func = PhotometricCorrection()
164
+ except Exception as e:
165
+ print(e, "so PhotometricCorrection is disabled")
166
+ class DummyCorrection:
167
+ def __init__(self):
168
+ self.backend=""
169
+ pass
170
+ def run(self,a,b,**kwargs):
171
+ return b
172
+ correction_func=DummyCorrection()
173
+
174
+ if "taichi" in correction_func.backend:
175
+ import sys
176
+ import io
177
+ import base64
178
+ from PIL import Image
179
+ def base64_to_pil(base64_str):
180
+ data = base64.b64decode(str(base64_str))
181
+ pil = Image.open(io.BytesIO(data))
182
+ return pil
183
+
184
+ def pil_to_base64(out_pil):
185
+ out_buffer = io.BytesIO()
186
+ out_pil.save(out_buffer, format="PNG")
187
+ out_buffer.seek(0)
188
+ base64_bytes = base64.b64encode(out_buffer.read())
189
+ base64_str = base64_bytes.decode("ascii")
190
+ return base64_str
191
+ from subprocess import Popen, PIPE, STDOUT
192
+ class SubprocessCorrection:
193
+ def __init__(self):
194
+ self.backend=correction_func.backend
195
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
196
+ def run(self,img_input,img_inpainted,mode):
197
+ if mode=="disabled":
198
+ return img_inpainted
199
+ base64_str_input = pil_to_base64(img_input)
200
+ base64_str_inpainted = pil_to_base64(img_inpainted)
201
+ try:
202
+ if self.child.poll():
203
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
204
+ self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
205
+ self.child.stdin.flush()
206
+ out = self.child.stdout.readline()
207
+ base64_str=out.decode().strip()
208
+ while base64_str and base64_str[0]=="[":
209
+ print(base64_str)
210
+ out = self.child.stdout.readline()
211
+ base64_str=out.decode().strip()
212
+ ret=base64_to_pil(base64_str)
213
+ except:
214
+ print("[PIE] not working, photometric correction is disabled")
215
+ ret=img_inpainted
216
+ return ret
217
+ correction_func = SubprocessCorrection()