import streamlit as st import cv2 import numpy as np import requests import json import time from ultralytics import YOLO from PIL import Image import io # Pickwish API Key API_KEY = "wxqtrh2v6z4fv6lsl" # Initialize YOLO model @st.cache_resource def load_model(): return YOLO("../yolov9e.pt") def create_task(image_bytes, x1, y1, x2, y2): headers = {"X-API-KEY": API_KEY} data = { "sync": "0", "rectangles": json.dumps( [{"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}] ), } files = {"image_file": ("image.jpg", image_bytes, "image/jpeg")} url = "https://techhk.aoscdn.com/api/tasks/visual/inpaint" try: response = requests.post(url, headers=headers, data=data, files=files) response_json = response.json() if response_json.get("status") == 200 and "data" in response_json: return response_json["data"].get("task_id") except Exception as e: st.error(f"Error creating task: {e}") return None def polling_task_result(task_id, timeout=30): headers = {"X-API-KEY": API_KEY} url = f"https://techhk.aoscdn.com/api/tasks/visual/inpaint/{task_id}" for _ in range(timeout): time.sleep(1) try: response = requests.get(url, headers=headers) response_json = response.json() if response_json.get("status") == 200 and "data" in response_json: if response_json["data"].get("state") == 1: return response_json["data"].get("image") except Exception as e: st.error(f"Error polling result: {e}") break return None def detect_objects(image): model = load_model() detected_objects = [] # Convert PIL Image to numpy array img_array = np.array(image) # Run detection results = model(img_array) for i, box in enumerate(results[0].boxes): x1, y1, x2, y2 = map(int, box.xyxy[0]) detected_objects.append((i, (x1, y1, x2, y2))) return results[0].plot(), detected_objects def main(): st.title("Object Removal with YOLO") # Initialize session state if "detected_objects" not in st.session_state: st.session_state.detected_objects = None if "original_image" not in st.session_state: st.session_state.original_image = None if "processed_image" not in st.session_state: st.session_state.processed_image = None uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: # Load and process image image = Image.open(uploaded_file) if ( st.session_state.original_image is None or uploaded_file != st.session_state.original_image ): st.session_state.original_image = uploaded_file processed_image, detected_objects = detect_objects(image) st.session_state.processed_image = processed_image st.session_state.detected_objects = detected_objects # Display the processed image st.image( st.session_state.original_image, caption="Original Image", use_container_width=True, ) # Ask user for object location st.write("Enter the approximate location of the object to remove:") x_input = st.number_input("X Coordinate", min_value=0, value=0) y_input = st.number_input("Y Coordinate", min_value=0, value=0) if st.button("Remove Object"): # Find the closest object to the entered coordinates if st.session_state.detected_objects: min_distance = float("inf") closest_object = None for obj in st.session_state.detected_objects: obj_id, (x1, y1, x2, y2) = obj center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 distance = np.sqrt( (center_x - x_input) ** 2 + (center_y - y_input) ** 2 ) if distance < min_distance: min_distance = distance closest_object = obj if closest_object: _, (x1, y1, x2, y2) = closest_object # Saving the masked image # img_array = np.array(image) # # Create a mask (black rectangle over selected bounding box) # masked_image = np.zeros_like(img_array, dtype=np.uint8) # cv2.rectangle(masked_image, (x1, y1), (x2, y2), (255, 255, 255), thickness=-1) # # Convert back to PIL image and save # masked_pil_image = Image.fromarray(masked_image) # masked_pil_image.save("masked_image.jpg") # st.success("Masked image saved as 'masked_image.jpg'") # st.image(masked_pil_image, caption="Masked Image (Before Removal)") # Convert image to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format="JPEG") img_byte_arr = img_byte_arr.getvalue() # Create task and get result task_id = create_task(img_byte_arr, x1, y1, x2, y2) if task_id: with st.spinner("Removing object..."): result_url = polling_task_result(task_id) if result_url: st.image(result_url, caption="Edited Image") else: st.error("Failed to get edited image") else: st.error("Failed to create removal task") else: st.error("No objects found near the entered coordinates.") if __name__ == "__main__": main()