File size: 8,815 Bytes
b147284 126d0b4 0e9299e f21aa39 0e9299e 8d71ca6 c99d0ab b147284 011ea0a 0e9299e 8d71ca6 b147284 c3c3dd4 b147284 648c268 c99d0ab a6e60bd c99d0ab b147284 172a17e c676f9b f21aa39 c676f9b aa64939 c676f9b aa64939 b147284 46db2a8 b147284 0e9299e 011ea0a b147284 011ea0a b147284 011ea0a b147284 0e9299e b147284 c7b01f0 0e9299e 92249de 876a83e e7e965c 876a83e 172a17e 11f70e4 b147284 126d0b4 172a17e 8029369 08a6710 126d0b4 2f2581c 9ccd0cc 2f2581c f21aa39 2f2581c c676f9b 2f2581c 126d0b4 bbbe3f7 c7b01f0 3f2393a c7b01f0 3f2393a c7b01f0 b4ddf81 52e1198 c7b01f0 bbbe3f7 c7b01f0 c8553c3 c7b01f0 c8553c3 bbbe3f7 2575860 bbbe3f7 c8553c3 b5ebf86 c8553c3 bbbe3f7 92249de c9a98ff 4f92585 c4c16e7 1f6c057 c9a98ff 08e68c1 c9a98ff 1f6c057 92249de 08e68c1 1f6c057 071911e 233f57a c7b01f0 bbbe3f7 172a17e 56cc275 b16906f 1b4308f 56cc275 2d7d767 b46ffe3 b16906f 56cc275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import os
import numpy as np
import streamlit as st
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from streamlit_js_eval import streamlit_js_eval
# Import the custom component for image coordinates
from streamlit_image_coordinates import streamlit_image_coordinates
# Import diffusers pipeline for Stable Diffusion inpainting
from diffusers import StableDiffusionInpaintPipeline
# Ultralytics provides the FastSAM model class
from ultralytics import FastSAM
# Set page config for a better mobile experience
st.set_page_config(page_title="Inpainting Demo", layout="centered")
page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,)
# Define model paths or IDs for easy switching in the future
FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights
SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5
# Helper function: center crop and resize to 768x512 (landscape)
def crop_resize_image(image, target_width=480, target_height=640):
desired_ratio = target_width / target_height # 768/512 = 1.5
width, height = image.size
current_ratio = width / height
# Crop horizontally if image is too wide
if current_ratio > desired_ratio:
new_width = int(height * desired_ratio)
left = (width - new_width) // 2
right = left + new_width
image = image.crop((left, 0, right, height))
# Crop vertically if image is too tall
elif current_ratio < desired_ratio:
new_height = int(width / desired_ratio)
top = (height - new_height) // 2
bottom = top + new_height
image = image.crop((0, top, width, bottom))
return image.resize((target_width, target_height))
# Ensure FastSAM model weights are available (download if not present)
if not os.path.exists(FASTSAM_CHECKPOINT):
# Download FastSAM weights (if not already in the repo)
# Here we use the official Ultralytics release URL for FastSAM-x (68MB).
import requests
fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt"
# st.write("Downloading FastSAM model weights...")
resp = requests.get(fastsam_url)
open(FASTSAM_CHECKPOINT, "wb").write(resp.content)
# Load models with caching to avoid reloading on each interaction
@st.cache_resource
def load_models():
# Load FastSAM model
fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # load the checkpoint
# Move FastSAM to GPU if available
# (Ultralytics will internally handle device assignment when calling the model)
# Load Stable Diffusion inpainting pipeline
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
SD_MODEL_ID,
torch_dtype=None # we'll let diffusers choose float16 if GPU is available
)
# Move pipeline to GPU for faster inference, if a GPU is available
sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# (Optional) Enable memory optimizations
sd_pipe.enable_attention_slicing() # improve memory usage
return fastsam_model, sd_pipe
# Initialize the models (this will run only once thanks to caching)
fastsam_model, sd_pipe = load_models()
# Ensure we have a state for removing_dots
if "is_removing_dot" not in st.session_state:
st.session_state.is_removing_dot = False
# Title
st.subheader("InteractiveInpainting")
# Camera input widget (opens device camera on mobile/desktop)
# picture = st.camera_input("Take a picture")
# picture = Image.new(mode="RGB", size=(512, 512), color=(153, 153, 255))
# Capture image from camera and process it
if "img" not in st.session_state:
enable = st.checkbox("Enable camera")
picture = st.camera_input("Take a picture", disabled=not enable)
if picture is not None:
img = Image.open(picture)
img = crop_resize_image(img, target_width=480, target_height=640)
st.session_state.img = img
# Reset coordinates list on new capture
st.session_state.coords_list = []
st.rerun()
else:
img = st.session_state.img
# Initialize the coordinates list if it doesn't exist.
if "coords_list" not in st.session_state:
st.session_state.coords_list = []
# --- Compute Segmentation Overlay ---
# If any points have been stored, run segmentation with FastSAM.
if st.session_state.coords_list:
points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list]
labels = [1] * len(points)
results = fastsam_model(img, points=points, labels=labels)
# Assume results[0].masks.data is a tensor with shape (N, H, W)
masks_tensor = results[0].masks.data
masks = masks_tensor.cpu().numpy()
if masks.ndim == 3 and masks.shape[0] > 0:
# Combine masks (logical OR via max)
combined_mask = np.max(masks, axis=0)
combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8))
# Resize the mask to ensure it matches the base image size
combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST)
# Create a red overlay with transparency
overlay = Image.new("RGBA", img.size, (255, 0, 0, 100))
base = img.convert("RGBA")
mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0)
overlay.putalpha(mask_alpha)
seg_overlay = Image.alpha_composite(base, overlay)
else:
seg_overlay = img.copy()
else:
seg_overlay = img.copy()
# --- Draw Red Dots on Top ---
final_img = seg_overlay.copy()
draw = ImageDraw.Draw(final_img)
for pt in st.session_state.coords_list:
cx, cy = int(pt["x"]), int(pt["y"])
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
# Get the original width from the image stored in session_state.
original_width = st.session_state.img.width # e.g. 480 from crop_resize_image
# Compute the scaling factor.
scale_factor = original_width / page_width
# Use the interactive component as the display canvas, showing the image with all dots.
new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always")
# Remap from displayed coordinate to original coordinate
if new_coord:
new_coord = {
"x": new_coord["x"] * scale_factor,
"y": new_coord["y"] * scale_factor
}
# If a new coordinate is received and it's not already in our list, add it and force a rerun.
if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot:
is_close = False
for coord in st.session_state.coords_list:
existing = np.array([coord["x"], coord["y"]])
new = np.array([new_coord["x"], new_coord["y"]])
if np.linalg.norm(existing - new) < 10:
is_close = True
break
if is_close:
st.session_state.coords_list.remove(coord)
st.session_state.is_removing_dot = True
else:
st.session_state.coords_list.append(new_coord)
st.rerun()
else:
st.session_state.is_removing_dot = False
st.write("Stored coordinates:", st.session_state.coords_list)
# --- 4) INPAINTING LOGIC ---
prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")
# If there's a prompt and we have at least one mask from the combined points, do inpainting
if prompt and combined_mask_img is not None:
combined_mask_img = combined_mask_img.convert("L")
# Dilate the mask: using a MaxFilter with a size (e.g. 5)
dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5))
# Blur the mask edges: adjust radius as needed (e.g. radius=3)
blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3))
if st.button("Run Inpainting"):
with st.spinner("Inpainting..."):
# Run Stable Diffusion Inpainting on the entire combined mask
inpainted_img = sd_pipe(
prompt=prompt,
image=img,
mask_image=combined_mask_img,
width=img.width,
height=img.height,
guidance_scale=8, # How strongly to follow the prompt
num_inference_steps=50
).images[0]
# Update the session image to the newly inpainted result
st.session_state.img = inpainted_img
# Optionally reset the points or keep them
st.session_state.coords_list = []
st.rerun() |