import argparse import gradio as gr import numpy as np import cv2 import torch from model import SimpleHRNet, ViTHeatmap from heatmap_utils import heatmaps_to_coords_dark from secure_torch_load import secure_torch_load def parse_args(): parser = argparse.ArgumentParser(description="Cephalogram landmark inference app") parser.add_argument("--checkpoint", type=str, default="best.pt.enc", help="Path to model checkpoint") parser.add_argument("--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), help="Torch device, e.g. cuda or cpu") parser.add_argument("--server-port", type=int, default=44065, help="Port for Gradio app") parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Host for Gradio app") parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") parser.add_argument("--inbrowser", action="store_true", help="Open app in browser on launch") return parser.parse_args() def load_model(checkpoint_path, device): ckpt = secure_torch_load(checkpoint_path, map_location="cpu") # ckpt = torch.load(checkpoint_path, map_location="cpu") args = ckpt["args"] landmark_symbols = ckpt.get("landmark_symbols", None) if args["model"] == "hrnet": model = SimpleHRNet(num_landmarks=args["num_landmarks"]) else: model = ViTHeatmap( num_landmarks=args["num_landmarks"], model_name=args["vit_name"], pretrained=False, img_size=(args["input_height"], args["input_width"]), ) model.load_state_dict(ckpt["model_state_dict"]) model.to(device) model.eval() return model, args, landmark_symbols def get_symbols(n, checkpoint_symbols): if checkpoint_symbols is not None and len(checkpoint_symbols) == n: return checkpoint_symbols return [f"LM_{i}" for i in range(n)] def preprocess(image, model_args, device): h_orig, w_orig = image.shape[:2] h_in = model_args["input_height"] w_in = model_args["input_width"] resized = cv2.resize(image, (w_in, h_in)) tensor = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0 tensor = tensor.unsqueeze(0).to(device) return tensor, (h_orig, w_orig), (h_in, w_in) def decode(pred_heatmaps, orig_size, input_size): h_orig, w_orig = orig_size h_in, w_in = input_size h_hm, w_hm = pred_heatmaps.shape[2], pred_heatmaps.shape[3] coords_hm = heatmaps_to_coords_dark(pred_heatmaps)[0] coords_in = coords_hm.clone() coords_in[:, 0] *= (w_in / w_hm) coords_in[:, 1] *= (h_in / h_hm) coords_orig = coords_in.clone() coords_orig[:, 0] *= (w_orig / w_in) coords_orig[:, 1] *= (h_orig / h_in) return coords_orig.cpu().numpy() def compute_confidence(heatmaps): hm = heatmaps[0].detach().cpu().numpy() return hm.reshape(hm.shape[0], -1).max(axis=1) def draw_points(image, coords, symbols, color=(255, 0, 0)): out = image.copy() h, w = out.shape[:2] for i, (x, y) in enumerate(coords): x, y = int(round(float(x))), int(round(float(y))) if 0 <= x < w and 0 <= y < h: cv2.circle(out, (x, y), 4, color, -1, lineType=cv2.LINE_AA) cv2.putText( out, symbols[i], (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA, ) return out def heatmap_overlay(image, heatmap): h, w = image.shape[:2] hm = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR) hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-6) hm_color = cv2.applyColorMap((hm * 255).astype(np.uint8), cv2.COLORMAP_JET) hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB) return cv2.addWeighted(image, 0.6, hm_color, 0.4, 0) def make_single_landmark_view(orig, coords, symbols, hm_np, idx): out = heatmap_overlay(orig, hm_np[idx]) out = draw_points( out, np.array([coords[idx]], dtype=np.float32), [symbols[idx]], color=(255, 255, 255), ) return out def build_demo(model, model_args, checkpoint_symbols, device): default_symbols = get_symbols(model_args["num_landmarks"], checkpoint_symbols) def run_inference(image): if image is None: return None, None, None, None, None, None, gr.Dropdown() orig = image.copy() tensor, orig_size, input_size = preprocess(orig, model_args, device) with torch.no_grad(): heatmaps = model(tensor) coords = decode(heatmaps, orig_size, input_size) hm_np = heatmaps[0].detach().cpu().numpy() conf = compute_confidence(heatmaps) symbols = get_symbols(len(coords), checkpoint_symbols) pred_overlay = draw_points(orig, coords, symbols) summed_overlay = heatmap_overlay(orig, hm_np.sum(axis=0)) single_overlay = make_single_landmark_view(orig, coords, symbols, hm_np, 0) table = [ [symbols[i], float(coords[i, 0]), float(coords[i, 1]), float(conf[i])] for i in range(len(symbols)) ] cache = { "orig": orig, "coords": coords, "symbols": symbols, "heatmaps": hm_np, "pred_overlay": pred_overlay, "summed_overlay": summed_overlay, "table": table, } dropdown_update = gr.Dropdown(choices=symbols, value=symbols[0]) return orig, pred_overlay, summed_overlay, single_overlay, table, cache, dropdown_update def update_selected_landmark(selected_landmark, cache): if cache is None: return None symbols = cache["symbols"] idx = symbols.index(selected_landmark) if selected_landmark in symbols else 0 return make_single_landmark_view( cache["orig"], cache["coords"], cache["symbols"], cache["heatmaps"], idx, ) with gr.Blocks() as demo: gr.Markdown("## Cephalogram Landmark Inference") cache_state = gr.State() with gr.Row(): with gr.Column(scale=1, min_width=320): input_image = gr.Image(type="numpy", label="Input Image", height=420) run_button = gr.Button("Run Inference", variant="primary") selected_landmark = gr.Dropdown( choices=default_symbols, value=default_symbols[0], label="Landmark Heatmap Selector", ) with gr.Column(scale=2): with gr.Row(): out_orig = gr.Image(label="Original", height=284) out_pred = gr.Image(label="Predictions", height=284) with gr.Row(): out_sum = gr.Image(label="All-Landmark Heatmap Overlay", height=284) out_single = gr.Image(label="Selected Landmark Heatmap Overlay", height=284) out_table = gr.Dataframe( headers=["Landmark", "X", "Y", "Confidence"], label="Predictions", interactive=False, wrap=True, ) run_button.click( fn=run_inference, inputs=[input_image], outputs=[ out_orig, out_pred, out_sum, out_single, out_table, cache_state, selected_landmark, ], ) selected_landmark.change( fn=update_selected_landmark, inputs=[selected_landmark, cache_state], outputs=[out_single], ) return demo if __name__ == "__main__": cli_args = parse_args() model, model_args, checkpoint_symbols = load_model(cli_args.checkpoint, cli_args.device) checkpoint_symbols = [ "A", "ANS", "B", "Me", "N", "Or", "Pog", "PNS", "Pn", "R", "S", "Ar", "Co", "Gn", "Go", "Po", "LPM", "LIT", "LMT", "UPM", "UIA", "UIT", "UMT", "LIA", "Li", "Ls", "N`", "Pog`", "Sn" ] # TEMPORARY HARD CODE demo = build_demo(model, model_args, checkpoint_symbols, cli_args.device) demo.launch( # server_name=cli_args.server_name, # server_port=cli_args.server_port, # share=cli_args.share, # inbrowser=cli_args.inbrowser, )