CCSS_Alignment / core /splade_utility.py
meruem123's picture
Upload 15 files
34a1c85 verified
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import torch.nn.functional as F
model_name = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
df = pd.read_csv('data/CCSS Common Core Standards(English Standards).csv')
df.dropna(inplace=True)
# Reset index to align doc IDs
class splade_utility:
def __init__(self, query, top_n=5):
self.query = query
self.top_n = top_n
@staticmethod
def get_splade_sparse_vector(text):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
logits = model(**inputs).logits.squeeze(0) # [seq_len, vocab_size]
relu_out = F.relu(logits)
splade_weights = torch.log1p(relu_out).max(dim=0).values
indices = torch.nonzero(splade_weights).squeeze()
return {
tokenizer.convert_ids_to_tokens([i.item()])[0]: splade_weights[i].item()
for i in indices
}
def dot_product_sparse(self , query_vec, doc_vec):
return sum(query_vec.get(term, 0.0) * doc_vec.get(term, 0.0) for term in query_vec)
def retrieve_top_n_splade(self):
query_vec = self.get_splade_sparse_vector(self.query)
scores = [
(self.dot_product_sparse(query_vec, doc_vec), idx)
for idx, doc_vec in enumerate(splade_doc_vectors)
]
top_matches = sorted(scores, reverse=True)[:self.top_n]
results = []
for score, idx in top_matches:
results.append({
"score": round(score, 4),
"standard": df.iloc[idx]["State Standard"],
"ID": df.iloc[idx]["ID"],
"Category": df.iloc[idx]["Category"],
"Sub Category": df.iloc[idx]["Sub Category"]
})
return results
df = df.reset_index(drop=True)
# Get list of standard texts
standard_texts = df["State Standard"].astype(str).tolist()
# Compute sparse vectors
splade_doc_vectors = [splade_utility.get_splade_sparse_vector(text) for text in (standard_texts)]
# Example usage
query = "determine main idea text explain supported key detail summarize text"
splade_instance = splade_utility(query)
results = splade_instance.retrieve_top_n_splade()
print(results)