Upload 42 files
Browse files- src/LICENSE +29 -0
- src/build/lib/loki/__init__.py +0 -0
- src/build/lib/loki/align.py +568 -0
- src/build/lib/loki/annotate.py +102 -0
- src/build/lib/loki/decompose.py +143 -0
- src/build/lib/loki/plot.py +435 -0
- src/build/lib/loki/plotting.py +435 -0
- src/build/lib/loki/predex.py +25 -0
- src/build/lib/loki/preprocess.py +324 -0
- src/build/lib/loki/retrieve.py +28 -0
- src/build/lib/loki/utilities.py +159 -0
- src/build/lib/loki/utils.py +278 -0
- src/dist/loki-0.0.1-py3-none-any.whl +0 -0
- src/dist/loki-0.0.1.tar.gz +3 -0
- src/loki.egg-info/PKG-INFO +23 -0
- src/loki.egg-info/SOURCES.txt +16 -0
- src/loki.egg-info/dependency_links.txt +1 -0
- src/loki.egg-info/requires.txt +13 -0
- src/loki.egg-info/top_level.txt +1 -0
- src/loki/__init__.py +0 -0
- src/loki/__pycache__/__init__.cpython-310.pyc +0 -0
- src/loki/__pycache__/__init__.cpython-39.pyc +0 -0
- src/loki/__pycache__/align.cpython-39.pyc +0 -0
- src/loki/__pycache__/annotate.cpython-39.pyc +0 -0
- src/loki/__pycache__/decompose.cpython-39.pyc +0 -0
- src/loki/__pycache__/deconv.cpython-39.pyc +0 -0
- src/loki/__pycache__/plot.cpython-39.pyc +0 -0
- src/loki/__pycache__/predex.cpython-39.pyc +0 -0
- src/loki/__pycache__/preprocess.cpython-39.pyc +0 -0
- src/loki/__pycache__/retrieve.cpython-39.pyc +0 -0
- src/loki/__pycache__/utils.cpython-39.pyc +0 -0
- src/loki/align.py +568 -0
- src/loki/annotate.py +102 -0
- src/loki/decompose.py +143 -0
- src/loki/plot.py +435 -0
- src/loki/predex.py +25 -0
- src/loki/preprocess.py +324 -0
- src/loki/requirements.txt +14 -0
- src/loki/retrieve.py +28 -0
- src/loki/utils.py +278 -0
- src/requirements.txt +14 -0
- src/setup.py +32 -0
src/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025, Wang Lab
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without
|
| 7 |
+
modification, are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
and/or other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
contributors may be used to endorse or promote products derived from
|
| 18 |
+
this software without specific prior written permission.
|
| 19 |
+
|
| 20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
src/build/lib/loki/__init__.py
ADDED
|
File without changes
|
src/build/lib/loki/align.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pycpd
|
| 2 |
+
from builtins import super
|
| 3 |
+
import numbers
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
class EMRegistration(object):
|
| 8 |
+
"""
|
| 9 |
+
Expectation maximization point cloud registration.
|
| 10 |
+
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
|
| 11 |
+
https://github.com/siavashk/pycpd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
Attributes
|
| 15 |
+
----------
|
| 16 |
+
X: numpy array
|
| 17 |
+
NxD array of target points.
|
| 18 |
+
|
| 19 |
+
Y: numpy array
|
| 20 |
+
MxD array of source points.
|
| 21 |
+
|
| 22 |
+
TY: numpy array
|
| 23 |
+
MxD array of transformed source points.
|
| 24 |
+
|
| 25 |
+
sigma2: float (positive)
|
| 26 |
+
Initial variance of the Gaussian mixture model.
|
| 27 |
+
|
| 28 |
+
N: int
|
| 29 |
+
Number of target points.
|
| 30 |
+
|
| 31 |
+
M: int
|
| 32 |
+
Number of source points.
|
| 33 |
+
|
| 34 |
+
D: int
|
| 35 |
+
Dimensionality of source and target points
|
| 36 |
+
|
| 37 |
+
iteration: int
|
| 38 |
+
The current iteration throughout registration.
|
| 39 |
+
|
| 40 |
+
max_iterations: int
|
| 41 |
+
Registration will terminate once the algorithm has taken this
|
| 42 |
+
many iterations.
|
| 43 |
+
|
| 44 |
+
tolerance: float (positive)
|
| 45 |
+
Registration will terminate once the difference between
|
| 46 |
+
consecutive objective function values falls within this tolerance.
|
| 47 |
+
|
| 48 |
+
w: float (between 0 and 1)
|
| 49 |
+
Contribution of the uniform distribution to account for outliers.
|
| 50 |
+
Valid values span 0 (inclusive) and 1 (exclusive).
|
| 51 |
+
|
| 52 |
+
q: float
|
| 53 |
+
The objective function value that represents the misalignment between source
|
| 54 |
+
and target point clouds.
|
| 55 |
+
|
| 56 |
+
diff: float (positive)
|
| 57 |
+
The absolute difference between the current and previous objective function values.
|
| 58 |
+
|
| 59 |
+
P: numpy array
|
| 60 |
+
MxN array of probabilities.
|
| 61 |
+
P[m, n] represents the probability that the m-th source point
|
| 62 |
+
corresponds to the n-th target point.
|
| 63 |
+
|
| 64 |
+
Pt1: numpy array
|
| 65 |
+
Nx1 column array.
|
| 66 |
+
Multiplication result between the transpose of P and a column vector of all 1s.
|
| 67 |
+
|
| 68 |
+
P1: numpy array
|
| 69 |
+
Mx1 column array.
|
| 70 |
+
Multiplication result between P and a column vector of all 1s.
|
| 71 |
+
|
| 72 |
+
Np: float (positive)
|
| 73 |
+
The sum of all elements in P.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, X, Y, sigma2=None, max_iterations=None, tolerance=None, w=None, *args, **kwargs):
|
| 78 |
+
if type(X) is not np.ndarray or X.ndim != 2:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"The target point cloud (X) must be at a 2D numpy array.")
|
| 81 |
+
|
| 82 |
+
if type(Y) is not np.ndarray or Y.ndim != 2:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"The source point cloud (Y) must be a 2D numpy array.")
|
| 85 |
+
|
| 86 |
+
if X.shape[1] != Y.shape[1]:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
"Both point clouds need to have the same number of dimensions.")
|
| 89 |
+
|
| 90 |
+
if sigma2 is not None and (not isinstance(sigma2, numbers.Number) or sigma2 <= 0):
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Expected a positive value for sigma2 instead got: {}".format(sigma2))
|
| 93 |
+
|
| 94 |
+
if max_iterations is not None and (not isinstance(max_iterations, numbers.Number) or max_iterations < 0):
|
| 95 |
+
raise ValueError(
|
| 96 |
+
"Expected a positive integer for max_iterations instead got: {}".format(max_iterations))
|
| 97 |
+
elif isinstance(max_iterations, numbers.Number) and not isinstance(max_iterations, int):
|
| 98 |
+
warn("Received a non-integer value for max_iterations: {}. Casting to integer.".format(max_iterations))
|
| 99 |
+
max_iterations = int(max_iterations)
|
| 100 |
+
|
| 101 |
+
if tolerance is not None and (not isinstance(tolerance, numbers.Number) or tolerance < 0):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Expected a positive float for tolerance instead got: {}".format(tolerance))
|
| 104 |
+
|
| 105 |
+
if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1):
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"Expected a value between 0 (inclusive) and 1 (exclusive) for w instead got: {}".format(w))
|
| 108 |
+
|
| 109 |
+
self.X = X
|
| 110 |
+
self.Y = Y
|
| 111 |
+
self.TY = Y
|
| 112 |
+
self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2
|
| 113 |
+
(self.N, self.D) = self.X.shape
|
| 114 |
+
(self.M, _) = self.Y.shape
|
| 115 |
+
self.tolerance = 0.001 if tolerance is None else tolerance
|
| 116 |
+
self.w = 0.0 if w is None else w
|
| 117 |
+
self.max_iterations = 100 if max_iterations is None else max_iterations
|
| 118 |
+
self.iteration = 0
|
| 119 |
+
self.diff = np.inf
|
| 120 |
+
self.q = np.inf
|
| 121 |
+
self.P = np.zeros((self.M, self.N))
|
| 122 |
+
self.Pt1 = np.zeros((self.N, ))
|
| 123 |
+
self.P1 = np.zeros((self.M, ))
|
| 124 |
+
self.PX = np.zeros((self.M, self.D))
|
| 125 |
+
self.Np = 0
|
| 126 |
+
|
| 127 |
+
def register(self, callback=lambda **kwargs: None):
|
| 128 |
+
"""
|
| 129 |
+
Perform the EM registration.
|
| 130 |
+
|
| 131 |
+
Attributes
|
| 132 |
+
----------
|
| 133 |
+
callback: function
|
| 134 |
+
A function that will be called after each iteration.
|
| 135 |
+
Can be used to visualize the registration process.
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
self.TY: numpy array
|
| 140 |
+
MxD array of transformed source points.
|
| 141 |
+
|
| 142 |
+
registration_parameters:
|
| 143 |
+
Returned params dependent on registration method used.
|
| 144 |
+
"""
|
| 145 |
+
self.transform_point_cloud()
|
| 146 |
+
while self.iteration < self.max_iterations and self.diff > self.tolerance:
|
| 147 |
+
self.iterate()
|
| 148 |
+
if callable(callback):
|
| 149 |
+
kwargs = {'iteration': self.iteration,
|
| 150 |
+
'error': self.q, 'X': self.X, 'Y': self.TY}
|
| 151 |
+
callback(**kwargs)
|
| 152 |
+
|
| 153 |
+
return self.TY, self.get_registration_parameters()
|
| 154 |
+
|
| 155 |
+
def get_registration_parameters(self):
|
| 156 |
+
"""
|
| 157 |
+
Placeholder for child classes.
|
| 158 |
+
"""
|
| 159 |
+
raise NotImplementedError(
|
| 160 |
+
"Registration parameters should be defined in child classes.")
|
| 161 |
+
|
| 162 |
+
def update_transform(self):
|
| 163 |
+
"""
|
| 164 |
+
Placeholder for child classes.
|
| 165 |
+
"""
|
| 166 |
+
raise NotImplementedError(
|
| 167 |
+
"Updating transform parameters should be defined in child classes.")
|
| 168 |
+
|
| 169 |
+
def transform_point_cloud(self):
|
| 170 |
+
"""
|
| 171 |
+
Placeholder for child classes.
|
| 172 |
+
"""
|
| 173 |
+
raise NotImplementedError(
|
| 174 |
+
"Updating the source point cloud should be defined in child classes.")
|
| 175 |
+
|
| 176 |
+
def update_variance(self):
|
| 177 |
+
"""
|
| 178 |
+
Placeholder for child classes.
|
| 179 |
+
"""
|
| 180 |
+
raise NotImplementedError(
|
| 181 |
+
"Updating the Gaussian variance for the mixture model should be defined in child classes.")
|
| 182 |
+
|
| 183 |
+
def iterate(self):
|
| 184 |
+
"""
|
| 185 |
+
Perform one iteration of the EM algorithm.
|
| 186 |
+
"""
|
| 187 |
+
self.expectation()
|
| 188 |
+
self.maximization()
|
| 189 |
+
self.iteration += 1
|
| 190 |
+
|
| 191 |
+
def expectation(self):
|
| 192 |
+
"""
|
| 193 |
+
Compute the expectation step of the EM algorithm.
|
| 194 |
+
"""
|
| 195 |
+
P = np.sum((self.X[None, :, :] - self.TY[:, None, :])**2, axis=2) # (M, N)
|
| 196 |
+
P = np.exp(-P/(2*self.sigma2))
|
| 197 |
+
c = (2*np.pi*self.sigma2)**(self.D/2)*self.w/(1. - self.w)*self.M/self.N
|
| 198 |
+
|
| 199 |
+
den = np.sum(P, axis = 0, keepdims = True) # (1, N)
|
| 200 |
+
den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c
|
| 201 |
+
|
| 202 |
+
self.P = np.divide(P, den)
|
| 203 |
+
self.Pt1 = np.sum(self.P, axis=0)
|
| 204 |
+
self.P1 = np.sum(self.P, axis=1)
|
| 205 |
+
self.Np = np.sum(self.P1)
|
| 206 |
+
self.PX = np.matmul(self.P, self.X)
|
| 207 |
+
|
| 208 |
+
def maximization(self):
|
| 209 |
+
"""
|
| 210 |
+
Compute the maximization step of the EM algorithm.
|
| 211 |
+
"""
|
| 212 |
+
self.update_transform()
|
| 213 |
+
self.transform_point_cloud()
|
| 214 |
+
self.update_variance()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class DeformableRegistration(EMRegistration):
|
| 218 |
+
"""
|
| 219 |
+
Deformable registration.
|
| 220 |
+
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
|
| 221 |
+
https://github.com/siavashk/pycpd
|
| 222 |
+
|
| 223 |
+
Attributes
|
| 224 |
+
----------
|
| 225 |
+
alpha: float (positive)
|
| 226 |
+
Represents the trade-off between the goodness of maximum likelihood fit and regularization.
|
| 227 |
+
|
| 228 |
+
beta: float(positive)
|
| 229 |
+
Width of the Gaussian kernel.
|
| 230 |
+
|
| 231 |
+
low_rank: bool
|
| 232 |
+
Whether to use low rank approximation.
|
| 233 |
+
|
| 234 |
+
num_eig: int
|
| 235 |
+
Number of eigenvectors to use in lowrank calculation.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(self, alpha=None, beta=None, low_rank=False, num_eig=100, *args, **kwargs):
|
| 239 |
+
super().__init__(*args, **kwargs)
|
| 240 |
+
if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0):
|
| 241 |
+
raise ValueError(
|
| 242 |
+
"Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha))
|
| 243 |
+
|
| 244 |
+
if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0):
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta))
|
| 247 |
+
|
| 248 |
+
self.alpha = 2 if alpha is None else alpha
|
| 249 |
+
self.beta = 2 if beta is None else beta
|
| 250 |
+
self.W = np.zeros((self.M, self.D))
|
| 251 |
+
self.G = gaussian_kernel(self.Y, self.beta)
|
| 252 |
+
self.low_rank = low_rank
|
| 253 |
+
self.num_eig = num_eig
|
| 254 |
+
if self.low_rank is True:
|
| 255 |
+
self.Q, self.S = low_rank_eigen(self.G, self.num_eig)
|
| 256 |
+
self.inv_S = np.diag(1./self.S)
|
| 257 |
+
self.S = np.diag(self.S)
|
| 258 |
+
self.E = 0.
|
| 259 |
+
|
| 260 |
+
def update_transform(self):
|
| 261 |
+
"""
|
| 262 |
+
Calculate a new estimate of the deformable transformation.
|
| 263 |
+
See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
if self.low_rank is False:
|
| 267 |
+
A = np.dot(np.diag(self.P1), self.G) + \
|
| 268 |
+
self.alpha * self.sigma2 * np.eye(self.M)
|
| 269 |
+
B = self.PX - np.dot(np.diag(self.P1), self.Y)
|
| 270 |
+
self.W = np.linalg.solve(A, B)
|
| 271 |
+
|
| 272 |
+
elif self.low_rank is True:
|
| 273 |
+
# Matlab code equivalent can be found here:
|
| 274 |
+
# https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
|
| 275 |
+
dP = np.diag(self.P1)
|
| 276 |
+
dPQ = np.matmul(dP, self.Q)
|
| 277 |
+
F = self.PX - np.matmul(dP, self.Y)
|
| 278 |
+
|
| 279 |
+
self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
|
| 280 |
+
np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
|
| 281 |
+
(np.matmul(self.Q.T, F))))))
|
| 282 |
+
QtW = np.matmul(self.Q.T, self.W)
|
| 283 |
+
self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))
|
| 284 |
+
|
| 285 |
+
def transform_point_cloud(self, Y=None):
|
| 286 |
+
"""
|
| 287 |
+
Update a point cloud using the new estimate of the deformable transformation.
|
| 288 |
+
|
| 289 |
+
Attributes
|
| 290 |
+
----------
|
| 291 |
+
Y: numpy array, optional
|
| 292 |
+
Array of points to transform - use to predict on new set of points.
|
| 293 |
+
Best for predicting on new points not used to run initial registration.
|
| 294 |
+
If None, self.Y used.
|
| 295 |
+
|
| 296 |
+
Returns
|
| 297 |
+
-------
|
| 298 |
+
If Y is None, returns None.
|
| 299 |
+
Otherwise, returns the transformed Y.
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
self.W[:,2:]=0
|
| 304 |
+
if Y is not None:
|
| 305 |
+
G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y)
|
| 306 |
+
return Y + np.dot(G, self.W)
|
| 307 |
+
else:
|
| 308 |
+
if self.low_rank is False:
|
| 309 |
+
self.TY = self.Y + np.dot(self.G, self.W)
|
| 310 |
+
|
| 311 |
+
elif self.low_rank is True:
|
| 312 |
+
self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)))
|
| 313 |
+
return
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def update_variance(self):
|
| 317 |
+
"""
|
| 318 |
+
Update the variance of the mixture model using the new estimate of the deformable transformation.
|
| 319 |
+
See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
|
| 320 |
+
|
| 321 |
+
"""
|
| 322 |
+
qprev = self.sigma2
|
| 323 |
+
|
| 324 |
+
# The original CPD paper does not explicitly calculate the objective functional.
|
| 325 |
+
# This functional will include terms from both the negative log-likelihood and
|
| 326 |
+
# the Gaussian kernel used for regularization.
|
| 327 |
+
self.q = np.inf
|
| 328 |
+
|
| 329 |
+
xPx = np.dot(np.transpose(self.Pt1), np.sum(
|
| 330 |
+
np.multiply(self.X, self.X), axis=1))
|
| 331 |
+
yPy = np.dot(np.transpose(self.P1), np.sum(
|
| 332 |
+
np.multiply(self.TY, self.TY), axis=1))
|
| 333 |
+
trPXY = np.sum(np.multiply(self.TY, self.PX))
|
| 334 |
+
|
| 335 |
+
self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D)
|
| 336 |
+
|
| 337 |
+
if self.sigma2 <= 0:
|
| 338 |
+
self.sigma2 = self.tolerance / 10
|
| 339 |
+
|
| 340 |
+
# Here we use the difference between the current and previous
|
| 341 |
+
# estimate of the variance as a proxy to test for convergence.
|
| 342 |
+
self.diff = np.abs(self.sigma2 - qprev)
|
| 343 |
+
|
| 344 |
+
def get_registration_parameters(self):
|
| 345 |
+
"""
|
| 346 |
+
Return the current estimate of the deformable transformation parameters.
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
Returns
|
| 350 |
+
-------
|
| 351 |
+
self.G: numpy array
|
| 352 |
+
Gaussian kernel matrix.
|
| 353 |
+
|
| 354 |
+
self.W: numpy array
|
| 355 |
+
Deformable transformation matrix.
|
| 356 |
+
"""
|
| 357 |
+
return self.G, self.W
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def initialize_sigma2(X, Y):
|
| 362 |
+
"""
|
| 363 |
+
Initialize the variance (sigma2).
|
| 364 |
+
|
| 365 |
+
param
|
| 366 |
+
----------
|
| 367 |
+
X: numpy array
|
| 368 |
+
NxD array of points for target.
|
| 369 |
+
|
| 370 |
+
Y: numpy array
|
| 371 |
+
MxD array of points for source.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
sigma2: float
|
| 376 |
+
Initial variance.
|
| 377 |
+
"""
|
| 378 |
+
(N, D) = X.shape
|
| 379 |
+
(M, _) = Y.shape
|
| 380 |
+
diff = X[None, :, :] - Y[:, None, :]
|
| 381 |
+
err = diff ** 2
|
| 382 |
+
return np.sum(err) / (D * M * N)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def gaussian_kernel(X, beta, Y=None):
|
| 387 |
+
"""
|
| 388 |
+
Computes a Gaussian (RBF) kernel matrix between two sets of vectors.
|
| 389 |
+
|
| 390 |
+
:param X: A numpy array of shape (n_samples_X, n_features) representing the first set of vectors.
|
| 391 |
+
:param beta: The standard deviation parameter for the Gaussian kernel. It controls the spread of the kernel.
|
| 392 |
+
:param Y: An optional numpy array of shape (n_samples_Y, n_features) representing the second set of vectors.
|
| 393 |
+
If None, the function computes the kernel between `X` and itself (i.e., the Gram matrix).
|
| 394 |
+
:return: A numpy array of shape (n_samples_X, n_samples_Y) representing the Gaussian kernel matrix.
|
| 395 |
+
Each element (i, j) in the matrix is computed as:
|
| 396 |
+
`exp(-||X[i] - Y[j]||^2 / (2 * beta^2))`
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
# If Y is not provided, use X for both sets, computing the kernel matrix between X and itself
|
| 400 |
+
if Y is None:
|
| 401 |
+
Y = X
|
| 402 |
+
|
| 403 |
+
# Compute the difference tensor between each pair of vectors in X and Y
|
| 404 |
+
# The resulting shape is (n_samples_X, n_samples_Y, n_features)
|
| 405 |
+
diff = X[:, None, :] - Y[None, :, :]
|
| 406 |
+
|
| 407 |
+
# Square the differences element-wise
|
| 408 |
+
diff = np.square(diff)
|
| 409 |
+
|
| 410 |
+
# Sum the squared differences across the feature dimension (axis 2) to get squared Euclidean distances
|
| 411 |
+
# The resulting shape is (n_samples_X, n_samples_Y)
|
| 412 |
+
diff = np.sum(diff, axis=2)
|
| 413 |
+
|
| 414 |
+
# Apply the Gaussian (RBF) kernel formula: exp(-||X[i] - Y[j]||^2 / (2 * beta^2))
|
| 415 |
+
kernel_matrix = np.exp(-diff / (2 * beta**2))
|
| 416 |
+
|
| 417 |
+
return kernel_matrix
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def low_rank_eigen(G, num_eig):
|
| 422 |
+
"""
|
| 423 |
+
Calculate the top `num_eig` eigenvectors and eigenvalues of a given Gaussian matrix G.
|
| 424 |
+
This function is useful for dimensionality reduction or when a low-rank approximation is needed.
|
| 425 |
+
|
| 426 |
+
:param G: A square matrix (numpy array) for which the eigen decomposition is to be performed.
|
| 427 |
+
:param num_eig: The number of top eigenvectors and eigenvalues to return, based on the magnitude of eigenvalues.
|
| 428 |
+
:return: A tuple containing:
|
| 429 |
+
- Q: A numpy array with shape (n, num_eig) containing the top `num_eig` eigenvectors of the matrix `G`.
|
| 430 |
+
Each column in `Q` corresponds to an eigenvector.
|
| 431 |
+
- S: A numpy array of shape (num_eig,) containing the top `num_eig` eigenvalues of the matrix `G`.
|
| 432 |
+
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
# Perform eigen decomposition on matrix G
|
| 436 |
+
# `S` will contain all the eigenvalues, and `Q` will contain the corresponding eigenvectors
|
| 437 |
+
S, Q = np.linalg.eigh(G)
|
| 438 |
+
|
| 439 |
+
# Sort eigenvalues in descending order based on their absolute values
|
| 440 |
+
# Get the indices of the top `num_eig` largest eigenvalues
|
| 441 |
+
eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig])
|
| 442 |
+
|
| 443 |
+
# Select the corresponding top eigenvectors based on the sorted indices
|
| 444 |
+
Q = Q[:, eig_indices] # Q now contains the top `num_eig` eigenvectors
|
| 445 |
+
|
| 446 |
+
# Select the top `num_eig` eigenvalues based on the sorted indices
|
| 447 |
+
S = S[eig_indices] # S now contains the top `num_eig` eigenvalues
|
| 448 |
+
|
| 449 |
+
return Q, S
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def find_homography_translation_rotation(src_points, dst_points):
|
| 454 |
+
"""
|
| 455 |
+
Find the homography between two sets of coordinates with only translation and rotation.
|
| 456 |
+
|
| 457 |
+
:param src_points: A numpy array of shape (n, 2) containing source coordinates.
|
| 458 |
+
:param dst_points: A numpy array of shape (n, 2) containing destination coordinates.
|
| 459 |
+
:return: A 3x3 homography matrix.
|
| 460 |
+
"""
|
| 461 |
+
# Ensure the points are in the correct shape
|
| 462 |
+
assert src_points.shape == dst_points.shape
|
| 463 |
+
assert src_points.shape[1] == 2
|
| 464 |
+
|
| 465 |
+
# Calculate the centroids of the point sets
|
| 466 |
+
src_centroid = np.mean(src_points, axis=0)
|
| 467 |
+
dst_centroid = np.mean(dst_points, axis=0)
|
| 468 |
+
|
| 469 |
+
# Center the points around the centroids
|
| 470 |
+
centered_src_points = src_points - src_centroid
|
| 471 |
+
centered_dst_points = dst_points - dst_centroid
|
| 472 |
+
|
| 473 |
+
# Calculate the covariance matrix
|
| 474 |
+
H = np.dot(centered_src_points.T, centered_dst_points)
|
| 475 |
+
|
| 476 |
+
# Singular Value Decomposition (SVD)
|
| 477 |
+
U, S, Vt = np.linalg.svd(H)
|
| 478 |
+
|
| 479 |
+
# Calculate the rotation matrix
|
| 480 |
+
R = np.dot(Vt.T, U.T)
|
| 481 |
+
|
| 482 |
+
# Ensure a proper rotation matrix (det(R) = 1)
|
| 483 |
+
if np.linalg.det(R) < 0:
|
| 484 |
+
Vt[-1, :] *= -1
|
| 485 |
+
R = np.dot(Vt.T, U.T)
|
| 486 |
+
|
| 487 |
+
# Calculate the translation vector
|
| 488 |
+
t = dst_centroid - np.dot(R, src_centroid)
|
| 489 |
+
|
| 490 |
+
# Construct the homography matrix
|
| 491 |
+
homography_matrix = np.eye(3)
|
| 492 |
+
homography_matrix[0:2, 0:2] = R
|
| 493 |
+
homography_matrix[0:2, 2] = t
|
| 494 |
+
|
| 495 |
+
return homography_matrix
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def apply_homography(coordinates, H):
|
| 500 |
+
"""
|
| 501 |
+
Apply a 3x3 homography matrix to 2D coordinates.
|
| 502 |
+
|
| 503 |
+
:param coordinates: A numpy array of shape (n, 2) containing 2D coordinates.
|
| 504 |
+
:param H: A numpy array of shape (3, 3) representing the homography matrix.
|
| 505 |
+
:return: A numpy array of shape (n, 2) with transformed coordinates.
|
| 506 |
+
"""
|
| 507 |
+
# Convert (x, y) to homogeneous coordinates (x, y, 1)
|
| 508 |
+
n = coordinates.shape[0]
|
| 509 |
+
homogeneous_coords = np.hstack((coordinates, np.ones((n, 1))))
|
| 510 |
+
|
| 511 |
+
# Apply the homography matrix
|
| 512 |
+
transformed_homogeneous = np.dot(homogeneous_coords, H.T)
|
| 513 |
+
|
| 514 |
+
# Convert back from homogeneous coordinates (x', y', w') to (x'/w', y'/w')
|
| 515 |
+
transformed_coords = transformed_homogeneous[:, :2] / transformed_homogeneous[:, [2]]
|
| 516 |
+
|
| 517 |
+
return transformed_coords
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def align_tissue(ad_tar_coor, ad_src_coor, pca_comb_features, src_img, alpha=0.5):
|
| 522 |
+
"""
|
| 523 |
+
Aligns the source coordinates to the target coordinates using Coherent Point Drift (CPD)
|
| 524 |
+
registration, and applies a homography transformation to warp the source coordinates accordingly.
|
| 525 |
+
|
| 526 |
+
:param ad_tar_coor: Numpy array of target coordinates to which the source will be aligned.
|
| 527 |
+
:param ad_src_coor: Numpy array of source coordinates that will be aligned to the target.
|
| 528 |
+
:param pca_comb_features: PCA-combined feature matrix used as additional features for the alignment process.
|
| 529 |
+
:param src_img: Source image to be warped based on the alignment.
|
| 530 |
+
:param alpha: Regularization parameter for CPD registration, default is 0.5.
|
| 531 |
+
:return:
|
| 532 |
+
- cpd_coor: The new source coordinates after CPD alignment.
|
| 533 |
+
- homo_coor: The source coordinates after applying the homography transformation.
|
| 534 |
+
- aligned_image: The source image warped based on the homography transformation.
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
# Normalize target and source coordinates to the range [0, 1]
|
| 538 |
+
ad_tar_coor_z = (ad_tar_coor - ad_tar_coor.min()) / (ad_tar_coor.max() - ad_tar_coor.min())
|
| 539 |
+
ad_src_coor_z = (ad_src_coor - ad_src_coor.min()) / (ad_src_coor.max() - ad_src_coor.min())
|
| 540 |
+
|
| 541 |
+
# Normalize PCA-combined features to the range [0, 1]
|
| 542 |
+
pca_comb_features_z = (pca_comb_features - pca_comb_features.min()) / (pca_comb_features.max() - pca_comb_features.min())
|
| 543 |
+
|
| 544 |
+
# Concatenate spatial and PCA-combined features for target and source
|
| 545 |
+
target = np.concatenate((ad_tar_coor_z, pca_comb_features_z[:ad_tar_coor.shape[0], :2]), axis=1)
|
| 546 |
+
source = np.concatenate((ad_src_coor_z, pca_comb_features_z[ad_tar_coor.shape[0]:, :2]), axis=1)
|
| 547 |
+
|
| 548 |
+
# Initialize and run the CPD registration (deformable with regularization)
|
| 549 |
+
reg = DeformableRegistration(X=target, Y=source, low_rank=True,
|
| 550 |
+
alpha=alpha,
|
| 551 |
+
max_iterations=int(1e9), tolerance=1e-9)
|
| 552 |
+
|
| 553 |
+
TY = reg.register()[0] # TY contains the transformed source points
|
| 554 |
+
|
| 555 |
+
# Rescale the CPD-aligned coordinates back to the original range of target coordinates
|
| 556 |
+
cpd_coor = TY[:, :2] * (ad_tar_coor.max() - ad_tar_coor.min()) + ad_tar_coor.min()
|
| 557 |
+
|
| 558 |
+
# Find homography transformation based on CPD-aligned coordinates and apply it
|
| 559 |
+
h = find_homography_translation_rotation(ad_src_coor, cpd_coor)
|
| 560 |
+
homo_coor = apply_homography(ad_src_coor, h)
|
| 561 |
+
|
| 562 |
+
# Warp the source image using the computed homography
|
| 563 |
+
aligned_image = cv2.warpPerspective(src_img, h, (src_img.shape[1], src_img.shape[0]))
|
| 564 |
+
|
| 565 |
+
# Return the CPD-aligned coordinates, the homography-transformed coordinates, and the warped image
|
| 566 |
+
return cpd_coor, homo_coor, aligned_image
|
| 567 |
+
|
| 568 |
+
|
src/build/lib/loki/annotate.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import os
|
| 5 |
+
import scanpy as sc
|
| 6 |
+
import json
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def annotate_with_bulk(img_features, bulk_features, normalize=True, T=1, tensor=False):
|
| 12 |
+
"""
|
| 13 |
+
Annotates tissue image with similarity scores between image features and bulk RNA-seq features.
|
| 14 |
+
|
| 15 |
+
:param img_features: Feature matrix representing histopathology image features.
|
| 16 |
+
:param bulk_features: Feature vector representing bulk RNA-seq features.
|
| 17 |
+
:param normalize: Whether to normalize similarity scores, default=True.
|
| 18 |
+
:param T: Temperature parameter to control the sharpness of the softmax distribution. Higher values result in a smoother distribution.
|
| 19 |
+
:param tensor: Feature format in torch tensor or not, default=False.
|
| 20 |
+
|
| 21 |
+
:return: An array or tensor containing the normalized similarity scores.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
if tensor:
|
| 25 |
+
# Compute similarity between image features and bulk RNA-seq features
|
| 26 |
+
cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 27 |
+
similarity = cosine_similarity(img_features, bulk_features.unsqueeze(0)) # Shape: [n]
|
| 28 |
+
|
| 29 |
+
# Optional normalization using the feature vector's norm
|
| 30 |
+
if normalize:
|
| 31 |
+
normalization_factor = torch.sqrt(torch.tensor([bulk_features.shape[0]], dtype=torch.float)) # sqrt(768)
|
| 32 |
+
similarity = similarity / normalization_factor
|
| 33 |
+
|
| 34 |
+
# Reshape and apply temperature scaling for softmax
|
| 35 |
+
similarity = similarity.unsqueeze(0) # Shape: [1, n]
|
| 36 |
+
similarity = similarity / T # Control distribution sharpness
|
| 37 |
+
|
| 38 |
+
# Convert similarity scores to probability distribution using softmax
|
| 39 |
+
similarity = torch.nn.functional.softmax(similarity, dim=-1) # Shape: [1, n]
|
| 40 |
+
|
| 41 |
+
else:
|
| 42 |
+
# Compute similarity for non-tensor mode
|
| 43 |
+
similarity = np.dot(img_features.T, bulk_features)
|
| 44 |
+
|
| 45 |
+
# Apply a softmax-like normalization for numerical stability
|
| 46 |
+
max_similarity = np.max(similarity) # Maximum value for stability
|
| 47 |
+
similarity = np.exp(similarity - max_similarity) / np.sum(np.exp(similarity - max_similarity))
|
| 48 |
+
|
| 49 |
+
# Normalize similarity scores to [0, 1] range for interpretation
|
| 50 |
+
similarity = (similarity - np.min(similarity)) / (np.max(similarity) - np.min(similarity))
|
| 51 |
+
|
| 52 |
+
return similarity
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def annotate_with_marker_genes(classes, image_embeddings, all_text_embeddings):
|
| 57 |
+
"""
|
| 58 |
+
Annotates tissue image with similarity scores between image features and marker gene features.
|
| 59 |
+
|
| 60 |
+
:param classes: A list or array of tissue type labels.
|
| 61 |
+
:param image_embeddings: A numpy array or torch tensor of image embeddings (shape: [n_images, embedding_dim]).
|
| 62 |
+
:param all_text_embeddings: A numpy array or torch tensor of text embeddings of the marker genes
|
| 63 |
+
(shape: [n_classes, embedding_dim]).
|
| 64 |
+
|
| 65 |
+
:return:
|
| 66 |
+
- dot_similarity: The matrix of dot product similarities between image embeddings and text embeddings.
|
| 67 |
+
- pred_class: The predicted tissue type for the image based on the highest similarity score.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Calculate dot product similarity between image embeddings and text embeddings
|
| 71 |
+
# This results in a similarity matrix of shape [n_images, n_classes]
|
| 72 |
+
dot_similarity = image_embeddings @ all_text_embeddings.T
|
| 73 |
+
|
| 74 |
+
# Find the class with the highest similarity for each image
|
| 75 |
+
# Use argmax to identify the index of the highest similarity score
|
| 76 |
+
pred_class = classes[dot_similarity.argmax()]
|
| 77 |
+
|
| 78 |
+
return dot_similarity, pred_class
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_image_annotation(image_path):
|
| 83 |
+
"""
|
| 84 |
+
Loads an image with annotation.
|
| 85 |
+
|
| 86 |
+
:param image_path: The file path to the image.
|
| 87 |
+
|
| 88 |
+
:return: The processed image, converted to BGR color space and of type uint8.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# Load the image from the specified file path using OpenCV
|
| 92 |
+
image = cv2.imread(image_path)
|
| 93 |
+
|
| 94 |
+
# Convert the color from RGB (OpenCV loads as BGR by default) to BGR (which matches common color standards)
|
| 95 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 96 |
+
|
| 97 |
+
# Ensure the image is of type uint8 for proper handling in OpenCV and other image processing libraries
|
| 98 |
+
image = image.astype(np.uint8)
|
| 99 |
+
|
| 100 |
+
return image
|
| 101 |
+
|
| 102 |
+
|
src/build/lib/loki/decompose.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import tangram as tg
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import anndata
|
| 6 |
+
from sklearn.decomposition import PCA
|
| 7 |
+
from sklearn.neighbors import NearestNeighbors
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_feature_ad(ad_expr, feature_path, sc=False):
|
| 12 |
+
"""
|
| 13 |
+
Generates an AnnData object with OmiCLIP text or image embeddings.
|
| 14 |
+
|
| 15 |
+
:param ad_expr: AnnData object containing metadata for the dataset.
|
| 16 |
+
:param feature_path: Path to the CSV file containing the features to be loaded.
|
| 17 |
+
:param sc: Boolean flag indicating whether to copy single-cell metadata or ST metadata. Default is False (ST).
|
| 18 |
+
:return: A new AnnData object with the loaded features and relevant metadata from ad_expr.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Load features from the CSV file. The index should match the cells/spots in ad_expr.obs.index.
|
| 22 |
+
features = pd.read_csv(feature_path, index_col=0)[ad_expr.obs.index]
|
| 23 |
+
|
| 24 |
+
# Create a new AnnData object with the features, transposing them to have cells/spots as rows
|
| 25 |
+
feature_ad = anndata.AnnData(features[ad_expr.obs.index].T)
|
| 26 |
+
|
| 27 |
+
# Copy relevant metadata from ad_expr based on the sc flag
|
| 28 |
+
if sc:
|
| 29 |
+
# If the data is single-cell (sc), copy the metadata from ad_expr.obs
|
| 30 |
+
feature_ad.obs = ad_expr.obs.copy()
|
| 31 |
+
else:
|
| 32 |
+
# If the data is spatial, copy the 'cell_num', 'spatial' info, and spatial coordinates
|
| 33 |
+
feature_ad.obs['cell_num'] = ad_expr.obs['cell_num'].copy()
|
| 34 |
+
feature_ad.uns['spatial'] = ad_expr.uns['spatial'].copy()
|
| 35 |
+
feature_ad.obsm['spatial'] = ad_expr.obsm['spatial'].copy()
|
| 36 |
+
|
| 37 |
+
return feature_ad
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def normalize_percentile(df, cols, min_percentile=5, max_percentile=95):
|
| 42 |
+
"""
|
| 43 |
+
Clips and normalizes the specified columns of a DataFrame based on percentile thresholds,
|
| 44 |
+
transforming their values to the [0, 1] range.
|
| 45 |
+
|
| 46 |
+
:param df: A pandas DataFrame containing the columns to normalize.
|
| 47 |
+
:type df: pandas.DataFrame
|
| 48 |
+
:param cols: A list of column names in `df` that should be normalized.
|
| 49 |
+
:type cols: list[str]
|
| 50 |
+
:param min_percentile: The lower percentile used for clipping (defaults to 5).
|
| 51 |
+
:type min_percentile: float
|
| 52 |
+
:param max_percentile: The upper percentile used for clipping (defaults to 95).
|
| 53 |
+
:type max_percentile: float
|
| 54 |
+
:return: The same DataFrame with specified columns clipped and normalized.
|
| 55 |
+
:rtype: pandas.DataFrame
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Iterate over each column that needs to be normalized
|
| 59 |
+
for col in cols:
|
| 60 |
+
# Compute the lower and upper values at the given percentiles
|
| 61 |
+
min_val = np.percentile(df[col], min_percentile)
|
| 62 |
+
max_val = np.percentile(df[col], max_percentile)
|
| 63 |
+
|
| 64 |
+
# Clip the column's values between these percentile thresholds
|
| 65 |
+
df[col] = np.clip(df[col], min_val, max_val)
|
| 66 |
+
|
| 67 |
+
# Perform min-max normalization to scale the clipped values to the [0, 1] range
|
| 68 |
+
df[col] = (df[col] - min_val) / (max_val - min_val)
|
| 69 |
+
|
| 70 |
+
return df
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False, major_types=None, min_percentile=5, max_percentile=95):
|
| 75 |
+
"""
|
| 76 |
+
Performs cell type decomposition on spatial data (ST or image) with single-cell data .
|
| 77 |
+
|
| 78 |
+
:param sc_ad: AnnData object containing single-cell meta data.
|
| 79 |
+
:param st_ad: AnnData object containing spatial data (ST or image) meta data.
|
| 80 |
+
:param density_prior: A numpy array providing prior information about cell densities in spatial spots.
|
| 81 |
+
:param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
|
| 82 |
+
:param target_count: If True, sums up the total number of cells in `st_ad.obs['cell_num']`. Can also be set to a specific value.
|
| 83 |
+
:param pca_mode: Boolean flag to apply PCA for dimensionality reduction. Default is True.
|
| 84 |
+
:param n_components: Number of PCA components to use if `pca_mode` is True. Default is 300.
|
| 85 |
+
:return: The spatial AnnData object with projected cell type annotations.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# Preprocess the data for decomposition using tangram (tg)
|
| 89 |
+
tg.pp_adatas(sc_ad, st_ad, genes=None) # Preprocessing: match genes between single-cell and spatial data
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Map single-cell data to spatial data using Tangram's "map_cells_to_space" function
|
| 93 |
+
ad_map = tg.map_cells_to_space(
|
| 94 |
+
sc_ad, st_ad,
|
| 95 |
+
mode="clusters", # Map based on clusters (cell types)
|
| 96 |
+
cluster_label=cell_type_col, # Column in `sc_ad.obs` representing cell type
|
| 97 |
+
device='cpu', # Run on CPU (or 'cuda' if GPU is available)
|
| 98 |
+
scale=False, # Don't scale data (can be set to True if needed)
|
| 99 |
+
density_prior='uniform', # Use prior information for cell densities
|
| 100 |
+
random_state=10, # Set random state for reproducibility
|
| 101 |
+
verbose=False, # Disable verbose output for cleaner logging
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Project cell type annotations from the single-cell data to the spatial data
|
| 105 |
+
tg.project_cell_annotations(ad_map, st_ad, annotation=cell_type_col)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if NMS_mode:
|
| 109 |
+
major_types = major_types
|
| 110 |
+
st_ad.obs = normalize_percentile(st_ad.obsm['tangram_ct_pred'], major_types, min_percentile, max_percentile)
|
| 111 |
+
|
| 112 |
+
st_ad_binary = st_ad.obsm['tangram_ct_pred'][major_types].copy()
|
| 113 |
+
# Retain the max value in each row and set the rest to 0
|
| 114 |
+
st_ad.obs[major_types] = st_ad_binary.where(st_ad_binary.eq(st_ad_binary.max(axis=1), axis=0), other=0)
|
| 115 |
+
|
| 116 |
+
return st_ad # Return the spatial AnnData object with the projected annotations
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def assign_cells_to_spots(cell_locs, spot_locs, patch_size=16):
|
| 121 |
+
"""
|
| 122 |
+
Assigns cells to spots based on their spatial coordinates. Each cell within the specified patch size (radius)
|
| 123 |
+
of a spot will be assigned to that spot.
|
| 124 |
+
|
| 125 |
+
:param cell_locs: Numpy array of shape (n_cells, 2) with the x, y coordinates of the cells.
|
| 126 |
+
:param spot_locs: Numpy array of shape (n_spots, 2) with the x, y coordinates of the spots.
|
| 127 |
+
:param patch_size: The diameter of the spot patch. The radius used for assignment will be half of this value.
|
| 128 |
+
:return: A sparse matrix where each row corresponds to a cell and each column corresponds to a spot.
|
| 129 |
+
The value is 1 if the cell is assigned to that spot, 0 otherwise.
|
| 130 |
+
"""
|
| 131 |
+
# Initialize the NearestNeighbors model with a radius equal to half the patch size
|
| 132 |
+
neigh = NearestNeighbors(radius=patch_size * 0.5)
|
| 133 |
+
|
| 134 |
+
# Fit the model on the spot locations
|
| 135 |
+
neigh.fit(spot_locs)
|
| 136 |
+
|
| 137 |
+
# Create the radius neighbors graph which will assign cells to spots based on proximity
|
| 138 |
+
# This graph is a sparse matrix where rows are cells and columns are spots, with a 1 indicating assignment
|
| 139 |
+
A = neigh.radius_neighbors_graph(cell_locs, mode='connectivity')
|
| 140 |
+
|
| 141 |
+
return A
|
| 142 |
+
|
| 143 |
+
|
src/build/lib/loki/plot.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import json
|
| 4 |
+
import cv2
|
| 5 |
+
from matplotlib import cm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
|
| 13 |
+
"""
|
| 14 |
+
Plots the target coordinates and alignment of source coordinates.
|
| 15 |
+
|
| 16 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
|
| 17 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 18 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 19 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 20 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 21 |
+
:param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
|
| 22 |
+
:param s: Marker size for the scatter plot points. Default is 0.8.
|
| 23 |
+
:param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
|
| 24 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Create a figure with three subplots, adjusting size and resolution
|
| 28 |
+
plt.figure(figsize=(10, 3), dpi=300)
|
| 29 |
+
|
| 30 |
+
# First subplot: Plot target coordinates
|
| 31 |
+
plt.subplot(1, 3, 1)
|
| 32 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
|
| 33 |
+
# Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
|
| 34 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 35 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 36 |
+
|
| 37 |
+
# Second subplot: Plot source coordinates
|
| 38 |
+
plt.subplot(1, 3, 2)
|
| 39 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 40 |
+
# Ensure consistent plot limits across subplots by using the same limits as the target coordinates
|
| 41 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 42 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 43 |
+
|
| 44 |
+
# Third subplot: Plot alignment of source coordinates
|
| 45 |
+
plt.subplot(1, 3, 3)
|
| 46 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 47 |
+
# Maintain the same plot limits across all subplots for a uniform comparison
|
| 48 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 49 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 50 |
+
|
| 51 |
+
# Optionally draw boundary lines at the minimum x and y values of the target coordinates
|
| 52 |
+
if boundary_line:
|
| 53 |
+
plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
|
| 54 |
+
plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
|
| 55 |
+
|
| 56 |
+
# Remove the axis labels and ticks from all subplots for a cleaner appearance
|
| 57 |
+
plt.axis('off')
|
| 58 |
+
|
| 59 |
+
# Display the plot
|
| 60 |
+
plt.show()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
|
| 65 |
+
"""
|
| 66 |
+
Plots the target coordinates and alignment of source coordinates with their respective images in the background.
|
| 67 |
+
|
| 68 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
|
| 69 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 70 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 71 |
+
:param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
|
| 72 |
+
:param src_img: Image associated with the source coordinates, used as the background in the second subplot.
|
| 73 |
+
:param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
|
| 74 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 75 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 76 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Create a figure with three subplots and set the size and resolution
|
| 80 |
+
plt.figure(figsize=(10, 8), dpi=150)
|
| 81 |
+
|
| 82 |
+
# First subplot: Plot target coordinates with the target image as the background
|
| 83 |
+
plt.subplot(1, 3, 1)
|
| 84 |
+
# Scatter plot for the target coordinates with transparency and small marker size
|
| 85 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 86 |
+
# Overlay the target image with some transparency (alpha = 0.3)
|
| 87 |
+
plt.imshow(tar_img, origin='lower', alpha=0.3)
|
| 88 |
+
|
| 89 |
+
# Second subplot: Plot source coordinates with the source image as the background
|
| 90 |
+
plt.subplot(1, 3, 2)
|
| 91 |
+
# Scatter plot for the source coordinates with transparency and small marker size
|
| 92 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 93 |
+
# Overlay the source image with some transparency (alpha = 0.3)
|
| 94 |
+
plt.imshow(src_img, origin='lower', alpha=0.3)
|
| 95 |
+
|
| 96 |
+
# Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
|
| 97 |
+
plt.subplot(1, 3, 3)
|
| 98 |
+
# Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
|
| 99 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 100 |
+
# Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
|
| 101 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 102 |
+
# Overlay the aligned image with some transparency (alpha = 0.3)
|
| 103 |
+
plt.imshow(aligned_image, origin='lower', alpha=0.3)
|
| 104 |
+
|
| 105 |
+
# Turn off the axis for all subplots to give a cleaner visual output
|
| 106 |
+
plt.axis('off')
|
| 107 |
+
|
| 108 |
+
# Display the plots
|
| 109 |
+
plt.show()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 114 |
+
"""
|
| 115 |
+
Draws one or more polygons on the given image.
|
| 116 |
+
|
| 117 |
+
:param image: The image on which to draw the polygons (as a numpy array).
|
| 118 |
+
:param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 119 |
+
:param color: A string or list of strings representing the color(s) for each polygon.
|
| 120 |
+
If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
|
| 121 |
+
:param thickness: An integer or a list of integers representing the thickness of the polygon borders.
|
| 122 |
+
If a single value is provided, it will be applied to all polygons. Default is 2.
|
| 123 |
+
|
| 124 |
+
:return: The image with the polygons drawn on it.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
|
| 128 |
+
if not isinstance(color, list):
|
| 129 |
+
color = [color] * len(polygon) # Create a list where each polygon gets the same color
|
| 130 |
+
|
| 131 |
+
# Loop through each polygon in the list, along with its corresponding color
|
| 132 |
+
for i, poly in enumerate(polygon):
|
| 133 |
+
# Get the color for the current polygon
|
| 134 |
+
c = color[i]
|
| 135 |
+
|
| 136 |
+
# Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
|
| 137 |
+
c = color_string_to_rgb(c)
|
| 138 |
+
|
| 139 |
+
# Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
|
| 140 |
+
t = thickness[i] if isinstance(thickness, list) else thickness
|
| 141 |
+
|
| 142 |
+
# Convert the polygon coordinates to a numpy array of integers
|
| 143 |
+
poly = np.array(poly, np.int32)
|
| 144 |
+
|
| 145 |
+
# Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
|
| 146 |
+
poly = poly.reshape((-1, 1, 2))
|
| 147 |
+
|
| 148 |
+
# Draw the polygon on the image using OpenCV's `cv2.polylines` function
|
| 149 |
+
# `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
|
| 150 |
+
image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
|
| 151 |
+
|
| 152 |
+
return image
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def blend_images(image1, image2, alpha=0.5):
|
| 157 |
+
"""
|
| 158 |
+
Blends two images together.
|
| 159 |
+
|
| 160 |
+
:param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
|
| 161 |
+
:param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
|
| 162 |
+
:param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
|
| 163 |
+
where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
|
| 164 |
+
|
| 165 |
+
:return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
|
| 166 |
+
The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
# Use cv2.addWeighted to blend the two images.
|
| 170 |
+
# The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
|
| 171 |
+
blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
|
| 172 |
+
|
| 173 |
+
# Return the resulting blended image.
|
| 174 |
+
return blended
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def color_string_to_rgb(color_string):
|
| 179 |
+
"""
|
| 180 |
+
Converts a color string to an RGB tuple.
|
| 181 |
+
|
| 182 |
+
:param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
|
| 183 |
+
a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
|
| 184 |
+
:return:
|
| 185 |
+
A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
|
| 186 |
+
:raises:
|
| 187 |
+
ValueError: If the color string is not recognized.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
# Remove any spaces in the color string
|
| 191 |
+
color_string = color_string.replace(' ', '')
|
| 192 |
+
|
| 193 |
+
# If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
|
| 194 |
+
if color_string.startswith('#'):
|
| 195 |
+
color_string = color_string[1:]
|
| 196 |
+
else:
|
| 197 |
+
# Handle shorthand single-letter color codes by converting them to hex values
|
| 198 |
+
# 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
|
| 199 |
+
if color_string == 'k': # Black
|
| 200 |
+
color_string = '000000'
|
| 201 |
+
elif color_string == 'r': # Red
|
| 202 |
+
color_string = 'ff0000'
|
| 203 |
+
elif color_string == 'g': # Green
|
| 204 |
+
color_string = '00ff00'
|
| 205 |
+
elif color_string == 'b': # Blue
|
| 206 |
+
color_string = '0000ff'
|
| 207 |
+
elif color_string == 'w': # White
|
| 208 |
+
color_string = 'ffffff'
|
| 209 |
+
else:
|
| 210 |
+
# Raise an error if the color string is not recognized
|
| 211 |
+
raise ValueError(f"Unknown color string {color_string}")
|
| 212 |
+
|
| 213 |
+
# Convert the first two characters to the red (R) value
|
| 214 |
+
r = int(color_string[:2], 16)
|
| 215 |
+
|
| 216 |
+
# Convert the next two characters to the green (G) value
|
| 217 |
+
g = int(color_string[2:4], 16)
|
| 218 |
+
|
| 219 |
+
# Convert the last two characters to the blue (B) value
|
| 220 |
+
b = int(color_string[4:], 16)
|
| 221 |
+
|
| 222 |
+
# Return the RGB values as a tuple
|
| 223 |
+
return (r, g, b)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def plot_heatmap(
|
| 228 |
+
coor,
|
| 229 |
+
similairty,
|
| 230 |
+
image_path=None,
|
| 231 |
+
patch_size=(256, 256),
|
| 232 |
+
save_path=None,
|
| 233 |
+
downsize=32,
|
| 234 |
+
cmap='turbo',
|
| 235 |
+
smooth=False,
|
| 236 |
+
boxes=None,
|
| 237 |
+
box_color='k',
|
| 238 |
+
box_thickness=2,
|
| 239 |
+
polygons=None,
|
| 240 |
+
polygons_color='k',
|
| 241 |
+
polygons_thickness=2,
|
| 242 |
+
image_alpha=0.5
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Plots a heatmap overlaid on an image based on given coordinates and similairty.
|
| 246 |
+
|
| 247 |
+
:param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
|
| 248 |
+
:param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
|
| 249 |
+
:param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
|
| 250 |
+
:param patch_size: Size of each patch in pixels (default is 256x256).
|
| 251 |
+
:param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
|
| 252 |
+
:param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
|
| 253 |
+
:param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
|
| 254 |
+
:param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
|
| 255 |
+
:param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
|
| 256 |
+
:param box_color: Color of the boxes. Default is black ('k').
|
| 257 |
+
:param box_thickness: Thickness of the box outlines.
|
| 258 |
+
:param polygons: List of polygons (N, 2) to draw on the heatmap.
|
| 259 |
+
:param polygons_color: Color of the polygon outlines. Default is black ('k').
|
| 260 |
+
:param polygons_thickness: Thickness of the polygon outlines.
|
| 261 |
+
:param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
|
| 262 |
+
|
| 263 |
+
:return:
|
| 264 |
+
- heatmap: The generated heatmap as a numpy array (RGB).
|
| 265 |
+
- image: The original image with overlaid polygons if provided.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
# Read the background image (if provided), otherwise a blank image
|
| 269 |
+
image = cv2.imread(image_path)
|
| 270 |
+
image_size = (image.shape[0], image.shape[1]) # Get the size of the image
|
| 271 |
+
coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
|
| 272 |
+
patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
|
| 273 |
+
|
| 274 |
+
# Convert similairties to colors using the provided colormap
|
| 275 |
+
cmap = plt.get_cmap(cmap) # Get the colormap object
|
| 276 |
+
norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
|
| 277 |
+
colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
|
| 278 |
+
|
| 279 |
+
# Initialize a blank white heatmap the size of the image
|
| 280 |
+
heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
|
| 281 |
+
|
| 282 |
+
# Place the colored patches on the heatmap according to the coordinates and patch size
|
| 283 |
+
for i in range(len(coor)):
|
| 284 |
+
x, y = coor[i]
|
| 285 |
+
w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
|
| 286 |
+
w = w.astype(np.uint8) # Convert the color to uint8
|
| 287 |
+
heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
|
| 288 |
+
|
| 289 |
+
# If the image_alpha is greater than 0, blend the heatmap with the original image
|
| 290 |
+
if image_alpha > 0:
|
| 291 |
+
image = np.array(image)
|
| 292 |
+
|
| 293 |
+
# Pad the image if necessary to match the heatmap size
|
| 294 |
+
if image.shape[0] < heatmap.shape[0]:
|
| 295 |
+
pad = heatmap.shape[0] - image.shape[0]
|
| 296 |
+
image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
|
| 297 |
+
if image.shape[1] < heatmap.shape[1]:
|
| 298 |
+
pad = heatmap.shape[1] - heatmap.shape[1]
|
| 299 |
+
image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
|
| 300 |
+
|
| 301 |
+
# Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
|
| 302 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 303 |
+
image = image.astype(np.uint8)
|
| 304 |
+
heatmap = heatmap.astype(np.uint8)
|
| 305 |
+
heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
|
| 306 |
+
|
| 307 |
+
# If polygons are provided, draw them on the heatmap and image
|
| 308 |
+
if polygons is not None:
|
| 309 |
+
polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
|
| 310 |
+
image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
|
| 311 |
+
heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
|
| 312 |
+
|
| 313 |
+
return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
|
| 314 |
+
else:
|
| 315 |
+
return heatmap, image # Return the heatmap and image
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def show_images_side_by_side(image1, image2, title1=None, title2=None):
|
| 320 |
+
"""
|
| 321 |
+
Displays two images side by side in a single figure.
|
| 322 |
+
|
| 323 |
+
:param image1: The first image to display (as a numpy array).
|
| 324 |
+
:param image2: The second image to display (as a numpy array).
|
| 325 |
+
:param title1: The title for the first image. Default is None (no title).
|
| 326 |
+
:param title2: The title for the second image. Default is None (no title).
|
| 327 |
+
:return: Displays the images side by side.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 331 |
+
fig, ax = plt.subplots(1, 2, figsize=(15,8))
|
| 332 |
+
|
| 333 |
+
# Display the first image on the first subplot
|
| 334 |
+
ax[0].imshow(image1)
|
| 335 |
+
|
| 336 |
+
# Display the second image on the second subplot
|
| 337 |
+
ax[1].imshow(image2)
|
| 338 |
+
|
| 339 |
+
# Set the title for the first image (if provided)
|
| 340 |
+
ax[0].set_title(title1)
|
| 341 |
+
|
| 342 |
+
# Set the title for the second image (if provided)
|
| 343 |
+
ax[1].set_title(title2)
|
| 344 |
+
|
| 345 |
+
# Remove axis labels and ticks for both images to give a cleaner look
|
| 346 |
+
ax[0].axis('off')
|
| 347 |
+
ax[1].axis('off')
|
| 348 |
+
|
| 349 |
+
# Show the final figure with both images displayed side by side
|
| 350 |
+
plt.show()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
|
| 355 |
+
"""
|
| 356 |
+
Plots image with polygons.
|
| 357 |
+
|
| 358 |
+
:param fullres_img: The full-resolution image to display (as a numpy array).
|
| 359 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 360 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 361 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 362 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 363 |
+
:return: Displays the image with ROI polygons overlaid.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Create a new figure with a fixed size for displaying the image and annotations
|
| 367 |
+
plt.figure(figsize=(10, 10))
|
| 368 |
+
|
| 369 |
+
# Display the full-resolution image
|
| 370 |
+
plt.imshow(fullres_img)
|
| 371 |
+
|
| 372 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 373 |
+
for polygon in roi_polygon:
|
| 374 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 375 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 376 |
+
|
| 377 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 378 |
+
plt.xlim(xlim)
|
| 379 |
+
|
| 380 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 381 |
+
plt.ylim(ylim)
|
| 382 |
+
|
| 383 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 384 |
+
plt.gca().invert_yaxis()
|
| 385 |
+
|
| 386 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 387 |
+
plt.axis('off')
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
|
| 392 |
+
"""
|
| 393 |
+
Plots tissue type annotation heatmap.
|
| 394 |
+
|
| 395 |
+
:param st_ad: AnnData object containing coordinates in `obsm['spatial']`
|
| 396 |
+
and similarity scores in `obs['bulk_simi']`.
|
| 397 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 398 |
+
:param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
|
| 399 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 400 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 401 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 402 |
+
:return: Displays the heatmap with polygons overlaid.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 406 |
+
plt.figure(figsize=(10, 10))
|
| 407 |
+
|
| 408 |
+
# Scatter plot for the spatial transcriptomics data.
|
| 409 |
+
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
| 410 |
+
plt.scatter(
|
| 411 |
+
st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
|
| 412 |
+
c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
|
| 413 |
+
s=s, # Size of each marker
|
| 414 |
+
vmin=0.1, vmax=0.95, # Set the range for the color normalization
|
| 415 |
+
cmap='turbo' # Use the 'turbo' colormap for the heatmap
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 419 |
+
for polygon in roi_polygon:
|
| 420 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 421 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 422 |
+
|
| 423 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 424 |
+
plt.xlim(xlim)
|
| 425 |
+
|
| 426 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 427 |
+
plt.ylim(ylim)
|
| 428 |
+
|
| 429 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 430 |
+
plt.gca().invert_yaxis()
|
| 431 |
+
|
| 432 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 433 |
+
plt.axis('off')
|
| 434 |
+
|
| 435 |
+
|
src/build/lib/loki/plotting.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import json
|
| 4 |
+
import cv2
|
| 5 |
+
from matplotlib import cm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
|
| 13 |
+
"""
|
| 14 |
+
Plots the target coordinates and alignment of source coordinates.
|
| 15 |
+
|
| 16 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
|
| 17 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 18 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 19 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 20 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 21 |
+
:param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
|
| 22 |
+
:param s: Marker size for the scatter plot points. Default is 0.8.
|
| 23 |
+
:param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
|
| 24 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Create a figure with three subplots, adjusting size and resolution
|
| 28 |
+
plt.figure(figsize=(10, 3), dpi=300)
|
| 29 |
+
|
| 30 |
+
# First subplot: Plot target coordinates
|
| 31 |
+
plt.subplot(1, 3, 1)
|
| 32 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
|
| 33 |
+
# Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
|
| 34 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 35 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 36 |
+
|
| 37 |
+
# Second subplot: Plot source coordinates
|
| 38 |
+
plt.subplot(1, 3, 2)
|
| 39 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 40 |
+
# Ensure consistent plot limits across subplots by using the same limits as the target coordinates
|
| 41 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 42 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 43 |
+
|
| 44 |
+
# Third subplot: Plot alignment of source coordinates
|
| 45 |
+
plt.subplot(1, 3, 3)
|
| 46 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 47 |
+
# Maintain the same plot limits across all subplots for a uniform comparison
|
| 48 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 49 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 50 |
+
|
| 51 |
+
# Optionally draw boundary lines at the minimum x and y values of the target coordinates
|
| 52 |
+
if boundary_line:
|
| 53 |
+
plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
|
| 54 |
+
plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
|
| 55 |
+
|
| 56 |
+
# Remove the axis labels and ticks from all subplots for a cleaner appearance
|
| 57 |
+
plt.axis('off')
|
| 58 |
+
|
| 59 |
+
# Display the plot
|
| 60 |
+
plt.show()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
|
| 65 |
+
"""
|
| 66 |
+
Plots the target coordinates and alignment of source coordinates with their respective images in the background.
|
| 67 |
+
|
| 68 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
|
| 69 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 70 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 71 |
+
:param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
|
| 72 |
+
:param src_img: Image associated with the source coordinates, used as the background in the second subplot.
|
| 73 |
+
:param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
|
| 74 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 75 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 76 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Create a figure with three subplots and set the size and resolution
|
| 80 |
+
plt.figure(figsize=(10, 8), dpi=150)
|
| 81 |
+
|
| 82 |
+
# First subplot: Plot target coordinates with the target image as the background
|
| 83 |
+
plt.subplot(1, 3, 1)
|
| 84 |
+
# Scatter plot for the target coordinates with transparency and small marker size
|
| 85 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 86 |
+
# Overlay the target image with some transparency (alpha = 0.3)
|
| 87 |
+
plt.imshow(tar_img, origin='lower', alpha=0.3)
|
| 88 |
+
|
| 89 |
+
# Second subplot: Plot source coordinates with the source image as the background
|
| 90 |
+
plt.subplot(1, 3, 2)
|
| 91 |
+
# Scatter plot for the source coordinates with transparency and small marker size
|
| 92 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 93 |
+
# Overlay the source image with some transparency (alpha = 0.3)
|
| 94 |
+
plt.imshow(src_img, origin='lower', alpha=0.3)
|
| 95 |
+
|
| 96 |
+
# Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
|
| 97 |
+
plt.subplot(1, 3, 3)
|
| 98 |
+
# Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
|
| 99 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 100 |
+
# Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
|
| 101 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 102 |
+
# Overlay the aligned image with some transparency (alpha = 0.3)
|
| 103 |
+
plt.imshow(aligned_image, origin='lower', alpha=0.3)
|
| 104 |
+
|
| 105 |
+
# Turn off the axis for all subplots to give a cleaner visual output
|
| 106 |
+
plt.axis('off')
|
| 107 |
+
|
| 108 |
+
# Display the plots
|
| 109 |
+
plt.show()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 114 |
+
"""
|
| 115 |
+
Draws one or more polygons on the given image.
|
| 116 |
+
|
| 117 |
+
:param image: The image on which to draw the polygons (as a numpy array).
|
| 118 |
+
:param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 119 |
+
:param color: A string or list of strings representing the color(s) for each polygon.
|
| 120 |
+
If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
|
| 121 |
+
:param thickness: An integer or a list of integers representing the thickness of the polygon borders.
|
| 122 |
+
If a single value is provided, it will be applied to all polygons. Default is 2.
|
| 123 |
+
|
| 124 |
+
:return: The image with the polygons drawn on it.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
|
| 128 |
+
if not isinstance(color, list):
|
| 129 |
+
color = [color] * len(polygon) # Create a list where each polygon gets the same color
|
| 130 |
+
|
| 131 |
+
# Loop through each polygon in the list, along with its corresponding color
|
| 132 |
+
for i, poly in enumerate(polygon):
|
| 133 |
+
# Get the color for the current polygon
|
| 134 |
+
c = color[i]
|
| 135 |
+
|
| 136 |
+
# Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
|
| 137 |
+
c = color_string_to_rgb(c)
|
| 138 |
+
|
| 139 |
+
# Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
|
| 140 |
+
t = thickness[i] if isinstance(thickness, list) else thickness
|
| 141 |
+
|
| 142 |
+
# Convert the polygon coordinates to a numpy array of integers
|
| 143 |
+
poly = np.array(poly, np.int32)
|
| 144 |
+
|
| 145 |
+
# Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
|
| 146 |
+
poly = poly.reshape((-1, 1, 2))
|
| 147 |
+
|
| 148 |
+
# Draw the polygon on the image using OpenCV's `cv2.polylines` function
|
| 149 |
+
# `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
|
| 150 |
+
image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
|
| 151 |
+
|
| 152 |
+
return image
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def blend_images(image1, image2, alpha=0.5):
|
| 157 |
+
"""
|
| 158 |
+
Blends two images together.
|
| 159 |
+
|
| 160 |
+
:param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
|
| 161 |
+
:param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
|
| 162 |
+
:param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
|
| 163 |
+
where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
|
| 164 |
+
|
| 165 |
+
:return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
|
| 166 |
+
The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
# Use cv2.addWeighted to blend the two images.
|
| 170 |
+
# The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
|
| 171 |
+
blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
|
| 172 |
+
|
| 173 |
+
# Return the resulting blended image.
|
| 174 |
+
return blended
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def color_string_to_rgb(color_string):
|
| 179 |
+
"""
|
| 180 |
+
Converts a color string to an RGB tuple.
|
| 181 |
+
|
| 182 |
+
:param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
|
| 183 |
+
a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
|
| 184 |
+
:return:
|
| 185 |
+
A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
|
| 186 |
+
:raises:
|
| 187 |
+
ValueError: If the color string is not recognized.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
# Remove any spaces in the color string
|
| 191 |
+
color_string = color_string.replace(' ', '')
|
| 192 |
+
|
| 193 |
+
# If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
|
| 194 |
+
if color_string.startswith('#'):
|
| 195 |
+
color_string = color_string[1:]
|
| 196 |
+
else:
|
| 197 |
+
# Handle shorthand single-letter color codes by converting them to hex values
|
| 198 |
+
# 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
|
| 199 |
+
if color_string == 'k': # Black
|
| 200 |
+
color_string = '000000'
|
| 201 |
+
elif color_string == 'r': # Red
|
| 202 |
+
color_string = 'ff0000'
|
| 203 |
+
elif color_string == 'g': # Green
|
| 204 |
+
color_string = '00ff00'
|
| 205 |
+
elif color_string == 'b': # Blue
|
| 206 |
+
color_string = '0000ff'
|
| 207 |
+
elif color_string == 'w': # White
|
| 208 |
+
color_string = 'ffffff'
|
| 209 |
+
else:
|
| 210 |
+
# Raise an error if the color string is not recognized
|
| 211 |
+
raise ValueError(f"Unknown color string {color_string}")
|
| 212 |
+
|
| 213 |
+
# Convert the first two characters to the red (R) value
|
| 214 |
+
r = int(color_string[:2], 16)
|
| 215 |
+
|
| 216 |
+
# Convert the next two characters to the green (G) value
|
| 217 |
+
g = int(color_string[2:4], 16)
|
| 218 |
+
|
| 219 |
+
# Convert the last two characters to the blue (B) value
|
| 220 |
+
b = int(color_string[4:], 16)
|
| 221 |
+
|
| 222 |
+
# Return the RGB values as a tuple
|
| 223 |
+
return (r, g, b)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def plot_heatmap(
|
| 228 |
+
coor,
|
| 229 |
+
similairty,
|
| 230 |
+
image_path=None,
|
| 231 |
+
patch_size=(256, 256),
|
| 232 |
+
save_path=None,
|
| 233 |
+
downsize=32,
|
| 234 |
+
cmap='turbo',
|
| 235 |
+
smooth=False,
|
| 236 |
+
boxes=None,
|
| 237 |
+
box_color='k',
|
| 238 |
+
box_thickness=2,
|
| 239 |
+
polygons=None,
|
| 240 |
+
polygons_color='k',
|
| 241 |
+
polygons_thickness=2,
|
| 242 |
+
image_alpha=0.5
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Plots a heatmap overlaid on an image based on given coordinates and similairty.
|
| 246 |
+
|
| 247 |
+
:param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
|
| 248 |
+
:param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
|
| 249 |
+
:param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
|
| 250 |
+
:param patch_size: Size of each patch in pixels (default is 256x256).
|
| 251 |
+
:param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
|
| 252 |
+
:param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
|
| 253 |
+
:param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
|
| 254 |
+
:param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
|
| 255 |
+
:param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
|
| 256 |
+
:param box_color: Color of the boxes. Default is black ('k').
|
| 257 |
+
:param box_thickness: Thickness of the box outlines.
|
| 258 |
+
:param polygons: List of polygons (N, 2) to draw on the heatmap.
|
| 259 |
+
:param polygons_color: Color of the polygon outlines. Default is black ('k').
|
| 260 |
+
:param polygons_thickness: Thickness of the polygon outlines.
|
| 261 |
+
:param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
|
| 262 |
+
|
| 263 |
+
:return:
|
| 264 |
+
- heatmap: The generated heatmap as a numpy array (RGB).
|
| 265 |
+
- image: The original image with overlaid polygons if provided.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
# Read the background image (if provided), otherwise a blank image
|
| 269 |
+
image = cv2.imread(image_path)
|
| 270 |
+
image_size = (image.shape[0], image.shape[1]) # Get the size of the image
|
| 271 |
+
coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
|
| 272 |
+
patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
|
| 273 |
+
|
| 274 |
+
# Convert similairties to colors using the provided colormap
|
| 275 |
+
cmap = plt.get_cmap(cmap) # Get the colormap object
|
| 276 |
+
norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
|
| 277 |
+
colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
|
| 278 |
+
|
| 279 |
+
# Initialize a blank white heatmap the size of the image
|
| 280 |
+
heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
|
| 281 |
+
|
| 282 |
+
# Place the colored patches on the heatmap according to the coordinates and patch size
|
| 283 |
+
for i in range(len(coor)):
|
| 284 |
+
x, y = coor[i]
|
| 285 |
+
w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
|
| 286 |
+
w = w.astype(np.uint8) # Convert the color to uint8
|
| 287 |
+
heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
|
| 288 |
+
|
| 289 |
+
# If the image_alpha is greater than 0, blend the heatmap with the original image
|
| 290 |
+
if image_alpha > 0:
|
| 291 |
+
image = np.array(image)
|
| 292 |
+
|
| 293 |
+
# Pad the image if necessary to match the heatmap size
|
| 294 |
+
if image.shape[0] < heatmap.shape[0]:
|
| 295 |
+
pad = heatmap.shape[0] - image.shape[0]
|
| 296 |
+
image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
|
| 297 |
+
if image.shape[1] < heatmap.shape[1]:
|
| 298 |
+
pad = heatmap.shape[1] - heatmap.shape[1]
|
| 299 |
+
image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
|
| 300 |
+
|
| 301 |
+
# Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
|
| 302 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 303 |
+
image = image.astype(np.uint8)
|
| 304 |
+
heatmap = heatmap.astype(np.uint8)
|
| 305 |
+
heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
|
| 306 |
+
|
| 307 |
+
# If polygons are provided, draw them on the heatmap and image
|
| 308 |
+
if polygons is not None:
|
| 309 |
+
polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
|
| 310 |
+
image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
|
| 311 |
+
heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
|
| 312 |
+
|
| 313 |
+
return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
|
| 314 |
+
else:
|
| 315 |
+
return heatmap, image # Return the heatmap and image
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def show_images_side_by_side(image1, image2, title1=None, title2=None):
|
| 320 |
+
"""
|
| 321 |
+
Displays two images side by side in a single figure.
|
| 322 |
+
|
| 323 |
+
:param image1: The first image to display (as a numpy array).
|
| 324 |
+
:param image2: The second image to display (as a numpy array).
|
| 325 |
+
:param title1: The title for the first image. Default is None (no title).
|
| 326 |
+
:param title2: The title for the second image. Default is None (no title).
|
| 327 |
+
:return: Displays the images side by side.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 331 |
+
fig, ax = plt.subplots(1, 2, figsize=(15,8))
|
| 332 |
+
|
| 333 |
+
# Display the first image on the first subplot
|
| 334 |
+
ax[0].imshow(image1)
|
| 335 |
+
|
| 336 |
+
# Display the second image on the second subplot
|
| 337 |
+
ax[1].imshow(image2)
|
| 338 |
+
|
| 339 |
+
# Set the title for the first image (if provided)
|
| 340 |
+
ax[0].set_title(title1)
|
| 341 |
+
|
| 342 |
+
# Set the title for the second image (if provided)
|
| 343 |
+
ax[1].set_title(title2)
|
| 344 |
+
|
| 345 |
+
# Remove axis labels and ticks for both images to give a cleaner look
|
| 346 |
+
ax[0].axis('off')
|
| 347 |
+
ax[1].axis('off')
|
| 348 |
+
|
| 349 |
+
# Show the final figure with both images displayed side by side
|
| 350 |
+
plt.show()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
|
| 355 |
+
"""
|
| 356 |
+
Plots image with polygons.
|
| 357 |
+
|
| 358 |
+
:param fullres_img: The full-resolution image to display (as a numpy array).
|
| 359 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 360 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 361 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 362 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 363 |
+
:return: Displays the image with ROI polygons overlaid.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Create a new figure with a fixed size for displaying the image and annotations
|
| 367 |
+
plt.figure(figsize=(10, 10))
|
| 368 |
+
|
| 369 |
+
# Display the full-resolution image
|
| 370 |
+
plt.imshow(fullres_img)
|
| 371 |
+
|
| 372 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 373 |
+
for polygon in roi_polygon:
|
| 374 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 375 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 376 |
+
|
| 377 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 378 |
+
plt.xlim(xlim)
|
| 379 |
+
|
| 380 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 381 |
+
plt.ylim(ylim)
|
| 382 |
+
|
| 383 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 384 |
+
plt.gca().invert_yaxis()
|
| 385 |
+
|
| 386 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 387 |
+
plt.axis('off')
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
|
| 392 |
+
"""
|
| 393 |
+
Plots tissue type annotation heatmap.
|
| 394 |
+
|
| 395 |
+
:param st_ad: AnnData object containing coordinates in `obsm['spatial']`
|
| 396 |
+
and similarity scores in `obs['bulk_simi']`.
|
| 397 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 398 |
+
:param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
|
| 399 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 400 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 401 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 402 |
+
:return: Displays the heatmap with polygons overlaid.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 406 |
+
plt.figure(figsize=(10, 10))
|
| 407 |
+
|
| 408 |
+
# Scatter plot for the spatial transcriptomics data.
|
| 409 |
+
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
| 410 |
+
plt.scatter(
|
| 411 |
+
st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
|
| 412 |
+
c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
|
| 413 |
+
s=s, # Size of each marker
|
| 414 |
+
vmin=0.1, vmax=0.95, # Set the range for the color normalization
|
| 415 |
+
cmap='turbo' # Use the 'turbo' colormap for the heatmap
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 419 |
+
for polygon in roi_polygon:
|
| 420 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 421 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 422 |
+
|
| 423 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 424 |
+
plt.xlim(xlim)
|
| 425 |
+
|
| 426 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 427 |
+
plt.ylim(ylim)
|
| 428 |
+
|
| 429 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 430 |
+
plt.gca().invert_yaxis()
|
| 431 |
+
|
| 432 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 433 |
+
plt.axis('off')
|
| 434 |
+
|
| 435 |
+
|
src/build/lib/loki/predex.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict_st_gene_expr(image_text_similarity, train_data):
|
| 6 |
+
"""
|
| 7 |
+
Predicts ST gene expression by H&E image.
|
| 8 |
+
|
| 9 |
+
:param image_text_similarity: Numpy array of similarities between images and text features (shape: [n_samples, n_genes]).
|
| 10 |
+
:param train_data: Numpy array or DataFrame of training data used for making predictions (shape: [n_genes, n_shared_genes]).
|
| 11 |
+
:return: Numpy array or DataFrame containing the predicted gene expression levels for the samples.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Compute the weighted sum of the train_data using image_text_similarity
|
| 15 |
+
weighted_sum = image_text_similarity @ train_data
|
| 16 |
+
|
| 17 |
+
# Compute the normalization factor (sum of the image-text similarities for each sample)
|
| 18 |
+
weights = image_text_similarity.sum(axis=1, keepdims=True)
|
| 19 |
+
|
| 20 |
+
# Normalize the predicted matrix to get weighted gene expression predictions
|
| 21 |
+
predicted_image_text_matrix = weighted_sum / weights
|
| 22 |
+
|
| 23 |
+
return predicted_image_text_matrix
|
| 24 |
+
|
| 25 |
+
|
src/build/lib/loki/preprocess.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scanpy as sc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_gene_df(ad, house_keeping_genes, todense=True):
|
| 11 |
+
"""
|
| 12 |
+
Generates a DataFrame with the top 50 genes for each observation in an AnnData object.
|
| 13 |
+
It removes genes containing '.' or '-' in their names, as well as genes listed in
|
| 14 |
+
the provided `house_keeping_genes` DataFrame/Series under the 'genesymbol' column.
|
| 15 |
+
|
| 16 |
+
:param ad: An AnnData object containing gene expression data.
|
| 17 |
+
:type ad: anndata.AnnData
|
| 18 |
+
:param house_keeping_genes: DataFrame or Series with a 'genesymbol' column listing housekeeping genes to exclude.
|
| 19 |
+
:type house_keeping_genes: pandas.DataFrame or pandas.Series
|
| 20 |
+
:param todense: Whether to convert the sparse matrix (ad.X) to a dense matrix before creating a DataFrame.
|
| 21 |
+
:type todense: bool
|
| 22 |
+
:return: A DataFrame (`top_k_genes_str`) that contains a 'label' column. Each row in 'label' is a string
|
| 23 |
+
with the top 50 gene names (space-separated) for that observation.
|
| 24 |
+
:rtype: pandas.DataFrame
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Remove genes containing '.' in their names
|
| 28 |
+
ad = ad[:, ~ad.var.index.str.contains('.', regex=False)]
|
| 29 |
+
# Remove genes containing '-'
|
| 30 |
+
ad = ad[:, ~ad.var.index.str.contains('-', regex=False)]
|
| 31 |
+
# Exclude housekeeping genes
|
| 32 |
+
ad = ad[:, ~ad.var.index.isin(house_keeping_genes['genesymbol'])]
|
| 33 |
+
|
| 34 |
+
# Convert to dense if requested; otherwise use the data as-is
|
| 35 |
+
if todense:
|
| 36 |
+
expr = pd.DataFrame(ad.X.todense(), index=ad.obs.index, columns=ad.var.index)
|
| 37 |
+
else:
|
| 38 |
+
expr = pd.DataFrame(ad.X, index=ad.obs.index, columns=ad.var.index)
|
| 39 |
+
|
| 40 |
+
# For each row (observation), find the top 50 genes with the highest expression
|
| 41 |
+
top_k_genes = expr.apply(lambda s, n: pd.Series(s.nlargest(n).index), axis=1, n=50)
|
| 42 |
+
|
| 43 |
+
# Create a new DataFrame to store the labels (space-separated top gene names)
|
| 44 |
+
top_k_genes_str = pd.DataFrame()
|
| 45 |
+
top_k_genes_str['label'] = top_k_genes[top_k_genes.columns].astype(str) \
|
| 46 |
+
.apply(lambda x: ' '.join(x), axis=1)
|
| 47 |
+
|
| 48 |
+
return top_k_genes_str
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def segment_patches(img_array, coord, patch_dir, height=20, width=20):
|
| 53 |
+
"""
|
| 54 |
+
Extracts small image patches centered at specified coordinates and saves them as individual PNG files.
|
| 55 |
+
|
| 56 |
+
:param img_array: A NumPy array representing the full-resolution image. Shape is expected to be (H, W[, C]).
|
| 57 |
+
:type img_array: numpy.ndarray
|
| 58 |
+
:param coord: A pandas DataFrame containing patch center coordinates in columns "pixel_x" and "pixel_y".
|
| 59 |
+
The index corresponds to spot IDs. Example columns: ["pixel_x", "pixel_y"].
|
| 60 |
+
:type coord: pandas.DataFrame
|
| 61 |
+
:param patch_dir: Directory path where the patch images will be saved.
|
| 62 |
+
:type patch_dir: str
|
| 63 |
+
:param height: The patch's height in pixels (distance in the y-direction).
|
| 64 |
+
:type height: int
|
| 65 |
+
:param width: The patch's width in pixels (distance in the x-direction).
|
| 66 |
+
:type width: int
|
| 67 |
+
:return: None. The function saves image patches to `patch_dir` but does not return anything.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Ensure the output directory exists; create it if it doesn't
|
| 71 |
+
if not os.path.exists(patch_dir):
|
| 72 |
+
os.makedirs(patch_dir)
|
| 73 |
+
|
| 74 |
+
# Extract the overall height and width of the image
|
| 75 |
+
yrange, xrange = img_array.shape[:2]
|
| 76 |
+
|
| 77 |
+
# Iterate through each coordinate in the DataFrame
|
| 78 |
+
for spot_idx in coord.index:
|
| 79 |
+
# Retrieve the center x and y coordinates for the current spot
|
| 80 |
+
ycenter, xcenter = coord.loc[spot_idx, ["pixel_x", "pixel_y"]]
|
| 81 |
+
|
| 82 |
+
# Compute the top-left (x1, y1) and bottom-right (x2, y2) boundaries of the patch
|
| 83 |
+
x1 = round(xcenter - width / 2)
|
| 84 |
+
y1 = round(ycenter - height / 2)
|
| 85 |
+
x2 = x1 + width
|
| 86 |
+
y2 = y1 + height
|
| 87 |
+
|
| 88 |
+
# Check if the patch boundaries go outside the image
|
| 89 |
+
if x1 < 0 or y1 < 0 or x2 > xrange or y2 > yrange:
|
| 90 |
+
print(f"Patch {spot_idx} is out of range and will be skipped.")
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# Extract the patch and convert to a PIL Image; cast to uint8 if needed
|
| 94 |
+
patch_img = Image.fromarray(img_array[y1:y2, x1:x2].astype(np.uint8))
|
| 95 |
+
|
| 96 |
+
# Create a filename for the patch image (e.g., "0_hires.png")
|
| 97 |
+
patch_name = f"{spot_idx}_hires.png"
|
| 98 |
+
patch_path = os.path.join(patch_dir, patch_name)
|
| 99 |
+
|
| 100 |
+
# Save the patch image to disk
|
| 101 |
+
patch_img.save(patch_path)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def read_gct(file_path):
|
| 106 |
+
"""
|
| 107 |
+
Reads a GCT file, parses its dimensions, and returns the data as a pandas DataFrame.
|
| 108 |
+
|
| 109 |
+
:param file_path: The path to the GCT file to be read.
|
| 110 |
+
:return: A pandas DataFrame containing the GCT data, where the first two columns represent gene names and descriptions,
|
| 111 |
+
and the subsequent columns contain the expression data.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# Open the GCT file for reading
|
| 115 |
+
with open(file_path, 'r') as file:
|
| 116 |
+
# Read and ignore the first line (GCT version line)
|
| 117 |
+
file.readline()
|
| 118 |
+
|
| 119 |
+
# Read the second line which contains the dimensions of the data matrix
|
| 120 |
+
dims = file.readline().strip().split() # Split the dimensions line by whitespace
|
| 121 |
+
num_rows = int(dims[0]) # Number of data rows (genes)
|
| 122 |
+
num_cols = int(dims[1]) # Number of data columns (samples + metadata)
|
| 123 |
+
|
| 124 |
+
# Read the data starting from the third line, using pandas for tab-delimited data
|
| 125 |
+
# The first two columns in GCT files are "Name" and "Description" (gene identifiers and annotations)
|
| 126 |
+
data = pd.read_csv(file, sep='\t', header=0, nrows=num_rows)
|
| 127 |
+
|
| 128 |
+
# Return the loaded data as a pandas DataFrame
|
| 129 |
+
return data
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_library_id(adata):
|
| 134 |
+
"""
|
| 135 |
+
Retrieves the library ID from the AnnData object, assuming it contains spatial data.
|
| 136 |
+
The function will return the first library ID found in `adata.uns['spatial']`.
|
| 137 |
+
|
| 138 |
+
:param adata: AnnData object containing spatial information in `adata.uns['spatial']`.
|
| 139 |
+
:return: The first library ID found in `adata.uns['spatial']`.
|
| 140 |
+
:raises:
|
| 141 |
+
AssertionError: If 'spatial' is not present in `adata.uns`.
|
| 142 |
+
Logs an error if no library ID is found.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# Check if 'spatial' is present in adata.uns; raises an error if not found
|
| 146 |
+
assert 'spatial' in adata.uns, "spatial not present in adata.uns"
|
| 147 |
+
|
| 148 |
+
# Retrieve the list of library IDs (which are keys in the 'spatial' dictionary)
|
| 149 |
+
library_ids = adata.uns['spatial'].keys()
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Attempt to return the first library ID (converting the keys object to a list)
|
| 153 |
+
library_id = list(library_ids)[0]
|
| 154 |
+
return library_id
|
| 155 |
+
except IndexError:
|
| 156 |
+
# If no library IDs exist, log an error message
|
| 157 |
+
logger.error('No library_id found in adata')
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_scalefactors(adata, library_id=None):
|
| 162 |
+
"""
|
| 163 |
+
Retrieves the scalefactors from the AnnData object for a given library ID. If no library ID is provided,
|
| 164 |
+
the function will automatically retrieve the first available library ID.
|
| 165 |
+
|
| 166 |
+
:param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
|
| 167 |
+
:param library_id: The library ID for which the scalefactors are to be retrieved. If not provided, it defaults to the first available ID.
|
| 168 |
+
:return: A dictionary containing scalefactors for the specified library ID.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# If no library_id is provided, retrieve the first available library ID
|
| 172 |
+
if library_id is None:
|
| 173 |
+
library_id = get_library_id(adata)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
# Attempt to retrieve the scalefactors for the specified library ID
|
| 177 |
+
scalef = adata.uns['spatial'][library_id]['scalefactors']
|
| 178 |
+
return scalef
|
| 179 |
+
except KeyError:
|
| 180 |
+
# Log an error if the scalefactors or library ID is not found
|
| 181 |
+
logger.error('scalefactors not found in adata')
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_spot_diameter_in_pixels(adata, library_id=None):
|
| 186 |
+
"""
|
| 187 |
+
Retrieves the spot diameter in pixels from the AnnData object's scalefactors for a given library ID.
|
| 188 |
+
If no library ID is provided, the function will automatically retrieve the first available library ID.
|
| 189 |
+
|
| 190 |
+
:param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
|
| 191 |
+
:param library_id: The library ID for which the spot diameter is to be retrieved. If not provided, defaults to the first available ID.
|
| 192 |
+
|
| 193 |
+
:return: The spot diameter in full resolution pixels, or None if not found.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
# Get the scalefactors for the specified or default library ID
|
| 197 |
+
scalef = get_scalefactors(adata, library_id=library_id)
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
# Attempt to retrieve the spot diameter in full resolution from the scalefactors
|
| 201 |
+
spot_diameter = scalef['spot_diameter_fullres']
|
| 202 |
+
return spot_diameter
|
| 203 |
+
except TypeError:
|
| 204 |
+
# Handle case where `scalef` is None or invalid (if get_scalefactors returned None)
|
| 205 |
+
pass
|
| 206 |
+
except KeyError:
|
| 207 |
+
# Log an error if the 'spot_diameter_fullres' key is not found in the scalefactors
|
| 208 |
+
logger.error('spot_diameter_fullres not found in adata')
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def prepare_data_for_alignment(data_path, scale_type='tissue_hires_scalef'):
|
| 213 |
+
"""
|
| 214 |
+
Prepares data for alignment by reading an AnnData object and preparing the high-resolution tissue image.
|
| 215 |
+
|
| 216 |
+
:param data_path: The path to the AnnData (.h5ad) file containing the Visium data.
|
| 217 |
+
:param scale_type: The type of scale factor to use (`tissue_hires_scalef` by default).
|
| 218 |
+
|
| 219 |
+
:return:
|
| 220 |
+
- ad: AnnData object containing the spatial transcriptomics data.
|
| 221 |
+
- ad_coor: Numpy array of scaled spatial coordinates (adjusted for the specified resolution).
|
| 222 |
+
- img: High-resolution tissue image, normalized to 8-bit unsigned integers.
|
| 223 |
+
|
| 224 |
+
:raises:
|
| 225 |
+
ValueError: If required data (e.g., scale factors, spatial coordinates, or images) is missing.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
# Load the AnnData object from the specified file path
|
| 229 |
+
ad = sc.read_h5ad(data_path)
|
| 230 |
+
|
| 231 |
+
# Ensure the variable (gene) names are unique to avoid potential conflicts
|
| 232 |
+
ad.var_names_make_unique()
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Retrieve the specified scale factor for spatial coordinates
|
| 236 |
+
scalef = get_scalefactors(ad)[scale_type]
|
| 237 |
+
except KeyError:
|
| 238 |
+
raise ValueError(f"Scale factor '{scale_type}' not found in ad.uns['spatial']")
|
| 239 |
+
|
| 240 |
+
# Scale the spatial coordinates using the specified scale factor
|
| 241 |
+
try:
|
| 242 |
+
ad_coor = np.array(ad.obsm['spatial']) * scalef
|
| 243 |
+
except KeyError:
|
| 244 |
+
raise ValueError("Spatial coordinates not found in ad.obsm['spatial']")
|
| 245 |
+
|
| 246 |
+
# Retrieve the high-resolution tissue image
|
| 247 |
+
try:
|
| 248 |
+
img = ad.uns['spatial'][get_library_id(ad)]['images']['hires']
|
| 249 |
+
except KeyError:
|
| 250 |
+
raise ValueError("High-resolution image not found in ad.uns['spatial']")
|
| 251 |
+
|
| 252 |
+
# If the image values are normalized to [0, 1], convert to 8-bit format for compatibility
|
| 253 |
+
if img.max() < 1.1:
|
| 254 |
+
img = (img * 255).astype('uint8')
|
| 255 |
+
|
| 256 |
+
return ad, ad_coor, img
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def load_data_for_annotation(st_data_path, json_path, in_tissue=True):
|
| 261 |
+
"""
|
| 262 |
+
Loads spatial transcriptomics (ST) data from an .h5ad file and prepares it for annotation.
|
| 263 |
+
|
| 264 |
+
:param sample_type: The type or category of the sample (used to locate the data in the directory structure).
|
| 265 |
+
:param sample_name: The name of the sample (used to locate specific files).
|
| 266 |
+
:param in_tissue: Boolean flag to filter the data to include only spots that are in tissue. Default is True.
|
| 267 |
+
|
| 268 |
+
:return:
|
| 269 |
+
- st_ad: AnnData object containing the spatial transcriptomics data, with spatial coordinates in `obs`.
|
| 270 |
+
- library_id: The library ID associated with the spatial data.
|
| 271 |
+
- roi_polygon: Region of interest polygon loaded from a JSON file for further annotation or analysis.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
# Load the spatial transcriptomics data into an AnnData object
|
| 275 |
+
st_ad = sc.read_h5ad(st_data_path)
|
| 276 |
+
|
| 277 |
+
# Optionally filter the data to include only spots that are within the tissue
|
| 278 |
+
if in_tissue:
|
| 279 |
+
st_ad = st_ad[st_ad.obs['in_tissue'] == 1]
|
| 280 |
+
|
| 281 |
+
# Initialize pixel coordinates for spatial information
|
| 282 |
+
st_ad.obs[["pixel_y", "pixel_x"]] = None # Ensure the columns exist
|
| 283 |
+
st_ad.obs[["pixel_y", "pixel_x"]] = st_ad.obsm['spatial'] # Copy spatial coordinates into obs
|
| 284 |
+
|
| 285 |
+
# Retrieve the library ID associated with the spatial data
|
| 286 |
+
library_id = get_library_id(st_ad)
|
| 287 |
+
|
| 288 |
+
# Load the region of interest (ROI) polygon from a JSON file
|
| 289 |
+
with open(json_path) as f:
|
| 290 |
+
roi_polygon = json.load(f)
|
| 291 |
+
|
| 292 |
+
return st_ad, library_id, roi_polygon
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def read_polygons(file_path, slide_id):
|
| 297 |
+
"""
|
| 298 |
+
Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
|
| 299 |
+
|
| 300 |
+
:param file_path: Path to the JSON file containing polygon configurations.
|
| 301 |
+
:param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
|
| 302 |
+
:return:
|
| 303 |
+
- polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
|
| 304 |
+
- polygon_colors: A list of color values corresponding to each polygon.
|
| 305 |
+
- polygon_thickness: A list of thickness values for each polygon's border.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
# Open the JSON file and load the polygon configurations into a Python dictionary
|
| 309 |
+
with open(file_path, 'r') as f:
|
| 310 |
+
polygons_configs = json.load(f)
|
| 311 |
+
|
| 312 |
+
# Check if the given slide_id exists in the polygon configurations
|
| 313 |
+
if slide_id not in polygons_configs:
|
| 314 |
+
return None, None, None # If slide_id is not found, return None for all outputs
|
| 315 |
+
|
| 316 |
+
# Extract the polygon coordinates, colors, and thicknesses for the given slide_id
|
| 317 |
+
polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
|
| 318 |
+
polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
|
| 319 |
+
polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
|
| 320 |
+
|
| 321 |
+
# Return the polygons, their colors, and their thicknesses
|
| 322 |
+
return polygons, polygon_colors, polygon_thickness
|
| 323 |
+
|
| 324 |
+
|
src/build/lib/loki/retrieve.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3):
|
| 6 |
+
"""
|
| 7 |
+
Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings.
|
| 8 |
+
|
| 9 |
+
:param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]).
|
| 10 |
+
:param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]).
|
| 11 |
+
:param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column.
|
| 12 |
+
:param k: The number of top similar samples to retrieve. Default is 3.
|
| 13 |
+
:return: A list of the filenames or indices corresponding to the top-k similar samples.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# Compute the dot product (similarity) between the image embeddings and all ST embeddings
|
| 17 |
+
dot_similarity = image_embeddings @ all_text_embeddings.T
|
| 18 |
+
|
| 19 |
+
# Retrieve the top-k most similar samples by similarity score (dot product)
|
| 20 |
+
values, indices = torch.topk(dot_similarity.squeeze(0), k)
|
| 21 |
+
|
| 22 |
+
# Extract the image filenames or indices from the DataFrame based on the top-k matches
|
| 23 |
+
image_filenames = dataframe['img_idx'].values
|
| 24 |
+
matches = [image_filenames[idx] for idx in indices]
|
| 25 |
+
|
| 26 |
+
return matches
|
| 27 |
+
|
| 28 |
+
|
src/build/lib/loki/utilities.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import json
|
| 8 |
+
import cv2
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model(model_path, device):
|
| 15 |
+
model, preprocess = create_model_from_pretrained("coca_ViT-L-14", device=device, pretrained=model_path)
|
| 16 |
+
tokenizer = get_tokenizer('coca_ViT-L-14')
|
| 17 |
+
|
| 18 |
+
return model, preprocess, tokenizer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def encode_image(model, preprocess, image):
|
| 23 |
+
image_input = torch.stack([preprocess(image)])
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
image_features = model.encode_image(image_input)
|
| 26 |
+
image_embeddings = F.normalize(image_features, p=2, dim=-1)
|
| 27 |
+
|
| 28 |
+
return image_embeddings
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def encode_image_patches(model, preprocess, data_dir, img_list):
|
| 33 |
+
image_embeddings = []
|
| 34 |
+
for img_name in img_list:
|
| 35 |
+
image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
|
| 36 |
+
image = Image.open(image_path)
|
| 37 |
+
image_features = encode_image(model, preprocess, image)
|
| 38 |
+
image_embeddings.append(image_features)
|
| 39 |
+
image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 40 |
+
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 41 |
+
return image_embeddings
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def encode_text(model, tokenizer, text):
|
| 46 |
+
text_input = tokenizer(text)
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
text_features = model.encode_text(text_input)
|
| 49 |
+
text_embeddings = F.normalize(text_features, p=2, dim=-1)
|
| 50 |
+
|
| 51 |
+
return text_embeddings
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def encode_text_df(model, tokenizer, df, col_name):
|
| 56 |
+
text_embeddings = []
|
| 57 |
+
for idx in df.index:
|
| 58 |
+
text = df[df.index==idx][col_name][0]
|
| 59 |
+
text_features = encode_text(model, tokenizer, text)
|
| 60 |
+
text_embeddings.append(text_features)
|
| 61 |
+
text_embeddings = torch.from_numpy(np.array(text_embeddings))
|
| 62 |
+
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| 63 |
+
return text_embeddings
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_pca_by_fit(tar_features, src_features):
|
| 68 |
+
"""
|
| 69 |
+
Applies PCA to target features and transforms both target and source features using the fitted PCA model.
|
| 70 |
+
Combines the PCA-transformed features from both target and source datasets and returns the combined data
|
| 71 |
+
along with batch labels indicating the origin of each sample.
|
| 72 |
+
|
| 73 |
+
:param tar_features: Numpy array of target features (samples by features).
|
| 74 |
+
:param src_features: Numpy array of source features (samples by features).
|
| 75 |
+
:return:
|
| 76 |
+
- pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
|
| 77 |
+
- pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
pca = PCA(n_components=3)
|
| 81 |
+
|
| 82 |
+
# Fit the PCA model on the target features (transposed to fit on features)
|
| 83 |
+
pca_fit_tar = pca.fit(tar_features.T)
|
| 84 |
+
|
| 85 |
+
# Transform the target and source features using the fitted PCA model
|
| 86 |
+
pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
|
| 87 |
+
pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
|
| 88 |
+
|
| 89 |
+
# Combine the PCA-transformed target and source features
|
| 90 |
+
pca_comb_features = np.concatenate((pca_tar, pca_src))
|
| 91 |
+
|
| 92 |
+
# Create a batch label array: 0 for target features, 1 for source features
|
| 93 |
+
pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
|
| 94 |
+
|
| 95 |
+
return pca_comb_features, pca_comb_features_batch
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cap_quantile(weight, cap_max=None, cap_min=None):
|
| 100 |
+
"""
|
| 101 |
+
Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
|
| 102 |
+
If the quantile thresholds are provided, the function will replace values above or below these thresholds
|
| 103 |
+
with the corresponding quantile values.
|
| 104 |
+
|
| 105 |
+
:param weight: Numpy array of weights to be capped.
|
| 106 |
+
:param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
|
| 107 |
+
If None, no maximum capping will be applied.
|
| 108 |
+
:param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
|
| 109 |
+
If None, no minimum capping will be applied.
|
| 110 |
+
:return: Numpy array with the values capped at the specified quantiles.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# If a maximum cap is specified, calculate the value at the specified cap_max quantile
|
| 114 |
+
if cap_max is not None:
|
| 115 |
+
cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
|
| 116 |
+
|
| 117 |
+
# If a minimum cap is specified, calculate the value at the specified cap_min quantile
|
| 118 |
+
if cap_min is not None:
|
| 119 |
+
cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
|
| 120 |
+
|
| 121 |
+
# Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
|
| 122 |
+
weight = np.minimum(weight, cap_max)
|
| 123 |
+
|
| 124 |
+
# Cap the values in 'weight' array to not go below the minimum cap (cap_min)
|
| 125 |
+
weight = np.maximum(weight, cap_min)
|
| 126 |
+
|
| 127 |
+
return weight
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def read_polygons(file_path, slide_id):
|
| 132 |
+
"""
|
| 133 |
+
Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
|
| 134 |
+
|
| 135 |
+
:param file_path: Path to the JSON file containing polygon configurations.
|
| 136 |
+
:param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
|
| 137 |
+
:return:
|
| 138 |
+
- polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
|
| 139 |
+
- polygon_colors: A list of color values corresponding to each polygon.
|
| 140 |
+
- polygon_thickness: A list of thickness values for each polygon's border.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
# Open the JSON file and load the polygon configurations into a Python dictionary
|
| 144 |
+
with open(file_path, 'r') as f:
|
| 145 |
+
polygons_configs = json.load(f)
|
| 146 |
+
|
| 147 |
+
# Check if the given slide_id exists in the polygon configurations
|
| 148 |
+
if slide_id not in polygons_configs:
|
| 149 |
+
return None, None, None # If slide_id is not found, return None for all outputs
|
| 150 |
+
|
| 151 |
+
# Extract the polygon coordinates, colors, and thicknesses for the given slide_id
|
| 152 |
+
polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
|
| 153 |
+
polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
|
| 154 |
+
polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
|
| 155 |
+
|
| 156 |
+
# Return the polygons, their colors, and their thicknesses
|
| 157 |
+
return polygons, polygon_colors, polygon_thickness
|
| 158 |
+
|
| 159 |
+
|
src/build/lib/loki/utils.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import json
|
| 8 |
+
import cv2
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model(model_path, device):
|
| 15 |
+
"""
|
| 16 |
+
Loads a pretrained CoCa (CLIP-like) model, along with its preprocessing function and tokenizer,
|
| 17 |
+
using the specified model checkpoint.
|
| 18 |
+
|
| 19 |
+
:param model_path: File path or URL to the pretrained model checkpoint. This is passed to
|
| 20 |
+
`create_model_from_pretrained` as the `pretrained` argument.
|
| 21 |
+
:type model_path: str
|
| 22 |
+
:param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
|
| 23 |
+
:type device: str or torch.device
|
| 24 |
+
:return: A tuple `(model, preprocess, tokenizer)` where:
|
| 25 |
+
- model: The loaded CoCa model.
|
| 26 |
+
- preprocess: A function or transform that preprocesses input data for the model.
|
| 27 |
+
- tokenizer: A tokenizer appropriate for textual input to the model.
|
| 28 |
+
:rtype: (nn.Module, callable, callable)
|
| 29 |
+
"""
|
| 30 |
+
# Create the model and its preprocessing transform from the specified checkpoint
|
| 31 |
+
model, preprocess = create_model_from_pretrained(
|
| 32 |
+
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
|
| 36 |
+
tokenizer = get_tokenizer('coca_ViT-L-14')
|
| 37 |
+
|
| 38 |
+
return model, preprocess, tokenizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def encode_image(model, preprocess, image):
|
| 43 |
+
"""
|
| 44 |
+
Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
|
| 45 |
+
|
| 46 |
+
:param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
|
| 47 |
+
:type model: torch.nn.Module
|
| 48 |
+
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 49 |
+
suitable for the model. Typically something returning a PyTorch tensor.
|
| 50 |
+
:type preprocess: callable
|
| 51 |
+
:param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
|
| 52 |
+
:type image: PIL.Image.Image or numpy.ndarray
|
| 53 |
+
:return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
|
| 54 |
+
:rtype: torch.Tensor
|
| 55 |
+
"""
|
| 56 |
+
# Preprocess the image, then stack to create a batch of size 1
|
| 57 |
+
image_input = torch.stack([preprocess(image)])
|
| 58 |
+
|
| 59 |
+
# Generate the image features without gradient tracking
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
image_features = model.encode_image(image_input)
|
| 62 |
+
|
| 63 |
+
# Normalize embeddings across the feature dimension (L2 normalization)
|
| 64 |
+
image_embeddings = F.normalize(image_features, p=2, dim=-1)
|
| 65 |
+
|
| 66 |
+
return image_embeddings
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_image_patches(model, preprocess, data_dir, img_list):
|
| 71 |
+
"""
|
| 72 |
+
Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
|
| 73 |
+
|
| 74 |
+
:param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
|
| 75 |
+
:type model: torch.nn.Module
|
| 76 |
+
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 77 |
+
suitable for the model. Typically something returning a PyTorch tensor.
|
| 78 |
+
:type preprocess: callable
|
| 79 |
+
:param data_dir: The base directory containing image data.
|
| 80 |
+
:type data_dir: str
|
| 81 |
+
:param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
|
| 82 |
+
stored in `data_dir/demo_data/patch/`.
|
| 83 |
+
:type img_list: list[str]
|
| 84 |
+
:return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
|
| 85 |
+
for each image in `img_list`.
|
| 86 |
+
:rtype: torch.Tensor
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# Prepare a list to hold each image's feature embedding
|
| 90 |
+
image_embeddings = []
|
| 91 |
+
|
| 92 |
+
# Loop through each image name in the provided list
|
| 93 |
+
for img_name in img_list:
|
| 94 |
+
# Build the path to the patch image and open it
|
| 95 |
+
image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
|
| 96 |
+
image = Image.open(image_path)
|
| 97 |
+
|
| 98 |
+
# Encode the image using the model & preprocess; returns shape (1, embedding_dim)
|
| 99 |
+
image_features = encode_image(model, preprocess, image)
|
| 100 |
+
|
| 101 |
+
# Accumulate the feature embeddings in the list
|
| 102 |
+
image_embeddings.append(image_features)
|
| 103 |
+
|
| 104 |
+
# Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 105 |
+
# Resulting shape will be (N, 1, embedding_dim)
|
| 106 |
+
image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 107 |
+
|
| 108 |
+
# Normalize all embeddings across the feature dimension (L2 normalization)
|
| 109 |
+
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 110 |
+
|
| 111 |
+
return image_embeddings
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def encode_text(model, tokenizer, text):
|
| 116 |
+
"""
|
| 117 |
+
Encodes text into a normalized feature embedding using a specified model and tokenizer.
|
| 118 |
+
|
| 119 |
+
:param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
|
| 120 |
+
:type model: torch.nn.Module
|
| 121 |
+
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 122 |
+
Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
|
| 123 |
+
:type tokenizer: callable
|
| 124 |
+
:param text: The input text (string or list of strings) to be encoded.
|
| 125 |
+
:type text: str or list[str]
|
| 126 |
+
:return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
|
| 127 |
+
:rtype: torch.Tensor
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
# Convert text to the appropriate tokenized representation
|
| 131 |
+
text_input = tokenizer(text)
|
| 132 |
+
|
| 133 |
+
# Run the model in no-grad mode (not tracking gradients, saving memory and compute)
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
text_features = model.encode_text(text_input)
|
| 136 |
+
|
| 137 |
+
# Normalize embeddings to unit length
|
| 138 |
+
text_embeddings = F.normalize(text_features, p=2, dim=-1)
|
| 139 |
+
|
| 140 |
+
return text_embeddings
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def encode_text_df(model, tokenizer, df, col_name):
|
| 145 |
+
"""
|
| 146 |
+
Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
|
| 147 |
+
returning a PyTorch tensor of normalized text embeddings.
|
| 148 |
+
|
| 149 |
+
:param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
|
| 150 |
+
:type model: torch.nn.Module
|
| 151 |
+
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 152 |
+
:type tokenizer: callable
|
| 153 |
+
:param df: A pandas DataFrame from which text will be extracted.
|
| 154 |
+
:type df: pandas.DataFrame
|
| 155 |
+
:param col_name: The name of the column in `df` that contains the text to be encoded.
|
| 156 |
+
:type col_name: str
|
| 157 |
+
:return: A PyTorch tensor containing the L2-normalized text embeddings,
|
| 158 |
+
where the shape is (number_of_rows, embedding_dim).
|
| 159 |
+
:rtype: torch.Tensor
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
# Prepare a list to hold each row's text embedding
|
| 163 |
+
text_embeddings = []
|
| 164 |
+
|
| 165 |
+
# Loop through each index in the DataFrame
|
| 166 |
+
for idx in df.index:
|
| 167 |
+
# Retrieve text from the specified column for the current row
|
| 168 |
+
text = df[df.index == idx][col_name][0]
|
| 169 |
+
|
| 170 |
+
# Encode the text using the provided model and tokenizer
|
| 171 |
+
text_features = encode_text(model, tokenizer, text)
|
| 172 |
+
|
| 173 |
+
# Accumulate the embedding tensor
|
| 174 |
+
text_embeddings.append(text_features)
|
| 175 |
+
|
| 176 |
+
# Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
|
| 177 |
+
text_embeddings = torch.from_numpy(np.array(text_embeddings))
|
| 178 |
+
|
| 179 |
+
# Normalize embeddings to unit length across the feature dimension
|
| 180 |
+
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| 181 |
+
|
| 182 |
+
return text_embeddings
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_pca_by_fit(tar_features, src_features):
|
| 187 |
+
"""
|
| 188 |
+
Applies PCA to target features and transforms both target and source features using the fitted PCA model.
|
| 189 |
+
Combines the PCA-transformed features from both target and source datasets and returns the combined data
|
| 190 |
+
along with batch labels indicating the origin of each sample.
|
| 191 |
+
|
| 192 |
+
:param tar_features: Numpy array of target features (samples by features).
|
| 193 |
+
:param src_features: Numpy array of source features (samples by features).
|
| 194 |
+
:return:
|
| 195 |
+
- pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
|
| 196 |
+
- pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
pca = PCA(n_components=3)
|
| 200 |
+
|
| 201 |
+
# Fit the PCA model on the target features (transposed to fit on features)
|
| 202 |
+
pca_fit_tar = pca.fit(tar_features.T)
|
| 203 |
+
|
| 204 |
+
# Transform the target and source features using the fitted PCA model
|
| 205 |
+
pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
|
| 206 |
+
pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
|
| 207 |
+
|
| 208 |
+
# Combine the PCA-transformed target and source features
|
| 209 |
+
pca_comb_features = np.concatenate((pca_tar, pca_src))
|
| 210 |
+
|
| 211 |
+
# Create a batch label array: 0 for target features, 1 for source features
|
| 212 |
+
pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
|
| 213 |
+
|
| 214 |
+
return pca_comb_features, pca_comb_features_batch
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def cap_quantile(weight, cap_max=None, cap_min=None):
|
| 219 |
+
"""
|
| 220 |
+
Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
|
| 221 |
+
If the quantile thresholds are provided, the function will replace values above or below these thresholds
|
| 222 |
+
with the corresponding quantile values.
|
| 223 |
+
|
| 224 |
+
:param weight: Numpy array of weights to be capped.
|
| 225 |
+
:param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
|
| 226 |
+
If None, no maximum capping will be applied.
|
| 227 |
+
:param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
|
| 228 |
+
If None, no minimum capping will be applied.
|
| 229 |
+
:return: Numpy array with the values capped at the specified quantiles.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
# If a maximum cap is specified, calculate the value at the specified cap_max quantile
|
| 233 |
+
if cap_max is not None:
|
| 234 |
+
cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
|
| 235 |
+
|
| 236 |
+
# If a minimum cap is specified, calculate the value at the specified cap_min quantile
|
| 237 |
+
if cap_min is not None:
|
| 238 |
+
cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
|
| 239 |
+
|
| 240 |
+
# Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
|
| 241 |
+
weight = np.minimum(weight, cap_max)
|
| 242 |
+
|
| 243 |
+
# Cap the values in 'weight' array to not go below the minimum cap (cap_min)
|
| 244 |
+
weight = np.maximum(weight, cap_min)
|
| 245 |
+
|
| 246 |
+
return weight
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def read_polygons(file_path, slide_id):
|
| 251 |
+
"""
|
| 252 |
+
Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
|
| 253 |
+
|
| 254 |
+
:param file_path: Path to the JSON file containing polygon configurations.
|
| 255 |
+
:param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
|
| 256 |
+
:return:
|
| 257 |
+
- polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
|
| 258 |
+
- polygon_colors: A list of color values corresponding to each polygon.
|
| 259 |
+
- polygon_thickness: A list of thickness values for each polygon's border.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
# Open the JSON file and load the polygon configurations into a Python dictionary
|
| 263 |
+
with open(file_path, 'r') as f:
|
| 264 |
+
polygons_configs = json.load(f)
|
| 265 |
+
|
| 266 |
+
# Check if the given slide_id exists in the polygon configurations
|
| 267 |
+
if slide_id not in polygons_configs:
|
| 268 |
+
return None, None, None # If slide_id is not found, return None for all outputs
|
| 269 |
+
|
| 270 |
+
# Extract the polygon coordinates, colors, and thicknesses for the given slide_id
|
| 271 |
+
polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
|
| 272 |
+
polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
|
| 273 |
+
polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
|
| 274 |
+
|
| 275 |
+
# Return the polygons, their colors, and their thicknesses
|
| 276 |
+
return polygons, polygon_colors, polygon_thickness
|
| 277 |
+
|
| 278 |
+
|
src/dist/loki-0.0.1-py3-none-any.whl
ADDED
|
Binary file (22.2 kB). View file
|
|
|
src/dist/loki-0.0.1.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98f4615e981aeb895088cb71b1f358a9d6470302043c7cab8bc15396b9cbbe0d
|
| 3 |
+
size 20339
|
src/loki.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: loki
|
| 3 |
+
Version: 0.0.1
|
| 4 |
+
Summary: The Loki platform offers 5 core functions: tissue alignment, cell type decomposition, tissue annotation, image-transcriptomics retrieval, and ST gene expression prediction
|
| 5 |
+
Author: Weiqing Chen
|
| 6 |
+
Author-email: wec4005@med.cornell.edu
|
| 7 |
+
Classifier: Programming Language :: Python :: 3
|
| 8 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 9 |
+
Classifier: Operating System :: OS Independent
|
| 10 |
+
Requires-Python: >=3.9
|
| 11 |
+
Requires-Dist: anndata==0.10.9
|
| 12 |
+
Requires-Dist: matplotlib==3.9.2
|
| 13 |
+
Requires-Dist: numpy==1.25.0
|
| 14 |
+
Requires-Dist: pandas==2.2.3
|
| 15 |
+
Requires-Dist: opencv-python==4.10.0.84
|
| 16 |
+
Requires-Dist: pycpd==2.0.0
|
| 17 |
+
Requires-Dist: torch==2.3.1
|
| 18 |
+
Requires-Dist: tangram-sc==1.0.4
|
| 19 |
+
Requires-Dist: tqdm==4.66.5
|
| 20 |
+
Requires-Dist: torchvision==0.18.1
|
| 21 |
+
Requires-Dist: open_clip_torch==2.26.1
|
| 22 |
+
Requires-Dist: pillow==10.4.0
|
| 23 |
+
Requires-Dist: ipykernel==6.29.5
|
src/loki.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
setup.py
|
| 3 |
+
loki/__init__.py
|
| 4 |
+
loki/align.py
|
| 5 |
+
loki/annotate.py
|
| 6 |
+
loki/decompose.py
|
| 7 |
+
loki/plot.py
|
| 8 |
+
loki/predex.py
|
| 9 |
+
loki/preprocess.py
|
| 10 |
+
loki/retrieve.py
|
| 11 |
+
loki/utils.py
|
| 12 |
+
loki.egg-info/PKG-INFO
|
| 13 |
+
loki.egg-info/SOURCES.txt
|
| 14 |
+
loki.egg-info/dependency_links.txt
|
| 15 |
+
loki.egg-info/requires.txt
|
| 16 |
+
loki.egg-info/top_level.txt
|
src/loki.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/loki.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anndata==0.10.9
|
| 2 |
+
matplotlib==3.9.2
|
| 3 |
+
numpy==1.25.0
|
| 4 |
+
pandas==2.2.3
|
| 5 |
+
opencv-python==4.10.0.84
|
| 6 |
+
pycpd==2.0.0
|
| 7 |
+
torch==2.3.1
|
| 8 |
+
tangram-sc==1.0.4
|
| 9 |
+
tqdm==4.66.5
|
| 10 |
+
torchvision==0.18.1
|
| 11 |
+
open_clip_torch==2.26.1
|
| 12 |
+
pillow==10.4.0
|
| 13 |
+
ipykernel==6.29.5
|
src/loki.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
loki
|
src/loki/__init__.py
ADDED
|
File without changes
|
src/loki/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (139 Bytes). View file
|
|
|
src/loki/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
src/loki/__pycache__/align.cpython-39.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
src/loki/__pycache__/annotate.cpython-39.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
src/loki/__pycache__/decompose.cpython-39.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|
src/loki/__pycache__/deconv.cpython-39.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
src/loki/__pycache__/plot.cpython-39.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
src/loki/__pycache__/predex.cpython-39.pyc
ADDED
|
Binary file (904 Bytes). View file
|
|
|
src/loki/__pycache__/preprocess.cpython-39.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
src/loki/__pycache__/retrieve.cpython-39.pyc
ADDED
|
Binary file (1.38 kB). View file
|
|
|
src/loki/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (9.44 kB). View file
|
|
|
src/loki/align.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pycpd
|
| 2 |
+
from builtins import super
|
| 3 |
+
import numbers
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
class EMRegistration(object):
|
| 8 |
+
"""
|
| 9 |
+
Expectation maximization point cloud registration.
|
| 10 |
+
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
|
| 11 |
+
https://github.com/siavashk/pycpd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
Attributes
|
| 15 |
+
----------
|
| 16 |
+
X: numpy array
|
| 17 |
+
NxD array of target points.
|
| 18 |
+
|
| 19 |
+
Y: numpy array
|
| 20 |
+
MxD array of source points.
|
| 21 |
+
|
| 22 |
+
TY: numpy array
|
| 23 |
+
MxD array of transformed source points.
|
| 24 |
+
|
| 25 |
+
sigma2: float (positive)
|
| 26 |
+
Initial variance of the Gaussian mixture model.
|
| 27 |
+
|
| 28 |
+
N: int
|
| 29 |
+
Number of target points.
|
| 30 |
+
|
| 31 |
+
M: int
|
| 32 |
+
Number of source points.
|
| 33 |
+
|
| 34 |
+
D: int
|
| 35 |
+
Dimensionality of source and target points
|
| 36 |
+
|
| 37 |
+
iteration: int
|
| 38 |
+
The current iteration throughout registration.
|
| 39 |
+
|
| 40 |
+
max_iterations: int
|
| 41 |
+
Registration will terminate once the algorithm has taken this
|
| 42 |
+
many iterations.
|
| 43 |
+
|
| 44 |
+
tolerance: float (positive)
|
| 45 |
+
Registration will terminate once the difference between
|
| 46 |
+
consecutive objective function values falls within this tolerance.
|
| 47 |
+
|
| 48 |
+
w: float (between 0 and 1)
|
| 49 |
+
Contribution of the uniform distribution to account for outliers.
|
| 50 |
+
Valid values span 0 (inclusive) and 1 (exclusive).
|
| 51 |
+
|
| 52 |
+
q: float
|
| 53 |
+
The objective function value that represents the misalignment between source
|
| 54 |
+
and target point clouds.
|
| 55 |
+
|
| 56 |
+
diff: float (positive)
|
| 57 |
+
The absolute difference between the current and previous objective function values.
|
| 58 |
+
|
| 59 |
+
P: numpy array
|
| 60 |
+
MxN array of probabilities.
|
| 61 |
+
P[m, n] represents the probability that the m-th source point
|
| 62 |
+
corresponds to the n-th target point.
|
| 63 |
+
|
| 64 |
+
Pt1: numpy array
|
| 65 |
+
Nx1 column array.
|
| 66 |
+
Multiplication result between the transpose of P and a column vector of all 1s.
|
| 67 |
+
|
| 68 |
+
P1: numpy array
|
| 69 |
+
Mx1 column array.
|
| 70 |
+
Multiplication result between P and a column vector of all 1s.
|
| 71 |
+
|
| 72 |
+
Np: float (positive)
|
| 73 |
+
The sum of all elements in P.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, X, Y, sigma2=None, max_iterations=None, tolerance=None, w=None, *args, **kwargs):
|
| 78 |
+
if type(X) is not np.ndarray or X.ndim != 2:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"The target point cloud (X) must be at a 2D numpy array.")
|
| 81 |
+
|
| 82 |
+
if type(Y) is not np.ndarray or Y.ndim != 2:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"The source point cloud (Y) must be a 2D numpy array.")
|
| 85 |
+
|
| 86 |
+
if X.shape[1] != Y.shape[1]:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
"Both point clouds need to have the same number of dimensions.")
|
| 89 |
+
|
| 90 |
+
if sigma2 is not None and (not isinstance(sigma2, numbers.Number) or sigma2 <= 0):
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Expected a positive value for sigma2 instead got: {}".format(sigma2))
|
| 93 |
+
|
| 94 |
+
if max_iterations is not None and (not isinstance(max_iterations, numbers.Number) or max_iterations < 0):
|
| 95 |
+
raise ValueError(
|
| 96 |
+
"Expected a positive integer for max_iterations instead got: {}".format(max_iterations))
|
| 97 |
+
elif isinstance(max_iterations, numbers.Number) and not isinstance(max_iterations, int):
|
| 98 |
+
warn("Received a non-integer value for max_iterations: {}. Casting to integer.".format(max_iterations))
|
| 99 |
+
max_iterations = int(max_iterations)
|
| 100 |
+
|
| 101 |
+
if tolerance is not None and (not isinstance(tolerance, numbers.Number) or tolerance < 0):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Expected a positive float for tolerance instead got: {}".format(tolerance))
|
| 104 |
+
|
| 105 |
+
if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1):
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"Expected a value between 0 (inclusive) and 1 (exclusive) for w instead got: {}".format(w))
|
| 108 |
+
|
| 109 |
+
self.X = X
|
| 110 |
+
self.Y = Y
|
| 111 |
+
self.TY = Y
|
| 112 |
+
self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2
|
| 113 |
+
(self.N, self.D) = self.X.shape
|
| 114 |
+
(self.M, _) = self.Y.shape
|
| 115 |
+
self.tolerance = 0.001 if tolerance is None else tolerance
|
| 116 |
+
self.w = 0.0 if w is None else w
|
| 117 |
+
self.max_iterations = 100 if max_iterations is None else max_iterations
|
| 118 |
+
self.iteration = 0
|
| 119 |
+
self.diff = np.inf
|
| 120 |
+
self.q = np.inf
|
| 121 |
+
self.P = np.zeros((self.M, self.N))
|
| 122 |
+
self.Pt1 = np.zeros((self.N, ))
|
| 123 |
+
self.P1 = np.zeros((self.M, ))
|
| 124 |
+
self.PX = np.zeros((self.M, self.D))
|
| 125 |
+
self.Np = 0
|
| 126 |
+
|
| 127 |
+
def register(self, callback=lambda **kwargs: None):
|
| 128 |
+
"""
|
| 129 |
+
Perform the EM registration.
|
| 130 |
+
|
| 131 |
+
Attributes
|
| 132 |
+
----------
|
| 133 |
+
callback: function
|
| 134 |
+
A function that will be called after each iteration.
|
| 135 |
+
Can be used to visualize the registration process.
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
self.TY: numpy array
|
| 140 |
+
MxD array of transformed source points.
|
| 141 |
+
|
| 142 |
+
registration_parameters:
|
| 143 |
+
Returned params dependent on registration method used.
|
| 144 |
+
"""
|
| 145 |
+
self.transform_point_cloud()
|
| 146 |
+
while self.iteration < self.max_iterations and self.diff > self.tolerance:
|
| 147 |
+
self.iterate()
|
| 148 |
+
if callable(callback):
|
| 149 |
+
kwargs = {'iteration': self.iteration,
|
| 150 |
+
'error': self.q, 'X': self.X, 'Y': self.TY}
|
| 151 |
+
callback(**kwargs)
|
| 152 |
+
|
| 153 |
+
return self.TY, self.get_registration_parameters()
|
| 154 |
+
|
| 155 |
+
def get_registration_parameters(self):
|
| 156 |
+
"""
|
| 157 |
+
Placeholder for child classes.
|
| 158 |
+
"""
|
| 159 |
+
raise NotImplementedError(
|
| 160 |
+
"Registration parameters should be defined in child classes.")
|
| 161 |
+
|
| 162 |
+
def update_transform(self):
|
| 163 |
+
"""
|
| 164 |
+
Placeholder for child classes.
|
| 165 |
+
"""
|
| 166 |
+
raise NotImplementedError(
|
| 167 |
+
"Updating transform parameters should be defined in child classes.")
|
| 168 |
+
|
| 169 |
+
def transform_point_cloud(self):
|
| 170 |
+
"""
|
| 171 |
+
Placeholder for child classes.
|
| 172 |
+
"""
|
| 173 |
+
raise NotImplementedError(
|
| 174 |
+
"Updating the source point cloud should be defined in child classes.")
|
| 175 |
+
|
| 176 |
+
def update_variance(self):
|
| 177 |
+
"""
|
| 178 |
+
Placeholder for child classes.
|
| 179 |
+
"""
|
| 180 |
+
raise NotImplementedError(
|
| 181 |
+
"Updating the Gaussian variance for the mixture model should be defined in child classes.")
|
| 182 |
+
|
| 183 |
+
def iterate(self):
|
| 184 |
+
"""
|
| 185 |
+
Perform one iteration of the EM algorithm.
|
| 186 |
+
"""
|
| 187 |
+
self.expectation()
|
| 188 |
+
self.maximization()
|
| 189 |
+
self.iteration += 1
|
| 190 |
+
|
| 191 |
+
def expectation(self):
|
| 192 |
+
"""
|
| 193 |
+
Compute the expectation step of the EM algorithm.
|
| 194 |
+
"""
|
| 195 |
+
P = np.sum((self.X[None, :, :] - self.TY[:, None, :])**2, axis=2) # (M, N)
|
| 196 |
+
P = np.exp(-P/(2*self.sigma2))
|
| 197 |
+
c = (2*np.pi*self.sigma2)**(self.D/2)*self.w/(1. - self.w)*self.M/self.N
|
| 198 |
+
|
| 199 |
+
den = np.sum(P, axis = 0, keepdims = True) # (1, N)
|
| 200 |
+
den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c
|
| 201 |
+
|
| 202 |
+
self.P = np.divide(P, den)
|
| 203 |
+
self.Pt1 = np.sum(self.P, axis=0)
|
| 204 |
+
self.P1 = np.sum(self.P, axis=1)
|
| 205 |
+
self.Np = np.sum(self.P1)
|
| 206 |
+
self.PX = np.matmul(self.P, self.X)
|
| 207 |
+
|
| 208 |
+
def maximization(self):
|
| 209 |
+
"""
|
| 210 |
+
Compute the maximization step of the EM algorithm.
|
| 211 |
+
"""
|
| 212 |
+
self.update_transform()
|
| 213 |
+
self.transform_point_cloud()
|
| 214 |
+
self.update_variance()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class DeformableRegistration(EMRegistration):
|
| 218 |
+
"""
|
| 219 |
+
Deformable registration.
|
| 220 |
+
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
|
| 221 |
+
https://github.com/siavashk/pycpd
|
| 222 |
+
|
| 223 |
+
Attributes
|
| 224 |
+
----------
|
| 225 |
+
alpha: float (positive)
|
| 226 |
+
Represents the trade-off between the goodness of maximum likelihood fit and regularization.
|
| 227 |
+
|
| 228 |
+
beta: float(positive)
|
| 229 |
+
Width of the Gaussian kernel.
|
| 230 |
+
|
| 231 |
+
low_rank: bool
|
| 232 |
+
Whether to use low rank approximation.
|
| 233 |
+
|
| 234 |
+
num_eig: int
|
| 235 |
+
Number of eigenvectors to use in lowrank calculation.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(self, alpha=None, beta=None, low_rank=False, num_eig=100, *args, **kwargs):
|
| 239 |
+
super().__init__(*args, **kwargs)
|
| 240 |
+
if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0):
|
| 241 |
+
raise ValueError(
|
| 242 |
+
"Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha))
|
| 243 |
+
|
| 244 |
+
if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0):
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta))
|
| 247 |
+
|
| 248 |
+
self.alpha = 2 if alpha is None else alpha
|
| 249 |
+
self.beta = 2 if beta is None else beta
|
| 250 |
+
self.W = np.zeros((self.M, self.D))
|
| 251 |
+
self.G = gaussian_kernel(self.Y, self.beta)
|
| 252 |
+
self.low_rank = low_rank
|
| 253 |
+
self.num_eig = num_eig
|
| 254 |
+
if self.low_rank is True:
|
| 255 |
+
self.Q, self.S = low_rank_eigen(self.G, self.num_eig)
|
| 256 |
+
self.inv_S = np.diag(1./self.S)
|
| 257 |
+
self.S = np.diag(self.S)
|
| 258 |
+
self.E = 0.
|
| 259 |
+
|
| 260 |
+
def update_transform(self):
|
| 261 |
+
"""
|
| 262 |
+
Calculate a new estimate of the deformable transformation.
|
| 263 |
+
See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
if self.low_rank is False:
|
| 267 |
+
A = np.dot(np.diag(self.P1), self.G) + \
|
| 268 |
+
self.alpha * self.sigma2 * np.eye(self.M)
|
| 269 |
+
B = self.PX - np.dot(np.diag(self.P1), self.Y)
|
| 270 |
+
self.W = np.linalg.solve(A, B)
|
| 271 |
+
|
| 272 |
+
elif self.low_rank is True:
|
| 273 |
+
# Matlab code equivalent can be found here:
|
| 274 |
+
# https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
|
| 275 |
+
dP = np.diag(self.P1)
|
| 276 |
+
dPQ = np.matmul(dP, self.Q)
|
| 277 |
+
F = self.PX - np.matmul(dP, self.Y)
|
| 278 |
+
|
| 279 |
+
self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
|
| 280 |
+
np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
|
| 281 |
+
(np.matmul(self.Q.T, F))))))
|
| 282 |
+
QtW = np.matmul(self.Q.T, self.W)
|
| 283 |
+
self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))
|
| 284 |
+
|
| 285 |
+
def transform_point_cloud(self, Y=None):
|
| 286 |
+
"""
|
| 287 |
+
Update a point cloud using the new estimate of the deformable transformation.
|
| 288 |
+
|
| 289 |
+
Attributes
|
| 290 |
+
----------
|
| 291 |
+
Y: numpy array, optional
|
| 292 |
+
Array of points to transform - use to predict on new set of points.
|
| 293 |
+
Best for predicting on new points not used to run initial registration.
|
| 294 |
+
If None, self.Y used.
|
| 295 |
+
|
| 296 |
+
Returns
|
| 297 |
+
-------
|
| 298 |
+
If Y is None, returns None.
|
| 299 |
+
Otherwise, returns the transformed Y.
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
self.W[:,2:]=0
|
| 304 |
+
if Y is not None:
|
| 305 |
+
G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y)
|
| 306 |
+
return Y + np.dot(G, self.W)
|
| 307 |
+
else:
|
| 308 |
+
if self.low_rank is False:
|
| 309 |
+
self.TY = self.Y + np.dot(self.G, self.W)
|
| 310 |
+
|
| 311 |
+
elif self.low_rank is True:
|
| 312 |
+
self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)))
|
| 313 |
+
return
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def update_variance(self):
|
| 317 |
+
"""
|
| 318 |
+
Update the variance of the mixture model using the new estimate of the deformable transformation.
|
| 319 |
+
See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
|
| 320 |
+
|
| 321 |
+
"""
|
| 322 |
+
qprev = self.sigma2
|
| 323 |
+
|
| 324 |
+
# The original CPD paper does not explicitly calculate the objective functional.
|
| 325 |
+
# This functional will include terms from both the negative log-likelihood and
|
| 326 |
+
# the Gaussian kernel used for regularization.
|
| 327 |
+
self.q = np.inf
|
| 328 |
+
|
| 329 |
+
xPx = np.dot(np.transpose(self.Pt1), np.sum(
|
| 330 |
+
np.multiply(self.X, self.X), axis=1))
|
| 331 |
+
yPy = np.dot(np.transpose(self.P1), np.sum(
|
| 332 |
+
np.multiply(self.TY, self.TY), axis=1))
|
| 333 |
+
trPXY = np.sum(np.multiply(self.TY, self.PX))
|
| 334 |
+
|
| 335 |
+
self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D)
|
| 336 |
+
|
| 337 |
+
if self.sigma2 <= 0:
|
| 338 |
+
self.sigma2 = self.tolerance / 10
|
| 339 |
+
|
| 340 |
+
# Here we use the difference between the current and previous
|
| 341 |
+
# estimate of the variance as a proxy to test for convergence.
|
| 342 |
+
self.diff = np.abs(self.sigma2 - qprev)
|
| 343 |
+
|
| 344 |
+
def get_registration_parameters(self):
|
| 345 |
+
"""
|
| 346 |
+
Return the current estimate of the deformable transformation parameters.
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
Returns
|
| 350 |
+
-------
|
| 351 |
+
self.G: numpy array
|
| 352 |
+
Gaussian kernel matrix.
|
| 353 |
+
|
| 354 |
+
self.W: numpy array
|
| 355 |
+
Deformable transformation matrix.
|
| 356 |
+
"""
|
| 357 |
+
return self.G, self.W
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def initialize_sigma2(X, Y):
|
| 362 |
+
"""
|
| 363 |
+
Initialize the variance (sigma2).
|
| 364 |
+
|
| 365 |
+
param
|
| 366 |
+
----------
|
| 367 |
+
X: numpy array
|
| 368 |
+
NxD array of points for target.
|
| 369 |
+
|
| 370 |
+
Y: numpy array
|
| 371 |
+
MxD array of points for source.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
sigma2: float
|
| 376 |
+
Initial variance.
|
| 377 |
+
"""
|
| 378 |
+
(N, D) = X.shape
|
| 379 |
+
(M, _) = Y.shape
|
| 380 |
+
diff = X[None, :, :] - Y[:, None, :]
|
| 381 |
+
err = diff ** 2
|
| 382 |
+
return np.sum(err) / (D * M * N)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def gaussian_kernel(X, beta, Y=None):
|
| 387 |
+
"""
|
| 388 |
+
Computes a Gaussian (RBF) kernel matrix between two sets of vectors.
|
| 389 |
+
|
| 390 |
+
:param X: A numpy array of shape (n_samples_X, n_features) representing the first set of vectors.
|
| 391 |
+
:param beta: The standard deviation parameter for the Gaussian kernel. It controls the spread of the kernel.
|
| 392 |
+
:param Y: An optional numpy array of shape (n_samples_Y, n_features) representing the second set of vectors.
|
| 393 |
+
If None, the function computes the kernel between `X` and itself (i.e., the Gram matrix).
|
| 394 |
+
:return: A numpy array of shape (n_samples_X, n_samples_Y) representing the Gaussian kernel matrix.
|
| 395 |
+
Each element (i, j) in the matrix is computed as:
|
| 396 |
+
`exp(-||X[i] - Y[j]||^2 / (2 * beta^2))`
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
# If Y is not provided, use X for both sets, computing the kernel matrix between X and itself
|
| 400 |
+
if Y is None:
|
| 401 |
+
Y = X
|
| 402 |
+
|
| 403 |
+
# Compute the difference tensor between each pair of vectors in X and Y
|
| 404 |
+
# The resulting shape is (n_samples_X, n_samples_Y, n_features)
|
| 405 |
+
diff = X[:, None, :] - Y[None, :, :]
|
| 406 |
+
|
| 407 |
+
# Square the differences element-wise
|
| 408 |
+
diff = np.square(diff)
|
| 409 |
+
|
| 410 |
+
# Sum the squared differences across the feature dimension (axis 2) to get squared Euclidean distances
|
| 411 |
+
# The resulting shape is (n_samples_X, n_samples_Y)
|
| 412 |
+
diff = np.sum(diff, axis=2)
|
| 413 |
+
|
| 414 |
+
# Apply the Gaussian (RBF) kernel formula: exp(-||X[i] - Y[j]||^2 / (2 * beta^2))
|
| 415 |
+
kernel_matrix = np.exp(-diff / (2 * beta**2))
|
| 416 |
+
|
| 417 |
+
return kernel_matrix
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def low_rank_eigen(G, num_eig):
|
| 422 |
+
"""
|
| 423 |
+
Calculate the top `num_eig` eigenvectors and eigenvalues of a given Gaussian matrix G.
|
| 424 |
+
This function is useful for dimensionality reduction or when a low-rank approximation is needed.
|
| 425 |
+
|
| 426 |
+
:param G: A square matrix (numpy array) for which the eigen decomposition is to be performed.
|
| 427 |
+
:param num_eig: The number of top eigenvectors and eigenvalues to return, based on the magnitude of eigenvalues.
|
| 428 |
+
:return: A tuple containing:
|
| 429 |
+
- Q: A numpy array with shape (n, num_eig) containing the top `num_eig` eigenvectors of the matrix `G`.
|
| 430 |
+
Each column in `Q` corresponds to an eigenvector.
|
| 431 |
+
- S: A numpy array of shape (num_eig,) containing the top `num_eig` eigenvalues of the matrix `G`.
|
| 432 |
+
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
# Perform eigen decomposition on matrix G
|
| 436 |
+
# `S` will contain all the eigenvalues, and `Q` will contain the corresponding eigenvectors
|
| 437 |
+
S, Q = np.linalg.eigh(G)
|
| 438 |
+
|
| 439 |
+
# Sort eigenvalues in descending order based on their absolute values
|
| 440 |
+
# Get the indices of the top `num_eig` largest eigenvalues
|
| 441 |
+
eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig])
|
| 442 |
+
|
| 443 |
+
# Select the corresponding top eigenvectors based on the sorted indices
|
| 444 |
+
Q = Q[:, eig_indices] # Q now contains the top `num_eig` eigenvectors
|
| 445 |
+
|
| 446 |
+
# Select the top `num_eig` eigenvalues based on the sorted indices
|
| 447 |
+
S = S[eig_indices] # S now contains the top `num_eig` eigenvalues
|
| 448 |
+
|
| 449 |
+
return Q, S
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def find_homography_translation_rotation(src_points, dst_points):
|
| 454 |
+
"""
|
| 455 |
+
Find the homography between two sets of coordinates with only translation and rotation.
|
| 456 |
+
|
| 457 |
+
:param src_points: A numpy array of shape (n, 2) containing source coordinates.
|
| 458 |
+
:param dst_points: A numpy array of shape (n, 2) containing destination coordinates.
|
| 459 |
+
:return: A 3x3 homography matrix.
|
| 460 |
+
"""
|
| 461 |
+
# Ensure the points are in the correct shape
|
| 462 |
+
assert src_points.shape == dst_points.shape
|
| 463 |
+
assert src_points.shape[1] == 2
|
| 464 |
+
|
| 465 |
+
# Calculate the centroids of the point sets
|
| 466 |
+
src_centroid = np.mean(src_points, axis=0)
|
| 467 |
+
dst_centroid = np.mean(dst_points, axis=0)
|
| 468 |
+
|
| 469 |
+
# Center the points around the centroids
|
| 470 |
+
centered_src_points = src_points - src_centroid
|
| 471 |
+
centered_dst_points = dst_points - dst_centroid
|
| 472 |
+
|
| 473 |
+
# Calculate the covariance matrix
|
| 474 |
+
H = np.dot(centered_src_points.T, centered_dst_points)
|
| 475 |
+
|
| 476 |
+
# Singular Value Decomposition (SVD)
|
| 477 |
+
U, S, Vt = np.linalg.svd(H)
|
| 478 |
+
|
| 479 |
+
# Calculate the rotation matrix
|
| 480 |
+
R = np.dot(Vt.T, U.T)
|
| 481 |
+
|
| 482 |
+
# Ensure a proper rotation matrix (det(R) = 1)
|
| 483 |
+
if np.linalg.det(R) < 0:
|
| 484 |
+
Vt[-1, :] *= -1
|
| 485 |
+
R = np.dot(Vt.T, U.T)
|
| 486 |
+
|
| 487 |
+
# Calculate the translation vector
|
| 488 |
+
t = dst_centroid - np.dot(R, src_centroid)
|
| 489 |
+
|
| 490 |
+
# Construct the homography matrix
|
| 491 |
+
homography_matrix = np.eye(3)
|
| 492 |
+
homography_matrix[0:2, 0:2] = R
|
| 493 |
+
homography_matrix[0:2, 2] = t
|
| 494 |
+
|
| 495 |
+
return homography_matrix
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def apply_homography(coordinates, H):
|
| 500 |
+
"""
|
| 501 |
+
Apply a 3x3 homography matrix to 2D coordinates.
|
| 502 |
+
|
| 503 |
+
:param coordinates: A numpy array of shape (n, 2) containing 2D coordinates.
|
| 504 |
+
:param H: A numpy array of shape (3, 3) representing the homography matrix.
|
| 505 |
+
:return: A numpy array of shape (n, 2) with transformed coordinates.
|
| 506 |
+
"""
|
| 507 |
+
# Convert (x, y) to homogeneous coordinates (x, y, 1)
|
| 508 |
+
n = coordinates.shape[0]
|
| 509 |
+
homogeneous_coords = np.hstack((coordinates, np.ones((n, 1))))
|
| 510 |
+
|
| 511 |
+
# Apply the homography matrix
|
| 512 |
+
transformed_homogeneous = np.dot(homogeneous_coords, H.T)
|
| 513 |
+
|
| 514 |
+
# Convert back from homogeneous coordinates (x', y', w') to (x'/w', y'/w')
|
| 515 |
+
transformed_coords = transformed_homogeneous[:, :2] / transformed_homogeneous[:, [2]]
|
| 516 |
+
|
| 517 |
+
return transformed_coords
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def align_tissue(ad_tar_coor, ad_src_coor, pca_comb_features, src_img, alpha=0.5):
|
| 522 |
+
"""
|
| 523 |
+
Aligns the source coordinates to the target coordinates using Coherent Point Drift (CPD)
|
| 524 |
+
registration, and applies a homography transformation to warp the source coordinates accordingly.
|
| 525 |
+
|
| 526 |
+
:param ad_tar_coor: Numpy array of target coordinates to which the source will be aligned.
|
| 527 |
+
:param ad_src_coor: Numpy array of source coordinates that will be aligned to the target.
|
| 528 |
+
:param pca_comb_features: PCA-combined feature matrix used as additional features for the alignment process.
|
| 529 |
+
:param src_img: Source image to be warped based on the alignment.
|
| 530 |
+
:param alpha: Regularization parameter for CPD registration, default is 0.5.
|
| 531 |
+
:return:
|
| 532 |
+
- cpd_coor: The new source coordinates after CPD alignment.
|
| 533 |
+
- homo_coor: The source coordinates after applying the homography transformation.
|
| 534 |
+
- aligned_image: The source image warped based on the homography transformation.
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
# Normalize target and source coordinates to the range [0, 1]
|
| 538 |
+
ad_tar_coor_z = (ad_tar_coor - ad_tar_coor.min()) / (ad_tar_coor.max() - ad_tar_coor.min())
|
| 539 |
+
ad_src_coor_z = (ad_src_coor - ad_src_coor.min()) / (ad_src_coor.max() - ad_src_coor.min())
|
| 540 |
+
|
| 541 |
+
# Normalize PCA-combined features to the range [0, 1]
|
| 542 |
+
pca_comb_features_z = (pca_comb_features - pca_comb_features.min()) / (pca_comb_features.max() - pca_comb_features.min())
|
| 543 |
+
|
| 544 |
+
# Concatenate spatial and PCA-combined features for target and source
|
| 545 |
+
target = np.concatenate((ad_tar_coor_z, pca_comb_features_z[:ad_tar_coor.shape[0], :2]), axis=1)
|
| 546 |
+
source = np.concatenate((ad_src_coor_z, pca_comb_features_z[ad_tar_coor.shape[0]:, :2]), axis=1)
|
| 547 |
+
|
| 548 |
+
# Initialize and run the CPD registration (deformable with regularization)
|
| 549 |
+
reg = DeformableRegistration(X=target, Y=source, low_rank=True,
|
| 550 |
+
alpha=alpha,
|
| 551 |
+
max_iterations=int(1e9), tolerance=1e-9)
|
| 552 |
+
|
| 553 |
+
TY = reg.register()[0] # TY contains the transformed source points
|
| 554 |
+
|
| 555 |
+
# Rescale the CPD-aligned coordinates back to the original range of target coordinates
|
| 556 |
+
cpd_coor = TY[:, :2] * (ad_tar_coor.max() - ad_tar_coor.min()) + ad_tar_coor.min()
|
| 557 |
+
|
| 558 |
+
# Find homography transformation based on CPD-aligned coordinates and apply it
|
| 559 |
+
h = find_homography_translation_rotation(ad_src_coor, cpd_coor)
|
| 560 |
+
homo_coor = apply_homography(ad_src_coor, h)
|
| 561 |
+
|
| 562 |
+
# Warp the source image using the computed homography
|
| 563 |
+
aligned_image = cv2.warpPerspective(src_img, h, (src_img.shape[1], src_img.shape[0]))
|
| 564 |
+
|
| 565 |
+
# Return the CPD-aligned coordinates, the homography-transformed coordinates, and the warped image
|
| 566 |
+
return cpd_coor, homo_coor, aligned_image
|
| 567 |
+
|
| 568 |
+
|
src/loki/annotate.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import os
|
| 5 |
+
import scanpy as sc
|
| 6 |
+
import json
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def annotate_with_bulk(img_features, bulk_features, normalize=True, T=1, tensor=False):
|
| 12 |
+
"""
|
| 13 |
+
Annotates tissue image with similarity scores between image features and bulk RNA-seq features.
|
| 14 |
+
|
| 15 |
+
:param img_features: Feature matrix representing histopathology image features.
|
| 16 |
+
:param bulk_features: Feature vector representing bulk RNA-seq features.
|
| 17 |
+
:param normalize: Whether to normalize similarity scores, default=True.
|
| 18 |
+
:param T: Temperature parameter to control the sharpness of the softmax distribution. Higher values result in a smoother distribution.
|
| 19 |
+
:param tensor: Feature format in torch tensor or not, default=False.
|
| 20 |
+
|
| 21 |
+
:return: An array or tensor containing the normalized similarity scores.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
if tensor:
|
| 25 |
+
# Compute similarity between image features and bulk RNA-seq features
|
| 26 |
+
cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 27 |
+
similarity = cosine_similarity(img_features, bulk_features.unsqueeze(0)) # Shape: [n]
|
| 28 |
+
|
| 29 |
+
# Optional normalization using the feature vector's norm
|
| 30 |
+
if normalize:
|
| 31 |
+
normalization_factor = torch.sqrt(torch.tensor([bulk_features.shape[0]], dtype=torch.float)) # sqrt(768)
|
| 32 |
+
similarity = similarity / normalization_factor
|
| 33 |
+
|
| 34 |
+
# Reshape and apply temperature scaling for softmax
|
| 35 |
+
similarity = similarity.unsqueeze(0) # Shape: [1, n]
|
| 36 |
+
similarity = similarity / T # Control distribution sharpness
|
| 37 |
+
|
| 38 |
+
# Convert similarity scores to probability distribution using softmax
|
| 39 |
+
similarity = torch.nn.functional.softmax(similarity, dim=-1) # Shape: [1, n]
|
| 40 |
+
|
| 41 |
+
else:
|
| 42 |
+
# Compute similarity for non-tensor mode
|
| 43 |
+
similarity = np.dot(img_features.T, bulk_features)
|
| 44 |
+
|
| 45 |
+
# Apply a softmax-like normalization for numerical stability
|
| 46 |
+
max_similarity = np.max(similarity) # Maximum value for stability
|
| 47 |
+
similarity = np.exp(similarity - max_similarity) / np.sum(np.exp(similarity - max_similarity))
|
| 48 |
+
|
| 49 |
+
# Normalize similarity scores to [0, 1] range for interpretation
|
| 50 |
+
similarity = (similarity - np.min(similarity)) / (np.max(similarity) - np.min(similarity))
|
| 51 |
+
|
| 52 |
+
return similarity
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def annotate_with_marker_genes(classes, image_embeddings, all_text_embeddings):
|
| 57 |
+
"""
|
| 58 |
+
Annotates tissue image with similarity scores between image features and marker gene features.
|
| 59 |
+
|
| 60 |
+
:param classes: A list or array of tissue type labels.
|
| 61 |
+
:param image_embeddings: A numpy array or torch tensor of image embeddings (shape: [n_images, embedding_dim]).
|
| 62 |
+
:param all_text_embeddings: A numpy array or torch tensor of text embeddings of the marker genes
|
| 63 |
+
(shape: [n_classes, embedding_dim]).
|
| 64 |
+
|
| 65 |
+
:return:
|
| 66 |
+
- dot_similarity: The matrix of dot product similarities between image embeddings and text embeddings.
|
| 67 |
+
- pred_class: The predicted tissue type for the image based on the highest similarity score.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Calculate dot product similarity between image embeddings and text embeddings
|
| 71 |
+
# This results in a similarity matrix of shape [n_images, n_classes]
|
| 72 |
+
dot_similarity = image_embeddings @ all_text_embeddings.T
|
| 73 |
+
|
| 74 |
+
# Find the class with the highest similarity for each image
|
| 75 |
+
# Use argmax to identify the index of the highest similarity score
|
| 76 |
+
pred_class = classes[dot_similarity.argmax()]
|
| 77 |
+
|
| 78 |
+
return dot_similarity, pred_class
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_image_annotation(image_path):
|
| 83 |
+
"""
|
| 84 |
+
Loads an image with annotation.
|
| 85 |
+
|
| 86 |
+
:param image_path: The file path to the image.
|
| 87 |
+
|
| 88 |
+
:return: The processed image, converted to BGR color space and of type uint8.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# Load the image from the specified file path using OpenCV
|
| 92 |
+
image = cv2.imread(image_path)
|
| 93 |
+
|
| 94 |
+
# Convert the color from RGB (OpenCV loads as BGR by default) to BGR (which matches common color standards)
|
| 95 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 96 |
+
|
| 97 |
+
# Ensure the image is of type uint8 for proper handling in OpenCV and other image processing libraries
|
| 98 |
+
image = image.astype(np.uint8)
|
| 99 |
+
|
| 100 |
+
return image
|
| 101 |
+
|
| 102 |
+
|
src/loki/decompose.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import tangram as tg
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import anndata
|
| 6 |
+
from sklearn.decomposition import PCA
|
| 7 |
+
from sklearn.neighbors import NearestNeighbors
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_feature_ad(ad_expr, feature_path, sc=False):
|
| 12 |
+
"""
|
| 13 |
+
Generates an AnnData object with OmiCLIP text or image embeddings.
|
| 14 |
+
|
| 15 |
+
:param ad_expr: AnnData object containing metadata for the dataset.
|
| 16 |
+
:param feature_path: Path to the CSV file containing the features to be loaded.
|
| 17 |
+
:param sc: Boolean flag indicating whether to copy single-cell metadata or ST metadata. Default is False (ST).
|
| 18 |
+
:return: A new AnnData object with the loaded features and relevant metadata from ad_expr.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Load features from the CSV file. The index should match the cells/spots in ad_expr.obs.index.
|
| 22 |
+
features = pd.read_csv(feature_path, index_col=0)[ad_expr.obs.index]
|
| 23 |
+
|
| 24 |
+
# Create a new AnnData object with the features, transposing them to have cells/spots as rows
|
| 25 |
+
feature_ad = anndata.AnnData(features[ad_expr.obs.index].T)
|
| 26 |
+
|
| 27 |
+
# Copy relevant metadata from ad_expr based on the sc flag
|
| 28 |
+
if sc:
|
| 29 |
+
# If the data is single-cell (sc), copy the metadata from ad_expr.obs
|
| 30 |
+
feature_ad.obs = ad_expr.obs.copy()
|
| 31 |
+
else:
|
| 32 |
+
# If the data is spatial, copy the 'cell_num', 'spatial' info, and spatial coordinates
|
| 33 |
+
feature_ad.obs['cell_num'] = ad_expr.obs['cell_num'].copy()
|
| 34 |
+
feature_ad.uns['spatial'] = ad_expr.uns['spatial'].copy()
|
| 35 |
+
feature_ad.obsm['spatial'] = ad_expr.obsm['spatial'].copy()
|
| 36 |
+
|
| 37 |
+
return feature_ad
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def normalize_percentile(df, cols, min_percentile=5, max_percentile=95):
|
| 42 |
+
"""
|
| 43 |
+
Clips and normalizes the specified columns of a DataFrame based on percentile thresholds,
|
| 44 |
+
transforming their values to the [0, 1] range.
|
| 45 |
+
|
| 46 |
+
:param df: A pandas DataFrame containing the columns to normalize.
|
| 47 |
+
:type df: pandas.DataFrame
|
| 48 |
+
:param cols: A list of column names in `df` that should be normalized.
|
| 49 |
+
:type cols: list[str]
|
| 50 |
+
:param min_percentile: The lower percentile used for clipping (defaults to 5).
|
| 51 |
+
:type min_percentile: float
|
| 52 |
+
:param max_percentile: The upper percentile used for clipping (defaults to 95).
|
| 53 |
+
:type max_percentile: float
|
| 54 |
+
:return: The same DataFrame with specified columns clipped and normalized.
|
| 55 |
+
:rtype: pandas.DataFrame
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Iterate over each column that needs to be normalized
|
| 59 |
+
for col in cols:
|
| 60 |
+
# Compute the lower and upper values at the given percentiles
|
| 61 |
+
min_val = np.percentile(df[col], min_percentile)
|
| 62 |
+
max_val = np.percentile(df[col], max_percentile)
|
| 63 |
+
|
| 64 |
+
# Clip the column's values between these percentile thresholds
|
| 65 |
+
df[col] = np.clip(df[col], min_val, max_val)
|
| 66 |
+
|
| 67 |
+
# Perform min-max normalization to scale the clipped values to the [0, 1] range
|
| 68 |
+
df[col] = (df[col] - min_val) / (max_val - min_val)
|
| 69 |
+
|
| 70 |
+
return df
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False, major_types=None, min_percentile=5, max_percentile=95):
|
| 75 |
+
"""
|
| 76 |
+
Performs cell type decomposition on spatial data (ST or image) with single-cell data .
|
| 77 |
+
|
| 78 |
+
:param sc_ad: AnnData object containing single-cell meta data.
|
| 79 |
+
:param st_ad: AnnData object containing spatial data (ST or image) meta data.
|
| 80 |
+
:param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
|
| 81 |
+
:param NMS_mode: Boolean flag to apply Non-Maximum Suppression (NMS) mode. Default is False.
|
| 82 |
+
:param major_types: Major cell types used for NMS mode. Default is None.
|
| 83 |
+
:param min_percentile: The lower percentile used for clipping (defaults to 5).
|
| 84 |
+
:param max_percentile: The upper percentile used for clipping (defaults to 95).
|
| 85 |
+
:return: The spatial AnnData object with projected cell type annotations.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# Preprocess the data for decomposition using tangram (tg)
|
| 89 |
+
tg.pp_adatas(sc_ad, st_ad, genes=None) # Preprocessing: match genes between single-cell and spatial data
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Map single-cell data to spatial data using Tangram's "map_cells_to_space" function
|
| 93 |
+
ad_map = tg.map_cells_to_space(
|
| 94 |
+
sc_ad, st_ad,
|
| 95 |
+
mode="clusters", # Map based on clusters (cell types)
|
| 96 |
+
cluster_label=cell_type_col, # Column in `sc_ad.obs` representing cell type
|
| 97 |
+
device='cpu', # Run on CPU (or 'cuda' if GPU is available)
|
| 98 |
+
scale=False, # Don't scale data (can be set to True if needed)
|
| 99 |
+
density_prior='uniform', # Use prior information for cell densities
|
| 100 |
+
random_state=10, # Set random state for reproducibility
|
| 101 |
+
verbose=False, # Disable verbose output for cleaner logging
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Project cell type annotations from the single-cell data to the spatial data
|
| 105 |
+
tg.project_cell_annotations(ad_map, st_ad, annotation=cell_type_col)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if NMS_mode:
|
| 109 |
+
major_types = major_types
|
| 110 |
+
st_ad.obs = normalize_percentile(st_ad.obsm['tangram_ct_pred'], major_types, min_percentile, max_percentile)
|
| 111 |
+
|
| 112 |
+
st_ad_binary = st_ad.obsm['tangram_ct_pred'][major_types].copy()
|
| 113 |
+
# Retain the max value in each row and set the rest to 0
|
| 114 |
+
st_ad.obs[major_types] = st_ad_binary.where(st_ad_binary.eq(st_ad_binary.max(axis=1), axis=0), other=0)
|
| 115 |
+
|
| 116 |
+
return st_ad # Return the spatial AnnData object with the projected annotations
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def assign_cells_to_spots(cell_locs, spot_locs, patch_size=16):
|
| 121 |
+
"""
|
| 122 |
+
Assigns cells to spots based on their spatial coordinates. Each cell within the specified patch size (radius)
|
| 123 |
+
of a spot will be assigned to that spot.
|
| 124 |
+
|
| 125 |
+
:param cell_locs: Numpy array of shape (n_cells, 2) with the x, y coordinates of the cells.
|
| 126 |
+
:param spot_locs: Numpy array of shape (n_spots, 2) with the x, y coordinates of the spots.
|
| 127 |
+
:param patch_size: The diameter of the spot patch. The radius used for assignment will be half of this value.
|
| 128 |
+
:return: A sparse matrix where each row corresponds to a cell and each column corresponds to a spot.
|
| 129 |
+
The value is 1 if the cell is assigned to that spot, 0 otherwise.
|
| 130 |
+
"""
|
| 131 |
+
# Initialize the NearestNeighbors model with a radius equal to half the patch size
|
| 132 |
+
neigh = NearestNeighbors(radius=patch_size * 0.5)
|
| 133 |
+
|
| 134 |
+
# Fit the model on the spot locations
|
| 135 |
+
neigh.fit(spot_locs)
|
| 136 |
+
|
| 137 |
+
# Create the radius neighbors graph which will assign cells to spots based on proximity
|
| 138 |
+
# This graph is a sparse matrix where rows are cells and columns are spots, with a 1 indicating assignment
|
| 139 |
+
A = neigh.radius_neighbors_graph(cell_locs, mode='connectivity')
|
| 140 |
+
|
| 141 |
+
return A
|
| 142 |
+
|
| 143 |
+
|
src/loki/plot.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import json
|
| 4 |
+
import cv2
|
| 5 |
+
from matplotlib import cm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
|
| 13 |
+
"""
|
| 14 |
+
Plots the target coordinates and alignment of source coordinates.
|
| 15 |
+
|
| 16 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
|
| 17 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 18 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 19 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 20 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 21 |
+
:param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
|
| 22 |
+
:param s: Marker size for the scatter plot points. Default is 0.8.
|
| 23 |
+
:param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
|
| 24 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Create a figure with three subplots, adjusting size and resolution
|
| 28 |
+
plt.figure(figsize=(10, 3), dpi=300)
|
| 29 |
+
|
| 30 |
+
# First subplot: Plot target coordinates
|
| 31 |
+
plt.subplot(1, 3, 1)
|
| 32 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
|
| 33 |
+
# Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
|
| 34 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 35 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 36 |
+
|
| 37 |
+
# Second subplot: Plot source coordinates
|
| 38 |
+
plt.subplot(1, 3, 2)
|
| 39 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 40 |
+
# Ensure consistent plot limits across subplots by using the same limits as the target coordinates
|
| 41 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 42 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 43 |
+
|
| 44 |
+
# Third subplot: Plot alignment of source coordinates
|
| 45 |
+
plt.subplot(1, 3, 3)
|
| 46 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
|
| 47 |
+
# Maintain the same plot limits across all subplots for a uniform comparison
|
| 48 |
+
plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 49 |
+
plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
|
| 50 |
+
|
| 51 |
+
# Optionally draw boundary lines at the minimum x and y values of the target coordinates
|
| 52 |
+
if boundary_line:
|
| 53 |
+
plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
|
| 54 |
+
plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
|
| 55 |
+
|
| 56 |
+
# Remove the axis labels and ticks from all subplots for a cleaner appearance
|
| 57 |
+
plt.axis('off')
|
| 58 |
+
|
| 59 |
+
# Display the plot
|
| 60 |
+
plt.show()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
|
| 65 |
+
"""
|
| 66 |
+
Plots the target coordinates and alignment of source coordinates with their respective images in the background.
|
| 67 |
+
|
| 68 |
+
:param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
|
| 69 |
+
:param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
|
| 70 |
+
:param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
|
| 71 |
+
:param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
|
| 72 |
+
:param src_img: Image associated with the source coordinates, used as the background in the second subplot.
|
| 73 |
+
:param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
|
| 74 |
+
:param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
|
| 75 |
+
:param tar_features: Feature matrix for the target, used to split color values between target and source data.
|
| 76 |
+
:return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Create a figure with three subplots and set the size and resolution
|
| 80 |
+
plt.figure(figsize=(10, 8), dpi=150)
|
| 81 |
+
|
| 82 |
+
# First subplot: Plot target coordinates with the target image as the background
|
| 83 |
+
plt.subplot(1, 3, 1)
|
| 84 |
+
# Scatter plot for the target coordinates with transparency and small marker size
|
| 85 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 86 |
+
# Overlay the target image with some transparency (alpha = 0.3)
|
| 87 |
+
plt.imshow(tar_img, origin='lower', alpha=0.3)
|
| 88 |
+
|
| 89 |
+
# Second subplot: Plot source coordinates with the source image as the background
|
| 90 |
+
plt.subplot(1, 3, 2)
|
| 91 |
+
# Scatter plot for the source coordinates with transparency and small marker size
|
| 92 |
+
plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 93 |
+
# Overlay the source image with some transparency (alpha = 0.3)
|
| 94 |
+
plt.imshow(src_img, origin='lower', alpha=0.3)
|
| 95 |
+
|
| 96 |
+
# Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
|
| 97 |
+
plt.subplot(1, 3, 3)
|
| 98 |
+
# Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
|
| 99 |
+
plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
|
| 100 |
+
# Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
|
| 101 |
+
plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
|
| 102 |
+
# Overlay the aligned image with some transparency (alpha = 0.3)
|
| 103 |
+
plt.imshow(aligned_image, origin='lower', alpha=0.3)
|
| 104 |
+
|
| 105 |
+
# Turn off the axis for all subplots to give a cleaner visual output
|
| 106 |
+
plt.axis('off')
|
| 107 |
+
|
| 108 |
+
# Display the plots
|
| 109 |
+
plt.show()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def draw_polygon(image, polygon, color='k', thickness=2):
|
| 114 |
+
"""
|
| 115 |
+
Draws one or more polygons on the given image.
|
| 116 |
+
|
| 117 |
+
:param image: The image on which to draw the polygons (as a numpy array).
|
| 118 |
+
:param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 119 |
+
:param color: A string or list of strings representing the color(s) for each polygon.
|
| 120 |
+
If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
|
| 121 |
+
:param thickness: An integer or a list of integers representing the thickness of the polygon borders.
|
| 122 |
+
If a single value is provided, it will be applied to all polygons. Default is 2.
|
| 123 |
+
|
| 124 |
+
:return: The image with the polygons drawn on it.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
|
| 128 |
+
if not isinstance(color, list):
|
| 129 |
+
color = [color] * len(polygon) # Create a list where each polygon gets the same color
|
| 130 |
+
|
| 131 |
+
# Loop through each polygon in the list, along with its corresponding color
|
| 132 |
+
for i, poly in enumerate(polygon):
|
| 133 |
+
# Get the color for the current polygon
|
| 134 |
+
c = color[i]
|
| 135 |
+
|
| 136 |
+
# Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
|
| 137 |
+
c = color_string_to_rgb(c)
|
| 138 |
+
|
| 139 |
+
# Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
|
| 140 |
+
t = thickness[i] if isinstance(thickness, list) else thickness
|
| 141 |
+
|
| 142 |
+
# Convert the polygon coordinates to a numpy array of integers
|
| 143 |
+
poly = np.array(poly, np.int32)
|
| 144 |
+
|
| 145 |
+
# Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
|
| 146 |
+
poly = poly.reshape((-1, 1, 2))
|
| 147 |
+
|
| 148 |
+
# Draw the polygon on the image using OpenCV's `cv2.polylines` function
|
| 149 |
+
# `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
|
| 150 |
+
image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
|
| 151 |
+
|
| 152 |
+
return image
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def blend_images(image1, image2, alpha=0.5):
|
| 157 |
+
"""
|
| 158 |
+
Blends two images together.
|
| 159 |
+
|
| 160 |
+
:param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
|
| 161 |
+
:param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
|
| 162 |
+
:param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
|
| 163 |
+
where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
|
| 164 |
+
|
| 165 |
+
:return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
|
| 166 |
+
The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
# Use cv2.addWeighted to blend the two images.
|
| 170 |
+
# The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
|
| 171 |
+
blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
|
| 172 |
+
|
| 173 |
+
# Return the resulting blended image.
|
| 174 |
+
return blended
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def color_string_to_rgb(color_string):
|
| 179 |
+
"""
|
| 180 |
+
Converts a color string to an RGB tuple.
|
| 181 |
+
|
| 182 |
+
:param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
|
| 183 |
+
a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
|
| 184 |
+
:return:
|
| 185 |
+
A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
|
| 186 |
+
:raises:
|
| 187 |
+
ValueError: If the color string is not recognized.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
# Remove any spaces in the color string
|
| 191 |
+
color_string = color_string.replace(' ', '')
|
| 192 |
+
|
| 193 |
+
# If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
|
| 194 |
+
if color_string.startswith('#'):
|
| 195 |
+
color_string = color_string[1:]
|
| 196 |
+
else:
|
| 197 |
+
# Handle shorthand single-letter color codes by converting them to hex values
|
| 198 |
+
# 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
|
| 199 |
+
if color_string == 'k': # Black
|
| 200 |
+
color_string = '000000'
|
| 201 |
+
elif color_string == 'r': # Red
|
| 202 |
+
color_string = 'ff0000'
|
| 203 |
+
elif color_string == 'g': # Green
|
| 204 |
+
color_string = '00ff00'
|
| 205 |
+
elif color_string == 'b': # Blue
|
| 206 |
+
color_string = '0000ff'
|
| 207 |
+
elif color_string == 'w': # White
|
| 208 |
+
color_string = 'ffffff'
|
| 209 |
+
else:
|
| 210 |
+
# Raise an error if the color string is not recognized
|
| 211 |
+
raise ValueError(f"Unknown color string {color_string}")
|
| 212 |
+
|
| 213 |
+
# Convert the first two characters to the red (R) value
|
| 214 |
+
r = int(color_string[:2], 16)
|
| 215 |
+
|
| 216 |
+
# Convert the next two characters to the green (G) value
|
| 217 |
+
g = int(color_string[2:4], 16)
|
| 218 |
+
|
| 219 |
+
# Convert the last two characters to the blue (B) value
|
| 220 |
+
b = int(color_string[4:], 16)
|
| 221 |
+
|
| 222 |
+
# Return the RGB values as a tuple
|
| 223 |
+
return (r, g, b)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def plot_heatmap(
|
| 228 |
+
coor,
|
| 229 |
+
similairty,
|
| 230 |
+
image_path=None,
|
| 231 |
+
patch_size=(256, 256),
|
| 232 |
+
save_path=None,
|
| 233 |
+
downsize=32,
|
| 234 |
+
cmap='turbo',
|
| 235 |
+
smooth=False,
|
| 236 |
+
boxes=None,
|
| 237 |
+
box_color='k',
|
| 238 |
+
box_thickness=2,
|
| 239 |
+
polygons=None,
|
| 240 |
+
polygons_color='k',
|
| 241 |
+
polygons_thickness=2,
|
| 242 |
+
image_alpha=0.5
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Plots a heatmap overlaid on an image based on given coordinates and similairty.
|
| 246 |
+
|
| 247 |
+
:param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
|
| 248 |
+
:param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
|
| 249 |
+
:param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
|
| 250 |
+
:param patch_size: Size of each patch in pixels (default is 256x256).
|
| 251 |
+
:param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
|
| 252 |
+
:param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
|
| 253 |
+
:param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
|
| 254 |
+
:param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
|
| 255 |
+
:param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
|
| 256 |
+
:param box_color: Color of the boxes. Default is black ('k').
|
| 257 |
+
:param box_thickness: Thickness of the box outlines.
|
| 258 |
+
:param polygons: List of polygons (N, 2) to draw on the heatmap.
|
| 259 |
+
:param polygons_color: Color of the polygon outlines. Default is black ('k').
|
| 260 |
+
:param polygons_thickness: Thickness of the polygon outlines.
|
| 261 |
+
:param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
|
| 262 |
+
|
| 263 |
+
:return:
|
| 264 |
+
- heatmap: The generated heatmap as a numpy array (RGB).
|
| 265 |
+
- image: The original image with overlaid polygons if provided.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
# Read the background image (if provided), otherwise a blank image
|
| 269 |
+
image = cv2.imread(image_path)
|
| 270 |
+
image_size = (image.shape[0], image.shape[1]) # Get the size of the image
|
| 271 |
+
coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
|
| 272 |
+
patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
|
| 273 |
+
|
| 274 |
+
# Convert similairties to colors using the provided colormap
|
| 275 |
+
cmap = plt.get_cmap(cmap) # Get the colormap object
|
| 276 |
+
norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
|
| 277 |
+
colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
|
| 278 |
+
|
| 279 |
+
# Initialize a blank white heatmap the size of the image
|
| 280 |
+
heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
|
| 281 |
+
|
| 282 |
+
# Place the colored patches on the heatmap according to the coordinates and patch size
|
| 283 |
+
for i in range(len(coor)):
|
| 284 |
+
x, y = coor[i]
|
| 285 |
+
w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
|
| 286 |
+
w = w.astype(np.uint8) # Convert the color to uint8
|
| 287 |
+
heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
|
| 288 |
+
|
| 289 |
+
# If the image_alpha is greater than 0, blend the heatmap with the original image
|
| 290 |
+
if image_alpha > 0:
|
| 291 |
+
image = np.array(image)
|
| 292 |
+
|
| 293 |
+
# Pad the image if necessary to match the heatmap size
|
| 294 |
+
if image.shape[0] < heatmap.shape[0]:
|
| 295 |
+
pad = heatmap.shape[0] - image.shape[0]
|
| 296 |
+
image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
|
| 297 |
+
if image.shape[1] < heatmap.shape[1]:
|
| 298 |
+
pad = heatmap.shape[1] - heatmap.shape[1]
|
| 299 |
+
image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
|
| 300 |
+
|
| 301 |
+
# Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
|
| 302 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 303 |
+
image = image.astype(np.uint8)
|
| 304 |
+
heatmap = heatmap.astype(np.uint8)
|
| 305 |
+
heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
|
| 306 |
+
|
| 307 |
+
# If polygons are provided, draw them on the heatmap and image
|
| 308 |
+
if polygons is not None:
|
| 309 |
+
polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
|
| 310 |
+
image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
|
| 311 |
+
heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
|
| 312 |
+
|
| 313 |
+
return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
|
| 314 |
+
else:
|
| 315 |
+
return heatmap, image # Return the heatmap and image
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def show_images_side_by_side(image1, image2, title1=None, title2=None):
|
| 320 |
+
"""
|
| 321 |
+
Displays two images side by side in a single figure.
|
| 322 |
+
|
| 323 |
+
:param image1: The first image to display (as a numpy array).
|
| 324 |
+
:param image2: The second image to display (as a numpy array).
|
| 325 |
+
:param title1: The title for the first image. Default is None (no title).
|
| 326 |
+
:param title2: The title for the second image. Default is None (no title).
|
| 327 |
+
:return: Displays the images side by side.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
# Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
|
| 331 |
+
fig, ax = plt.subplots(1, 2, figsize=(15,8))
|
| 332 |
+
|
| 333 |
+
# Display the first image on the first subplot
|
| 334 |
+
ax[0].imshow(image1)
|
| 335 |
+
|
| 336 |
+
# Display the second image on the second subplot
|
| 337 |
+
ax[1].imshow(image2)
|
| 338 |
+
|
| 339 |
+
# Set the title for the first image (if provided)
|
| 340 |
+
ax[0].set_title(title1)
|
| 341 |
+
|
| 342 |
+
# Set the title for the second image (if provided)
|
| 343 |
+
ax[1].set_title(title2)
|
| 344 |
+
|
| 345 |
+
# Remove axis labels and ticks for both images to give a cleaner look
|
| 346 |
+
ax[0].axis('off')
|
| 347 |
+
ax[1].axis('off')
|
| 348 |
+
|
| 349 |
+
# Show the final figure with both images displayed side by side
|
| 350 |
+
plt.show()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
|
| 355 |
+
"""
|
| 356 |
+
Plots image with polygons.
|
| 357 |
+
|
| 358 |
+
:param fullres_img: The full-resolution image to display (as a numpy array).
|
| 359 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 360 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 361 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 362 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 363 |
+
:return: Displays the image with ROI polygons overlaid.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Create a new figure with a fixed size for displaying the image and annotations
|
| 367 |
+
plt.figure(figsize=(10, 10))
|
| 368 |
+
|
| 369 |
+
# Display the full-resolution image
|
| 370 |
+
plt.imshow(fullres_img)
|
| 371 |
+
|
| 372 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 373 |
+
for polygon in roi_polygon:
|
| 374 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 375 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 376 |
+
|
| 377 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 378 |
+
plt.xlim(xlim)
|
| 379 |
+
|
| 380 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 381 |
+
plt.ylim(ylim)
|
| 382 |
+
|
| 383 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 384 |
+
plt.gca().invert_yaxis()
|
| 385 |
+
|
| 386 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 387 |
+
plt.axis('off')
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
|
| 392 |
+
"""
|
| 393 |
+
Plots tissue type annotation heatmap.
|
| 394 |
+
|
| 395 |
+
:param st_ad: AnnData object containing coordinates in `obsm['spatial']`
|
| 396 |
+
and similarity scores in `obs['bulk_simi']`.
|
| 397 |
+
:param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
|
| 398 |
+
:param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
|
| 399 |
+
:param linewidth: The thickness of the lines used to draw the polygons.
|
| 400 |
+
:param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
|
| 401 |
+
:param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
|
| 402 |
+
:return: Displays the heatmap with polygons overlaid.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
# Create a new figure with a fixed size for displaying the heatmap and annotations
|
| 406 |
+
plt.figure(figsize=(10, 10))
|
| 407 |
+
|
| 408 |
+
# Scatter plot for the spatial transcriptomics data.
|
| 409 |
+
# The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
|
| 410 |
+
plt.scatter(
|
| 411 |
+
st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
|
| 412 |
+
c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
|
| 413 |
+
s=s, # Size of each marker
|
| 414 |
+
vmin=0.1, vmax=0.95, # Set the range for the color normalization
|
| 415 |
+
cmap='turbo' # Use the 'turbo' colormap for the heatmap
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Loop through each polygon in roi_polygon and plot them on the image
|
| 419 |
+
for polygon in roi_polygon:
|
| 420 |
+
x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
|
| 421 |
+
plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
|
| 422 |
+
|
| 423 |
+
# Set the x-axis limits based on the provided tuple (xlim)
|
| 424 |
+
plt.xlim(xlim)
|
| 425 |
+
|
| 426 |
+
# Set the y-axis limits based on the provided tuple (ylim)
|
| 427 |
+
plt.ylim(ylim)
|
| 428 |
+
|
| 429 |
+
# Invert the y-axis to match the typical image display convention (origin at the top-left)
|
| 430 |
+
plt.gca().invert_yaxis()
|
| 431 |
+
|
| 432 |
+
# Turn off the axis to give a cleaner image display without ticks or labels
|
| 433 |
+
plt.axis('off')
|
| 434 |
+
|
| 435 |
+
|
src/loki/predex.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict_st_gene_expr(image_text_similarity, train_data):
|
| 6 |
+
"""
|
| 7 |
+
Predicts ST gene expression by H&E image.
|
| 8 |
+
|
| 9 |
+
:param image_text_similarity: Numpy array of similarities between images and text features (shape: [n_samples, n_genes]).
|
| 10 |
+
:param train_data: Numpy array or DataFrame of training data used for making predictions (shape: [n_genes, n_shared_genes]).
|
| 11 |
+
:return: Numpy array or DataFrame containing the predicted gene expression levels for the samples.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Compute the weighted sum of the train_data using image_text_similarity
|
| 15 |
+
weighted_sum = image_text_similarity @ train_data
|
| 16 |
+
|
| 17 |
+
# Compute the normalization factor (sum of the image-text similarities for each sample)
|
| 18 |
+
weights = image_text_similarity.sum(axis=1, keepdims=True)
|
| 19 |
+
|
| 20 |
+
# Normalize the predicted matrix to get weighted gene expression predictions
|
| 21 |
+
predicted_image_text_matrix = weighted_sum / weights
|
| 22 |
+
|
| 23 |
+
return predicted_image_text_matrix
|
| 24 |
+
|
| 25 |
+
|
src/loki/preprocess.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scanpy as sc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_gene_df(ad, house_keeping_genes, todense=True):
|
| 11 |
+
"""
|
| 12 |
+
Generates a DataFrame with the top 50 genes for each observation in an AnnData object.
|
| 13 |
+
It removes genes containing '.' or '-' in their names, as well as genes listed in
|
| 14 |
+
the provided `house_keeping_genes` DataFrame/Series under the 'genesymbol' column.
|
| 15 |
+
|
| 16 |
+
:param ad: An AnnData object containing gene expression data.
|
| 17 |
+
:type ad: anndata.AnnData
|
| 18 |
+
:param house_keeping_genes: DataFrame or Series with a 'genesymbol' column listing housekeeping genes to exclude.
|
| 19 |
+
:type house_keeping_genes: pandas.DataFrame or pandas.Series
|
| 20 |
+
:param todense: Whether to convert the sparse matrix (ad.X) to a dense matrix before creating a DataFrame.
|
| 21 |
+
:type todense: bool
|
| 22 |
+
:return: A DataFrame (`top_k_genes_str`) that contains a 'label' column. Each row in 'label' is a string
|
| 23 |
+
with the top 50 gene names (space-separated) for that observation.
|
| 24 |
+
:rtype: pandas.DataFrame
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Remove genes containing '.' in their names
|
| 28 |
+
ad = ad[:, ~ad.var.index.str.contains('.', regex=False)]
|
| 29 |
+
# Remove genes containing '-'
|
| 30 |
+
ad = ad[:, ~ad.var.index.str.contains('-', regex=False)]
|
| 31 |
+
# Exclude housekeeping genes
|
| 32 |
+
ad = ad[:, ~ad.var.index.isin(house_keeping_genes['genesymbol'])]
|
| 33 |
+
|
| 34 |
+
# Convert to dense if requested; otherwise use the data as-is
|
| 35 |
+
if todense:
|
| 36 |
+
expr = pd.DataFrame(ad.X.todense(), index=ad.obs.index, columns=ad.var.index)
|
| 37 |
+
else:
|
| 38 |
+
expr = pd.DataFrame(ad.X, index=ad.obs.index, columns=ad.var.index)
|
| 39 |
+
|
| 40 |
+
# For each row (observation), find the top 50 genes with the highest expression
|
| 41 |
+
top_k_genes = expr.apply(lambda s, n: pd.Series(s.nlargest(n).index), axis=1, n=50)
|
| 42 |
+
|
| 43 |
+
# Create a new DataFrame to store the labels (space-separated top gene names)
|
| 44 |
+
top_k_genes_str = pd.DataFrame()
|
| 45 |
+
top_k_genes_str['label'] = top_k_genes[top_k_genes.columns].astype(str) \
|
| 46 |
+
.apply(lambda x: ' '.join(x), axis=1)
|
| 47 |
+
|
| 48 |
+
return top_k_genes_str
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def segment_patches(img_array, coord, patch_dir, height=20, width=20):
|
| 53 |
+
"""
|
| 54 |
+
Extracts small image patches centered at specified coordinates and saves them as individual PNG files.
|
| 55 |
+
|
| 56 |
+
:param img_array: A NumPy array representing the full-resolution image. Shape is expected to be (H, W[, C]).
|
| 57 |
+
:type img_array: numpy.ndarray
|
| 58 |
+
:param coord: A pandas DataFrame containing patch center coordinates in columns "pixel_x" and "pixel_y".
|
| 59 |
+
The index corresponds to spot IDs. Example columns: ["pixel_x", "pixel_y"].
|
| 60 |
+
:type coord: pandas.DataFrame
|
| 61 |
+
:param patch_dir: Directory path where the patch images will be saved.
|
| 62 |
+
:type patch_dir: str
|
| 63 |
+
:param height: The patch's height in pixels (distance in the y-direction).
|
| 64 |
+
:type height: int
|
| 65 |
+
:param width: The patch's width in pixels (distance in the x-direction).
|
| 66 |
+
:type width: int
|
| 67 |
+
:return: None. The function saves image patches to `patch_dir` but does not return anything.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Ensure the output directory exists; create it if it doesn't
|
| 71 |
+
if not os.path.exists(patch_dir):
|
| 72 |
+
os.makedirs(patch_dir)
|
| 73 |
+
|
| 74 |
+
# Extract the overall height and width of the image
|
| 75 |
+
yrange, xrange = img_array.shape[:2]
|
| 76 |
+
|
| 77 |
+
# Iterate through each coordinate in the DataFrame
|
| 78 |
+
for spot_idx in coord.index:
|
| 79 |
+
# Retrieve the center x and y coordinates for the current spot
|
| 80 |
+
ycenter, xcenter = coord.loc[spot_idx, ["pixel_x", "pixel_y"]]
|
| 81 |
+
|
| 82 |
+
# Compute the top-left (x1, y1) and bottom-right (x2, y2) boundaries of the patch
|
| 83 |
+
x1 = round(xcenter - width / 2)
|
| 84 |
+
y1 = round(ycenter - height / 2)
|
| 85 |
+
x2 = x1 + width
|
| 86 |
+
y2 = y1 + height
|
| 87 |
+
|
| 88 |
+
# Check if the patch boundaries go outside the image
|
| 89 |
+
if x1 < 0 or y1 < 0 or x2 > xrange or y2 > yrange:
|
| 90 |
+
print(f"Patch {spot_idx} is out of range and will be skipped.")
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# Extract the patch and convert to a PIL Image; cast to uint8 if needed
|
| 94 |
+
patch_img = Image.fromarray(img_array[y1:y2, x1:x2].astype(np.uint8))
|
| 95 |
+
|
| 96 |
+
# Create a filename for the patch image (e.g., "0_hires.png")
|
| 97 |
+
patch_name = f"{spot_idx}_hires.png"
|
| 98 |
+
patch_path = os.path.join(patch_dir, patch_name)
|
| 99 |
+
|
| 100 |
+
# Save the patch image to disk
|
| 101 |
+
patch_img.save(patch_path)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def read_gct(file_path):
|
| 106 |
+
"""
|
| 107 |
+
Reads a GCT file, parses its dimensions, and returns the data as a pandas DataFrame.
|
| 108 |
+
|
| 109 |
+
:param file_path: The path to the GCT file to be read.
|
| 110 |
+
:return: A pandas DataFrame containing the GCT data, where the first two columns represent gene names and descriptions,
|
| 111 |
+
and the subsequent columns contain the expression data.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# Open the GCT file for reading
|
| 115 |
+
with open(file_path, 'r') as file:
|
| 116 |
+
# Read and ignore the first line (GCT version line)
|
| 117 |
+
file.readline()
|
| 118 |
+
|
| 119 |
+
# Read the second line which contains the dimensions of the data matrix
|
| 120 |
+
dims = file.readline().strip().split() # Split the dimensions line by whitespace
|
| 121 |
+
num_rows = int(dims[0]) # Number of data rows (genes)
|
| 122 |
+
num_cols = int(dims[1]) # Number of data columns (samples + metadata)
|
| 123 |
+
|
| 124 |
+
# Read the data starting from the third line, using pandas for tab-delimited data
|
| 125 |
+
# The first two columns in GCT files are "Name" and "Description" (gene identifiers and annotations)
|
| 126 |
+
data = pd.read_csv(file, sep='\t', header=0, nrows=num_rows)
|
| 127 |
+
|
| 128 |
+
# Return the loaded data as a pandas DataFrame
|
| 129 |
+
return data
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_library_id(adata):
|
| 134 |
+
"""
|
| 135 |
+
Retrieves the library ID from the AnnData object, assuming it contains spatial data.
|
| 136 |
+
The function will return the first library ID found in `adata.uns['spatial']`.
|
| 137 |
+
|
| 138 |
+
:param adata: AnnData object containing spatial information in `adata.uns['spatial']`.
|
| 139 |
+
:return: The first library ID found in `adata.uns['spatial']`.
|
| 140 |
+
:raises:
|
| 141 |
+
AssertionError: If 'spatial' is not present in `adata.uns`.
|
| 142 |
+
Logs an error if no library ID is found.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# Check if 'spatial' is present in adata.uns; raises an error if not found
|
| 146 |
+
assert 'spatial' in adata.uns, "spatial not present in adata.uns"
|
| 147 |
+
|
| 148 |
+
# Retrieve the list of library IDs (which are keys in the 'spatial' dictionary)
|
| 149 |
+
library_ids = adata.uns['spatial'].keys()
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Attempt to return the first library ID (converting the keys object to a list)
|
| 153 |
+
library_id = list(library_ids)[0]
|
| 154 |
+
return library_id
|
| 155 |
+
except IndexError:
|
| 156 |
+
# If no library IDs exist, log an error message
|
| 157 |
+
logger.error('No library_id found in adata')
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_scalefactors(adata, library_id=None):
|
| 162 |
+
"""
|
| 163 |
+
Retrieves the scalefactors from the AnnData object for a given library ID. If no library ID is provided,
|
| 164 |
+
the function will automatically retrieve the first available library ID.
|
| 165 |
+
|
| 166 |
+
:param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
|
| 167 |
+
:param library_id: The library ID for which the scalefactors are to be retrieved. If not provided, it defaults to the first available ID.
|
| 168 |
+
:return: A dictionary containing scalefactors for the specified library ID.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# If no library_id is provided, retrieve the first available library ID
|
| 172 |
+
if library_id is None:
|
| 173 |
+
library_id = get_library_id(adata)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
# Attempt to retrieve the scalefactors for the specified library ID
|
| 177 |
+
scalef = adata.uns['spatial'][library_id]['scalefactors']
|
| 178 |
+
return scalef
|
| 179 |
+
except KeyError:
|
| 180 |
+
# Log an error if the scalefactors or library ID is not found
|
| 181 |
+
logger.error('scalefactors not found in adata')
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_spot_diameter_in_pixels(adata, library_id=None):
|
| 186 |
+
"""
|
| 187 |
+
Retrieves the spot diameter in pixels from the AnnData object's scalefactors for a given library ID.
|
| 188 |
+
If no library ID is provided, the function will automatically retrieve the first available library ID.
|
| 189 |
+
|
| 190 |
+
:param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
|
| 191 |
+
:param library_id: The library ID for which the spot diameter is to be retrieved. If not provided, defaults to the first available ID.
|
| 192 |
+
|
| 193 |
+
:return: The spot diameter in full resolution pixels, or None if not found.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
# Get the scalefactors for the specified or default library ID
|
| 197 |
+
scalef = get_scalefactors(adata, library_id=library_id)
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
# Attempt to retrieve the spot diameter in full resolution from the scalefactors
|
| 201 |
+
spot_diameter = scalef['spot_diameter_fullres']
|
| 202 |
+
return spot_diameter
|
| 203 |
+
except TypeError:
|
| 204 |
+
# Handle case where `scalef` is None or invalid (if get_scalefactors returned None)
|
| 205 |
+
pass
|
| 206 |
+
except KeyError:
|
| 207 |
+
# Log an error if the 'spot_diameter_fullres' key is not found in the scalefactors
|
| 208 |
+
logger.error('spot_diameter_fullres not found in adata')
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def prepare_data_for_alignment(data_path, scale_type='tissue_hires_scalef'):
|
| 213 |
+
"""
|
| 214 |
+
Prepares data for alignment by reading an AnnData object and preparing the high-resolution tissue image.
|
| 215 |
+
|
| 216 |
+
:param data_path: The path to the AnnData (.h5ad) file containing the Visium data.
|
| 217 |
+
:param scale_type: The type of scale factor to use (`tissue_hires_scalef` by default).
|
| 218 |
+
|
| 219 |
+
:return:
|
| 220 |
+
- ad: AnnData object containing the spatial transcriptomics data.
|
| 221 |
+
- ad_coor: Numpy array of scaled spatial coordinates (adjusted for the specified resolution).
|
| 222 |
+
- img: High-resolution tissue image, normalized to 8-bit unsigned integers.
|
| 223 |
+
|
| 224 |
+
:raises:
|
| 225 |
+
ValueError: If required data (e.g., scale factors, spatial coordinates, or images) is missing.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
# Load the AnnData object from the specified file path
|
| 229 |
+
ad = sc.read_h5ad(data_path)
|
| 230 |
+
|
| 231 |
+
# Ensure the variable (gene) names are unique to avoid potential conflicts
|
| 232 |
+
ad.var_names_make_unique()
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Retrieve the specified scale factor for spatial coordinates
|
| 236 |
+
scalef = get_scalefactors(ad)[scale_type]
|
| 237 |
+
except KeyError:
|
| 238 |
+
raise ValueError(f"Scale factor '{scale_type}' not found in ad.uns['spatial']")
|
| 239 |
+
|
| 240 |
+
# Scale the spatial coordinates using the specified scale factor
|
| 241 |
+
try:
|
| 242 |
+
ad_coor = np.array(ad.obsm['spatial']) * scalef
|
| 243 |
+
except KeyError:
|
| 244 |
+
raise ValueError("Spatial coordinates not found in ad.obsm['spatial']")
|
| 245 |
+
|
| 246 |
+
# Retrieve the high-resolution tissue image
|
| 247 |
+
try:
|
| 248 |
+
img = ad.uns['spatial'][get_library_id(ad)]['images']['hires']
|
| 249 |
+
except KeyError:
|
| 250 |
+
raise ValueError("High-resolution image not found in ad.uns['spatial']")
|
| 251 |
+
|
| 252 |
+
# If the image values are normalized to [0, 1], convert to 8-bit format for compatibility
|
| 253 |
+
if img.max() < 1.1:
|
| 254 |
+
img = (img * 255).astype('uint8')
|
| 255 |
+
|
| 256 |
+
return ad, ad_coor, img
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def load_data_for_annotation(st_data_path, json_path, in_tissue=True):
|
| 261 |
+
"""
|
| 262 |
+
Loads spatial transcriptomics (ST) data from an .h5ad file and prepares it for annotation.
|
| 263 |
+
|
| 264 |
+
:param sample_type: The type or category of the sample (used to locate the data in the directory structure).
|
| 265 |
+
:param sample_name: The name of the sample (used to locate specific files).
|
| 266 |
+
:param in_tissue: Boolean flag to filter the data to include only spots that are in tissue. Default is True.
|
| 267 |
+
|
| 268 |
+
:return:
|
| 269 |
+
- st_ad: AnnData object containing the spatial transcriptomics data, with spatial coordinates in `obs`.
|
| 270 |
+
- library_id: The library ID associated with the spatial data.
|
| 271 |
+
- roi_polygon: Region of interest polygon loaded from a JSON file for further annotation or analysis.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
# Load the spatial transcriptomics data into an AnnData object
|
| 275 |
+
st_ad = sc.read_h5ad(st_data_path)
|
| 276 |
+
|
| 277 |
+
# Optionally filter the data to include only spots that are within the tissue
|
| 278 |
+
if in_tissue:
|
| 279 |
+
st_ad = st_ad[st_ad.obs['in_tissue'] == 1]
|
| 280 |
+
|
| 281 |
+
# Initialize pixel coordinates for spatial information
|
| 282 |
+
st_ad.obs[["pixel_y", "pixel_x"]] = None # Ensure the columns exist
|
| 283 |
+
st_ad.obs[["pixel_y", "pixel_x"]] = st_ad.obsm['spatial'] # Copy spatial coordinates into obs
|
| 284 |
+
|
| 285 |
+
# Retrieve the library ID associated with the spatial data
|
| 286 |
+
library_id = get_library_id(st_ad)
|
| 287 |
+
|
| 288 |
+
# Load the region of interest (ROI) polygon from a JSON file
|
| 289 |
+
with open(json_path) as f:
|
| 290 |
+
roi_polygon = json.load(f)
|
| 291 |
+
|
| 292 |
+
return st_ad, library_id, roi_polygon
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def read_polygons(file_path, slide_id):
|
| 297 |
+
"""
|
| 298 |
+
Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
|
| 299 |
+
|
| 300 |
+
:param file_path: Path to the JSON file containing polygon configurations.
|
| 301 |
+
:param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
|
| 302 |
+
:return:
|
| 303 |
+
- polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
|
| 304 |
+
- polygon_colors: A list of color values corresponding to each polygon.
|
| 305 |
+
- polygon_thickness: A list of thickness values for each polygon's border.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
# Open the JSON file and load the polygon configurations into a Python dictionary
|
| 309 |
+
with open(file_path, 'r') as f:
|
| 310 |
+
polygons_configs = json.load(f)
|
| 311 |
+
|
| 312 |
+
# Check if the given slide_id exists in the polygon configurations
|
| 313 |
+
if slide_id not in polygons_configs:
|
| 314 |
+
return None, None, None # If slide_id is not found, return None for all outputs
|
| 315 |
+
|
| 316 |
+
# Extract the polygon coordinates, colors, and thicknesses for the given slide_id
|
| 317 |
+
polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
|
| 318 |
+
polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
|
| 319 |
+
polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
|
| 320 |
+
|
| 321 |
+
# Return the polygons, their colors, and their thicknesses
|
| 322 |
+
return polygons, polygon_colors, polygon_thickness
|
| 323 |
+
|
| 324 |
+
|
src/loki/requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anndata==0.10.9
|
| 2 |
+
matplotlib==3.9.2
|
| 3 |
+
numpy==1.25.0
|
| 4 |
+
pandas==2.2.3
|
| 5 |
+
opencv-python==4.10.0.84
|
| 6 |
+
pycpd==2.0.0
|
| 7 |
+
torch==2.3.1
|
| 8 |
+
tangram-sc==1.0.4
|
| 9 |
+
tqdm==4.66.5
|
| 10 |
+
torchvision==0.18.1
|
| 11 |
+
open_clip_torch==2.26.1
|
| 12 |
+
pillow==10.4.0
|
| 13 |
+
ipykernel==6.29.5
|
| 14 |
+
|
src/loki/retrieve.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3):
|
| 6 |
+
"""
|
| 7 |
+
Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings.
|
| 8 |
+
|
| 9 |
+
:param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]).
|
| 10 |
+
:param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]).
|
| 11 |
+
:param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column.
|
| 12 |
+
:param k: The number of top similar samples to retrieve. Default is 3.
|
| 13 |
+
:return: A list of the filenames or indices corresponding to the top-k similar samples.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# Compute the dot product (similarity) between the image embeddings and all ST embeddings
|
| 17 |
+
dot_similarity = image_embeddings @ all_text_embeddings.T
|
| 18 |
+
|
| 19 |
+
# Retrieve the top-k most similar samples by similarity score (dot product)
|
| 20 |
+
values, indices = torch.topk(dot_similarity.squeeze(0), k)
|
| 21 |
+
|
| 22 |
+
# Extract the image filenames or indices from the DataFrame based on the top-k matches
|
| 23 |
+
image_filenames = dataframe['img_idx'].values
|
| 24 |
+
matches = [image_filenames[idx] for idx in indices]
|
| 25 |
+
|
| 26 |
+
return matches
|
| 27 |
+
|
| 28 |
+
|
src/loki/utils.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import json
|
| 8 |
+
import cv2
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model(model_path, device):
|
| 15 |
+
"""
|
| 16 |
+
Loads a pretrained OmiCLIP model, along with its preprocessing function and tokenizer,
|
| 17 |
+
using the specified model checkpoint.
|
| 18 |
+
|
| 19 |
+
:param model_path: File path to the pretrained model checkpoint. This is passed to
|
| 20 |
+
`create_model_from_pretrained` as the `pretrained` argument.
|
| 21 |
+
:type model_path: str
|
| 22 |
+
:param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
|
| 23 |
+
:type device: str or torch.device
|
| 24 |
+
:return: A tuple `(model, preprocess, tokenizer)` where:
|
| 25 |
+
- model: The loaded OmiCLIP model.
|
| 26 |
+
- preprocess: A function or transform that preprocesses input data for the model.
|
| 27 |
+
- tokenizer: A tokenizer appropriate for textual input to the model.
|
| 28 |
+
:rtype: (nn.Module, callable, callable)
|
| 29 |
+
"""
|
| 30 |
+
# Create the model and its preprocessing transform from the specified checkpoint
|
| 31 |
+
model, preprocess = create_model_from_pretrained(
|
| 32 |
+
"coca_ViT-L-14", device=device, pretrained=model_path
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
|
| 36 |
+
tokenizer = get_tokenizer('coca_ViT-L-14')
|
| 37 |
+
|
| 38 |
+
return model, preprocess, tokenizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def encode_image(model, preprocess, image):
|
| 43 |
+
"""
|
| 44 |
+
Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
|
| 45 |
+
|
| 46 |
+
:param model: A model object that provides an `encode_image` method.
|
| 47 |
+
:type model: torch.nn.Module
|
| 48 |
+
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 49 |
+
suitable for the model. Typically something returning a PyTorch tensor.
|
| 50 |
+
:type preprocess: callable
|
| 51 |
+
:param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
|
| 52 |
+
:type image: PIL.Image.Image or numpy.ndarray
|
| 53 |
+
:return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
|
| 54 |
+
:rtype: torch.Tensor
|
| 55 |
+
"""
|
| 56 |
+
# Preprocess the image, then stack to create a batch of size 1
|
| 57 |
+
image_input = torch.stack([preprocess(image)])
|
| 58 |
+
|
| 59 |
+
# Generate the image features without gradient tracking
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
image_features = model.encode_image(image_input)
|
| 62 |
+
|
| 63 |
+
# Normalize embeddings across the feature dimension (L2 normalization)
|
| 64 |
+
image_embeddings = F.normalize(image_features, p=2, dim=-1)
|
| 65 |
+
|
| 66 |
+
return image_embeddings
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_image_patches(model, preprocess, data_dir, img_list):
|
| 71 |
+
"""
|
| 72 |
+
Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
|
| 73 |
+
|
| 74 |
+
:param model: A model object that provides an `encode_image` method.
|
| 75 |
+
:type model: torch.nn.Module
|
| 76 |
+
:param preprocess: A preprocessing function that transforms the input image into a tensor
|
| 77 |
+
suitable for the model. Typically something returning a PyTorch tensor.
|
| 78 |
+
:type preprocess: callable
|
| 79 |
+
:param data_dir: The base directory containing image data.
|
| 80 |
+
:type data_dir: str
|
| 81 |
+
:param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
|
| 82 |
+
stored in `data_dir/demo_data/patch/`.
|
| 83 |
+
:type img_list: list[str]
|
| 84 |
+
:return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
|
| 85 |
+
for each image in `img_list`.
|
| 86 |
+
:rtype: torch.Tensor
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# Prepare a list to hold each image's feature embedding
|
| 90 |
+
image_embeddings = []
|
| 91 |
+
|
| 92 |
+
# Loop through each image name in the provided list
|
| 93 |
+
for img_name in img_list:
|
| 94 |
+
# Build the path to the patch image and open it
|
| 95 |
+
image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
|
| 96 |
+
image = Image.open(image_path)
|
| 97 |
+
|
| 98 |
+
# Encode the image using the model & preprocess; returns shape (1, embedding_dim)
|
| 99 |
+
image_features = encode_image(model, preprocess, image)
|
| 100 |
+
|
| 101 |
+
# Accumulate the feature embeddings in the list
|
| 102 |
+
image_embeddings.append(image_features)
|
| 103 |
+
|
| 104 |
+
# Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
|
| 105 |
+
# Resulting shape will be (N, 1, embedding_dim)
|
| 106 |
+
image_embeddings = torch.from_numpy(np.array(image_embeddings))
|
| 107 |
+
|
| 108 |
+
# Normalize all embeddings across the feature dimension (L2 normalization)
|
| 109 |
+
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
|
| 110 |
+
|
| 111 |
+
return image_embeddings
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def encode_text(model, tokenizer, text):
|
| 116 |
+
"""
|
| 117 |
+
Encodes text into a normalized feature embedding using a specified model and tokenizer.
|
| 118 |
+
|
| 119 |
+
:param model: A model object that provides an `encode_text` method.
|
| 120 |
+
:type model: torch.nn.Module
|
| 121 |
+
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 122 |
+
Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
|
| 123 |
+
:type tokenizer: callable
|
| 124 |
+
:param text: The input text (string or list of strings) to be encoded.
|
| 125 |
+
:type text: str or list[str]
|
| 126 |
+
:return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
|
| 127 |
+
:rtype: torch.Tensor
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
# Convert text to the appropriate tokenized representation
|
| 131 |
+
text_input = tokenizer(text)
|
| 132 |
+
|
| 133 |
+
# Run the model in no-grad mode (not tracking gradients, saving memory and compute)
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
text_features = model.encode_text(text_input)
|
| 136 |
+
|
| 137 |
+
# Normalize embeddings to unit length
|
| 138 |
+
text_embeddings = F.normalize(text_features, p=2, dim=-1)
|
| 139 |
+
|
| 140 |
+
return text_embeddings
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def encode_text_df(model, tokenizer, df, col_name):
|
| 145 |
+
"""
|
| 146 |
+
Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
|
| 147 |
+
returning a PyTorch tensor of normalized text embeddings.
|
| 148 |
+
|
| 149 |
+
:param model: A model object that provides an `encode_text` method.
|
| 150 |
+
:type model: torch.nn.Module
|
| 151 |
+
:param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
|
| 152 |
+
:type tokenizer: callable
|
| 153 |
+
:param df: A pandas DataFrame from which text will be extracted.
|
| 154 |
+
:type df: pandas.DataFrame
|
| 155 |
+
:param col_name: The name of the column in `df` that contains the text to be encoded.
|
| 156 |
+
:type col_name: str
|
| 157 |
+
:return: A PyTorch tensor containing the L2-normalized text embeddings,
|
| 158 |
+
where the shape is (number_of_rows, embedding_dim).
|
| 159 |
+
:rtype: torch.Tensor
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
# Prepare a list to hold each row's text embedding
|
| 163 |
+
text_embeddings = []
|
| 164 |
+
|
| 165 |
+
# Loop through each index in the DataFrame
|
| 166 |
+
for idx in df.index:
|
| 167 |
+
# Retrieve text from the specified column for the current row
|
| 168 |
+
text = df[df.index == idx][col_name][0]
|
| 169 |
+
|
| 170 |
+
# Encode the text using the provided model and tokenizer
|
| 171 |
+
text_features = encode_text(model, tokenizer, text)
|
| 172 |
+
|
| 173 |
+
# Accumulate the embedding tensor
|
| 174 |
+
text_embeddings.append(text_features)
|
| 175 |
+
|
| 176 |
+
# Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
|
| 177 |
+
text_embeddings = torch.from_numpy(np.array(text_embeddings))
|
| 178 |
+
|
| 179 |
+
# Normalize embeddings to unit length across the feature dimension
|
| 180 |
+
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
|
| 181 |
+
|
| 182 |
+
return text_embeddings
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_pca_by_fit(tar_features, src_features):
|
| 187 |
+
"""
|
| 188 |
+
Applies PCA to target features and transforms both target and source features using the fitted PCA model.
|
| 189 |
+
Combines the PCA-transformed features from both target and source datasets and returns the combined data
|
| 190 |
+
along with batch labels indicating the origin of each sample.
|
| 191 |
+
|
| 192 |
+
:param tar_features: Numpy array of target features (samples by features).
|
| 193 |
+
:param src_features: Numpy array of source features (samples by features).
|
| 194 |
+
:return:
|
| 195 |
+
- pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
|
| 196 |
+
- pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
pca = PCA(n_components=3)
|
| 200 |
+
|
| 201 |
+
# Fit the PCA model on the target features (transposed to fit on features)
|
| 202 |
+
pca_fit_tar = pca.fit(tar_features.T)
|
| 203 |
+
|
| 204 |
+
# Transform the target and source features using the fitted PCA model
|
| 205 |
+
pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
|
| 206 |
+
pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
|
| 207 |
+
|
| 208 |
+
# Combine the PCA-transformed target and source features
|
| 209 |
+
pca_comb_features = np.concatenate((pca_tar, pca_src))
|
| 210 |
+
|
| 211 |
+
# Create a batch label array: 0 for target features, 1 for source features
|
| 212 |
+
pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
|
| 213 |
+
|
| 214 |
+
return pca_comb_features, pca_comb_features_batch
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def cap_quantile(weight, cap_max=None, cap_min=None):
|
| 219 |
+
"""
|
| 220 |
+
Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
|
| 221 |
+
If the quantile thresholds are provided, the function will replace values above or below these thresholds
|
| 222 |
+
with the corresponding quantile values.
|
| 223 |
+
|
| 224 |
+
:param weight: Numpy array of weights to be capped.
|
| 225 |
+
:param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
|
| 226 |
+
If None, no maximum capping will be applied.
|
| 227 |
+
:param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
|
| 228 |
+
If None, no minimum capping will be applied.
|
| 229 |
+
:return: Numpy array with the values capped at the specified quantiles.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
# If a maximum cap is specified, calculate the value at the specified cap_max quantile
|
| 233 |
+
if cap_max is not None:
|
| 234 |
+
cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
|
| 235 |
+
|
| 236 |
+
# If a minimum cap is specified, calculate the value at the specified cap_min quantile
|
| 237 |
+
if cap_min is not None:
|
| 238 |
+
cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
|
| 239 |
+
|
| 240 |
+
# Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
|
| 241 |
+
weight = np.minimum(weight, cap_max)
|
| 242 |
+
|
| 243 |
+
# Cap the values in 'weight' array to not go below the minimum cap (cap_min)
|
| 244 |
+
weight = np.maximum(weight, cap_min)
|
| 245 |
+
|
| 246 |
+
return weight
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def read_polygons(file_path, slide_id):
|
| 251 |
+
"""
|
| 252 |
+
Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
|
| 253 |
+
|
| 254 |
+
:param file_path: Path to the JSON file containing polygon configurations.
|
| 255 |
+
:param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
|
| 256 |
+
:return:
|
| 257 |
+
- polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
|
| 258 |
+
- polygon_colors: A list of color values corresponding to each polygon.
|
| 259 |
+
- polygon_thickness: A list of thickness values for each polygon's border.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
# Open the JSON file and load the polygon configurations into a Python dictionary
|
| 263 |
+
with open(file_path, 'r') as f:
|
| 264 |
+
polygons_configs = json.load(f)
|
| 265 |
+
|
| 266 |
+
# Check if the given slide_id exists in the polygon configurations
|
| 267 |
+
if slide_id not in polygons_configs:
|
| 268 |
+
return None, None, None # If slide_id is not found, return None for all outputs
|
| 269 |
+
|
| 270 |
+
# Extract the polygon coordinates, colors, and thicknesses for the given slide_id
|
| 271 |
+
polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
|
| 272 |
+
polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
|
| 273 |
+
polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
|
| 274 |
+
|
| 275 |
+
# Return the polygons, their colors, and their thicknesses
|
| 276 |
+
return polygons, polygon_colors, polygon_thickness
|
| 277 |
+
|
| 278 |
+
|
src/requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anndata==0.10.9
|
| 2 |
+
matplotlib==3.9.2
|
| 3 |
+
numpy==1.25.0
|
| 4 |
+
pandas==2.2.3
|
| 5 |
+
opencv-python==4.10.0.84
|
| 6 |
+
pycpd==2.0.0
|
| 7 |
+
torch==2.3.1
|
| 8 |
+
tangram-sc==1.0.4
|
| 9 |
+
tqdm==4.66.5
|
| 10 |
+
torchvision==0.18.1
|
| 11 |
+
open_clip_torch==2.26.1
|
| 12 |
+
pillow==10.4.0
|
| 13 |
+
ipykernel==6.29.5
|
| 14 |
+
|
src/setup.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import setuptools
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
setuptools.setup(
|
| 5 |
+
name="loki", # The name of your package on PyPI
|
| 6 |
+
version="0.0.1", # Choose your initial release version
|
| 7 |
+
author="Weiqing Chen",
|
| 8 |
+
author_email="wec4005@med.cornell.edu",
|
| 9 |
+
description="The Loki platform offers 5 core functions: tissue alignment, tissue annotation, cell type decomposition, image-transcriptomics retrieval, and ST gene expression prediction",
|
| 10 |
+
packages=setuptools.find_packages(), # Finds the 'loki' folder automatically
|
| 11 |
+
classifiers=[
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: BSD 3-Clause License",
|
| 14 |
+
"Operating System :: OS Independent",
|
| 15 |
+
],
|
| 16 |
+
python_requires='>=3.9', # or the minimum version you support
|
| 17 |
+
install_requires=[
|
| 18 |
+
"anndata==0.10.9",
|
| 19 |
+
"matplotlib==3.9.2",
|
| 20 |
+
"numpy==1.25.0",
|
| 21 |
+
"pandas==2.2.3",
|
| 22 |
+
"opencv-python==4.10.0.84",
|
| 23 |
+
"pycpd==2.0.0",
|
| 24 |
+
"torch==2.3.1",
|
| 25 |
+
"tangram-sc==1.0.4",
|
| 26 |
+
"tqdm==4.66.5",
|
| 27 |
+
"torchvision==0.18.1",
|
| 28 |
+
"open_clip_torch==2.26.1",
|
| 29 |
+
"pillow==10.4.0",
|
| 30 |
+
"ipykernel==6.29.5",
|
| 31 |
+
],
|
| 32 |
+
)
|