Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -320,7 +320,7 @@ def import_state(state, json_text):
|
|
| 320 |
### Main worker
|
| 321 |
|
| 322 |
|
| 323 |
-
def register(state, drawpad):
|
| 324 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 325 |
print('Generate!')
|
| 326 |
|
|
@@ -362,15 +362,15 @@ def register(state, drawpad):
|
|
| 362 |
# prompts, negative_prompts = preprocess_prompts(
|
| 363 |
# prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
| 364 |
|
| 365 |
-
|
| 366 |
background.convert('RGB'),
|
| 367 |
prompt=None,
|
| 368 |
negative_prompt=None,
|
| 369 |
)
|
| 370 |
-
state.prompts[0] =
|
| 371 |
-
state.neg_prompts[0] =
|
| 372 |
|
| 373 |
-
|
| 374 |
prompts=prompts,
|
| 375 |
negative_prompts=negative_prompts,
|
| 376 |
masks=masks.to(device),
|
|
@@ -384,23 +384,23 @@ def register(state, drawpad):
|
|
| 384 |
|
| 385 |
@spaces.GPU(duration=120)
|
| 386 |
def run(state, drawpad):
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
state.model.prepare()
|
| 392 |
|
| 393 |
-
state = register(state, drawpad)
|
| 394 |
state.is_running = True
|
| 395 |
|
| 396 |
tic = time.time()
|
| 397 |
while True:
|
| 398 |
-
yield [state,
|
| 399 |
toc = time.time()
|
| 400 |
tdelta = toc - tic
|
| 401 |
if tdelta > opt.run_time:
|
| 402 |
state.is_running = False
|
| 403 |
-
|
|
|
|
| 404 |
|
| 405 |
|
| 406 |
def hide_element():
|
|
@@ -412,7 +412,11 @@ def show_element():
|
|
| 412 |
|
| 413 |
|
| 414 |
def draw(state, drawpad):
|
|
|
|
|
|
|
|
|
|
| 415 |
if not state.is_running:
|
|
|
|
| 416 |
return
|
| 417 |
|
| 418 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
|
@@ -601,7 +605,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
|
|
| 601 |
state.model_id = opt.model
|
| 602 |
state.style_name = '(None)'
|
| 603 |
state.quality_name = 'Standard v3.1'
|
| 604 |
-
state.model =
|
| 605 |
|
| 606 |
# State variables (one-hot).
|
| 607 |
state.active_palettes = 5
|
|
|
|
| 320 |
### Main worker
|
| 321 |
|
| 322 |
|
| 323 |
+
def register(state, drawpad, model):
|
| 324 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
| 325 |
print('Generate!')
|
| 326 |
|
|
|
|
| 362 |
# prompts, negative_prompts = preprocess_prompts(
|
| 363 |
# prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
| 364 |
|
| 365 |
+
model.update_background(
|
| 366 |
background.convert('RGB'),
|
| 367 |
prompt=None,
|
| 368 |
negative_prompt=None,
|
| 369 |
)
|
| 370 |
+
state.prompts[0] = model.background.prompt
|
| 371 |
+
state.neg_prompts[0] = model.background.negative_prompt
|
| 372 |
|
| 373 |
+
model.update_layers(
|
| 374 |
prompts=prompts,
|
| 375 |
negative_prompts=negative_prompts,
|
| 376 |
masks=masks.to(device),
|
|
|
|
| 384 |
|
| 385 |
@spaces.GPU(duration=120)
|
| 386 |
def run(state, drawpad):
|
| 387 |
+
model.device = torch.device('cuda')
|
| 388 |
+
model.reset_seed(model.generator, opt.seed)
|
| 389 |
+
model.reset_latent()
|
| 390 |
+
model.prepare()
|
|
|
|
| 391 |
|
| 392 |
+
state = register(state, drawpad, model)
|
| 393 |
state.is_running = True
|
| 394 |
|
| 395 |
tic = time.time()
|
| 396 |
while True:
|
| 397 |
+
yield [state, model()]
|
| 398 |
toc = time.time()
|
| 399 |
tdelta = toc - tic
|
| 400 |
if tdelta > opt.run_time:
|
| 401 |
state.is_running = False
|
| 402 |
+
state.model = None
|
| 403 |
+
return [state, model()]
|
| 404 |
|
| 405 |
|
| 406 |
def hide_element():
|
|
|
|
| 412 |
|
| 413 |
|
| 414 |
def draw(state, drawpad):
|
| 415 |
+
if not hasattr(state, 'model') or state.model is None:
|
| 416 |
+
print('[WARNING] Model is not registered, update ignored.')
|
| 417 |
+
return
|
| 418 |
if not state.is_running:
|
| 419 |
+
print('[WARNING] Streaming is currently off, update ignored.')
|
| 420 |
return
|
| 421 |
|
| 422 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
|
|
|
| 605 |
state.model_id = opt.model
|
| 606 |
state.style_name = '(None)'
|
| 607 |
state.quality_name = 'Standard v3.1'
|
| 608 |
+
state.model = None
|
| 609 |
|
| 610 |
# State variables (one-hot).
|
| 611 |
state.active_palettes = 5
|