Ayesha352's picture
Update app.py
b6c2f72 verified
raw
history blame
7.2 kB
import cv2
import numpy as np
import json
import gradio as gr
import os
import xml.etree.ElementTree as ET
from lxml import etree
# ---------------- Helper functions ----------------
def get_rotated_rect_corners(x, y, w, h, rotation_deg):
rot_rad = np.deg2rad(rotation_deg)
cos_r, sin_r = np.cos(rot_rad), np.sin(rot_rad)
R = np.array([[cos_r, -sin_r], [sin_r, cos_r]])
cx, cy = x + w/2, y + h/2
local_corners = np.array([[-w/2,-h/2],[w/2,-h/2],[w/2,h/2],[-w/2,h/2]])
rotated_corners = np.dot(local_corners, R.T)
return (rotated_corners + np.array([cx,cy])).astype(np.float32)
def preprocess_gray_clahe(img):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
return clahe.apply(gray)
def detect_and_match(img1_gray, img2_gray, method="SIFT", ratio_thresh=0.78):
if method=="SIFT": detector=cv2.SIFT_create(nfeatures=5000); matcher=cv2.BFMatcher(cv2.NORM_L2)
elif method=="ORB": detector=cv2.ORB_create(5000); matcher=cv2.BFMatcher(cv2.NORM_HAMMING)
elif method=="BRISK": detector=cv2.BRISK_create(); matcher=cv2.BFMatcher(cv2.NORM_HAMMING)
elif method=="KAZE": detector=cv2.KAZE_create(); matcher=cv2.BFMatcher(cv2.NORM_L2)
elif method=="AKAZE": detector=cv2.AKAZE_create(); matcher=cv2.BFMatcher(cv2.NORM_HAMMING)
else: return None,None,[]
kp1, des1 = detector.detectAndCompute(img1_gray,None)
kp2, des2 = detector.detectAndCompute(img2_gray,None)
if des1 is None or des2 is None: return None,None,[]
raw_matches = matcher.knnMatch(des1,des2,k=2)
good = [m for m,n in raw_matches if m.distance < ratio_thresh*n.distance]
return kp1, kp2, good
def parse_xml_points(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
points=[]
for pt_type in ["TopLeft","TopRight","BottomLeft","BottomRight"]:
elem=root.find(f".//point[@type='{pt_type}']")
points.append([float(elem.get("x")), float(elem.get("y"))])
return np.array(points,dtype=np.float32).reshape(-1,2)
def extract_four_points_from_xml(xml_path):
tree = etree.parse(xml_path)
root = tree.getroot()
transform = root.find('.//transform')
points = {}
for pt in transform.findall('.//point'):
pt_type = pt.attrib['type']
x = int(float(pt.attrib['x']))
y = int(float(pt.attrib['y']))
points[pt_type] = (x, y)
return points
def draw_polygon_overlay(img, points_dict):
ordered_points = ['TopLeft','TopRight','BottomRight','BottomLeft']
polygon = [points_dict[pt] for pt in ordered_points]
pts = np.array(polygon, np.int32).reshape((-1,1,2))
img_overlay = img.copy()
cv2.polylines(img_overlay, [pts], isClosed=True, color=(255,0,0), thickness=3)
return img_overlay
# ---------------- Padding Helper ----------------
def pad_to_size(img, target_h, target_w):
h, w = img.shape[:2]
canvas = np.ones((target_h, target_w,3), dtype=np.uint8)*255
canvas[:h, :w] = img
return canvas
# ---------------- Main Function ----------------
def homography_all_detectors(flat_file, persp_file, json_file, xml_file):
flat_img = cv2.imread(flat_file)
persp_img = cv2.imread(persp_file)
mockup = json.load(open(json_file.name))
roi_data = mockup["printAreas"][0]["position"]
roi_x, roi_y = roi_data["x"], roi_data["y"]
roi_w, roi_h = mockup["printAreas"][0]["width"], mockup["printAreas"][0]["height"]
roi_rot_deg = mockup["printAreas"][0]["rotation"]
flat_gray = preprocess_gray_clahe(flat_img)
persp_gray = preprocess_gray_clahe(persp_img)
xml_points = parse_xml_points(xml_file.name)
xml_dict = extract_four_points_from_xml(xml_file.name)
methods = ["SIFT","ORB","BRISK","KAZE","AKAZE"]
gallery_paths = []
download_files = []
for method in methods:
kp1,kp2,good_matches = detect_and_match(flat_gray,persp_gray,method)
if kp1 is None or kp2 is None or len(good_matches)<4: continue
match_img = cv2.drawMatches(flat_img,kp1,persp_img,kp2,good_matches,None,flags=2)
src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1,1,2)
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1,1,2)
H,_ = cv2.findHomography(src_pts,dst_pts,cv2.RANSAC,5.0)
if H is None: continue
# Homography-based ROI overlay
roi_corners_flat = get_rotated_rect_corners(roi_x,roi_y,roi_w,roi_h,roi_rot_deg)
roi_corners_persp = cv2.perspectiveTransform(roi_corners_flat.reshape(-1,1,2),H).reshape(-1,2)
persp_roi = persp_img.copy()
cv2.polylines(persp_roi,[roi_corners_persp.astype(int)],True,(0,255,0),2)
for px,py in roi_corners_persp: cv2.circle(persp_roi,(int(px),int(py)),5,(255,0,0),-1)
# XML GT overlay
xml_gt_img = draw_polygon_overlay(persp_img, xml_dict)
# Convert to RGB
flat_rgb = cv2.cvtColor(flat_img, cv2.COLOR_BGR2RGB)
match_rgb = cv2.cvtColor(match_img, cv2.COLOR_BGR2RGB)
roi_rgb = cv2.cvtColor(persp_roi, cv2.COLOR_BGR2RGB)
xml_rgb = cv2.cvtColor(xml_gt_img, cv2.COLOR_BGR2RGB)
# Determine max height/width
max_h = max(flat_rgb.shape[0], match_rgb.shape[0], roi_rgb.shape[0], xml_rgb.shape[0])
max_w = max(flat_rgb.shape[1], match_rgb.shape[1], roi_rgb.shape[1], xml_rgb.shape[1])
# Pad all images
flat_pad = pad_to_size(flat_rgb, max_h, max_w)
match_pad = pad_to_size(match_rgb, max_h, max_w)
roi_pad = pad_to_size(roi_rgb, max_h, max_w)
xml_pad = pad_to_size(xml_rgb, max_h, max_w)
# Merge 2x2 grid
top = np.hstack([flat_pad, match_pad])
bottom = np.hstack([roi_pad, xml_pad])
combined_grid = np.vstack([top, bottom])
# Save
base_name = os.path.splitext(os.path.basename(persp_file))[0]
file_name = f"{base_name}_{method.lower()}.png"
cv2.imwrite(file_name, cv2.cvtColor(combined_grid, cv2.COLOR_RGB2BGR))
gallery_paths.append(file_name)
download_files.append(file_name)
while len(download_files)<5: download_files.append(None)
return gallery_paths, download_files[0], download_files[1], download_files[2], download_files[3], download_files[4]
# ---------------- Gradio UI ----------------
iface = gr.Interface(
fn=homography_all_detectors,
inputs=[
gr.Image(label="Upload Flat Image",type="filepath"),
gr.Image(label="Upload Perspective Image",type="filepath"),
gr.File(label="Upload mockup.json",file_types=[".json"]),
gr.File(label="Upload XML file",file_types=[".xml"])
],
outputs=[
gr.Gallery(label="Results per Detector",show_label=True),
gr.File(label="Download SIFT Result"),
gr.File(label="Download ORB Result"),
gr.File(label="Download BRISK Result"),
gr.File(label="Download KAZE Result"),
gr.File(label="Download AKAZE Result")
],
title="Homography ROI + Feature Matching + XML GT",
description="Shows Flat, Feature-Matched, Homography ROI (green), and XML Ground-Truth (red) overlay in a 2x2 grid with same size."
)
iface.launch()