Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import json | |
| import time | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| import io | |
| # Pickwish API Key | |
| API_KEY = "wxqtrh2v6z4fv6lsl" | |
| # Initialize YOLO model | |
| def load_model(): | |
| return YOLO("../yolov9e.pt") | |
| def create_task(image_bytes, x1, y1, x2, y2): | |
| headers = {"X-API-KEY": API_KEY} | |
| data = { | |
| "sync": "0", | |
| "rectangles": json.dumps( | |
| [{"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}] | |
| ), | |
| } | |
| files = {"image_file": ("image.jpg", image_bytes, "image/jpeg")} | |
| url = "https://techhk.aoscdn.com/api/tasks/visual/inpaint" | |
| try: | |
| response = requests.post(url, headers=headers, data=data, files=files) | |
| response_json = response.json() | |
| if response_json.get("status") == 200 and "data" in response_json: | |
| return response_json["data"].get("task_id") | |
| except Exception as e: | |
| st.error(f"Error creating task: {e}") | |
| return None | |
| def polling_task_result(task_id, timeout=30): | |
| headers = {"X-API-KEY": API_KEY} | |
| url = f"https://techhk.aoscdn.com/api/tasks/visual/inpaint/{task_id}" | |
| for _ in range(timeout): | |
| time.sleep(1) | |
| try: | |
| response = requests.get(url, headers=headers) | |
| response_json = response.json() | |
| if response_json.get("status") == 200 and "data" in response_json: | |
| if response_json["data"].get("state") == 1: | |
| return response_json["data"].get("image") | |
| except Exception as e: | |
| st.error(f"Error polling result: {e}") | |
| break | |
| return None | |
| def detect_objects(image): | |
| model = load_model() | |
| detected_objects = [] | |
| # Convert PIL Image to numpy array | |
| img_array = np.array(image) | |
| # Run detection | |
| results = model(img_array) | |
| for i, box in enumerate(results[0].boxes): | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
| detected_objects.append((i, (x1, y1, x2, y2))) | |
| return results[0].plot(), detected_objects | |
| def main(): | |
| st.title("Object Removal with YOLO") | |
| # Initialize session state | |
| if "detected_objects" not in st.session_state: | |
| st.session_state.detected_objects = None | |
| if "original_image" not in st.session_state: | |
| st.session_state.original_image = None | |
| if "processed_image" not in st.session_state: | |
| st.session_state.processed_image = None | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file is not None: | |
| # Load and process image | |
| image = Image.open(uploaded_file) | |
| if ( | |
| st.session_state.original_image is None | |
| or uploaded_file != st.session_state.original_image | |
| ): | |
| st.session_state.original_image = uploaded_file | |
| processed_image, detected_objects = detect_objects(image) | |
| st.session_state.processed_image = processed_image | |
| st.session_state.detected_objects = detected_objects | |
| # Display the processed image | |
| st.image( | |
| st.session_state.original_image, | |
| caption="Original Image", | |
| use_container_width=True, | |
| ) | |
| # Ask user for object location | |
| st.write("Enter the approximate location of the object to remove:") | |
| x_input = st.number_input("X Coordinate", min_value=0, value=0) | |
| y_input = st.number_input("Y Coordinate", min_value=0, value=0) | |
| if st.button("Remove Object"): | |
| # Find the closest object to the entered coordinates | |
| if st.session_state.detected_objects: | |
| min_distance = float("inf") | |
| closest_object = None | |
| for obj in st.session_state.detected_objects: | |
| obj_id, (x1, y1, x2, y2) = obj | |
| center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 | |
| distance = np.sqrt( | |
| (center_x - x_input) ** 2 + (center_y - y_input) ** 2 | |
| ) | |
| if distance < min_distance: | |
| min_distance = distance | |
| closest_object = obj | |
| if closest_object: | |
| _, (x1, y1, x2, y2) = closest_object | |
| # Saving the masked image | |
| # img_array = np.array(image) | |
| # # Create a mask (black rectangle over selected bounding box) | |
| # masked_image = np.zeros_like(img_array, dtype=np.uint8) | |
| # cv2.rectangle(masked_image, (x1, y1), (x2, y2), (255, 255, 255), thickness=-1) | |
| # # Convert back to PIL image and save | |
| # masked_pil_image = Image.fromarray(masked_image) | |
| # masked_pil_image.save("masked_image.jpg") | |
| # st.success("Masked image saved as 'masked_image.jpg'") | |
| # st.image(masked_pil_image, caption="Masked Image (Before Removal)") | |
| # Convert image to bytes | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format="JPEG") | |
| img_byte_arr = img_byte_arr.getvalue() | |
| # Create task and get result | |
| task_id = create_task(img_byte_arr, x1, y1, x2, y2) | |
| if task_id: | |
| with st.spinner("Removing object..."): | |
| result_url = polling_task_result(task_id) | |
| if result_url: | |
| st.image(result_url, caption="Edited Image") | |
| else: | |
| st.error("Failed to get edited image") | |
| else: | |
| st.error("Failed to create removal task") | |
| else: | |
| st.error("No objects found near the entered coordinates.") | |
| if __name__ == "__main__": | |
| main() | |