Sandwich / app.py
marc-rod's picture
Update app.py
634f585 verified
import gradio as gr
import cv2
import numpy as np
import zipfile
import io
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import MinMaxScaler
import os
import cv2
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pathlib import Path
import SimpleITK as sitk
import matplotlib.patches as mpatches
import napari
#%% === Main Functionalities ===
def rgb2gray(rgb):
"""
Converts RGB images into grayscale images based on the formula
Input parameters: RGB image
Output: Graycale image
"""
if rgb.ndim == 2: return rgb
return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])
def Normalize(image):
img_min = np.min(image)
img_max = np.max(image)
if img_max == img_min:
return np.zeros_like(image)
return (image - img_min)/(img_max-img_min)
def Normalize_percentiles (image, low_perc, upp_perc):
p_min = np.percentile(image, low_perc)
p_max = np.percentile(image, upp_perc)
if p_max == p_min:
return np.zeros_like(image)
norm_img = (image- p_min)/(p_max-p_min)
norm_img = np.clip(norm_img, 0, 1)
return norm_img
#%% === Pre-processment Functionalities ===
def prepare_base_image (path):
img = cv2.imread(path)
gray = rgb2gray(img)
norm = Normalize(gray)
final_img = (norm * 255).astype(np.uint8)
return final_img, img
def resize_image (image, scale_factor):
h, w = image.shape[:2]
new_size = (int(w * scale_factor), int(h * scale_factor))
resized = cv2.resize(image, new_size, interpolation=cv2.INTER_LANCZOS4)
return resized
def apply_clahe(image, use_clahe=True):
if use_clahe:
# clahe_obj = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
clahe_obj = cv2.createCLAHE(clipLimit=6.0, tileGridSize=(8,8))
enhanced = clahe_obj.apply(image)
else:
enhanced = cv2.equalizeHist(image)
return enhanced
def align_centers(fixed_img, moving_img):
# 1. Obtain dimensions of both images
h_fixed, w_fixed = fixed_img.shape[:2]
h_moving, w_moving = moving_img.shape[:2]
# --- Compute moments ---
M_fixed = cv2.moments(fixed_img)
if M_fixed["m00"] == 0:
print("Warning: Fixed image is empty/black. Skipping CoM alignment.")
return moving_img, (0, 0)
cX_fixed = int(M_fixed["m10"] / M_fixed["m00"])
cY_fixed = int(M_fixed["m01"] / M_fixed["m00"])
M_moving = cv2.moments(moving_img)
if M_moving["m00"] == 0:
print("Warning: Moving image is empty/black. Skipping CoM alignment.")
return moving_img, (0, 0)
cX_moving = int(M_moving["m10"] / M_moving["m00"])
cY_moving = int(M_moving["m01"] / M_moving["m00"])
# Compute displacement
shift_x = cX_fixed - cX_moving
shift_y = cY_fixed - cY_moving
# Transformation matrix
T = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
centered_moving = cv2.warpAffine(
moving_img,
T,
(w_fixed, h_fixed),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0
)
return centered_moving, (shift_x, shift_y)
def Gaussian_blur(image, kernel_size=5, sigma=0):
if kernel_size % 2 == 0:
kernel_size += 1
print("Kernel adjust:{kernel_size}")
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
return blurred
def find_edges(image, sigma=0.33):
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
v = np.median(image)
lower = int(max(0, (1.0 - sigma) * v))
upper = int(min(255, (1.0 + sigma) * v))
edged = cv2.Canny(image, lower, upper)
return edged
def find_edges_binary(image):
_, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
edged = cv2.Canny(binary, 30, 100)
return edged
def binary_mask(image, kernel_size=5):
"""
Converts an image into a solid binary mask (Silhouette).
Assumes the input image is already preprocessed so that the
tissue is bright and the background is dark (Bright-on-Dark).
Args:
image: Input image (Grayscale or BGR).
kernel_size: Size of the structuring element for morphological operations.
Larger size removes bigger noise spots but might smooth shape details.
Returns:
clean_mask: A binary image (0 and 255) containing only the main tissue shape.
"""
# 1. CONVERT TO GRAYSCALE
# Ensure we are working with a single channel
if image.ndim == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image.copy()
# 2. GAUSSIAN BLUR
# Essential to reduce high-frequency noise before thresholding
blurred = cv2.GaussianBlur(gray, (7, 7), 0)
# 3. BINARIZATION (Otsu's Method)
# Otsu automatically finds the optimal threshold value
# Since you already inverted BF, we use standard THRESH_BINARY
# (Pixels > threshold becomes 255/White, others become 0/Black)
_, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 4. MORPHOLOGICAL OPERATIONS (Cleaning)
# Create an elliptical kernel for smooth edges
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# a) CLOSE: Dilate -> Erode
# Fills small holes INSIDE the tissue to make it solid
# We use 2 iterations to ensure gaps are closed
solid_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
# b) OPEN: Erode -> Dilate.
# Removes small noise/dust OUTSIDE the tissue.
clean_mask = cv2.morphologyEx(solid_mask, cv2.MORPH_OPEN, kernel, iterations=1)
return clean_mask
#%% === Simple Matcher Functionalities ===
def get_symmetry(image, label):
"""
Apply symmetry for a SQUARE (Group D8) or a RECTANGLE that can rotate 90° (D2+90° rotations):
R0, R1, R2, R3, M1, M2, D1, D2
"""
# Rotations
if label == "R0": return image
elif label == "R1": return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) # 90°
elif label == "R2": return cv2.rotate(image, cv2.ROTATE_180) # 180°
elif label == "R3": return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) # 270°
# Mirror
elif label == "M1": return cv2.flip(image, 1) # Flip y
elif label == "M2": return cv2.flip(image, 0) # Flip x
# Diagonals
elif label == "D1":
transposed = cv2.transpose(image)
return transposed
elif label == "D2":
transposed = cv2.transpose(image)
return cv2.flip(transposed, 0)
return image
def find_best_match_pixel(img_1, img_2, image_name):
results = []
h1, w1 = img_1.shape[:2]
h2, w2 = img_2.shape[:2]
area1 = h1 * w1
area2 = h2 * w2
if area1 > area2:
image_big = img_1
image_small = img_2
else:
image_big = img_2
image_small = img_1
h_big, w_big = image_big.shape[:2]
h_small, w_small = image_small.shape[:2]
max_dim_small = max(h_small, w_small)
missing_h = max(0, max_dim_small - h_big)
missing_w = max(0, max_dim_small - w_big)
pad_top = 0
pad_left = 0
if missing_h > 0 or missing_w > 0:
margin = int(max_dim_small * 0.1)
if missing_h > 0:
pad_top = (missing_h // 2) + margin
pad_bottom = (missing_h // 2) + margin
else:
pad_top = 0
pad_bottom = 0
if missing_w > 0:
pad_left = (missing_w // 2) + margin
pad_right = (missing_w // 2) + margin
else:
pad_left = 0
pad_right = 0
image_big = cv2.copyMakeBorder(
image_big,
pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=0
)
symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"]
for sym in symmetries:
current_small = get_symmetry(image_small, sym)
res = cv2.matchTemplate(image_big, current_small, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(res)
real_x = max_loc[0] - pad_left
real_y = max_loc[1] - pad_top
results.append({
"name": image_name,
"symmetry": sym,
"score": max_val,
"location": (real_x, real_y),
"raman_img": current_small
})
best_result = max(results, key=lambda x: x['score'])
return best_result, results
def find_best_match_features(img_1, img_2, image_name):
"""
Tests all 8 symmetries of img_raman (img_1) against img_bf (img_2)
using SIFT Features + KNN Matching + Lowe's Ratio Test + RANSAC.
Returns dictionary with the best result based on Inlier Count.
"""
# 1. Feature detector (ORB or SIFT)
#sift = cv2.ORB_create()
sift = cv2.SIFT_create()
# 2.Compute features in Img 2
kp2, des2 = sift.detectAndCompute(img_2, None)
results = []
# 3. Symmetry loop for Img 1
symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"]
for sym in symmetries:
current_raman = get_symmetry(img_1, sym)
kp1, des1 = sift.detectAndCompute(current_raman, None)
if des1 is None or des2 is None or len(kp1) < 5 or len(kp2) < 5:
continue
bf = cv2.BFMatcher(cv2.NORM_L2)
matches = bf.knnMatch(des1, des2, k=2)
good_matches = []
for m, n in matches:
if m.distance < 0.75 * n.distance:
good_matches.append(m)
score = 0
H = None
matches_mask = []
if len(good_matches) >= 4:
src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
if mask is not None:
matches_mask = mask.ravel().tolist()
score = np.sum(matches_mask)
results.append({
"name": image_name,
"symmetry": sym,
"score": score,
"homography": H,
"raman_img": current_raman,
"keypoints_1": kp1,
"good_matches": good_matches,
"matches_mask": matches_mask
})
if not results:
return None, []
best_result = max(results, key=lambda x: x['score'])
return best_result, results
def find_best_match_fft(img_1, img_2, image_name):
"""
Tests all 8 symmetries using FFT Phase Correlation.
"""
# 1. PREPARATION
# FFT requires float32 or float64
# We use the larger image as the base size
h_max = max(img_1.shape[0], img_2.shape[0])
w_max = max(img_1.shape[1], img_2.shape[1])
# Helper to pad image to target size (center alignment)
def pad_to_size(img, th, tw):
h, w = img.shape
if h == th and w == tw: return img.astype(np.float32)
padded = np.zeros((th, tw), dtype=np.float32)
# Place in the center (helps with Hanning window)
y_off = (th - h) // 2
x_off = (tw - w) // 2
padded[y_off:y_off+h, x_off:x_off+w] = img
return padded
# Prepare Target (img_2)
img_2_float = pad_to_size(img_2, h_max, w_max)
# Create Hanning Window to reduce edge effects (Spectral Leakage)
# This greatly improves accuracy for non-periodic images like tissues
window = cv2.createHanningWindow((w_max, h_max), cv2.CV_32F)
results = []
# 2. SYMMETRY LOOP
symmetries = ["R0", "R1", "R2", "R3", "M1", "M2", "D1", "D2"]
for sym in symmetries:
# a) Transform Template
current_raman = get_symmetry(img_1, sym)
# b) Pad to match size
current_raman_float = pad_to_size(current_raman, h_max, w_max)
# c) PHASE CORRELATION
# Returns: (dx, dy) shift and 'response' (confidence 0.0 to 1.0)
try:
# We apply the Hanning window to both images
shift, response = cv2.phaseCorrelate(current_raman_float, img_2_float, window=window)
# Unpack shift
dx, dy = shift
results.append({
"name": image_name,
"symmetry": sym,
"score": response, # Higher is better (0 to 1)
"shift_xy": (dx, dy),
"raman_img": current_raman # Store original unpadded for visualization
})
except Exception as e:
# FFT can fail if images are tiny or completely zero
print(f"FFT Error on {sym}: {e}")
continue
# 3. SELECT WINNER
if not results:
return None, []
# Best match is the one with highest Phase Correlation Response (Peak)
best_result = max(results, key=lambda x: x['score'])
return best_result, results
#%% === Fine Tuning Functionalities ===
def fine_tune_registration(fixed_img_cv, moving_img_cv, transform_type):
"""
Complete registration using Mutual Information.
Args:
fixed_img_cv: BF crop (grayscale).
moving_img_cv: Preoriented Raman (grayscale).
transform_type: "Rigid", "Similarity", "Affine", "BSpline"
Returns:
registered_img (numpy): transformed Raman image
final_transform (sitk.Transform): mathematical computed matrix
"""
fixed = sitk.GetImageFromArray(fixed_img_cv.astype(np.float32))
moving = sitk.GetImageFromArray(moving_img_cv.astype(np.float32))
if transform_type == "Rigid":
# DOF: 3 (Rot + Trans)
initial_transform = sitk.CenteredTransformInitializer(
fixed, moving,
sitk.Euler2DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY
)
elif transform_type == "Similarity":
# DOF: 4 (Rot + Trans + Escala Uniforme)
# Ideal si hay diferencia de zoom real entre microscopios
initial_transform = sitk.CenteredTransformInitializer(
fixed, moving,
sitk.Similarity2DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY
)
elif transform_type == "Affine":
# DOF: 6 (Rot + Trans + Escala + Shear)
initial_transform = sitk.CenteredTransformInitializer(
fixed, moving,
sitk.AffineTransform(2),
sitk.CenteredTransformInitializerFilter.GEOMETRY
)
elif transform_type == "BSpline":
# DOF: (Elastic deformation / Non-Rigid)
# BSpline needs a previous inicialization (generally Affine)
init_rigid = sitk.CenteredTransformInitializer(
fixed, moving, sitk.Euler2DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY
)
# Deformation (3x3 Grid)
grid_physical_spacing = [50.0, 50.0] # Adjustable according to pixel size
mesh_size = [3, 3]
initial_transform = sitk.BSplineTransformInitializer(fixed, mesh_size)
else:
raise ValueError("Use 'Rigid', 'Similarity', 'Affine' or 'BSpline'.")
# 3. Configurate Register Method
R = sitk.ImageRegistrationMethod()
# Metrics
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
R.SetMetricSamplingStrategy(R.RANDOM)
R.SetMetricSamplingPercentage(0.3)
# Optimizer
if transform_type == "BSpline":
# LBFGSB is better for high dimensionality (BSpline)
R.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, numberOfIterations=100, maximumNumberOfCorrections=5)
else:
# Gradient Descent for linear transformations
R.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
R.SetOptimizerScalesFromPhysicalShift()
# Final configuration
R.SetInitialTransform(initial_transform, inPlace=False)
R.SetInterpolator(sitk.sitkLinear)
# 4. Execute register
try:
final_transform = R.Execute(fixed, moving)
print(f"Register {transform_type} complete. Metric value: {R.GetMetricValue():.4f}")
except Exception as e:
print(f"Register {transform_type} fails: {e}")
return moving_img_cv, None
# 5. Apply transformation (Resample)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed)
resampler.SetInterpolator(sitk.sitkBSpline) # BSpline
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(final_transform)
out_sitk = resampler.Execute(moving)
return sitk.GetArrayFromImage(out_sitk), final_transform
#%% === Visualization ===
def visual_debugger(img_1, img_2):
_, bin_r = cv2.threshold(img_1.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
ret, bf_thresh = cv2.threshold(img_2.astype(np.uint8), 30, 255, cv2.THRESH_TOZERO)
_, bin_b = cv2.threshold(bf_thresh, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
bin_b_processed = cv2.dilate(bin_b, kernel, iterations=1)
bin_b_processed = cv2.morphologyEx(bin_b_processed, cv2.MORPH_CLOSE, kernel)
h, w = bin_r.shape
viz = np.zeros((h, w, 3), dtype=np.uint8)
# Channel Red: Img 1
viz[:,:,0] = bin_r
# Channel Green: Bright Field
viz[:,:,1] = bin_b_processed
mask_r = bin_r > 0
mask_b = bin_b_processed > 0
intersection = np.count_nonzero(np.logical_and(mask_r, mask_b))
area_r = np.count_nonzero(mask_r)
area_b = np.count_nonzero(mask_b)
score = intersection / min(area_r, area_b) if min(area_r, area_b) > 0 else 0
plt.figure(figsize=(12, 12))
plt.imshow(viz)
plt.title(f"Visual Debugger | Overlap Score: {score:.4f}", fontsize=14, fontweight='bold')
plt.axis('off')
patch_red = mpatches.Patch(color='red', label='Img 1')
patch_green = mpatches.Patch(color='green', label='Img 2')
patch_yellow = mpatches.Patch(color='yellow', label='MATCH')
plt.legend(handles=[patch_red, patch_green, patch_yellow],
loc='upper right', framealpha=0.9, fontsize=12, facecolor='black', labelcolor='white')
plt.tight_layout()
plt.show()
def show_in_napari (img_1, img_2):
_, bin_1 = cv2.threshold(img_1.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
ret, bf_thresh = cv2.threshold(img_2.astype(np.uint8), 30, 255, cv2.THRESH_TOZERO)
_, bin_2 = cv2.threshold(bf_thresh, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
# bin_1_dilated = cv2.dilate(bin_1, kernel, iterations=1)
# bin_1_final = cv2.morphologyEx(bin_1_dilated, cv2.MORPH_CLOSE, kernel)
viewer = napari.Viewer(title="IMG 1 vs IMG 2 Registration Debugger")
# LAYER IMG 1
viewer.add_image(
img_1,
name='IMG 1',
colormap='gray',
opacity=1.0
)
# LAYER IMG 2
viewer.add_image(
img_2,
name='IMG 2',
colormap='inferno',
blending='additive',
opacity=0.8
)
# LAYER MASK IMG 1
viewer.add_image(
bin_1,
name='Debug: IMG 1 Mask',
colormap='green',
blending='additive',
opacity=0.5,
visible=False
)
# LAYER MASK IMG 2
viewer.add_image(
bin_2,
name='Debug: IMG 2 Mask',
colormap='red',
blending='additive',
opacity=0.5,
visible=False
)
napari.run()
def RGBA_visualization(img_1, img_2):
# 1. Asegurar que ambas sean uint8 (necesario para merge)
if img_1.dtype != np.uint8:
img_1 = cv2.normalize(img_1, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
if img_2.dtype != np.uint8:
img_2 = cv2.normalize(img_2, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
# 2. Forzar coincidencia EXACTA de dimensiones (Alto y Ancho)
# OpenCV resize usa (Ancho, Alto)
if img_1.shape[:2] != img_2.shape[:2]:
img_2 = cv2.resize(img_2, (img_1.shape[1], img_1.shape[0]))
# 3. Crear el mapa de color (esto genera una imagen de 3 canales uint8)
img_1_color = cv2.applyColorMap(img_1, cv2.COLORMAP_JET)
# 4. Separar canales
b, g, r = cv2.split(img_1_color)
# 5. El canal alpha debe ser del mismo tamaño y tipo que los otros
alpha = img_2
# 6. Combinar canales en BGRA
rgba_img = cv2.merge([b, g, r, alpha])
return rgba_img, b, g, r, alpha
# ==========================================
# 1. HELPER FUNCTIONS
# ==========================================
def load_from_file(file_obj):
"""Reads an image from a generic file object (path) via OpenCV."""
if file_obj is None: return None
img = cv2.imread(file_obj.name, -1)
if img is None: return None
if img.ndim == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def load_and_normalize(image):
"""Handles TIFF/16-bit loading and normalization."""
if image is None: return None
if image.dtype == np.uint16 or image.dtype == np.float32 or image.dtype == np.float64:
min_val = np.min(image)
max_val = np.max(image)
if max_val - min_val > 0:
norm = (image - min_val) / (max_val - min_val)
image = (norm * 255).astype(np.uint8)
else:
image = np.zeros_like(image, dtype=np.uint8)
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
return image
def to_display(image, max_width=800):
"""Creates a small copy (Standard Numpy) for creating the Plot."""
if image is None: return None
h, w = image.shape[:2]
if w > max_width:
scale = max_width / w
new_h = int(h * scale)
return cv2.resize(image, (max_width, new_h), interpolation=cv2.INTER_NEAREST)
return image
def to_interactive_plot(image, height=400):
"""Converts a numpy image to a zoomable/pannable Plotly figure."""
if image is None: return None
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
fig = px.imshow(image)
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0),
xaxis={'showticklabels': False, 'visible': False},
yaxis={'showticklabels': False, 'visible': False},
dragmode='pan',
height=height
)
return fig
def create_full_tensor_files(stack):
"""Exports the stack as a Numpy text representation and a binary .npy."""
if not stack: return None, None
arrays = [item['img'] for item in stack]
try:
tensor_np = np.stack(arrays, axis=0)
header_info = f"Shape: {tensor_np.shape}, Dtype: {tensor_np.dtype}"
except:
tensor_np = np.array(arrays, dtype=object)
header_info = "Ragged Array (dimensions vary)"
npy_filename = "sandwich_tensor.npy"
np.save(npy_filename, tensor_np)
txt_filename = "sandwich_tensor_text.txt"
with open(txt_filename, "w") as f:
f.write(f"# {header_info}\n")
f.write(f"# Load with: import ast; data = ast.literal_eval(open('file.txt').read())\n")
f.write(str(tensor_np.tolist()))
return npy_filename, txt_filename
def create_montage_plot(stack):
"""Creates a Matplotlib figure with all layers side-by-side with axes."""
if not stack: return None
n = len(stack)
fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), constrained_layout=True)
if n == 1: axes = [axes]
for i, ax in enumerate(axes):
img = stack[i]['img']
name = stack[i]['name']
if img.ndim == 2:
ax.imshow(img, cmap='gray')
else:
ax.imshow(img)
ax.set_title(f"{i}: {name}")
ax.axis('on')
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
filename = "sandwich_montage.png"
with open(filename, "wb") as f:
f.write(buf.read())
return filename
def create_metadata_csv(stack):
"""Creates a CSV with metadata, splitting history into Preprocessing and Crop details."""
if not stack: return None
rows = []
for i, item in enumerate(stack):
history = item.get('history', [])
crop_keywords = ["Cropped", "Shift", "Symmetry"]
prep_steps = [step for step in history if not any(k in step for k in crop_keywords)]
crop_steps = [step for step in history if any(k in step for k in crop_keywords)]
rows.append({
"Layer Index": i,
"Layer Name": item['name'],
"Shape": str(item['img'].shape),
"Dtype": str(item['img'].dtype),
"Preprocessing Steps": " -> ".join(prep_steps),
"Cropping/Registration Details": " -> ".join(crop_steps) # The New Column
})
df = pd.DataFrame(rows)
filename = "sandwich_metadata.csv"
df.to_csv(filename, index=False)
return filename
# ==========================================
# 2. GRADIO INTERFACE
# ==========================================
css = """
.gradio-container {background-color: #f4f6f9}
.section-header {background: #eef2f6; padding: 10px; border-radius: 8px; margin-bottom: 10px;}
"""
with gr.Blocks(title="MRS Demo") as app:
s_fixed_raw = gr.State(None)
s_fixed_proc = gr.State(None)
s_f_hist = gr.State([])
s_moving_raw = gr.State(None)
s_moving_proc = gr.State(None)
s_m_hist = gr.State([])
r_fixed_raw = gr.State(None)
r_fixed_proc = gr.State(None)
r_moving_raw = gr.State(None)
r_moving_proc = gr.State(None)
r_params = gr.State({"dx": 0, "dy": 0, "sym": "R0"})
s_stack = gr.State([])
res_fixed_state = gr.State(None)
res_fixed_hist_state = gr.State([])
res_moving_state = gr.State(None)
res_moving_hist_state = gr.State([])
gr.Markdown("## MRS 🥪 : Multimodal Registration Sandwich")
gr.Markdown("<span style='font-size: 16px; font-weight: normal;'>Generalized Framework for Universal Microscopy Image Correlation</span>")
with gr.Tabs() as main_tabs:
# ============================================================
# TAB 1: UPLOAD & PREPROCESS
# ============================================================
with gr.TabItem("1. Upload & Preprocess", id=0):
with gr.Row():
# === FIXED IMAGE ===
with gr.Column():
gr.Markdown("<h3 style='text-align: center;'>Fixed Image (Reference)</h3>")
file_f_input = gr.File(label="Upload Fixed Image (TIFF/PNG/JPG)")
plot_f_view = gr.Plot(label="Fixed Preview (Interactive)", container=True)
with gr.Group():
gr.Markdown("**Structural**")
with gr.Row():
num_f_resize = gr.Number(value=1.0, step=0.1, label="Scale", container=False, scale=1)
btn_f_resize = gr.Button("Apply Resize", size="sm", scale=2)
gr.Markdown("**Preprocess Options**")
with gr.Row():
btn_f_gray = gr.Button("Gray", size="sm")
btn_f_norm = gr.Button("Norm", size="sm")
btn_f_inv = gr.Button("Invert", size="sm")
with gr.Row():
btn_f_clahe = gr.Button("CLAHE", size="sm")
btn_f_mask = gr.Button("Binary Mask", size="sm")
btn_f_reset = gr.Button("Reset", variant="secondary")
# === MOVING IMAGE ===
with gr.Column():
gr.Markdown("<h3 style='text-align: center;'>Moving Image (To Register)</h3>")
file_m_input = gr.File(label="Upload Moving Image (TIFF/PNG/JPG)")
plot_m_view = gr.Plot(label="Moving Preview (Interactive)", container=True)
with gr.Group():
gr.Markdown("**Structural**")
with gr.Row():
num_m_resize = gr.Number(value=1.0, step=0.1, label="Scale", container=False, scale=1)
btn_m_resize = gr.Button("Apply Resize", size="sm", scale=2)
gr.Markdown("**Preprocess Options**")
with gr.Row():
btn_m_gray = gr.Button("Gray", size="sm")
btn_m_norm = gr.Button("Norm", size="sm")
btn_m_inv = gr.Button("Invert", size="sm")
with gr.Row():
btn_m_clahe = gr.Button("CLAHE", size="sm")
btn_m_mask = gr.Button("Binary Mask", size="sm")
btn_m_reset = gr.Button("Reset", variant="secondary")
with gr.Row():
btn_proceed = gr.Button("Proceed to Registration", variant="primary", size="lg")
# ============================================================
# TAB 2: REGISTRATION
# ============================================================
with gr.TabItem("2. Register & Crop", id=1):
with gr.Row():
# CONTROLS
with gr.Column(scale=1):
gr.Markdown("#### 1. Auto-Match", elem_classes="section-header")
dd_algo = gr.Dropdown(
["Pixel Intensity (Brute Force)"],
value="Pixel Intensity (Brute Force)", show_label=False
)
btn_auto_match = gr.Button("Run Auto-Match", variant="primary")
gr.Markdown("#### 2. Fine-Tune", elem_classes="section-header")
slider_x = gr.Number(label="Shift X", value=0)
slider_y = gr.Number(label="Shift Y", value=0)
with gr.Row():
btn_rot_cw = gr.Button("↻ 90°")
btn_rot_ccw = gr.Button("↺ -90°")
gr.Markdown("---")
btn_confirm = gr.Button("Confirm & Crop Originals", variant="primary")
# LIVE OVERLAY
with gr.Column(scale=3):
gr.Markdown("<div style='text-align: center; font-weight: 900 !important;'>Live Alignment Preview</div>")
img_overlay = gr.Plot(label="Live Overlay", container=True)
slider_opacity = gr.Slider(0, 1, value=0.5, label="Opacity")
# CROPPED RESULTS
with gr.Group(visible=False) as grp_results:
gr.Markdown("### Final Cropped Originals", elem_classes="section-header")
with gr.Row():
with gr.Column():
gr.Markdown("**Fixed Image (Cropped)**")
res_fixed_crop = gr.Image(interactive=False, height=300)
btn_add_f = gr.Button("Add to Sandwich")
btn_use_f = gr.Button("Use as Input (Tab 1)")
with gr.Column():
gr.Markdown("**Moving Image**")
res_moving_crop = gr.Image(interactive=False, height=300)
btn_add_m = gr.Button("Add to Sandwich")
btn_use_m = gr.Button("Use as Input (Tab 1)")
txt_log = gr.Textbox(label="Log", lines=2)
# ============================================================
# TAB 3: SANDWICH
# ============================================================
with gr.TabItem("3. Sandwich & Export", id=2):
gr.Markdown("### Sandwich Composition", elem_classes="section-header")
with gr.Row():
with gr.Column(scale=3):
# CHANGED: Better visualization
gr.Markdown("*Visual Layers (Full View)*")
gallery = gr.Gallery(
label="Visual Layers",
columns=3,
height="auto",
object_fit="contain"
)
df_layers = gr.Dataframe(
headers=["Index", "Name", "Shape", "Steps"],
label="Layer Metadata",
interactive=False
)
with gr.Column(scale=1):
gr.Markdown("**Manage Layers**")
num_idx = gr.Number(label="Target Layer Index", value=0, precision=0)
with gr.Row():
btn_up = gr.Button("⬆️ Up")
btn_down = gr.Button("⬇️ Down")
with gr.Row():
txt_rename = gr.Textbox(label="New Name", placeholder="Enter name...")
btn_rename = gr.Button(" Rename")
btn_delete = gr.Button(" Delete Layer", variant="stop")
gr.Markdown("---")
gr.Markdown("### Export Options")
with gr.Row():
with gr.Column():
btn_exp_images = gr.Button("1. Prepare Images (.zip)", variant="secondary")
file_zip = gr.File(label="Images (.zip)")
with gr.Column():
btn_exp_tensor = gr.Button("2. Prepare Tensor (.npy & .txt)", variant="secondary")
with gr.Row():
file_npy = gr.File(label="Numpy Binary (.npy)")
file_txt_arr = gr.File(label="Text Representation (.txt)")
with gr.Column():
btn_exp_data = gr.Button("3. Prepare Metadata & Montage", variant="secondary")
with gr.Row():
file_csv = gr.File(label="Metadata (.csv)")
file_montage = gr.File(label="Montage Plot (.png)")
gr.Markdown("---")
gr.Markdown("""
<div style='text-align: center; color: #666;'>
<p style='margin-bottom: 0;'>Datrix SPA - Politecnico di Milano (VIBRA Group)</p>
<p style='font-size: 14px; margin-top: 5px;'>Contact: marc.rodriguez@datrixgroup.com / marc.rodriguez@polimi.it</p>
</div>
""")
# ============================================================
# 3. LOGIC & HANDLERS
# ============================================================
# --- UPLOAD HANDLERS (Initialize History) ---
def on_file_upload(file_obj):
raw_img = load_from_file(file_obj)
norm = load_and_normalize(raw_img)
return norm, norm, to_interactive_plot(to_display(norm, 800), height=350), ["Loaded"]
file_f_input.change(on_file_upload, file_f_input, [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist])
file_m_input.change(on_file_upload, file_m_input, [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist])
# --- PREPROCESSING (Update History) ---
def process_image(raw, current_proc, history, op, param=None):
if raw is None: return None, None, None, history
updated_history = list(history)
# 1. Structural Ops
if op == "Resize":
try:
scale = float(param)
new_raw = resize_image(raw, scale)
updated_history.append(f"Resize (scale={scale})")
return new_raw, new_raw, to_interactive_plot(to_display(new_raw, 800), height=350), updated_history
except:
return raw, current_proc, to_interactive_plot(to_display(current_proc, 800), height=350), history
# 2. Matching Ops
if op == "Reset":
return raw, raw, to_interactive_plot(to_display(raw, 800), height=350), ["Reset"]
img_to_mod = current_proc.copy()
if op == "Gray":
if img_to_mod.ndim == 3: res = rgb2gray(img_to_mod).astype(np.uint8)
else: res = img_to_mod
updated_history.append("Grayscale")
elif op == "Invert":
res = cv2.bitwise_not(img_to_mod)
updated_history.append("Invert")
elif op == "Norm":
norm = Normalize(img_to_mod)
res = (norm * 255).astype(np.uint8)
updated_history.append("Normalize")
elif op == "CLAHE":
res = apply_clahe(img_to_mod)
updated_history.append("CLAHE")
elif op == "Binary Mask":
res = binary_mask(img_to_mod)
updated_history.append("Binary Mask")
else:
res = img_to_mod
return raw, res, to_interactive_plot(to_display(res, 800), height=350), updated_history
# Wired Events (Fixed) - Include s_f_hist
btn_f_resize.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State("Resize"), num_f_resize], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist])
for btn, op in [(btn_f_gray,"Gray"), (btn_f_norm,"Norm"), (btn_f_inv,"Invert"), (btn_f_clahe,"CLAHE"), (btn_f_mask, "Binary Mask")]:
btn.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State(op)], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist])
btn_f_reset.click(process_image, [s_fixed_raw, s_fixed_proc, s_f_hist, gr.State("Reset")], [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist])
# Wired Events (Moving) - Include s_m_hist
btn_m_resize.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State("Resize"), num_m_resize], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist])
for btn, op in [(btn_m_gray,"Gray"), (btn_m_norm,"Norm"), (btn_m_inv,"Invert"), (btn_m_clahe,"CLAHE"), (btn_m_mask, "Binary Mask")]:
btn.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State(op)], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist])
btn_m_reset.click(process_image, [s_moving_raw, s_moving_proc, s_m_hist, gr.State("Reset")], [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist])
# --- TRANSITION ---
def proceed_to_reg(fr, fp, mr, mp):
if fr is None or mr is None: return None, None, None, None, gr.Tabs()
return fr, fp, mr, mp, gr.Tabs(selected=1)
btn_proceed.click(proceed_to_reg, [s_fixed_raw, s_fixed_proc, s_moving_raw, s_moving_proc], [r_fixed_raw, r_fixed_proc, r_moving_raw, r_moving_proc, main_tabs])
# --- REGISTRATION LOGIC ---
# (Kept identical to original, just hiding for brevity as logic didn't change, only inputs/outputs for crop)
def gen_overlay(fixed_raw, moving_raw, dx, dy, sym, opacity):
"""Generates Interactive Plotly Overlay."""
if fixed_raw is None or moving_raw is None: return None
moved = get_symmetry(moving_raw, sym)
h, w = fixed_raw.shape[:2]
if fixed_raw.ndim==3: canvas = np.zeros((h, w, 3), dtype=np.uint8)
else: canvas = np.zeros((h, w), dtype=np.uint8)
h_m, w_m = moved.shape[:2]
x1, y1 = max(0, int(dx)), max(0, int(dy))
x2, y2 = min(w, int(dx)+w_m), min(h, int(dy)+h_m)
mx1, my1 = max(0, -int(dx)), max(0, -int(dy))
try:
if x2>x1 and y2>y1:
canvas[y1:y2, x1:x2] = moved[my1:my1+(y2-y1), mx1:mx1+(x2-x1)]
except: pass
f = fixed_raw if fixed_raw.ndim==3 else cv2.cvtColor(fixed_raw, cv2.COLOR_GRAY2RGB)
m = canvas if canvas.ndim==3 else cv2.cvtColor(canvas, cv2.COLOR_GRAY2RGB)
blended = cv2.addWeighted(f, 1-opacity, m, opacity, 0)
return to_interactive_plot(to_display(blended, 800), height=500)
def on_auto(fp, mp, fr, mr, algo):
if fp is None: return 0,0,"R0", "No Data", None, None
fp_g = fp if fp.ndim==2 else rgb2gray(fp)
mp_g = mp if mp.ndim==2 else rgb2gray(mp)
try:
if "Pixel" in algo: res, _ = find_best_match_pixel(fp_g, mp_g, "x")
elif "Feature" in algo: res, _ = find_best_match_features(fp_g, mp_g, "x")
elif "FFT" in algo: res, _ = find_best_match_fft(fp_g, mp_g, "x")
if res:
loc = res.get('location', (0,0))
if 'shift_xy' in res: loc = res['shift_xy']
dx, dy, sym = int(loc[0]), int(loc[1]), res['symmetry']
log = f"Score: {res['score'] * 100:.2f}% (Symmetry: {sym})"
else:
dx, dy, sym, log = 0, 0, "R0", "No match"
except Exception as e:
dx, dy, sym, log = 0, 0, "R0", str(e)
ov = gen_overlay(fr, mr, dx, dy, sym, 0.5)
return dx, dy, log, ov, {"dx": dx, "dy": dy, "sym": sym}
btn_auto_match.click(on_auto, [r_fixed_proc, r_moving_proc, r_fixed_raw, r_moving_raw, dd_algo], [slider_x, slider_y, txt_log, img_overlay, r_params])
def on_manual(fr, mr, dx, dy, op, params):
sym = params.get("sym", "R0")
ov = gen_overlay(fr, mr, dx, dy, sym, op)
params["dx"], params["dy"] = dx, dy
return ov, params
for inp in [slider_x, slider_y, slider_opacity]:
inp.change(on_manual, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params], [img_overlay, r_params])
def on_rot(fr, mr, dx, dy, op, params, dir):
syms = ["R0", "R1", "R2", "R3"]
idx = syms.index(params.get("sym", "R0")) if params.get("sym", "R0") in syms else 0
idx = (idx+1)%4 if dir=="cw" else (idx-1)%4
params["sym"] = syms[idx]
ov = gen_overlay(fr, mr, dx, dy, params["sym"], op)
return ov, params
btn_rot_cw.click(on_rot, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params, gr.State("cw")], [img_overlay, r_params])
btn_rot_ccw.click(on_rot, [r_fixed_raw, r_moving_raw, slider_x, slider_y, slider_opacity, r_params, gr.State("ccw")], [img_overlay, r_params])
# --- CROP & CONFIRM (Pass History) ---
def apply_crop_wrapper(f_raw, m_raw, params, f_hist, m_hist):
dx, dy, sym = int(params["dx"]), int(params["dy"]), params["sym"]
m_moved = get_symmetry(m_raw, sym)
h_f, w_f = f_raw.shape[:2]
h_m, w_m = m_moved.shape[:2]
x1, y1 = max(0, dx), max(0, dy)
x2, y2 = min(w_f, dx + w_m), min(h_f, dy + h_m)
if x2 <= x1 or y2 <= y1: return None, None, [], [], None, None, gr.Group(visible=True)
full_c_fixed = f_raw[y1:y2, x1:x2]
mx1, my1 = x1 - dx, y1 - dy
full_c_moving = m_moved[my1:my1+(y2-y1), mx1:mx1+(x2-x1)]
view_fixed = to_display(full_c_fixed)
view_moving = to_display(full_c_moving)
# Append Registration details to history
new_f_hist = list(f_hist) + [f"Cropped to ({x1}:{x2}, {y1}:{y2})"]
new_m_hist = list(m_hist) + [f"Symmetry {sym}", f"Shift ({dx},{dy})", "Cropped to Intersection"]
return full_c_fixed, full_c_moving, new_f_hist, new_m_hist, view_fixed, view_moving, gr.Group(visible=True)
btn_confirm.click(apply_crop_wrapper,
[r_fixed_raw, r_moving_raw, r_params, s_f_hist, s_m_hist],
[res_fixed_state, res_moving_state, res_fixed_hist_state, res_moving_hist_state, res_fixed_crop, res_moving_crop, grp_results])
# --- SANDWICH MANAGEMENT ---
def update_sandwich_ui(stack):
# Update Dataframe to show Preprocessing Steps
data = []
for i, x in enumerate(stack):
hist_summary = " -> ".join(x.get("history", []))
data.append([i, x["name"], str(x["img"].shape), hist_summary])
imgs = [to_display(x["img"], 300) for x in stack]
return stack, imgs, data
# Modified to accept history
def add_to_stack(img, hist, stack):
if img is None: return stack, [], []
name = f"Layer_{len(stack)}"
# Store img AND history
stack.append({"name": name, "img": img, "history": hist})
return update_sandwich_ui(stack)
def manage_layers(stack, idx, action, new_name):
if not stack: return update_sandwich_ui(stack)
i = int(idx)
if i < 0 or i >= len(stack): return update_sandwich_ui(stack)
if action == "delete":
stack.pop(i)
elif action == "up":
if i > 0:
stack[i], stack[i-1] = stack[i-1], stack[i]
elif action == "down":
if i < len(stack) - 1:
stack[i], stack[i+1] = stack[i+1], stack[i]
elif action == "rename":
if new_name.strip():
stack[i]["name"] = new_name.strip()
return update_sandwich_ui(stack)
# Add Buttons now pass history state
btn_add_f.click(add_to_stack, [res_fixed_state, res_fixed_hist_state, s_stack], [s_stack, gallery, df_layers])
btn_add_m.click(add_to_stack, [res_moving_state, res_moving_hist_state, s_stack], [s_stack, gallery, df_layers])
# Management Buttons
btn_up.click(manage_layers, [s_stack, num_idx, gr.State("up"), txt_rename], [s_stack, gallery, df_layers])
btn_down.click(manage_layers, [s_stack, num_idx, gr.State("down"), txt_rename], [s_stack, gallery, df_layers])
btn_delete.click(manage_layers, [s_stack, num_idx, gr.State("delete"), txt_rename], [s_stack, gallery, df_layers])
btn_rename.click(manage_layers, [s_stack, num_idx, gr.State("rename"), txt_rename], [s_stack, gallery, df_layers])
# Recycle
def recycle(img):
norm = load_and_normalize(img)
plot = to_interactive_plot(to_display(norm, 800), height=350)
# Reset history because it's starting over
return norm, norm, plot, ["Recycled from Crop"], gr.Tabs(selected=0)
btn_use_f.click(recycle, res_fixed_state, [s_fixed_raw, s_fixed_proc, plot_f_view, s_f_hist, main_tabs])
btn_use_m.click(recycle, res_moving_state, [s_moving_raw, s_moving_proc, plot_m_view, s_m_hist, main_tabs])
# --- NEW EXPORTS ---
# 1. Zip Images
def do_export_zip(stack):
if not stack: return None
zip_file = "sandwich_images.zip"
with zipfile.ZipFile(zip_file, 'w') as zf:
for i, item in enumerate(stack):
bgr = cv2.cvtColor(item["img"], cv2.COLOR_RGB2BGR) if item["img"].ndim==3 else item["img"]
_, buf = cv2.imencode('.png', bgr)
zf.writestr(f"{i}_{item['name']}.png", buf)
return zip_file
btn_exp_images.click(do_export_zip, s_stack, file_zip)
# 2. Tensors (.npy and .txt)
btn_exp_tensor.click(create_full_tensor_files, s_stack, [file_npy, file_txt_arr])
# 3. Metadata & Montage
def do_export_meta_montage(stack):
csv = create_metadata_csv(stack)
img = create_montage_plot(stack)
return csv, img
btn_exp_data.click(do_export_meta_montage, s_stack, [file_csv, file_montage])
if __name__ == "__main__":
app.launch(inbrowser=True, share=False, theme=gr.themes.Soft(), css=css)