Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,19 +14,13 @@ from matplotlib import pyplot as plt
|
|
| 14 |
from torchvision import transforms
|
| 15 |
from diffusers import DiffusionPipeline
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
transform = transforms.Compose([
|
| 26 |
-
transforms.ToTensor(),
|
| 27 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 28 |
-
transforms.Resize((512, 512)),
|
| 29 |
-
])
|
| 30 |
|
| 31 |
def read_content(file_path: str) -> str:
|
| 32 |
"""read the content of target file
|
|
@@ -36,11 +30,11 @@ def read_content(file_path: str) -> str:
|
|
| 36 |
|
| 37 |
return content
|
| 38 |
|
| 39 |
-
def predict(dict,
|
| 40 |
-
init_image = dict["image"].convert("RGB")
|
| 41 |
-
mask = dict["mask"].convert("RGB")
|
| 42 |
-
|
| 43 |
-
return
|
| 44 |
|
| 45 |
|
| 46 |
css = '''
|
|
@@ -89,9 +83,9 @@ with image_blocks as demo:
|
|
| 89 |
with gr.Box():
|
| 90 |
with gr.Row():
|
| 91 |
with gr.Column():
|
| 92 |
-
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
|
|
|
|
| 93 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
| 94 |
-
prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
|
| 95 |
btn = gr.Button("Inpaint!").style(
|
| 96 |
margin=False,
|
| 97 |
rounded=(False, True, True, False),
|
|
@@ -105,7 +99,7 @@ with image_blocks as demo:
|
|
| 105 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
| 106 |
|
| 107 |
|
| 108 |
-
btn.click(fn=predict, inputs=[image,
|
| 109 |
share_button.click(None, [], [], _js=share_js)
|
| 110 |
|
| 111 |
|
|
|
|
| 14 |
from torchvision import transforms
|
| 15 |
from diffusers import DiffusionPipeline
|
| 16 |
|
| 17 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 18 |
+
"patrickvonplaten/new_inpaint_test",
|
| 19 |
+
torch_dtype=torch.float16,
|
| 20 |
+
)
|
| 21 |
+
pipe = pipe.to("cuda")
|
| 22 |
|
| 23 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def read_content(file_path: str) -> str:
|
| 26 |
"""read the content of target file
|
|
|
|
| 30 |
|
| 31 |
return content
|
| 32 |
|
| 33 |
+
def predict(dict, example_image):
|
| 34 |
+
init_image = dict["image"].convert("RGB")
|
| 35 |
+
mask = dict["mask"].convert("RGB")
|
| 36 |
+
image = pipe(image=init_image, mask_image=mask, example_image=example_image).images[0]
|
| 37 |
+
return image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
| 38 |
|
| 39 |
|
| 40 |
css = '''
|
|
|
|
| 83 |
with gr.Box():
|
| 84 |
with gr.Row():
|
| 85 |
with gr.Column():
|
| 86 |
+
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
|
| 87 |
+
example = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Upload")
|
| 88 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
|
|
|
| 89 |
btn = gr.Button("Inpaint!").style(
|
| 90 |
margin=False,
|
| 91 |
rounded=(False, True, True, False),
|
|
|
|
| 99 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
| 100 |
|
| 101 |
|
| 102 |
+
btn.click(fn=predict, inputs=[image, example], outputs=[image_out, community_icon, loading_icon, share_button])
|
| 103 |
share_button.click(None, [], [], _js=share_js)
|
| 104 |
|
| 105 |
|