Spaces:
Runtime error
Runtime error
fix session state issue
Browse files- app.py +22 -12
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -119,7 +119,7 @@ class ImageComp:
|
|
| 119 |
self.baseoutput = output.astype(np.uint8)
|
| 120 |
return self.baseoutput
|
| 121 |
|
| 122 |
-
def
|
| 123 |
panoptic_mask_ = panoptic_mask + 1
|
| 124 |
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
| 125 |
mask_ = torch.tensor(mask_)
|
|
@@ -137,29 +137,29 @@ class ImageComp:
|
|
| 137 |
return final_mask, obj_class
|
| 138 |
|
| 139 |
|
| 140 |
-
def
|
| 141 |
input_pmask = self.input_pmask
|
| 142 |
input_segmask = self.input_segmask
|
| 143 |
|
| 144 |
if whole_ref:
|
| 145 |
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
| 146 |
else:
|
| 147 |
-
reference_mask, _ = self.
|
| 148 |
|
| 149 |
-
edit_mask, _ = self.
|
| 150 |
ma = torch.max(input_pmask)
|
| 151 |
input_pmask[edit_mask == 1] = ma + 1
|
| 152 |
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
| 153 |
|
| 154 |
|
| 155 |
-
def
|
| 156 |
input_img = (self.input_img/127.5 - 1)
|
| 157 |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 158 |
|
| 159 |
reference_img = (self.ref_img/127.5 - 1)
|
| 160 |
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 161 |
|
| 162 |
-
reference_mask, input_pmask, input_segmask, region_mask, ma = self.
|
| 163 |
|
| 164 |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
| 165 |
_, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
|
|
@@ -182,7 +182,7 @@ class ImageComp:
|
|
| 182 |
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
| 183 |
num_samples, ddim_steps, guess_mode, strength,
|
| 184 |
scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
|
| 185 |
-
structure, appearance, mask, img = self.
|
| 186 |
whole_ref=whole_ref, inter=inter)
|
| 187 |
|
| 188 |
null_structure = torch.zeros(structure.shape).cuda() - 1
|
|
@@ -242,6 +242,17 @@ class ImageComp:
|
|
| 242 |
return [] + results
|
| 243 |
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
css = """
|
| 246 |
h1 {
|
| 247 |
text-align: center;
|
|
@@ -293,14 +304,14 @@ def create_app_demo():
|
|
| 293 |
""")
|
| 294 |
with gr.Column():
|
| 295 |
with gr.Row():
|
| 296 |
-
img_edit = ImageComp('edit_app')
|
| 297 |
with gr.Column():
|
| 298 |
btn1 = gr.Button("Input Image")
|
| 299 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 300 |
with gr.Column():
|
| 301 |
btn2 = gr.Button("Select Object to Edit")
|
| 302 |
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
| 303 |
-
input_image.change(fn=
|
| 304 |
|
| 305 |
# with gr.Row():
|
| 306 |
with gr.Column():
|
|
@@ -310,7 +321,7 @@ def create_app_demo():
|
|
| 310 |
btn4 = gr.Button("Select Reference Object")
|
| 311 |
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
| 312 |
|
| 313 |
-
ref_img.change(fn=
|
| 314 |
|
| 315 |
with gr.Row():
|
| 316 |
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
|
@@ -325,7 +336,6 @@ def create_app_demo():
|
|
| 325 |
|
| 326 |
with gr.Accordion("Advanced options", open=False):
|
| 327 |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 328 |
-
image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
| 329 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 330 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 331 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
|
@@ -351,7 +361,7 @@ def create_app_demo():
|
|
| 351 |
)
|
| 352 |
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 353 |
scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
|
| 354 |
-
run_button.click(fn=
|
| 355 |
|
| 356 |
|
| 357 |
|
|
|
|
| 119 |
self.baseoutput = output.astype(np.uint8)
|
| 120 |
return self.baseoutput
|
| 121 |
|
| 122 |
+
def _process_mask(self, mask, panoptic_mask, segmask):
|
| 123 |
panoptic_mask_ = panoptic_mask + 1
|
| 124 |
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
| 125 |
mask_ = torch.tensor(mask_)
|
|
|
|
| 137 |
return final_mask, obj_class
|
| 138 |
|
| 139 |
|
| 140 |
+
def _edit_app(self, input_mask, ref_mask, whole_ref):
|
| 141 |
input_pmask = self.input_pmask
|
| 142 |
input_segmask = self.input_segmask
|
| 143 |
|
| 144 |
if whole_ref:
|
| 145 |
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
| 146 |
else:
|
| 147 |
+
reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
|
| 148 |
|
| 149 |
+
edit_mask, _ = self._process_mask(input_mask, self.input_pmask, self.input_segmask)
|
| 150 |
ma = torch.max(input_pmask)
|
| 151 |
input_pmask[edit_mask == 1] = ma + 1
|
| 152 |
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
| 153 |
|
| 154 |
|
| 155 |
+
def _edit(self, input_mask, ref_mask, whole_ref=False, inter=1):
|
| 156 |
input_img = (self.input_img/127.5 - 1)
|
| 157 |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 158 |
|
| 159 |
reference_img = (self.ref_img/127.5 - 1)
|
| 160 |
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 161 |
|
| 162 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(input_mask, ref_mask, whole_ref)
|
| 163 |
|
| 164 |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
| 165 |
_, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
|
|
|
|
| 182 |
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
| 183 |
num_samples, ddim_steps, guess_mode, strength,
|
| 184 |
scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
|
| 185 |
+
structure, appearance, mask, img = self._edit(input_mask, ref_mask,
|
| 186 |
whole_ref=whole_ref, inter=inter)
|
| 187 |
|
| 188 |
null_structure = torch.zeros(structure.shape).cuda() - 1
|
|
|
|
| 242 |
return [] + results
|
| 243 |
|
| 244 |
|
| 245 |
+
def init_input_canvas_wrapper(obj, *args):
|
| 246 |
+
return obj.init_input_canvas(*args)
|
| 247 |
+
|
| 248 |
+
def init_ref_canvas_wrapper(obj, *args):
|
| 249 |
+
return obj.init_ref_canvas(*args)
|
| 250 |
+
|
| 251 |
+
def process_wrapper(obj, *args):
|
| 252 |
+
return obj.process(*args)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
css = """
|
| 257 |
h1 {
|
| 258 |
text-align: center;
|
|
|
|
| 304 |
""")
|
| 305 |
with gr.Column():
|
| 306 |
with gr.Row():
|
| 307 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
| 308 |
with gr.Column():
|
| 309 |
btn1 = gr.Button("Input Image")
|
| 310 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 311 |
with gr.Column():
|
| 312 |
btn2 = gr.Button("Select Object to Edit")
|
| 313 |
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
| 314 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
|
| 315 |
|
| 316 |
# with gr.Row():
|
| 317 |
with gr.Column():
|
|
|
|
| 321 |
btn4 = gr.Button("Select Reference Object")
|
| 322 |
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
| 323 |
|
| 324 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
|
| 325 |
|
| 326 |
with gr.Row():
|
| 327 |
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
|
|
|
| 336 |
|
| 337 |
with gr.Accordion("Advanced options", open=False):
|
| 338 |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
|
|
|
| 339 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 340 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 341 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
|
|
|
| 361 |
)
|
| 362 |
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 363 |
scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
|
| 364 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
| 365 |
|
| 366 |
|
| 367 |
|
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
addict==2.4.0
|
| 2 |
albumentations==1.3.0
|
| 3 |
einops==0.3.0
|
| 4 |
-
gradio==3.
|
| 5 |
imageio==2.9.0
|
| 6 |
imageio-ffmpeg==0.4.2
|
| 7 |
kornia==0.6.0
|
|
|
|
| 1 |
addict==2.4.0
|
| 2 |
albumentations==1.3.0
|
| 3 |
einops==0.3.0
|
| 4 |
+
gradio==3.25.0
|
| 5 |
imageio==2.9.0
|
| 6 |
imageio-ffmpeg==0.4.2
|
| 7 |
kornia==0.6.0
|