File size: 1,304 Bytes
1e315b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3):
"""
Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings.
:param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]).
:param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]).
:param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column.
:param k: The number of top similar samples to retrieve. Default is 3.
:return: A list of the filenames or indices corresponding to the top-k similar samples.
"""
# Compute the dot product (similarity) between the image embeddings and all ST embeddings
dot_similarity = image_embeddings @ all_text_embeddings.T
# Retrieve the top-k most similar samples by similarity score (dot product)
values, indices = torch.topk(dot_similarity.squeeze(0), k)
# Extract the image filenames or indices from the DataFrame based on the top-k matches
image_filenames = dataframe['img_idx'].values
matches = [image_filenames[idx] for idx in indices]
return matches
|