Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -51,13 +51,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|
| 51 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 52 |
sd = pl_sd["state_dict"]
|
| 53 |
model = instantiate_from_config(config.model)
|
| 54 |
-
# m, u = model.load_state_dict(sd, strict=False)
|
| 55 |
-
# if len(m) > 0 and verbose:
|
| 56 |
-
# print("missing keys:")
|
| 57 |
-
# print(m)
|
| 58 |
-
# if len(u) > 0 and verbose:
|
| 59 |
-
# print("unexpected keys:")
|
| 60 |
-
# print(u)
|
| 61 |
model.to(device)
|
| 62 |
model.eval()
|
| 63 |
return model
|
|
@@ -280,7 +273,6 @@ def inference(input_prompt, input_category):
|
|
| 280 |
data = [batch_size * [prompt]]
|
| 281 |
|
| 282 |
else:
|
| 283 |
-
# print(f"reading prompts from {opt.from_file}")
|
| 284 |
with open(opt.from_file, "r") as f:
|
| 285 |
data = f.read().splitlines()
|
| 286 |
data = list(chunk(data, batch_size))
|
|
@@ -290,7 +282,6 @@ def inference(input_prompt, input_category):
|
|
| 290 |
|
| 291 |
start_code = None
|
| 292 |
if opt.fixed_code:
|
| 293 |
-
# print('start_code')
|
| 294 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
| 295 |
for n in trange(opt.n_iter, desc="Sampling"):
|
| 296 |
for prompts in tqdm(data, desc="data"):
|
|
@@ -320,7 +311,6 @@ def inference(input_prompt, input_category):
|
|
| 320 |
x_sample = torch.clamp((x_samples_ddim[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 321 |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
| 322 |
|
| 323 |
-
# Image.fromarray(x_sample.astype(np.uint8)).save("demo/demo.png")
|
| 324 |
img = x_sample.astype(np.uint8)
|
| 325 |
|
| 326 |
class_name = trainclass
|
|
@@ -352,154 +342,23 @@ def inference(input_prompt, input_category):
|
|
| 352 |
mask = annotation_pred.numpy()
|
| 353 |
mask = np.expand_dims(mask, 0)
|
| 354 |
done_image_mask = plot_mask(img, mask, alpha=0.9, indexlist=[0])
|
| 355 |
-
# cv2.imwrite(os.path.join("demo/demo_mask.png"), done_image_mask)
|
| 356 |
-
|
| 357 |
-
# torchvision.utils.save_image(annotation_pred, os.path.join("demo/demo_segresult.png"), normalize=True, scale_each=True)
|
| 358 |
generated_image = x_sample.astype(np.uint8)
|
| 359 |
generated_mask = done_image_mask
|
| 360 |
return [generated_image, generated_mask]
|
| 361 |
|
| 362 |
|
| 363 |
-
# def make_transparent_foreground(pic, mask):
|
| 364 |
-
# # split the image into channels
|
| 365 |
-
# b, g, r = cv2.split(np.array(pic).astype('uint8'))
|
| 366 |
-
# # add an alpha channel with and fill all with transparent pixels (max 255)
|
| 367 |
-
# a = np.ones(mask.shape, dtype='uint8') * 255
|
| 368 |
-
# # merge the alpha channel back
|
| 369 |
-
# alpha_im = cv2.merge([b, g, r, a], 4)
|
| 370 |
-
# # create a transparent background
|
| 371 |
-
# bg = np.zeros(alpha_im.shape)
|
| 372 |
-
# # setup the new mask
|
| 373 |
-
# new_mask = np.stack([mask, mask, mask, mask], axis=2)
|
| 374 |
-
# # copy only the foreground color pixels from the original image where mask is set
|
| 375 |
-
# foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
|
| 376 |
-
|
| 377 |
-
# return foreground
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
# def remove_background(input_image):
|
| 381 |
-
# preprocess = transforms.Compose([
|
| 382 |
-
# transforms.ToTensor(),
|
| 383 |
-
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 384 |
-
# ])
|
| 385 |
-
|
| 386 |
-
# input_tensor = preprocess(input_image)
|
| 387 |
-
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
| 388 |
-
|
| 389 |
-
# # move the input and model to GPU for speed if available
|
| 390 |
-
# if torch.cuda.is_available():
|
| 391 |
-
# input_batch = input_batch.to('cuda')
|
| 392 |
-
# model.to('cuda')
|
| 393 |
-
|
| 394 |
-
# with torch.no_grad():
|
| 395 |
-
# output = model(input_batch)['out'][0]
|
| 396 |
-
# output_predictions = output.argmax(0)
|
| 397 |
-
|
| 398 |
-
# # create a binary (black and white) mask of the profile foreground
|
| 399 |
-
# mask = output_predictions.byte().cpu().numpy()
|
| 400 |
-
# background = np.zeros(mask.shape)
|
| 401 |
-
# bin_mask = np.where(mask, 255, background).astype(np.uint8)
|
| 402 |
-
|
| 403 |
-
# foreground = make_transparent_foreground(input_image, bin_mask)
|
| 404 |
-
|
| 405 |
-
# return foreground, bin_mask
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
# def inference(img):
|
| 409 |
-
# foreground, _ = remove_background(img)
|
| 410 |
-
# return foreground
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
# torch.hub.download_url_to_file('https://pbs.twimg.com/profile_images/691700243809718272/z7XZUARB_400x400.jpg',
|
| 414 |
-
# 'demis.jpg')
|
| 415 |
-
# torch.hub.download_url_to_file('https://hai.stanford.edu/sites/default/files/styles/person_medium/public/2020-03/hai_1512feifei.png?itok=INFuLABp',
|
| 416 |
-
# 'lifeifei.png')
|
| 417 |
-
# model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
|
| 418 |
-
# model.eval()
|
| 419 |
-
|
| 420 |
-
# gr.Interface(
|
| 421 |
-
# inference,
|
| 422 |
-
# gr.inputs.Textbox(label='Prompt', default='a photo of a lion on a mountain top at sunset'),
|
| 423 |
-
# gr.inputs.Textbox(label='category', default='lion'),
|
| 424 |
-
# gr.outputs.Image(type="pil", label="Output"),
|
| 425 |
-
# # title=title,
|
| 426 |
-
# # description=description,
|
| 427 |
-
# # article=article,
|
| 428 |
-
# # examples=[['demis.jpg'], ['lifeifei.png']],
|
| 429 |
-
# # enable_queue=True
|
| 430 |
-
# ).launch(debug=False)
|
| 431 |
-
|
| 432 |
def main():
|
| 433 |
|
| 434 |
-
# def load_example(
|
| 435 |
-
# steps: int,
|
| 436 |
-
# randomize_seed: bool,
|
| 437 |
-
# seed: int,
|
| 438 |
-
# randomize_cfg: bool,
|
| 439 |
-
# text_cfg_scale: float,
|
| 440 |
-
# image_cfg_scale: float,
|
| 441 |
-
# ):
|
| 442 |
-
# example_instruction = random.choice(example_instructions)
|
| 443 |
-
# return [example_image, example_instruction] + generate(
|
| 444 |
-
# example_image,
|
| 445 |
-
# example_instruction,
|
| 446 |
-
# steps,
|
| 447 |
-
# randomize_seed,
|
| 448 |
-
# seed,
|
| 449 |
-
# randomize_cfg,
|
| 450 |
-
# text_cfg_scale,
|
| 451 |
-
# image_cfg_scale,
|
| 452 |
-
# )
|
| 453 |
-
|
| 454 |
-
# def generate(
|
| 455 |
-
# input_image: Image.Image,
|
| 456 |
-
# instruction: str,
|
| 457 |
-
# steps: int,
|
| 458 |
-
# randomize_seed: bool,
|
| 459 |
-
# seed: int,
|
| 460 |
-
# randomize_cfg: bool,
|
| 461 |
-
# text_cfg_scale: float,
|
| 462 |
-
# image_cfg_scale: float,
|
| 463 |
-
# ):
|
| 464 |
-
# seed = random.randint(0, 100000) if randomize_seed else seed
|
| 465 |
-
# text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
|
| 466 |
-
# image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
|
| 467 |
-
|
| 468 |
-
# width, height = input_image.size
|
| 469 |
-
# factor = 512 / max(width, height)
|
| 470 |
-
# factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
| 471 |
-
# width = int((width * factor) // 64) * 64
|
| 472 |
-
# height = int((height * factor) // 64) * 64
|
| 473 |
-
# input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
| 474 |
-
|
| 475 |
-
# if instruction == "":
|
| 476 |
-
# return [input_image, seed]
|
| 477 |
-
|
| 478 |
-
# generator = torch.manual_seed(seed)
|
| 479 |
-
# edited_image = pipe(
|
| 480 |
-
# instruction, image=input_image,
|
| 481 |
-
# guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
|
| 482 |
-
# num_inference_steps=steps, generator=generator,
|
| 483 |
-
# ).images[0]
|
| 484 |
-
# return [seed, text_cfg_scale, image_cfg_scale, edited_image]
|
| 485 |
-
|
| 486 |
-
# def reset():
|
| 487 |
-
# return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None]
|
| 488 |
-
|
| 489 |
with gr.Blocks() as demo:
|
| 490 |
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
|
| 491 |
-
|
| 492 |
</h1>
|
| 493 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 494 |
<br/>
|
| 495 |
-
<a href="https://huggingface.co/spaces/
|
| 496 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
| 497 |
<p/>""")
|
| 498 |
with gr.Row():
|
| 499 |
-
# with gr.Column(scale=1, min_width=100):
|
| 500 |
-
# load_button = gr.Button("Load Example")
|
| 501 |
-
# with gr.Column(scale=1, min_width=100):
|
| 502 |
-
# reset_button = gr.Button("Reset")
|
| 503 |
with gr.Column(scale=3):
|
| 504 |
Prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
|
| 505 |
with gr.Column(scale=2):
|
|
@@ -513,40 +372,7 @@ def main():
|
|
| 513 |
generated_image.style(height=512, width=512)
|
| 514 |
generated_mask.style(height=512, width=512)
|
| 515 |
|
| 516 |
-
|
| 517 |
-
# steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
|
| 518 |
-
# randomize_seed = gr.Radio(
|
| 519 |
-
# ["Fix Seed", "Randomize Seed"],
|
| 520 |
-
# value="Randomize Seed",
|
| 521 |
-
# type="index",
|
| 522 |
-
# show_label=False,
|
| 523 |
-
# interactive=True,
|
| 524 |
-
# )
|
| 525 |
-
# seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
|
| 526 |
-
# randomize_cfg = gr.Radio(
|
| 527 |
-
# ["Fix CFG", "Randomize CFG"],
|
| 528 |
-
# value="Fix CFG",
|
| 529 |
-
# type="index",
|
| 530 |
-
# show_label=False,
|
| 531 |
-
# interactive=True,
|
| 532 |
-
# )
|
| 533 |
-
# text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
| 534 |
-
# image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
| 535 |
-
|
| 536 |
-
# gr.Markdown(help_text)
|
| 537 |
-
|
| 538 |
-
# load_button.click(
|
| 539 |
-
# fn=load_example,
|
| 540 |
-
# inputs=[
|
| 541 |
-
# steps,
|
| 542 |
-
# randomize_seed,
|
| 543 |
-
# seed,
|
| 544 |
-
# randomize_cfg,
|
| 545 |
-
# text_cfg_scale,
|
| 546 |
-
# image_cfg_scale,
|
| 547 |
-
# ],
|
| 548 |
-
# outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
|
| 549 |
-
# )
|
| 550 |
generate_button.click(
|
| 551 |
fn=inference,
|
| 552 |
inputs=[
|
|
@@ -555,11 +381,6 @@ def main():
|
|
| 555 |
],
|
| 556 |
outputs=[generated_image, generated_mask],
|
| 557 |
)
|
| 558 |
-
# reset_button.click(
|
| 559 |
-
# fn=reset,
|
| 560 |
-
# inputs=[],
|
| 561 |
-
# outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image],
|
| 562 |
-
# )
|
| 563 |
|
| 564 |
demo.queue(concurrency_count=1)
|
| 565 |
demo.launch(share=False)
|
|
|
|
| 51 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 52 |
sd = pl_sd["state_dict"]
|
| 53 |
model = instantiate_from_config(config.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
model.to(device)
|
| 55 |
model.eval()
|
| 56 |
return model
|
|
|
|
| 273 |
data = [batch_size * [prompt]]
|
| 274 |
|
| 275 |
else:
|
|
|
|
| 276 |
with open(opt.from_file, "r") as f:
|
| 277 |
data = f.read().splitlines()
|
| 278 |
data = list(chunk(data, batch_size))
|
|
|
|
| 282 |
|
| 283 |
start_code = None
|
| 284 |
if opt.fixed_code:
|
|
|
|
| 285 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
| 286 |
for n in trange(opt.n_iter, desc="Sampling"):
|
| 287 |
for prompts in tqdm(data, desc="data"):
|
|
|
|
| 311 |
x_sample = torch.clamp((x_samples_ddim[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 312 |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
| 313 |
|
|
|
|
| 314 |
img = x_sample.astype(np.uint8)
|
| 315 |
|
| 316 |
class_name = trainclass
|
|
|
|
| 342 |
mask = annotation_pred.numpy()
|
| 343 |
mask = np.expand_dims(mask, 0)
|
| 344 |
done_image_mask = plot_mask(img, mask, alpha=0.9, indexlist=[0])
|
|
|
|
|
|
|
|
|
|
| 345 |
generated_image = x_sample.astype(np.uint8)
|
| 346 |
generated_mask = done_image_mask
|
| 347 |
return [generated_image, generated_mask]
|
| 348 |
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
def main():
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
with gr.Blocks() as demo:
|
| 353 |
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
|
| 354 |
+
Guiding Text-to-Image Diffusion Model Towards Grounded Generation
|
| 355 |
</h1>
|
| 356 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 357 |
<br/>
|
| 358 |
+
<a href="https://huggingface.co/spaces/Purple11/Grounded-Diffusion?duplicate=true">
|
| 359 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
| 360 |
<p/>""")
|
| 361 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
with gr.Column(scale=3):
|
| 363 |
Prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
|
| 364 |
with gr.Column(scale=2):
|
|
|
|
| 372 |
generated_image.style(height=512, width=512)
|
| 373 |
generated_mask.style(height=512, width=512)
|
| 374 |
|
| 375 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
generate_button.click(
|
| 377 |
fn=inference,
|
| 378 |
inputs=[
|
|
|
|
| 381 |
],
|
| 382 |
outputs=[generated_image, generated_mask],
|
| 383 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
demo.queue(concurrency_count=1)
|
| 386 |
demo.launch(share=False)
|