Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -423,7 +423,11 @@ def register(state, drawpad, model):
|
|
| 423 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 424 |
print('Generate!')
|
| 425 |
|
| 426 |
-
background = drawpad['background']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
inpainting_mode = np.asarray(background).sum() != 0
|
| 428 |
if not inpainting_mode:
|
| 429 |
background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
|
|
@@ -432,9 +436,14 @@ def register(state, drawpad, model):
|
|
| 432 |
background_prompt = None
|
| 433 |
print('Inpainting mode: ', inpainting_mode)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
palette = torch.tensor([
|
| 440 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
|
|
|
| 423 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 424 |
print('Generate!')
|
| 425 |
|
| 426 |
+
background = drawpad['background']
|
| 427 |
+
if background is None:
|
| 428 |
+
background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
|
| 429 |
+
else:
|
| 430 |
+
background = background.convert('RGBA')
|
| 431 |
inpainting_mode = np.asarray(background).sum() != 0
|
| 432 |
if not inpainting_mode:
|
| 433 |
background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
|
|
|
|
| 436 |
background_prompt = None
|
| 437 |
print('Inpainting mode: ', inpainting_mode)
|
| 438 |
|
| 439 |
+
if drawpad['composite'] is None:
|
| 440 |
+
user_input = np.zeros((opt.height, opt.width, 4))
|
| 441 |
+
foreground_mask = torch.zeros((1, 1, opt.height, opt.width))
|
| 442 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
| 443 |
+
else:
|
| 444 |
+
user_input = np.asarray(drawpad['composite']) # (H, W, 4)
|
| 445 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
| 446 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
| 447 |
|
| 448 |
palette = torch.tensor([
|
| 449 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|