import torch from sklearn.metrics.pairwise import cosine_similarity def match_captions(image_features, captions, clip_model, processor): text_inputs = processor(text=captions, return_tensors="pt", padding=True) with torch.no_grad(): text_features = clip_model.get_text_features(**text_inputs) image_features = image_features.detach().cpu().numpy() text_features = text_features.detach().cpu().numpy() similarities = cosine_similarity(image_features, text_features) best_indices = similarities.argsort(axis=1)[0][::-1] best_captions = [captions[i] for i in best_indices] return best_captions, similarities[0][best_indices].tolist()