| import torch | |
| def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device) | |
| img2cap_match_idx = similarity_matrix.argmax(dim=1) | |
| cap2img_match_idx = similarity_matrix.argmax(dim=0) | |
| img_acc = (img2cap_match_idx == y).float().mean() | |
| cap_acc = (cap2img_match_idx == y).float().mean() | |
| return img_acc, cap_acc | |