|
|
import os |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
from PIL import Image, ImageDraw, ImageFilter |
|
|
import numpy as np |
|
|
import torch |
|
|
from streamlit_js_eval import streamlit_js_eval |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from streamlit_image_coordinates import streamlit_image_coordinates |
|
|
|
|
|
|
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
|
|
|
|
|
|
from ultralytics import FastSAM |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Inpainting Demo", layout="centered") |
|
|
|
|
|
|
|
|
page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,) |
|
|
|
|
|
|
|
|
|
|
|
FASTSAM_CHECKPOINT = "FastSAM-x.pt" |
|
|
SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" |
|
|
|
|
|
|
|
|
def crop_resize_image(image, target_width=480, target_height=640): |
|
|
desired_ratio = target_width / target_height |
|
|
width, height = image.size |
|
|
current_ratio = width / height |
|
|
|
|
|
|
|
|
if current_ratio > desired_ratio: |
|
|
new_width = int(height * desired_ratio) |
|
|
left = (width - new_width) // 2 |
|
|
right = left + new_width |
|
|
image = image.crop((left, 0, right, height)) |
|
|
|
|
|
elif current_ratio < desired_ratio: |
|
|
new_height = int(width / desired_ratio) |
|
|
top = (height - new_height) // 2 |
|
|
bottom = top + new_height |
|
|
image = image.crop((0, top, width, bottom)) |
|
|
|
|
|
return image.resize((target_width, target_height)) |
|
|
|
|
|
|
|
|
if not os.path.exists(FASTSAM_CHECKPOINT): |
|
|
|
|
|
|
|
|
import requests |
|
|
fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt" |
|
|
|
|
|
resp = requests.get(fastsam_url) |
|
|
open(FASTSAM_CHECKPOINT, "wb").write(resp.content) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
|
|
|
fastsam_model = FastSAM(FASTSAM_CHECKPOINT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
|
SD_MODEL_ID, |
|
|
torch_dtype=None |
|
|
) |
|
|
|
|
|
sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
sd_pipe.enable_attention_slicing() |
|
|
return fastsam_model, sd_pipe |
|
|
|
|
|
|
|
|
fastsam_model, sd_pipe = load_models() |
|
|
|
|
|
|
|
|
if "is_removing_dot" not in st.session_state: |
|
|
st.session_state.is_removing_dot = False |
|
|
|
|
|
|
|
|
st.subheader("InteractiveInpainting") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "img" not in st.session_state: |
|
|
enable = st.checkbox("Enable camera") |
|
|
picture = st.camera_input("Take a picture", disabled=not enable) |
|
|
if picture is not None: |
|
|
img = Image.open(picture) |
|
|
img = crop_resize_image(img, target_width=480, target_height=640) |
|
|
st.session_state.img = img |
|
|
|
|
|
st.session_state.coords_list = [] |
|
|
st.rerun() |
|
|
|
|
|
else: |
|
|
img = st.session_state.img |
|
|
|
|
|
|
|
|
if "coords_list" not in st.session_state: |
|
|
st.session_state.coords_list = [] |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.coords_list: |
|
|
points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list] |
|
|
labels = [1] * len(points) |
|
|
results = fastsam_model(img, points=points, labels=labels) |
|
|
|
|
|
masks_tensor = results[0].masks.data |
|
|
masks = masks_tensor.cpu().numpy() |
|
|
if masks.ndim == 3 and masks.shape[0] > 0: |
|
|
|
|
|
combined_mask = np.max(masks, axis=0) |
|
|
combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8)) |
|
|
|
|
|
combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST) |
|
|
|
|
|
overlay = Image.new("RGBA", img.size, (255, 0, 0, 100)) |
|
|
base = img.convert("RGBA") |
|
|
mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0) |
|
|
overlay.putalpha(mask_alpha) |
|
|
|
|
|
seg_overlay = Image.alpha_composite(base, overlay) |
|
|
else: |
|
|
seg_overlay = img.copy() |
|
|
else: |
|
|
seg_overlay = img.copy() |
|
|
|
|
|
|
|
|
final_img = seg_overlay.copy() |
|
|
draw = ImageDraw.Draw(final_img) |
|
|
for pt in st.session_state.coords_list: |
|
|
cx, cy = int(pt["x"]), int(pt["y"]) |
|
|
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red") |
|
|
|
|
|
|
|
|
|
|
|
original_width = st.session_state.img.width |
|
|
|
|
|
|
|
|
scale_factor = original_width / page_width |
|
|
|
|
|
new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always") |
|
|
|
|
|
|
|
|
if new_coord: |
|
|
new_coord = { |
|
|
"x": new_coord["x"] * scale_factor, |
|
|
"y": new_coord["y"] * scale_factor |
|
|
} |
|
|
|
|
|
|
|
|
if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot: |
|
|
is_close = False |
|
|
for coord in st.session_state.coords_list: |
|
|
existing = np.array([coord["x"], coord["y"]]) |
|
|
new = np.array([new_coord["x"], new_coord["y"]]) |
|
|
if np.linalg.norm(existing - new) < 10: |
|
|
is_close = True |
|
|
break |
|
|
if is_close: |
|
|
st.session_state.coords_list.remove(coord) |
|
|
st.session_state.is_removing_dot = True |
|
|
else: |
|
|
st.session_state.coords_list.append(new_coord) |
|
|
st.rerun() |
|
|
else: |
|
|
st.session_state.is_removing_dot = False |
|
|
|
|
|
st.write("Stored coordinates:", st.session_state.coords_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):") |
|
|
|
|
|
|
|
|
if prompt and combined_mask_img is not None: |
|
|
|
|
|
combined_mask_img = combined_mask_img.convert("L") |
|
|
|
|
|
|
|
|
dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5)) |
|
|
|
|
|
|
|
|
blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
|
|
if st.button("Run Inpainting"): |
|
|
with st.spinner("Inpainting..."): |
|
|
|
|
|
inpainted_img = sd_pipe( |
|
|
prompt=prompt, |
|
|
image=img, |
|
|
mask_image=combined_mask_img, |
|
|
width=img.width, |
|
|
height=img.height, |
|
|
guidance_scale=8, |
|
|
num_inference_steps=50 |
|
|
).images[0] |
|
|
|
|
|
|
|
|
st.session_state.img = inpainted_img |
|
|
|
|
|
st.session_state.coords_list = [] |
|
|
st.rerun() |