ReDesign / app.py
mjohanes's picture
push app
b5ebf86
import os
import numpy as np
import streamlit as st
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from streamlit_js_eval import streamlit_js_eval
# Import the custom component for image coordinates
from streamlit_image_coordinates import streamlit_image_coordinates
# Import diffusers pipeline for Stable Diffusion inpainting
from diffusers import StableDiffusionInpaintPipeline
# Ultralytics provides the FastSAM model class
from ultralytics import FastSAM
# Set page config for a better mobile experience
st.set_page_config(page_title="Inpainting Demo", layout="centered")
page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,)
# Define model paths or IDs for easy switching in the future
FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights
SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5
# Helper function: center crop and resize to 768x512 (landscape)
def crop_resize_image(image, target_width=480, target_height=640):
desired_ratio = target_width / target_height # 768/512 = 1.5
width, height = image.size
current_ratio = width / height
# Crop horizontally if image is too wide
if current_ratio > desired_ratio:
new_width = int(height * desired_ratio)
left = (width - new_width) // 2
right = left + new_width
image = image.crop((left, 0, right, height))
# Crop vertically if image is too tall
elif current_ratio < desired_ratio:
new_height = int(width / desired_ratio)
top = (height - new_height) // 2
bottom = top + new_height
image = image.crop((0, top, width, bottom))
return image.resize((target_width, target_height))
# Ensure FastSAM model weights are available (download if not present)
if not os.path.exists(FASTSAM_CHECKPOINT):
# Download FastSAM weights (if not already in the repo)
# Here we use the official Ultralytics release URL for FastSAM-x (68MB).
import requests
fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt"
# st.write("Downloading FastSAM model weights...")
resp = requests.get(fastsam_url)
open(FASTSAM_CHECKPOINT, "wb").write(resp.content)
# Load models with caching to avoid reloading on each interaction
@st.cache_resource
def load_models():
# Load FastSAM model
fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # load the checkpoint
# Move FastSAM to GPU if available
# (Ultralytics will internally handle device assignment when calling the model)
# Load Stable Diffusion inpainting pipeline
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
SD_MODEL_ID,
torch_dtype=None # we'll let diffusers choose float16 if GPU is available
)
# Move pipeline to GPU for faster inference, if a GPU is available
sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# (Optional) Enable memory optimizations
sd_pipe.enable_attention_slicing() # improve memory usage
return fastsam_model, sd_pipe
# Initialize the models (this will run only once thanks to caching)
fastsam_model, sd_pipe = load_models()
# Ensure we have a state for removing_dots
if "is_removing_dot" not in st.session_state:
st.session_state.is_removing_dot = False
# Title
st.subheader("InteractiveInpainting")
# Camera input widget (opens device camera on mobile/desktop)
# picture = st.camera_input("Take a picture")
# picture = Image.new(mode="RGB", size=(512, 512), color=(153, 153, 255))
# Capture image from camera and process it
if "img" not in st.session_state:
enable = st.checkbox("Enable camera")
picture = st.camera_input("Take a picture", disabled=not enable)
if picture is not None:
img = Image.open(picture)
img = crop_resize_image(img, target_width=480, target_height=640)
st.session_state.img = img
# Reset coordinates list on new capture
st.session_state.coords_list = []
st.rerun()
else:
img = st.session_state.img
# Initialize the coordinates list if it doesn't exist.
if "coords_list" not in st.session_state:
st.session_state.coords_list = []
# --- Compute Segmentation Overlay ---
# If any points have been stored, run segmentation with FastSAM.
if st.session_state.coords_list:
points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list]
labels = [1] * len(points)
results = fastsam_model(img, points=points, labels=labels)
# Assume results[0].masks.data is a tensor with shape (N, H, W)
masks_tensor = results[0].masks.data
masks = masks_tensor.cpu().numpy()
if masks.ndim == 3 and masks.shape[0] > 0:
# Combine masks (logical OR via max)
combined_mask = np.max(masks, axis=0)
combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8))
# Resize the mask to ensure it matches the base image size
combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST)
# Create a red overlay with transparency
overlay = Image.new("RGBA", img.size, (255, 0, 0, 100))
base = img.convert("RGBA")
mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0)
overlay.putalpha(mask_alpha)
seg_overlay = Image.alpha_composite(base, overlay)
else:
seg_overlay = img.copy()
else:
seg_overlay = img.copy()
# --- Draw Red Dots on Top ---
final_img = seg_overlay.copy()
draw = ImageDraw.Draw(final_img)
for pt in st.session_state.coords_list:
cx, cy = int(pt["x"]), int(pt["y"])
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
# Get the original width from the image stored in session_state.
original_width = st.session_state.img.width # e.g. 480 from crop_resize_image
# Compute the scaling factor.
scale_factor = original_width / page_width
# Use the interactive component as the display canvas, showing the image with all dots.
new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always")
# Remap from displayed coordinate to original coordinate
if new_coord:
new_coord = {
"x": new_coord["x"] * scale_factor,
"y": new_coord["y"] * scale_factor
}
# If a new coordinate is received and it's not already in our list, add it and force a rerun.
if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot:
is_close = False
for coord in st.session_state.coords_list:
existing = np.array([coord["x"], coord["y"]])
new = np.array([new_coord["x"], new_coord["y"]])
if np.linalg.norm(existing - new) < 10:
is_close = True
break
if is_close:
st.session_state.coords_list.remove(coord)
st.session_state.is_removing_dot = True
else:
st.session_state.coords_list.append(new_coord)
st.rerun()
else:
st.session_state.is_removing_dot = False
st.write("Stored coordinates:", st.session_state.coords_list)
# --- 4) INPAINTING LOGIC ---
prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")
# If there's a prompt and we have at least one mask from the combined points, do inpainting
if prompt and combined_mask_img is not None:
combined_mask_img = combined_mask_img.convert("L")
# Dilate the mask: using a MaxFilter with a size (e.g. 5)
dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5))
# Blur the mask edges: adjust radius as needed (e.g. radius=3)
blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3))
if st.button("Run Inpainting"):
with st.spinner("Inpainting..."):
# Run Stable Diffusion Inpainting on the entire combined mask
inpainted_img = sd_pipe(
prompt=prompt,
image=img,
mask_image=combined_mask_img,
width=img.width,
height=img.height,
guidance_scale=8, # How strongly to follow the prompt
num_inference_steps=50
).images[0]
# Update the session image to the newly inpainted result
st.session_state.img = inpainted_img
# Optionally reset the points or keep them
st.session_state.coords_list = []
st.rerun()