| import gradio as gr |
| import cv2 |
| import numpy as np |
| import zipfile |
| import io |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import plotly.express as px |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| from sklearn.preprocessing import MinMaxScaler |
| import os |
| import cv2 |
| from mpl_toolkits.axes_grid1 import make_axes_locatable |
| from pathlib import Path |
| import SimpleITK as sitk |
| import matplotlib.patches as mpatches |
| import napari |
|
|
|
|
| |
|
|
| def rgb2gray(rgb): |
| """ |
| Converts RGB images into grayscale images based on the formula |
| Input parameters: RGB image |
| Output: Graycale image |
| """ |
| if rgb.ndim == 2: return rgb |
| return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]) |
|
|
| def Normalize(image): |
| img_min = np.min(image) |
| img_max = np.max(image) |
| |
| if img_max == img_min: |
| return np.zeros_like(image) |
| |
| return (image - img_min)/(img_max-img_min) |
| |
| def Normalize_percentiles (image, low_perc, upp_perc): |
| p_min = np.percentile(image, low_perc) |
| p_max = np.percentile(image, upp_perc) |
| |
| if p_max == p_min: |
| return np.zeros_like(image) |
| norm_img = (image- p_min)/(p_max-p_min) |
| norm_img = np.clip(norm_img, 0, 1) |
| |
| return norm_img |
|
|
| |
|
|
| def prepare_base_image (path): |
| img = cv2.imread(path) |
| gray = rgb2gray(img) |
| norm = Normalize(gray) |
| final_img = (norm * 255).astype(np.uint8) |
| |
| return final_img, img |
|
|
| def resize_image (image, scale_factor): |
| h, w = image.shape[:2] |
| new_size = (int(w * scale_factor), int(h * scale_factor)) |
| resized = cv2.resize(image, new_size, interpolation=cv2.INTER_LANCZOS4) |
| return resized |
|
|
| def apply_clahe(image, use_clahe=True): |
| if use_clahe: |
| |
| clahe_obj = cv2.createCLAHE(clipLimit=6.0, tileGridSize=(8,8)) |
| enhanced = clahe_obj.apply(image) |
| else: |
| enhanced = cv2.equalizeHist(image) |
| |
| return enhanced |
|
|
| def align_centers(fixed_img, moving_img): |
| |
| h_fixed, w_fixed = fixed_img.shape[:2] |
| h_moving, w_moving = moving_img.shape[:2] |
|
|
| |
| M_fixed = cv2.moments(fixed_img) |
| if M_fixed["m00"] == 0: |
| print("Warning: Fixed image is empty/black. Skipping CoM alignment.") |
| return moving_img, (0, 0) |
| |
| cX_fixed = int(M_fixed["m10"] / M_fixed["m00"]) |
| cY_fixed = int(M_fixed["m01"] / M_fixed["m00"]) |
|
|
| M_moving = cv2.moments(moving_img) |
| if M_moving["m00"] == 0: |
| print("Warning: Moving image is empty/black. Skipping CoM alignment.") |
| return moving_img, (0, 0) |
| |
| cX_moving = int(M_moving["m10"] / M_moving["m00"]) |
| cY_moving = int(M_moving["m01"] / M_moving["m00"]) |
|
|
| |
| shift_x = cX_fixed - cX_moving |
| shift_y = cY_fixed - cY_moving |
| |
| |
| T = np.float32([[1, 0, shift_x], [0, 1, shift_y]]) |
| |
| centered_moving = cv2.warpAffine( |
| moving_img, |
| T, |
| (w_fixed, h_fixed), |
| flags=cv2.INTER_LINEAR, |
| borderMode=cv2.BORDER_CONSTANT, |
| borderValue=0 |
| ) |
| |
| return centered_moving, (shift_x, shift_y) |
|
|
| def Gaussian_blur(image, kernel_size=5, sigma=0): |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
| print("Kernel adjust:{kernel_size}") |
| blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma) |
| |
| return blurred |
| |
| def find_edges(image, sigma=0.33): |
| if image.ndim == 3: |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
| v = np.median(image) |
| |
| lower = int(max(0, (1.0 - sigma) * v)) |
| upper = int(min(255, (1.0 + sigma) * v)) |
| |
| edged = cv2.Canny(image, lower, upper) |
| |
| return edged |
|
|
| def find_edges_binary(image): |
|
|
| _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| edged = cv2.Canny(binary, 30, 100) |
| |
| return edged |
|
|
| def binary_mask(image, kernel_size=5): |
| """ |
| Converts an image into a solid binary mask (Silhouette). |
| |
| Assumes the input image is already preprocessed so that the |
| tissue is bright and the background is dark (Bright-on-Dark). |
| |
| Args: |
| image: Input image (Grayscale or BGR). |
| kernel_size: Size of the structuring element for morphological operations. |
| Larger size removes bigger noise spots but might smooth shape details. |
| |
| Returns: |
| clean_mask: A binary image (0 and 255) containing only the main tissue shape. |
| """ |
| |
| |
| |
| if image.ndim == 3: |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| else: |
| gray = image.copy() |
| |
| |
| |
| blurred = cv2.GaussianBlur(gray, (7, 7), 0) |
| |
| |
| |
| |
| |
| _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
| |
| |
| |
| |
| solid_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2) |
| |
| |
| |
| clean_mask = cv2.morphologyEx(solid_mask, cv2.MORPH_OPEN, kernel, iterations=1) |
| |
| return clean_mask |
|
|
| |
|
|
| def get_symmetry(image, label): |
| """ |
| Apply symmetry for a SQUARE (Group D8) or a RECTANGLE that can rotate 90° (D2+90° rotations): |
| R0, R1, R2, R3, M1, M2, D1, D2 |
| """ |
| |
| |
| if label == "R0": return image |
| elif label == "R1": return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) |
| elif label == "R2": return cv2.rotate(image, cv2.ROTATE_180) |
| elif label == "R3": return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) |
| |
| |
| elif label == "M1": return cv2.flip(image, 1) |
| elif label == "M2": return cv2.flip(image, 0) |
| |
| |
| elif label == "D1": |
| transposed = cv2.transpose(image) |
| return transposed |
| elif label == "D2": |
| transposed = cv2.transpose(image) |
| return cv2.flip(transposed, 0) |
| |
| return image |
|
|
| def find_best_match_pixel(img_1, img_2, image_name): |
| results = [] |
|
|
| h1, w1 = img_1.shape[:2] |
| h2, w2 = img_2.shape[:2] |
|
|
| area1 = h1 * w1 |
| area2 = h2 * w2 |
|
|
| if area1 > area2: |
| image_big = img_1 |
| image_small = img_2 |
| else: |
| image_big = img_2 |
| image_small = img_1 |
|
|
| h_big, w_big = image_big.shape[:2] |
| h_small, w_small = image_small.shape[:2] |
|
|
| max_dim_small = max(h_small, w_small) |
| |
| missing_h = max(0, max_dim_small - h_big) |
| missing_w = max(0, max_dim_small - w_big) |
|
|
| pad_top = 0 |
| pad_left = 0 |
|
|
| if missing_h > 0 or missing_w > 0: |
| margin = int(max_dim_small * 0.1) |
|
|
| if missing_h > 0: |
| pad_top = (missing_h // 2) + margin |
| pad_bottom = (missing_h // 2) + margin |
| else: |
| pad_top = 0 |
| pad_bottom = 0 |
|
|
| if missing_w > 0: |
| pad_left = (missing_w // 2) + margin |
| pad_right = (missing_w // 2) + margin |
| else: |
| pad_left = 0 |
| pad_right = 0 |
|
|
| image_big = cv2.copyMakeBorder( |
| image_big, |
| pad_top, pad_bottom, pad_left, pad_right, |
| cv2.BORDER_CONSTANT, value=0 |
| ) |
|
|
| symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"] |
|
|
| for sym in symmetries: |
| current_small = get_symmetry(image_small, sym) |
| |
| res = cv2.matchTemplate(image_big, current_small, cv2.TM_CCOEFF_NORMED) |
| |
| _, max_val, _, max_loc = cv2.minMaxLoc(res) |
| |
| real_x = max_loc[0] - pad_left |
| real_y = max_loc[1] - pad_top |
| |
| results.append({ |
| "name": image_name, |
| "symmetry": sym, |
| "score": max_val, |
| "location": (real_x, real_y), |
| "raman_img": current_small |
| }) |
| |
| best_result = max(results, key=lambda x: x['score']) |
|
|
| return best_result, results |
|
|
| def find_best_match_features(img_1, img_2, image_name): |
| """ |
| Tests all 8 symmetries of img_raman (img_1) against img_bf (img_2) |
| using SIFT Features + KNN Matching + Lowe's Ratio Test + RANSAC. |
| |
| Returns dictionary with the best result based on Inlier Count. |
| """ |
| |
| |
| sift = cv2.SIFT_create() |
|
|
| |
| kp2, des2 = sift.detectAndCompute(img_2, None) |
| |
| results = [] |
| |
| |
| symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"] |
|
|
| for sym in symmetries: |
| current_raman = get_symmetry(img_1, sym) |
| |
| kp1, des1 = sift.detectAndCompute(current_raman, None) |
| |
| if des1 is None or des2 is None or len(kp1) < 5 or len(kp2) < 5: |
| continue |
| |
| bf = cv2.BFMatcher(cv2.NORM_L2) |
| matches = bf.knnMatch(des1, des2, k=2) |
| |
| good_matches = [] |
| for m, n in matches: |
| if m.distance < 0.75 * n.distance: |
| good_matches.append(m) |
| |
| score = 0 |
| H = None |
| matches_mask = [] |
| |
| if len(good_matches) >= 4: |
| src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2) |
| dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2) |
| |
| H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) |
| |
| if mask is not None: |
| matches_mask = mask.ravel().tolist() |
| score = np.sum(matches_mask) |
| |
| results.append({ |
| "name": image_name, |
| "symmetry": sym, |
| "score": score, |
| "homography": H, |
| "raman_img": current_raman, |
| "keypoints_1": kp1, |
| "good_matches": good_matches, |
| "matches_mask": matches_mask |
| }) |
|
|
| if not results: |
| return None, [] |
|
|
| best_result = max(results, key=lambda x: x['score']) |
| |
| return best_result, results |
|
|
|
|
| def find_best_match_fft(img_1, img_2, image_name): |
| """ |
| Tests all 8 symmetries using FFT Phase Correlation. |
| """ |
| |
| |
| |
| |
| h_max = max(img_1.shape[0], img_2.shape[0]) |
| w_max = max(img_1.shape[1], img_2.shape[1]) |
| |
| |
| def pad_to_size(img, th, tw): |
| h, w = img.shape |
| if h == th and w == tw: return img.astype(np.float32) |
| |
| padded = np.zeros((th, tw), dtype=np.float32) |
| |
| y_off = (th - h) // 2 |
| x_off = (tw - w) // 2 |
| padded[y_off:y_off+h, x_off:x_off+w] = img |
| return padded |
|
|
| |
| img_2_float = pad_to_size(img_2, h_max, w_max) |
| |
| |
| |
| window = cv2.createHanningWindow((w_max, h_max), cv2.CV_32F) |
|
|
| results = [] |
| |
| |
| symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"] |
|
|
| for sym in symmetries: |
| |
| current_raman = get_symmetry(img_1, sym) |
| |
| |
| current_raman_float = pad_to_size(current_raman, h_max, w_max) |
| |
| |
| |
| try: |
| |
| shift, response = cv2.phaseCorrelate(current_raman_float, img_2_float, window=window) |
| |
| |
| dx, dy = shift |
| |
| results.append({ |
| "name": image_name, |
| "symmetry": sym, |
| "score": response, |
| "shift_xy": (dx, dy), |
| "raman_img": current_raman |
| }) |
| |
| except Exception as e: |
| |
| print(f"FFT Error on {sym}: {e}") |
| continue |
|
|
| |
| if not results: |
| return None, [] |
|
|
| |
| best_result = max(results, key=lambda x: x['score']) |
| |
| return best_result, results |
|
|
| |
|
|
| def fine_tune_registration(fixed_img_cv, moving_img_cv, transform_type): |
| """ |
| Complete registration using Mutual Information. |
| Args: |
| fixed_img_cv: BF crop (grayscale). |
| moving_img_cv: Preoriented Raman (grayscale). |
| transform_type: "Rigid", "Similarity", "Affine", "BSpline" |
| Returns: |
| registered_img (numpy): transformed Raman image |
| final_transform (sitk.Transform): mathematical computed matrix |
| """ |
| |
| fixed = sitk.GetImageFromArray(fixed_img_cv.astype(np.float32)) |
| moving = sitk.GetImageFromArray(moving_img_cv.astype(np.float32)) |
|
|
| if transform_type == "Rigid": |
| |
| initial_transform = sitk.CenteredTransformInitializer( |
| fixed, moving, |
| sitk.Euler2DTransform(), |
| sitk.CenteredTransformInitializerFilter.GEOMETRY |
| ) |
|
|
| elif transform_type == "Similarity": |
| |
| |
| initial_transform = sitk.CenteredTransformInitializer( |
| fixed, moving, |
| sitk.Similarity2DTransform(), |
| sitk.CenteredTransformInitializerFilter.GEOMETRY |
| ) |
|
|
| elif transform_type == "Affine": |
| |
| initial_transform = sitk.CenteredTransformInitializer( |
| fixed, moving, |
| sitk.AffineTransform(2), |
| sitk.CenteredTransformInitializerFilter.GEOMETRY |
| ) |
| |
| elif transform_type == "BSpline": |
| |
| |
| |
| init_rigid = sitk.CenteredTransformInitializer( |
| fixed, moving, sitk.Euler2DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY |
| ) |
| |
| grid_physical_spacing = [50.0, 50.0] |
| mesh_size = [3, 3] |
| |
| initial_transform = sitk.BSplineTransformInitializer(fixed, mesh_size) |
| |
| else: |
| raise ValueError("Use 'Rigid', 'Similarity', 'Affine' or 'BSpline'.") |
|
|
| |
| R = sitk.ImageRegistrationMethod() |
| |
| |
| R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) |
| R.SetMetricSamplingStrategy(R.RANDOM) |
| R.SetMetricSamplingPercentage(0.3) |
|
|
| |
| if transform_type == "BSpline": |
| |
| R.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, numberOfIterations=100, maximumNumberOfCorrections=5) |
| else: |
| |
| R.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10) |
| R.SetOptimizerScalesFromPhysicalShift() |
|
|
| |
| R.SetInitialTransform(initial_transform, inPlace=False) |
| R.SetInterpolator(sitk.sitkLinear) |
|
|
| |
| try: |
| final_transform = R.Execute(fixed, moving) |
| |
| print(f"Register {transform_type} complete. Metric value: {R.GetMetricValue():.4f}") |
| |
| except Exception as e: |
| print(f"Register {transform_type} fails: {e}") |
| return moving_img_cv, None |
|
|
| |
| resampler = sitk.ResampleImageFilter() |
| resampler.SetReferenceImage(fixed) |
| resampler.SetInterpolator(sitk.sitkBSpline) |
| resampler.SetDefaultPixelValue(0) |
| resampler.SetTransform(final_transform) |
|
|
| out_sitk = resampler.Execute(moving) |
| |
| return sitk.GetArrayFromImage(out_sitk), final_transform |
|
|
| |
|
|
| def visual_debugger(img_1, img_2): |
| |
| _, bin_r = cv2.threshold(img_1.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| ret, bf_thresh = cv2.threshold(img_2.astype(np.uint8), 30, 255, cv2.THRESH_TOZERO) |
| _, bin_b = cv2.threshold(bf_thresh, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) |
| bin_b_processed = cv2.dilate(bin_b, kernel, iterations=1) |
| bin_b_processed = cv2.morphologyEx(bin_b_processed, cv2.MORPH_CLOSE, kernel) |
|
|
| h, w = bin_r.shape |
| viz = np.zeros((h, w, 3), dtype=np.uint8) |
| |
| |
| viz[:,:,0] = bin_r |
| |
| |
| viz[:,:,1] = bin_b_processed |
|
|
| mask_r = bin_r > 0 |
| mask_b = bin_b_processed > 0 |
| intersection = np.count_nonzero(np.logical_and(mask_r, mask_b)) |
| area_r = np.count_nonzero(mask_r) |
| area_b = np.count_nonzero(mask_b) |
| score = intersection / min(area_r, area_b) if min(area_r, area_b) > 0 else 0 |
|
|
| plt.figure(figsize=(12, 12)) |
| plt.imshow(viz) |
| plt.title(f"Visual Debugger | Overlap Score: {score:.4f}", fontsize=14, fontweight='bold') |
| plt.axis('off') |
|
|
| patch_red = mpatches.Patch(color='red', label='Img 1') |
| patch_green = mpatches.Patch(color='green', label='Img 2') |
| patch_yellow = mpatches.Patch(color='yellow', label='MATCH') |
|
|
| plt.legend(handles=[patch_red, patch_green, patch_yellow], |
| loc='upper right', framealpha=0.9, fontsize=12, facecolor='black', labelcolor='white') |
|
|
| plt.tight_layout() |
| plt.show() |
| |
| def show_in_napari (img_1, img_2): |
| _, bin_1 = cv2.threshold(img_1.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| ret, bf_thresh = cv2.threshold(img_2.astype(np.uint8), 30, 255, cv2.THRESH_TOZERO) |
| _, bin_2 = cv2.threshold(bf_thresh, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| |
| |
| |
| |
|
|
| viewer = napari.Viewer(title="IMG 1 vs IMG 2 Registration Debugger") |
|
|
| |
| viewer.add_image( |
| img_1, |
| name='IMG 1', |
| colormap='gray', |
| opacity=1.0 |
| ) |
|
|
| |
| viewer.add_image( |
| img_2, |
| name='IMG 2', |
| colormap='inferno', |
| blending='additive', |
| opacity=0.8 |
| ) |
|
|
| |
| viewer.add_image( |
| bin_1, |
| name='Debug: IMG 1 Mask', |
| colormap='green', |
| blending='additive', |
| opacity=0.5, |
| visible=False |
| ) |
|
|
| |
| viewer.add_image( |
| bin_2, |
| name='Debug: IMG 2 Mask', |
| colormap='red', |
| blending='additive', |
| opacity=0.5, |
| visible=False |
| ) |
|
|
| napari.run() |
| |
| def RGBA_visualization(img_1, img_2): |
| |
| if img_1.dtype != np.uint8: |
| img_1 = cv2.normalize(img_1, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) |
| if img_2.dtype != np.uint8: |
| img_2 = cv2.normalize(img_2, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) |
|
|
| |
| |
| if img_1.shape[:2] != img_2.shape[:2]: |
| img_2 = cv2.resize(img_2, (img_1.shape[1], img_1.shape[0])) |
| |
| |
| img_1_color = cv2.applyColorMap(img_1, cv2.COLORMAP_JET) |
| |
| |
| b, g, r = cv2.split(img_1_color) |
| |
| |
| alpha = img_2 |
| |
| |
| rgba_img = cv2.merge([b, g, r, alpha]) |
| |
| return rgba_img, b, g, r, alpha |
|
|
|
|
| |
| |
| |
|
|
| def load_from_file(file_obj): |
| """Reads an image from a generic file object (path) via OpenCV.""" |
| if file_obj is None: return None |
| img = cv2.imread(file_obj.name, -1) |
| if img is None: return None |
| if img.ndim == 3: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| return img |
|
|
| def load_and_normalize(image): |
| """Handles TIFF/16-bit loading and normalization.""" |
| if image is None: return None |
| |
| if image.dtype == np.uint16 or image.dtype == np.float32 or image.dtype == np.float64: |
| min_val = np.min(image) |
| max_val = np.max(image) |
| if max_val - min_val > 0: |
| norm = (image - min_val) / (max_val - min_val) |
| image = (norm * 255).astype(np.uint8) |
| else: |
| image = np.zeros_like(image, dtype=np.uint8) |
| |
| if image.ndim == 2: |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| return image |
|
|
| def to_display(image, max_width=800): |
| """Creates a small copy (Standard Numpy) for creating the Plot.""" |
| if image is None: return None |
| h, w = image.shape[:2] |
| if w > max_width: |
| scale = max_width / w |
| new_h = int(h * scale) |
| return cv2.resize(image, (max_width, new_h), interpolation=cv2.INTER_NEAREST) |
| return image |
|
|
| def to_interactive_plot(image, height=400): |
| """Converts a numpy image to a zoomable/pannable Plotly figure.""" |
| if image is None: return None |
| if image.ndim == 2: |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| fig = px.imshow(image) |
| fig.update_layout( |
| margin=dict(l=0, r=0, b=0, t=0), |
| xaxis={'showticklabels': False, 'visible': False}, |
| yaxis={'showticklabels': False, 'visible': False}, |
| dragmode='pan', |
| height=height |
| ) |
| return fig |
|
|
| def create_full_tensor_files(stack): |
| """Exports the stack as a Numpy text representation and a binary .npy.""" |
| if not stack: return None, None |
| |
| |
| arrays = [item['img'] for item in stack] |
| |
| try: |
| tensor_np = np.stack(arrays, axis=0) |
| header_info = f"Shape: {tensor_np.shape}, Dtype: {tensor_np.dtype}" |
| except: |
| tensor_np = np.array(arrays, dtype=object) |
| header_info = "Ragged Array (dimensions vary)" |
|
|
| npy_filename = "sandwich_tensor.npy" |
| np.save(npy_filename, tensor_np) |
|
|
| txt_filename = "sandwich_tensor_text.txt" |
| with open(txt_filename, "w") as f: |
| f.write(f"# {header_info}\n") |
| f.write(f"# Load with: import ast; data = ast.literal_eval(open('file.txt').read())\n") |
| f.write(str(tensor_np.tolist())) |
| |
| return npy_filename, txt_filename |
|
|
| def create_montage_plot(stack): |
| """Creates a Matplotlib figure with all layers side-by-side with axes.""" |
| if not stack: return None |
| |
| n = len(stack) |
| fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), constrained_layout=True) |
| |
| if n == 1: axes = [axes] |
| |
| for i, ax in enumerate(axes): |
| img = stack[i]['img'] |
| name = stack[i]['name'] |
| |
| if img.ndim == 2: |
| ax.imshow(img, cmap='gray') |
| else: |
| ax.imshow(img) |
| |
| ax.set_title(f"{i}: {name}") |
| ax.axis('on') |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=150) |
| plt.close(fig) |
| buf.seek(0) |
| |
| filename = "sandwich_montage.png" |
| with open(filename, "wb") as f: |
| f.write(buf.read()) |
| return filename |
|
|
|
|
| def create_metadata_csv(stack): |
| """Creates a CSV with metadata, splitting history into Preprocessing and Crop details.""" |
| if not stack: return None |
| |
| rows = [] |
| for i, item in enumerate(stack): |
| history = item.get('history', []) |
| |
|
|
| crop_keywords = ["Cropped", "Shift", "Symmetry"] |
| |
| |
| prep_steps = [step for step in history if not any(k in step for k in crop_keywords)] |
| |
| crop_steps = [step for step in history if any(k in step for k in crop_keywords)] |
| |
| rows.append({ |
| "Layer Index": i, |
| "Layer Name": item['name'], |
| "Shape": str(item['img'].shape), |
| "Dtype": str(item['img'].dtype), |
| "Preprocessing Steps": " -> ".join(prep_steps), |
| "Cropping/Registration Details": " -> ".join(crop_steps) |
| }) |
| |
| df = pd.DataFrame(rows) |
| filename = "sandwich_metadata.csv" |
| df.to_csv(filename, index=False) |
| return filename |
|
|
| |
| |
| |
|
|
| css = """ |
| .gradio-container {background-color: #f4f6f9} |
| .section-header {background: #eef2f6; padding: 10px; border-radius: 8px; margin-bottom: 10px;} |
| """ |
|
|
| with gr.Blocks(title="MRS Demo") as app: |
| |
| s_fixed_raw = gr.State(None) |
| s_fixed_proc = gr.State(None) |
| s_f_hist = gr.State([]) |
| |
| s_moving_raw = gr.State(None) |
| s_moving_proc = gr.State(None) |
| s_m_hist = gr.State([]) |
| |
| r_fixed_raw = gr.State(None) |
| r_fixed_proc = gr.State(None) |
| r_moving_raw = gr.State(None) |
| r_moving_proc = gr.State(None) |
| |
| r_params = gr.State({"dx": 0, "dy": 0, "sym": "R0"}) |
| |
|
|
| s_stack = gr.State([]) |
| |
| res_fixed_state = gr.State(None) |
| res_fixed_hist_state = gr.State([]) |
|
|
| res_moving_state = gr.State(None) |
| res_moving_hist_state = gr.State([]) |
|
|
| gr.Markdown("## MRS 🥪 : Multimodal Registration Sandwich") |
| gr.Markdown("<span style='font-size: 16px; font-weight: normal;'>Generalized Framework for Universal Microscopy Image Correlation</span>") |
|
|
|
|
| with gr.Tabs() as main_tabs: |
|
|
| |
| |
| |
| with gr.TabItem("1. Upload & Preprocess", id=0): |
| |
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown("<h3 style='text-align: center;'>Fixed Image (Reference)</h3>") |
| |
| file_f_input = gr.File(label="Upload Fixed Image (TIFF/PNG/JPG)") |
| plot_f_view = gr.Plot(label="Fixed Preview (Interactive)", container=True) |
| |
| with gr.Group(): |
| gr.Markdown("**Structural**") |
| with gr.Row(): |
| num_f_resize = gr.Number(value=1.0, step=0.1, label="Scale", container=False, scale=1) |
| btn_f_resize = gr.Button("Apply Resize", size="sm", scale=2) |
| |
| gr.Markdown("**Preprocess Options**") |
| with gr.Row(): |
| btn_f_gray = gr.Button("Gray", size="sm") |
| btn_f_norm = gr.Button("Norm", size="sm") |
| btn_f_inv = gr.Button("Invert", size="sm") |
| with gr.Row(): |
| btn_f_clahe = gr.Button("CLAHE", size="sm") |
| btn_f_mask = gr.Button("Binary Mask", size="sm") |
| |
| btn_f_reset = gr.Button("Reset", variant="secondary") |
|
|
| |
| with gr.Column(): |
| gr.Markdown("<h3 style='text-align: center;'>Moving Image (To Register)</h3>") |
| |
| file_m_input = gr.File(label="Upload Moving Image (TIFF/PNG/JPG)") |
| plot_m_view = gr.Plot(label="Moving Preview (Interactive)", container=True) |
| |
| with gr.Group(): |
| gr.Markdown("**Structural**") |
| with gr.Row(): |
| num_m_resize = gr.Number(value=1.0, step=0.1, label="Scale", container=False, scale=1) |
| btn_m_resize = gr.Button("Apply Resize", size="sm", scale=2) |
| |
| gr.Markdown("**Preprocess Options**") |
| with gr.Row(): |
| btn_m_gray = gr.Button("Gray", size="sm") |
| btn_m_norm = gr.Button("Norm", size="sm") |
| btn_m_inv = gr.Button("Invert", size="sm") |
| with gr.Row(): |
| btn_m_clahe = gr.Button("CLAHE", size="sm") |
| btn_m_mask = gr.Button("Binary Mask", size="sm") |
| |
| btn_m_reset = gr.Button("Reset", variant="secondary") |
|
|
| with gr.Row(): |
| btn_proceed = gr.Button("Proceed to Registration", variant="primary", size="lg") |
|
|
| |
| |
| |
| with gr.TabItem("2. Register & Crop", id=1): |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("#### 1. Auto-Match", elem_classes="section-header") |
| dd_algo = gr.Dropdown( |
| ["Pixel Intensity (Brute Force)"], |
| value="Pixel Intensity (Brute Force)", show_label=False |
| ) |
| btn_auto_match = gr.Button("Run Auto-Match", variant="primary") |
| |
| gr.Markdown("#### 2. Fine-Tune", elem_classes="section-header") |
| slider_x = gr.Number(label="Shift X", value=0) |
| slider_y = gr.Number(label="Shift Y", value=0) |
| |
| with gr.Row(): |
| btn_rot_cw = gr.Button("↻ 90°") |
| btn_rot_ccw = gr.Button("↺ -90°") |
| |
| gr.Markdown("---") |
| btn_confirm = gr.Button("Confirm & Crop Originals", variant="primary") |
|
|
| |
| with gr.Column(scale=3): |
| gr.Markdown("<div style='text-align: center; font-weight: 900 !important;'>Live Alignment Preview</div>") |
| img_overlay = gr.Plot(label="Live Overlay", container=True) |
| slider_opacity = gr.Slider(0, 1, value=0.5, label="Opacity") |
|
|
| |
| with gr.Group(visible=False) as grp_results: |
| gr.Markdown("### Final Cropped Originals", elem_classes="section-header") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("**Fixed Image (Cropped)**") |
| res_fixed_crop = gr.Image(interactive=False, height=300) |
| btn_add_f = gr.Button("Add to Sandwich") |
| btn_use_f = gr.Button("Use as Input (Tab 1)") |
| |
| with gr.Column(): |
| gr.Markdown("**Moving Image**") |
| res_moving_crop = gr.Image(interactive=False, height=300) |
| btn_add_m = gr.Button("Add to Sandwich") |
| btn_use_m = gr.Button("Use as Input (Tab 1)") |
| |
| txt_log = gr.Textbox(label="Log", lines=2) |
|
|
| |
| |
| |
| with gr.TabItem("3. Sandwich & Export", id=2): |
| gr.Markdown("### Sandwich Composition", elem_classes="section-header") |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| |
| gr.Markdown("*Visual Layers (Full View)*") |
| gallery = gr.Gallery( |
| label="Visual Layers", |
| columns=3, |
| height="auto", |
| object_fit="contain" |
| ) |
| |
| df_layers = gr.Dataframe( |
| headers=["Index", "Name", "Shape", "Steps"], |
| label="Layer Metadata", |
| interactive=False |
| ) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("**Manage Layers**") |
| num_idx = gr.Number(label="Target Layer Index", value=0, precision=0) |
| |
| with gr.Row(): |
| btn_up = gr.Button("⬆️ Up") |
| btn_down = gr.Button("⬇️ Down") |
| |
| with gr.Row(): |
| txt_rename = gr.Textbox(label="New Name", placeholder="Enter name...") |
| btn_rename = gr.Button(" Rename") |
| |
| btn_delete = gr.Button(" Delete Layer", variant="stop") |
| |
| gr.Markdown("---") |
| gr.Markdown("### Export Options") |
| |
| with gr.Row(): |
| with gr.Column(): |
| btn_exp_images = gr.Button("1. Prepare Images (.zip)", variant="secondary") |
| file_zip = gr.File(label="Images (.zip)") |
| |
| with gr.Column(): |
| btn_exp_tensor = gr.Button("2. Prepare Tensor (.npy & .txt)", variant="secondary") |
| with gr.Row(): |
| file_npy = gr.File(label="Numpy Binary (.npy)") |
| file_txt_arr = gr.File(label="Text Representation (.txt)") |
|
|
| with gr.Column(): |
| btn_exp_data = gr.Button("3. Prepare Metadata & Montage", variant="secondary") |
| with gr.Row(): |
| file_csv = gr.File(label="Metadata (.csv)") |
| file_montage = gr.File(label="Montage Plot (.png)") |
|
|
| gr.Markdown("---") |
| gr.Markdown(""" |
| <div style='text-align: center; color: #666;'> |
| <p style='margin-bottom: 0;'>Datrix SPA - Politecnico di Milano (VIBRA Group)</p> |
| <p style='font-size: 14px; margin-top: 5px;'>Contact: marc.rodriguez@datrixgroup.com / marc.rodriguez@polimi.it</p> |
| </div> |
| """) |
|
|
| |
| |
| |
|
|
| |
| def on_file_upload(file_obj): |
| raw_img = load_from_file(file_obj) |
| norm = load_and_normalize(raw_img) |
| return norm, norm, to_interactive_plot(to_display(norm, 800), height=350), ["Loaded"] |
|
|
| file_f_input.change(on_file_upload, file_f_input, [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist]) |
| file_m_input.change(on_file_upload, file_m_input, [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist]) |
|
|
| |
| def process_image(raw, current_proc, history, op, param=None): |
| if raw is None: return None, None, None, history |
| |
| updated_history = list(history) |
| |
| |
| if op == "Resize": |
| try: |
| scale = float(param) |
| new_raw = resize_image(raw, scale) |
| updated_history.append(f"Resize (scale={scale})") |
| return new_raw, new_raw, to_interactive_plot(to_display(new_raw, 800), height=350), updated_history |
| except: |
| return raw, current_proc, to_interactive_plot(to_display(current_proc, 800), height=350), history |
|
|
| |
| if op == "Reset": |
| return raw, raw, to_interactive_plot(to_display(raw, 800), height=350), ["Reset"] |
|
|
| img_to_mod = current_proc.copy() |
| |
| if op == "Gray": |
| if img_to_mod.ndim == 3: res = rgb2gray(img_to_mod).astype(np.uint8) |
| else: res = img_to_mod |
| updated_history.append("Grayscale") |
| elif op == "Invert": |
| res = cv2.bitwise_not(img_to_mod) |
| updated_history.append("Invert") |
| elif op == "Norm": |
| norm = Normalize(img_to_mod) |
| res = (norm * 255).astype(np.uint8) |
| updated_history.append("Normalize") |
| elif op == "CLAHE": |
| res = apply_clahe(img_to_mod) |
| updated_history.append("CLAHE") |
| elif op == "Binary Mask": |
| res = binary_mask(img_to_mod) |
| updated_history.append("Binary Mask") |
| else: |
| res = img_to_mod |
| |
| return raw, res, to_interactive_plot(to_display(res, 800), height=350), updated_history |
|
|
| |
| btn_f_resize.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State("Resize"), num_f_resize], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist]) |
| for btn, op in [(btn_f_gray,"Gray"), (btn_f_norm,"Norm"), (btn_f_inv,"Invert"), (btn_f_clahe,"CLAHE"), (btn_f_mask, "Binary Mask")]: |
| btn.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State(op)], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist]) |
| btn_f_reset.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State("Reset")], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist]) |
|
|
| |
| btn_m_resize.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State("Resize"), num_m_resize], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist]) |
| for btn, op in [(btn_m_gray,"Gray"), (btn_m_norm,"Norm"), (btn_m_inv,"Invert"), (btn_m_clahe,"CLAHE"), (btn_m_mask, "Binary Mask")]: |
| btn.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State(op)], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist]) |
| btn_m_reset.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State("Reset")], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist]) |
|
|
| |
| def proceed_to_reg(fr, fp, mr, mp): |
| if fr is None or mr is None: return None, None, None, None, gr.Tabs() |
| return fr, fp, mr, mp, gr.Tabs(selected=1) |
|
|
| btn_proceed.click(proceed_to_reg, [s_fixed_raw, s_fixed_proc, s_moving_raw, s_moving_proc], [r_fixed_raw, r_fixed_proc, r_moving_raw, r_moving_proc, main_tabs]) |
|
|
| |
| |
| def gen_overlay(fixed_raw, moving_raw, dx, dy, sym, opacity): |
| """Generates Interactive Plotly Overlay.""" |
| if fixed_raw is None or moving_raw is None: return None |
| moved = get_symmetry(moving_raw, sym) |
| h, w = fixed_raw.shape[:2] |
| if fixed_raw.ndim==3: canvas = np.zeros((h, w, 3), dtype=np.uint8) |
| else: canvas = np.zeros((h, w), dtype=np.uint8) |
| h_m, w_m = moved.shape[:2] |
| x1, y1 = max(0, int(dx)), max(0, int(dy)) |
| x2, y2 = min(w, int(dx)+w_m), min(h, int(dy)+h_m) |
| mx1, my1 = max(0, -int(dx)), max(0, -int(dy)) |
| try: |
| if x2>x1 and y2>y1: |
| canvas[y1:y2, x1:x2] = moved[my1:my1+(y2-y1), mx1:mx1+(x2-x1)] |
| except: pass |
| f = fixed_raw if fixed_raw.ndim==3 else cv2.cvtColor(fixed_raw, cv2.COLOR_GRAY2RGB) |
| m = canvas if canvas.ndim==3 else cv2.cvtColor(canvas, cv2.COLOR_GRAY2RGB) |
| blended = cv2.addWeighted(f, 1-opacity, m, opacity, 0) |
| return to_interactive_plot(to_display(blended, 800), height=500) |
|
|
| def on_auto(fp, mp, fr, mr, algo): |
| if fp is None: return 0,0,"R0", "No Data", None, None |
| fp_g = fp if fp.ndim==2 else rgb2gray(fp) |
| mp_g = mp if mp.ndim==2 else rgb2gray(mp) |
| try: |
| if "Pixel" in algo: res, _ = find_best_match_pixel(fp_g, mp_g, "x") |
| elif "Feature" in algo: res, _ = find_best_match_features(fp_g, mp_g, "x") |
| elif "FFT" in algo: res, _ = find_best_match_fft(fp_g, mp_g, "x") |
| if res: |
| loc = res.get('location', (0,0)) |
| if 'shift_xy' in res: loc = res['shift_xy'] |
| dx, dy, sym = int(loc[0]), int(loc[1]), res['symmetry'] |
| log = f"Score: {res['score'] * 100:.2f}% (Symmetry: {sym})" |
| else: |
| dx, dy, sym, log = 0, 0, "R0", "No match" |
| except Exception as e: |
| dx, dy, sym, log = 0, 0, "R0", str(e) |
| ov = gen_overlay(fr, mr, dx, dy, sym, 0.5) |
| return dx, dy, log, ov, {"dx": dx, "dy": dy, "sym": sym} |
|
|
| btn_auto_match.click(on_auto, [r_fixed_proc, r_moving_proc, r_fixed_raw, r_moving_raw, dd_algo], [slider_x, slider_y, txt_log, img_overlay, r_params]) |
|
|
| def on_manual(fr, mr, dx, dy, op, params): |
| sym = params.get("sym", "R0") |
| ov = gen_overlay(fr, mr, dx, dy, sym, op) |
| params["dx"], params["dy"] = dx, dy |
| return ov, params |
|
|
| for inp in [slider_x, slider_y, slider_opacity]: |
| inp.change(on_manual, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params], [img_overlay, r_params]) |
|
|
| def on_rot(fr, mr, dx, dy, op, params, dir): |
| syms = ["R0", "R1", "R2", "R3"] |
| idx = syms.index(params.get("sym", "R0")) if params.get("sym", "R0") in syms else 0 |
| idx = (idx+1)%4 if dir=="cw" else (idx-1)%4 |
| params["sym"] = syms[idx] |
| ov = gen_overlay(fr, mr, dx, dy, params["sym"], op) |
| return ov, params |
|
|
| btn_rot_cw.click(on_rot, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params, gr.State("cw")], [img_overlay, r_params]) |
| btn_rot_ccw.click(on_rot, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params, gr.State("ccw")], [img_overlay, r_params]) |
|
|
| |
| def apply_crop_wrapper(f_raw, m_raw, params, f_hist, m_hist): |
| dx, dy, sym = int(params["dx"]), int(params["dy"]), params["sym"] |
| m_moved = get_symmetry(m_raw, sym) |
| h_f, w_f = f_raw.shape[:2] |
| h_m, w_m = m_moved.shape[:2] |
| x1, y1 = max(0, dx), max(0, dy) |
| x2, y2 = min(w_f, dx + w_m), min(h_f, dy + h_m) |
| |
| if x2 <= x1 or y2 <= y1: return None, None, [], [], None, None, gr.Group(visible=True) |
| |
| full_c_fixed = f_raw[y1:y2, x1:x2] |
| mx1, my1 = x1 - dx, y1 - dy |
| full_c_moving = m_moved[my1:my1+(y2-y1), mx1:mx1+(x2-x1)] |
| |
| view_fixed = to_display(full_c_fixed) |
| view_moving = to_display(full_c_moving) |
| |
| |
| new_f_hist = list(f_hist) + [f"Cropped to ({x1}:{x2}, {y1}:{y2})"] |
| new_m_hist = list(m_hist) + [f"Symmetry {sym}", f"Shift ({dx},{dy})", "Cropped to Intersection"] |
| |
| return full_c_fixed, full_c_moving, new_f_hist, new_m_hist, view_fixed, view_moving, gr.Group(visible=True) |
|
|
| btn_confirm.click(apply_crop_wrapper, |
| [r_fixed_raw, r_moving_raw, r_params, s_f_hist, s_m_hist], |
| [res_fixed_state, res_moving_state, res_fixed_hist_state, res_moving_hist_state, res_fixed_crop, res_moving_crop, grp_results]) |
|
|
| |
| def update_sandwich_ui(stack): |
| |
| data = [] |
| for i, x in enumerate(stack): |
| hist_summary = " -> ".join(x.get("history", [])) |
| data.append([i, x["name"], str(x["img"].shape), hist_summary]) |
| |
| imgs = [to_display(x["img"], 300) for x in stack] |
| return stack, imgs, data |
|
|
| |
| def add_to_stack(img, hist, stack): |
| if img is None: return stack, [], [] |
| name = f"Layer_{len(stack)}" |
| |
| stack.append({"name": name, "img": img, "history": hist}) |
| return update_sandwich_ui(stack) |
|
|
| def manage_layers(stack, idx, action, new_name): |
| if not stack: return update_sandwich_ui(stack) |
| i = int(idx) |
| if i < 0 or i >= len(stack): return update_sandwich_ui(stack) |
| |
| if action == "delete": |
| stack.pop(i) |
| elif action == "up": |
| if i > 0: |
| stack[i], stack[i-1] = stack[i-1], stack[i] |
| elif action == "down": |
| if i < len(stack) - 1: |
| stack[i], stack[i+1] = stack[i+1], stack[i] |
| elif action == "rename": |
| if new_name.strip(): |
| stack[i]["name"] = new_name.strip() |
| |
| return update_sandwich_ui(stack) |
|
|
| |
| btn_add_f.click(add_to_stack, [res_fixed_state, res_fixed_hist_state, s_stack], [s_stack, gallery, df_layers]) |
| btn_add_m.click(add_to_stack, [res_moving_state, res_moving_hist_state, s_stack], [s_stack, gallery, df_layers]) |
|
|
| |
| btn_up.click(manage_layers, [s_stack, num_idx, gr.State("up"), txt_rename], [s_stack, gallery, df_layers]) |
| btn_down.click(manage_layers, [s_stack, num_idx, gr.State("down"), txt_rename], [s_stack, gallery, df_layers]) |
| btn_delete.click(manage_layers, [s_stack, num_idx, gr.State("delete"), txt_rename], [s_stack, gallery, df_layers]) |
| btn_rename.click(manage_layers, [s_stack, num_idx, gr.State("rename"), txt_rename], [s_stack, gallery, df_layers]) |
|
|
| |
| def recycle(img): |
| norm = load_and_normalize(img) |
| plot = to_interactive_plot(to_display(norm, 800), height=350) |
| |
| return norm, norm, plot, ["Recycled from Crop"], gr.Tabs(selected=0) |
| |
| btn_use_f.click(recycle, res_fixed_state, [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist, main_tabs]) |
| btn_use_m.click(recycle, res_moving_state, [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist, main_tabs]) |
|
|
| |
| |
| |
| def do_export_zip(stack): |
| if not stack: return None |
| zip_file = "sandwich_images.zip" |
| with zipfile.ZipFile(zip_file, 'w') as zf: |
| for i, item in enumerate(stack): |
| bgr = cv2.cvtColor(item["img"], cv2.COLOR_RGB2BGR) if item["img"].ndim==3 else item["img"] |
| _, buf = cv2.imencode('.png', bgr) |
| zf.writestr(f"{i}_{item['name']}.png", buf) |
| return zip_file |
|
|
| btn_exp_images.click(do_export_zip, s_stack, file_zip) |
|
|
| |
| btn_exp_tensor.click(create_full_tensor_files, s_stack, [file_npy, file_txt_arr]) |
| |
| |
| def do_export_meta_montage(stack): |
| csv = create_metadata_csv(stack) |
| img = create_montage_plot(stack) |
| return csv, img |
|
|
| btn_exp_data.click(do_export_meta_montage, s_stack, [file_csv, file_montage]) |
|
|
| if __name__ == "__main__": |
| app.launch(inbrowser=True, share=False, theme=gr.themes.Soft(), css=css) |