Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,6 +31,7 @@ import glob
|
|
| 31 |
import pathlib
|
| 32 |
from functools import partial
|
| 33 |
from pprint import pprint
|
|
|
|
| 34 |
|
| 35 |
import numpy as np
|
| 36 |
from PIL import Image
|
|
@@ -150,6 +151,11 @@ opt.excluded_keys = ['inpainting_mode', 'is_running', 'active_palettes', 'curren
|
|
| 150 |
opt.prep_time = 20
|
| 151 |
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
### Event handlers
|
| 154 |
|
| 155 |
def add_palette(state):
|
|
@@ -385,6 +391,11 @@ def register(state, drawpad, model):
|
|
| 385 |
|
| 386 |
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
| 387 |
def run(state, drawpad):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
model.device = torch.device('cuda')
|
| 389 |
model.reset_seed(model.generator, opt.seed)
|
| 390 |
model.reset_latent()
|
|
@@ -395,6 +406,17 @@ def run(state, drawpad):
|
|
| 395 |
|
| 396 |
tic = time.time()
|
| 397 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
yield [state, model()]
|
| 399 |
toc = time.time()
|
| 400 |
tdelta = toc - tic
|
|
@@ -412,12 +434,14 @@ def show_element():
|
|
| 412 |
return gr.update(visible=True)
|
| 413 |
|
| 414 |
|
| 415 |
-
@spaces.GPU
|
| 416 |
def draw(state, drawpad):
|
| 417 |
if not state.is_running:
|
| 418 |
print('[WARNING] Streaming is currently off, update ignored.')
|
| 419 |
return
|
| 420 |
|
|
|
|
|
|
|
|
|
|
| 421 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 422 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
| 423 |
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
|
@@ -441,13 +465,15 @@ def draw(state, drawpad):
|
|
| 441 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
| 442 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
| 443 |
|
| 444 |
-
for i in range(len(has_masks)):
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
| 451 |
|
| 452 |
### Load examples
|
| 453 |
|
|
|
|
| 31 |
import pathlib
|
| 32 |
from functools import partial
|
| 33 |
from pprint import pprint
|
| 34 |
+
from multiprocessing.connection import Client, Listener
|
| 35 |
|
| 36 |
import numpy as np
|
| 37 |
from PIL import Image
|
|
|
|
| 151 |
opt.prep_time = 20
|
| 152 |
|
| 153 |
|
| 154 |
+
### Shared memory hack for ZeroGPU
|
| 155 |
+
opt.address = ('localhost', 6000)
|
| 156 |
+
= b'secret password'
|
| 157 |
+
|
| 158 |
+
|
| 159 |
### Event handlers
|
| 160 |
|
| 161 |
def add_palette(state):
|
|
|
|
| 391 |
|
| 392 |
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
| 393 |
def run(state, drawpad):
|
| 394 |
+
# ZeroGPU hack.
|
| 395 |
+
listener = Listener(opt.address, authkey=opt.authkey)
|
| 396 |
+
conn = listener.accept()
|
| 397 |
+
|
| 398 |
+
# Reset model.
|
| 399 |
model.device = torch.device('cuda')
|
| 400 |
model.reset_seed(model.generator, opt.seed)
|
| 401 |
model.reset_latent()
|
|
|
|
| 406 |
|
| 407 |
tic = time.time()
|
| 408 |
while True:
|
| 409 |
+
# Receive real-time mask inputs from the main process.
|
| 410 |
+
msg = conn.recv()
|
| 411 |
+
print(msg + ' Received!!!')
|
| 412 |
+
# for i in range(opt.max_palettes):
|
| 413 |
+
# model.update_single_layer(
|
| 414 |
+
# idx=i,
|
| 415 |
+
# mask=masks[i],
|
| 416 |
+
# mask_strength=mask_strengths[i],
|
| 417 |
+
# mask_std=mask_stds[i],
|
| 418 |
+
# )
|
| 419 |
+
|
| 420 |
yield [state, model()]
|
| 421 |
toc = time.time()
|
| 422 |
tdelta = toc - tic
|
|
|
|
| 434 |
return gr.update(visible=True)
|
| 435 |
|
| 436 |
|
|
|
|
| 437 |
def draw(state, drawpad):
|
| 438 |
if not state.is_running:
|
| 439 |
print('[WARNING] Streaming is currently off, update ignored.')
|
| 440 |
return
|
| 441 |
|
| 442 |
+
# ZeroGPU hack.
|
| 443 |
+
conn = Client(opt.address, authkey=opt.authkey)
|
| 444 |
+
|
| 445 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 446 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
| 447 |
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
|
|
|
| 465 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
| 466 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
| 467 |
|
| 468 |
+
# for i in range(len(has_masks)):
|
| 469 |
+
# model.update_single_layer(
|
| 470 |
+
# idx=i,
|
| 471 |
+
# mask=masks[i],
|
| 472 |
+
# mask_strength=mask_strengths[i],
|
| 473 |
+
# mask_std=mask_stds[i],
|
| 474 |
+
# )
|
| 475 |
+
conn.send('Hello!!!!')
|
| 476 |
+
conn.close()
|
| 477 |
|
| 478 |
### Load examples
|
| 479 |
|