File size: 676 Bytes
e02c264 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch
from sklearn.metrics.pairwise import cosine_similarity
def match_captions(image_features, captions, clip_model, processor):
text_inputs = processor(text=captions, return_tensors="pt", padding=True)
with torch.no_grad():
text_features = clip_model.get_text_features(**text_inputs)
image_features = image_features.detach().cpu().numpy()
text_features = text_features.detach().cpu().numpy()
similarities = cosine_similarity(image_features, text_features)
best_indices = similarities.argsort(axis=1)[0][::-1]
best_captions = [captions[i] for i in best_indices]
return best_captions, similarities[0][best_indices].tolist() |