File size: 2,472 Bytes
34a1c85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import torch.nn.functional as F

model_name = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
df = pd.read_csv('data/CCSS Common Core Standards(English Standards).csv')
df.dropna(inplace=True)

# Reset index to align doc IDs

class splade_utility:
    def __init__(self, query, top_n=5):
        self.query = query
        self.top_n = top_n
        
    @staticmethod
    def get_splade_sparse_vector(text):
        with torch.no_grad():
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            logits = model(**inputs).logits.squeeze(0)  # [seq_len, vocab_size]
            relu_out = F.relu(logits)
            splade_weights = torch.log1p(relu_out).max(dim=0).values
            indices = torch.nonzero(splade_weights).squeeze()
            return {
                tokenizer.convert_ids_to_tokens([i.item()])[0]: splade_weights[i].item()
                for i in indices
            }
        
    def dot_product_sparse(self , query_vec, doc_vec):
        return sum(query_vec.get(term, 0.0) * doc_vec.get(term, 0.0) for term in query_vec)

    def retrieve_top_n_splade(self):
        query_vec = self.get_splade_sparse_vector(self.query)
        scores = [
            (self.dot_product_sparse(query_vec, doc_vec), idx)
            for idx, doc_vec in enumerate(splade_doc_vectors)
        ]
        
        top_matches = sorted(scores, reverse=True)[:self.top_n]
        
        results = []
        for score, idx in top_matches:
            results.append({
                "score": round(score, 4),
                "standard": df.iloc[idx]["State Standard"],
                "ID": df.iloc[idx]["ID"],
                "Category": df.iloc[idx]["Category"],
                "Sub Category": df.iloc[idx]["Sub Category"]
            })
        return results

df = df.reset_index(drop=True)

# Get list of standard texts
standard_texts = df["State Standard"].astype(str).tolist()

# Compute sparse vectors
splade_doc_vectors = [splade_utility.get_splade_sparse_vector(text) for text in (standard_texts)]


# Example usage
query = "determine main idea text explain supported key detail summarize text"
splade_instance = splade_utility(query)
results = splade_instance.retrieve_top_n_splade()
print(results)