File size: 8,815 Bytes
b147284
126d0b4
0e9299e
f21aa39
0e9299e
8d71ca6
c99d0ab
 
 
b147284
 
 
 
 
011ea0a
0e9299e
8d71ca6
b147284
 
 
 
c3c3dd4
b147284
648c268
c99d0ab
a6e60bd
c99d0ab
b147284
 
 
172a17e
c676f9b
f21aa39
c676f9b
aa64939
c676f9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa64939
b147284
 
 
 
 
 
46db2a8
b147284
 
 
 
0e9299e
011ea0a
b147284
 
 
 
011ea0a
b147284
 
 
 
011ea0a
b147284
 
 
 
 
0e9299e
b147284
c7b01f0
0e9299e
92249de
 
 
 
876a83e
e7e965c
876a83e
172a17e
11f70e4
b147284
126d0b4
 
172a17e
8029369
 
 
08a6710
126d0b4
2f2581c
9ccd0cc
 
2f2581c
 
f21aa39
2f2581c
c676f9b
 
2f2581c
 
 
126d0b4
 
bbbe3f7
 
 
 
c7b01f0
 
 
 
 
 
 
 
 
 
3f2393a
c7b01f0
 
3f2393a
 
c7b01f0
 
 
b4ddf81
 
 
52e1198
c7b01f0
 
 
 
 
 
 
 
 
 
bbbe3f7
 
c7b01f0
c8553c3
 
c7b01f0
c8553c3
 
bbbe3f7
2575860
bbbe3f7
c8553c3
b5ebf86
 
 
 
 
c8553c3
bbbe3f7
92249de
c9a98ff
4f92585
c4c16e7
 
1f6c057
c9a98ff
08e68c1
c9a98ff
1f6c057
92249de
08e68c1
 
1f6c057
071911e
 
233f57a
c7b01f0
bbbe3f7
 
172a17e
56cc275
 
 
 
 
b16906f
1b4308f
 
 
 
 
 
 
56cc275
 
 
 
 
 
2d7d767
 
b46ffe3
 
b16906f
56cc275
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import numpy as np
import streamlit as st
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from streamlit_js_eval import streamlit_js_eval



# Import the custom component for image coordinates
from streamlit_image_coordinates import streamlit_image_coordinates

# Import diffusers pipeline for Stable Diffusion inpainting
from diffusers import StableDiffusionInpaintPipeline


# Ultralytics provides the FastSAM model class
from ultralytics import FastSAM

# Set page config for a better mobile experience
st.set_page_config(page_title="Inpainting Demo", layout="centered")


page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH',  want_output = True,)


# Define model paths or IDs for easy switching in the future
FASTSAM_CHECKPOINT = "FastSAM-x.pt"  # file name of the FastSAM model weights
SD_MODEL_ID = "runwayml/stable-diffusion-inpainting"  # HF Hub model for SD Inpainting v1.5

# Helper function: center crop and resize to 768x512 (landscape)
def crop_resize_image(image, target_width=480, target_height=640):
    desired_ratio = target_width / target_height  # 768/512 = 1.5
    width, height = image.size
    current_ratio = width / height

    # Crop horizontally if image is too wide
    if current_ratio > desired_ratio:
        new_width = int(height * desired_ratio)
        left = (width - new_width) // 2
        right = left + new_width
        image = image.crop((left, 0, right, height))
    # Crop vertically if image is too tall
    elif current_ratio < desired_ratio:
        new_height = int(width / desired_ratio)
        top = (height - new_height) // 2
        bottom = top + new_height
        image = image.crop((0, top, width, bottom))
    
    return image.resize((target_width, target_height))

# Ensure FastSAM model weights are available (download if not present)
if not os.path.exists(FASTSAM_CHECKPOINT):
    # Download FastSAM weights (if not already in the repo)
    # Here we use the official Ultralytics release URL for FastSAM-x (68MB).
    import requests
    fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt"
    # st.write("Downloading FastSAM model weights...")
    resp = requests.get(fastsam_url)
    open(FASTSAM_CHECKPOINT, "wb").write(resp.content)

# Load models with caching to avoid reloading on each interaction
@st.cache_resource
def load_models():
    # Load FastSAM model
    fastsam_model = FastSAM(FASTSAM_CHECKPOINT)  # load the checkpoint
    # Move FastSAM to GPU if available
    # (Ultralytics will internally handle device assignment when calling the model)
    
    # Load Stable Diffusion inpainting pipeline
    sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
        SD_MODEL_ID,
        torch_dtype=None  # we'll let diffusers choose float16 if GPU is available
    )
    # Move pipeline to GPU for faster inference, if a GPU is available
    sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
    # (Optional) Enable memory optimizations
    sd_pipe.enable_attention_slicing()  # improve memory usage
    return fastsam_model, sd_pipe

# Initialize the models (this will run only once thanks to caching)
fastsam_model, sd_pipe = load_models()

# Ensure we have a state for removing_dots
if "is_removing_dot" not in st.session_state:
    st.session_state.is_removing_dot = False

# Title
st.subheader("InteractiveInpainting")



# Camera input widget (opens device camera on mobile/desktop)
# picture = st.camera_input("Take a picture")


# picture = Image.new(mode="RGB", size=(512, 512), color=(153, 153, 255))
# Capture image from camera and process it



if "img" not in st.session_state:
    enable = st.checkbox("Enable camera")
    picture = st.camera_input("Take a picture", disabled=not enable)
    if picture is not None:
        img = Image.open(picture)
        img = crop_resize_image(img, target_width=480, target_height=640)
        st.session_state.img = img
        # Reset coordinates list on new capture
        st.session_state.coords_list = []
        st.rerun()

else:
    img = st.session_state.img

    # Initialize the coordinates list if it doesn't exist.
    if "coords_list" not in st.session_state:
        st.session_state.coords_list = []

    # --- Compute Segmentation Overlay ---
    # If any points have been stored, run segmentation with FastSAM.
    if st.session_state.coords_list:
        points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list]
        labels = [1] * len(points)
        results = fastsam_model(img, points=points, labels=labels)
        # Assume results[0].masks.data is a tensor with shape (N, H, W)
        masks_tensor = results[0].masks.data
        masks = masks_tensor.cpu().numpy()
        if masks.ndim == 3 and masks.shape[0] > 0:
             # Combine masks (logical OR via max)
            combined_mask = np.max(masks, axis=0)
            combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8))
            # Resize the mask to ensure it matches the base image size
            combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST)
            # Create a red overlay with transparency
            overlay = Image.new("RGBA", img.size, (255, 0, 0, 100))
            base = img.convert("RGBA")
            mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0)
            overlay.putalpha(mask_alpha)

            seg_overlay = Image.alpha_composite(base, overlay)
        else:
            seg_overlay = img.copy()
    else:
        seg_overlay = img.copy()

    # --- Draw Red Dots on Top ---
    final_img = seg_overlay.copy()
    draw = ImageDraw.Draw(final_img)
    for pt in st.session_state.coords_list:
        cx, cy = int(pt["x"]), int(pt["y"])
        draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")

  
    # Get the original width from the image stored in session_state.
    original_width = st.session_state.img.width  # e.g. 480 from crop_resize_image

    # Compute the scaling factor.
    scale_factor = original_width / page_width
    # Use the interactive component as the display canvas, showing the image with all dots.
    new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always")

    # Remap from displayed coordinate to original coordinate
    if new_coord:
        new_coord = {
            "x": new_coord["x"] * scale_factor,
            "y": new_coord["y"] * scale_factor
        }

    # If a new coordinate is received and it's not already in our list, add it and force a rerun.
    if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot:
        is_close = False
        for coord in st.session_state.coords_list:
            existing = np.array([coord["x"], coord["y"]])
            new = np.array([new_coord["x"], new_coord["y"]])
            if  np.linalg.norm(existing - new) < 10:
                is_close = True
                break
        if is_close:
            st.session_state.coords_list.remove(coord)
            st.session_state.is_removing_dot = True
        else:
            st.session_state.coords_list.append(new_coord)
        st.rerun()
    else:
        st.session_state.is_removing_dot = False

    st.write("Stored coordinates:", st.session_state.coords_list)


    
   # --- 4) INPAINTING LOGIC ---
    prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")

    # If there's a prompt and we have at least one mask from the combined points, do inpainting
    if prompt and combined_mask_img is not None:

        combined_mask_img = combined_mask_img.convert("L")

        # Dilate the mask: using a MaxFilter with a size (e.g. 5)
        dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5))

        # Blur the mask edges: adjust radius as needed (e.g. radius=3)
        blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3))
        if st.button("Run Inpainting"):
            with st.spinner("Inpainting..."):
                # Run Stable Diffusion Inpainting on the entire combined mask
                inpainted_img = sd_pipe(
                    prompt=prompt,
                    image=img,
                    mask_image=combined_mask_img,
                    width=img.width,          
                    height=img.height,
                    guidance_scale=8,  # How strongly to follow the prompt
                    num_inference_steps=50
                ).images[0]

                # Update the session image to the newly inpainted result
                st.session_state.img = inpainted_img
                # Optionally reset the points or keep them
                st.session_state.coords_list = []
                st.rerun()