| | import gradio as gr |
| | import os |
| | import tempfile |
| | import cv2 |
| | from test import predict_one |
| | from plot import ( |
| | autocrop, get_json_corners, extract_points_from_xml, |
| | draw_feature_matching, stack_images_side_by_side |
| | ) |
| |
|
| | |
| | MODEL_CKPT = "best_model.pth" |
| |
|
| |
|
| | |
| | |
| | |
| | def run_pipeline(flat_img, pers_img, mockup_json, xml_gt): |
| | |
| | tmpdir = tempfile.mkdtemp() |
| | xml_pred_path = os.path.join(tmpdir, "pred.xml") |
| | result_path = os.path.join(tmpdir, "result.png") |
| |
|
| | |
| | predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path) |
| |
|
| | |
| | img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB)) |
| | img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB)) |
| |
|
| | json_pts = get_json_corners(mockup_json) |
| | gt_pts = extract_points_from_xml(xml_gt) |
| | pred_pts = extract_points_from_xml(xml_pred_path) |
| | color = (0, 255, 0) |
| | color2 = (0, 0, 255) |
| | match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True) |
| | match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True) |
| |
|
| | stacked = stack_images_side_by_side(match_json_gt, match_json_pred) |
| | |
| | h, w, _ = stacked.shape |
| | center_x = w // 2 |
| | cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4) |
| |
|
| | |
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| | cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA) |
| | cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA) |
| |
|
| | |
| | cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR)) |
| |
|
| | return result_path, xml_pred_path |
| |
|
| |
|
| | |
| | |
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## Mesh Key Point Transformer Demo") |
| |
|
| | with gr.Row(): |
| | flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300) |
| | pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300) |
| |
|
| | with gr.Row(): |
| | mockup_json_in = gr.File(type="filepath", label="Mockup JSON") |
| | xml_gt_in = gr.File(type="filepath", label="Ground Truth XML") |
| |
|
| | run_btn = gr.Button("Run Prediction + Visualization") |
| |
|
| | with gr.Row(): |
| | out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600) |
| | out_xml = gr.File(type="filepath", label="Predicted XML") |
| |
|
| | run_btn.click( |
| | fn=run_pipeline, |
| | inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in], |
| | outputs=[out_img, out_xml] |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |
| |
|