File size: 676 Bytes
e02c264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from sklearn.metrics.pairwise import cosine_similarity

def match_captions(image_features, captions, clip_model, processor):
    text_inputs = processor(text=captions, return_tensors="pt", padding=True)
    with torch.no_grad():
        text_features = clip_model.get_text_features(**text_inputs)

    image_features = image_features.detach().cpu().numpy()
    text_features = text_features.detach().cpu().numpy()

    similarities = cosine_similarity(image_features, text_features)
    best_indices = similarities.argsort(axis=1)[0][::-1]  
    best_captions = [captions[i] for i in best_indices]

    return best_captions, similarities[0][best_indices].tolist()