banyapon commited on
Commit
8550198
·
1 Parent(s): 2543583

Add main space file

Browse files
Files changed (3) hide show
  1. app.py +75 -0
  2. models/Unet_2020-10-30/weights.pth +3 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ import albumentations as albu
5
+ from pylab import imshow
6
+ import matplotlib.pyplot as plt
7
+ from diffusers import StableDiffusionInpaintPipeline
8
+ from PIL import Image
9
+ from iglovikov_helper_functions.utils.image_utils import load_rgb, pad, unpad
10
+ from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
11
+ from cloths_segmentation.pre_trained_models import create_model
12
+
13
+ # Load Cloth Segmentation Model (Ensure this is available)
14
+ model = create_model("Unet_2020-10-30")
15
+ model.eval()
16
+
17
+ # Load Inpainting Model (Ensure this is available)
18
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
19
+
20
+ def load_and_preprocess_image(image_path):
21
+ image = load_rgb(image_path)
22
+ padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT)
23
+ transform = albu.Compose([albu.Normalize(p=1)], p=1)
24
+ x = transform(image=padded_image)["image"]
25
+ x = torch.unsqueeze(tensor_from_rgb_image(x), 0)
26
+ return x, image, pads
27
+
28
+ def segment_cloth(image_tensor, model, pads):
29
+ with torch.no_grad():
30
+ prediction = model(image_tensor)[0][0]
31
+ mask = (prediction > 0).cpu().numpy().astype(np.uint8)
32
+ mask = unpad(mask, pads)
33
+ return mask
34
+
35
+ def perform_inpainting(image_path, mask_path, prompt):
36
+ image = Image.open(image_path)
37
+ mask_image = Image.open(mask_path).convert("L") # Convert to single-channel grayscale
38
+ mask_image = mask_image.resize(image.size) # Resize mask to match image
39
+
40
+ output_image = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
41
+ return output_image
42
+
43
+ def resize_and_upscale(image, new_width, new_height):
44
+ resized_img = cv2.resize(np.array(image), (new_width, new_height), interpolation=cv2.INTER_CUBIC)
45
+ return Image.fromarray(resized_img)
46
+
47
+ import gradio as gr
48
+
49
+ def image_segmentation_and_inpainting(image, prompt="Chinese Red and Golder Armor"):
50
+ x, image, pads = load_and_preprocess_image(image.name) # Gradio provides image.name for the path
51
+ mask = segment_cloth(x, model, pads)
52
+
53
+ # Save mask temporarily
54
+ mask_path = "temp_mask.jpg"
55
+ plt.imsave(mask_path, mask, cmap='gray')
56
+
57
+ output_image = perform_inpainting(image.name, mask_path, prompt)
58
+ output_image = resize_and_upscale(output_image, 1280, 720) # Adjust dimensions as needed
59
+ return output_image
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("# Cloth Image Segmentation and Inpainting")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ image_input = gr.Image(label="Upload Image")
68
+ prompt_input = gr.Textbox(label="Inpainting Prompt (Optional)", value="Chinese Red and Golder Armor")
69
+ run_button = gr.Button("Run")
70
+ with gr.Column():
71
+ image_output = gr.Image(label="Result")
72
+
73
+ run_button.click(fn=image_segmentation_and_inpainting, inputs=[image_input, prompt_input], outputs=image_output)
74
+
75
+ demo.launch(share=True)
models/Unet_2020-10-30/weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6e949a0e98e79fa1a814213bb73945e258699936f9352aedceda44247ab1f6
3
+ size 53237049
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ albumentations==1.2.1
3
+ matplotlib==3.6.2
4
+ diffusers==0.11.1
5
+ transformers==4.26.1
6
+ iglovikov_helper_functions # (Replace with the correct installation method if not on PyPI)
7
+ cloths_segmentation # (Replace with the correct installation method if not on PyPI)
8
+ opencv-python-headless==4.6.0.66
9
+ gradio==3.15.0