osakemon commited on
Commit
b118ecd
·
verified ·
1 Parent(s): ed1ae7f

update source code

Browse files
src/build/lib/loki/decompose.py CHANGED
@@ -77,11 +77,11 @@ def cell_type_decompose(sc_ad, st_ad, cell_type_col='cell_type', NMS_mode=False,
77
 
78
  :param sc_ad: AnnData object containing single-cell meta data.
79
  :param st_ad: AnnData object containing spatial data (ST or image) meta data.
80
- :param density_prior: A numpy array providing prior information about cell densities in spatial spots.
81
  :param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
82
- :param target_count: If True, sums up the total number of cells in `st_ad.obs['cell_num']`. Can also be set to a specific value.
83
- :param pca_mode: Boolean flag to apply PCA for dimensionality reduction. Default is True.
84
- :param n_components: Number of PCA components to use if `pca_mode` is True. Default is 300.
 
85
  :return: The spatial AnnData object with projected cell type annotations.
86
  """
87
 
 
77
 
78
  :param sc_ad: AnnData object containing single-cell meta data.
79
  :param st_ad: AnnData object containing spatial data (ST or image) meta data.
 
80
  :param cell_type_col: The column name in `sc_ad.obs` that contains cell type annotations. Default is 'cell_type'.
81
+ :param NMS_mode: Boolean flag to apply Non-Maximum Suppression (NMS) mode. Default is False.
82
+ :param major_types: Major cell types used for NMS mode. Default is None.
83
+ :param min_percentile: The lower percentile used for clipping (defaults to 5).
84
+ :param max_percentile: The upper percentile used for clipping (defaults to 95).
85
  :return: The spatial AnnData object with projected cell type annotations.
86
  """
87
 
src/build/lib/loki/plot.py CHANGED
@@ -8,107 +8,102 @@ import numpy as np
8
  from tqdm import tqdm
9
 
10
 
11
-
12
- def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
 
 
 
 
 
 
 
 
13
  """
14
- Plots the target coordinates and alignment of source coordinates.
15
-
16
- :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
17
- :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
18
- :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
19
- :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
20
- :param tar_features: Feature matrix for the target, used to split color values between target and source data.
21
- :param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
22
- :param s: Marker size for the scatter plot points. Default is 0.8.
23
- :param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
24
- :return: Displays the alignment plot of target, source, and alignment of source coordinates.
25
  """
26
-
27
- # Create a figure with three subplots, adjusting size and resolution
28
- plt.figure(figsize=(10, 3), dpi=300)
29
-
30
- # First subplot: Plot target coordinates
31
- plt.subplot(1, 3, 1)
32
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
33
- # Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
34
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
35
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
36
-
37
- # Second subplot: Plot source coordinates
38
- plt.subplot(1, 3, 2)
39
- plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
40
- # Ensure consistent plot limits across subplots by using the same limits as the target coordinates
41
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
42
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
43
-
44
- # Third subplot: Plot alignment of source coordinates
45
- plt.subplot(1, 3, 3)
46
- plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
47
- # Maintain the same plot limits across all subplots for a uniform comparison
48
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
49
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
50
-
51
- # Optionally draw boundary lines at the minimum x and y values of the target coordinates
52
- if boundary_line:
53
- plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
54
- plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
55
-
56
- # Remove the axis labels and ticks from all subplots for a cleaner appearance
57
- plt.axis('off')
58
-
59
- # Display the plot
60
  plt.show()
61
 
62
 
63
 
64
- def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
 
 
 
 
 
 
 
 
 
 
65
  """
66
- Plots the target coordinates and alignment of source coordinates with their respective images in the background.
67
-
68
- :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
69
- :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
70
- :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
71
- :param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
72
- :param src_img: Image associated with the source coordinates, used as the background in the second subplot.
73
- :param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
74
- :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
75
- :param tar_features: Feature matrix for the target, used to split color values between target and source data.
76
- :return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
77
  """
78
-
79
- # Create a figure with three subplots and set the size and resolution
80
- plt.figure(figsize=(10, 8), dpi=150)
81
-
82
- # First subplot: Plot target coordinates with the target image as the background
83
- plt.subplot(1, 3, 1)
84
- # Scatter plot for the target coordinates with transparency and small marker size
85
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
86
- # Overlay the target image with some transparency (alpha = 0.3)
87
- plt.imshow(tar_img, origin='lower', alpha=0.3)
88
-
89
- # Second subplot: Plot source coordinates with the source image as the background
90
- plt.subplot(1, 3, 2)
91
- # Scatter plot for the source coordinates with transparency and small marker size
92
- plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
93
- # Overlay the source image with some transparency (alpha = 0.3)
94
- plt.imshow(src_img, origin='lower', alpha=0.3)
95
-
96
- # Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
97
- plt.subplot(1, 3, 3)
98
- # Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
99
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
100
- # Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
101
- plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
102
- # Overlay the aligned image with some transparency (alpha = 0.3)
103
- plt.imshow(aligned_image, origin='lower', alpha=0.3)
104
-
105
- # Turn off the axis for all subplots to give a cleaner visual output
106
- plt.axis('off')
107
-
108
- # Display the plots
109
  plt.show()
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def draw_polygon(image, polygon, color='k', thickness=2):
114
  """
@@ -228,6 +223,9 @@ def plot_heatmap(
228
  coor,
229
  similairty,
230
  image_path=None,
 
 
 
231
  patch_size=(256, 256),
232
  save_path=None,
233
  downsize=32,
@@ -236,9 +234,6 @@ def plot_heatmap(
236
  boxes=None,
237
  box_color='k',
238
  box_thickness=2,
239
- polygons=None,
240
- polygons_color='k',
241
- polygons_thickness=2,
242
  image_alpha=0.5
243
  ):
244
  """
@@ -316,7 +311,7 @@ def plot_heatmap(
316
 
317
 
318
 
319
- def show_images_side_by_side(image1, image2, title1=None, title2=None):
320
  """
321
  Displays two images side by side in a single figure.
322
 
@@ -328,7 +323,7 @@ def show_images_side_by_side(image1, image2, title1=None, title2=None):
328
  """
329
 
330
  # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
331
- fig, ax = plt.subplots(1, 2, figsize=(15,8))
332
 
333
  # Display the first image on the first subplot
334
  ax[0].imshow(image1)
@@ -364,7 +359,7 @@ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
364
  """
365
 
366
  # Create a new figure with a fixed size for displaying the image and annotations
367
- plt.figure(figsize=(10, 10))
368
 
369
  # Display the full-resolution image
370
  plt.imshow(fullres_img)
@@ -403,7 +398,7 @@ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
403
  """
404
 
405
  # Create a new figure with a fixed size for displaying the heatmap and annotations
406
- plt.figure(figsize=(10, 10))
407
 
408
  # Scatter plot for the spatial transcriptomics data.
409
  # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
 
8
  from tqdm import tqdm
9
 
10
 
11
+ def plot_alignment(
12
+ ad_tar_coor: np.ndarray,
13
+ ad_src_coor: np.ndarray,
14
+ homo_coor: np.ndarray,
15
+ pca_hex_comb: np.ndarray,
16
+ tar_features: np.ndarray,
17
+ shift: float = 300,
18
+ s: float = 0.8,
19
+ boundary_line: bool = True
20
+ ) -> None:
21
  """
22
+ Optimized plot: target, source, and aligned coordinates with titles.
 
 
 
 
 
 
 
 
 
 
23
  """
24
+ # Determine common limits
25
+ coords = np.vstack([ad_tar_coor, ad_src_coor, homo_coor])
26
+ x_min, x_max = coords[:,0].min() - shift, coords[:,0].max() + shift
27
+ y_min, y_max = coords[:,1].min() - shift, coords[:,1].max() + shift
28
+
29
+ fig, axes = plt.subplots(1, 3, figsize=(10, 3), dpi=150)
30
+ titles = ["Target ST", "Source ST", "Aligned Source ST"]
31
+ splits = [len(ad_tar_coor), len(ad_tar_coor)+len(ad_src_coor)]
32
+
33
+ for ax, title, data_slice in zip(
34
+ axes,
35
+ titles,
36
+ [(ad_tar_coor, pca_hex_comb[:splits[0]]),
37
+ (ad_src_coor, pca_hex_comb[splits[0]:splits[1]]),
38
+ (homo_coor, pca_hex_comb[splits[0]:splits[1]])]
39
+ ):
40
+ coords_arr, colors = data_slice
41
+ ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
42
+ ax.set_xlim(x_min, x_max)
43
+ ax.set_ylim(y_min, y_max)
44
+ ax.set_aspect('equal')
45
+ if boundary_line:
46
+ ax.axvline(x=ad_tar_coor[:,0].min(), color='black', linewidth=1)
47
+ ax.axhline(y=ad_tar_coor[:,1].min(), color='black', linewidth=1)
48
+ ax.set_title(title)
49
+ ax.axis('off')
50
+ plt.tight_layout()
 
 
 
 
 
 
 
51
  plt.show()
52
 
53
 
54
 
55
+ def plot_alignment_with_img(
56
+ ad_tar_coor: np.ndarray,
57
+ ad_src_coor: np.ndarray,
58
+ homo_coor: np.ndarray,
59
+ tar_img,
60
+ src_img,
61
+ aligned_image,
62
+ pca_hex_comb: np.ndarray,
63
+ tar_features: np.ndarray,
64
+ s: float = 1.0
65
+ ) -> None:
66
  """
67
+ Optimized plot with images in the background and subplot titles.
 
 
 
 
 
 
 
 
 
 
68
  """
69
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5), dpi=150)
70
+ titles = ["Target + Image", "Source + Image", "Aligned + Image"]
71
+ splits = [len(tar_features.T), len(tar_features.T) * 2]
72
+
73
+ # Data slices for each subplot
74
+ data_slices = [
75
+ (ad_tar_coor, pca_hex_comb[:splits[0]], tar_img),
76
+ (ad_src_coor, pca_hex_comb[splits[0]:splits[1]], src_img),
77
+ (np.vstack([ad_tar_coor, homo_coor]),
78
+ np.concatenate([pca_hex_comb[:splits[0]], pca_hex_comb[splits[0]:splits[1]]]),
79
+ aligned_image)
80
+ ]
81
+
82
+ for ax, title, (coords_arr, colors, img) in zip(axes, titles, data_slices):
83
+ ax.imshow(img, origin='lower', alpha=0.3)
84
+ ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
85
+ ax.set_aspect('equal')
86
+ ax.set_title(title)
87
+ ax.axis('off')
88
+
89
+ plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
90
  plt.show()
91
 
92
 
93
+ def show_image(img, title: str = "Aligned Source Image", origin: str = "lower", cmap=None):
94
+ """
95
+ Display a single image with no axes and a title.
96
+
97
+ :param img: The image to display (NumPy array, PIL Image, etc.).
98
+ :param title: Title to show above the image.
99
+ :param origin: Origin parameter passed to plt.imshow (e.g. 'lower' or 'upper').
100
+ :param cmap: Optional colormap for grayscale or other single‑channel data.
101
+ """
102
+ plt.imshow(img, origin=origin, cmap=cmap)
103
+ plt.title(title)
104
+ plt.axis('off')
105
+ plt.show()
106
+
107
 
108
  def draw_polygon(image, polygon, color='k', thickness=2):
109
  """
 
223
  coor,
224
  similairty,
225
  image_path=None,
226
+ polygons=None,
227
+ polygons_color='k',
228
+ polygons_thickness=2,
229
  patch_size=(256, 256),
230
  save_path=None,
231
  downsize=32,
 
234
  boxes=None,
235
  box_color='k',
236
  box_thickness=2,
 
 
 
237
  image_alpha=0.5
238
  ):
239
  """
 
311
 
312
 
313
 
314
+ def show_images_side_by_side(image1, image2, title1='Annotated H&E Image', title2='Similatrity Heatmap'):
315
  """
316
  Displays two images side by side in a single figure.
317
 
 
323
  """
324
 
325
  # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
326
+ fig, ax = plt.subplots(1, 2, figsize=(8,6), dpi=150)
327
 
328
  # Display the first image on the first subplot
329
  ax[0].imshow(image1)
 
359
  """
360
 
361
  # Create a new figure with a fixed size for displaying the image and annotations
362
+ plt.figure(figsize=(12, 12), dpi=150)
363
 
364
  # Display the full-resolution image
365
  plt.imshow(fullres_img)
 
398
  """
399
 
400
  # Create a new figure with a fixed size for displaying the heatmap and annotations
401
+ plt.figure(figsize=(12, 12), dpi=150)
402
 
403
  # Scatter plot for the spatial transcriptomics data.
404
  # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
src/build/lib/loki/utils.py CHANGED
@@ -11,175 +11,107 @@ from open_clip import create_model_from_pretrained, get_tokenizer
11
 
12
 
13
 
14
- def load_model(model_path, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
- Loads a pretrained CoCa (CLIP-like) model, along with its preprocessing function and tokenizer,
17
- using the specified model checkpoint.
18
-
19
- :param model_path: File path or URL to the pretrained model checkpoint. This is passed to
20
- `create_model_from_pretrained` as the `pretrained` argument.
21
- :type model_path: str
22
- :param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
23
- :type device: str or torch.device
24
- :return: A tuple `(model, preprocess, tokenizer)` where:
25
- - model: The loaded CoCa model.
26
- - preprocess: A function or transform that preprocesses input data for the model.
27
- - tokenizer: A tokenizer appropriate for textual input to the model.
28
- :rtype: (nn.Module, callable, callable)
29
  """
30
- # Create the model and its preprocessing transform from the specified checkpoint
31
  model, preprocess = create_model_from_pretrained(
32
  "coca_ViT-L-14", device=device, pretrained=model_path
33
  )
34
-
35
- # Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
36
- tokenizer = get_tokenizer('coca_ViT-L-14')
37
-
38
  return model, preprocess, tokenizer
39
 
 
40
 
41
-
42
- def encode_image(model, preprocess, image):
43
- """
44
- Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
45
-
46
- :param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
47
- :type model: torch.nn.Module
48
- :param preprocess: A preprocessing function that transforms the input image into a tensor
49
- suitable for the model. Typically something returning a PyTorch tensor.
50
- :type preprocess: callable
51
- :param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
52
- :type image: PIL.Image.Image or numpy.ndarray
53
- :return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
54
- :rtype: torch.Tensor
55
  """
56
- # Preprocess the image, then stack to create a batch of size 1
57
- image_input = torch.stack([preprocess(image)])
58
-
59
- # Generate the image features without gradient tracking
60
- with torch.no_grad():
61
- image_features = model.encode_image(image_input)
62
-
63
- # Normalize embeddings across the feature dimension (L2 normalization)
64
- image_embeddings = F.normalize(image_features, p=2, dim=-1)
65
-
66
- return image_embeddings
67
-
68
-
69
-
70
- def encode_image_patches(model, preprocess, data_dir, img_list):
71
  """
72
- Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
 
 
73
 
74
- :param model: A model object that provides an `encode_image` method (e.g., a CLIP or CoCa model).
75
- :type model: torch.nn.Module
76
- :param preprocess: A preprocessing function that transforms the input image into a tensor
77
- suitable for the model. Typically something returning a PyTorch tensor.
78
- :type preprocess: callable
79
- :param data_dir: The base directory containing image data.
80
- :type data_dir: str
81
- :param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
82
- stored in `data_dir/demo_data/patch/`.
83
- :type img_list: list[str]
84
- :return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
85
- for each image in `img_list`.
86
- :rtype: torch.Tensor
87
- """
88
 
89
- # Prepare a list to hold each image's feature embedding
90
- image_embeddings = []
91
 
92
- # Loop through each image name in the provided list
93
- for img_name in img_list:
94
- # Build the path to the patch image and open it
95
- image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
96
- image = Image.open(image_path)
97
 
98
- # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
99
- image_features = encode_image(model, preprocess, image)
100
 
101
- # Accumulate the feature embeddings in the list
102
- image_embeddings.append(image_features)
103
 
104
- # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
105
- # Resulting shape will be (N, 1, embedding_dim)
106
- image_embeddings = torch.from_numpy(np.array(image_embeddings))
107
 
108
- # Normalize all embeddings across the feature dimension (L2 normalization)
109
- image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
110
 
111
- return image_embeddings
112
 
113
 
 
114
 
115
- def encode_text(model, tokenizer, text):
 
 
 
 
 
116
  """
117
- Encodes text into a normalized feature embedding using a specified model and tokenizer.
118
-
119
- :param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
120
- :type model: torch.nn.Module
121
- :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
122
- Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
123
- :type tokenizer: callable
124
- :param text: The input text (string or list of strings) to be encoded.
125
- :type text: str or list[str]
126
- :return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
127
- :rtype: torch.Tensor
128
  """
129
-
130
- # Convert text to the appropriate tokenized representation
131
- text_input = tokenizer(text)
132
-
133
- # Run the model in no-grad mode (not tracking gradients, saving memory and compute)
134
  with torch.no_grad():
135
- text_features = model.encode_text(text_input)
136
-
137
- # Normalize embeddings to unit length
138
- text_embeddings = F.normalize(text_features, p=2, dim=-1)
139
 
140
- return text_embeddings
141
 
142
-
143
-
144
- def encode_text_df(model, tokenizer, df, col_name):
 
 
 
 
145
  """
146
- Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
147
- returning a PyTorch tensor of normalized text embeddings.
148
-
149
- :param model: A model object that provides an `encode_text` method (e.g., a CLIP-like or CoCa model).
150
- :type model: torch.nn.Module
151
- :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
152
- :type tokenizer: callable
153
- :param df: A pandas DataFrame from which text will be extracted.
154
- :type df: pandas.DataFrame
155
- :param col_name: The name of the column in `df` that contains the text to be encoded.
156
- :type col_name: str
157
- :return: A PyTorch tensor containing the L2-normalized text embeddings,
158
- where the shape is (number_of_rows, embedding_dim).
159
- :rtype: torch.Tensor
160
  """
 
 
161
 
162
- # Prepare a list to hold each row's text embedding
163
- text_embeddings = []
164
-
165
- # Loop through each index in the DataFrame
166
- for idx in df.index:
167
- # Retrieve text from the specified column for the current row
168
- text = df[df.index == idx][col_name][0]
169
-
170
- # Encode the text using the provided model and tokenizer
171
- text_features = encode_text(model, tokenizer, text)
172
-
173
- # Accumulate the embedding tensor
174
- text_embeddings.append(text_features)
175
-
176
- # Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
177
- text_embeddings = torch.from_numpy(np.array(text_embeddings))
178
-
179
- # Normalize embeddings to unit length across the feature dimension
180
- text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
181
-
182
- return text_embeddings
183
 
184
 
185
 
 
11
 
12
 
13
 
14
+ import os
15
+ from typing import List, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ from PIL import Image
21
+ import pandas as pd
22
+
23
+ # --- Model loading --------------------------------------------------------
24
+
25
+ def load_model(
26
+ model_path: str,
27
+ device: Union[str, torch.device]
28
+ ) -> Tuple[torch.nn.Module, callable, callable]:
29
  """
30
+ Load pretrained OmiCLIP (COCA ViT‑L‑14) model, its image preprocess, and tokenizer.
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
 
32
  model, preprocess = create_model_from_pretrained(
33
  "coca_ViT-L-14", device=device, pretrained=model_path
34
  )
35
+ tokenizer = get_tokenizer("coca_ViT-L-14")
36
+ model.to(device).eval()
 
 
37
  return model, preprocess, tokenizer
38
 
39
+ # --- Image encoding -------------------------------------------------------
40
 
41
+ def encode_images(
42
+ model: torch.nn.Module,
43
+ preprocess: callable,
44
+ image_paths: List[str],
45
+ device: Union[str, torch.device]
46
+ ) -> torch.Tensor:
 
 
 
 
 
 
 
 
47
  """
48
+ Batch–encode a list of image file paths into L2‑normalized embeddings.
49
+ Returns a tensor of shape (N, D).
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
+ # Load & preprocess all images
52
+ imgs = [preprocess(Image.open(p)) for p in image_paths]
53
+ batch = torch.stack(imgs, dim=0).to(device) # (N, C, H, W)
54
 
55
+ with torch.no_grad():
56
+ feats = model.encode_image(batch) # (N, D)
57
+ return F.normalize(feats, p=2, dim=-1) # (N, D)
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
59
 
60
+ # # Loop through each image name in the provided list
61
+ # for img_name in img_list:
62
+ # # Build the path to the patch image and open it
63
+ # image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
64
+ # image = Image.open(image_path)
65
 
66
+ # # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
67
+ # image_features = encode_image(model, preprocess, image)
68
 
69
+ # # Accumulate the feature embeddings in the list
70
+ # image_embeddings.append(image_features)
71
 
72
+ # # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
73
+ # # Resulting shape will be (N, 1, embedding_dim)
74
+ # image_embeddings = torch.from_numpy(np.array(image_embeddings))
75
 
76
+ # # Normalize all embeddings across the feature dimension (L2 normalization)
77
+ # image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
78
 
79
+ # return image_embeddings
80
 
81
 
82
+ # --- Text encoding --------------------------------------------------------
83
 
84
+ def encode_texts(
85
+ model: torch.nn.Module,
86
+ tokenizer: callable,
87
+ texts: List[str],
88
+ device: Union[str, torch.device]
89
+ ) -> torch.Tensor:
90
  """
91
+ Batch–encode a list of strings into L2‑normalized embeddings.
92
+ Returns a tensor of shape (N, D).
 
 
 
 
 
 
 
 
 
93
  """
94
+ # Tokenizer returns a dict of tensors
95
+ text_inputs = tokenizer(texts)
96
+
 
 
97
  with torch.no_grad():
98
+ feats = model.encode_text(text_inputs) # (N, D)
99
+ return F.normalize(feats, p=2, dim=-1) # (N, D)
 
 
100
 
 
101
 
102
+ def encode_text_df(
103
+ model: torch.nn.Module,
104
+ tokenizer: callable,
105
+ df: pd.DataFrame,
106
+ col_name: str,
107
+ device: Union[str, torch.device]
108
+ ) -> torch.Tensor:
109
  """
110
+ Encodes an entire DataFrame column into (N, D) embeddings.
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  """
112
+ texts = df[col_name].astype(str).tolist()
113
+ return encode_texts(model, tokenizer, texts, device)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
 
src/loki.egg-info/PKG-INFO CHANGED
@@ -1,13 +1,14 @@
1
  Metadata-Version: 2.1
2
  Name: loki
3
  Version: 0.0.1
4
- Summary: The Loki platform offers 5 core functions: tissue alignment, cell type decomposition, tissue annotation, image-transcriptomics retrieval, and ST gene expression prediction
5
  Author: Weiqing Chen
6
  Author-email: wec4005@med.cornell.edu
7
  Classifier: Programming Language :: Python :: 3
8
- Classifier: License :: OSI Approved :: MIT License
9
  Classifier: Operating System :: OS Independent
10
  Requires-Python: >=3.9
 
11
  Requires-Dist: anndata==0.10.9
12
  Requires-Dist: matplotlib==3.9.2
13
  Requires-Dist: numpy==1.25.0
 
1
  Metadata-Version: 2.1
2
  Name: loki
3
  Version: 0.0.1
4
+ Summary: The Loki platform offers 5 core functions: tissue alignment, tissue annotation, cell type decomposition, image-transcriptomics retrieval, and ST gene expression prediction
5
  Author: Weiqing Chen
6
  Author-email: wec4005@med.cornell.edu
7
  Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: BSD 3-Clause License
9
  Classifier: Operating System :: OS Independent
10
  Requires-Python: >=3.9
11
+ License-File: LICENSE
12
  Requires-Dist: anndata==0.10.9
13
  Requires-Dist: matplotlib==3.9.2
14
  Requires-Dist: numpy==1.25.0
src/loki.egg-info/SOURCES.txt CHANGED
@@ -1,4 +1,4 @@
1
- README.md
2
  setup.py
3
  loki/__init__.py
4
  loki/align.py
 
1
+ LICENSE
2
  setup.py
3
  loki/__init__.py
4
  loki/align.py
src/loki/__pycache__/plot.cpython-39.pyc CHANGED
Binary files a/src/loki/__pycache__/plot.cpython-39.pyc and b/src/loki/__pycache__/plot.cpython-39.pyc differ
 
src/loki/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/src/loki/__pycache__/utils.cpython-39.pyc and b/src/loki/__pycache__/utils.cpython-39.pyc differ
 
src/loki/plot.py CHANGED
@@ -8,107 +8,102 @@ import numpy as np
8
  from tqdm import tqdm
9
 
10
 
11
-
12
- def plot_alignment(ad_tar_coor, ad_src_coor, homo_coor, pca_hex_comb, tar_features, shift=300, s=0.8, boundary_line=True):
 
 
 
 
 
 
 
 
13
  """
14
- Plots the target coordinates and alignment of source coordinates.
15
-
16
- :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first subplot.
17
- :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
18
- :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
19
- :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
20
- :param tar_features: Feature matrix for the target, used to split color values between target and source data.
21
- :param shift: Value used to adjust the plot limits around the coordinates for better visualization. Default is 300.
22
- :param s: Marker size for the scatter plot points. Default is 0.8.
23
- :param boundary_line: Boolean indicating whether to draw boundary lines (horizontal and vertical lines). Default is True.
24
- :return: Displays the alignment plot of target, source, and alignment of source coordinates.
25
  """
26
-
27
- # Create a figure with three subplots, adjusting size and resolution
28
- plt.figure(figsize=(10, 3), dpi=300)
29
-
30
- # First subplot: Plot target coordinates
31
- plt.subplot(1, 3, 1)
32
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', s=s, c=pca_hex_comb[:len(tar_features.T)])
33
- # Set plot limits based on the minimum and maximum target coordinates, with extra padding from 'shift'
34
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
35
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
36
-
37
- # Second subplot: Plot source coordinates
38
- plt.subplot(1, 3, 2)
39
- plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
40
- # Ensure consistent plot limits across subplots by using the same limits as the target coordinates
41
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
42
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
43
-
44
- # Third subplot: Plot alignment of source coordinates
45
- plt.subplot(1, 3, 3)
46
- plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='o', s=s, c=pca_hex_comb[len(tar_features.T):])
47
- # Maintain the same plot limits across all subplots for a uniform comparison
48
- plt.xlim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
49
- plt.ylim([ad_tar_coor.min() - shift, ad_tar_coor.max() + shift])
50
-
51
- # Optionally draw boundary lines at the minimum x and y values of the target coordinates
52
- if boundary_line:
53
- plt.axvline(x=ad_tar_coor[:, 0].min(), color='black') # Vertical boundary line at the minimum x of target coordinates
54
- plt.axhline(y=ad_tar_coor[:, 1].min(), color='black') # Horizontal boundary line at the minimum y of target coordinates
55
-
56
- # Remove the axis labels and ticks from all subplots for a cleaner appearance
57
- plt.axis('off')
58
-
59
- # Display the plot
60
  plt.show()
61
 
62
 
63
 
64
- def plot_alignment_with_img(ad_tar_coor, ad_src_coor, homo_coor, tar_img, src_img, aligned_image, pca_hex_comb, tar_features):
 
 
 
 
 
 
 
 
 
 
65
  """
66
- Plots the target coordinates and alignment of source coordinates with their respective images in the background.
67
-
68
- :param ad_tar_coor: Numpy array of target coordinates to be plotted in the first and third subplots.
69
- :param ad_src_coor: Numpy array of source coordinates to be plotted in the second subplot.
70
- :param homo_coor: Numpy array of alignment of source coordinates to be plotted in the third subplot.
71
- :param tar_img: Image associated with the target coordinates, used as the background in the first subplot.
72
- :param src_img: Image associated with the source coordinates, used as the background in the second subplot.
73
- :param aligned_image: Image associated with the aligned coordinates, used as the background in the third subplot.
74
- :param pca_hex_comb: Color values (e.g., PCA or hex values) for plotting the coordinates.
75
- :param tar_features: Feature matrix for the target, used to split color values between target and source data.
76
- :return: Displays the alignment plot of target, source, and alignment of source coordinates with their associated images.
77
  """
78
-
79
- # Create a figure with three subplots and set the size and resolution
80
- plt.figure(figsize=(10, 8), dpi=150)
81
-
82
- # First subplot: Plot target coordinates with the target image as the background
83
- plt.subplot(1, 3, 1)
84
- # Scatter plot for the target coordinates with transparency and small marker size
85
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[:len(tar_features.T)])
86
- # Overlay the target image with some transparency (alpha = 0.3)
87
- plt.imshow(tar_img, origin='lower', alpha=0.3)
88
-
89
- # Second subplot: Plot source coordinates with the source image as the background
90
- plt.subplot(1, 3, 2)
91
- # Scatter plot for the source coordinates with transparency and small marker size
92
- plt.scatter(ad_src_coor[:, 0], ad_src_coor[:, 1], marker='o', alpha=0.8, s=1, c=pca_hex_comb[len(tar_features.T):])
93
- # Overlay the source image with some transparency (alpha = 0.3)
94
- plt.imshow(src_img, origin='lower', alpha=0.3)
95
-
96
- # Third subplot: Plot both target and alignment of source coordinates with the aligned image as the background
97
- plt.subplot(1, 3, 3)
98
- # Scatter plot for the target coordinates with lower opacity (alpha = 0.2)
99
- plt.scatter(ad_tar_coor[:, 0], ad_tar_coor[:, 1], marker='o', alpha=0.2, s=1, c=pca_hex_comb[:len(tar_features.T)])
100
- # Scatter plot for the homologous coordinates with a '+' marker and the same color mapping
101
- plt.scatter(homo_coor[:, 0], homo_coor[:, 1], marker='+', s=1, c=pca_hex_comb[len(tar_features.T):])
102
- # Overlay the aligned image with some transparency (alpha = 0.3)
103
- plt.imshow(aligned_image, origin='lower', alpha=0.3)
104
-
105
- # Turn off the axis for all subplots to give a cleaner visual output
106
- plt.axis('off')
107
-
108
- # Display the plots
109
  plt.show()
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def draw_polygon(image, polygon, color='k', thickness=2):
114
  """
@@ -228,6 +223,9 @@ def plot_heatmap(
228
  coor,
229
  similairty,
230
  image_path=None,
 
 
 
231
  patch_size=(256, 256),
232
  save_path=None,
233
  downsize=32,
@@ -236,9 +234,6 @@ def plot_heatmap(
236
  boxes=None,
237
  box_color='k',
238
  box_thickness=2,
239
- polygons=None,
240
- polygons_color='k',
241
- polygons_thickness=2,
242
  image_alpha=0.5
243
  ):
244
  """
@@ -316,7 +311,7 @@ def plot_heatmap(
316
 
317
 
318
 
319
- def show_images_side_by_side(image1, image2, title1=None, title2=None):
320
  """
321
  Displays two images side by side in a single figure.
322
 
@@ -328,7 +323,7 @@ def show_images_side_by_side(image1, image2, title1=None, title2=None):
328
  """
329
 
330
  # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
331
- fig, ax = plt.subplots(1, 2, figsize=(15,8))
332
 
333
  # Display the first image on the first subplot
334
  ax[0].imshow(image1)
@@ -364,7 +359,7 @@ def plot_img_with_annotation(fullres_img, roi_polygon, linewidth, xlim, ylim):
364
  """
365
 
366
  # Create a new figure with a fixed size for displaying the image and annotations
367
- plt.figure(figsize=(10, 10))
368
 
369
  # Display the full-resolution image
370
  plt.imshow(fullres_img)
@@ -403,7 +398,7 @@ def plot_annotation_heatmap(st_ad, roi_polygon, s, linewidth, xlim, ylim):
403
  """
404
 
405
  # Create a new figure with a fixed size for displaying the heatmap and annotations
406
- plt.figure(figsize=(10, 10))
407
 
408
  # Scatter plot for the spatial transcriptomics data.
409
  # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
 
8
  from tqdm import tqdm
9
 
10
 
11
+ def plot_alignment(
12
+ ad_tar_coor: np.ndarray,
13
+ ad_src_coor: np.ndarray,
14
+ homo_coor: np.ndarray,
15
+ pca_hex_comb: np.ndarray,
16
+ tar_features: np.ndarray,
17
+ shift: float = 300,
18
+ s: float = 0.8,
19
+ boundary_line: bool = True
20
+ ) -> None:
21
  """
22
+ Optimized plot: target, source, and aligned coordinates with titles.
 
 
 
 
 
 
 
 
 
 
23
  """
24
+ # Determine common limits
25
+ coords = np.vstack([ad_tar_coor, ad_src_coor, homo_coor])
26
+ x_min, x_max = coords[:,0].min() - shift, coords[:,0].max() + shift
27
+ y_min, y_max = coords[:,1].min() - shift, coords[:,1].max() + shift
28
+
29
+ fig, axes = plt.subplots(1, 3, figsize=(10, 3), dpi=150)
30
+ titles = ["Target ST", "Source ST", "Aligned Source ST"]
31
+ splits = [len(ad_tar_coor), len(ad_tar_coor)+len(ad_src_coor)]
32
+
33
+ for ax, title, data_slice in zip(
34
+ axes,
35
+ titles,
36
+ [(ad_tar_coor, pca_hex_comb[:splits[0]]),
37
+ (ad_src_coor, pca_hex_comb[splits[0]:splits[1]]),
38
+ (homo_coor, pca_hex_comb[splits[0]:splits[1]])]
39
+ ):
40
+ coords_arr, colors = data_slice
41
+ ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
42
+ ax.set_xlim(x_min, x_max)
43
+ ax.set_ylim(y_min, y_max)
44
+ ax.set_aspect('equal')
45
+ if boundary_line:
46
+ ax.axvline(x=ad_tar_coor[:,0].min(), color='black', linewidth=1)
47
+ ax.axhline(y=ad_tar_coor[:,1].min(), color='black', linewidth=1)
48
+ ax.set_title(title)
49
+ ax.axis('off')
50
+ plt.tight_layout()
 
 
 
 
 
 
 
51
  plt.show()
52
 
53
 
54
 
55
+ def plot_alignment_with_img(
56
+ ad_tar_coor: np.ndarray,
57
+ ad_src_coor: np.ndarray,
58
+ homo_coor: np.ndarray,
59
+ tar_img,
60
+ src_img,
61
+ aligned_image,
62
+ pca_hex_comb: np.ndarray,
63
+ tar_features: np.ndarray,
64
+ s: float = 1.0
65
+ ) -> None:
66
  """
67
+ Optimized plot with images in the background and subplot titles.
 
 
 
 
 
 
 
 
 
 
68
  """
69
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5), dpi=150)
70
+ titles = ["Target + Image", "Source + Image", "Aligned + Image"]
71
+ splits = [len(tar_features.T), len(tar_features.T) * 2]
72
+
73
+ # Data slices for each subplot
74
+ data_slices = [
75
+ (ad_tar_coor, pca_hex_comb[:splits[0]], tar_img),
76
+ (ad_src_coor, pca_hex_comb[splits[0]:splits[1]], src_img),
77
+ (np.vstack([ad_tar_coor, homo_coor]),
78
+ np.concatenate([pca_hex_comb[:splits[0]], pca_hex_comb[splits[0]:splits[1]]]),
79
+ aligned_image)
80
+ ]
81
+
82
+ for ax, title, (coords_arr, colors, img) in zip(axes, titles, data_slices):
83
+ ax.imshow(img, origin='lower', alpha=0.3)
84
+ ax.scatter(coords_arr[:,0], coords_arr[:,1], s=s, c=colors, marker='o')
85
+ ax.set_aspect('equal')
86
+ ax.set_title(title)
87
+ ax.axis('off')
88
+
89
+ plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
90
  plt.show()
91
 
92
 
93
+ def show_image(img, title: str = "Aligned Source Image", origin: str = "lower", cmap=None):
94
+ """
95
+ Display a single image with no axes and a title.
96
+
97
+ :param img: The image to display (NumPy array, PIL Image, etc.).
98
+ :param title: Title to show above the image.
99
+ :param origin: Origin parameter passed to plt.imshow (e.g. 'lower' or 'upper').
100
+ :param cmap: Optional colormap for grayscale or other single‑channel data.
101
+ """
102
+ plt.imshow(img, origin=origin, cmap=cmap)
103
+ plt.title(title)
104
+ plt.axis('off')
105
+ plt.show()
106
+
107
 
108
  def draw_polygon(image, polygon, color='k', thickness=2):
109
  """
 
223
  coor,
224
  similairty,
225
  image_path=None,
226
+ polygons=None,
227
+ polygons_color='k',
228
+ polygons_thickness=2,
229
  patch_size=(256, 256),
230
  save_path=None,
231
  downsize=32,
 
234
  boxes=None,
235
  box_color='k',
236
  box_thickness=2,
 
 
 
237
  image_alpha=0.5
238
  ):
239
  """
 
311
 
312
 
313
 
314
+ def show_images_side_by_side(image1, image2, title1='Annotated H&E Image', title2='Similatrity Heatmap'):
315
  """
316
  Displays two images side by side in a single figure.
317
 
 
323
  """
324
 
325
  # Create a figure with 2 subplots (1 row, 2 columns), and set the figure size
326
+ fig, ax = plt.subplots(1, 2, figsize=(8,6), dpi=150)
327
 
328
  # Display the first image on the first subplot
329
  ax[0].imshow(image1)
 
359
  """
360
 
361
  # Create a new figure with a fixed size for displaying the image and annotations
362
+ plt.figure(figsize=(12, 12), dpi=150)
363
 
364
  # Display the full-resolution image
365
  plt.imshow(fullres_img)
 
398
  """
399
 
400
  # Create a new figure with a fixed size for displaying the heatmap and annotations
401
+ plt.figure(figsize=(12, 12), dpi=150)
402
 
403
  # Scatter plot for the spatial transcriptomics data.
404
  # The 'spatial' coordinates are plotted with color intensity based on 'bulk_simi' values.
src/loki/utils.py CHANGED
@@ -11,175 +11,107 @@ from open_clip import create_model_from_pretrained, get_tokenizer
11
 
12
 
13
 
14
- def load_model(model_path, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
- Loads a pretrained OmiCLIP model, along with its preprocessing function and tokenizer,
17
- using the specified model checkpoint.
18
-
19
- :param model_path: File path to the pretrained model checkpoint. This is passed to
20
- `create_model_from_pretrained` as the `pretrained` argument.
21
- :type model_path: str
22
- :param device: The device on which to load the model (e.g., 'cpu' or 'cuda').
23
- :type device: str or torch.device
24
- :return: A tuple `(model, preprocess, tokenizer)` where:
25
- - model: The loaded OmiCLIP model.
26
- - preprocess: A function or transform that preprocesses input data for the model.
27
- - tokenizer: A tokenizer appropriate for textual input to the model.
28
- :rtype: (nn.Module, callable, callable)
29
  """
30
- # Create the model and its preprocessing transform from the specified checkpoint
31
  model, preprocess = create_model_from_pretrained(
32
  "coca_ViT-L-14", device=device, pretrained=model_path
33
  )
34
-
35
- # Retrieve a tokenizer compatible with the "coca_ViT-L-14" architecture
36
- tokenizer = get_tokenizer('coca_ViT-L-14')
37
-
38
  return model, preprocess, tokenizer
39
 
 
40
 
41
-
42
- def encode_image(model, preprocess, image):
43
- """
44
- Encodes an image into a normalized feature embedding using the specified model and preprocessing function.
45
-
46
- :param model: A model object that provides an `encode_image` method.
47
- :type model: torch.nn.Module
48
- :param preprocess: A preprocessing function that transforms the input image into a tensor
49
- suitable for the model. Typically something returning a PyTorch tensor.
50
- :type preprocess: callable
51
- :param image: The input image (PIL Image, NumPy array, or other format supported by `preprocess`).
52
- :type image: PIL.Image.Image or numpy.ndarray
53
- :return: A single normalized image embedding as a PyTorch tensor of shape (1, embedding_dim).
54
- :rtype: torch.Tensor
55
  """
56
- # Preprocess the image, then stack to create a batch of size 1
57
- image_input = torch.stack([preprocess(image)])
58
-
59
- # Generate the image features without gradient tracking
60
- with torch.no_grad():
61
- image_features = model.encode_image(image_input)
62
-
63
- # Normalize embeddings across the feature dimension (L2 normalization)
64
- image_embeddings = F.normalize(image_features, p=2, dim=-1)
65
-
66
- return image_embeddings
67
-
68
-
69
-
70
- def encode_image_patches(model, preprocess, data_dir, img_list):
71
  """
72
- Encodes multiple image patches into normalized feature embeddings using a specified model and preprocess function.
 
 
73
 
74
- :param model: A model object that provides an `encode_image` method.
75
- :type model: torch.nn.Module
76
- :param preprocess: A preprocessing function that transforms the input image into a tensor
77
- suitable for the model. Typically something returning a PyTorch tensor.
78
- :type preprocess: callable
79
- :param data_dir: The base directory containing image data.
80
- :type data_dir: str
81
- :param img_list: A list of image filenames (strings). Each filename corresponds to a patch image
82
- stored in `data_dir/demo_data/patch/`.
83
- :type img_list: list[str]
84
- :return: A PyTorch tensor of shape (N, 1, embedding_dim), containing the normalized embeddings
85
- for each image in `img_list`.
86
- :rtype: torch.Tensor
87
- """
88
 
89
- # Prepare a list to hold each image's feature embedding
90
- image_embeddings = []
91
 
92
- # Loop through each image name in the provided list
93
- for img_name in img_list:
94
- # Build the path to the patch image and open it
95
- image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
96
- image = Image.open(image_path)
97
 
98
- # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
99
- image_features = encode_image(model, preprocess, image)
100
 
101
- # Accumulate the feature embeddings in the list
102
- image_embeddings.append(image_features)
103
 
104
- # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
105
- # Resulting shape will be (N, 1, embedding_dim)
106
- image_embeddings = torch.from_numpy(np.array(image_embeddings))
107
 
108
- # Normalize all embeddings across the feature dimension (L2 normalization)
109
- image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
110
 
111
- return image_embeddings
112
 
113
 
 
114
 
115
- def encode_text(model, tokenizer, text):
 
 
 
 
 
116
  """
117
- Encodes text into a normalized feature embedding using a specified model and tokenizer.
118
-
119
- :param model: A model object that provides an `encode_text` method.
120
- :type model: torch.nn.Module
121
- :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
122
- Typically returns token IDs, attention masks, etc. as a torch.Tensor or similar structure.
123
- :type tokenizer: callable
124
- :param text: The input text (string or list of strings) to be encoded.
125
- :type text: str or list[str]
126
- :return: A PyTorch tensor of shape (batch_size, embedding_dim) containing the L2-normalized text embeddings.
127
- :rtype: torch.Tensor
128
  """
129
-
130
- # Convert text to the appropriate tokenized representation
131
- text_input = tokenizer(text)
132
-
133
- # Run the model in no-grad mode (not tracking gradients, saving memory and compute)
134
  with torch.no_grad():
135
- text_features = model.encode_text(text_input)
136
-
137
- # Normalize embeddings to unit length
138
- text_embeddings = F.normalize(text_features, p=2, dim=-1)
139
 
140
- return text_embeddings
141
 
142
-
143
-
144
- def encode_text_df(model, tokenizer, df, col_name):
 
 
 
 
145
  """
146
- Encodes text from a specified column in a pandas DataFrame using the given model and tokenizer,
147
- returning a PyTorch tensor of normalized text embeddings.
148
-
149
- :param model: A model object that provides an `encode_text` method.
150
- :type model: torch.nn.Module
151
- :param tokenizer: A tokenizer function that converts the input text into a format suitable for `model.encode_text`.
152
- :type tokenizer: callable
153
- :param df: A pandas DataFrame from which text will be extracted.
154
- :type df: pandas.DataFrame
155
- :param col_name: The name of the column in `df` that contains the text to be encoded.
156
- :type col_name: str
157
- :return: A PyTorch tensor containing the L2-normalized text embeddings,
158
- where the shape is (number_of_rows, embedding_dim).
159
- :rtype: torch.Tensor
160
  """
 
 
161
 
162
- # Prepare a list to hold each row's text embedding
163
- text_embeddings = []
164
-
165
- # Loop through each index in the DataFrame
166
- for idx in df.index:
167
- # Retrieve text from the specified column for the current row
168
- text = df[df.index == idx][col_name][0]
169
-
170
- # Encode the text using the provided model and tokenizer
171
- text_features = encode_text(model, tokenizer, text)
172
-
173
- # Accumulate the embedding tensor
174
- text_embeddings.append(text_features)
175
-
176
- # Convert the list of embeddings (likely shape [N, embedding_dim]) into a NumPy array, then to a torch tensor
177
- text_embeddings = torch.from_numpy(np.array(text_embeddings))
178
-
179
- # Normalize embeddings to unit length across the feature dimension
180
- text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
181
-
182
- return text_embeddings
183
 
184
 
185
 
 
11
 
12
 
13
 
14
+ import os
15
+ from typing import List, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ from PIL import Image
21
+ import pandas as pd
22
+
23
+ # --- Model loading --------------------------------------------------------
24
+
25
+ def load_model(
26
+ model_path: str,
27
+ device: Union[str, torch.device]
28
+ ) -> Tuple[torch.nn.Module, callable, callable]:
29
  """
30
+ Load pretrained OmiCLIP (COCA ViT‑L‑14) model, its image preprocess, and tokenizer.
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
 
32
  model, preprocess = create_model_from_pretrained(
33
  "coca_ViT-L-14", device=device, pretrained=model_path
34
  )
35
+ tokenizer = get_tokenizer("coca_ViT-L-14")
36
+ model.to(device).eval()
 
 
37
  return model, preprocess, tokenizer
38
 
39
+ # --- Image encoding -------------------------------------------------------
40
 
41
+ def encode_images(
42
+ model: torch.nn.Module,
43
+ preprocess: callable,
44
+ image_paths: List[str],
45
+ device: Union[str, torch.device]
46
+ ) -> torch.Tensor:
 
 
 
 
 
 
 
 
47
  """
48
+ Batch–encode a list of image file paths into L2‑normalized embeddings.
49
+ Returns a tensor of shape (N, D).
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
+ # Load & preprocess all images
52
+ imgs = [preprocess(Image.open(p)) for p in image_paths]
53
+ batch = torch.stack(imgs, dim=0).to(device) # (N, C, H, W)
54
 
55
+ with torch.no_grad():
56
+ feats = model.encode_image(batch) # (N, D)
57
+ return F.normalize(feats, p=2, dim=-1) # (N, D)
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
59
 
60
+ # # Loop through each image name in the provided list
61
+ # for img_name in img_list:
62
+ # # Build the path to the patch image and open it
63
+ # image_path = os.path.join(data_dir, 'demo_data', 'patch', img_name)
64
+ # image = Image.open(image_path)
65
 
66
+ # # Encode the image using the model & preprocess; returns shape (1, embedding_dim)
67
+ # image_features = encode_image(model, preprocess, image)
68
 
69
+ # # Accumulate the feature embeddings in the list
70
+ # image_embeddings.append(image_features)
71
 
72
+ # # Convert the list of embeddings to a NumPy array, then to a PyTorch tensor
73
+ # # Resulting shape will be (N, 1, embedding_dim)
74
+ # image_embeddings = torch.from_numpy(np.array(image_embeddings))
75
 
76
+ # # Normalize all embeddings across the feature dimension (L2 normalization)
77
+ # image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
78
 
79
+ # return image_embeddings
80
 
81
 
82
+ # --- Text encoding --------------------------------------------------------
83
 
84
+ def encode_texts(
85
+ model: torch.nn.Module,
86
+ tokenizer: callable,
87
+ texts: List[str],
88
+ device: Union[str, torch.device]
89
+ ) -> torch.Tensor:
90
  """
91
+ Batch–encode a list of strings into L2‑normalized embeddings.
92
+ Returns a tensor of shape (N, D).
 
 
 
 
 
 
 
 
 
93
  """
94
+ # Tokenizer returns a dict of tensors
95
+ text_inputs = tokenizer(texts)
96
+
 
 
97
  with torch.no_grad():
98
+ feats = model.encode_text(text_inputs) # (N, D)
99
+ return F.normalize(feats, p=2, dim=-1) # (N, D)
 
 
100
 
 
101
 
102
+ def encode_text_df(
103
+ model: torch.nn.Module,
104
+ tokenizer: callable,
105
+ df: pd.DataFrame,
106
+ col_name: str,
107
+ device: Union[str, torch.device]
108
+ ) -> torch.Tensor:
109
  """
110
+ Encodes an entire DataFrame column into (N, D) embeddings.
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  """
112
+ texts = df[col_name].astype(str).tolist()
113
+ return encode_texts(model, tokenizer, texts, device)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
 
src/requirements.txt CHANGED
@@ -11,4 +11,5 @@ torchvision==0.18.1
11
  open_clip_torch==2.26.1
12
  pillow==10.4.0
13
  ipykernel==6.29.5
 
14
 
 
11
  open_clip_torch==2.26.1
12
  pillow==10.4.0
13
  ipykernel==6.29.5
14
+ ipywidgets==8.1.6
15