mjohanes commited on
Commit
b147284
·
1 Parent(s): d400066
Files changed (2) hide show
  1. app.py +91 -94
  2. requirements.txt +7 -21
app.py CHANGED
@@ -1,110 +1,107 @@
 
1
  import streamlit as st
 
2
  import numpy as np
3
- from PIL import Image
4
- import torch
 
 
 
5
  from diffusers import StableDiffusionInpaintPipeline
6
- from fastsam import FastSAM, FastSAMPrompt
7
- from huggingface_hub import hf_hub_download
8
 
9
- # Initialize session state
10
- if "points" not in st.session_state:
11
- st.session_state.points = []
 
 
 
 
 
 
12
 
13
- # Load models with caching
 
 
 
 
 
 
 
 
 
 
14
  @st.cache_resource
15
  def load_models():
16
- fastsam_path = hf_hub_download(
17
- repo_id="An-619/FastSAM",
18
- filename="FastSAM.pt",
19
- repo_type="model"
20
- )
21
- fastsam_model = FastSAM(fastsam_path)
22
 
23
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
24
- "stabilityai/stable-diffusion-2-inpainting",
25
- torch_dtype=torch.float16,
 
26
  )
27
- if torch.cuda.is_available():
28
- pipe = pipe.to("cuda")
29
- return fastsam_model, pipe
30
-
31
- fastsam, pipe = load_models()
32
-
33
- st.title("Mobile Inpainting with Point Selection")
34
- st.write("1. Take photo 2. Select points 3. Enter prompt")
35
 
36
- # Camera input
37
- img_file = st.camera_input("Take a picture:")
38
 
39
- def add_point(x, y):
40
- st.session_state.points.append((x, y))
 
 
 
 
 
41
 
42
- def clear_points():
43
- st.session_state.points = []
44
 
45
- if img_file:
46
- img = Image.open(img_file).convert("RGB")
47
- w, h = img.size
 
48
 
49
- # Display image with click coordinates
50
- st.image(img, caption="Original Image")
51
- st.write(f"Image dimensions: {w}x{h} pixels")
52
-
53
- # Point input controls
54
- col1, col2 = st.columns(2)
55
- with col1:
56
- x = st.number_input("X coordinate", 0, w-1, w//2)
57
- with col2:
58
- y = st.number_input("Y coordinate", 0, h-1, h//2)
59
-
60
- st.button("Add Point", on_click=add_point, args=(x, y))
61
- st.button("Clear Points", on_click=clear_points)
 
 
62
 
63
- # Show selected points
64
- if st.session_state.points:
65
- st.write("Selected points (x,y):")
66
- st.write(st.session_state.points)
 
 
 
 
 
 
 
 
67
 
68
- # Generate mask
69
- if st.button("Generate Mask"):
70
- img_np = np.array(img)
71
-
72
- with st.spinner("Segmenting..."):
73
- # Normalize points
74
- norm_points = [[x/w, y/h] for x, y in st.session_state.points]
75
-
76
- # FastSAM processing
77
- results = fastsam(
78
- img_np,
79
- device="cuda" if torch.cuda.is_available() else "cpu",
80
- imgsz=1024,
81
- conf=0.4,
82
- )
83
- prompt_process = FastSAMPrompt(img_np, results, device="cpu")
84
- mask = prompt_process.point_prompt(
85
- points=norm_points,
86
- pointlabel=[1]*len(norm_points)
87
- )
88
- mask = mask[0].astype(np.uint8) * 255
89
-
90
- st.image(mask, caption="Generated Mask")
91
-
92
- # Inpainting
93
- prompt = st.text_input("What should replace the selected area?")
94
- if prompt:
95
- with st.spinner("Generating result..."):
96
- img_512 = img.resize((512, 512))
97
- mask_512 = Image.fromarray(mask).resize((512, 512))
98
-
99
- result = pipe(
100
- prompt=prompt,
101
- image=img_512,
102
- mask_image=mask_512,
103
- num_inference_steps=30,
104
- ).images[0]
105
-
106
- st.image(result, caption="Final Result")
107
- else:
108
- st.warning("Add points to create a mask")
109
- else:
110
- st.info("Take a photo to begin")
 
1
+ import os
2
  import streamlit as st
3
+ from PIL import Image, ImageDraw
4
  import numpy as np
5
+
6
+ # Import the custom component for image coordinates
7
+ from streamlit_image_coordinates import streamlit_image_coordinates
8
+
9
+ # Import diffusers pipeline for Stable Diffusion inpainting
10
  from diffusers import StableDiffusionInpaintPipeline
 
 
11
 
12
+ # Ultralytics provides the FastSAM model class
13
+ from ultralytics import FastSAM
14
+
15
+ # Set page config for a better mobile experience
16
+ st.set_page_config(page_title="Inpainting Demo", layout="wide")
17
+
18
+ # Define model paths or IDs for easy switching in the future
19
+ FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights
20
+ SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5
21
 
22
+ # Ensure FastSAM model weights are available (download if not present)
23
+ if not os.path.exists(FASTSAM_CHECKPOINT):
24
+ # Download FastSAM weights (if not already in the repo)
25
+ # Here we use the official Ultralytics release URL for FastSAM-x (68MB).
26
+ import requests
27
+ fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt"
28
+ st.write("Downloading FastSAM model weights...")
29
+ resp = requests.get(fastsam_url)
30
+ open(FASTSAM_CHECKPOINT, "wb").write(resp.content)
31
+
32
+ # Load models with caching to avoid reloading on each interaction
33
  @st.cache_resource
34
  def load_models():
35
+ # Load FastSAM model
36
+ fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # load the checkpoint
37
+ # Move FastSAM to GPU if available
38
+ # (Ultralytics will internally handle device assignment when calling the model)
 
 
39
 
40
+ # Load Stable Diffusion inpainting pipeline
41
+ sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
42
+ SD_MODEL_ID,
43
+ torch_dtype=None # we'll let diffusers choose float16 if GPU is available
44
  )
45
+ # Move pipeline to GPU for faster inference, if a GPU is available
46
+ sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
47
+ # (Optional) Enable memory optimizations
48
+ sd_pipe.enable_attention_slicing() # improve memory usage
49
+ return fastsam_model, sd_pipe
 
 
 
50
 
51
+ # Initialize the models (this will run only once thanks to caching)
52
+ fastsam_model, sd_pipe = load_models()
53
 
54
+ # Title and instructions
55
+ st.title("📱 Mobile Inpainting with FastSAM and Stable Diffusion")
56
+ st.markdown(
57
+ "1. **Capture** an image using the camera.\n"
58
+ "2. **Tap** on an object in the image to select it.\n"
59
+ "3. **Describe** what should replace it, and let the app do the rest!"
60
+ )
61
 
62
+ # Camera input widget (opens device camera on mobile/desktop)
63
+ picture = st.camera_input("Take a picture")
64
 
65
+ if picture is not None:
66
+ # When an image is captured, display it and allow point selection
67
+ img = Image.open(picture) # read image as PIL
68
+ st.image(img, caption="Captured Image", use_column_width=True)
69
 
70
+ # Let user click a point on the image. This returns a dict with 'x' and 'y'.
71
+ coords = streamlit_image_coordinates(img, key="click_img")
72
+ if coords:
73
+ # If a point was clicked, mark it on the image for user feedback
74
+ cx, cy = int(coords['x']), int(coords['y'])
75
+ # Draw a small red circle on the image copy to show selected point
76
+ img_with_dot = img.copy()
77
+ draw = ImageDraw.Draw(img_with_dot)
78
+ draw.ellipse((cx-5, cy-5, cx+5, cy+5), fill='red')
79
+ st.image(img_with_dot, caption=f"Selected Point: ({cx}, {cy})", use_column_width=True)
80
+ else:
81
+ cx = cy = None
82
+
83
+ # Prompt input for inpainting
84
+ prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")
85
 
86
+ # Only proceed when a point is selected and prompt is provided
87
+ if coords and prompt:
88
+ cx, cy = int(coords['x']), int(coords['y'])
89
+ st.write("Generating mask with FastSAM...")
90
+ # Run FastSAM segmentation with the selected point as prompt
91
+ # Using the Ultralytics API: points=[[x,y]] and labels=[1] for a positive point prompt
92
+ results = fastsam_model(img, points=[[cx, cy]], labels=[1])
93
+ # The results object holds masks; extract the first mask (closest object to the point)
94
+ mask_data = results[0].masks.data[0] # mask tensor (H x W)
95
+ mask_array = mask_data.cpu().numpy() # convert to numpy array
96
+ # Create a PIL Image for the mask: convert 1.0 to 255 (white), 0.0 to 0 (black)
97
+ mask_image = Image.fromarray((mask_array * 255).astype(np.uint8))
98
 
99
+ # For debugging, we can display the mask – uncomment if needed
100
+ # st.image(mask_image, caption="Segmentation Mask", use_column_width=True)
101
+
102
+ st.write("Running Stable Diffusion Inpainting...")
103
+ # Run the Stable Diffusion inpainting pipeline
104
+ result = sd_pipe(prompt=prompt, image=img, mask_image=mask_image).images[0]
105
+
106
+ # Display the final inpainted image
107
+ st.image(result, caption="Inpainted Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,21 +1,7 @@
1
- # Core dependencies
2
- streamlit>=1.28.0
3
- torch
4
- torchvision
5
- ultralytics==8.0.121 # Critical for FastSAM
6
-
7
- # Hugging Face ecosystem
8
- diffusers>=0.19.0
9
- transformers>=4.34.0
10
- huggingface-hub>=0.17.0
11
-
12
- # Image processing
13
- opencv-python>=4.7.0.72
14
- matplotlib>=3.7.2
15
- pillow>=9.5.0
16
-
17
- # Additional
18
- accelerate>=0.24.0
19
-
20
- # FastSAM from specific commit
21
- git+https://github.com/CASIA-IVA-Lab/FastSAM.git
 
1
+ streamlit==1.x
2
+ streamlit-image-coordinates==0.2.0 # component for getting click coordinates on images
3
+ ultralytics==8.0.134 # includes FastSAM integration
4
+ diffusers==0.17.0 # for Stable Diffusion pipeline
5
+ transformers==4.30.2 # for Stable Diffusion text encoder
6
+ accelerate==0.20.3 # helps with model acceleration
7
+ torch # PyTorch (will auto-select a CUDA version on GPU)