dxcanh's picture
Rename test.py to app.py
2c9403e verified
import gradio as gr
import cv2
import numpy as np
import matplotlib.pyplot as plt
def extract_all(image: np.ndarray, area_threshold: int = 100, lower_thresh: int = 100, upper_thresh: int = 200) -> dict:
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
edges = cv2.Canny(blurred, lower_thresh, upper_thresh)
kernel = np.ones((3, 3), np.uint8)
closed_edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=3)
kernel = np.ones((5, 5), np.uint8)
closed_edges = cv2.dilate(closed_edges, kernel, iterations=1)
kernel = np.ones((3, 3), np.uint8)
closed_edges = cv2.morphologyEx(closed_edges, cv2.MORPH_CLOSE, kernel, iterations=2)
cv2.imwrite("canny_binary.jpg", closed_edges)
contours, _ = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
real_islands = {}
contour_id = 0
for contour in contours:
if cv2.contourArea(contour) > area_threshold:
mask = np.zeros_like(gray)
cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
pixels = list(zip(*np.where(mask == 255)))
real_islands[(pixels[0][0], pixels[0][1])] = pixels
contour_id += 1
print(f"Detected {len(real_islands)} islands from {len(contours)} contours")
return real_islands
def extract_object(image: np.ndarray, island: list[tuple]) -> np.ndarray:
coords = np.array(island)
min_y, min_x = coords.min(axis=0)
max_y, max_x = coords.max(axis=0)
height, width = max_y - min_y + 1, max_x - min_x + 1
num_channels = image.shape[2] if len(image.shape) == 3 else 1
result = np.zeros((height, width, num_channels), dtype=np.uint8)
y_coords = coords[:, 0] - min_y
x_coords = coords[:, 1] - min_x
result[y_coords, x_coords] = image[coords[:, 0], coords[:, 1]]
return result
def draw_bound(img: np.ndarray, top: int, down: int, left: int, right: int, size: int, color=(0, 255, 0)) -> np.ndarray:
img_copy = img.copy()
cv2.rectangle(img_copy, (left, top), (right, top + size), color, thickness=-1)
cv2.rectangle(img_copy, (left, down - size), (right, down), color, thickness=-1)
cv2.rectangle(img_copy, (left, top), (left + size, down), color, thickness=-1)
cv2.rectangle(img_copy, (right - size, top), (right, down), color, thickness=-1)
return img_copy
def compute_template_matching(img: np.ndarray, template: np.ndarray, method, mask: np.ndarray):
n_img = img.astype(np.uint8)
n_template = template.astype(np.uint8)
if np.std(n_template) == 0:
raise ValueError("Standard = 0")
if np.std(n_img) == 0:
raise ValueError("Standard = 0")
result = cv2.matchTemplate(n_img, n_template, method, mask=mask)
result = np.where(np.isinf(result), 0, result)
return result
def process_single_object_loop(img: np.ndarray, template: np.ndarray, method, mask: np.ndarray):
result = compute_template_matching(img, template, method, mask)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
top_left = max_loc
bound_image = draw_bound(
img,
top_left[1],
top_left[1] + template.shape[0],
top_left[0],
top_left[0] + template.shape[1],
8,
(0, 255, 0)
)
return max_val, result, bound_image, (top_left[1], top_left[0])
def process_template_at_scale(source: np.ndarray, template: np.ndarray, method, scale: float):
masked_template = template.copy().astype(np.uint8)
temp = cv2.medianBlur(masked_template.copy(), 5)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
temp = cv2.erode(temp, kernel, iterations=1)
_, mask = cv2.threshold(temp, 1, 255, cv2.THRESH_BINARY)
mask = cv2.resize(mask, (int(mask.shape[1] * scale), int(mask.shape[0] * scale)), interpolation=cv2.INTER_NEAREST_EXACT)
masked_template = cv2.resize(masked_template, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_NEAREST_EXACT)
local_max, result, bound_image, pos = process_single_object_loop(source.copy(), masked_template, method, mask.astype(np.uint8))
max_template = np.zeros_like(masked_template)
max_template[mask.astype(bool)] = masked_template[mask.astype(bool)]
return local_max, result, bound_image, max_template, pos
def process_images(source_img, objects_img, confidence_threshold=0.7):
if isinstance(source_img, np.ndarray):
source = source_img
else:
source = np.array(source_img)[:, :, ::-1] # RGB -> BGR
if isinstance(objects_img, np.ndarray):
objects = objects_img
else:
objects = np.array(objects_img)[:, :, ::-1] # RGB -> BGR
object_img = cv2.medianBlur(objects.copy(), 3)
islands = extract_all(object_img, area_threshold=100, lower_thresh=100, upper_thresh=200)
objects_extracted = []
for island in islands.values():
object_image = extract_object(objects, island)
objects_extracted.append(object_image)
result_image = source.copy()
method = cv2.TM_CCOEFF_NORMED
print("\nProcessing object detection...")
print(f"Confidence threshold: {confidence_threshold}")
print(f"Total objects to detect: {len(objects_extracted)}\n")
for i, template in enumerate(objects_extracted):
print(f"\nProcessing object {i+1}/{len(objects_extracted)}")
max_val = 0
max_pos = None
max_template = None
scale_steps = np.linspace(0.25, 1.0, 20)
for scale in scale_steps:
local_max, _, temp_bound_image, local_template, pos = process_template_at_scale(
source, template, method, scale
)
print(f"Scale {scale:.2f}: Confidence = {local_max:.4f}")
if local_max > max_val:
max_val = local_max
max_template = local_template
max_pos = pos
if max_val >= confidence_threshold:
print(f"Stopping at scale {scale:.2f} as confidence {max_val:.4f} >= threshold")
break
print(f"Final confidence for object {i+1}: {max_val:.4f}")
if max_pos is not None and max_val >= confidence_threshold:
h, w = max_template.shape[:2]
result_image = draw_bound(
result_image,
max_pos[0],
max_pos[0] + h,
max_pos[1],
max_pos[1] + w,
8,
(0, 255, 0)
)
cv2.putText(
result_image,
f"{i+1}",
(max_pos[1], max_pos[0]-10),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(0, 255, 0),
2
)
print(f"Object {i+1} detected at position ({max_pos[0]}, {max_pos[1]}) with size ({h}x{w})")
else:
print(f"Object {i+1} not detected (confidence {max_val:.4f} < threshold {confidence_threshold})")
print("\nDetection completed!")
return result_image
# create a Gradio interface
with gr.Blocks(title="Object Detection in Images") as demo:
gr.Markdown("# Object Detection in Images")
gr.Markdown("Upload a source image and an objects image to detect and draw bounding boxes around matching objects.")
with gr.Row():
with gr.Column():
source_input = gr.Image(label="Source Image", type="numpy")
objects_input = gr.Image(label="Objects Image", type="numpy")
threshold_input = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.01,
label="Confidence Threshold"
)
submit_btn = gr.Button("Detect Objects")
with gr.Column():
output_image = gr.Image(label="Result with Bounding Boxes", type="numpy")
submit_btn.click(
fn=process_images,
inputs=[source_input, objects_input, threshold_input],
outputs=output_image
)
demo.launch()