File size: 5,477 Bytes
68fd999
 
 
 
c795cd4
 
1310186
c795cd4
 
68fd999
c795cd4
 
 
68fd999
c795cd4
 
76dd5e2
68fd999
 
c795cd4
68fd999
 
 
 
 
 
 
 
 
 
 
c795cd4
 
 
 
 
 
 
 
 
68fd999
1310186
c795cd4
 
ab4ff40
c795cd4
7b687d4
 
ab4ff40
76dd5e2
7b687d4
68fd999
 
7b687d4
c795cd4
 
 
bccb1fa
68fd999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bccb1fa
c795cd4
 
 
 
68fd999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edfee12
68fd999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76dd5e2
68fd999
 
 
 
bccb1fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# This script defines functions that search the corpus for blocks that are similar to the query.
# Loading embeddings of the query had to be changed for deployment in production because
# my CSVs took too much space for the free tier of HuggingFace spaces.

import polars as pl
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics.pairwise import cosine_similarity
from huggingface_hub import hf_hub_download
import numpy as np
from numpy.typing import NDArray
from joblib import load
import scipy
import fasttext
from collections.abc import Callable


def query_worker(query: str, rownames: list[str], fasttext_model: fasttext.FastText._FastText, idf: NDArray[np.float64], dtm_svd: NDArray[np.float64], dtm_svd_mat: NDArray[np.float64], vocab_norm: NDArray[np.float64], concentration: float = 10 ) -> pl.DataFrame:
    """
    Calculate the cosine similarity of the query to each block of text from the corpus.

    Parameters:
        query (str): Search query
        fasttext_model (fasttext.FastText._FastText): 
        idf (numpy.ndarray):
        dtm_svd (numpy.ndarray): 
        dtm_svd_mat (numpy.ndarray):
        vocab_norm (numpy.ndarray):
        concentration (float):
    Returns:
        polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
    """

    # query embeddings:
    query_embeddings = np.array([fasttext_model.get_word_vector(term) for term in query.split()])

    # Normalize rows
    query_norm = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

    # Compute cosine similarity matrix
    query_similarities = np.dot(query_norm, vocab_norm.T)
    query_tfidf = idf * scipy.special.softmax(query_similarities * concentration, axis = 1)
    query_weights = np.mean(dtm_svd.transform(query_tfidf), axis=0)

    # calculate the average TF-IDF score of the query over topics:
    mean_query_score = np.reshape(cosine_similarity(np.reshape(query_weights, shape = (1, -1)), dtm_svd_mat), shape=-1)

    sorted_df = pl.DataFrame(
        {
            'score-tfidf': mean_query_score,
            'file': rownames
        }).sort("score-tfidf", descending = True).with_columns(pl.Series("rank-tfidf", [i + 1 for i in range(len(mean_query_score))]))

    #return the sorted results
    return(sorted_df)



def query_factory(rownames: list[str], fasttext_model: fasttext.FastText._FastText, idf: NDArray[np.float64], dtm_svd: NDArray[np.float64], dtm_svd_mat: NDArray[np.float64], vocab_norm: NDArray[np.float64], concentration: float = 10) -> Callable[[str], pl.DataFrame]:
    """
    Create a function that will compare query text to the documents in the corpus.

    Parameters:
        dtm_svd (np.ndarray): 
    """

    def do_query(query: str) -> pl.DataFrame:
        """
        Call the worker that compares the query term distribution to the documents in the corpus

        Parameters:
            query (str): Text to compare to the documents

        Returns:
            polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
        """
        return query_worker(query, rownames, fasttext_model, idf, dtm_svd, dtm_svd_mat, vocab_norm, concentration)
    
    return do_query



def create_tfidf_search_function(dtm_df_path: str, vectorizer_path: str, model_name: str = "facebook/fasttext-en-vectors") -> Callable[[str], pl.DataFrame]:
    """
    Create a function that compares the word distribution in a query to each document in the corpus.

    Parameters:
        dtm_df_path (str): Path to a TF-IDF document-term matrix (DTM) for the corpus in parquet format.
        vectorizer_path (str): Path to the saved vectorizer that generated the DTM saved at `csv_path`. We expect that the vectorizer was dumped to disk by `joblib`.
        model_name (str): Name of a model on HuggingFace that generates word embeddings (default is 'facebook/fasttext-en-vectors'.)"
    
    Returns: 
        callable: Function that compares the query string to the corpus.
    """

    # load the fasttext model
    fasttext_model = fasttext.load_model(hf_hub_download(model_name, "model.bin"))

    # load the TF-IDF and DTM
    my_df = pl.read_parquet(dtm_df_path)
    my_vectorizer = load(vectorizer_path)

    # vocab embeddings:
    my_vocabulary = my_vectorizer.get_feature_names_out()
    vocab_embeddings = np.array([fasttext_model.get_word_vector(term) for term in my_vocabulary])
    keep_terms = [any(vocab_embeddings[i,] != 0) for i in range(vocab_embeddings.shape[0])]

    # drop terms that have no embeddings in the fasttext model:
    vocab_embeddings = vocab_embeddings[keep_terms, :]
    my_vocabulary = my_vocabulary[keep_terms]

    # get just IDF document-term matrix of the corpus:
    my_idf = np.reshape(my_vectorizer.idf_[keep_terms], shape=(-1, vocab_embeddings.shape[0]))

    # calculate length of each embedding vector
    vocab_norm = vocab_embeddings / np.linalg.norm(vocab_embeddings, axis=1, keepdims=True)

    # get the document-term matrix and project it to 300 pseudo-topics.
    filenames = my_df["file"].to_list()
    doc_term_mat = my_df.select(pl.exclude(["file"]))[:,keep_terms]
    dtm_svd = TruncatedSVD(n_components=300)
    X_svd = dtm_svd.fit_transform(doc_term_mat)

    return query_factory(rownames = filenames, fasttext_model = fasttext_model, idf = my_idf, dtm_svd = dtm_svd, dtm_svd_mat = X_svd, vocab_norm=vocab_norm, concentration = 30)