Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -389,11 +389,11 @@ def register(state, drawpad, model):
|
|
| 389 |
return state
|
| 390 |
|
| 391 |
|
| 392 |
-
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
| 393 |
def run(state, drawpad):
|
| 394 |
# ZeroGPU hack.
|
| 395 |
-
listener = Listener(opt.address, authkey=opt.authkey)
|
| 396 |
-
conn = listener.accept()
|
| 397 |
|
| 398 |
# Reset model.
|
| 399 |
model.device = torch.device('cuda')
|
|
@@ -407,16 +407,16 @@ def run(state, drawpad):
|
|
| 407 |
tic = time.time()
|
| 408 |
while True:
|
| 409 |
# Receive real-time mask inputs from the main process.
|
| 410 |
-
data = conn.recv()
|
| 411 |
-
if data is not None:
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
|
| 421 |
yield [state, model()]
|
| 422 |
toc = time.time()
|
|
@@ -441,7 +441,7 @@ def draw(state, drawpad):
|
|
| 441 |
# return
|
| 442 |
|
| 443 |
# ZeroGPU hack.
|
| 444 |
-
conn = Client(opt.address, authkey=opt.authkey)
|
| 445 |
|
| 446 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 447 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
|
@@ -466,20 +466,20 @@ def draw(state, drawpad):
|
|
| 466 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
| 467 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
data = dict(
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
)
|
| 481 |
-
conn.send(data)
|
| 482 |
-
conn.close()
|
| 483 |
|
| 484 |
### Load examples
|
| 485 |
|
|
|
|
| 389 |
return state
|
| 390 |
|
| 391 |
|
| 392 |
+
# @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
| 393 |
def run(state, drawpad):
|
| 394 |
# ZeroGPU hack.
|
| 395 |
+
# listener = Listener(opt.address, authkey=opt.authkey)
|
| 396 |
+
# conn = listener.accept()
|
| 397 |
|
| 398 |
# Reset model.
|
| 399 |
model.device = torch.device('cuda')
|
|
|
|
| 407 |
tic = time.time()
|
| 408 |
while True:
|
| 409 |
# Receive real-time mask inputs from the main process.
|
| 410 |
+
# data = conn.recv()
|
| 411 |
+
# if data is not None:
|
| 412 |
+
# print('Received data!!!')
|
| 413 |
+
# for i in range(opt.max_palettes):
|
| 414 |
+
# model.update_single_layer(
|
| 415 |
+
# idx=i,
|
| 416 |
+
# mask=data['masks'][i],
|
| 417 |
+
# mask_strength=data['mask_strengths'][i],
|
| 418 |
+
# mask_std=data['mask_stds'][i],
|
| 419 |
+
# )
|
| 420 |
|
| 421 |
yield [state, model()]
|
| 422 |
toc = time.time()
|
|
|
|
| 441 |
# return
|
| 442 |
|
| 443 |
# ZeroGPU hack.
|
| 444 |
+
# conn = Client(opt.address, authkey=opt.authkey)
|
| 445 |
|
| 446 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
| 447 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
|
|
|
| 466 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
| 467 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
| 468 |
|
| 469 |
+
for i in range(len(has_masks)):
|
| 470 |
+
model.update_single_layer(
|
| 471 |
+
idx=i,
|
| 472 |
+
mask=masks[i],
|
| 473 |
+
mask_strength=mask_strengths[i],
|
| 474 |
+
mask_std=mask_stds[i],
|
| 475 |
+
)
|
| 476 |
+
# data = dict(
|
| 477 |
+
# masks=masks,
|
| 478 |
+
# mask_strengths=mask_strengths,
|
| 479 |
+
# mask_stds=mask_stds,
|
| 480 |
+
# )
|
| 481 |
+
# conn.send(data)
|
| 482 |
+
# conn.close()
|
| 483 |
|
| 484 |
### Load examples
|
| 485 |
|