|
|
import numpy as np |
|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
import os |
|
|
import scanpy as sc |
|
|
import json |
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
def annotate_with_bulk(img_features, bulk_features, normalize=True, T=1, tensor=False): |
|
|
""" |
|
|
Annotates tissue image with similarity scores between image features and bulk RNA-seq features. |
|
|
|
|
|
:param img_features: Feature matrix representing histopathology image features. |
|
|
:param bulk_features: Feature vector representing bulk RNA-seq features. |
|
|
:param normalize: Whether to normalize similarity scores, default=True. |
|
|
:param T: Temperature parameter to control the sharpness of the softmax distribution. Higher values result in a smoother distribution. |
|
|
:param tensor: Feature format in torch tensor or not, default=False. |
|
|
|
|
|
:return: An array or tensor containing the normalized similarity scores. |
|
|
""" |
|
|
|
|
|
if tensor: |
|
|
|
|
|
cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) |
|
|
similarity = cosine_similarity(img_features, bulk_features.unsqueeze(0)) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
normalization_factor = torch.sqrt(torch.tensor([bulk_features.shape[0]], dtype=torch.float)) |
|
|
similarity = similarity / normalization_factor |
|
|
|
|
|
|
|
|
similarity = similarity.unsqueeze(0) |
|
|
similarity = similarity / T |
|
|
|
|
|
|
|
|
similarity = torch.nn.functional.softmax(similarity, dim=-1) |
|
|
|
|
|
else: |
|
|
|
|
|
similarity = np.dot(img_features.T, bulk_features) |
|
|
|
|
|
|
|
|
max_similarity = np.max(similarity) |
|
|
similarity = np.exp(similarity - max_similarity) / np.sum(np.exp(similarity - max_similarity)) |
|
|
|
|
|
|
|
|
similarity = (similarity - np.min(similarity)) / (np.max(similarity) - np.min(similarity)) |
|
|
|
|
|
return similarity |
|
|
|
|
|
|
|
|
|
|
|
def annotate_with_marker_genes(classes, image_embeddings, all_text_embeddings): |
|
|
""" |
|
|
Annotates tissue image with similarity scores between image features and marker gene features. |
|
|
|
|
|
:param classes: A list or array of tissue type labels. |
|
|
:param image_embeddings: A numpy array or torch tensor of image embeddings (shape: [n_images, embedding_dim]). |
|
|
:param all_text_embeddings: A numpy array or torch tensor of text embeddings of the marker genes |
|
|
(shape: [n_classes, embedding_dim]). |
|
|
|
|
|
:return: |
|
|
- dot_similarity: The matrix of dot product similarities between image embeddings and text embeddings. |
|
|
- pred_class: The predicted tissue type for the image based on the highest similarity score. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
dot_similarity = image_embeddings @ all_text_embeddings.T |
|
|
|
|
|
|
|
|
|
|
|
pred_class = classes[dot_similarity.argmax()] |
|
|
|
|
|
return dot_similarity, pred_class |
|
|
|
|
|
|
|
|
|
|
|
def load_image_annotation(image_path): |
|
|
""" |
|
|
Loads an image with annotation. |
|
|
|
|
|
:param image_path: The file path to the image. |
|
|
|
|
|
:return: The processed image, converted to BGR color space and of type uint8. |
|
|
""" |
|
|
|
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
image = image.astype(np.uint8) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|