File size: 1,061 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from utils.helpers import get_token_embeddings

def compute_similarity_matrix(model, dataset, device):
    """
    Compute the similarity matrix for the dataset.
    Args:
        model (torch.nn.Module): Model.
        dataset (torch.utils.data.Dataset): Dataset.
        device (str): Device to use.
    Returns:
        np.ndarray: Similarity matrix.
    """
    embeddings = get_token_embeddings(model, dataset, device)  # shape: (n_samples, seq_len, d_model)
    
    # Compute the mean embedding for each token across all samples
    mean_token_embeddings = embeddings.mean(dim=0)  # shape: (seq_len, d_model)
    
     # Normalize the mean token embeddings (for cosine similarity)
    mean_token_embeddings = mean_token_embeddings / mean_token_embeddings.norm(dim=1, keepdim=True)
    
    # Calculate cosine similarity for all pairs of tokens using matrix multiplication
    similarity_matrix = torch.mm(mean_token_embeddings, mean_token_embeddings.T).cpu().numpy()
    
    
    return similarity_matrix  # Convert to numpy array if needed