update source code
Browse files- src/build/lib/loki/decompose.py +4 -4
- src/build/lib/loki/plot.py +92 -97
- src/build/lib/loki/utils.py +72 -140
- src/loki.egg-info/PKG-INFO +3 -2
- src/loki.egg-info/SOURCES.txt +1 -1
- src/loki/__pycache__/plot.cpython-39.pyc +0 -0
- src/loki/__pycache__/utils.cpython-39.pyc +0 -0
- src/loki/plot.py +92 -97
- src/loki/utils.py +72 -140
- src/requirements.txt +1 -0
src/build/lib/loki/decompose.py
CHANGED
|
@@ -77,11 +77,11 @@ def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False,
|
|
| 77 |
|
| 78 |
:param sc_ad: AnnData object containing single-cell meta data.
|
| 79 |
:param st_ad: AnnData object containing spatial data (ST or image) meta data.
|
| 80 |
-
:param density_prior: A numpy array providing prior information about cell densities in spatial spots.
|
| 81 |
:param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
|
| 82 |
-
:param
|
| 83 |
-
:param
|
| 84 |
-
:param
|
|
|
|
| 85 |
:return: The spatial AnnData object with projected cell type annotations.
|
| 86 |
"""
|
| 87 |
|
|
|
|
| 77 |
|
| 78 |
:param sc_ad: AnnData object containing single-cell meta data.
|
| 79 |
:param st_ad: AnnData object containing spatial data (ST or image) meta data.
|
|
|
|
| 80 |
:param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
|
| 81 |
+
:param NMS_mode: Boolean flag to apply Non-Maximum Suppression (NMS) mode. Default is False.
|
| 82 |
+
:param major_types: Major cell types used for NMS mode. Default is None.
|
| 83 |
+
:param min_percentile: The lower percentile used for clipping (defaults to 5).
|
| 84 |
+
:param max_percentile: The upper percentile used for clipping (defaults to 95).
|
| 85 |
:return: The spatial AnnData object with projected cell type annotations.
|
| 86 |
"""
|
| 87 |
|
src/build/lib/loki/plot.py
CHANGED
|
@@ -8,107 +8,102 @@ import numpy as np
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
|
| 17 |
-
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 18 |
-
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 19 |
-
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 20 |
-
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 21 |
-
:param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
|
| 22 |
-
:param s: Marker size for the scatter plot points. Default is 0.8.
|
| 23 |
-
:param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
|
| 24 |
-
:return: Displays the alignment plot of target, source, and alignment of source coordinates.
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
plt.
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
|
| 54 |
-
plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
|
| 55 |
-
|
| 56 |
-
# Remove the axis labels and ticks from all subplots for a cleaner appearance
|
| 57 |
-
plt.axis('off')
|
| 58 |
-
|
| 59 |
-
# Display the plot
|
| 60 |
plt.show()
|
| 61 |
|
| 62 |
|
| 63 |
|
| 64 |
-
def plot_alignment_with_img(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
|
| 69 |
-
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 70 |
-
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 71 |
-
:param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
|
| 72 |
-
:param src_img: Image associated with the source coordinates, used as the background in the second subplot.
|
| 73 |
-
:param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
|
| 74 |
-
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 75 |
-
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 76 |
-
:return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
|
| 77 |
"""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 100 |
-
# Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
|
| 101 |
-
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 102 |
-
# Overlay the aligned image with some transparency (alpha = 0.3)
|
| 103 |
-
plt.imshow(aligned_image, origin='lower', alpha=0.3)
|
| 104 |
-
|
| 105 |
-
# Turn off the axis for all subplots to give a cleaner visual output
|
| 106 |
-
plt.axis('off')
|
| 107 |
-
|
| 108 |
-
# Display the plots
|
| 109 |
plt.show()
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 114 |
"""
|
|
@@ -228,6 +223,9 @@ def plot_heatmap(
|
|
| 228 |
coor,
|
| 229 |
similairty,
|
| 230 |
image_path=None,
|
|
|
|
|
|
|
|
|
|
| 231 |
patch_size=(256, 256),
|
| 232 |
save_path=None,
|
| 233 |
downsize=32,
|
|
@@ -236,9 +234,6 @@ def plot_heatmap(
|
|
| 236 |
boxes=None,
|
| 237 |
box_color='k',
|
| 238 |
box_thickness=2,
|
| 239 |
-
polygons=None,
|
| 240 |
-
polygons_color='k',
|
| 241 |
-
polygons_thickness=2,
|
| 242 |
image_alpha=0.5
|
| 243 |
):
|
| 244 |
"""
|
|
@@ -316,7 +311,7 @@ def plot_heatmap(
|
|
| 316 |
|
| 317 |
|
| 318 |
|
| 319 |
-
def show_images_side_by_side(image1, image2, title1=
|
| 320 |
"""
|
| 321 |
Displays two images side by side in a single figure.
|
| 322 |
|
|
@@ -328,7 +323,7 @@ def show_images_side_by_side(image1, image2, title1=None, title2=None):
|
|
| 328 |
"""
|
| 329 |
|
| 330 |
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 331 |
-
fig, ax = plt.subplots(1, 2, figsize=(
|
| 332 |
|
| 333 |
# Display the first image on the first subplot
|
| 334 |
ax[0].imshow(image1)
|
|
@@ -364,7 +359,7 @@ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
|
|
| 364 |
"""
|
| 365 |
|
| 366 |
# Create a new figure with a fixed size for displaying the image and annotations
|
| 367 |
-
plt.figure(figsize=(
|
| 368 |
|
| 369 |
# Display the full-resolution image
|
| 370 |
plt.imshow(fullres_img)
|
|
@@ -403,7 +398,7 @@ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
|
|
| 403 |
"""
|
| 404 |
|
| 405 |
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 406 |
-
plt.figure(figsize=(
|
| 407 |
|
| 408 |
# Scatter plot for the spatial transcriptomics data.
|
| 409 |
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
+
def plot_alignment(
|
| 12 |
+
ad_tar_coor: np.ndarray,
|
| 13 |
+
ad_src_coor: np.ndarray,
|
| 14 |
+
homo_coor: np.ndarray,
|
| 15 |
+
pca_hex_comb: np.ndarray,
|
| 16 |
+
tar_features: np.ndarray,
|
| 17 |
+
shift: float = 300,
|
| 18 |
+
s: float = 0.8,
|
| 19 |
+
boundary_line: bool = True
|
| 20 |
+
) -> None:
|
| 21 |
"""
|
| 22 |
+
Optimized plot: target, source, and aligned coordinates with titles.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
+
# Determine common limits
|
| 25 |
+
coords = np.vstack([ad_tar_coor, ad_src_coor, homo_coor])
|
| 26 |
+
x_min, x_max = coords[:,0].min() - shift, coords[:,0].max() + shift
|
| 27 |
+
y_min, y_max = coords[:,1].min() - shift, coords[:,1].max() + shift
|
| 28 |
+
|
| 29 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 3), dpi=150)
|
| 30 |
+
titles = ["Target ST", "Source ST", "Aligned Source ST"]
|
| 31 |
+
splits = [len(ad_tar_coor), len(ad_tar_coor)+len(ad_src_coor)]
|
| 32 |
+
|
| 33 |
+
for ax, title, data_slice in zip(
|
| 34 |
+
axes,
|
| 35 |
+
titles,
|
| 36 |
+
[(ad_tar_coor, pca_hex_comb[:splits[0]]),
|
| 37 |
+
(ad_src_coor, pca_hex_comb[splits[0]:splits[1]]),
|
| 38 |
+
(homo_coor, pca_hex_comb[splits[0]:splits[1]])]
|
| 39 |
+
):
|
| 40 |
+
coords_arr, colors = data_slice
|
| 41 |
+
ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
|
| 42 |
+
ax.set_xlim(x_min, x_max)
|
| 43 |
+
ax.set_ylim(y_min, y_max)
|
| 44 |
+
ax.set_aspect('equal')
|
| 45 |
+
if boundary_line:
|
| 46 |
+
ax.axvline(x=ad_tar_coor[:,0].min(), color='black', linewidth=1)
|
| 47 |
+
ax.axhline(y=ad_tar_coor[:,1].min(), color='black', linewidth=1)
|
| 48 |
+
ax.set_title(title)
|
| 49 |
+
ax.axis('off')
|
| 50 |
+
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
plt.show()
|
| 52 |
|
| 53 |
|
| 54 |
|
| 55 |
+
def plot_alignment_with_img(
|
| 56 |
+
ad_tar_coor: np.ndarray,
|
| 57 |
+
ad_src_coor: np.ndarray,
|
| 58 |
+
homo_coor: np.ndarray,
|
| 59 |
+
tar_img,
|
| 60 |
+
src_img,
|
| 61 |
+
aligned_image,
|
| 62 |
+
pca_hex_comb: np.ndarray,
|
| 63 |
+
tar_features: np.ndarray,
|
| 64 |
+
s: float = 1.0
|
| 65 |
+
) -> None:
|
| 66 |
"""
|
| 67 |
+
Optimized plot with images in the background and subplot titles.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5), dpi=150)
|
| 70 |
+
titles = ["Target + Image", "Source + Image", "Aligned + Image"]
|
| 71 |
+
splits = [len(tar_features.T), len(tar_features.T) * 2]
|
| 72 |
+
|
| 73 |
+
# Data slices for each subplot
|
| 74 |
+
data_slices = [
|
| 75 |
+
(ad_tar_coor, pca_hex_comb[:splits[0]], tar_img),
|
| 76 |
+
(ad_src_coor, pca_hex_comb[splits[0]:splits[1]], src_img),
|
| 77 |
+
(np.vstack([ad_tar_coor, homo_coor]),
|
| 78 |
+
np.concatenate([pca_hex_comb[:splits[0]], pca_hex_comb[splits[0]:splits[1]]]),
|
| 79 |
+
aligned_image)
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
for ax, title, (coords_arr, colors, img) in zip(axes, titles, data_slices):
|
| 83 |
+
ax.imshow(img, origin='lower', alpha=0.3)
|
| 84 |
+
ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
|
| 85 |
+
ax.set_aspect('equal')
|
| 86 |
+
ax.set_title(title)
|
| 87 |
+
ax.axis('off')
|
| 88 |
+
|
| 89 |
+
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
plt.show()
|
| 91 |
|
| 92 |
|
| 93 |
+
def show_image(img, title: str = "Aligned Source Image", origin: str = "lower", cmap=None):
|
| 94 |
+
"""
|
| 95 |
+
Display a single image with no axes and a title.
|
| 96 |
+
|
| 97 |
+
:param img: The image to display (NumPy array, PIL Image, etc.).
|
| 98 |
+
:param title: Title to show above the image.
|
| 99 |
+
:param origin: Origin parameter passed to plt.imshow (e.g. 'lower' or 'upper').
|
| 100 |
+
:param cmap: Optional colormap for grayscale or other single‑channel data.
|
| 101 |
+
"""
|
| 102 |
+
plt.imshow(img, origin=origin, cmap=cmap)
|
| 103 |
+
plt.title(title)
|
| 104 |
+
plt.axis('off')
|
| 105 |
+
plt.show()
|
| 106 |
+
|
| 107 |
|
| 108 |
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 109 |
"""
|
|
|
|
| 223 |
coor,
|
| 224 |
similairty,
|
| 225 |
image_path=None,
|
| 226 |
+
polygons=None,
|
| 227 |
+
polygons_color='k',
|
| 228 |
+
polygons_thickness=2,
|
| 229 |
patch_size=(256, 256),
|
| 230 |
save_path=None,
|
| 231 |
downsize=32,
|
|
|
|
| 234 |
boxes=None,
|
| 235 |
box_color='k',
|
| 236 |
box_thickness=2,
|
|
|
|
|
|
|
|
|
|
| 237 |
image_alpha=0.5
|
| 238 |
):
|
| 239 |
"""
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
|
| 314 |
+
def show_images_side_by_side(image1, image2, title1='Annotated H&E Image', title2='Similatrity Heatmap'):
|
| 315 |
"""
|
| 316 |
Displays two images side by side in a single figure.
|
| 317 |
|
|
|
|
| 323 |
"""
|
| 324 |
|
| 325 |
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 326 |
+
fig, ax = plt.subplots(1, 2, figsize=(8,6), dpi=150)
|
| 327 |
|
| 328 |
# Display the first image on the first subplot
|
| 329 |
ax[0].imshow(image1)
|
|
|
|
| 359 |
"""
|
| 360 |
|
| 361 |
# Create a new figure with a fixed size for displaying the image and annotations
|
| 362 |
+
plt.figure(figsize=(12, 12), dpi=150)
|
| 363 |
|
| 364 |
# Display the full-resolution image
|
| 365 |
plt.imshow(fullres_img)
|
|
|
|
| 398 |
"""
|
| 399 |
|
| 400 |
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 401 |
+
plt.figure(figsize=(12, 12), dpi=150)
|
| 402 |
|
| 403 |
# Scatter plot for the spatial transcriptomics data.
|
| 404 |
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
src/build/lib/loki/utils.py
CHANGED
|
@@ -11,175 +11,107 @@ from open_clip import create_model_from_pretrained, get_tokenizer
|
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
-
using the specified model checkpoint.
|
| 18 |
-
|
| 19 |
-
:param model_path: File path or URL to the pretrained model checkpoint. This is passed to
|
| 20 |
-
`create_model_from_pretrained` as the `pretrained` argument.
|
| 21 |
-
:type model_path: str
|
| 22 |
-
:param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
|
| 23 |
-
:type device: str or torch.device
|
| 24 |
-
:return: A tuple `(model, preprocess, tokenizer)` where:
|
| 25 |
-
- model: The loaded CoCa model.
|
| 26 |
-
- preprocess: A function or transform that preprocesses input data for the model.
|
| 27 |
-
- tokenizer: A tokenizer appropriate for textual input to the model.
|
| 28 |
-
:rtype: (nn.Module, callable, callable)
|
| 29 |
"""
|
| 30 |
-
# Create the model and its preprocessing transform from the specified checkpoint
|
| 31 |
model, preprocess = create_model_from_pretrained(
|
| 32 |
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
tokenizer = get_tokenizer('coca_ViT-L-14')
|
| 37 |
-
|
| 38 |
return model, preprocess, tokenizer
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
:type model: torch.nn.Module
|
| 48 |
-
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 49 |
-
suitable for the model. Typically something returning a PyTorch tensor.
|
| 50 |
-
:type preprocess: callable
|
| 51 |
-
:param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
|
| 52 |
-
:type image: PIL.Image.Image or numpy.ndarray
|
| 53 |
-
:return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
|
| 54 |
-
:rtype: torch.Tensor
|
| 55 |
"""
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# Generate the image features without gradient tracking
|
| 60 |
-
with torch.no_grad():
|
| 61 |
-
image_features = model.encode_image(image_input)
|
| 62 |
-
|
| 63 |
-
# Normalize embeddings across the feature dimension (L2 normalization)
|
| 64 |
-
image_embeddings = F.normalize(image_features, p=2, dim=-1)
|
| 65 |
-
|
| 66 |
-
return image_embeddings
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def encode_image_patches(model, preprocess, data_dir, img_list):
|
| 71 |
"""
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
suitable for the model. Typically something returning a PyTorch tensor.
|
| 78 |
-
:type preprocess: callable
|
| 79 |
-
:param data_dir: The base directory containing image data.
|
| 80 |
-
:type data_dir: str
|
| 81 |
-
:param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
|
| 82 |
-
stored in `data_dir/demo_data/patch/`.
|
| 83 |
-
:type img_list: list[str]
|
| 84 |
-
:return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
|
| 85 |
-
for each image in `img_list`.
|
| 86 |
-
:rtype: torch.Tensor
|
| 87 |
-
"""
|
| 88 |
|
| 89 |
-
# Prepare a list to hold each image's feature embedding
|
| 90 |
-
image_embeddings = []
|
| 91 |
|
| 92 |
-
# Loop through each image name in the provided list
|
| 93 |
-
for img_name in img_list:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
|
| 104 |
-
# Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 105 |
-
# Resulting shape will be (N, 1, embedding_dim)
|
| 106 |
-
image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 107 |
|
| 108 |
-
# Normalize all embeddings across the feature dimension (L2 normalization)
|
| 109 |
-
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 110 |
|
| 111 |
-
return image_embeddings
|
| 112 |
|
| 113 |
|
|
|
|
| 114 |
|
| 115 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
:param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
|
| 120 |
-
:type model: torch.nn.Module
|
| 121 |
-
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 122 |
-
Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
|
| 123 |
-
:type tokenizer: callable
|
| 124 |
-
:param text: The input text (string or list of strings) to be encoded.
|
| 125 |
-
:type text: str or list[str]
|
| 126 |
-
:return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
|
| 127 |
-
:rtype: torch.Tensor
|
| 128 |
"""
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# Run the model in no-grad mode (not tracking gradients, saving memory and compute)
|
| 134 |
with torch.no_grad():
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
# Normalize embeddings to unit length
|
| 138 |
-
text_embeddings = F.normalize(text_features, p=2, dim=-1)
|
| 139 |
|
| 140 |
-
return text_embeddings
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
-
Encodes
|
| 147 |
-
returning a PyTorch tensor of normalized text embeddings.
|
| 148 |
-
|
| 149 |
-
:param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
|
| 150 |
-
:type model: torch.nn.Module
|
| 151 |
-
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 152 |
-
:type tokenizer: callable
|
| 153 |
-
:param df: A pandas DataFrame from which text will be extracted.
|
| 154 |
-
:type df: pandas.DataFrame
|
| 155 |
-
:param col_name: The name of the column in `df` that contains the text to be encoded.
|
| 156 |
-
:type col_name: str
|
| 157 |
-
:return: A PyTorch tensor containing the L2-normalized text embeddings,
|
| 158 |
-
where the shape is (number_of_rows, embedding_dim).
|
| 159 |
-
:rtype: torch.Tensor
|
| 160 |
"""
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
# Prepare a list to hold each row's text embedding
|
| 163 |
-
text_embeddings = []
|
| 164 |
-
|
| 165 |
-
# Loop through each index in the DataFrame
|
| 166 |
-
for idx in df.index:
|
| 167 |
-
# Retrieve text from the specified column for the current row
|
| 168 |
-
text = df[df.index == idx][col_name][0]
|
| 169 |
-
|
| 170 |
-
# Encode the text using the provided model and tokenizer
|
| 171 |
-
text_features = encode_text(model, tokenizer, text)
|
| 172 |
-
|
| 173 |
-
# Accumulate the embedding tensor
|
| 174 |
-
text_embeddings.append(text_features)
|
| 175 |
-
|
| 176 |
-
# Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
|
| 177 |
-
text_embeddings = torch.from_numpy(np.array(text_embeddings))
|
| 178 |
-
|
| 179 |
-
# Normalize embeddings to unit length across the feature dimension
|
| 180 |
-
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| 181 |
-
|
| 182 |
-
return text_embeddings
|
| 183 |
|
| 184 |
|
| 185 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
+
import os
|
| 15 |
+
from typing import List, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import pandas as pd
|
| 22 |
+
|
| 23 |
+
# --- Model loading --------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def load_model(
|
| 26 |
+
model_path: str,
|
| 27 |
+
device: Union[str, torch.device]
|
| 28 |
+
) -> Tuple[torch.nn.Module, callable, callable]:
|
| 29 |
"""
|
| 30 |
+
Load pretrained OmiCLIP (COCA ViT‑L‑14) model, its image preprocess, and tokenizer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
|
|
|
| 32 |
model, preprocess = create_model_from_pretrained(
|
| 33 |
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 34 |
)
|
| 35 |
+
tokenizer = get_tokenizer("coca_ViT-L-14")
|
| 36 |
+
model.to(device).eval()
|
|
|
|
|
|
|
| 37 |
return model, preprocess, tokenizer
|
| 38 |
|
| 39 |
+
# --- Image encoding -------------------------------------------------------
|
| 40 |
|
| 41 |
+
def encode_images(
|
| 42 |
+
model: torch.nn.Module,
|
| 43 |
+
preprocess: callable,
|
| 44 |
+
image_paths: List[str],
|
| 45 |
+
device: Union[str, torch.device]
|
| 46 |
+
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
+
Batch–encode a list of image file paths into L2‑normalized embeddings.
|
| 49 |
+
Returns a tensor of shape (N, D).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
+
# Load & preprocess all images
|
| 52 |
+
imgs = [preprocess(Image.open(p)) for p in image_paths]
|
| 53 |
+
batch = torch.stack(imgs, dim=0).to(device) # (N, C, H, W)
|
| 54 |
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
feats = model.encode_image(batch) # (N, D)
|
| 57 |
+
return F.normalize(feats, p=2, dim=-1) # (N, D)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# # Loop through each image name in the provided list
|
| 61 |
+
# for img_name in img_list:
|
| 62 |
+
# # Build the path to the patch image and open it
|
| 63 |
+
# image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
|
| 64 |
+
# image = Image.open(image_path)
|
| 65 |
|
| 66 |
+
# # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
|
| 67 |
+
# image_features = encode_image(model, preprocess, image)
|
| 68 |
|
| 69 |
+
# # Accumulate the feature embeddings in the list
|
| 70 |
+
# image_embeddings.append(image_features)
|
| 71 |
|
| 72 |
+
# # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 73 |
+
# # Resulting shape will be (N, 1, embedding_dim)
|
| 74 |
+
# image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 75 |
|
| 76 |
+
# # Normalize all embeddings across the feature dimension (L2 normalization)
|
| 77 |
+
# image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 78 |
|
| 79 |
+
# return image_embeddings
|
| 80 |
|
| 81 |
|
| 82 |
+
# --- Text encoding --------------------------------------------------------
|
| 83 |
|
| 84 |
+
def encode_texts(
|
| 85 |
+
model: torch.nn.Module,
|
| 86 |
+
tokenizer: callable,
|
| 87 |
+
texts: List[str],
|
| 88 |
+
device: Union[str, torch.device]
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
"""
|
| 91 |
+
Batch–encode a list of strings into L2‑normalized embeddings.
|
| 92 |
+
Returns a tensor of shape (N, D).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"""
|
| 94 |
+
# Tokenizer returns a dict of tensors
|
| 95 |
+
text_inputs = tokenizer(texts)
|
| 96 |
+
|
|
|
|
|
|
|
| 97 |
with torch.no_grad():
|
| 98 |
+
feats = model.encode_text(text_inputs) # (N, D)
|
| 99 |
+
return F.normalize(feats, p=2, dim=-1) # (N, D)
|
|
|
|
|
|
|
| 100 |
|
|
|
|
| 101 |
|
| 102 |
+
def encode_text_df(
|
| 103 |
+
model: torch.nn.Module,
|
| 104 |
+
tokenizer: callable,
|
| 105 |
+
df: pd.DataFrame,
|
| 106 |
+
col_name: str,
|
| 107 |
+
device: Union[str, torch.device]
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
"""
|
| 110 |
+
Encodes an entire DataFrame column into (N, D) embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
"""
|
| 112 |
+
texts = df[col_name].astype(str).tolist()
|
| 113 |
+
return encode_texts(model, tokenizer, texts, device)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
|
src/loki.egg-info/PKG-INFO
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
Metadata-Version: 2.1
|
| 2 |
Name: loki
|
| 3 |
Version: 0.0.1
|
| 4 |
-
Summary: The Loki platform offers 5 core functions: tissue alignment, cell type decomposition,
|
| 5 |
Author: Weiqing Chen
|
| 6 |
Author-email: wec4005@med.cornell.edu
|
| 7 |
Classifier: Programming Language :: Python :: 3
|
| 8 |
-
Classifier: License ::
|
| 9 |
Classifier: Operating System :: OS Independent
|
| 10 |
Requires-Python: >=3.9
|
|
|
|
| 11 |
Requires-Dist: anndata==0.10.9
|
| 12 |
Requires-Dist: matplotlib==3.9.2
|
| 13 |
Requires-Dist: numpy==1.25.0
|
|
|
|
| 1 |
Metadata-Version: 2.1
|
| 2 |
Name: loki
|
| 3 |
Version: 0.0.1
|
| 4 |
+
Summary: The Loki platform offers 5 core functions: tissue alignment, tissue annotation, cell type decomposition, image-transcriptomics retrieval, and ST gene expression prediction
|
| 5 |
Author: Weiqing Chen
|
| 6 |
Author-email: wec4005@med.cornell.edu
|
| 7 |
Classifier: Programming Language :: Python :: 3
|
| 8 |
+
Classifier: License :: BSD 3-Clause License
|
| 9 |
Classifier: Operating System :: OS Independent
|
| 10 |
Requires-Python: >=3.9
|
| 11 |
+
License-File: LICENSE
|
| 12 |
Requires-Dist: anndata==0.10.9
|
| 13 |
Requires-Dist: matplotlib==3.9.2
|
| 14 |
Requires-Dist: numpy==1.25.0
|
src/loki.egg-info/SOURCES.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
setup.py
|
| 3 |
loki/__init__.py
|
| 4 |
loki/align.py
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
setup.py
|
| 3 |
loki/__init__.py
|
| 4 |
loki/align.py
|
src/loki/__pycache__/plot.cpython-39.pyc
CHANGED
|
Binary files a/src/loki/__pycache__/plot.cpython-39.pyc and b/src/loki/__pycache__/plot.cpython-39.pyc differ
|
|
|
src/loki/__pycache__/utils.cpython-39.pyc
CHANGED
|
Binary files a/src/loki/__pycache__/utils.cpython-39.pyc and b/src/loki/__pycache__/utils.cpython-39.pyc differ
|
|
|
src/loki/plot.py
CHANGED
|
@@ -8,107 +8,102 @@ import numpy as np
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
|
| 17 |
-
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 18 |
-
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 19 |
-
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 20 |
-
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 21 |
-
:param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
|
| 22 |
-
:param s: Marker size for the scatter plot points. Default is 0.8.
|
| 23 |
-
:param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
|
| 24 |
-
:return: Displays the alignment plot of target, source, and alignment of source coordinates.
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
plt.
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
|
| 54 |
-
plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
|
| 55 |
-
|
| 56 |
-
# Remove the axis labels and ticks from all subplots for a cleaner appearance
|
| 57 |
-
plt.axis('off')
|
| 58 |
-
|
| 59 |
-
# Display the plot
|
| 60 |
plt.show()
|
| 61 |
|
| 62 |
|
| 63 |
|
| 64 |
-
def plot_alignment_with_img(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
|
| 69 |
-
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 70 |
-
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 71 |
-
:param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
|
| 72 |
-
:param src_img: Image associated with the source coordinates, used as the background in the second subplot.
|
| 73 |
-
:param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
|
| 74 |
-
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 75 |
-
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 76 |
-
:return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
|
| 77 |
"""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 100 |
-
# Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
|
| 101 |
-
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 102 |
-
# Overlay the aligned image with some transparency (alpha = 0.3)
|
| 103 |
-
plt.imshow(aligned_image, origin='lower', alpha=0.3)
|
| 104 |
-
|
| 105 |
-
# Turn off the axis for all subplots to give a cleaner visual output
|
| 106 |
-
plt.axis('off')
|
| 107 |
-
|
| 108 |
-
# Display the plots
|
| 109 |
plt.show()
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 114 |
"""
|
|
@@ -228,6 +223,9 @@ def plot_heatmap(
|
|
| 228 |
coor,
|
| 229 |
similairty,
|
| 230 |
image_path=None,
|
|
|
|
|
|
|
|
|
|
| 231 |
patch_size=(256, 256),
|
| 232 |
save_path=None,
|
| 233 |
downsize=32,
|
|
@@ -236,9 +234,6 @@ def plot_heatmap(
|
|
| 236 |
boxes=None,
|
| 237 |
box_color='k',
|
| 238 |
box_thickness=2,
|
| 239 |
-
polygons=None,
|
| 240 |
-
polygons_color='k',
|
| 241 |
-
polygons_thickness=2,
|
| 242 |
image_alpha=0.5
|
| 243 |
):
|
| 244 |
"""
|
|
@@ -316,7 +311,7 @@ def plot_heatmap(
|
|
| 316 |
|
| 317 |
|
| 318 |
|
| 319 |
-
def show_images_side_by_side(image1, image2, title1=
|
| 320 |
"""
|
| 321 |
Displays two images side by side in a single figure.
|
| 322 |
|
|
@@ -328,7 +323,7 @@ def show_images_side_by_side(image1, image2, title1=None, title2=None):
|
|
| 328 |
"""
|
| 329 |
|
| 330 |
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 331 |
-
fig, ax = plt.subplots(1, 2, figsize=(
|
| 332 |
|
| 333 |
# Display the first image on the first subplot
|
| 334 |
ax[0].imshow(image1)
|
|
@@ -364,7 +359,7 @@ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
|
|
| 364 |
"""
|
| 365 |
|
| 366 |
# Create a new figure with a fixed size for displaying the image and annotations
|
| 367 |
-
plt.figure(figsize=(
|
| 368 |
|
| 369 |
# Display the full-resolution image
|
| 370 |
plt.imshow(fullres_img)
|
|
@@ -403,7 +398,7 @@ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
|
|
| 403 |
"""
|
| 404 |
|
| 405 |
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 406 |
-
plt.figure(figsize=(
|
| 407 |
|
| 408 |
# Scatter plot for the spatial transcriptomics data.
|
| 409 |
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
+
def plot_alignment(
|
| 12 |
+
ad_tar_coor: np.ndarray,
|
| 13 |
+
ad_src_coor: np.ndarray,
|
| 14 |
+
homo_coor: np.ndarray,
|
| 15 |
+
pca_hex_comb: np.ndarray,
|
| 16 |
+
tar_features: np.ndarray,
|
| 17 |
+
shift: float = 300,
|
| 18 |
+
s: float = 0.8,
|
| 19 |
+
boundary_line: bool = True
|
| 20 |
+
) -> None:
|
| 21 |
"""
|
| 22 |
+
Optimized plot: target, source, and aligned coordinates with titles.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
+
# Determine common limits
|
| 25 |
+
coords = np.vstack([ad_tar_coor, ad_src_coor, homo_coor])
|
| 26 |
+
x_min, x_max = coords[:,0].min() - shift, coords[:,0].max() + shift
|
| 27 |
+
y_min, y_max = coords[:,1].min() - shift, coords[:,1].max() + shift
|
| 28 |
+
|
| 29 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 3), dpi=150)
|
| 30 |
+
titles = ["Target ST", "Source ST", "Aligned Source ST"]
|
| 31 |
+
splits = [len(ad_tar_coor), len(ad_tar_coor)+len(ad_src_coor)]
|
| 32 |
+
|
| 33 |
+
for ax, title, data_slice in zip(
|
| 34 |
+
axes,
|
| 35 |
+
titles,
|
| 36 |
+
[(ad_tar_coor, pca_hex_comb[:splits[0]]),
|
| 37 |
+
(ad_src_coor, pca_hex_comb[splits[0]:splits[1]]),
|
| 38 |
+
(homo_coor, pca_hex_comb[splits[0]:splits[1]])]
|
| 39 |
+
):
|
| 40 |
+
coords_arr, colors = data_slice
|
| 41 |
+
ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
|
| 42 |
+
ax.set_xlim(x_min, x_max)
|
| 43 |
+
ax.set_ylim(y_min, y_max)
|
| 44 |
+
ax.set_aspect('equal')
|
| 45 |
+
if boundary_line:
|
| 46 |
+
ax.axvline(x=ad_tar_coor[:,0].min(), color='black', linewidth=1)
|
| 47 |
+
ax.axhline(y=ad_tar_coor[:,1].min(), color='black', linewidth=1)
|
| 48 |
+
ax.set_title(title)
|
| 49 |
+
ax.axis('off')
|
| 50 |
+
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
plt.show()
|
| 52 |
|
| 53 |
|
| 54 |
|
| 55 |
+
def plot_alignment_with_img(
|
| 56 |
+
ad_tar_coor: np.ndarray,
|
| 57 |
+
ad_src_coor: np.ndarray,
|
| 58 |
+
homo_coor: np.ndarray,
|
| 59 |
+
tar_img,
|
| 60 |
+
src_img,
|
| 61 |
+
aligned_image,
|
| 62 |
+
pca_hex_comb: np.ndarray,
|
| 63 |
+
tar_features: np.ndarray,
|
| 64 |
+
s: float = 1.0
|
| 65 |
+
) -> None:
|
| 66 |
"""
|
| 67 |
+
Optimized plot with images in the background and subplot titles.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5), dpi=150)
|
| 70 |
+
titles = ["Target + Image", "Source + Image", "Aligned + Image"]
|
| 71 |
+
splits = [len(tar_features.T), len(tar_features.T) * 2]
|
| 72 |
+
|
| 73 |
+
# Data slices for each subplot
|
| 74 |
+
data_slices = [
|
| 75 |
+
(ad_tar_coor, pca_hex_comb[:splits[0]], tar_img),
|
| 76 |
+
(ad_src_coor, pca_hex_comb[splits[0]:splits[1]], src_img),
|
| 77 |
+
(np.vstack([ad_tar_coor, homo_coor]),
|
| 78 |
+
np.concatenate([pca_hex_comb[:splits[0]], pca_hex_comb[splits[0]:splits[1]]]),
|
| 79 |
+
aligned_image)
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
for ax, title, (coords_arr, colors, img) in zip(axes, titles, data_slices):
|
| 83 |
+
ax.imshow(img, origin='lower', alpha=0.3)
|
| 84 |
+
ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
|
| 85 |
+
ax.set_aspect('equal')
|
| 86 |
+
ax.set_title(title)
|
| 87 |
+
ax.axis('off')
|
| 88 |
+
|
| 89 |
+
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
plt.show()
|
| 91 |
|
| 92 |
|
| 93 |
+
def show_image(img, title: str = "Aligned Source Image", origin: str = "lower", cmap=None):
|
| 94 |
+
"""
|
| 95 |
+
Display a single image with no axes and a title.
|
| 96 |
+
|
| 97 |
+
:param img: The image to display (NumPy array, PIL Image, etc.).
|
| 98 |
+
:param title: Title to show above the image.
|
| 99 |
+
:param origin: Origin parameter passed to plt.imshow (e.g. 'lower' or 'upper').
|
| 100 |
+
:param cmap: Optional colormap for grayscale or other single‑channel data.
|
| 101 |
+
"""
|
| 102 |
+
plt.imshow(img, origin=origin, cmap=cmap)
|
| 103 |
+
plt.title(title)
|
| 104 |
+
plt.axis('off')
|
| 105 |
+
plt.show()
|
| 106 |
+
|
| 107 |
|
| 108 |
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 109 |
"""
|
|
|
|
| 223 |
coor,
|
| 224 |
similairty,
|
| 225 |
image_path=None,
|
| 226 |
+
polygons=None,
|
| 227 |
+
polygons_color='k',
|
| 228 |
+
polygons_thickness=2,
|
| 229 |
patch_size=(256, 256),
|
| 230 |
save_path=None,
|
| 231 |
downsize=32,
|
|
|
|
| 234 |
boxes=None,
|
| 235 |
box_color='k',
|
| 236 |
box_thickness=2,
|
|
|
|
|
|
|
|
|
|
| 237 |
image_alpha=0.5
|
| 238 |
):
|
| 239 |
"""
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
|
| 314 |
+
def show_images_side_by_side(image1, image2, title1='Annotated H&E Image', title2='Similatrity Heatmap'):
|
| 315 |
"""
|
| 316 |
Displays two images side by side in a single figure.
|
| 317 |
|
|
|
|
| 323 |
"""
|
| 324 |
|
| 325 |
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 326 |
+
fig, ax = plt.subplots(1, 2, figsize=(8,6), dpi=150)
|
| 327 |
|
| 328 |
# Display the first image on the first subplot
|
| 329 |
ax[0].imshow(image1)
|
|
|
|
| 359 |
"""
|
| 360 |
|
| 361 |
# Create a new figure with a fixed size for displaying the image and annotations
|
| 362 |
+
plt.figure(figsize=(12, 12), dpi=150)
|
| 363 |
|
| 364 |
# Display the full-resolution image
|
| 365 |
plt.imshow(fullres_img)
|
|
|
|
| 398 |
"""
|
| 399 |
|
| 400 |
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 401 |
+
plt.figure(figsize=(12, 12), dpi=150)
|
| 402 |
|
| 403 |
# Scatter plot for the spatial transcriptomics data.
|
| 404 |
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
src/loki/utils.py
CHANGED
|
@@ -11,175 +11,107 @@ from open_clip import create_model_from_pretrained, get_tokenizer
|
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
-
using the specified model checkpoint.
|
| 18 |
-
|
| 19 |
-
:param model_path: File path to the pretrained model checkpoint. This is passed to
|
| 20 |
-
`create_model_from_pretrained` as the `pretrained` argument.
|
| 21 |
-
:type model_path: str
|
| 22 |
-
:param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
|
| 23 |
-
:type device: str or torch.device
|
| 24 |
-
:return: A tuple `(model, preprocess, tokenizer)` where:
|
| 25 |
-
- model: The loaded OmiCLIP model.
|
| 26 |
-
- preprocess: A function or transform that preprocesses input data for the model.
|
| 27 |
-
- tokenizer: A tokenizer appropriate for textual input to the model.
|
| 28 |
-
:rtype: (nn.Module, callable, callable)
|
| 29 |
"""
|
| 30 |
-
# Create the model and its preprocessing transform from the specified checkpoint
|
| 31 |
model, preprocess = create_model_from_pretrained(
|
| 32 |
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
tokenizer = get_tokenizer('coca_ViT-L-14')
|
| 37 |
-
|
| 38 |
return model, preprocess, tokenizer
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
:type model: torch.nn.Module
|
| 48 |
-
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 49 |
-
suitable for the model. Typically something returning a PyTorch tensor.
|
| 50 |
-
:type preprocess: callable
|
| 51 |
-
:param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
|
| 52 |
-
:type image: PIL.Image.Image or numpy.ndarray
|
| 53 |
-
:return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
|
| 54 |
-
:rtype: torch.Tensor
|
| 55 |
"""
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# Generate the image features without gradient tracking
|
| 60 |
-
with torch.no_grad():
|
| 61 |
-
image_features = model.encode_image(image_input)
|
| 62 |
-
|
| 63 |
-
# Normalize embeddings across the feature dimension (L2 normalization)
|
| 64 |
-
image_embeddings = F.normalize(image_features, p=2, dim=-1)
|
| 65 |
-
|
| 66 |
-
return image_embeddings
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def encode_image_patches(model, preprocess, data_dir, img_list):
|
| 71 |
"""
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
suitable for the model. Typically something returning a PyTorch tensor.
|
| 78 |
-
:type preprocess: callable
|
| 79 |
-
:param data_dir: The base directory containing image data.
|
| 80 |
-
:type data_dir: str
|
| 81 |
-
:param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
|
| 82 |
-
stored in `data_dir/demo_data/patch/`.
|
| 83 |
-
:type img_list: list[str]
|
| 84 |
-
:return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
|
| 85 |
-
for each image in `img_list`.
|
| 86 |
-
:rtype: torch.Tensor
|
| 87 |
-
"""
|
| 88 |
|
| 89 |
-
# Prepare a list to hold each image's feature embedding
|
| 90 |
-
image_embeddings = []
|
| 91 |
|
| 92 |
-
# Loop through each image name in the provided list
|
| 93 |
-
for img_name in img_list:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
|
| 104 |
-
# Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 105 |
-
# Resulting shape will be (N, 1, embedding_dim)
|
| 106 |
-
image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 107 |
|
| 108 |
-
# Normalize all embeddings across the feature dimension (L2 normalization)
|
| 109 |
-
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 110 |
|
| 111 |
-
return image_embeddings
|
| 112 |
|
| 113 |
|
|
|
|
| 114 |
|
| 115 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
:param model: A model object that provides an `encode_text` method.
|
| 120 |
-
:type model: torch.nn.Module
|
| 121 |
-
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 122 |
-
Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
|
| 123 |
-
:type tokenizer: callable
|
| 124 |
-
:param text: The input text (string or list of strings) to be encoded.
|
| 125 |
-
:type text: str or list[str]
|
| 126 |
-
:return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
|
| 127 |
-
:rtype: torch.Tensor
|
| 128 |
"""
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# Run the model in no-grad mode (not tracking gradients, saving memory and compute)
|
| 134 |
with torch.no_grad():
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
# Normalize embeddings to unit length
|
| 138 |
-
text_embeddings = F.normalize(text_features, p=2, dim=-1)
|
| 139 |
|
| 140 |
-
return text_embeddings
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
-
Encodes
|
| 147 |
-
returning a PyTorch tensor of normalized text embeddings.
|
| 148 |
-
|
| 149 |
-
:param model: A model object that provides an `encode_text` method.
|
| 150 |
-
:type model: torch.nn.Module
|
| 151 |
-
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 152 |
-
:type tokenizer: callable
|
| 153 |
-
:param df: A pandas DataFrame from which text will be extracted.
|
| 154 |
-
:type df: pandas.DataFrame
|
| 155 |
-
:param col_name: The name of the column in `df` that contains the text to be encoded.
|
| 156 |
-
:type col_name: str
|
| 157 |
-
:return: A PyTorch tensor containing the L2-normalized text embeddings,
|
| 158 |
-
where the shape is (number_of_rows, embedding_dim).
|
| 159 |
-
:rtype: torch.Tensor
|
| 160 |
"""
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
# Prepare a list to hold each row's text embedding
|
| 163 |
-
text_embeddings = []
|
| 164 |
-
|
| 165 |
-
# Loop through each index in the DataFrame
|
| 166 |
-
for idx in df.index:
|
| 167 |
-
# Retrieve text from the specified column for the current row
|
| 168 |
-
text = df[df.index == idx][col_name][0]
|
| 169 |
-
|
| 170 |
-
# Encode the text using the provided model and tokenizer
|
| 171 |
-
text_features = encode_text(model, tokenizer, text)
|
| 172 |
-
|
| 173 |
-
# Accumulate the embedding tensor
|
| 174 |
-
text_embeddings.append(text_features)
|
| 175 |
-
|
| 176 |
-
# Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
|
| 177 |
-
text_embeddings = torch.from_numpy(np.array(text_embeddings))
|
| 178 |
-
|
| 179 |
-
# Normalize embeddings to unit length across the feature dimension
|
| 180 |
-
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| 181 |
-
|
| 182 |
-
return text_embeddings
|
| 183 |
|
| 184 |
|
| 185 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
+
import os
|
| 15 |
+
from typing import List, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import pandas as pd
|
| 22 |
+
|
| 23 |
+
# --- Model loading --------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def load_model(
|
| 26 |
+
model_path: str,
|
| 27 |
+
device: Union[str, torch.device]
|
| 28 |
+
) -> Tuple[torch.nn.Module, callable, callable]:
|
| 29 |
"""
|
| 30 |
+
Load pretrained OmiCLIP (COCA ViT‑L‑14) model, its image preprocess, and tokenizer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
|
|
|
| 32 |
model, preprocess = create_model_from_pretrained(
|
| 33 |
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 34 |
)
|
| 35 |
+
tokenizer = get_tokenizer("coca_ViT-L-14")
|
| 36 |
+
model.to(device).eval()
|
|
|
|
|
|
|
| 37 |
return model, preprocess, tokenizer
|
| 38 |
|
| 39 |
+
# --- Image encoding -------------------------------------------------------
|
| 40 |
|
| 41 |
+
def encode_images(
|
| 42 |
+
model: torch.nn.Module,
|
| 43 |
+
preprocess: callable,
|
| 44 |
+
image_paths: List[str],
|
| 45 |
+
device: Union[str, torch.device]
|
| 46 |
+
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
+
Batch–encode a list of image file paths into L2‑normalized embeddings.
|
| 49 |
+
Returns a tensor of shape (N, D).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
+
# Load & preprocess all images
|
| 52 |
+
imgs = [preprocess(Image.open(p)) for p in image_paths]
|
| 53 |
+
batch = torch.stack(imgs, dim=0).to(device) # (N, C, H, W)
|
| 54 |
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
feats = model.encode_image(batch) # (N, D)
|
| 57 |
+
return F.normalize(feats, p=2, dim=-1) # (N, D)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# # Loop through each image name in the provided list
|
| 61 |
+
# for img_name in img_list:
|
| 62 |
+
# # Build the path to the patch image and open it
|
| 63 |
+
# image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
|
| 64 |
+
# image = Image.open(image_path)
|
| 65 |
|
| 66 |
+
# # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
|
| 67 |
+
# image_features = encode_image(model, preprocess, image)
|
| 68 |
|
| 69 |
+
# # Accumulate the feature embeddings in the list
|
| 70 |
+
# image_embeddings.append(image_features)
|
| 71 |
|
| 72 |
+
# # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 73 |
+
# # Resulting shape will be (N, 1, embedding_dim)
|
| 74 |
+
# image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 75 |
|
| 76 |
+
# # Normalize all embeddings across the feature dimension (L2 normalization)
|
| 77 |
+
# image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 78 |
|
| 79 |
+
# return image_embeddings
|
| 80 |
|
| 81 |
|
| 82 |
+
# --- Text encoding --------------------------------------------------------
|
| 83 |
|
| 84 |
+
def encode_texts(
|
| 85 |
+
model: torch.nn.Module,
|
| 86 |
+
tokenizer: callable,
|
| 87 |
+
texts: List[str],
|
| 88 |
+
device: Union[str, torch.device]
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
"""
|
| 91 |
+
Batch–encode a list of strings into L2‑normalized embeddings.
|
| 92 |
+
Returns a tensor of shape (N, D).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"""
|
| 94 |
+
# Tokenizer returns a dict of tensors
|
| 95 |
+
text_inputs = tokenizer(texts)
|
| 96 |
+
|
|
|
|
|
|
|
| 97 |
with torch.no_grad():
|
| 98 |
+
feats = model.encode_text(text_inputs) # (N, D)
|
| 99 |
+
return F.normalize(feats, p=2, dim=-1) # (N, D)
|
|
|
|
|
|
|
| 100 |
|
|
|
|
| 101 |
|
| 102 |
+
def encode_text_df(
|
| 103 |
+
model: torch.nn.Module,
|
| 104 |
+
tokenizer: callable,
|
| 105 |
+
df: pd.DataFrame,
|
| 106 |
+
col_name: str,
|
| 107 |
+
device: Union[str, torch.device]
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
"""
|
| 110 |
+
Encodes an entire DataFrame column into (N, D) embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
"""
|
| 112 |
+
texts = df[col_name].astype(str).tolist()
|
| 113 |
+
return encode_texts(model, tokenizer, texts, device)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
|
src/requirements.txt
CHANGED
|
@@ -11,4 +11,5 @@ torchvision==0.18.1
|
|
| 11 |
open_clip_torch==2.26.1
|
| 12 |
pillow==10.4.0
|
| 13 |
ipykernel==6.29.5
|
|
|
|
| 14 |
|
|
|
|
| 11 |
open_clip_torch==2.26.1
|
| 12 |
pillow==10.4.0
|
| 13 |
ipykernel==6.29.5
|
| 14 |
+
ipywidgets==8.1.6
|
| 15 |
|