# """ # Visualization Utilities # ======================= # Plotting and overlay functions for dental teeth segmentation results. # Includes FDI-aware labelling, per-quadrant colour coding, and training curves. # """ # import os # import numpy as np # import cv2 # import matplotlib # matplotlib.use("Agg") #sets the backend (for Docker use) # import matplotlib.pyplot as plt # import matplotlib.patches as patches # import matplotlib.cm as cm # from typing import Optional, List, Dict # from pathlib import Path # import pandas as pd # parent_dir = os.path.abspath("..") # import sys # sys.path.append(parent_dir) # import utils.preprocessing # import importlib # importlib.reload(utils.preprocessing) # from utils.preprocessing import ( # count_teeth_per_image, # class_frequency # ) # # FDI quadrant colours for intuitive display # QUADRANT_COLORS = { # "UR": "#4A90D9", # upper right — blue # "UL": "#E87040", # upper left — orange # "LL": "#2ECC71", # lower left — green # "LR": "#9B59B6", # lower right — purple # } # FDI_TO_QUADRANT = { # **{fdi: "UR" for fdi in range(11, 19)}, # **{fdi: "UL" for fdi in range(21, 29)}, # **{fdi: "LL" for fdi in range(31, 39)}, # **{fdi: "LR" for fdi in range(41, 49)}, # } # def _quadrant_color(class_name): # """ # Pick a colour for a class name based on FDI quadrant. # AKUDENTAL categories: # "11 - Central Incisor" → extract FDI number → quadrant color # "Bridge" / "Filling-Crown" / "Implant" → gray # """ # # Try to extract FDI number from name e.g. "11 - Central Incisor" # try: # fdi = int(class_name.split(" ")[0]) # if 11 <= fdi <= 18: return QUADRANT_COLORS["UR"] # if 21 <= fdi <= 28: return QUADRANT_COLORS["UL"] # if 31 <= fdi <= 38: return QUADRANT_COLORS["LL"] # if 41 <= fdi <= 48: return QUADRANT_COLORS["LR"] # except (ValueError, IndexError): # pass # # Non-FDI categories: Bridge, Filling-Crown, Implant # return "#AAAAAA" # gray # def apply_masks(image, masks,class_names=None, alpha = 0.45): # """ # Draw semi-transparent tooth masks(model predictions) on the # original jpg. # Args: # image: RGB image [H, W, 3] uint8. # masks: Bool masks [H, W, N]. # class_names: Class name per mask (used for color). # alpha: Mask opacity. # """ # output = image.copy().astype(np.float32) # for i in range(masks.shape[-1]): # if class_names and i < len(class_names): # hex_col = _quadrant_color(class_names[i]) # r, g, b = int(hex_col[1:3], 16), int(hex_col[3:5], 16), int(hex_col[5:7], 16) # colour = np.array([r, g, b], dtype=np.float32) # else: # cmap = cm.get_cmap("tab20", max(1, masks.shape[-1])) # colour = np.array(cmap(i)[:3]) * 255 # for c in range(3): # output[:, :, c] = np.where( # masks[:, :, i], # output[:, :, c] * (1 - alpha) + colour[c] * alpha, # output[:, :, c], # ) # return output.astype(np.uint8) # def draw_bounding_boxes(image,rois,class_ids,scores,class_names): # """ # Draw bounding boxes with label and confidence score. # Args: # image: (H,W,3) uint8 .jpg image # roi: bounding boxes (one 4-element array for each N teeth) # class_ids: array of class indices(one per tooth) # scores: confidence score (N,) # class_names: List of class names ["BG","tooth"] # """ # out = image.copy() # for i, roi in enumerate(rois): # y1, x1, y2, x2 = roi # name = class_names[class_ids[i]] if class_ids[i] < len(class_names) else "unknown" # hex_col = _quadrant_color(name) # color = (int(hex_col[1:3], 16), int(hex_col[3:5], 16), int(hex_col[5:7], 16)) # label = f"{name} {scores[i]:.0%}" # cv2.rectangle(out, (x1, y1), (x2, y2), color, 2) # cv2.putText(out, label, (x1, max(y1 - 5, 12)), # cv2.FONT_HERSHEY_SIMPLEX, 0.38, color, 1, cv2.LINE_AA) # return out # def visualize_prediction(image, result, class_names, save_path=None, show=False): # """ # Visualizes both colored masks and bounding boxes # Args: # image: original Xray uint8 # result: dict from model.detect (prediction dictionary) # class_names: ['bg','11-central incisor'] # save_path: optional path to where to save the image # show: call plt.show() or not # """ # masks = result.get("masks", np.zeros((*image.shape[:2], 0), dtype=bool)) # rois = result.get("rois", np.zeros((0, 4), dtype=int)) # class_ids = result.get("class_ids", np.array([], dtype=int)) # scores = result.get("scores", np.array([], dtype=float)) # det_names = [class_names[cid] for cid in class_ids if cid < len(class_names)] # annotated = apply_masks(image, masks, det_names) # annotated = draw_bounding_boxes(annotated, rois, class_ids, scores, class_names) # fig, axes = plt.subplots(1, 2, figsize=(16, 5)) # axes[0].imshow(image) # axes[0].set_title("Original Panoramic X-ray") # axes[0].axis("off") # axes[1].imshow(annotated) # axes[1].set_title(f"Segmentation — {masks.shape[-1]} teeth detected") # axes[1].axis("off") # # Legend — quadrant colors + gray for restorations # legend_elements = ([patches.Patch(facecolor=c, label=q) for q, c in QUADRANT_COLORS.items()] # + [patches.Patch(facecolor="#AAAAAA", label="Bridge/Implant/Crown")]) # axes[1].legend(handles=legend_elements, loc="lower right", # fontsize=8, title="Category", framealpha=0.8) # plt.tight_layout() # if save_path: # os.makedirs(os.path.dirname(save_path), exist_ok=True) # plt.savefig(save_path, dpi=150, bbox_inches="tight") # if show: # plt.show() # plt.close(fig) # return annotated # def plot_class_distribution(coco, save_path = None): # """ # Bar chart showing how many times each tooth was annotated # across all images we have. # Grouped by quadrant with quadrant colours. # Args: # coco: loaded coco dict # save_path: optional path for where to save the chart # """ # freq = class_frequency(coco) # cat_map = {c['id']: c['name'] for c in coco['categories']} # cat_ids = sorted(freq.keys()) # names = [cat_map.get(i,str(i)) for i in cat_ids] # counts = [freq[i] for i in cat_ids] # colors = [_quadrant_color(n) for n in names] # fig, ax = plt.subplots(figsize=(20, 5)) # ax.bar(range(len(cat_ids)), counts, color=colors, edgecolor="white") # ax.set_xticks(range(len(cat_ids))) # ax.set_xticklabels(names, rotation=45, ha='right', fontsize=7) # ax.set_ylabel("Annotation count") # ax.set_title("Annotation Frequency per Category (AKUDENTAL)") # ax.grid(axis="y", alpha=0.3) # legend_elements = [ # patches.Patch(facecolor=c, label=q) # for q, c in QUADRANT_COLORS.items() # ] + [patches.Patch(facecolor="#AAAAAA", label="Bridge/Implant/Crown")] # ax.legend(handles=legend_elements, fontsize=8) # plt.tight_layout() # if save_path: # plt.savefig(save_path, dpi=150, bbox_inches="tight") # plt.close(fig) # def plot_teeth_per_image(coco, save_path = None): # """ # Bar chart of tooth count per image. # Args: # coco - annotation directory # save_path - full path where to save the image # """ # counts = count_teeth_per_image(coco) # names = sorted(counts.keys()) # values = [counts[n] for n in names] # fig, ax = plt.subplots(figsize=(12, 4)) # ax.bar(range(len(names)), values, color="#185FA5", edgecolor="white") # ax.axhline(np.mean(values), color="#D85A30", linestyle="--", linewidth=1.5, # label=f"Mean = {np.mean(values):.1f}") # ax.set_xlabel("Image") # ax.set_ylabel("Number of annotated teeth") # ax.set_title("Annotation count per image") # ax.legend() # ax.grid(axis="y", alpha=0.3) # plt.xticks(rotation=45, ha="right") # plt.tight_layout() # if save_path: # plt.savefig(save_path, dpi=150, bbox_inches="tight") # plt.close(fig) """ Visualization Utilities ======================= Plotting and overlay functions for dental teeth segmentation results. Includes FDI-aware labelling, per-quadrant colour coding, and training curves. Supports both binary mode ('tooth') and FDI multi-class mode (35 categories). """ import os import numpy as np import cv2 import matplotlib matplotlib.use("Agg") # sets the backend (for Docker use) import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.cm as cm from pathlib import Path import pandas as pd parent_dir = os.path.abspath("..") import sys sys.path.append(parent_dir) import utils.preprocessing import importlib importlib.reload(utils.preprocessing) from utils.preprocessing import ( count_teeth_per_image, class_frequency, ) # FDI quadrant colours for intuitive display QUADRANT_COLORS = { "UR": "#4A90D9", # upper right — blue "UL": "#E87040", # upper left — orange "LL": "#2ECC71", # lower left — green "LR": "#9B59B6", # lower right — purple } FDI_TO_QUADRANT = { **{fdi: "UR" for fdi in range(11, 19)}, **{fdi: "UL" for fdi in range(21, 29)}, **{fdi: "LL" for fdi in range(31, 39)}, **{fdi: "LR" for fdi in range(41, 49)}, } def _quadrant_color(class_name): """ Pick a colour for a class name based on FDI quadrant. Binary mode: 'tooth' → blue FDI mode: '11 - Central Incisor' → quadrant color Other: 'Bridge'/'Implant'/etc. → gray """ # Binary mode — single tooth class if class_name == 'tooth': return "#4A90D9" # blue # FDI mode — "11 - Central Incisor" → extract FDI number → quadrant color try: fdi = int(class_name.split(" ")[0]) if 11 <= fdi <= 18: return QUADRANT_COLORS["UR"] if 21 <= fdi <= 28: return QUADRANT_COLORS["UL"] if 31 <= fdi <= 38: return QUADRANT_COLORS["LL"] if 41 <= fdi <= 48: return QUADRANT_COLORS["LR"] except (ValueError, IndexError): pass # Non-FDI categories: Bridge, Filling-Crown, Implant return "#AAAAAA" def apply_masks(image, masks, class_names=None, alpha=0.45): """ Draw semi-transparent tooth masks on the original image. Args: image: RGB image [H, W, 3] uint8. masks: Bool masks [H, W, N]. class_names: Class name per mask (used for color). alpha: Mask opacity (0=transparent, 1=opaque). """ output = image.copy().astype(np.float32) for i in range(masks.shape[-1]): if class_names and i < len(class_names): hex_col = _quadrant_color(class_names[i]) r = int(hex_col[1:3], 16) g = int(hex_col[3:5], 16) b = int(hex_col[5:7], 16) colour = np.array([r, g, b], dtype=np.float32) else: cmap = cm.get_cmap("tab20", max(1, masks.shape[-1])) colour = np.array(cmap(i)[:3]) * 255 for c in range(3): output[:, :, c] = np.where( masks[:, :, i], output[:, :, c] * (1 - alpha) + colour[c] * alpha, output[:, :, c], ) return output.astype(np.uint8) def draw_bounding_boxes(image, rois, class_ids, scores, class_names): """ Draw bounding boxes with label and confidence score. Args: image: (H, W, 3) uint8 image. rois: (N, 4) bounding boxes [y1, x1, y2, x2]. class_ids: (N,) class indices. scores: (N,) confidence scores. class_names: List of class names e.g. ['BG', 'tooth']. """ out = image.copy() for i, roi in enumerate(rois): y1, x1, y2, x2 = roi name = class_names[class_ids[i]] if class_ids[i] < len(class_names) else "unknown" hex_col = _quadrant_color(name) color = (int(hex_col[1:3], 16), int(hex_col[3:5], 16), int(hex_col[5:7], 16)) label = f"{name} {scores[i]:.0%}" cv2.rectangle(out, (x1, y1), (x2, y2), color, 2) cv2.putText(out, label, (x1, max(y1 - 5, 12)), cv2.FONT_HERSHEY_SIMPLEX, 0.38, color, 1, cv2.LINE_AA) return out def visualize_prediction(image, result, class_names, save_path=None, show=False): """ Visualizes both colored masks and bounding boxes. Supports binary mode (['BG', 'tooth']) and FDI mode (['BG', '11 - Central Incisor', ...]). Args: image: Original X-ray uint8 (H, W, 3). result: Dict from model.detect() {masks, rois, class_ids, scores}. class_names: ['BG', 'tooth'] for binary or full FDI list for multi-class. save_path: Optional path to save figure. show: Whether to call plt.show(). """ masks = result.get("masks", np.zeros((*image.shape[:2], 0), dtype=bool)) rois = result.get("rois", np.zeros((0, 4), dtype=int)) class_ids = result.get("class_ids", np.array([], dtype=int)) scores = result.get("scores", np.array([], dtype=float)) det_names = [class_names[cid] for cid in class_ids if cid < len(class_names)] annotated = apply_masks(image, masks, det_names) annotated = draw_bounding_boxes(annotated, rois, class_ids, scores, class_names) fig, axes = plt.subplots(1, 2, figsize=(16, 5)) axes[0].imshow(image) axes[0].set_title("Original Panoramic X-ray") axes[0].axis("off") axes[1].imshow(annotated) axes[1].set_title(f"Segmentation — {masks.shape[-1]} teeth detected") axes[1].axis("off") # Legend — adapt to binary vs FDI multi-class mode if 'tooth' in class_names: # Binary mode — single color legend_elements = [ patches.Patch(facecolor="#4A90D9", label="tooth") ] else: # FDI multi-class mode — quadrant colors + gray for restorations legend_elements = ( [patches.Patch(facecolor=c, label=q) for q, c in QUADRANT_COLORS.items()] + [patches.Patch(facecolor="#AAAAAA", label="Bridge/Implant/Crown")] ) axes[1].legend(handles=legend_elements, loc="lower right", fontsize=8, title="Category", framealpha=0.8) plt.tight_layout() if save_path: os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") if show: plt.show() plt.close(fig) return annotated def plot_class_distribution(coco, save_path=None): """ Bar chart showing annotation frequency per category. FDI teeth colored by quadrant, Bridge/Implant/Crown in gray. Args: coco: Loaded COCO annotation dict. save_path: Optional path to save the chart. """ freq = class_frequency(coco) cat_map = {c['id']: c['name'] for c in coco['categories']} cat_ids = sorted(freq.keys()) names = [cat_map.get(i, str(i)) for i in cat_ids] counts = [freq[i] for i in cat_ids] colors = [_quadrant_color(n) for n in names] fig, ax = plt.subplots(figsize=(20, 5)) ax.bar(range(len(cat_ids)), counts, color=colors, edgecolor="white") ax.set_xticks(range(len(cat_ids))) ax.set_xticklabels(names, rotation=45, ha='right', fontsize=7) ax.set_ylabel("Annotation count") ax.set_title("Annotation Frequency per Category (AKUDENTAL)") ax.grid(axis="y", alpha=0.3) legend_elements = ( [patches.Patch(facecolor=c, label=q) for q, c in QUADRANT_COLORS.items()] + [patches.Patch(facecolor="#AAAAAA", label="Bridge/Implant/Crown")] ) ax.legend(handles=legend_elements, fontsize=8) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) def plot_teeth_per_image(coco, save_path=None): """ Histogram of annotation count per image. Args: coco: Loaded COCO annotation dict. save_path: Optional path to save the chart. """ counts = count_teeth_per_image(coco) values = list(counts.values()) fig, ax = plt.subplots(figsize=(10, 4)) ax.hist(values, bins=20, color="#185FA5", edgecolor="white") ax.axvline(np.mean(values), color="#D85A30", linestyle="--", linewidth=1.5, label=f"Mean = {np.mean(values):.1f}") ax.set_xlabel("Number of annotated instances") ax.set_ylabel("Number of images") ax.set_title("Annotation count per image (AKUDENTAL)") ax.legend() ax.grid(axis="y", alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig)