osakemon commited on
Commit
1e315b6
·
verified ·
1 Parent(s): 5c94492

Upload 42 files

Browse files
Files changed (42) hide show
  1. src/LICENSE +29 -0
  2. src/build/lib/loki/__init__.py +0 -0
  3. src/build/lib/loki/align.py +568 -0
  4. src/build/lib/loki/annotate.py +102 -0
  5. src/build/lib/loki/decompose.py +143 -0
  6. src/build/lib/loki/plot.py +435 -0
  7. src/build/lib/loki/plotting.py +435 -0
  8. src/build/lib/loki/predex.py +25 -0
  9. src/build/lib/loki/preprocess.py +324 -0
  10. src/build/lib/loki/retrieve.py +28 -0
  11. src/build/lib/loki/utilities.py +159 -0
  12. src/build/lib/loki/utils.py +278 -0
  13. src/dist/loki-0.0.1-py3-none-any.whl +0 -0
  14. src/dist/loki-0.0.1.tar.gz +3 -0
  15. src/loki.egg-info/PKG-INFO +23 -0
  16. src/loki.egg-info/SOURCES.txt +16 -0
  17. src/loki.egg-info/dependency_links.txt +1 -0
  18. src/loki.egg-info/requires.txt +13 -0
  19. src/loki.egg-info/top_level.txt +1 -0
  20. src/loki/__init__.py +0 -0
  21. src/loki/__pycache__/__init__.cpython-310.pyc +0 -0
  22. src/loki/__pycache__/__init__.cpython-39.pyc +0 -0
  23. src/loki/__pycache__/align.cpython-39.pyc +0 -0
  24. src/loki/__pycache__/annotate.cpython-39.pyc +0 -0
  25. src/loki/__pycache__/decompose.cpython-39.pyc +0 -0
  26. src/loki/__pycache__/deconv.cpython-39.pyc +0 -0
  27. src/loki/__pycache__/plot.cpython-39.pyc +0 -0
  28. src/loki/__pycache__/predex.cpython-39.pyc +0 -0
  29. src/loki/__pycache__/preprocess.cpython-39.pyc +0 -0
  30. src/loki/__pycache__/retrieve.cpython-39.pyc +0 -0
  31. src/loki/__pycache__/utils.cpython-39.pyc +0 -0
  32. src/loki/align.py +568 -0
  33. src/loki/annotate.py +102 -0
  34. src/loki/decompose.py +143 -0
  35. src/loki/plot.py +435 -0
  36. src/loki/predex.py +25 -0
  37. src/loki/preprocess.py +324 -0
  38. src/loki/requirements.txt +14 -0
  39. src/loki/retrieve.py +28 -0
  40. src/loki/utils.py +278 -0
  41. src/requirements.txt +14 -0
  42. src/setup.py +32 -0
src/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2025, Wang Lab
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
src/build/lib/loki/__init__.py ADDED
File without changes
src/build/lib/loki/align.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pycpd
2
+ from builtins import super
3
+ import numbers
4
+ import numpy as np
5
+ import cv2
6
+
7
+ class EMRegistration(object):
8
+ """
9
+ Expectation maximization point cloud registration.
10
+ Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
11
+ https://github.com/siavashk/pycpd
12
+
13
+
14
+ Attributes
15
+ ----------
16
+ X: numpy array
17
+ NxD array of target points.
18
+
19
+ Y: numpy array
20
+ MxD array of source points.
21
+
22
+ TY: numpy array
23
+ MxD array of transformed source points.
24
+
25
+ sigma2: float (positive)
26
+ Initial variance of the Gaussian mixture model.
27
+
28
+ N: int
29
+ Number of target points.
30
+
31
+ M: int
32
+ Number of source points.
33
+
34
+ D: int
35
+ Dimensionality of source and target points
36
+
37
+ iteration: int
38
+ The current iteration throughout registration.
39
+
40
+ max_iterations: int
41
+ Registration will terminate once the algorithm has taken this
42
+ many iterations.
43
+
44
+ tolerance: float (positive)
45
+ Registration will terminate once the difference between
46
+ consecutive objective function values falls within this tolerance.
47
+
48
+ w: float (between 0 and 1)
49
+ Contribution of the uniform distribution to account for outliers.
50
+ Valid values span 0 (inclusive) and 1 (exclusive).
51
+
52
+ q: float
53
+ The objective function value that represents the misalignment between source
54
+ and target point clouds.
55
+
56
+ diff: float (positive)
57
+ The absolute difference between the current and previous objective function values.
58
+
59
+ P: numpy array
60
+ MxN array of probabilities.
61
+ P[m, n] represents the probability that the m-th source point
62
+ corresponds to the n-th target point.
63
+
64
+ Pt1: numpy array
65
+ Nx1 column array.
66
+ Multiplication result between the transpose of P and a column vector of all 1s.
67
+
68
+ P1: numpy array
69
+ Mx1 column array.
70
+ Multiplication result between P and a column vector of all 1s.
71
+
72
+ Np: float (positive)
73
+ The sum of all elements in P.
74
+
75
+ """
76
+
77
+ def __init__(self, X, Y, sigma2=None, max_iterations=None, tolerance=None, w=None, *args, **kwargs):
78
+ if type(X) is not np.ndarray or X.ndim != 2:
79
+ raise ValueError(
80
+ "The target point cloud (X) must be at a 2D numpy array.")
81
+
82
+ if type(Y) is not np.ndarray or Y.ndim != 2:
83
+ raise ValueError(
84
+ "The source point cloud (Y) must be a 2D numpy array.")
85
+
86
+ if X.shape[1] != Y.shape[1]:
87
+ raise ValueError(
88
+ "Both point clouds need to have the same number of dimensions.")
89
+
90
+ if sigma2 is not None and (not isinstance(sigma2, numbers.Number) or sigma2 <= 0):
91
+ raise ValueError(
92
+ "Expected a positive value for sigma2 instead got: {}".format(sigma2))
93
+
94
+ if max_iterations is not None and (not isinstance(max_iterations, numbers.Number) or max_iterations < 0):
95
+ raise ValueError(
96
+ "Expected a positive integer for max_iterations instead got: {}".format(max_iterations))
97
+ elif isinstance(max_iterations, numbers.Number) and not isinstance(max_iterations, int):
98
+ warn("Received a non-integer value for max_iterations: {}. Casting to integer.".format(max_iterations))
99
+ max_iterations = int(max_iterations)
100
+
101
+ if tolerance is not None and (not isinstance(tolerance, numbers.Number) or tolerance < 0):
102
+ raise ValueError(
103
+ "Expected a positive float for tolerance instead got: {}".format(tolerance))
104
+
105
+ if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1):
106
+ raise ValueError(
107
+ "Expected a value between 0 (inclusive) and 1 (exclusive) for w instead got: {}".format(w))
108
+
109
+ self.X = X
110
+ self.Y = Y
111
+ self.TY = Y
112
+ self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2
113
+ (self.N, self.D) = self.X.shape
114
+ (self.M, _) = self.Y.shape
115
+ self.tolerance = 0.001 if tolerance is None else tolerance
116
+ self.w = 0.0 if w is None else w
117
+ self.max_iterations = 100 if max_iterations is None else max_iterations
118
+ self.iteration = 0
119
+ self.diff = np.inf
120
+ self.q = np.inf
121
+ self.P = np.zeros((self.M, self.N))
122
+ self.Pt1 = np.zeros((self.N, ))
123
+ self.P1 = np.zeros((self.M, ))
124
+ self.PX = np.zeros((self.M, self.D))
125
+ self.Np = 0
126
+
127
+ def register(self, callback=lambda **kwargs: None):
128
+ """
129
+ Perform the EM registration.
130
+
131
+ Attributes
132
+ ----------
133
+ callback: function
134
+ A function that will be called after each iteration.
135
+ Can be used to visualize the registration process.
136
+
137
+ Returns
138
+ -------
139
+ self.TY: numpy array
140
+ MxD array of transformed source points.
141
+
142
+ registration_parameters:
143
+ Returned params dependent on registration method used.
144
+ """
145
+ self.transform_point_cloud()
146
+ while self.iteration < self.max_iterations and self.diff > self.tolerance:
147
+ self.iterate()
148
+ if callable(callback):
149
+ kwargs = {'iteration': self.iteration,
150
+ 'error': self.q, 'X': self.X, 'Y': self.TY}
151
+ callback(**kwargs)
152
+
153
+ return self.TY, self.get_registration_parameters()
154
+
155
+ def get_registration_parameters(self):
156
+ """
157
+ Placeholder for child classes.
158
+ """
159
+ raise NotImplementedError(
160
+ "Registration parameters should be defined in child classes.")
161
+
162
+ def update_transform(self):
163
+ """
164
+ Placeholder for child classes.
165
+ """
166
+ raise NotImplementedError(
167
+ "Updating transform parameters should be defined in child classes.")
168
+
169
+ def transform_point_cloud(self):
170
+ """
171
+ Placeholder for child classes.
172
+ """
173
+ raise NotImplementedError(
174
+ "Updating the source point cloud should be defined in child classes.")
175
+
176
+ def update_variance(self):
177
+ """
178
+ Placeholder for child classes.
179
+ """
180
+ raise NotImplementedError(
181
+ "Updating the Gaussian variance for the mixture model should be defined in child classes.")
182
+
183
+ def iterate(self):
184
+ """
185
+ Perform one iteration of the EM algorithm.
186
+ """
187
+ self.expectation()
188
+ self.maximization()
189
+ self.iteration += 1
190
+
191
+ def expectation(self):
192
+ """
193
+ Compute the expectation step of the EM algorithm.
194
+ """
195
+ P = np.sum((self.X[None, :, :] - self.TY[:, None, :])**2, axis=2) # (M, N)
196
+ P = np.exp(-P/(2*self.sigma2))
197
+ c = (2*np.pi*self.sigma2)**(self.D/2)*self.w/(1. - self.w)*self.M/self.N
198
+
199
+ den = np.sum(P, axis = 0, keepdims = True) # (1, N)
200
+ den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c
201
+
202
+ self.P = np.divide(P, den)
203
+ self.Pt1 = np.sum(self.P, axis=0)
204
+ self.P1 = np.sum(self.P, axis=1)
205
+ self.Np = np.sum(self.P1)
206
+ self.PX = np.matmul(self.P, self.X)
207
+
208
+ def maximization(self):
209
+ """
210
+ Compute the maximization step of the EM algorithm.
211
+ """
212
+ self.update_transform()
213
+ self.transform_point_cloud()
214
+ self.update_variance()
215
+
216
+
217
+ class DeformableRegistration(EMRegistration):
218
+ """
219
+ Deformable registration.
220
+ Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
221
+ https://github.com/siavashk/pycpd
222
+
223
+ Attributes
224
+ ----------
225
+ alpha: float (positive)
226
+ Represents the trade-off between the goodness of maximum likelihood fit and regularization.
227
+
228
+ beta: float(positive)
229
+ Width of the Gaussian kernel.
230
+
231
+ low_rank: bool
232
+ Whether to use low rank approximation.
233
+
234
+ num_eig: int
235
+ Number of eigenvectors to use in lowrank calculation.
236
+ """
237
+
238
+ def __init__(self, alpha=None, beta=None, low_rank=False, num_eig=100, *args, **kwargs):
239
+ super().__init__(*args, **kwargs)
240
+ if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0):
241
+ raise ValueError(
242
+ "Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha))
243
+
244
+ if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0):
245
+ raise ValueError(
246
+ "Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta))
247
+
248
+ self.alpha = 2 if alpha is None else alpha
249
+ self.beta = 2 if beta is None else beta
250
+ self.W = np.zeros((self.M, self.D))
251
+ self.G = gaussian_kernel(self.Y, self.beta)
252
+ self.low_rank = low_rank
253
+ self.num_eig = num_eig
254
+ if self.low_rank is True:
255
+ self.Q, self.S = low_rank_eigen(self.G, self.num_eig)
256
+ self.inv_S = np.diag(1./self.S)
257
+ self.S = np.diag(self.S)
258
+ self.E = 0.
259
+
260
+ def update_transform(self):
261
+ """
262
+ Calculate a new estimate of the deformable transformation.
263
+ See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.
264
+
265
+ """
266
+ if self.low_rank is False:
267
+ A = np.dot(np.diag(self.P1), self.G) + \
268
+ self.alpha * self.sigma2 * np.eye(self.M)
269
+ B = self.PX - np.dot(np.diag(self.P1), self.Y)
270
+ self.W = np.linalg.solve(A, B)
271
+
272
+ elif self.low_rank is True:
273
+ # Matlab code equivalent can be found here:
274
+ # https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
275
+ dP = np.diag(self.P1)
276
+ dPQ = np.matmul(dP, self.Q)
277
+ F = self.PX - np.matmul(dP, self.Y)
278
+
279
+ self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
280
+ np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
281
+ (np.matmul(self.Q.T, F))))))
282
+ QtW = np.matmul(self.Q.T, self.W)
283
+ self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))
284
+
285
+ def transform_point_cloud(self, Y=None):
286
+ """
287
+ Update a point cloud using the new estimate of the deformable transformation.
288
+
289
+ Attributes
290
+ ----------
291
+ Y: numpy array, optional
292
+ Array of points to transform - use to predict on new set of points.
293
+ Best for predicting on new points not used to run initial registration.
294
+ If None, self.Y used.
295
+
296
+ Returns
297
+ -------
298
+ If Y is None, returns None.
299
+ Otherwise, returns the transformed Y.
300
+
301
+
302
+ """
303
+ self.W[:,2:]=0
304
+ if Y is not None:
305
+ G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y)
306
+ return Y + np.dot(G, self.W)
307
+ else:
308
+ if self.low_rank is False:
309
+ self.TY = self.Y + np.dot(self.G, self.W)
310
+
311
+ elif self.low_rank is True:
312
+ self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)))
313
+ return
314
+
315
+
316
+ def update_variance(self):
317
+ """
318
+ Update the variance of the mixture model using the new estimate of the deformable transformation.
319
+ See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
320
+
321
+ """
322
+ qprev = self.sigma2
323
+
324
+ # The original CPD paper does not explicitly calculate the objective functional.
325
+ # This functional will include terms from both the negative log-likelihood and
326
+ # the Gaussian kernel used for regularization.
327
+ self.q = np.inf
328
+
329
+ xPx = np.dot(np.transpose(self.Pt1), np.sum(
330
+ np.multiply(self.X, self.X), axis=1))
331
+ yPy = np.dot(np.transpose(self.P1), np.sum(
332
+ np.multiply(self.TY, self.TY), axis=1))
333
+ trPXY = np.sum(np.multiply(self.TY, self.PX))
334
+
335
+ self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D)
336
+
337
+ if self.sigma2 <= 0:
338
+ self.sigma2 = self.tolerance / 10
339
+
340
+ # Here we use the difference between the current and previous
341
+ # estimate of the variance as a proxy to test for convergence.
342
+ self.diff = np.abs(self.sigma2 - qprev)
343
+
344
+ def get_registration_parameters(self):
345
+ """
346
+ Return the current estimate of the deformable transformation parameters.
347
+
348
+
349
+ Returns
350
+ -------
351
+ self.G: numpy array
352
+ Gaussian kernel matrix.
353
+
354
+ self.W: numpy array
355
+ Deformable transformation matrix.
356
+ """
357
+ return self.G, self.W
358
+
359
+
360
+
361
+ def initialize_sigma2(X, Y):
362
+ """
363
+ Initialize the variance (sigma2).
364
+
365
+ param
366
+ ----------
367
+ X: numpy array
368
+ NxD array of points for target.
369
+
370
+ Y: numpy array
371
+ MxD array of points for source.
372
+
373
+ Returns
374
+ -------
375
+ sigma2: float
376
+ Initial variance.
377
+ """
378
+ (N, D) = X.shape
379
+ (M, _) = Y.shape
380
+ diff = X[None, :, :] - Y[:, None, :]
381
+ err = diff ** 2
382
+ return np.sum(err) / (D * M * N)
383
+
384
+
385
+
386
+ def gaussian_kernel(X, beta, Y=None):
387
+ """
388
+ Computes a Gaussian (RBF) kernel matrix between two sets of vectors.
389
+
390
+ :param X: A numpy array of shape (n_samples_X, n_features) representing the first set of vectors.
391
+ :param beta: The standard deviation parameter for the Gaussian kernel. It controls the spread of the kernel.
392
+ :param Y: An optional numpy array of shape (n_samples_Y, n_features) representing the second set of vectors.
393
+ If None, the function computes the kernel between `X` and itself (i.e., the Gram matrix).
394
+ :return: A numpy array of shape (n_samples_X, n_samples_Y) representing the Gaussian kernel matrix.
395
+ Each element (i, j) in the matrix is computed as:
396
+ `exp(-||X[i] - Y[j]||^2 / (2 * beta^2))`
397
+ """
398
+
399
+ # If Y is not provided, use X for both sets, computing the kernel matrix between X and itself
400
+ if Y is None:
401
+ Y = X
402
+
403
+ # Compute the difference tensor between each pair of vectors in X and Y
404
+ # The resulting shape is (n_samples_X, n_samples_Y, n_features)
405
+ diff = X[:, None, :] - Y[None, :, :]
406
+
407
+ # Square the differences element-wise
408
+ diff = np.square(diff)
409
+
410
+ # Sum the squared differences across the feature dimension (axis 2) to get squared Euclidean distances
411
+ # The resulting shape is (n_samples_X, n_samples_Y)
412
+ diff = np.sum(diff, axis=2)
413
+
414
+ # Apply the Gaussian (RBF) kernel formula: exp(-||X[i] - Y[j]||^2 / (2 * beta^2))
415
+ kernel_matrix = np.exp(-diff / (2 * beta**2))
416
+
417
+ return kernel_matrix
418
+
419
+
420
+
421
+ def low_rank_eigen(G, num_eig):
422
+ """
423
+ Calculate the top `num_eig` eigenvectors and eigenvalues of a given Gaussian matrix G.
424
+ This function is useful for dimensionality reduction or when a low-rank approximation is needed.
425
+
426
+ :param G: A square matrix (numpy array) for which the eigen decomposition is to be performed.
427
+ :param num_eig: The number of top eigenvectors and eigenvalues to return, based on the magnitude of eigenvalues.
428
+ :return: A tuple containing:
429
+ - Q: A numpy array with shape (n, num_eig) containing the top `num_eig` eigenvectors of the matrix `G`.
430
+ Each column in `Q` corresponds to an eigenvector.
431
+ - S: A numpy array of shape (num_eig,) containing the top `num_eig` eigenvalues of the matrix `G`.
432
+
433
+ """
434
+
435
+ # Perform eigen decomposition on matrix G
436
+ # `S` will contain all the eigenvalues, and `Q` will contain the corresponding eigenvectors
437
+ S, Q = np.linalg.eigh(G)
438
+
439
+ # Sort eigenvalues in descending order based on their absolute values
440
+ # Get the indices of the top `num_eig` largest eigenvalues
441
+ eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig])
442
+
443
+ # Select the corresponding top eigenvectors based on the sorted indices
444
+ Q = Q[:, eig_indices] # Q now contains the top `num_eig` eigenvectors
445
+
446
+ # Select the top `num_eig` eigenvalues based on the sorted indices
447
+ S = S[eig_indices] # S now contains the top `num_eig` eigenvalues
448
+
449
+ return Q, S
450
+
451
+
452
+
453
+ def find_homography_translation_rotation(src_points, dst_points):
454
+ """
455
+ Find the homography between two sets of coordinates with only translation and rotation.
456
+
457
+ :param src_points: A numpy array of shape (n, 2) containing source coordinates.
458
+ :param dst_points: A numpy array of shape (n, 2) containing destination coordinates.
459
+ :return: A 3x3 homography matrix.
460
+ """
461
+ # Ensure the points are in the correct shape
462
+ assert src_points.shape == dst_points.shape
463
+ assert src_points.shape[1] == 2
464
+
465
+ # Calculate the centroids of the point sets
466
+ src_centroid = np.mean(src_points, axis=0)
467
+ dst_centroid = np.mean(dst_points, axis=0)
468
+
469
+ # Center the points around the centroids
470
+ centered_src_points = src_points - src_centroid
471
+ centered_dst_points = dst_points - dst_centroid
472
+
473
+ # Calculate the covariance matrix
474
+ H = np.dot(centered_src_points.T, centered_dst_points)
475
+
476
+ # Singular Value Decomposition (SVD)
477
+ U, S, Vt = np.linalg.svd(H)
478
+
479
+ # Calculate the rotation matrix
480
+ R = np.dot(Vt.T, U.T)
481
+
482
+ # Ensure a proper rotation matrix (det(R) = 1)
483
+ if np.linalg.det(R) < 0:
484
+ Vt[-1, :] *= -1
485
+ R = np.dot(Vt.T, U.T)
486
+
487
+ # Calculate the translation vector
488
+ t = dst_centroid - np.dot(R, src_centroid)
489
+
490
+ # Construct the homography matrix
491
+ homography_matrix = np.eye(3)
492
+ homography_matrix[0:2, 0:2] = R
493
+ homography_matrix[0:2, 2] = t
494
+
495
+ return homography_matrix
496
+
497
+
498
+
499
+ def apply_homography(coordinates, H):
500
+ """
501
+ Apply a 3x3 homography matrix to 2D coordinates.
502
+
503
+ :param coordinates: A numpy array of shape (n, 2) containing 2D coordinates.
504
+ :param H: A numpy array of shape (3, 3) representing the homography matrix.
505
+ :return: A numpy array of shape (n, 2) with transformed coordinates.
506
+ """
507
+ # Convert (x, y) to homogeneous coordinates (x, y, 1)
508
+ n = coordinates.shape[0]
509
+ homogeneous_coords = np.hstack((coordinates, np.ones((n, 1))))
510
+
511
+ # Apply the homography matrix
512
+ transformed_homogeneous = np.dot(homogeneous_coords, H.T)
513
+
514
+ # Convert back from homogeneous coordinates (x', y', w') to (x'/w', y'/w')
515
+ transformed_coords = transformed_homogeneous[:, :2] / transformed_homogeneous[:, [2]]
516
+
517
+ return transformed_coords
518
+
519
+
520
+
521
+ def align_tissue(ad_tar_coor, ad_src_coor, pca_comb_features, src_img, alpha=0.5):
522
+ """
523
+ Aligns the source coordinates to the target coordinates using Coherent Point Drift (CPD)
524
+ registration, and applies a homography transformation to warp the source coordinates accordingly.
525
+
526
+ :param ad_tar_coor: Numpy array of target coordinates to which the source will be aligned.
527
+ :param ad_src_coor: Numpy array of source coordinates that will be aligned to the target.
528
+ :param pca_comb_features: PCA-combined feature matrix used as additional features for the alignment process.
529
+ :param src_img: Source image to be warped based on the alignment.
530
+ :param alpha: Regularization parameter for CPD registration, default is 0.5.
531
+ :return:
532
+ - cpd_coor: The new source coordinates after CPD alignment.
533
+ - homo_coor: The source coordinates after applying the homography transformation.
534
+ - aligned_image: The source image warped based on the homography transformation.
535
+ """
536
+
537
+ # Normalize target and source coordinates to the range [0, 1]
538
+ ad_tar_coor_z = (ad_tar_coor - ad_tar_coor.min()) / (ad_tar_coor.max() - ad_tar_coor.min())
539
+ ad_src_coor_z = (ad_src_coor - ad_src_coor.min()) / (ad_src_coor.max() - ad_src_coor.min())
540
+
541
+ # Normalize PCA-combined features to the range [0, 1]
542
+ pca_comb_features_z = (pca_comb_features - pca_comb_features.min()) / (pca_comb_features.max() - pca_comb_features.min())
543
+
544
+ # Concatenate spatial and PCA-combined features for target and source
545
+ target = np.concatenate((ad_tar_coor_z, pca_comb_features_z[:ad_tar_coor.shape[0], :2]), axis=1)
546
+ source = np.concatenate((ad_src_coor_z, pca_comb_features_z[ad_tar_coor.shape[0]:, :2]), axis=1)
547
+
548
+ # Initialize and run the CPD registration (deformable with regularization)
549
+ reg = DeformableRegistration(X=target, Y=source, low_rank=True,
550
+ alpha=alpha,
551
+ max_iterations=int(1e9), tolerance=1e-9)
552
+
553
+ TY = reg.register()[0] # TY contains the transformed source points
554
+
555
+ # Rescale the CPD-aligned coordinates back to the original range of target coordinates
556
+ cpd_coor = TY[:, :2] * (ad_tar_coor.max() - ad_tar_coor.min()) + ad_tar_coor.min()
557
+
558
+ # Find homography transformation based on CPD-aligned coordinates and apply it
559
+ h = find_homography_translation_rotation(ad_src_coor, cpd_coor)
560
+ homo_coor = apply_homography(ad_src_coor, h)
561
+
562
+ # Warp the source image using the computed homography
563
+ aligned_image = cv2.warpPerspective(src_img, h, (src_img.shape[1], src_img.shape[0]))
564
+
565
+ # Return the CPD-aligned coordinates, the homography-transformed coordinates, and the warped image
566
+ return cpd_coor, homo_coor, aligned_image
567
+
568
+
src/build/lib/loki/annotate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+ import os
5
+ import scanpy as sc
6
+ import json
7
+ import cv2
8
+
9
+
10
+
11
+ def annotate_with_bulk(img_features, bulk_features, normalize=True, T=1, tensor=False):
12
+ """
13
+ Annotates tissue image with similarity scores between image features and bulk RNA-seq features.
14
+
15
+ :param img_features: Feature matrix representing histopathology image features.
16
+ :param bulk_features: Feature vector representing bulk RNA-seq features.
17
+ :param normalize: Whether to normalize similarity scores, default=True.
18
+ :param T: Temperature parameter to control the sharpness of the softmax distribution. Higher values result in a smoother distribution.
19
+ :param tensor: Feature format in torch tensor or not, default=False.
20
+
21
+ :return: An array or tensor containing the normalized similarity scores.
22
+ """
23
+
24
+ if tensor:
25
+ # Compute similarity between image features and bulk RNA-seq features
26
+ cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
27
+ similarity = cosine_similarity(img_features, bulk_features.unsqueeze(0)) # Shape: [n]
28
+
29
+ # Optional normalization using the feature vector's norm
30
+ if normalize:
31
+ normalization_factor = torch.sqrt(torch.tensor([bulk_features.shape[0]], dtype=torch.float)) # sqrt(768)
32
+ similarity = similarity / normalization_factor
33
+
34
+ # Reshape and apply temperature scaling for softmax
35
+ similarity = similarity.unsqueeze(0) # Shape: [1, n]
36
+ similarity = similarity / T # Control distribution sharpness
37
+
38
+ # Convert similarity scores to probability distribution using softmax
39
+ similarity = torch.nn.functional.softmax(similarity, dim=-1) # Shape: [1, n]
40
+
41
+ else:
42
+ # Compute similarity for non-tensor mode
43
+ similarity = np.dot(img_features.T, bulk_features)
44
+
45
+ # Apply a softmax-like normalization for numerical stability
46
+ max_similarity = np.max(similarity) # Maximum value for stability
47
+ similarity = np.exp(similarity - max_similarity) / np.sum(np.exp(similarity - max_similarity))
48
+
49
+ # Normalize similarity scores to [0, 1] range for interpretation
50
+ similarity = (similarity - np.min(similarity)) / (np.max(similarity) - np.min(similarity))
51
+
52
+ return similarity
53
+
54
+
55
+
56
+ def annotate_with_marker_genes(classes, image_embeddings, all_text_embeddings):
57
+ """
58
+ Annotates tissue image with similarity scores between image features and marker gene features.
59
+
60
+ :param classes: A list or array of tissue type labels.
61
+ :param image_embeddings: A numpy array or torch tensor of image embeddings (shape: [n_images, embedding_dim]).
62
+ :param all_text_embeddings: A numpy array or torch tensor of text embeddings of the marker genes
63
+ (shape: [n_classes, embedding_dim]).
64
+
65
+ :return:
66
+ - dot_similarity: The matrix of dot product similarities between image embeddings and text embeddings.
67
+ - pred_class: The predicted tissue type for the image based on the highest similarity score.
68
+ """
69
+
70
+ # Calculate dot product similarity between image embeddings and text embeddings
71
+ # This results in a similarity matrix of shape [n_images, n_classes]
72
+ dot_similarity = image_embeddings @ all_text_embeddings.T
73
+
74
+ # Find the class with the highest similarity for each image
75
+ # Use argmax to identify the index of the highest similarity score
76
+ pred_class = classes[dot_similarity.argmax()]
77
+
78
+ return dot_similarity, pred_class
79
+
80
+
81
+
82
+ def load_image_annotation(image_path):
83
+ """
84
+ Loads an image with annotation.
85
+
86
+ :param image_path: The file path to the image.
87
+
88
+ :return: The processed image, converted to BGR color space and of type uint8.
89
+ """
90
+
91
+ # Load the image from the specified file path using OpenCV
92
+ image = cv2.imread(image_path)
93
+
94
+ # Convert the color from RGB (OpenCV loads as BGR by default) to BGR (which matches common color standards)
95
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
96
+
97
+ # Ensure the image is of type uint8 for proper handling in OpenCV and other image processing libraries
98
+ image = image.astype(np.uint8)
99
+
100
+ return image
101
+
102
+
src/build/lib/loki/decompose.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import tangram as tg
3
+ import numpy as np
4
+ import torch
5
+ import anndata
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.neighbors import NearestNeighbors
8
+
9
+
10
+
11
+ def generate_feature_ad(ad_expr, feature_path, sc=False):
12
+ """
13
+ Generates an AnnData object with OmiCLIP text or image embeddings.
14
+
15
+ :param ad_expr: AnnData object containing metadata for the dataset.
16
+ :param feature_path: Path to the CSV file containing the features to be loaded.
17
+ :param sc: Boolean flag indicating whether to copy single-cell metadata or ST metadata. Default is False (ST).
18
+ :return: A new AnnData object with the loaded features and relevant metadata from ad_expr.
19
+ """
20
+
21
+ # Load features from the CSV file. The index should match the cells/spots in ad_expr.obs.index.
22
+ features = pd.read_csv(feature_path, index_col=0)[ad_expr.obs.index]
23
+
24
+ # Create a new AnnData object with the features, transposing them to have cells/spots as rows
25
+ feature_ad = anndata.AnnData(features[ad_expr.obs.index].T)
26
+
27
+ # Copy relevant metadata from ad_expr based on the sc flag
28
+ if sc:
29
+ # If the data is single-cell (sc), copy the metadata from ad_expr.obs
30
+ feature_ad.obs = ad_expr.obs.copy()
31
+ else:
32
+ # If the data is spatial, copy the 'cell_num', 'spatial' info, and spatial coordinates
33
+ feature_ad.obs['cell_num'] = ad_expr.obs['cell_num'].copy()
34
+ feature_ad.uns['spatial'] = ad_expr.uns['spatial'].copy()
35
+ feature_ad.obsm['spatial'] = ad_expr.obsm['spatial'].copy()
36
+
37
+ return feature_ad
38
+
39
+
40
+
41
+ def normalize_percentile(df, cols, min_percentile=5, max_percentile=95):
42
+ """
43
+ Clips and normalizes the specified columns of a DataFrame based on percentile thresholds,
44
+ transforming their values to the [0, 1] range.
45
+
46
+ :param df: A pandas DataFrame containing the columns to normalize.
47
+ :type df: pandas.DataFrame
48
+ :param cols: A list of column names in `df` that should be normalized.
49
+ :type cols: list[str]
50
+ :param min_percentile: The lower percentile used for clipping (defaults to 5).
51
+ :type min_percentile: float
52
+ :param max_percentile: The upper percentile used for clipping (defaults to 95).
53
+ :type max_percentile: float
54
+ :return: The same DataFrame with specified columns clipped and normalized.
55
+ :rtype: pandas.DataFrame
56
+ """
57
+
58
+ # Iterate over each column that needs to be normalized
59
+ for col in cols:
60
+ # Compute the lower and upper values at the given percentiles
61
+ min_val = np.percentile(df[col], min_percentile)
62
+ max_val = np.percentile(df[col], max_percentile)
63
+
64
+ # Clip the column's values between these percentile thresholds
65
+ df[col] = np.clip(df[col], min_val, max_val)
66
+
67
+ # Perform min-max normalization to scale the clipped values to the [0, 1] range
68
+ df[col] = (df[col] - min_val) / (max_val - min_val)
69
+
70
+ return df
71
+
72
+
73
+
74
+ def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False, major_types=None, min_percentile=5, max_percentile=95):
75
+ """
76
+ Performs cell type decomposition on spatial data (ST or image) with single-cell data .
77
+
78
+ :param sc_ad: AnnData object containing single-cell meta data.
79
+ :param st_ad: AnnData object containing spatial data (ST or image) meta data.
80
+ :param density_prior: A numpy array providing prior information about cell densities in spatial spots.
81
+ :param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
82
+ :param target_count: If True, sums up the total number of cells in `st_ad.obs['cell_num']`. Can also be set to a specific value.
83
+ :param pca_mode: Boolean flag to apply PCA for dimensionality reduction. Default is True.
84
+ :param n_components: Number of PCA components to use if `pca_mode` is True. Default is 300.
85
+ :return: The spatial AnnData object with projected cell type annotations.
86
+ """
87
+
88
+ # Preprocess the data for decomposition using tangram (tg)
89
+ tg.pp_adatas(sc_ad, st_ad, genes=None) # Preprocessing: match genes between single-cell and spatial data
90
+
91
+
92
+ # Map single-cell data to spatial data using Tangram's "map_cells_to_space" function
93
+ ad_map = tg.map_cells_to_space(
94
+ sc_ad, st_ad,
95
+ mode="clusters", # Map based on clusters (cell types)
96
+ cluster_label=cell_type_col, # Column in `sc_ad.obs` representing cell type
97
+ device='cpu', # Run on CPU (or 'cuda' if GPU is available)
98
+ scale=False, # Don't scale data (can be set to True if needed)
99
+ density_prior='uniform', # Use prior information for cell densities
100
+ random_state=10, # Set random state for reproducibility
101
+ verbose=False, # Disable verbose output for cleaner logging
102
+ )
103
+
104
+ # Project cell type annotations from the single-cell data to the spatial data
105
+ tg.project_cell_annotations(ad_map, st_ad, annotation=cell_type_col)
106
+
107
+
108
+ if NMS_mode:
109
+ major_types = major_types
110
+ st_ad.obs = normalize_percentile(st_ad.obsm['tangram_ct_pred'], major_types, min_percentile, max_percentile)
111
+
112
+ st_ad_binary = st_ad.obsm['tangram_ct_pred'][major_types].copy()
113
+ # Retain the max value in each row and set the rest to 0
114
+ st_ad.obs[major_types] = st_ad_binary.where(st_ad_binary.eq(st_ad_binary.max(axis=1), axis=0), other=0)
115
+
116
+ return st_ad # Return the spatial AnnData object with the projected annotations
117
+
118
+
119
+
120
+ def assign_cells_to_spots(cell_locs, spot_locs, patch_size=16):
121
+ """
122
+ Assigns cells to spots based on their spatial coordinates. Each cell within the specified patch size (radius)
123
+ of a spot will be assigned to that spot.
124
+
125
+ :param cell_locs: Numpy array of shape (n_cells, 2) with the x, y coordinates of the cells.
126
+ :param spot_locs: Numpy array of shape (n_spots, 2) with the x, y coordinates of the spots.
127
+ :param patch_size: The diameter of the spot patch. The radius used for assignment will be half of this value.
128
+ :return: A sparse matrix where each row corresponds to a cell and each column corresponds to a spot.
129
+ The value is 1 if the cell is assigned to that spot, 0 otherwise.
130
+ """
131
+ # Initialize the NearestNeighbors model with a radius equal to half the patch size
132
+ neigh = NearestNeighbors(radius=patch_size * 0.5)
133
+
134
+ # Fit the model on the spot locations
135
+ neigh.fit(spot_locs)
136
+
137
+ # Create the radius neighbors graph which will assign cells to spots based on proximity
138
+ # This graph is a sparse matrix where rows are cells and columns are spots, with a 1 indicating assignment
139
+ A = neigh.radius_neighbors_graph(cell_locs, mode='connectivity')
140
+
141
+ return A
142
+
143
+
src/build/lib/loki/plot.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from pathlib import Path
3
+ import json
4
+ import cv2
5
+ from matplotlib import cm
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+
11
+
12
+ def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
13
+ """
14
+ Plots the target coordinates and alignment of source coordinates.
15
+
16
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
17
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
18
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
19
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
20
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
21
+ :param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
22
+ :param s: Marker size for the scatter plot points. Default is 0.8.
23
+ :param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
24
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates.
25
+ """
26
+
27
+ # Create a figure with three subplots, adjusting size and resolution
28
+ plt.figure(figsize=(10, 3), dpi=300)
29
+
30
+ # First subplot: Plot target coordinates
31
+ plt.subplot(1, 3, 1)
32
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
33
+ # Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
34
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
35
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
36
+
37
+ # Second subplot: Plot source coordinates
38
+ plt.subplot(1, 3, 2)
39
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
40
+ # Ensure consistent plot limits across subplots by using the same limits as the target coordinates
41
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
42
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
43
+
44
+ # Third subplot: Plot alignment of source coordinates
45
+ plt.subplot(1, 3, 3)
46
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
47
+ # Maintain the same plot limits across all subplots for a uniform comparison
48
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
49
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
50
+
51
+ # Optionally draw boundary lines at the minimum x and y values of the target coordinates
52
+ if boundary_line:
53
+ plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
54
+ plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
55
+
56
+ # Remove the axis labels and ticks from all subplots for a cleaner appearance
57
+ plt.axis('off')
58
+
59
+ # Display the plot
60
+ plt.show()
61
+
62
+
63
+
64
+ def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
65
+ """
66
+ Plots the target coordinates and alignment of source coordinates with their respective images in the background.
67
+
68
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
69
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
70
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
71
+ :param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
72
+ :param src_img: Image associated with the source coordinates, used as the background in the second subplot.
73
+ :param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
74
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
75
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
76
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
77
+ """
78
+
79
+ # Create a figure with three subplots and set the size and resolution
80
+ plt.figure(figsize=(10, 8), dpi=150)
81
+
82
+ # First subplot: Plot target coordinates with the target image as the background
83
+ plt.subplot(1, 3, 1)
84
+ # Scatter plot for the target coordinates with transparency and small marker size
85
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
86
+ # Overlay the target image with some transparency (alpha = 0.3)
87
+ plt.imshow(tar_img, origin='lower', alpha=0.3)
88
+
89
+ # Second subplot: Plot source coordinates with the source image as the background
90
+ plt.subplot(1, 3, 2)
91
+ # Scatter plot for the source coordinates with transparency and small marker size
92
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
93
+ # Overlay the source image with some transparency (alpha = 0.3)
94
+ plt.imshow(src_img, origin='lower', alpha=0.3)
95
+
96
+ # Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
97
+ plt.subplot(1, 3, 3)
98
+ # Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
99
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
100
+ # Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
101
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
102
+ # Overlay the aligned image with some transparency (alpha = 0.3)
103
+ plt.imshow(aligned_image, origin='lower', alpha=0.3)
104
+
105
+ # Turn off the axis for all subplots to give a cleaner visual output
106
+ plt.axis('off')
107
+
108
+ # Display the plots
109
+ plt.show()
110
+
111
+
112
+
113
+ def draw_polygon(image, polygon, color='k', thickness=2):
114
+ """
115
+ Draws one or more polygons on the given image.
116
+
117
+ :param image: The image on which to draw the polygons (as a numpy array).
118
+ :param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
119
+ :param color: A string or list of strings representing the color(s) for each polygon.
120
+ If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
121
+ :param thickness: An integer or a list of integers representing the thickness of the polygon borders.
122
+ If a single value is provided, it will be applied to all polygons. Default is 2.
123
+
124
+ :return: The image with the polygons drawn on it.
125
+ """
126
+
127
+ # If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
128
+ if not isinstance(color, list):
129
+ color = [color] * len(polygon) # Create a list where each polygon gets the same color
130
+
131
+ # Loop through each polygon in the list, along with its corresponding color
132
+ for i, poly in enumerate(polygon):
133
+ # Get the color for the current polygon
134
+ c = color[i]
135
+
136
+ # Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
137
+ c = color_string_to_rgb(c)
138
+
139
+ # Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
140
+ t = thickness[i] if isinstance(thickness, list) else thickness
141
+
142
+ # Convert the polygon coordinates to a numpy array of integers
143
+ poly = np.array(poly, np.int32)
144
+
145
+ # Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
146
+ poly = poly.reshape((-1, 1, 2))
147
+
148
+ # Draw the polygon on the image using OpenCV's `cv2.polylines` function
149
+ # `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
150
+ image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
151
+
152
+ return image
153
+
154
+
155
+
156
+ def blend_images(image1, image2, alpha=0.5):
157
+ """
158
+ Blends two images together.
159
+
160
+ :param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
161
+ :param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
162
+ :param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
163
+ where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
164
+
165
+ :return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
166
+ The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
167
+ """
168
+
169
+ # Use cv2.addWeighted to blend the two images.
170
+ # The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
171
+ blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
172
+
173
+ # Return the resulting blended image.
174
+ return blended
175
+
176
+
177
+
178
+ def color_string_to_rgb(color_string):
179
+ """
180
+ Converts a color string to an RGB tuple.
181
+
182
+ :param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
183
+ a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
184
+ :return:
185
+ A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
186
+ :raises:
187
+ ValueError: If the color string is not recognized.
188
+ """
189
+
190
+ # Remove any spaces in the color string
191
+ color_string = color_string.replace(' ', '')
192
+
193
+ # If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
194
+ if color_string.startswith('#'):
195
+ color_string = color_string[1:]
196
+ else:
197
+ # Handle shorthand single-letter color codes by converting them to hex values
198
+ # 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
199
+ if color_string == 'k': # Black
200
+ color_string = '000000'
201
+ elif color_string == 'r': # Red
202
+ color_string = 'ff0000'
203
+ elif color_string == 'g': # Green
204
+ color_string = '00ff00'
205
+ elif color_string == 'b': # Blue
206
+ color_string = '0000ff'
207
+ elif color_string == 'w': # White
208
+ color_string = 'ffffff'
209
+ else:
210
+ # Raise an error if the color string is not recognized
211
+ raise ValueError(f"Unknown color string {color_string}")
212
+
213
+ # Convert the first two characters to the red (R) value
214
+ r = int(color_string[:2], 16)
215
+
216
+ # Convert the next two characters to the green (G) value
217
+ g = int(color_string[2:4], 16)
218
+
219
+ # Convert the last two characters to the blue (B) value
220
+ b = int(color_string[4:], 16)
221
+
222
+ # Return the RGB values as a tuple
223
+ return (r, g, b)
224
+
225
+
226
+
227
+ def plot_heatmap(
228
+ coor,
229
+ similairty,
230
+ image_path=None,
231
+ patch_size=(256, 256),
232
+ save_path=None,
233
+ downsize=32,
234
+ cmap='turbo',
235
+ smooth=False,
236
+ boxes=None,
237
+ box_color='k',
238
+ box_thickness=2,
239
+ polygons=None,
240
+ polygons_color='k',
241
+ polygons_thickness=2,
242
+ image_alpha=0.5
243
+ ):
244
+ """
245
+ Plots a heatmap overlaid on an image based on given coordinates and similairty.
246
+
247
+ :param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
248
+ :param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
249
+ :param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
250
+ :param patch_size: Size of each patch in pixels (default is 256x256).
251
+ :param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
252
+ :param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
253
+ :param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
254
+ :param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
255
+ :param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
256
+ :param box_color: Color of the boxes. Default is black ('k').
257
+ :param box_thickness: Thickness of the box outlines.
258
+ :param polygons: List of polygons (N, 2) to draw on the heatmap.
259
+ :param polygons_color: Color of the polygon outlines. Default is black ('k').
260
+ :param polygons_thickness: Thickness of the polygon outlines.
261
+ :param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
262
+
263
+ :return:
264
+ - heatmap: The generated heatmap as a numpy array (RGB).
265
+ - image: The original image with overlaid polygons if provided.
266
+ """
267
+
268
+ # Read the background image (if provided), otherwise a blank image
269
+ image = cv2.imread(image_path)
270
+ image_size = (image.shape[0], image.shape[1]) # Get the size of the image
271
+ coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
272
+ patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
273
+
274
+ # Convert similairties to colors using the provided colormap
275
+ cmap = plt.get_cmap(cmap) # Get the colormap object
276
+ norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
277
+ colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
278
+
279
+ # Initialize a blank white heatmap the size of the image
280
+ heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
281
+
282
+ # Place the colored patches on the heatmap according to the coordinates and patch size
283
+ for i in range(len(coor)):
284
+ x, y = coor[i]
285
+ w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
286
+ w = w.astype(np.uint8) # Convert the color to uint8
287
+ heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
288
+
289
+ # If the image_alpha is greater than 0, blend the heatmap with the original image
290
+ if image_alpha > 0:
291
+ image = np.array(image)
292
+
293
+ # Pad the image if necessary to match the heatmap size
294
+ if image.shape[0] < heatmap.shape[0]:
295
+ pad = heatmap.shape[0] - image.shape[0]
296
+ image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
297
+ if image.shape[1] < heatmap.shape[1]:
298
+ pad = heatmap.shape[1] - heatmap.shape[1]
299
+ image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
300
+
301
+ # Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
302
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
303
+ image = image.astype(np.uint8)
304
+ heatmap = heatmap.astype(np.uint8)
305
+ heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
306
+
307
+ # If polygons are provided, draw them on the heatmap and image
308
+ if polygons is not None:
309
+ polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
310
+ image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
311
+ heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
312
+
313
+ return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
314
+ else:
315
+ return heatmap, image # Return the heatmap and image
316
+
317
+
318
+
319
+ def show_images_side_by_side(image1, image2, title1=None, title2=None):
320
+ """
321
+ Displays two images side by side in a single figure.
322
+
323
+ :param image1: The first image to display (as a numpy array).
324
+ :param image2: The second image to display (as a numpy array).
325
+ :param title1: The title for the first image. Default is None (no title).
326
+ :param title2: The title for the second image. Default is None (no title).
327
+ :return: Displays the images side by side.
328
+ """
329
+
330
+ # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
331
+ fig, ax = plt.subplots(1, 2, figsize=(15,8))
332
+
333
+ # Display the first image on the first subplot
334
+ ax[0].imshow(image1)
335
+
336
+ # Display the second image on the second subplot
337
+ ax[1].imshow(image2)
338
+
339
+ # Set the title for the first image (if provided)
340
+ ax[0].set_title(title1)
341
+
342
+ # Set the title for the second image (if provided)
343
+ ax[1].set_title(title2)
344
+
345
+ # Remove axis labels and ticks for both images to give a cleaner look
346
+ ax[0].axis('off')
347
+ ax[1].axis('off')
348
+
349
+ # Show the final figure with both images displayed side by side
350
+ plt.show()
351
+
352
+
353
+
354
+ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
355
+ """
356
+ Plots image with polygons.
357
+
358
+ :param fullres_img: The full-resolution image to display (as a numpy array).
359
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
360
+ :param linewidth: The thickness of the lines used to draw the polygons.
361
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
362
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
363
+ :return: Displays the image with ROI polygons overlaid.
364
+ """
365
+
366
+ # Create a new figure with a fixed size for displaying the image and annotations
367
+ plt.figure(figsize=(10, 10))
368
+
369
+ # Display the full-resolution image
370
+ plt.imshow(fullres_img)
371
+
372
+ # Loop through each polygon in roi_polygon and plot them on the image
373
+ for polygon in roi_polygon:
374
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
375
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
376
+
377
+ # Set the x-axis limits based on the provided tuple (xlim)
378
+ plt.xlim(xlim)
379
+
380
+ # Set the y-axis limits based on the provided tuple (ylim)
381
+ plt.ylim(ylim)
382
+
383
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
384
+ plt.gca().invert_yaxis()
385
+
386
+ # Turn off the axis to give a cleaner image display without ticks or labels
387
+ plt.axis('off')
388
+
389
+
390
+
391
+ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
392
+ """
393
+ Plots tissue type annotation heatmap.
394
+
395
+ :param st_ad: AnnData object containing coordinates in `obsm['spatial']`
396
+ and similarity scores in `obs['bulk_simi']`.
397
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
398
+ :param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
399
+ :param linewidth: The thickness of the lines used to draw the polygons.
400
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
401
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
402
+ :return: Displays the heatmap with polygons overlaid.
403
+ """
404
+
405
+ # Create a new figure with a fixed size for displaying the heatmap and annotations
406
+ plt.figure(figsize=(10, 10))
407
+
408
+ # Scatter plot for the spatial transcriptomics data.
409
+ # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
410
+ plt.scatter(
411
+ st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
412
+ c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
413
+ s=s, # Size of each marker
414
+ vmin=0.1, vmax=0.95, # Set the range for the color normalization
415
+ cmap='turbo' # Use the 'turbo' colormap for the heatmap
416
+ )
417
+
418
+ # Loop through each polygon in roi_polygon and plot them on the image
419
+ for polygon in roi_polygon:
420
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
421
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
422
+
423
+ # Set the x-axis limits based on the provided tuple (xlim)
424
+ plt.xlim(xlim)
425
+
426
+ # Set the y-axis limits based on the provided tuple (ylim)
427
+ plt.ylim(ylim)
428
+
429
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
430
+ plt.gca().invert_yaxis()
431
+
432
+ # Turn off the axis to give a cleaner image display without ticks or labels
433
+ plt.axis('off')
434
+
435
+
src/build/lib/loki/plotting.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from pathlib import Path
3
+ import json
4
+ import cv2
5
+ from matplotlib import cm
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+
11
+
12
+ def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
13
+ """
14
+ Plots the target coordinates and alignment of source coordinates.
15
+
16
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
17
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
18
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
19
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
20
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
21
+ :param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
22
+ :param s: Marker size for the scatter plot points. Default is 0.8.
23
+ :param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
24
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates.
25
+ """
26
+
27
+ # Create a figure with three subplots, adjusting size and resolution
28
+ plt.figure(figsize=(10, 3), dpi=300)
29
+
30
+ # First subplot: Plot target coordinates
31
+ plt.subplot(1, 3, 1)
32
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
33
+ # Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
34
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
35
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
36
+
37
+ # Second subplot: Plot source coordinates
38
+ plt.subplot(1, 3, 2)
39
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
40
+ # Ensure consistent plot limits across subplots by using the same limits as the target coordinates
41
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
42
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
43
+
44
+ # Third subplot: Plot alignment of source coordinates
45
+ plt.subplot(1, 3, 3)
46
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
47
+ # Maintain the same plot limits across all subplots for a uniform comparison
48
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
49
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
50
+
51
+ # Optionally draw boundary lines at the minimum x and y values of the target coordinates
52
+ if boundary_line:
53
+ plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
54
+ plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
55
+
56
+ # Remove the axis labels and ticks from all subplots for a cleaner appearance
57
+ plt.axis('off')
58
+
59
+ # Display the plot
60
+ plt.show()
61
+
62
+
63
+
64
+ def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
65
+ """
66
+ Plots the target coordinates and alignment of source coordinates with their respective images in the background.
67
+
68
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
69
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
70
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
71
+ :param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
72
+ :param src_img: Image associated with the source coordinates, used as the background in the second subplot.
73
+ :param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
74
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
75
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
76
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
77
+ """
78
+
79
+ # Create a figure with three subplots and set the size and resolution
80
+ plt.figure(figsize=(10, 8), dpi=150)
81
+
82
+ # First subplot: Plot target coordinates with the target image as the background
83
+ plt.subplot(1, 3, 1)
84
+ # Scatter plot for the target coordinates with transparency and small marker size
85
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
86
+ # Overlay the target image with some transparency (alpha = 0.3)
87
+ plt.imshow(tar_img, origin='lower', alpha=0.3)
88
+
89
+ # Second subplot: Plot source coordinates with the source image as the background
90
+ plt.subplot(1, 3, 2)
91
+ # Scatter plot for the source coordinates with transparency and small marker size
92
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
93
+ # Overlay the source image with some transparency (alpha = 0.3)
94
+ plt.imshow(src_img, origin='lower', alpha=0.3)
95
+
96
+ # Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
97
+ plt.subplot(1, 3, 3)
98
+ # Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
99
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
100
+ # Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
101
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
102
+ # Overlay the aligned image with some transparency (alpha = 0.3)
103
+ plt.imshow(aligned_image, origin='lower', alpha=0.3)
104
+
105
+ # Turn off the axis for all subplots to give a cleaner visual output
106
+ plt.axis('off')
107
+
108
+ # Display the plots
109
+ plt.show()
110
+
111
+
112
+
113
+ def draw_polygon(image, polygon, color='k', thickness=2):
114
+ """
115
+ Draws one or more polygons on the given image.
116
+
117
+ :param image: The image on which to draw the polygons (as a numpy array).
118
+ :param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
119
+ :param color: A string or list of strings representing the color(s) for each polygon.
120
+ If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
121
+ :param thickness: An integer or a list of integers representing the thickness of the polygon borders.
122
+ If a single value is provided, it will be applied to all polygons. Default is 2.
123
+
124
+ :return: The image with the polygons drawn on it.
125
+ """
126
+
127
+ # If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
128
+ if not isinstance(color, list):
129
+ color = [color] * len(polygon) # Create a list where each polygon gets the same color
130
+
131
+ # Loop through each polygon in the list, along with its corresponding color
132
+ for i, poly in enumerate(polygon):
133
+ # Get the color for the current polygon
134
+ c = color[i]
135
+
136
+ # Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
137
+ c = color_string_to_rgb(c)
138
+
139
+ # Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
140
+ t = thickness[i] if isinstance(thickness, list) else thickness
141
+
142
+ # Convert the polygon coordinates to a numpy array of integers
143
+ poly = np.array(poly, np.int32)
144
+
145
+ # Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
146
+ poly = poly.reshape((-1, 1, 2))
147
+
148
+ # Draw the polygon on the image using OpenCV's `cv2.polylines` function
149
+ # `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
150
+ image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
151
+
152
+ return image
153
+
154
+
155
+
156
+ def blend_images(image1, image2, alpha=0.5):
157
+ """
158
+ Blends two images together.
159
+
160
+ :param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
161
+ :param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
162
+ :param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
163
+ where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
164
+
165
+ :return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
166
+ The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
167
+ """
168
+
169
+ # Use cv2.addWeighted to blend the two images.
170
+ # The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
171
+ blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
172
+
173
+ # Return the resulting blended image.
174
+ return blended
175
+
176
+
177
+
178
+ def color_string_to_rgb(color_string):
179
+ """
180
+ Converts a color string to an RGB tuple.
181
+
182
+ :param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
183
+ a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
184
+ :return:
185
+ A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
186
+ :raises:
187
+ ValueError: If the color string is not recognized.
188
+ """
189
+
190
+ # Remove any spaces in the color string
191
+ color_string = color_string.replace(' ', '')
192
+
193
+ # If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
194
+ if color_string.startswith('#'):
195
+ color_string = color_string[1:]
196
+ else:
197
+ # Handle shorthand single-letter color codes by converting them to hex values
198
+ # 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
199
+ if color_string == 'k': # Black
200
+ color_string = '000000'
201
+ elif color_string == 'r': # Red
202
+ color_string = 'ff0000'
203
+ elif color_string == 'g': # Green
204
+ color_string = '00ff00'
205
+ elif color_string == 'b': # Blue
206
+ color_string = '0000ff'
207
+ elif color_string == 'w': # White
208
+ color_string = 'ffffff'
209
+ else:
210
+ # Raise an error if the color string is not recognized
211
+ raise ValueError(f"Unknown color string {color_string}")
212
+
213
+ # Convert the first two characters to the red (R) value
214
+ r = int(color_string[:2], 16)
215
+
216
+ # Convert the next two characters to the green (G) value
217
+ g = int(color_string[2:4], 16)
218
+
219
+ # Convert the last two characters to the blue (B) value
220
+ b = int(color_string[4:], 16)
221
+
222
+ # Return the RGB values as a tuple
223
+ return (r, g, b)
224
+
225
+
226
+
227
+ def plot_heatmap(
228
+ coor,
229
+ similairty,
230
+ image_path=None,
231
+ patch_size=(256, 256),
232
+ save_path=None,
233
+ downsize=32,
234
+ cmap='turbo',
235
+ smooth=False,
236
+ boxes=None,
237
+ box_color='k',
238
+ box_thickness=2,
239
+ polygons=None,
240
+ polygons_color='k',
241
+ polygons_thickness=2,
242
+ image_alpha=0.5
243
+ ):
244
+ """
245
+ Plots a heatmap overlaid on an image based on given coordinates and similairty.
246
+
247
+ :param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
248
+ :param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
249
+ :param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
250
+ :param patch_size: Size of each patch in pixels (default is 256x256).
251
+ :param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
252
+ :param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
253
+ :param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
254
+ :param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
255
+ :param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
256
+ :param box_color: Color of the boxes. Default is black ('k').
257
+ :param box_thickness: Thickness of the box outlines.
258
+ :param polygons: List of polygons (N, 2) to draw on the heatmap.
259
+ :param polygons_color: Color of the polygon outlines. Default is black ('k').
260
+ :param polygons_thickness: Thickness of the polygon outlines.
261
+ :param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
262
+
263
+ :return:
264
+ - heatmap: The generated heatmap as a numpy array (RGB).
265
+ - image: The original image with overlaid polygons if provided.
266
+ """
267
+
268
+ # Read the background image (if provided), otherwise a blank image
269
+ image = cv2.imread(image_path)
270
+ image_size = (image.shape[0], image.shape[1]) # Get the size of the image
271
+ coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
272
+ patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
273
+
274
+ # Convert similairties to colors using the provided colormap
275
+ cmap = plt.get_cmap(cmap) # Get the colormap object
276
+ norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
277
+ colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
278
+
279
+ # Initialize a blank white heatmap the size of the image
280
+ heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
281
+
282
+ # Place the colored patches on the heatmap according to the coordinates and patch size
283
+ for i in range(len(coor)):
284
+ x, y = coor[i]
285
+ w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
286
+ w = w.astype(np.uint8) # Convert the color to uint8
287
+ heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
288
+
289
+ # If the image_alpha is greater than 0, blend the heatmap with the original image
290
+ if image_alpha > 0:
291
+ image = np.array(image)
292
+
293
+ # Pad the image if necessary to match the heatmap size
294
+ if image.shape[0] < heatmap.shape[0]:
295
+ pad = heatmap.shape[0] - image.shape[0]
296
+ image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
297
+ if image.shape[1] < heatmap.shape[1]:
298
+ pad = heatmap.shape[1] - heatmap.shape[1]
299
+ image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
300
+
301
+ # Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
302
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
303
+ image = image.astype(np.uint8)
304
+ heatmap = heatmap.astype(np.uint8)
305
+ heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
306
+
307
+ # If polygons are provided, draw them on the heatmap and image
308
+ if polygons is not None:
309
+ polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
310
+ image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
311
+ heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
312
+
313
+ return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
314
+ else:
315
+ return heatmap, image # Return the heatmap and image
316
+
317
+
318
+
319
+ def show_images_side_by_side(image1, image2, title1=None, title2=None):
320
+ """
321
+ Displays two images side by side in a single figure.
322
+
323
+ :param image1: The first image to display (as a numpy array).
324
+ :param image2: The second image to display (as a numpy array).
325
+ :param title1: The title for the first image. Default is None (no title).
326
+ :param title2: The title for the second image. Default is None (no title).
327
+ :return: Displays the images side by side.
328
+ """
329
+
330
+ # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
331
+ fig, ax = plt.subplots(1, 2, figsize=(15,8))
332
+
333
+ # Display the first image on the first subplot
334
+ ax[0].imshow(image1)
335
+
336
+ # Display the second image on the second subplot
337
+ ax[1].imshow(image2)
338
+
339
+ # Set the title for the first image (if provided)
340
+ ax[0].set_title(title1)
341
+
342
+ # Set the title for the second image (if provided)
343
+ ax[1].set_title(title2)
344
+
345
+ # Remove axis labels and ticks for both images to give a cleaner look
346
+ ax[0].axis('off')
347
+ ax[1].axis('off')
348
+
349
+ # Show the final figure with both images displayed side by side
350
+ plt.show()
351
+
352
+
353
+
354
+ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
355
+ """
356
+ Plots image with polygons.
357
+
358
+ :param fullres_img: The full-resolution image to display (as a numpy array).
359
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
360
+ :param linewidth: The thickness of the lines used to draw the polygons.
361
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
362
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
363
+ :return: Displays the image with ROI polygons overlaid.
364
+ """
365
+
366
+ # Create a new figure with a fixed size for displaying the image and annotations
367
+ plt.figure(figsize=(10, 10))
368
+
369
+ # Display the full-resolution image
370
+ plt.imshow(fullres_img)
371
+
372
+ # Loop through each polygon in roi_polygon and plot them on the image
373
+ for polygon in roi_polygon:
374
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
375
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
376
+
377
+ # Set the x-axis limits based on the provided tuple (xlim)
378
+ plt.xlim(xlim)
379
+
380
+ # Set the y-axis limits based on the provided tuple (ylim)
381
+ plt.ylim(ylim)
382
+
383
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
384
+ plt.gca().invert_yaxis()
385
+
386
+ # Turn off the axis to give a cleaner image display without ticks or labels
387
+ plt.axis('off')
388
+
389
+
390
+
391
+ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
392
+ """
393
+ Plots tissue type annotation heatmap.
394
+
395
+ :param st_ad: AnnData object containing coordinates in `obsm['spatial']`
396
+ and similarity scores in `obs['bulk_simi']`.
397
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
398
+ :param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
399
+ :param linewidth: The thickness of the lines used to draw the polygons.
400
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
401
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
402
+ :return: Displays the heatmap with polygons overlaid.
403
+ """
404
+
405
+ # Create a new figure with a fixed size for displaying the heatmap and annotations
406
+ plt.figure(figsize=(10, 10))
407
+
408
+ # Scatter plot for the spatial transcriptomics data.
409
+ # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
410
+ plt.scatter(
411
+ st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
412
+ c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
413
+ s=s, # Size of each marker
414
+ vmin=0.1, vmax=0.95, # Set the range for the color normalization
415
+ cmap='turbo' # Use the 'turbo' colormap for the heatmap
416
+ )
417
+
418
+ # Loop through each polygon in roi_polygon and plot them on the image
419
+ for polygon in roi_polygon:
420
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
421
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
422
+
423
+ # Set the x-axis limits based on the provided tuple (xlim)
424
+ plt.xlim(xlim)
425
+
426
+ # Set the y-axis limits based on the provided tuple (ylim)
427
+ plt.ylim(ylim)
428
+
429
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
430
+ plt.gca().invert_yaxis()
431
+
432
+ # Turn off the axis to give a cleaner image display without ticks or labels
433
+ plt.axis('off')
434
+
435
+
src/build/lib/loki/predex.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+
5
+ def predict_st_gene_expr(image_text_similarity, train_data):
6
+ """
7
+ Predicts ST gene expression by H&E image.
8
+
9
+ :param image_text_similarity: Numpy array of similarities between images and text features (shape: [n_samples, n_genes]).
10
+ :param train_data: Numpy array or DataFrame of training data used for making predictions (shape: [n_genes, n_shared_genes]).
11
+ :return: Numpy array or DataFrame containing the predicted gene expression levels for the samples.
12
+ """
13
+
14
+ # Compute the weighted sum of the train_data using image_text_similarity
15
+ weighted_sum = image_text_similarity @ train_data
16
+
17
+ # Compute the normalization factor (sum of the image-text similarities for each sample)
18
+ weights = image_text_similarity.sum(axis=1, keepdims=True)
19
+
20
+ # Normalize the predicted matrix to get weighted gene expression predictions
21
+ predicted_image_text_matrix = weighted_sum / weights
22
+
23
+ return predicted_image_text_matrix
24
+
25
+
src/build/lib/loki/preprocess.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scanpy as sc
2
+ import numpy as np
3
+ import pandas as pd
4
+ import json
5
+ import os
6
+ from PIL import Image
7
+
8
+
9
+
10
+ def generate_gene_df(ad, house_keeping_genes, todense=True):
11
+ """
12
+ Generates a DataFrame with the top 50 genes for each observation in an AnnData object.
13
+ It removes genes containing '.' or '-' in their names, as well as genes listed in
14
+ the provided `house_keeping_genes` DataFrame/Series under the 'genesymbol' column.
15
+
16
+ :param ad: An AnnData object containing gene expression data.
17
+ :type ad: anndata.AnnData
18
+ :param house_keeping_genes: DataFrame or Series with a 'genesymbol' column listing housekeeping genes to exclude.
19
+ :type house_keeping_genes: pandas.DataFrame or pandas.Series
20
+ :param todense: Whether to convert the sparse matrix (ad.X) to a dense matrix before creating a DataFrame.
21
+ :type todense: bool
22
+ :return: A DataFrame (`top_k_genes_str`) that contains a 'label' column. Each row in 'label' is a string
23
+ with the top 50 gene names (space-separated) for that observation.
24
+ :rtype: pandas.DataFrame
25
+ """
26
+
27
+ # Remove genes containing '.' in their names
28
+ ad = ad[:, ~ad.var.index.str.contains('.', regex=False)]
29
+ # Remove genes containing '-'
30
+ ad = ad[:, ~ad.var.index.str.contains('-', regex=False)]
31
+ # Exclude housekeeping genes
32
+ ad = ad[:, ~ad.var.index.isin(house_keeping_genes['genesymbol'])]
33
+
34
+ # Convert to dense if requested; otherwise use the data as-is
35
+ if todense:
36
+ expr = pd.DataFrame(ad.X.todense(), index=ad.obs.index, columns=ad.var.index)
37
+ else:
38
+ expr = pd.DataFrame(ad.X, index=ad.obs.index, columns=ad.var.index)
39
+
40
+ # For each row (observation), find the top 50 genes with the highest expression
41
+ top_k_genes = expr.apply(lambda s, n: pd.Series(s.nlargest(n).index), axis=1, n=50)
42
+
43
+ # Create a new DataFrame to store the labels (space-separated top gene names)
44
+ top_k_genes_str = pd.DataFrame()
45
+ top_k_genes_str['label'] = top_k_genes[top_k_genes.columns].astype(str) \
46
+ .apply(lambda x: ' '.join(x), axis=1)
47
+
48
+ return top_k_genes_str
49
+
50
+
51
+
52
+ def segment_patches(img_array, coord, patch_dir, height=20, width=20):
53
+ """
54
+ Extracts small image patches centered at specified coordinates and saves them as individual PNG files.
55
+
56
+ :param img_array: A NumPy array representing the full-resolution image. Shape is expected to be (H, W[, C]).
57
+ :type img_array: numpy.ndarray
58
+ :param coord: A pandas DataFrame containing patch center coordinates in columns "pixel_x" and "pixel_y".
59
+ The index corresponds to spot IDs. Example columns: ["pixel_x", "pixel_y"].
60
+ :type coord: pandas.DataFrame
61
+ :param patch_dir: Directory path where the patch images will be saved.
62
+ :type patch_dir: str
63
+ :param height: The patch's height in pixels (distance in the y-direction).
64
+ :type height: int
65
+ :param width: The patch's width in pixels (distance in the x-direction).
66
+ :type width: int
67
+ :return: None. The function saves image patches to `patch_dir` but does not return anything.
68
+ """
69
+
70
+ # Ensure the output directory exists; create it if it doesn't
71
+ if not os.path.exists(patch_dir):
72
+ os.makedirs(patch_dir)
73
+
74
+ # Extract the overall height and width of the image
75
+ yrange, xrange = img_array.shape[:2]
76
+
77
+ # Iterate through each coordinate in the DataFrame
78
+ for spot_idx in coord.index:
79
+ # Retrieve the center x and y coordinates for the current spot
80
+ ycenter, xcenter = coord.loc[spot_idx, ["pixel_x", "pixel_y"]]
81
+
82
+ # Compute the top-left (x1, y1) and bottom-right (x2, y2) boundaries of the patch
83
+ x1 = round(xcenter - width / 2)
84
+ y1 = round(ycenter - height / 2)
85
+ x2 = x1 + width
86
+ y2 = y1 + height
87
+
88
+ # Check if the patch boundaries go outside the image
89
+ if x1 < 0 or y1 < 0 or x2 > xrange or y2 > yrange:
90
+ print(f"Patch {spot_idx} is out of range and will be skipped.")
91
+ continue
92
+
93
+ # Extract the patch and convert to a PIL Image; cast to uint8 if needed
94
+ patch_img = Image.fromarray(img_array[y1:y2, x1:x2].astype(np.uint8))
95
+
96
+ # Create a filename for the patch image (e.g., "0_hires.png")
97
+ patch_name = f"{spot_idx}_hires.png"
98
+ patch_path = os.path.join(patch_dir, patch_name)
99
+
100
+ # Save the patch image to disk
101
+ patch_img.save(patch_path)
102
+
103
+
104
+
105
+ def read_gct(file_path):
106
+ """
107
+ Reads a GCT file, parses its dimensions, and returns the data as a pandas DataFrame.
108
+
109
+ :param file_path: The path to the GCT file to be read.
110
+ :return: A pandas DataFrame containing the GCT data, where the first two columns represent gene names and descriptions,
111
+ and the subsequent columns contain the expression data.
112
+ """
113
+
114
+ # Open the GCT file for reading
115
+ with open(file_path, 'r') as file:
116
+ # Read and ignore the first line (GCT version line)
117
+ file.readline()
118
+
119
+ # Read the second line which contains the dimensions of the data matrix
120
+ dims = file.readline().strip().split() # Split the dimensions line by whitespace
121
+ num_rows = int(dims[0]) # Number of data rows (genes)
122
+ num_cols = int(dims[1]) # Number of data columns (samples + metadata)
123
+
124
+ # Read the data starting from the third line, using pandas for tab-delimited data
125
+ # The first two columns in GCT files are "Name" and "Description" (gene identifiers and annotations)
126
+ data = pd.read_csv(file, sep='\t', header=0, nrows=num_rows)
127
+
128
+ # Return the loaded data as a pandas DataFrame
129
+ return data
130
+
131
+
132
+
133
+ def get_library_id(adata):
134
+ """
135
+ Retrieves the library ID from the AnnData object, assuming it contains spatial data.
136
+ The function will return the first library ID found in `adata.uns['spatial']`.
137
+
138
+ :param adata: AnnData object containing spatial information in `adata.uns['spatial']`.
139
+ :return: The first library ID found in `adata.uns['spatial']`.
140
+ :raises:
141
+ AssertionError: If 'spatial' is not present in `adata.uns`.
142
+ Logs an error if no library ID is found.
143
+ """
144
+
145
+ # Check if 'spatial' is present in adata.uns; raises an error if not found
146
+ assert 'spatial' in adata.uns, "spatial not present in adata.uns"
147
+
148
+ # Retrieve the list of library IDs (which are keys in the 'spatial' dictionary)
149
+ library_ids = adata.uns['spatial'].keys()
150
+
151
+ try:
152
+ # Attempt to return the first library ID (converting the keys object to a list)
153
+ library_id = list(library_ids)[0]
154
+ return library_id
155
+ except IndexError:
156
+ # If no library IDs exist, log an error message
157
+ logger.error('No library_id found in adata')
158
+
159
+
160
+
161
+ def get_scalefactors(adata, library_id=None):
162
+ """
163
+ Retrieves the scalefactors from the AnnData object for a given library ID. If no library ID is provided,
164
+ the function will automatically retrieve the first available library ID.
165
+
166
+ :param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
167
+ :param library_id: The library ID for which the scalefactors are to be retrieved. If not provided, it defaults to the first available ID.
168
+ :return: A dictionary containing scalefactors for the specified library ID.
169
+ """
170
+
171
+ # If no library_id is provided, retrieve the first available library ID
172
+ if library_id is None:
173
+ library_id = get_library_id(adata)
174
+
175
+ try:
176
+ # Attempt to retrieve the scalefactors for the specified library ID
177
+ scalef = adata.uns['spatial'][library_id]['scalefactors']
178
+ return scalef
179
+ except KeyError:
180
+ # Log an error if the scalefactors or library ID is not found
181
+ logger.error('scalefactors not found in adata')
182
+
183
+
184
+
185
+ def get_spot_diameter_in_pixels(adata, library_id=None):
186
+ """
187
+ Retrieves the spot diameter in pixels from the AnnData object's scalefactors for a given library ID.
188
+ If no library ID is provided, the function will automatically retrieve the first available library ID.
189
+
190
+ :param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
191
+ :param library_id: The library ID for which the spot diameter is to be retrieved. If not provided, defaults to the first available ID.
192
+
193
+ :return: The spot diameter in full resolution pixels, or None if not found.
194
+ """
195
+
196
+ # Get the scalefactors for the specified or default library ID
197
+ scalef = get_scalefactors(adata, library_id=library_id)
198
+
199
+ try:
200
+ # Attempt to retrieve the spot diameter in full resolution from the scalefactors
201
+ spot_diameter = scalef['spot_diameter_fullres']
202
+ return spot_diameter
203
+ except TypeError:
204
+ # Handle case where `scalef` is None or invalid (if get_scalefactors returned None)
205
+ pass
206
+ except KeyError:
207
+ # Log an error if the 'spot_diameter_fullres' key is not found in the scalefactors
208
+ logger.error('spot_diameter_fullres not found in adata')
209
+
210
+
211
+
212
+ def prepare_data_for_alignment(data_path, scale_type='tissue_hires_scalef'):
213
+ """
214
+ Prepares data for alignment by reading an AnnData object and preparing the high-resolution tissue image.
215
+
216
+ :param data_path: The path to the AnnData (.h5ad) file containing the Visium data.
217
+ :param scale_type: The type of scale factor to use (`tissue_hires_scalef` by default).
218
+
219
+ :return:
220
+ - ad: AnnData object containing the spatial transcriptomics data.
221
+ - ad_coor: Numpy array of scaled spatial coordinates (adjusted for the specified resolution).
222
+ - img: High-resolution tissue image, normalized to 8-bit unsigned integers.
223
+
224
+ :raises:
225
+ ValueError: If required data (e.g., scale factors, spatial coordinates, or images) is missing.
226
+ """
227
+
228
+ # Load the AnnData object from the specified file path
229
+ ad = sc.read_h5ad(data_path)
230
+
231
+ # Ensure the variable (gene) names are unique to avoid potential conflicts
232
+ ad.var_names_make_unique()
233
+
234
+ try:
235
+ # Retrieve the specified scale factor for spatial coordinates
236
+ scalef = get_scalefactors(ad)[scale_type]
237
+ except KeyError:
238
+ raise ValueError(f"Scale factor '{scale_type}' not found in ad.uns['spatial']")
239
+
240
+ # Scale the spatial coordinates using the specified scale factor
241
+ try:
242
+ ad_coor = np.array(ad.obsm['spatial']) * scalef
243
+ except KeyError:
244
+ raise ValueError("Spatial coordinates not found in ad.obsm['spatial']")
245
+
246
+ # Retrieve the high-resolution tissue image
247
+ try:
248
+ img = ad.uns['spatial'][get_library_id(ad)]['images']['hires']
249
+ except KeyError:
250
+ raise ValueError("High-resolution image not found in ad.uns['spatial']")
251
+
252
+ # If the image values are normalized to [0, 1], convert to 8-bit format for compatibility
253
+ if img.max() < 1.1:
254
+ img = (img * 255).astype('uint8')
255
+
256
+ return ad, ad_coor, img
257
+
258
+
259
+
260
+ def load_data_for_annotation(st_data_path, json_path, in_tissue=True):
261
+ """
262
+ Loads spatial transcriptomics (ST) data from an .h5ad file and prepares it for annotation.
263
+
264
+ :param sample_type: The type or category of the sample (used to locate the data in the directory structure).
265
+ :param sample_name: The name of the sample (used to locate specific files).
266
+ :param in_tissue: Boolean flag to filter the data to include only spots that are in tissue. Default is True.
267
+
268
+ :return:
269
+ - st_ad: AnnData object containing the spatial transcriptomics data, with spatial coordinates in `obs`.
270
+ - library_id: The library ID associated with the spatial data.
271
+ - roi_polygon: Region of interest polygon loaded from a JSON file for further annotation or analysis.
272
+ """
273
+
274
+ # Load the spatial transcriptomics data into an AnnData object
275
+ st_ad = sc.read_h5ad(st_data_path)
276
+
277
+ # Optionally filter the data to include only spots that are within the tissue
278
+ if in_tissue:
279
+ st_ad = st_ad[st_ad.obs['in_tissue'] == 1]
280
+
281
+ # Initialize pixel coordinates for spatial information
282
+ st_ad.obs[["pixel_y", "pixel_x"]] = None # Ensure the columns exist
283
+ st_ad.obs[["pixel_y", "pixel_x"]] = st_ad.obsm['spatial'] # Copy spatial coordinates into obs
284
+
285
+ # Retrieve the library ID associated with the spatial data
286
+ library_id = get_library_id(st_ad)
287
+
288
+ # Load the region of interest (ROI) polygon from a JSON file
289
+ with open(json_path) as f:
290
+ roi_polygon = json.load(f)
291
+
292
+ return st_ad, library_id, roi_polygon
293
+
294
+
295
+
296
+ def read_polygons(file_path, slide_id):
297
+ """
298
+ Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
299
+
300
+ :param file_path: Path to the JSON file containing polygon configurations.
301
+ :param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
302
+ :return:
303
+ - polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
304
+ - polygon_colors: A list of color values corresponding to each polygon.
305
+ - polygon_thickness: A list of thickness values for each polygon's border.
306
+ """
307
+
308
+ # Open the JSON file and load the polygon configurations into a Python dictionary
309
+ with open(file_path, 'r') as f:
310
+ polygons_configs = json.load(f)
311
+
312
+ # Check if the given slide_id exists in the polygon configurations
313
+ if slide_id not in polygons_configs:
314
+ return None, None, None # If slide_id is not found, return None for all outputs
315
+
316
+ # Extract the polygon coordinates, colors, and thicknesses for the given slide_id
317
+ polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
318
+ polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
319
+ polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
320
+
321
+ # Return the polygons, their colors, and their thicknesses
322
+ return polygons, polygon_colors, polygon_thickness
323
+
324
+
src/build/lib/loki/retrieve.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3):
6
+ """
7
+ Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings.
8
+
9
+ :param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]).
10
+ :param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]).
11
+ :param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column.
12
+ :param k: The number of top similar samples to retrieve. Default is 3.
13
+ :return: A list of the filenames or indices corresponding to the top-k similar samples.
14
+ """
15
+
16
+ # Compute the dot product (similarity) between the image embeddings and all ST embeddings
17
+ dot_similarity = image_embeddings @ all_text_embeddings.T
18
+
19
+ # Retrieve the top-k most similar samples by similarity score (dot product)
20
+ values, indices = torch.topk(dot_similarity.squeeze(0), k)
21
+
22
+ # Extract the image filenames or indices from the DataFrame based on the top-k matches
23
+ image_filenames = dataframe['img_idx'].values
24
+ matches = [image_filenames[idx] for idx in indices]
25
+
26
+ return matches
27
+
28
+
src/build/lib/loki/utilities.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import json
8
+ import cv2
9
+ from sklearn.decomposition import PCA
10
+ from open_clip import create_model_from_pretrained, get_tokenizer
11
+
12
+
13
+
14
+ def load_model(model_path, device):
15
+ model, preprocess = create_model_from_pretrained("coca_ViT-L-14", device=device, pretrained=model_path)
16
+ tokenizer = get_tokenizer('coca_ViT-L-14')
17
+
18
+ return model, preprocess, tokenizer
19
+
20
+
21
+
22
+ def encode_image(model, preprocess, image):
23
+ image_input = torch.stack([preprocess(image)])
24
+ with torch.no_grad():
25
+ image_features = model.encode_image(image_input)
26
+ image_embeddings = F.normalize(image_features, p=2, dim=-1)
27
+
28
+ return image_embeddings
29
+
30
+
31
+
32
+ def encode_image_patches(model, preprocess, data_dir, img_list):
33
+ image_embeddings = []
34
+ for img_name in img_list:
35
+ image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
36
+ image = Image.open(image_path)
37
+ image_features = encode_image(model, preprocess, image)
38
+ image_embeddings.append(image_features)
39
+ image_embeddings = torch.from_numpy(np.array(image_embeddings))
40
+ image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
41
+ return image_embeddings
42
+
43
+
44
+
45
+ def encode_text(model, tokenizer, text):
46
+ text_input = tokenizer(text)
47
+ with torch.no_grad():
48
+ text_features = model.encode_text(text_input)
49
+ text_embeddings = F.normalize(text_features, p=2, dim=-1)
50
+
51
+ return text_embeddings
52
+
53
+
54
+
55
+ def encode_text_df(model, tokenizer, df, col_name):
56
+ text_embeddings = []
57
+ for idx in df.index:
58
+ text = df[df.index==idx][col_name][0]
59
+ text_features = encode_text(model, tokenizer, text)
60
+ text_embeddings.append(text_features)
61
+ text_embeddings = torch.from_numpy(np.array(text_embeddings))
62
+ text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
63
+ return text_embeddings
64
+
65
+
66
+
67
+ def get_pca_by_fit(tar_features, src_features):
68
+ """
69
+ Applies PCA to target features and transforms both target and source features using the fitted PCA model.
70
+ Combines the PCA-transformed features from both target and source datasets and returns the combined data
71
+ along with batch labels indicating the origin of each sample.
72
+
73
+ :param tar_features: Numpy array of target features (samples by features).
74
+ :param src_features: Numpy array of source features (samples by features).
75
+ :return:
76
+ - pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
77
+ - pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
78
+ """
79
+
80
+ pca = PCA(n_components=3)
81
+
82
+ # Fit the PCA model on the target features (transposed to fit on features)
83
+ pca_fit_tar = pca.fit(tar_features.T)
84
+
85
+ # Transform the target and source features using the fitted PCA model
86
+ pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
87
+ pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
88
+
89
+ # Combine the PCA-transformed target and source features
90
+ pca_comb_features = np.concatenate((pca_tar, pca_src))
91
+
92
+ # Create a batch label array: 0 for target features, 1 for source features
93
+ pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
94
+
95
+ return pca_comb_features, pca_comb_features_batch
96
+
97
+
98
+
99
+ def cap_quantile(weight, cap_max=None, cap_min=None):
100
+ """
101
+ Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
102
+ If the quantile thresholds are provided, the function will replace values above or below these thresholds
103
+ with the corresponding quantile values.
104
+
105
+ :param weight: Numpy array of weights to be capped.
106
+ :param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
107
+ If None, no maximum capping will be applied.
108
+ :param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
109
+ If None, no minimum capping will be applied.
110
+ :return: Numpy array with the values capped at the specified quantiles.
111
+ """
112
+
113
+ # If a maximum cap is specified, calculate the value at the specified cap_max quantile
114
+ if cap_max is not None:
115
+ cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
116
+
117
+ # If a minimum cap is specified, calculate the value at the specified cap_min quantile
118
+ if cap_min is not None:
119
+ cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
120
+
121
+ # Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
122
+ weight = np.minimum(weight, cap_max)
123
+
124
+ # Cap the values in 'weight' array to not go below the minimum cap (cap_min)
125
+ weight = np.maximum(weight, cap_min)
126
+
127
+ return weight
128
+
129
+
130
+
131
+ def read_polygons(file_path, slide_id):
132
+ """
133
+ Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
134
+
135
+ :param file_path: Path to the JSON file containing polygon configurations.
136
+ :param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
137
+ :return:
138
+ - polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
139
+ - polygon_colors: A list of color values corresponding to each polygon.
140
+ - polygon_thickness: A list of thickness values for each polygon's border.
141
+ """
142
+
143
+ # Open the JSON file and load the polygon configurations into a Python dictionary
144
+ with open(file_path, 'r') as f:
145
+ polygons_configs = json.load(f)
146
+
147
+ # Check if the given slide_id exists in the polygon configurations
148
+ if slide_id not in polygons_configs:
149
+ return None, None, None # If slide_id is not found, return None for all outputs
150
+
151
+ # Extract the polygon coordinates, colors, and thicknesses for the given slide_id
152
+ polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
153
+ polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
154
+ polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
155
+
156
+ # Return the polygons, their colors, and their thicknesses
157
+ return polygons, polygon_colors, polygon_thickness
158
+
159
+
src/build/lib/loki/utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import json
8
+ import cv2
9
+ from sklearn.decomposition import PCA
10
+ from open_clip import create_model_from_pretrained, get_tokenizer
11
+
12
+
13
+
14
+ def load_model(model_path, device):
15
+ """
16
+ Loads a pretrained CoCa (CLIP-like) model, along with its preprocessing function and tokenizer,
17
+ using the specified model checkpoint.
18
+
19
+ :param model_path: File path or URL to the pretrained model checkpoint. This is passed to
20
+ `create_model_from_pretrained` as the `pretrained` argument.
21
+ :type model_path: str
22
+ :param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
23
+ :type device: str or torch.device
24
+ :return: A tuple `(model, preprocess, tokenizer)` where:
25
+ - model: The loaded CoCa model.
26
+ - preprocess: A function or transform that preprocesses input data for the model.
27
+ - tokenizer: A tokenizer appropriate for textual input to the model.
28
+ :rtype: (nn.Module, callable, callable)
29
+ """
30
+ # Create the model and its preprocessing transform from the specified checkpoint
31
+ model, preprocess = create_model_from_pretrained(
32
+ "coca_ViT-L-14", device=device, pretrained=model_path
33
+ )
34
+
35
+ # Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
36
+ tokenizer = get_tokenizer('coca_ViT-L-14')
37
+
38
+ return model, preprocess, tokenizer
39
+
40
+
41
+
42
+ def encode_image(model, preprocess, image):
43
+ """
44
+ Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
45
+
46
+ :param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
47
+ :type model: torch.nn.Module
48
+ :param preprocess: A preprocessing function that transforms the input image into a tensor
49
+ suitable for the model. Typically something returning a PyTorch tensor.
50
+ :type preprocess: callable
51
+ :param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
52
+ :type image: PIL.Image.Image or numpy.ndarray
53
+ :return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
54
+ :rtype: torch.Tensor
55
+ """
56
+ # Preprocess the image, then stack to create a batch of size 1
57
+ image_input = torch.stack([preprocess(image)])
58
+
59
+ # Generate the image features without gradient tracking
60
+ with torch.no_grad():
61
+ image_features = model.encode_image(image_input)
62
+
63
+ # Normalize embeddings across the feature dimension (L2 normalization)
64
+ image_embeddings = F.normalize(image_features, p=2, dim=-1)
65
+
66
+ return image_embeddings
67
+
68
+
69
+
70
+ def encode_image_patches(model, preprocess, data_dir, img_list):
71
+ """
72
+ Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
73
+
74
+ :param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
75
+ :type model: torch.nn.Module
76
+ :param preprocess: A preprocessing function that transforms the input image into a tensor
77
+ suitable for the model. Typically something returning a PyTorch tensor.
78
+ :type preprocess: callable
79
+ :param data_dir: The base directory containing image data.
80
+ :type data_dir: str
81
+ :param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
82
+ stored in `data_dir/demo_data/patch/`.
83
+ :type img_list: list[str]
84
+ :return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
85
+ for each image in `img_list`.
86
+ :rtype: torch.Tensor
87
+ """
88
+
89
+ # Prepare a list to hold each image's feature embedding
90
+ image_embeddings = []
91
+
92
+ # Loop through each image name in the provided list
93
+ for img_name in img_list:
94
+ # Build the path to the patch image and open it
95
+ image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
96
+ image = Image.open(image_path)
97
+
98
+ # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
99
+ image_features = encode_image(model, preprocess, image)
100
+
101
+ # Accumulate the feature embeddings in the list
102
+ image_embeddings.append(image_features)
103
+
104
+ # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
105
+ # Resulting shape will be (N, 1, embedding_dim)
106
+ image_embeddings = torch.from_numpy(np.array(image_embeddings))
107
+
108
+ # Normalize all embeddings across the feature dimension (L2 normalization)
109
+ image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
110
+
111
+ return image_embeddings
112
+
113
+
114
+
115
+ def encode_text(model, tokenizer, text):
116
+ """
117
+ Encodes text into a normalized feature embedding using a specified model and tokenizer.
118
+
119
+ :param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
120
+ :type model: torch.nn.Module
121
+ :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
122
+ Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
123
+ :type tokenizer: callable
124
+ :param text: The input text (string or list of strings) to be encoded.
125
+ :type text: str or list[str]
126
+ :return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
127
+ :rtype: torch.Tensor
128
+ """
129
+
130
+ # Convert text to the appropriate tokenized representation
131
+ text_input = tokenizer(text)
132
+
133
+ # Run the model in no-grad mode (not tracking gradients, saving memory and compute)
134
+ with torch.no_grad():
135
+ text_features = model.encode_text(text_input)
136
+
137
+ # Normalize embeddings to unit length
138
+ text_embeddings = F.normalize(text_features, p=2, dim=-1)
139
+
140
+ return text_embeddings
141
+
142
+
143
+
144
+ def encode_text_df(model, tokenizer, df, col_name):
145
+ """
146
+ Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
147
+ returning a PyTorch tensor of normalized text embeddings.
148
+
149
+ :param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
150
+ :type model: torch.nn.Module
151
+ :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
152
+ :type tokenizer: callable
153
+ :param df: A pandas DataFrame from which text will be extracted.
154
+ :type df: pandas.DataFrame
155
+ :param col_name: The name of the column in `df` that contains the text to be encoded.
156
+ :type col_name: str
157
+ :return: A PyTorch tensor containing the L2-normalized text embeddings,
158
+ where the shape is (number_of_rows, embedding_dim).
159
+ :rtype: torch.Tensor
160
+ """
161
+
162
+ # Prepare a list to hold each row's text embedding
163
+ text_embeddings = []
164
+
165
+ # Loop through each index in the DataFrame
166
+ for idx in df.index:
167
+ # Retrieve text from the specified column for the current row
168
+ text = df[df.index == idx][col_name][0]
169
+
170
+ # Encode the text using the provided model and tokenizer
171
+ text_features = encode_text(model, tokenizer, text)
172
+
173
+ # Accumulate the embedding tensor
174
+ text_embeddings.append(text_features)
175
+
176
+ # Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
177
+ text_embeddings = torch.from_numpy(np.array(text_embeddings))
178
+
179
+ # Normalize embeddings to unit length across the feature dimension
180
+ text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
181
+
182
+ return text_embeddings
183
+
184
+
185
+
186
+ def get_pca_by_fit(tar_features, src_features):
187
+ """
188
+ Applies PCA to target features and transforms both target and source features using the fitted PCA model.
189
+ Combines the PCA-transformed features from both target and source datasets and returns the combined data
190
+ along with batch labels indicating the origin of each sample.
191
+
192
+ :param tar_features: Numpy array of target features (samples by features).
193
+ :param src_features: Numpy array of source features (samples by features).
194
+ :return:
195
+ - pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
196
+ - pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
197
+ """
198
+
199
+ pca = PCA(n_components=3)
200
+
201
+ # Fit the PCA model on the target features (transposed to fit on features)
202
+ pca_fit_tar = pca.fit(tar_features.T)
203
+
204
+ # Transform the target and source features using the fitted PCA model
205
+ pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
206
+ pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
207
+
208
+ # Combine the PCA-transformed target and source features
209
+ pca_comb_features = np.concatenate((pca_tar, pca_src))
210
+
211
+ # Create a batch label array: 0 for target features, 1 for source features
212
+ pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
213
+
214
+ return pca_comb_features, pca_comb_features_batch
215
+
216
+
217
+
218
+ def cap_quantile(weight, cap_max=None, cap_min=None):
219
+ """
220
+ Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
221
+ If the quantile thresholds are provided, the function will replace values above or below these thresholds
222
+ with the corresponding quantile values.
223
+
224
+ :param weight: Numpy array of weights to be capped.
225
+ :param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
226
+ If None, no maximum capping will be applied.
227
+ :param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
228
+ If None, no minimum capping will be applied.
229
+ :return: Numpy array with the values capped at the specified quantiles.
230
+ """
231
+
232
+ # If a maximum cap is specified, calculate the value at the specified cap_max quantile
233
+ if cap_max is not None:
234
+ cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
235
+
236
+ # If a minimum cap is specified, calculate the value at the specified cap_min quantile
237
+ if cap_min is not None:
238
+ cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
239
+
240
+ # Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
241
+ weight = np.minimum(weight, cap_max)
242
+
243
+ # Cap the values in 'weight' array to not go below the minimum cap (cap_min)
244
+ weight = np.maximum(weight, cap_min)
245
+
246
+ return weight
247
+
248
+
249
+
250
+ def read_polygons(file_path, slide_id):
251
+ """
252
+ Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
253
+
254
+ :param file_path: Path to the JSON file containing polygon configurations.
255
+ :param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
256
+ :return:
257
+ - polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
258
+ - polygon_colors: A list of color values corresponding to each polygon.
259
+ - polygon_thickness: A list of thickness values for each polygon's border.
260
+ """
261
+
262
+ # Open the JSON file and load the polygon configurations into a Python dictionary
263
+ with open(file_path, 'r') as f:
264
+ polygons_configs = json.load(f)
265
+
266
+ # Check if the given slide_id exists in the polygon configurations
267
+ if slide_id not in polygons_configs:
268
+ return None, None, None # If slide_id is not found, return None for all outputs
269
+
270
+ # Extract the polygon coordinates, colors, and thicknesses for the given slide_id
271
+ polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
272
+ polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
273
+ polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
274
+
275
+ # Return the polygons, their colors, and their thicknesses
276
+ return polygons, polygon_colors, polygon_thickness
277
+
278
+
src/dist/loki-0.0.1-py3-none-any.whl ADDED
Binary file (22.2 kB). View file
 
src/dist/loki-0.0.1.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98f4615e981aeb895088cb71b1f358a9d6470302043c7cab8bc15396b9cbbe0d
3
+ size 20339
src/loki.egg-info/PKG-INFO ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: loki
3
+ Version: 0.0.1
4
+ Summary: The Loki platform offers 5 core functions: tissue alignment, cell type decomposition, tissue annotation, image-transcriptomics retrieval, and ST gene expression prediction
5
+ Author: Weiqing Chen
6
+ Author-email: wec4005@med.cornell.edu
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.9
11
+ Requires-Dist: anndata==0.10.9
12
+ Requires-Dist: matplotlib==3.9.2
13
+ Requires-Dist: numpy==1.25.0
14
+ Requires-Dist: pandas==2.2.3
15
+ Requires-Dist: opencv-python==4.10.0.84
16
+ Requires-Dist: pycpd==2.0.0
17
+ Requires-Dist: torch==2.3.1
18
+ Requires-Dist: tangram-sc==1.0.4
19
+ Requires-Dist: tqdm==4.66.5
20
+ Requires-Dist: torchvision==0.18.1
21
+ Requires-Dist: open_clip_torch==2.26.1
22
+ Requires-Dist: pillow==10.4.0
23
+ Requires-Dist: ipykernel==6.29.5
src/loki.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.py
3
+ loki/__init__.py
4
+ loki/align.py
5
+ loki/annotate.py
6
+ loki/decompose.py
7
+ loki/plot.py
8
+ loki/predex.py
9
+ loki/preprocess.py
10
+ loki/retrieve.py
11
+ loki/utils.py
12
+ loki.egg-info/PKG-INFO
13
+ loki.egg-info/SOURCES.txt
14
+ loki.egg-info/dependency_links.txt
15
+ loki.egg-info/requires.txt
16
+ loki.egg-info/top_level.txt
src/loki.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/loki.egg-info/requires.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anndata==0.10.9
2
+ matplotlib==3.9.2
3
+ numpy==1.25.0
4
+ pandas==2.2.3
5
+ opencv-python==4.10.0.84
6
+ pycpd==2.0.0
7
+ torch==2.3.1
8
+ tangram-sc==1.0.4
9
+ tqdm==4.66.5
10
+ torchvision==0.18.1
11
+ open_clip_torch==2.26.1
12
+ pillow==10.4.0
13
+ ipykernel==6.29.5
src/loki.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ loki
src/loki/__init__.py ADDED
File without changes
src/loki/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (139 Bytes). View file
 
src/loki/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
src/loki/__pycache__/align.cpython-39.pyc ADDED
Binary file (17.3 kB). View file
 
src/loki/__pycache__/annotate.cpython-39.pyc ADDED
Binary file (2.99 kB). View file
 
src/loki/__pycache__/decompose.cpython-39.pyc ADDED
Binary file (4.72 kB). View file
 
src/loki/__pycache__/deconv.cpython-39.pyc ADDED
Binary file (3.52 kB). View file
 
src/loki/__pycache__/plot.cpython-39.pyc ADDED
Binary file (13.6 kB). View file
 
src/loki/__pycache__/predex.cpython-39.pyc ADDED
Binary file (904 Bytes). View file
 
src/loki/__pycache__/preprocess.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
src/loki/__pycache__/retrieve.cpython-39.pyc ADDED
Binary file (1.38 kB). View file
 
src/loki/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.44 kB). View file
 
src/loki/align.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pycpd
2
+ from builtins import super
3
+ import numbers
4
+ import numpy as np
5
+ import cv2
6
+
7
+ class EMRegistration(object):
8
+ """
9
+ Expectation maximization point cloud registration.
10
+ Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
11
+ https://github.com/siavashk/pycpd
12
+
13
+
14
+ Attributes
15
+ ----------
16
+ X: numpy array
17
+ NxD array of target points.
18
+
19
+ Y: numpy array
20
+ MxD array of source points.
21
+
22
+ TY: numpy array
23
+ MxD array of transformed source points.
24
+
25
+ sigma2: float (positive)
26
+ Initial variance of the Gaussian mixture model.
27
+
28
+ N: int
29
+ Number of target points.
30
+
31
+ M: int
32
+ Number of source points.
33
+
34
+ D: int
35
+ Dimensionality of source and target points
36
+
37
+ iteration: int
38
+ The current iteration throughout registration.
39
+
40
+ max_iterations: int
41
+ Registration will terminate once the algorithm has taken this
42
+ many iterations.
43
+
44
+ tolerance: float (positive)
45
+ Registration will terminate once the difference between
46
+ consecutive objective function values falls within this tolerance.
47
+
48
+ w: float (between 0 and 1)
49
+ Contribution of the uniform distribution to account for outliers.
50
+ Valid values span 0 (inclusive) and 1 (exclusive).
51
+
52
+ q: float
53
+ The objective function value that represents the misalignment between source
54
+ and target point clouds.
55
+
56
+ diff: float (positive)
57
+ The absolute difference between the current and previous objective function values.
58
+
59
+ P: numpy array
60
+ MxN array of probabilities.
61
+ P[m, n] represents the probability that the m-th source point
62
+ corresponds to the n-th target point.
63
+
64
+ Pt1: numpy array
65
+ Nx1 column array.
66
+ Multiplication result between the transpose of P and a column vector of all 1s.
67
+
68
+ P1: numpy array
69
+ Mx1 column array.
70
+ Multiplication result between P and a column vector of all 1s.
71
+
72
+ Np: float (positive)
73
+ The sum of all elements in P.
74
+
75
+ """
76
+
77
+ def __init__(self, X, Y, sigma2=None, max_iterations=None, tolerance=None, w=None, *args, **kwargs):
78
+ if type(X) is not np.ndarray or X.ndim != 2:
79
+ raise ValueError(
80
+ "The target point cloud (X) must be at a 2D numpy array.")
81
+
82
+ if type(Y) is not np.ndarray or Y.ndim != 2:
83
+ raise ValueError(
84
+ "The source point cloud (Y) must be a 2D numpy array.")
85
+
86
+ if X.shape[1] != Y.shape[1]:
87
+ raise ValueError(
88
+ "Both point clouds need to have the same number of dimensions.")
89
+
90
+ if sigma2 is not None and (not isinstance(sigma2, numbers.Number) or sigma2 <= 0):
91
+ raise ValueError(
92
+ "Expected a positive value for sigma2 instead got: {}".format(sigma2))
93
+
94
+ if max_iterations is not None and (not isinstance(max_iterations, numbers.Number) or max_iterations < 0):
95
+ raise ValueError(
96
+ "Expected a positive integer for max_iterations instead got: {}".format(max_iterations))
97
+ elif isinstance(max_iterations, numbers.Number) and not isinstance(max_iterations, int):
98
+ warn("Received a non-integer value for max_iterations: {}. Casting to integer.".format(max_iterations))
99
+ max_iterations = int(max_iterations)
100
+
101
+ if tolerance is not None and (not isinstance(tolerance, numbers.Number) or tolerance < 0):
102
+ raise ValueError(
103
+ "Expected a positive float for tolerance instead got: {}".format(tolerance))
104
+
105
+ if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1):
106
+ raise ValueError(
107
+ "Expected a value between 0 (inclusive) and 1 (exclusive) for w instead got: {}".format(w))
108
+
109
+ self.X = X
110
+ self.Y = Y
111
+ self.TY = Y
112
+ self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2
113
+ (self.N, self.D) = self.X.shape
114
+ (self.M, _) = self.Y.shape
115
+ self.tolerance = 0.001 if tolerance is None else tolerance
116
+ self.w = 0.0 if w is None else w
117
+ self.max_iterations = 100 if max_iterations is None else max_iterations
118
+ self.iteration = 0
119
+ self.diff = np.inf
120
+ self.q = np.inf
121
+ self.P = np.zeros((self.M, self.N))
122
+ self.Pt1 = np.zeros((self.N, ))
123
+ self.P1 = np.zeros((self.M, ))
124
+ self.PX = np.zeros((self.M, self.D))
125
+ self.Np = 0
126
+
127
+ def register(self, callback=lambda **kwargs: None):
128
+ """
129
+ Perform the EM registration.
130
+
131
+ Attributes
132
+ ----------
133
+ callback: function
134
+ A function that will be called after each iteration.
135
+ Can be used to visualize the registration process.
136
+
137
+ Returns
138
+ -------
139
+ self.TY: numpy array
140
+ MxD array of transformed source points.
141
+
142
+ registration_parameters:
143
+ Returned params dependent on registration method used.
144
+ """
145
+ self.transform_point_cloud()
146
+ while self.iteration < self.max_iterations and self.diff > self.tolerance:
147
+ self.iterate()
148
+ if callable(callback):
149
+ kwargs = {'iteration': self.iteration,
150
+ 'error': self.q, 'X': self.X, 'Y': self.TY}
151
+ callback(**kwargs)
152
+
153
+ return self.TY, self.get_registration_parameters()
154
+
155
+ def get_registration_parameters(self):
156
+ """
157
+ Placeholder for child classes.
158
+ """
159
+ raise NotImplementedError(
160
+ "Registration parameters should be defined in child classes.")
161
+
162
+ def update_transform(self):
163
+ """
164
+ Placeholder for child classes.
165
+ """
166
+ raise NotImplementedError(
167
+ "Updating transform parameters should be defined in child classes.")
168
+
169
+ def transform_point_cloud(self):
170
+ """
171
+ Placeholder for child classes.
172
+ """
173
+ raise NotImplementedError(
174
+ "Updating the source point cloud should be defined in child classes.")
175
+
176
+ def update_variance(self):
177
+ """
178
+ Placeholder for child classes.
179
+ """
180
+ raise NotImplementedError(
181
+ "Updating the Gaussian variance for the mixture model should be defined in child classes.")
182
+
183
+ def iterate(self):
184
+ """
185
+ Perform one iteration of the EM algorithm.
186
+ """
187
+ self.expectation()
188
+ self.maximization()
189
+ self.iteration += 1
190
+
191
+ def expectation(self):
192
+ """
193
+ Compute the expectation step of the EM algorithm.
194
+ """
195
+ P = np.sum((self.X[None, :, :] - self.TY[:, None, :])**2, axis=2) # (M, N)
196
+ P = np.exp(-P/(2*self.sigma2))
197
+ c = (2*np.pi*self.sigma2)**(self.D/2)*self.w/(1. - self.w)*self.M/self.N
198
+
199
+ den = np.sum(P, axis = 0, keepdims = True) # (1, N)
200
+ den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c
201
+
202
+ self.P = np.divide(P, den)
203
+ self.Pt1 = np.sum(self.P, axis=0)
204
+ self.P1 = np.sum(self.P, axis=1)
205
+ self.Np = np.sum(self.P1)
206
+ self.PX = np.matmul(self.P, self.X)
207
+
208
+ def maximization(self):
209
+ """
210
+ Compute the maximization step of the EM algorithm.
211
+ """
212
+ self.update_transform()
213
+ self.transform_point_cloud()
214
+ self.update_variance()
215
+
216
+
217
+ class DeformableRegistration(EMRegistration):
218
+ """
219
+ Deformable registration.
220
+ Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm:
221
+ https://github.com/siavashk/pycpd
222
+
223
+ Attributes
224
+ ----------
225
+ alpha: float (positive)
226
+ Represents the trade-off between the goodness of maximum likelihood fit and regularization.
227
+
228
+ beta: float(positive)
229
+ Width of the Gaussian kernel.
230
+
231
+ low_rank: bool
232
+ Whether to use low rank approximation.
233
+
234
+ num_eig: int
235
+ Number of eigenvectors to use in lowrank calculation.
236
+ """
237
+
238
+ def __init__(self, alpha=None, beta=None, low_rank=False, num_eig=100, *args, **kwargs):
239
+ super().__init__(*args, **kwargs)
240
+ if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0):
241
+ raise ValueError(
242
+ "Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha))
243
+
244
+ if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0):
245
+ raise ValueError(
246
+ "Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta))
247
+
248
+ self.alpha = 2 if alpha is None else alpha
249
+ self.beta = 2 if beta is None else beta
250
+ self.W = np.zeros((self.M, self.D))
251
+ self.G = gaussian_kernel(self.Y, self.beta)
252
+ self.low_rank = low_rank
253
+ self.num_eig = num_eig
254
+ if self.low_rank is True:
255
+ self.Q, self.S = low_rank_eigen(self.G, self.num_eig)
256
+ self.inv_S = np.diag(1./self.S)
257
+ self.S = np.diag(self.S)
258
+ self.E = 0.
259
+
260
+ def update_transform(self):
261
+ """
262
+ Calculate a new estimate of the deformable transformation.
263
+ See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.
264
+
265
+ """
266
+ if self.low_rank is False:
267
+ A = np.dot(np.diag(self.P1), self.G) + \
268
+ self.alpha * self.sigma2 * np.eye(self.M)
269
+ B = self.PX - np.dot(np.diag(self.P1), self.Y)
270
+ self.W = np.linalg.solve(A, B)
271
+
272
+ elif self.low_rank is True:
273
+ # Matlab code equivalent can be found here:
274
+ # https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
275
+ dP = np.diag(self.P1)
276
+ dPQ = np.matmul(dP, self.Q)
277
+ F = self.PX - np.matmul(dP, self.Y)
278
+
279
+ self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
280
+ np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
281
+ (np.matmul(self.Q.T, F))))))
282
+ QtW = np.matmul(self.Q.T, self.W)
283
+ self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))
284
+
285
+ def transform_point_cloud(self, Y=None):
286
+ """
287
+ Update a point cloud using the new estimate of the deformable transformation.
288
+
289
+ Attributes
290
+ ----------
291
+ Y: numpy array, optional
292
+ Array of points to transform - use to predict on new set of points.
293
+ Best for predicting on new points not used to run initial registration.
294
+ If None, self.Y used.
295
+
296
+ Returns
297
+ -------
298
+ If Y is None, returns None.
299
+ Otherwise, returns the transformed Y.
300
+
301
+
302
+ """
303
+ self.W[:,2:]=0
304
+ if Y is not None:
305
+ G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y)
306
+ return Y + np.dot(G, self.W)
307
+ else:
308
+ if self.low_rank is False:
309
+ self.TY = self.Y + np.dot(self.G, self.W)
310
+
311
+ elif self.low_rank is True:
312
+ self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)))
313
+ return
314
+
315
+
316
+ def update_variance(self):
317
+ """
318
+ Update the variance of the mixture model using the new estimate of the deformable transformation.
319
+ See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
320
+
321
+ """
322
+ qprev = self.sigma2
323
+
324
+ # The original CPD paper does not explicitly calculate the objective functional.
325
+ # This functional will include terms from both the negative log-likelihood and
326
+ # the Gaussian kernel used for regularization.
327
+ self.q = np.inf
328
+
329
+ xPx = np.dot(np.transpose(self.Pt1), np.sum(
330
+ np.multiply(self.X, self.X), axis=1))
331
+ yPy = np.dot(np.transpose(self.P1), np.sum(
332
+ np.multiply(self.TY, self.TY), axis=1))
333
+ trPXY = np.sum(np.multiply(self.TY, self.PX))
334
+
335
+ self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D)
336
+
337
+ if self.sigma2 <= 0:
338
+ self.sigma2 = self.tolerance / 10
339
+
340
+ # Here we use the difference between the current and previous
341
+ # estimate of the variance as a proxy to test for convergence.
342
+ self.diff = np.abs(self.sigma2 - qprev)
343
+
344
+ def get_registration_parameters(self):
345
+ """
346
+ Return the current estimate of the deformable transformation parameters.
347
+
348
+
349
+ Returns
350
+ -------
351
+ self.G: numpy array
352
+ Gaussian kernel matrix.
353
+
354
+ self.W: numpy array
355
+ Deformable transformation matrix.
356
+ """
357
+ return self.G, self.W
358
+
359
+
360
+
361
+ def initialize_sigma2(X, Y):
362
+ """
363
+ Initialize the variance (sigma2).
364
+
365
+ param
366
+ ----------
367
+ X: numpy array
368
+ NxD array of points for target.
369
+
370
+ Y: numpy array
371
+ MxD array of points for source.
372
+
373
+ Returns
374
+ -------
375
+ sigma2: float
376
+ Initial variance.
377
+ """
378
+ (N, D) = X.shape
379
+ (M, _) = Y.shape
380
+ diff = X[None, :, :] - Y[:, None, :]
381
+ err = diff ** 2
382
+ return np.sum(err) / (D * M * N)
383
+
384
+
385
+
386
+ def gaussian_kernel(X, beta, Y=None):
387
+ """
388
+ Computes a Gaussian (RBF) kernel matrix between two sets of vectors.
389
+
390
+ :param X: A numpy array of shape (n_samples_X, n_features) representing the first set of vectors.
391
+ :param beta: The standard deviation parameter for the Gaussian kernel. It controls the spread of the kernel.
392
+ :param Y: An optional numpy array of shape (n_samples_Y, n_features) representing the second set of vectors.
393
+ If None, the function computes the kernel between `X` and itself (i.e., the Gram matrix).
394
+ :return: A numpy array of shape (n_samples_X, n_samples_Y) representing the Gaussian kernel matrix.
395
+ Each element (i, j) in the matrix is computed as:
396
+ `exp(-||X[i] - Y[j]||^2 / (2 * beta^2))`
397
+ """
398
+
399
+ # If Y is not provided, use X for both sets, computing the kernel matrix between X and itself
400
+ if Y is None:
401
+ Y = X
402
+
403
+ # Compute the difference tensor between each pair of vectors in X and Y
404
+ # The resulting shape is (n_samples_X, n_samples_Y, n_features)
405
+ diff = X[:, None, :] - Y[None, :, :]
406
+
407
+ # Square the differences element-wise
408
+ diff = np.square(diff)
409
+
410
+ # Sum the squared differences across the feature dimension (axis 2) to get squared Euclidean distances
411
+ # The resulting shape is (n_samples_X, n_samples_Y)
412
+ diff = np.sum(diff, axis=2)
413
+
414
+ # Apply the Gaussian (RBF) kernel formula: exp(-||X[i] - Y[j]||^2 / (2 * beta^2))
415
+ kernel_matrix = np.exp(-diff / (2 * beta**2))
416
+
417
+ return kernel_matrix
418
+
419
+
420
+
421
+ def low_rank_eigen(G, num_eig):
422
+ """
423
+ Calculate the top `num_eig` eigenvectors and eigenvalues of a given Gaussian matrix G.
424
+ This function is useful for dimensionality reduction or when a low-rank approximation is needed.
425
+
426
+ :param G: A square matrix (numpy array) for which the eigen decomposition is to be performed.
427
+ :param num_eig: The number of top eigenvectors and eigenvalues to return, based on the magnitude of eigenvalues.
428
+ :return: A tuple containing:
429
+ - Q: A numpy array with shape (n, num_eig) containing the top `num_eig` eigenvectors of the matrix `G`.
430
+ Each column in `Q` corresponds to an eigenvector.
431
+ - S: A numpy array of shape (num_eig,) containing the top `num_eig` eigenvalues of the matrix `G`.
432
+
433
+ """
434
+
435
+ # Perform eigen decomposition on matrix G
436
+ # `S` will contain all the eigenvalues, and `Q` will contain the corresponding eigenvectors
437
+ S, Q = np.linalg.eigh(G)
438
+
439
+ # Sort eigenvalues in descending order based on their absolute values
440
+ # Get the indices of the top `num_eig` largest eigenvalues
441
+ eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig])
442
+
443
+ # Select the corresponding top eigenvectors based on the sorted indices
444
+ Q = Q[:, eig_indices] # Q now contains the top `num_eig` eigenvectors
445
+
446
+ # Select the top `num_eig` eigenvalues based on the sorted indices
447
+ S = S[eig_indices] # S now contains the top `num_eig` eigenvalues
448
+
449
+ return Q, S
450
+
451
+
452
+
453
+ def find_homography_translation_rotation(src_points, dst_points):
454
+ """
455
+ Find the homography between two sets of coordinates with only translation and rotation.
456
+
457
+ :param src_points: A numpy array of shape (n, 2) containing source coordinates.
458
+ :param dst_points: A numpy array of shape (n, 2) containing destination coordinates.
459
+ :return: A 3x3 homography matrix.
460
+ """
461
+ # Ensure the points are in the correct shape
462
+ assert src_points.shape == dst_points.shape
463
+ assert src_points.shape[1] == 2
464
+
465
+ # Calculate the centroids of the point sets
466
+ src_centroid = np.mean(src_points, axis=0)
467
+ dst_centroid = np.mean(dst_points, axis=0)
468
+
469
+ # Center the points around the centroids
470
+ centered_src_points = src_points - src_centroid
471
+ centered_dst_points = dst_points - dst_centroid
472
+
473
+ # Calculate the covariance matrix
474
+ H = np.dot(centered_src_points.T, centered_dst_points)
475
+
476
+ # Singular Value Decomposition (SVD)
477
+ U, S, Vt = np.linalg.svd(H)
478
+
479
+ # Calculate the rotation matrix
480
+ R = np.dot(Vt.T, U.T)
481
+
482
+ # Ensure a proper rotation matrix (det(R) = 1)
483
+ if np.linalg.det(R) < 0:
484
+ Vt[-1, :] *= -1
485
+ R = np.dot(Vt.T, U.T)
486
+
487
+ # Calculate the translation vector
488
+ t = dst_centroid - np.dot(R, src_centroid)
489
+
490
+ # Construct the homography matrix
491
+ homography_matrix = np.eye(3)
492
+ homography_matrix[0:2, 0:2] = R
493
+ homography_matrix[0:2, 2] = t
494
+
495
+ return homography_matrix
496
+
497
+
498
+
499
+ def apply_homography(coordinates, H):
500
+ """
501
+ Apply a 3x3 homography matrix to 2D coordinates.
502
+
503
+ :param coordinates: A numpy array of shape (n, 2) containing 2D coordinates.
504
+ :param H: A numpy array of shape (3, 3) representing the homography matrix.
505
+ :return: A numpy array of shape (n, 2) with transformed coordinates.
506
+ """
507
+ # Convert (x, y) to homogeneous coordinates (x, y, 1)
508
+ n = coordinates.shape[0]
509
+ homogeneous_coords = np.hstack((coordinates, np.ones((n, 1))))
510
+
511
+ # Apply the homography matrix
512
+ transformed_homogeneous = np.dot(homogeneous_coords, H.T)
513
+
514
+ # Convert back from homogeneous coordinates (x', y', w') to (x'/w', y'/w')
515
+ transformed_coords = transformed_homogeneous[:, :2] / transformed_homogeneous[:, [2]]
516
+
517
+ return transformed_coords
518
+
519
+
520
+
521
+ def align_tissue(ad_tar_coor, ad_src_coor, pca_comb_features, src_img, alpha=0.5):
522
+ """
523
+ Aligns the source coordinates to the target coordinates using Coherent Point Drift (CPD)
524
+ registration, and applies a homography transformation to warp the source coordinates accordingly.
525
+
526
+ :param ad_tar_coor: Numpy array of target coordinates to which the source will be aligned.
527
+ :param ad_src_coor: Numpy array of source coordinates that will be aligned to the target.
528
+ :param pca_comb_features: PCA-combined feature matrix used as additional features for the alignment process.
529
+ :param src_img: Source image to be warped based on the alignment.
530
+ :param alpha: Regularization parameter for CPD registration, default is 0.5.
531
+ :return:
532
+ - cpd_coor: The new source coordinates after CPD alignment.
533
+ - homo_coor: The source coordinates after applying the homography transformation.
534
+ - aligned_image: The source image warped based on the homography transformation.
535
+ """
536
+
537
+ # Normalize target and source coordinates to the range [0, 1]
538
+ ad_tar_coor_z = (ad_tar_coor - ad_tar_coor.min()) / (ad_tar_coor.max() - ad_tar_coor.min())
539
+ ad_src_coor_z = (ad_src_coor - ad_src_coor.min()) / (ad_src_coor.max() - ad_src_coor.min())
540
+
541
+ # Normalize PCA-combined features to the range [0, 1]
542
+ pca_comb_features_z = (pca_comb_features - pca_comb_features.min()) / (pca_comb_features.max() - pca_comb_features.min())
543
+
544
+ # Concatenate spatial and PCA-combined features for target and source
545
+ target = np.concatenate((ad_tar_coor_z, pca_comb_features_z[:ad_tar_coor.shape[0], :2]), axis=1)
546
+ source = np.concatenate((ad_src_coor_z, pca_comb_features_z[ad_tar_coor.shape[0]:, :2]), axis=1)
547
+
548
+ # Initialize and run the CPD registration (deformable with regularization)
549
+ reg = DeformableRegistration(X=target, Y=source, low_rank=True,
550
+ alpha=alpha,
551
+ max_iterations=int(1e9), tolerance=1e-9)
552
+
553
+ TY = reg.register()[0] # TY contains the transformed source points
554
+
555
+ # Rescale the CPD-aligned coordinates back to the original range of target coordinates
556
+ cpd_coor = TY[:, :2] * (ad_tar_coor.max() - ad_tar_coor.min()) + ad_tar_coor.min()
557
+
558
+ # Find homography transformation based on CPD-aligned coordinates and apply it
559
+ h = find_homography_translation_rotation(ad_src_coor, cpd_coor)
560
+ homo_coor = apply_homography(ad_src_coor, h)
561
+
562
+ # Warp the source image using the computed homography
563
+ aligned_image = cv2.warpPerspective(src_img, h, (src_img.shape[1], src_img.shape[0]))
564
+
565
+ # Return the CPD-aligned coordinates, the homography-transformed coordinates, and the warped image
566
+ return cpd_coor, homo_coor, aligned_image
567
+
568
+
src/loki/annotate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+ import os
5
+ import scanpy as sc
6
+ import json
7
+ import cv2
8
+
9
+
10
+
11
+ def annotate_with_bulk(img_features, bulk_features, normalize=True, T=1, tensor=False):
12
+ """
13
+ Annotates tissue image with similarity scores between image features and bulk RNA-seq features.
14
+
15
+ :param img_features: Feature matrix representing histopathology image features.
16
+ :param bulk_features: Feature vector representing bulk RNA-seq features.
17
+ :param normalize: Whether to normalize similarity scores, default=True.
18
+ :param T: Temperature parameter to control the sharpness of the softmax distribution. Higher values result in a smoother distribution.
19
+ :param tensor: Feature format in torch tensor or not, default=False.
20
+
21
+ :return: An array or tensor containing the normalized similarity scores.
22
+ """
23
+
24
+ if tensor:
25
+ # Compute similarity between image features and bulk RNA-seq features
26
+ cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
27
+ similarity = cosine_similarity(img_features, bulk_features.unsqueeze(0)) # Shape: [n]
28
+
29
+ # Optional normalization using the feature vector's norm
30
+ if normalize:
31
+ normalization_factor = torch.sqrt(torch.tensor([bulk_features.shape[0]], dtype=torch.float)) # sqrt(768)
32
+ similarity = similarity / normalization_factor
33
+
34
+ # Reshape and apply temperature scaling for softmax
35
+ similarity = similarity.unsqueeze(0) # Shape: [1, n]
36
+ similarity = similarity / T # Control distribution sharpness
37
+
38
+ # Convert similarity scores to probability distribution using softmax
39
+ similarity = torch.nn.functional.softmax(similarity, dim=-1) # Shape: [1, n]
40
+
41
+ else:
42
+ # Compute similarity for non-tensor mode
43
+ similarity = np.dot(img_features.T, bulk_features)
44
+
45
+ # Apply a softmax-like normalization for numerical stability
46
+ max_similarity = np.max(similarity) # Maximum value for stability
47
+ similarity = np.exp(similarity - max_similarity) / np.sum(np.exp(similarity - max_similarity))
48
+
49
+ # Normalize similarity scores to [0, 1] range for interpretation
50
+ similarity = (similarity - np.min(similarity)) / (np.max(similarity) - np.min(similarity))
51
+
52
+ return similarity
53
+
54
+
55
+
56
+ def annotate_with_marker_genes(classes, image_embeddings, all_text_embeddings):
57
+ """
58
+ Annotates tissue image with similarity scores between image features and marker gene features.
59
+
60
+ :param classes: A list or array of tissue type labels.
61
+ :param image_embeddings: A numpy array or torch tensor of image embeddings (shape: [n_images, embedding_dim]).
62
+ :param all_text_embeddings: A numpy array or torch tensor of text embeddings of the marker genes
63
+ (shape: [n_classes, embedding_dim]).
64
+
65
+ :return:
66
+ - dot_similarity: The matrix of dot product similarities between image embeddings and text embeddings.
67
+ - pred_class: The predicted tissue type for the image based on the highest similarity score.
68
+ """
69
+
70
+ # Calculate dot product similarity between image embeddings and text embeddings
71
+ # This results in a similarity matrix of shape [n_images, n_classes]
72
+ dot_similarity = image_embeddings @ all_text_embeddings.T
73
+
74
+ # Find the class with the highest similarity for each image
75
+ # Use argmax to identify the index of the highest similarity score
76
+ pred_class = classes[dot_similarity.argmax()]
77
+
78
+ return dot_similarity, pred_class
79
+
80
+
81
+
82
+ def load_image_annotation(image_path):
83
+ """
84
+ Loads an image with annotation.
85
+
86
+ :param image_path: The file path to the image.
87
+
88
+ :return: The processed image, converted to BGR color space and of type uint8.
89
+ """
90
+
91
+ # Load the image from the specified file path using OpenCV
92
+ image = cv2.imread(image_path)
93
+
94
+ # Convert the color from RGB (OpenCV loads as BGR by default) to BGR (which matches common color standards)
95
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
96
+
97
+ # Ensure the image is of type uint8 for proper handling in OpenCV and other image processing libraries
98
+ image = image.astype(np.uint8)
99
+
100
+ return image
101
+
102
+
src/loki/decompose.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import tangram as tg
3
+ import numpy as np
4
+ import torch
5
+ import anndata
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.neighbors import NearestNeighbors
8
+
9
+
10
+
11
+ def generate_feature_ad(ad_expr, feature_path, sc=False):
12
+ """
13
+ Generates an AnnData object with OmiCLIP text or image embeddings.
14
+
15
+ :param ad_expr: AnnData object containing metadata for the dataset.
16
+ :param feature_path: Path to the CSV file containing the features to be loaded.
17
+ :param sc: Boolean flag indicating whether to copy single-cell metadata or ST metadata. Default is False (ST).
18
+ :return: A new AnnData object with the loaded features and relevant metadata from ad_expr.
19
+ """
20
+
21
+ # Load features from the CSV file. The index should match the cells/spots in ad_expr.obs.index.
22
+ features = pd.read_csv(feature_path, index_col=0)[ad_expr.obs.index]
23
+
24
+ # Create a new AnnData object with the features, transposing them to have cells/spots as rows
25
+ feature_ad = anndata.AnnData(features[ad_expr.obs.index].T)
26
+
27
+ # Copy relevant metadata from ad_expr based on the sc flag
28
+ if sc:
29
+ # If the data is single-cell (sc), copy the metadata from ad_expr.obs
30
+ feature_ad.obs = ad_expr.obs.copy()
31
+ else:
32
+ # If the data is spatial, copy the 'cell_num', 'spatial' info, and spatial coordinates
33
+ feature_ad.obs['cell_num'] = ad_expr.obs['cell_num'].copy()
34
+ feature_ad.uns['spatial'] = ad_expr.uns['spatial'].copy()
35
+ feature_ad.obsm['spatial'] = ad_expr.obsm['spatial'].copy()
36
+
37
+ return feature_ad
38
+
39
+
40
+
41
+ def normalize_percentile(df, cols, min_percentile=5, max_percentile=95):
42
+ """
43
+ Clips and normalizes the specified columns of a DataFrame based on percentile thresholds,
44
+ transforming their values to the [0, 1] range.
45
+
46
+ :param df: A pandas DataFrame containing the columns to normalize.
47
+ :type df: pandas.DataFrame
48
+ :param cols: A list of column names in `df` that should be normalized.
49
+ :type cols: list[str]
50
+ :param min_percentile: The lower percentile used for clipping (defaults to 5).
51
+ :type min_percentile: float
52
+ :param max_percentile: The upper percentile used for clipping (defaults to 95).
53
+ :type max_percentile: float
54
+ :return: The same DataFrame with specified columns clipped and normalized.
55
+ :rtype: pandas.DataFrame
56
+ """
57
+
58
+ # Iterate over each column that needs to be normalized
59
+ for col in cols:
60
+ # Compute the lower and upper values at the given percentiles
61
+ min_val = np.percentile(df[col], min_percentile)
62
+ max_val = np.percentile(df[col], max_percentile)
63
+
64
+ # Clip the column's values between these percentile thresholds
65
+ df[col] = np.clip(df[col], min_val, max_val)
66
+
67
+ # Perform min-max normalization to scale the clipped values to the [0, 1] range
68
+ df[col] = (df[col] - min_val) / (max_val - min_val)
69
+
70
+ return df
71
+
72
+
73
+
74
+ def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False, major_types=None, min_percentile=5, max_percentile=95):
75
+ """
76
+ Performs cell type decomposition on spatial data (ST or image) with single-cell data .
77
+
78
+ :param sc_ad: AnnData object containing single-cell meta data.
79
+ :param st_ad: AnnData object containing spatial data (ST or image) meta data.
80
+ :param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
81
+ :param NMS_mode: Boolean flag to apply Non-Maximum Suppression (NMS) mode. Default is False.
82
+ :param major_types: Major cell types used for NMS mode. Default is None.
83
+ :param min_percentile: The lower percentile used for clipping (defaults to 5).
84
+ :param max_percentile: The upper percentile used for clipping (defaults to 95).
85
+ :return: The spatial AnnData object with projected cell type annotations.
86
+ """
87
+
88
+ # Preprocess the data for decomposition using tangram (tg)
89
+ tg.pp_adatas(sc_ad, st_ad, genes=None) # Preprocessing: match genes between single-cell and spatial data
90
+
91
+
92
+ # Map single-cell data to spatial data using Tangram's "map_cells_to_space" function
93
+ ad_map = tg.map_cells_to_space(
94
+ sc_ad, st_ad,
95
+ mode="clusters", # Map based on clusters (cell types)
96
+ cluster_label=cell_type_col, # Column in `sc_ad.obs` representing cell type
97
+ device='cpu', # Run on CPU (or 'cuda' if GPU is available)
98
+ scale=False, # Don't scale data (can be set to True if needed)
99
+ density_prior='uniform', # Use prior information for cell densities
100
+ random_state=10, # Set random state for reproducibility
101
+ verbose=False, # Disable verbose output for cleaner logging
102
+ )
103
+
104
+ # Project cell type annotations from the single-cell data to the spatial data
105
+ tg.project_cell_annotations(ad_map, st_ad, annotation=cell_type_col)
106
+
107
+
108
+ if NMS_mode:
109
+ major_types = major_types
110
+ st_ad.obs = normalize_percentile(st_ad.obsm['tangram_ct_pred'], major_types, min_percentile, max_percentile)
111
+
112
+ st_ad_binary = st_ad.obsm['tangram_ct_pred'][major_types].copy()
113
+ # Retain the max value in each row and set the rest to 0
114
+ st_ad.obs[major_types] = st_ad_binary.where(st_ad_binary.eq(st_ad_binary.max(axis=1), axis=0), other=0)
115
+
116
+ return st_ad # Return the spatial AnnData object with the projected annotations
117
+
118
+
119
+
120
+ def assign_cells_to_spots(cell_locs, spot_locs, patch_size=16):
121
+ """
122
+ Assigns cells to spots based on their spatial coordinates. Each cell within the specified patch size (radius)
123
+ of a spot will be assigned to that spot.
124
+
125
+ :param cell_locs: Numpy array of shape (n_cells, 2) with the x, y coordinates of the cells.
126
+ :param spot_locs: Numpy array of shape (n_spots, 2) with the x, y coordinates of the spots.
127
+ :param patch_size: The diameter of the spot patch. The radius used for assignment will be half of this value.
128
+ :return: A sparse matrix where each row corresponds to a cell and each column corresponds to a spot.
129
+ The value is 1 if the cell is assigned to that spot, 0 otherwise.
130
+ """
131
+ # Initialize the NearestNeighbors model with a radius equal to half the patch size
132
+ neigh = NearestNeighbors(radius=patch_size * 0.5)
133
+
134
+ # Fit the model on the spot locations
135
+ neigh.fit(spot_locs)
136
+
137
+ # Create the radius neighbors graph which will assign cells to spots based on proximity
138
+ # This graph is a sparse matrix where rows are cells and columns are spots, with a 1 indicating assignment
139
+ A = neigh.radius_neighbors_graph(cell_locs, mode='connectivity')
140
+
141
+ return A
142
+
143
+
src/loki/plot.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from pathlib import Path
3
+ import json
4
+ import cv2
5
+ from matplotlib import cm
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+
11
+
12
+ def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
13
+ """
14
+ Plots the target coordinates and alignment of source coordinates.
15
+
16
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
17
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
18
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
19
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
20
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
21
+ :param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
22
+ :param s: Marker size for the scatter plot points. Default is 0.8.
23
+ :param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
24
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates.
25
+ """
26
+
27
+ # Create a figure with three subplots, adjusting size and resolution
28
+ plt.figure(figsize=(10, 3), dpi=300)
29
+
30
+ # First subplot: Plot target coordinates
31
+ plt.subplot(1, 3, 1)
32
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
33
+ # Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
34
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
35
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
36
+
37
+ # Second subplot: Plot source coordinates
38
+ plt.subplot(1, 3, 2)
39
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
40
+ # Ensure consistent plot limits across subplots by using the same limits as the target coordinates
41
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
42
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
43
+
44
+ # Third subplot: Plot alignment of source coordinates
45
+ plt.subplot(1, 3, 3)
46
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
47
+ # Maintain the same plot limits across all subplots for a uniform comparison
48
+ plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
49
+ plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
50
+
51
+ # Optionally draw boundary lines at the minimum x and y values of the target coordinates
52
+ if boundary_line:
53
+ plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
54
+ plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
55
+
56
+ # Remove the axis labels and ticks from all subplots for a cleaner appearance
57
+ plt.axis('off')
58
+
59
+ # Display the plot
60
+ plt.show()
61
+
62
+
63
+
64
+ def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
65
+ """
66
+ Plots the target coordinates and alignment of source coordinates with their respective images in the background.
67
+
68
+ :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
69
+ :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
70
+ :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
71
+ :param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
72
+ :param src_img: Image associated with the source coordinates, used as the background in the second subplot.
73
+ :param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
74
+ :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
75
+ :param tar_features: Feature matrix for the target, used to split color values between target and source data.
76
+ :return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
77
+ """
78
+
79
+ # Create a figure with three subplots and set the size and resolution
80
+ plt.figure(figsize=(10, 8), dpi=150)
81
+
82
+ # First subplot: Plot target coordinates with the target image as the background
83
+ plt.subplot(1, 3, 1)
84
+ # Scatter plot for the target coordinates with transparency and small marker size
85
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
86
+ # Overlay the target image with some transparency (alpha = 0.3)
87
+ plt.imshow(tar_img, origin='lower', alpha=0.3)
88
+
89
+ # Second subplot: Plot source coordinates with the source image as the background
90
+ plt.subplot(1, 3, 2)
91
+ # Scatter plot for the source coordinates with transparency and small marker size
92
+ plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
93
+ # Overlay the source image with some transparency (alpha = 0.3)
94
+ plt.imshow(src_img, origin='lower', alpha=0.3)
95
+
96
+ # Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
97
+ plt.subplot(1, 3, 3)
98
+ # Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
99
+ plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
100
+ # Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
101
+ plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
102
+ # Overlay the aligned image with some transparency (alpha = 0.3)
103
+ plt.imshow(aligned_image, origin='lower', alpha=0.3)
104
+
105
+ # Turn off the axis for all subplots to give a cleaner visual output
106
+ plt.axis('off')
107
+
108
+ # Display the plots
109
+ plt.show()
110
+
111
+
112
+
113
+ def draw_polygon(image, polygon, color='k', thickness=2):
114
+ """
115
+ Draws one or more polygons on the given image.
116
+
117
+ :param image: The image on which to draw the polygons (as a numpy array).
118
+ :param polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
119
+ :param color: A string or list of strings representing the color(s) for each polygon.
120
+ If a single color is provided, it will be applied to all polygons. Default is 'k' (black).
121
+ :param thickness: An integer or a list of integers representing the thickness of the polygon borders.
122
+ If a single value is provided, it will be applied to all polygons. Default is 2.
123
+
124
+ :return: The image with the polygons drawn on it.
125
+ """
126
+
127
+ # If the provided `color` is a single value (string), convert it to a list of the same color for each polygon
128
+ if not isinstance(color, list):
129
+ color = [color] * len(polygon) # Create a list where each polygon gets the same color
130
+
131
+ # Loop through each polygon in the list, along with its corresponding color
132
+ for i, poly in enumerate(polygon):
133
+ # Get the color for the current polygon
134
+ c = color[i]
135
+
136
+ # Convert the color from a string format (e.g., 'k' or '#ff0000') to an RGB tuple
137
+ c = color_string_to_rgb(c)
138
+
139
+ # Get the thickness value for the current polygon (if a list is provided, use the corresponding value)
140
+ t = thickness[i] if isinstance(thickness, list) else thickness
141
+
142
+ # Convert the polygon coordinates to a numpy array of integers
143
+ poly = np.array(poly, np.int32)
144
+
145
+ # Reshape the polygon array to match OpenCV's expected input format: (number of points, 1, 2)
146
+ poly = poly.reshape((-1, 1, 2))
147
+
148
+ # Draw the polygon on the image using OpenCV's `cv2.polylines` function
149
+ # `isClosed=True` indicates that the polygon should be closed (start and end points are connected)
150
+ image = cv2.polylines(image, [poly], isClosed=True, color=c, thickness=t)
151
+
152
+ return image
153
+
154
+
155
+
156
+ def blend_images(image1, image2, alpha=0.5):
157
+ """
158
+ Blends two images together.
159
+
160
+ :param image1: Background image, a numpy array of shape (H, W, 3), where H is height, W is width, and 3 represents the RGB color channels.
161
+ :param image2: Foreground image, a numpy array of shape (H, W, 3), same dimensions as image1.
162
+ :param alpha: Blending factor, a float between 0 and 1. The value of alpha determines the weight of image1 in the blend,
163
+ where 0 means only image2 is shown, and 1 means only image1 is shown. Default is 0.5 (equal blending).
164
+
165
+ :return: A blended image, where each pixel is a weighted combination of the corresponding pixels from image1 and image2.
166
+ The blending is computed as: `blended = alpha * image1 + (1 - alpha) * image2`.
167
+ """
168
+
169
+ # Use cv2.addWeighted to blend the two images.
170
+ # The first image (image1) is weighted by 'alpha', and the second image (image2) is weighted by '1 - alpha'.
171
+ blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 0)
172
+
173
+ # Return the resulting blended image.
174
+ return blended
175
+
176
+
177
+
178
+ def color_string_to_rgb(color_string):
179
+ """
180
+ Converts a color string to an RGB tuple.
181
+
182
+ :param color_string: A string representing the color. This can be in hexadecimal form (e.g., '#ff0000') or
183
+ a shorthand character for basic colors (e.g., 'k' for black, 'r' for red, etc.).
184
+ :return:
185
+ A tuple (r, g, b) representing the RGB values of the color, where each value is an integer between 0 and 255.
186
+ :raises:
187
+ ValueError: If the color string is not recognized.
188
+ """
189
+
190
+ # Remove any spaces in the color string
191
+ color_string = color_string.replace(' ', '')
192
+
193
+ # If the string starts with a '#', it's a hexadecimal color, so we remove the '#'
194
+ if color_string.startswith('#'):
195
+ color_string = color_string[1:]
196
+ else:
197
+ # Handle shorthand single-letter color codes by converting them to hex values
198
+ # 'k' -> black, 'r' -> red, 'g' -> green, 'b' -> blue, 'w' -> white
199
+ if color_string == 'k': # Black
200
+ color_string = '000000'
201
+ elif color_string == 'r': # Red
202
+ color_string = 'ff0000'
203
+ elif color_string == 'g': # Green
204
+ color_string = '00ff00'
205
+ elif color_string == 'b': # Blue
206
+ color_string = '0000ff'
207
+ elif color_string == 'w': # White
208
+ color_string = 'ffffff'
209
+ else:
210
+ # Raise an error if the color string is not recognized
211
+ raise ValueError(f"Unknown color string {color_string}")
212
+
213
+ # Convert the first two characters to the red (R) value
214
+ r = int(color_string[:2], 16)
215
+
216
+ # Convert the next two characters to the green (G) value
217
+ g = int(color_string[2:4], 16)
218
+
219
+ # Convert the last two characters to the blue (B) value
220
+ b = int(color_string[4:], 16)
221
+
222
+ # Return the RGB values as a tuple
223
+ return (r, g, b)
224
+
225
+
226
+
227
+ def plot_heatmap(
228
+ coor,
229
+ similairty,
230
+ image_path=None,
231
+ patch_size=(256, 256),
232
+ save_path=None,
233
+ downsize=32,
234
+ cmap='turbo',
235
+ smooth=False,
236
+ boxes=None,
237
+ box_color='k',
238
+ box_thickness=2,
239
+ polygons=None,
240
+ polygons_color='k',
241
+ polygons_thickness=2,
242
+ image_alpha=0.5
243
+ ):
244
+ """
245
+ Plots a heatmap overlaid on an image based on given coordinates and similairty.
246
+
247
+ :param coor: Array of coordinates (N, 2) where N is the number of patches to place on the heatmap.
248
+ :param similairty: Array of similairty (N,) corresponding to the coordinates. These similairties are mapped to colors using a colormap.
249
+ :param image_path: Path to the background image on which the heatmap will be overlaid. If None, a blank white background is used.
250
+ :param patch_size: Size of each patch in pixels (default is 256x256).
251
+ :param save_path: Path to save the heatmap image. If None, the heatmap is returned instead of being saved.
252
+ :param downsize: Factor to downsize the image and patches for faster processing. Default is 32.
253
+ :param cmap: Colormap to map the similairties to colors. Default is 'turbo'.
254
+ :param smooth: Boolean to indicate if the heatmap should be smoothed. Not implemented in this version.
255
+ :param boxes: List of boxes in (x, y, w, h) format. If provided, boxes will be drawn on the heatmap.
256
+ :param box_color: Color of the boxes. Default is black ('k').
257
+ :param box_thickness: Thickness of the box outlines.
258
+ :param polygons: List of polygons (N, 2) to draw on the heatmap.
259
+ :param polygons_color: Color of the polygon outlines. Default is black ('k').
260
+ :param polygons_thickness: Thickness of the polygon outlines.
261
+ :param image_alpha: Transparency value (0 to 1) for blending the heatmap with the original image. Default is 0.5.
262
+
263
+ :return:
264
+ - heatmap: The generated heatmap as a numpy array (RGB).
265
+ - image: The original image with overlaid polygons if provided.
266
+ """
267
+
268
+ # Read the background image (if provided), otherwise a blank image
269
+ image = cv2.imread(image_path)
270
+ image_size = (image.shape[0], image.shape[1]) # Get the size of the image
271
+ coor = [(x // downsize, y // downsize) for x, y in coor] # Downsize the coordinates for faster processing
272
+ patch_size = (patch_size[0] // downsize, patch_size[1] // downsize) # Downsize the patch size
273
+
274
+ # Convert similairties to colors using the provided colormap
275
+ cmap = plt.get_cmap(cmap) # Get the colormap object
276
+ norm = plt.Normalize(vmin=similairty.min(), vmax=similairty.max()) # Normalize similairties to map to color range
277
+ colors = cmap(norm(similairty)) # Convert the normalized similairties to RGB colors
278
+
279
+ # Initialize a blank white heatmap the size of the image
280
+ heatmap = np.ones((image_size[0], image_size[1], 3)) * 255 # Start with a white background
281
+
282
+ # Place the colored patches on the heatmap according to the coordinates and patch size
283
+ for i in range(len(coor)):
284
+ x, y = coor[i]
285
+ w = colors[i][:3] * 255 # Get the RGB color for the patch, scaling from [0, 1] to [0, 255]
286
+ w = w.astype(np.uint8) # Convert the color to uint8
287
+ heatmap[y:y + patch_size[0], x:x + patch_size[1], :] = w # Place the patch on the heatmap
288
+
289
+ # If the image_alpha is greater than 0, blend the heatmap with the original image
290
+ if image_alpha > 0:
291
+ image = np.array(image)
292
+
293
+ # Pad the image if necessary to match the heatmap size
294
+ if image.shape[0] < heatmap.shape[0]:
295
+ pad = heatmap.shape[0] - image.shape[0]
296
+ image = np.pad(image, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255)
297
+ if image.shape[1] < heatmap.shape[1]:
298
+ pad = heatmap.shape[1] - heatmap.shape[1]
299
+ image = np.pad(image, ((0, 0), (0, pad), (0, 0)), mode='constant', constant_values=255)
300
+
301
+ # Convert the image to BGR (for OpenCV compatibility) and blend with the heatmap
302
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
303
+ image = image.astype(np.uint8)
304
+ heatmap = heatmap.astype(np.uint8)
305
+ heatmap = blend_images(heatmap, image, alpha=image_alpha) # Blend the heatmap and the image
306
+
307
+ # If polygons are provided, draw them on the heatmap and image
308
+ if polygons is not None:
309
+ polygons = [poly // downsize for poly in polygons] # Downsize the polygon coordinates
310
+ image_polygons = draw_polygon(image, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the original image
311
+ heatmap_polygons = draw_polygon(heatmap, polygons, color=polygons_color, thickness=polygons_thickness) # Draw polygons on the heatmap
312
+
313
+ return heatmap_polygons, image_polygons # Return the heatmap and image with polygons drawn on them
314
+ else:
315
+ return heatmap, image # Return the heatmap and image
316
+
317
+
318
+
319
+ def show_images_side_by_side(image1, image2, title1=None, title2=None):
320
+ """
321
+ Displays two images side by side in a single figure.
322
+
323
+ :param image1: The first image to display (as a numpy array).
324
+ :param image2: The second image to display (as a numpy array).
325
+ :param title1: The title for the first image. Default is None (no title).
326
+ :param title2: The title for the second image. Default is None (no title).
327
+ :return: Displays the images side by side.
328
+ """
329
+
330
+ # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
331
+ fig, ax = plt.subplots(1, 2, figsize=(15,8))
332
+
333
+ # Display the first image on the first subplot
334
+ ax[0].imshow(image1)
335
+
336
+ # Display the second image on the second subplot
337
+ ax[1].imshow(image2)
338
+
339
+ # Set the title for the first image (if provided)
340
+ ax[0].set_title(title1)
341
+
342
+ # Set the title for the second image (if provided)
343
+ ax[1].set_title(title2)
344
+
345
+ # Remove axis labels and ticks for both images to give a cleaner look
346
+ ax[0].axis('off')
347
+ ax[1].axis('off')
348
+
349
+ # Show the final figure with both images displayed side by side
350
+ plt.show()
351
+
352
+
353
+
354
+ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
355
+ """
356
+ Plots image with polygons.
357
+
358
+ :param fullres_img: The full-resolution image to display (as a numpy array).
359
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
360
+ :param linewidth: The thickness of the lines used to draw the polygons.
361
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
362
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
363
+ :return: Displays the image with ROI polygons overlaid.
364
+ """
365
+
366
+ # Create a new figure with a fixed size for displaying the image and annotations
367
+ plt.figure(figsize=(10, 10))
368
+
369
+ # Display the full-resolution image
370
+ plt.imshow(fullres_img)
371
+
372
+ # Loop through each polygon in roi_polygon and plot them on the image
373
+ for polygon in roi_polygon:
374
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
375
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
376
+
377
+ # Set the x-axis limits based on the provided tuple (xlim)
378
+ plt.xlim(xlim)
379
+
380
+ # Set the y-axis limits based on the provided tuple (ylim)
381
+ plt.ylim(ylim)
382
+
383
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
384
+ plt.gca().invert_yaxis()
385
+
386
+ # Turn off the axis to give a cleaner image display without ticks or labels
387
+ plt.axis('off')
388
+
389
+
390
+
391
+ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
392
+ """
393
+ Plots tissue type annotation heatmap.
394
+
395
+ :param st_ad: AnnData object containing coordinates in `obsm['spatial']`
396
+ and similarity scores in `obs['bulk_simi']`.
397
+ :param roi_polygon: A list of polygons, where each polygon is a list of (x, y) coordinate tuples.
398
+ :param s: The size of the scatter plot markers representing each spatial transcriptomics spot.
399
+ :param linewidth: The thickness of the lines used to draw the polygons.
400
+ :param xlim: A tuple (xmin, xmax) defining the x-axis limits for zooming in on a specific region of the image.
401
+ :param ylim: A tuple (ymin, ymax) defining the y-axis limits for zooming in on a specific region of the image.
402
+ :return: Displays the heatmap with polygons overlaid.
403
+ """
404
+
405
+ # Create a new figure with a fixed size for displaying the heatmap and annotations
406
+ plt.figure(figsize=(10, 10))
407
+
408
+ # Scatter plot for the spatial transcriptomics data.
409
+ # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
410
+ plt.scatter(
411
+ st_ad.obsm['spatial'][:, 0], st_ad.obsm['spatial'][:, 1], # x and y coordinates
412
+ c=st_ad.obs['bulk_simi'], # Color values based on 'bulk_simi'
413
+ s=s, # Size of each marker
414
+ vmin=0.1, vmax=0.95, # Set the range for the color normalization
415
+ cmap='turbo' # Use the 'turbo' colormap for the heatmap
416
+ )
417
+
418
+ # Loop through each polygon in roi_polygon and plot them on the image
419
+ for polygon in roi_polygon:
420
+ x, y = zip(*polygon) # Unzip the list of (x, y) tuples into separate x and y coordinate lists
421
+ plt.plot(x, y, color='black', linewidth=linewidth) # Plot the polygon using the specified linewidth
422
+
423
+ # Set the x-axis limits based on the provided tuple (xlim)
424
+ plt.xlim(xlim)
425
+
426
+ # Set the y-axis limits based on the provided tuple (ylim)
427
+ plt.ylim(ylim)
428
+
429
+ # Invert the y-axis to match the typical image display convention (origin at the top-left)
430
+ plt.gca().invert_yaxis()
431
+
432
+ # Turn off the axis to give a cleaner image display without ticks or labels
433
+ plt.axis('off')
434
+
435
+
src/loki/predex.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+
5
+ def predict_st_gene_expr(image_text_similarity, train_data):
6
+ """
7
+ Predicts ST gene expression by H&E image.
8
+
9
+ :param image_text_similarity: Numpy array of similarities between images and text features (shape: [n_samples, n_genes]).
10
+ :param train_data: Numpy array or DataFrame of training data used for making predictions (shape: [n_genes, n_shared_genes]).
11
+ :return: Numpy array or DataFrame containing the predicted gene expression levels for the samples.
12
+ """
13
+
14
+ # Compute the weighted sum of the train_data using image_text_similarity
15
+ weighted_sum = image_text_similarity @ train_data
16
+
17
+ # Compute the normalization factor (sum of the image-text similarities for each sample)
18
+ weights = image_text_similarity.sum(axis=1, keepdims=True)
19
+
20
+ # Normalize the predicted matrix to get weighted gene expression predictions
21
+ predicted_image_text_matrix = weighted_sum / weights
22
+
23
+ return predicted_image_text_matrix
24
+
25
+
src/loki/preprocess.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scanpy as sc
2
+ import numpy as np
3
+ import pandas as pd
4
+ import json
5
+ import os
6
+ from PIL import Image
7
+
8
+
9
+
10
+ def generate_gene_df(ad, house_keeping_genes, todense=True):
11
+ """
12
+ Generates a DataFrame with the top 50 genes for each observation in an AnnData object.
13
+ It removes genes containing '.' or '-' in their names, as well as genes listed in
14
+ the provided `house_keeping_genes` DataFrame/Series under the 'genesymbol' column.
15
+
16
+ :param ad: An AnnData object containing gene expression data.
17
+ :type ad: anndata.AnnData
18
+ :param house_keeping_genes: DataFrame or Series with a 'genesymbol' column listing housekeeping genes to exclude.
19
+ :type house_keeping_genes: pandas.DataFrame or pandas.Series
20
+ :param todense: Whether to convert the sparse matrix (ad.X) to a dense matrix before creating a DataFrame.
21
+ :type todense: bool
22
+ :return: A DataFrame (`top_k_genes_str`) that contains a 'label' column. Each row in 'label' is a string
23
+ with the top 50 gene names (space-separated) for that observation.
24
+ :rtype: pandas.DataFrame
25
+ """
26
+
27
+ # Remove genes containing '.' in their names
28
+ ad = ad[:, ~ad.var.index.str.contains('.', regex=False)]
29
+ # Remove genes containing '-'
30
+ ad = ad[:, ~ad.var.index.str.contains('-', regex=False)]
31
+ # Exclude housekeeping genes
32
+ ad = ad[:, ~ad.var.index.isin(house_keeping_genes['genesymbol'])]
33
+
34
+ # Convert to dense if requested; otherwise use the data as-is
35
+ if todense:
36
+ expr = pd.DataFrame(ad.X.todense(), index=ad.obs.index, columns=ad.var.index)
37
+ else:
38
+ expr = pd.DataFrame(ad.X, index=ad.obs.index, columns=ad.var.index)
39
+
40
+ # For each row (observation), find the top 50 genes with the highest expression
41
+ top_k_genes = expr.apply(lambda s, n: pd.Series(s.nlargest(n).index), axis=1, n=50)
42
+
43
+ # Create a new DataFrame to store the labels (space-separated top gene names)
44
+ top_k_genes_str = pd.DataFrame()
45
+ top_k_genes_str['label'] = top_k_genes[top_k_genes.columns].astype(str) \
46
+ .apply(lambda x: ' '.join(x), axis=1)
47
+
48
+ return top_k_genes_str
49
+
50
+
51
+
52
+ def segment_patches(img_array, coord, patch_dir, height=20, width=20):
53
+ """
54
+ Extracts small image patches centered at specified coordinates and saves them as individual PNG files.
55
+
56
+ :param img_array: A NumPy array representing the full-resolution image. Shape is expected to be (H, W[, C]).
57
+ :type img_array: numpy.ndarray
58
+ :param coord: A pandas DataFrame containing patch center coordinates in columns "pixel_x" and "pixel_y".
59
+ The index corresponds to spot IDs. Example columns: ["pixel_x", "pixel_y"].
60
+ :type coord: pandas.DataFrame
61
+ :param patch_dir: Directory path where the patch images will be saved.
62
+ :type patch_dir: str
63
+ :param height: The patch's height in pixels (distance in the y-direction).
64
+ :type height: int
65
+ :param width: The patch's width in pixels (distance in the x-direction).
66
+ :type width: int
67
+ :return: None. The function saves image patches to `patch_dir` but does not return anything.
68
+ """
69
+
70
+ # Ensure the output directory exists; create it if it doesn't
71
+ if not os.path.exists(patch_dir):
72
+ os.makedirs(patch_dir)
73
+
74
+ # Extract the overall height and width of the image
75
+ yrange, xrange = img_array.shape[:2]
76
+
77
+ # Iterate through each coordinate in the DataFrame
78
+ for spot_idx in coord.index:
79
+ # Retrieve the center x and y coordinates for the current spot
80
+ ycenter, xcenter = coord.loc[spot_idx, ["pixel_x", "pixel_y"]]
81
+
82
+ # Compute the top-left (x1, y1) and bottom-right (x2, y2) boundaries of the patch
83
+ x1 = round(xcenter - width / 2)
84
+ y1 = round(ycenter - height / 2)
85
+ x2 = x1 + width
86
+ y2 = y1 + height
87
+
88
+ # Check if the patch boundaries go outside the image
89
+ if x1 < 0 or y1 < 0 or x2 > xrange or y2 > yrange:
90
+ print(f"Patch {spot_idx} is out of range and will be skipped.")
91
+ continue
92
+
93
+ # Extract the patch and convert to a PIL Image; cast to uint8 if needed
94
+ patch_img = Image.fromarray(img_array[y1:y2, x1:x2].astype(np.uint8))
95
+
96
+ # Create a filename for the patch image (e.g., "0_hires.png")
97
+ patch_name = f"{spot_idx}_hires.png"
98
+ patch_path = os.path.join(patch_dir, patch_name)
99
+
100
+ # Save the patch image to disk
101
+ patch_img.save(patch_path)
102
+
103
+
104
+
105
+ def read_gct(file_path):
106
+ """
107
+ Reads a GCT file, parses its dimensions, and returns the data as a pandas DataFrame.
108
+
109
+ :param file_path: The path to the GCT file to be read.
110
+ :return: A pandas DataFrame containing the GCT data, where the first two columns represent gene names and descriptions,
111
+ and the subsequent columns contain the expression data.
112
+ """
113
+
114
+ # Open the GCT file for reading
115
+ with open(file_path, 'r') as file:
116
+ # Read and ignore the first line (GCT version line)
117
+ file.readline()
118
+
119
+ # Read the second line which contains the dimensions of the data matrix
120
+ dims = file.readline().strip().split() # Split the dimensions line by whitespace
121
+ num_rows = int(dims[0]) # Number of data rows (genes)
122
+ num_cols = int(dims[1]) # Number of data columns (samples + metadata)
123
+
124
+ # Read the data starting from the third line, using pandas for tab-delimited data
125
+ # The first two columns in GCT files are "Name" and "Description" (gene identifiers and annotations)
126
+ data = pd.read_csv(file, sep='\t', header=0, nrows=num_rows)
127
+
128
+ # Return the loaded data as a pandas DataFrame
129
+ return data
130
+
131
+
132
+
133
+ def get_library_id(adata):
134
+ """
135
+ Retrieves the library ID from the AnnData object, assuming it contains spatial data.
136
+ The function will return the first library ID found in `adata.uns['spatial']`.
137
+
138
+ :param adata: AnnData object containing spatial information in `adata.uns['spatial']`.
139
+ :return: The first library ID found in `adata.uns['spatial']`.
140
+ :raises:
141
+ AssertionError: If 'spatial' is not present in `adata.uns`.
142
+ Logs an error if no library ID is found.
143
+ """
144
+
145
+ # Check if 'spatial' is present in adata.uns; raises an error if not found
146
+ assert 'spatial' in adata.uns, "spatial not present in adata.uns"
147
+
148
+ # Retrieve the list of library IDs (which are keys in the 'spatial' dictionary)
149
+ library_ids = adata.uns['spatial'].keys()
150
+
151
+ try:
152
+ # Attempt to return the first library ID (converting the keys object to a list)
153
+ library_id = list(library_ids)[0]
154
+ return library_id
155
+ except IndexError:
156
+ # If no library IDs exist, log an error message
157
+ logger.error('No library_id found in adata')
158
+
159
+
160
+
161
+ def get_scalefactors(adata, library_id=None):
162
+ """
163
+ Retrieves the scalefactors from the AnnData object for a given library ID. If no library ID is provided,
164
+ the function will automatically retrieve the first available library ID.
165
+
166
+ :param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
167
+ :param library_id: The library ID for which the scalefactors are to be retrieved. If not provided, it defaults to the first available ID.
168
+ :return: A dictionary containing scalefactors for the specified library ID.
169
+ """
170
+
171
+ # If no library_id is provided, retrieve the first available library ID
172
+ if library_id is None:
173
+ library_id = get_library_id(adata)
174
+
175
+ try:
176
+ # Attempt to retrieve the scalefactors for the specified library ID
177
+ scalef = adata.uns['spatial'][library_id]['scalefactors']
178
+ return scalef
179
+ except KeyError:
180
+ # Log an error if the scalefactors or library ID is not found
181
+ logger.error('scalefactors not found in adata')
182
+
183
+
184
+
185
+ def get_spot_diameter_in_pixels(adata, library_id=None):
186
+ """
187
+ Retrieves the spot diameter in pixels from the AnnData object's scalefactors for a given library ID.
188
+ If no library ID is provided, the function will automatically retrieve the first available library ID.
189
+
190
+ :param adata: AnnData object containing spatial data and scalefactors in `adata.uns['spatial']`.
191
+ :param library_id: The library ID for which the spot diameter is to be retrieved. If not provided, defaults to the first available ID.
192
+
193
+ :return: The spot diameter in full resolution pixels, or None if not found.
194
+ """
195
+
196
+ # Get the scalefactors for the specified or default library ID
197
+ scalef = get_scalefactors(adata, library_id=library_id)
198
+
199
+ try:
200
+ # Attempt to retrieve the spot diameter in full resolution from the scalefactors
201
+ spot_diameter = scalef['spot_diameter_fullres']
202
+ return spot_diameter
203
+ except TypeError:
204
+ # Handle case where `scalef` is None or invalid (if get_scalefactors returned None)
205
+ pass
206
+ except KeyError:
207
+ # Log an error if the 'spot_diameter_fullres' key is not found in the scalefactors
208
+ logger.error('spot_diameter_fullres not found in adata')
209
+
210
+
211
+
212
+ def prepare_data_for_alignment(data_path, scale_type='tissue_hires_scalef'):
213
+ """
214
+ Prepares data for alignment by reading an AnnData object and preparing the high-resolution tissue image.
215
+
216
+ :param data_path: The path to the AnnData (.h5ad) file containing the Visium data.
217
+ :param scale_type: The type of scale factor to use (`tissue_hires_scalef` by default).
218
+
219
+ :return:
220
+ - ad: AnnData object containing the spatial transcriptomics data.
221
+ - ad_coor: Numpy array of scaled spatial coordinates (adjusted for the specified resolution).
222
+ - img: High-resolution tissue image, normalized to 8-bit unsigned integers.
223
+
224
+ :raises:
225
+ ValueError: If required data (e.g., scale factors, spatial coordinates, or images) is missing.
226
+ """
227
+
228
+ # Load the AnnData object from the specified file path
229
+ ad = sc.read_h5ad(data_path)
230
+
231
+ # Ensure the variable (gene) names are unique to avoid potential conflicts
232
+ ad.var_names_make_unique()
233
+
234
+ try:
235
+ # Retrieve the specified scale factor for spatial coordinates
236
+ scalef = get_scalefactors(ad)[scale_type]
237
+ except KeyError:
238
+ raise ValueError(f"Scale factor '{scale_type}' not found in ad.uns['spatial']")
239
+
240
+ # Scale the spatial coordinates using the specified scale factor
241
+ try:
242
+ ad_coor = np.array(ad.obsm['spatial']) * scalef
243
+ except KeyError:
244
+ raise ValueError("Spatial coordinates not found in ad.obsm['spatial']")
245
+
246
+ # Retrieve the high-resolution tissue image
247
+ try:
248
+ img = ad.uns['spatial'][get_library_id(ad)]['images']['hires']
249
+ except KeyError:
250
+ raise ValueError("High-resolution image not found in ad.uns['spatial']")
251
+
252
+ # If the image values are normalized to [0, 1], convert to 8-bit format for compatibility
253
+ if img.max() < 1.1:
254
+ img = (img * 255).astype('uint8')
255
+
256
+ return ad, ad_coor, img
257
+
258
+
259
+
260
+ def load_data_for_annotation(st_data_path, json_path, in_tissue=True):
261
+ """
262
+ Loads spatial transcriptomics (ST) data from an .h5ad file and prepares it for annotation.
263
+
264
+ :param sample_type: The type or category of the sample (used to locate the data in the directory structure).
265
+ :param sample_name: The name of the sample (used to locate specific files).
266
+ :param in_tissue: Boolean flag to filter the data to include only spots that are in tissue. Default is True.
267
+
268
+ :return:
269
+ - st_ad: AnnData object containing the spatial transcriptomics data, with spatial coordinates in `obs`.
270
+ - library_id: The library ID associated with the spatial data.
271
+ - roi_polygon: Region of interest polygon loaded from a JSON file for further annotation or analysis.
272
+ """
273
+
274
+ # Load the spatial transcriptomics data into an AnnData object
275
+ st_ad = sc.read_h5ad(st_data_path)
276
+
277
+ # Optionally filter the data to include only spots that are within the tissue
278
+ if in_tissue:
279
+ st_ad = st_ad[st_ad.obs['in_tissue'] == 1]
280
+
281
+ # Initialize pixel coordinates for spatial information
282
+ st_ad.obs[["pixel_y", "pixel_x"]] = None # Ensure the columns exist
283
+ st_ad.obs[["pixel_y", "pixel_x"]] = st_ad.obsm['spatial'] # Copy spatial coordinates into obs
284
+
285
+ # Retrieve the library ID associated with the spatial data
286
+ library_id = get_library_id(st_ad)
287
+
288
+ # Load the region of interest (ROI) polygon from a JSON file
289
+ with open(json_path) as f:
290
+ roi_polygon = json.load(f)
291
+
292
+ return st_ad, library_id, roi_polygon
293
+
294
+
295
+
296
+ def read_polygons(file_path, slide_id):
297
+ """
298
+ Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
299
+
300
+ :param file_path: Path to the JSON file containing polygon configurations.
301
+ :param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
302
+ :return:
303
+ - polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
304
+ - polygon_colors: A list of color values corresponding to each polygon.
305
+ - polygon_thickness: A list of thickness values for each polygon's border.
306
+ """
307
+
308
+ # Open the JSON file and load the polygon configurations into a Python dictionary
309
+ with open(file_path, 'r') as f:
310
+ polygons_configs = json.load(f)
311
+
312
+ # Check if the given slide_id exists in the polygon configurations
313
+ if slide_id not in polygons_configs:
314
+ return None, None, None # If slide_id is not found, return None for all outputs
315
+
316
+ # Extract the polygon coordinates, colors, and thicknesses for the given slide_id
317
+ polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
318
+ polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
319
+ polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
320
+
321
+ # Return the polygons, their colors, and their thicknesses
322
+ return polygons, polygon_colors, polygon_thickness
323
+
324
+
src/loki/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anndata==0.10.9
2
+ matplotlib==3.9.2
3
+ numpy==1.25.0
4
+ pandas==2.2.3
5
+ opencv-python==4.10.0.84
6
+ pycpd==2.0.0
7
+ torch==2.3.1
8
+ tangram-sc==1.0.4
9
+ tqdm==4.66.5
10
+ torchvision==0.18.1
11
+ open_clip_torch==2.26.1
12
+ pillow==10.4.0
13
+ ipykernel==6.29.5
14
+
src/loki/retrieve.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3):
6
+ """
7
+ Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings.
8
+
9
+ :param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]).
10
+ :param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]).
11
+ :param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column.
12
+ :param k: The number of top similar samples to retrieve. Default is 3.
13
+ :return: A list of the filenames or indices corresponding to the top-k similar samples.
14
+ """
15
+
16
+ # Compute the dot product (similarity) between the image embeddings and all ST embeddings
17
+ dot_similarity = image_embeddings @ all_text_embeddings.T
18
+
19
+ # Retrieve the top-k most similar samples by similarity score (dot product)
20
+ values, indices = torch.topk(dot_similarity.squeeze(0), k)
21
+
22
+ # Extract the image filenames or indices from the DataFrame based on the top-k matches
23
+ image_filenames = dataframe['img_idx'].values
24
+ matches = [image_filenames[idx] for idx in indices]
25
+
26
+ return matches
27
+
28
+
src/loki/utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import json
8
+ import cv2
9
+ from sklearn.decomposition import PCA
10
+ from open_clip import create_model_from_pretrained, get_tokenizer
11
+
12
+
13
+
14
+ def load_model(model_path, device):
15
+ """
16
+ Loads a pretrained OmiCLIP model, along with its preprocessing function and tokenizer,
17
+ using the specified model checkpoint.
18
+
19
+ :param model_path: File path to the pretrained model checkpoint. This is passed to
20
+ `create_model_from_pretrained` as the `pretrained` argument.
21
+ :type model_path: str
22
+ :param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
23
+ :type device: str or torch.device
24
+ :return: A tuple `(model, preprocess, tokenizer)` where:
25
+ - model: The loaded OmiCLIP model.
26
+ - preprocess: A function or transform that preprocesses input data for the model.
27
+ - tokenizer: A tokenizer appropriate for textual input to the model.
28
+ :rtype: (nn.Module, callable, callable)
29
+ """
30
+ # Create the model and its preprocessing transform from the specified checkpoint
31
+ model, preprocess = create_model_from_pretrained(
32
+ "coca_ViT-L-14", device=device, pretrained=model_path
33
+ )
34
+
35
+ # Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
36
+ tokenizer = get_tokenizer('coca_ViT-L-14')
37
+
38
+ return model, preprocess, tokenizer
39
+
40
+
41
+
42
+ def encode_image(model, preprocess, image):
43
+ """
44
+ Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
45
+
46
+ :param model: A model object that provides an `encode_image` method.
47
+ :type model: torch.nn.Module
48
+ :param preprocess: A preprocessing function that transforms the input image into a tensor
49
+ suitable for the model. Typically something returning a PyTorch tensor.
50
+ :type preprocess: callable
51
+ :param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
52
+ :type image: PIL.Image.Image or numpy.ndarray
53
+ :return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
54
+ :rtype: torch.Tensor
55
+ """
56
+ # Preprocess the image, then stack to create a batch of size 1
57
+ image_input = torch.stack([preprocess(image)])
58
+
59
+ # Generate the image features without gradient tracking
60
+ with torch.no_grad():
61
+ image_features = model.encode_image(image_input)
62
+
63
+ # Normalize embeddings across the feature dimension (L2 normalization)
64
+ image_embeddings = F.normalize(image_features, p=2, dim=-1)
65
+
66
+ return image_embeddings
67
+
68
+
69
+
70
+ def encode_image_patches(model, preprocess, data_dir, img_list):
71
+ """
72
+ Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
73
+
74
+ :param model: A model object that provides an `encode_image` method.
75
+ :type model: torch.nn.Module
76
+ :param preprocess: A preprocessing function that transforms the input image into a tensor
77
+ suitable for the model. Typically something returning a PyTorch tensor.
78
+ :type preprocess: callable
79
+ :param data_dir: The base directory containing image data.
80
+ :type data_dir: str
81
+ :param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
82
+ stored in `data_dir/demo_data/patch/`.
83
+ :type img_list: list[str]
84
+ :return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
85
+ for each image in `img_list`.
86
+ :rtype: torch.Tensor
87
+ """
88
+
89
+ # Prepare a list to hold each image's feature embedding
90
+ image_embeddings = []
91
+
92
+ # Loop through each image name in the provided list
93
+ for img_name in img_list:
94
+ # Build the path to the patch image and open it
95
+ image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
96
+ image = Image.open(image_path)
97
+
98
+ # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
99
+ image_features = encode_image(model, preprocess, image)
100
+
101
+ # Accumulate the feature embeddings in the list
102
+ image_embeddings.append(image_features)
103
+
104
+ # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
105
+ # Resulting shape will be (N, 1, embedding_dim)
106
+ image_embeddings = torch.from_numpy(np.array(image_embeddings))
107
+
108
+ # Normalize all embeddings across the feature dimension (L2 normalization)
109
+ image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
110
+
111
+ return image_embeddings
112
+
113
+
114
+
115
+ def encode_text(model, tokenizer, text):
116
+ """
117
+ Encodes text into a normalized feature embedding using a specified model and tokenizer.
118
+
119
+ :param model: A model object that provides an `encode_text` method.
120
+ :type model: torch.nn.Module
121
+ :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
122
+ Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
123
+ :type tokenizer: callable
124
+ :param text: The input text (string or list of strings) to be encoded.
125
+ :type text: str or list[str]
126
+ :return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
127
+ :rtype: torch.Tensor
128
+ """
129
+
130
+ # Convert text to the appropriate tokenized representation
131
+ text_input = tokenizer(text)
132
+
133
+ # Run the model in no-grad mode (not tracking gradients, saving memory and compute)
134
+ with torch.no_grad():
135
+ text_features = model.encode_text(text_input)
136
+
137
+ # Normalize embeddings to unit length
138
+ text_embeddings = F.normalize(text_features, p=2, dim=-1)
139
+
140
+ return text_embeddings
141
+
142
+
143
+
144
+ def encode_text_df(model, tokenizer, df, col_name):
145
+ """
146
+ Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
147
+ returning a PyTorch tensor of normalized text embeddings.
148
+
149
+ :param model: A model object that provides an `encode_text` method.
150
+ :type model: torch.nn.Module
151
+ :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
152
+ :type tokenizer: callable
153
+ :param df: A pandas DataFrame from which text will be extracted.
154
+ :type df: pandas.DataFrame
155
+ :param col_name: The name of the column in `df` that contains the text to be encoded.
156
+ :type col_name: str
157
+ :return: A PyTorch tensor containing the L2-normalized text embeddings,
158
+ where the shape is (number_of_rows, embedding_dim).
159
+ :rtype: torch.Tensor
160
+ """
161
+
162
+ # Prepare a list to hold each row's text embedding
163
+ text_embeddings = []
164
+
165
+ # Loop through each index in the DataFrame
166
+ for idx in df.index:
167
+ # Retrieve text from the specified column for the current row
168
+ text = df[df.index == idx][col_name][0]
169
+
170
+ # Encode the text using the provided model and tokenizer
171
+ text_features = encode_text(model, tokenizer, text)
172
+
173
+ # Accumulate the embedding tensor
174
+ text_embeddings.append(text_features)
175
+
176
+ # Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
177
+ text_embeddings = torch.from_numpy(np.array(text_embeddings))
178
+
179
+ # Normalize embeddings to unit length across the feature dimension
180
+ text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
181
+
182
+ return text_embeddings
183
+
184
+
185
+
186
+ def get_pca_by_fit(tar_features, src_features):
187
+ """
188
+ Applies PCA to target features and transforms both target and source features using the fitted PCA model.
189
+ Combines the PCA-transformed features from both target and source datasets and returns the combined data
190
+ along with batch labels indicating the origin of each sample.
191
+
192
+ :param tar_features: Numpy array of target features (samples by features).
193
+ :param src_features: Numpy array of source features (samples by features).
194
+ :return:
195
+ - pca_comb_features: A numpy array containing PCA-transformed target and source features combined.
196
+ - pca_comb_features_batch: A numpy array of batch labels indicating which samples are from target (0) and source (1).
197
+ """
198
+
199
+ pca = PCA(n_components=3)
200
+
201
+ # Fit the PCA model on the target features (transposed to fit on features)
202
+ pca_fit_tar = pca.fit(tar_features.T)
203
+
204
+ # Transform the target and source features using the fitted PCA model
205
+ pca_tar = pca_fit_tar.transform(tar_features.T) # Transform target features
206
+ pca_src = pca_fit_tar.transform(src_features.T) # Transform source features using the same PCA fit
207
+
208
+ # Combine the PCA-transformed target and source features
209
+ pca_comb_features = np.concatenate((pca_tar, pca_src))
210
+
211
+ # Create a batch label array: 0 for target features, 1 for source features
212
+ pca_comb_features_batch = np.array([0] * len(pca_tar) + [1] * len(pca_src))
213
+
214
+ return pca_comb_features, pca_comb_features_batch
215
+
216
+
217
+
218
+ def cap_quantile(weight, cap_max=None, cap_min=None):
219
+ """
220
+ Caps the values in the 'weight' array based on the specified quantile thresholds for maximum and minimum values.
221
+ If the quantile thresholds are provided, the function will replace values above or below these thresholds
222
+ with the corresponding quantile values.
223
+
224
+ :param weight: Numpy array of weights to be capped.
225
+ :param cap_max: Quantile threshold for the maximum cap. Values above this quantile will be capped.
226
+ If None, no maximum capping will be applied.
227
+ :param cap_min: Quantile threshold for the minimum cap. Values below this quantile will be capped.
228
+ If None, no minimum capping will be applied.
229
+ :return: Numpy array with the values capped at the specified quantiles.
230
+ """
231
+
232
+ # If a maximum cap is specified, calculate the value at the specified cap_max quantile
233
+ if cap_max is not None:
234
+ cap_max = np.quantile(weight, cap_max) # Get the value at the cap_max quantile
235
+
236
+ # If a minimum cap is specified, calculate the value at the specified cap_min quantile
237
+ if cap_min is not None:
238
+ cap_min = np.quantile(weight, cap_min) # Get the value at the cap_min quantile
239
+
240
+ # Cap the values in 'weight' array to not exceed the maximum cap (cap_max)
241
+ weight = np.minimum(weight, cap_max)
242
+
243
+ # Cap the values in 'weight' array to not go below the minimum cap (cap_min)
244
+ weight = np.maximum(weight, cap_min)
245
+
246
+ return weight
247
+
248
+
249
+
250
+ def read_polygons(file_path, slide_id):
251
+ """
252
+ Reads polygon data from a JSON file for a specific slide ID, extracting coordinates, colors, and thickness.
253
+
254
+ :param file_path: Path to the JSON file containing polygon configurations.
255
+ :param slide_id: Identifier for the specific slide whose polygon data is to be extracted.
256
+ :return:
257
+ - polygons: A list of numpy arrays, where each array contains the coordinates of a polygon.
258
+ - polygon_colors: A list of color values corresponding to each polygon.
259
+ - polygon_thickness: A list of thickness values for each polygon's border.
260
+ """
261
+
262
+ # Open the JSON file and load the polygon configurations into a Python dictionary
263
+ with open(file_path, 'r') as f:
264
+ polygons_configs = json.load(f)
265
+
266
+ # Check if the given slide_id exists in the polygon configurations
267
+ if slide_id not in polygons_configs:
268
+ return None, None, None # If slide_id is not found, return None for all outputs
269
+
270
+ # Extract the polygon coordinates, colors, and thicknesses for the given slide_id
271
+ polygons = [np.array(poly['coords']) for poly in polygons_configs[slide_id]] # Convert polygon coordinates to numpy arrays
272
+ polygon_colors = [poly['color'] for poly in polygons_configs[slide_id]] # Extract the color for each polygon
273
+ polygon_thickness = [poly['thickness'] for poly in polygons_configs[slide_id]] # Extract the thickness for each polygon
274
+
275
+ # Return the polygons, their colors, and their thicknesses
276
+ return polygons, polygon_colors, polygon_thickness
277
+
278
+
src/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anndata==0.10.9
2
+ matplotlib==3.9.2
3
+ numpy==1.25.0
4
+ pandas==2.2.3
5
+ opencv-python==4.10.0.84
6
+ pycpd==2.0.0
7
+ torch==2.3.1
8
+ tangram-sc==1.0.4
9
+ tqdm==4.66.5
10
+ torchvision==0.18.1
11
+ open_clip_torch==2.26.1
12
+ pillow==10.4.0
13
+ ipykernel==6.29.5
14
+
src/setup.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+
4
+ setuptools.setup(
5
+ name="loki", # The name of your package on PyPI
6
+ version="0.0.1", # Choose your initial release version
7
+ author="Weiqing Chen",
8
+ author_email="wec4005@med.cornell.edu",
9
+ description="The Loki platform offers 5 core functions: tissue alignment, tissue annotation, cell type decomposition, image-transcriptomics retrieval, and ST gene expression prediction",
10
+ packages=setuptools.find_packages(), # Finds the 'loki' folder automatically
11
+ classifiers=[
12
+ "Programming Language :: Python :: 3",
13
+ "License :: BSD 3-Clause License",
14
+ "Operating System :: OS Independent",
15
+ ],
16
+ python_requires='>=3.9', # or the minimum version you support
17
+ install_requires=[
18
+ "anndata==0.10.9",
19
+ "matplotlib==3.9.2",
20
+ "numpy==1.25.0",
21
+ "pandas==2.2.3",
22
+ "opencv-python==4.10.0.84",
23
+ "pycpd==2.0.0",
24
+ "torch==2.3.1",
25
+ "tangram-sc==1.0.4",
26
+ "tqdm==4.66.5",
27
+ "torchvision==0.18.1",
28
+ "open_clip_torch==2.26.1",
29
+ "pillow==10.4.0",
30
+ "ipykernel==6.29.5",
31
+ ],
32
+ )