ObjectRemoval / app.py
AnubhutiBhardwaj's picture
feat: object removal app
1898e14
Raw
History Blame Contribute Delete
5.97 kB
import streamlit as st
import cv2
import numpy as np
import requests
import json
import time
from ultralytics import YOLO
from PIL import Image
import io
# Pickwish API Key
API_KEY = "wxqtrh2v6z4fv6lsl"
# Initialize YOLO model
@st.cache_resource
def load_model():
return YOLO("../yolov9e.pt")
def create_task(image_bytes, x1, y1, x2, y2):
headers = {"X-API-KEY": API_KEY}
data = {
"sync": "0",
"rectangles": json.dumps(
[{"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}]
),
}
files = {"image_file": ("image.jpg", image_bytes, "image/jpeg")}
url = "https://techhk.aoscdn.com/api/tasks/visual/inpaint"
try:
response = requests.post(url, headers=headers, data=data, files=files)
response_json = response.json()
if response_json.get("status") == 200 and "data" in response_json:
return response_json["data"].get("task_id")
except Exception as e:
st.error(f"Error creating task: {e}")
return None
def polling_task_result(task_id, timeout=30):
headers = {"X-API-KEY": API_KEY}
url = f"https://techhk.aoscdn.com/api/tasks/visual/inpaint/{task_id}"
for _ in range(timeout):
time.sleep(1)
try:
response = requests.get(url, headers=headers)
response_json = response.json()
if response_json.get("status") == 200 and "data" in response_json:
if response_json["data"].get("state") == 1:
return response_json["data"].get("image")
except Exception as e:
st.error(f"Error polling result: {e}")
break
return None
def detect_objects(image):
model = load_model()
detected_objects = []
# Convert PIL Image to numpy array
img_array = np.array(image)
# Run detection
results = model(img_array)
for i, box in enumerate(results[0].boxes):
x1, y1, x2, y2 = map(int, box.xyxy[0])
detected_objects.append((i, (x1, y1, x2, y2)))
return results[0].plot(), detected_objects
def main():
st.title("Object Removal with YOLO")
# Initialize session state
if "detected_objects" not in st.session_state:
st.session_state.detected_objects = None
if "original_image" not in st.session_state:
st.session_state.original_image = None
if "processed_image" not in st.session_state:
st.session_state.processed_image = None
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Load and process image
image = Image.open(uploaded_file)
if (
st.session_state.original_image is None
or uploaded_file != st.session_state.original_image
):
st.session_state.original_image = uploaded_file
processed_image, detected_objects = detect_objects(image)
st.session_state.processed_image = processed_image
st.session_state.detected_objects = detected_objects
# Display the processed image
st.image(
st.session_state.original_image,
caption="Original Image",
use_container_width=True,
)
# Ask user for object location
st.write("Enter the approximate location of the object to remove:")
x_input = st.number_input("X Coordinate", min_value=0, value=0)
y_input = st.number_input("Y Coordinate", min_value=0, value=0)
if st.button("Remove Object"):
# Find the closest object to the entered coordinates
if st.session_state.detected_objects:
min_distance = float("inf")
closest_object = None
for obj in st.session_state.detected_objects:
obj_id, (x1, y1, x2, y2) = obj
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
distance = np.sqrt(
(center_x - x_input) ** 2 + (center_y - y_input) ** 2
)
if distance < min_distance:
min_distance = distance
closest_object = obj
if closest_object:
_, (x1, y1, x2, y2) = closest_object
# Saving the masked image
# img_array = np.array(image)
# # Create a mask (black rectangle over selected bounding box)
# masked_image = np.zeros_like(img_array, dtype=np.uint8)
# cv2.rectangle(masked_image, (x1, y1), (x2, y2), (255, 255, 255), thickness=-1)
# # Convert back to PIL image and save
# masked_pil_image = Image.fromarray(masked_image)
# masked_pil_image.save("masked_image.jpg")
# st.success("Masked image saved as 'masked_image.jpg'")
# st.image(masked_pil_image, caption="Masked Image (Before Removal)")
# Convert image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="JPEG")
img_byte_arr = img_byte_arr.getvalue()
# Create task and get result
task_id = create_task(img_byte_arr, x1, y1, x2, y2)
if task_id:
with st.spinner("Removing object..."):
result_url = polling_task_result(task_id)
if result_url:
st.image(result_url, caption="Edited Image")
else:
st.error("Failed to get edited image")
else:
st.error("Failed to create removal task")
else:
st.error("No objects found near the entered coordinates.")
if __name__ == "__main__":
main()