meruem123 commited on
Commit
34a1c85
·
verified ·
1 Parent(s): 703a7c0

Upload 15 files

Browse files
Readme.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📚 CCSS Alignment with BM25 & SPLADE
2
+
3
+ This project allows you to align input educational text (lesson plans, learning objectives) with Common Core State Standards (ELA) using two retrieval techniques:
4
+
5
+ - **BM25** (sparse lexical search)
6
+ - **SPLADE** (sparse transformer embeddings)
7
+
8
+ ## 🚀 How to Run the App
9
+
10
+ Make sure you're in the project root folder, then run:
11
+
12
+ ```bash
13
+ streamlit run app.py
14
+ ```
15
+
16
+ You will be able to:
17
+ - Select either BM25 or SPLADE
18
+ - Input a query (e.g., "identify key ideas and details")
19
+ - View top-matching CCSS standards
20
+ - Compare accuracy between both retrieval models
21
+
22
+ ## 🧪 Sample Starter Code for app.py
23
+
24
+ ```python
25
+ import streamlit as st
26
+ from core.bm25_utility import bm25_utility
27
+ from core.splade_utility import SpladeUtility
28
+
29
+ query = st.text_input("Enter your query:")
30
+ method = st.selectbox("Choose retrieval method", ["BM25", "SPLADE"])
31
+
32
+ if st.button("Get Standards"):
33
+ if method == "BM25":
34
+ results = bm25_utility(query).retrieve_top_n_bm25()
35
+ else:
36
+ results = SpladeUtility(query).retrieve_top_n_splade()
37
+
38
+ for r in results:
39
+ st.write(f"**{r['ID']}** - {r['standard']} (Score: {r['score']})")
40
+ ```
41
+
42
+
43
+ ## 📝 Notes
44
+
45
+ - Ensure that model weights for SPLADE are downloaded or cached.
46
+ - Make sure you're using cleaned and preprocessed CCSS data for accurate matching.
47
+ - Streamlit interface supports rapid switching between BM25 and SPLADE for testing.
48
+
49
+ ---
50
+
51
+ **Author**: Shivendra Gupta
52
+ **Purpose**: Educational NLP for aligning teaching content to learning standards.
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+
6
+ from core.splade_utility import splade_utility
7
+ from core.bm25_utility import bm25_utility
8
+
9
+
10
+
11
+ # ==== Import your models and functions ====
12
+ # Assume these are already defined in your notebook/script
13
+ # from your_code import retrieve_top_n_bm25, retrieve_top_n_splade, evaluate_top1_on_state_standard
14
+
15
+ # Dummy accuracy values (replace with real ones from your code)
16
+ bm25_accuracy = 0.9959
17
+ splade_accuracy = 0.9797
18
+
19
+ # Dummy placeholder functions (replace with your actual ones)
20
+ def retrieve_top_n_bm25(query, top_n=5):
21
+ bm25_utility_instance = bm25_utility(query, top_n=5)
22
+ top_n_results = bm25_utility_instance.retrieve_top_n_bm25()
23
+ return top_n_results
24
+
25
+ def retrieve_top_n_splade(query, top_n=5):
26
+ splade_utility_instance = splade_utility(query, top_n=top_n)
27
+ return splade_utility_instance.retrieve_top_n_splade()
28
+
29
+ # ==== Streamlit UI ====
30
+
31
+ st.set_page_config(page_title="CCSS Alignment", layout="centered")
32
+ st.title("📚 CCSS Alignment Search")
33
+
34
+ # Select model
35
+ model_choice = st.radio("Select Retrieval Model:", ["BM25", "SPLADE"])
36
+
37
+ # Accuracy bar chart
38
+ st.subheader("🎯 Model Top-1 Accuracy")
39
+ fig, ax = plt.subplots()
40
+ ax.bar(["BM25", "SPLADE"], [bm25_accuracy, splade_accuracy], color=["skyblue", "lightgreen"])
41
+ ax.set_ylim([0.9, 1.01])
42
+ ax.set_ylabel("Top-1 Accuracy")
43
+ for i, acc in enumerate([bm25_accuracy, splade_accuracy]):
44
+ ax.text(i, acc + 0.001, f"{acc:.4f}", ha='center', fontsize=10)
45
+ st.pyplot(fig)
46
+
47
+ # Query input
48
+ st.subheader("🔍 Try a Query")
49
+ query = st.text_area("Enter a lesson or objective text:", height=100)
50
+
51
+ # Search button
52
+ if st.button("Search"):
53
+ st.subheader("📄 Top Results")
54
+
55
+ if model_choice == "BM25":
56
+ results = retrieve_top_n_bm25(query, top_n=5)
57
+ else:
58
+ results = retrieve_top_n_splade(query, top_n=5)
59
+
60
+ if results:
61
+ for i, r in enumerate(results, 1):
62
+ st.markdown(f"""
63
+ **Rank {i}**
64
+ - **Standard**: {r['standard']}
65
+ - **ID**: {r.get('ID', 'N/A')}
66
+ - **Category**: {r.get('Category', 'N/A')}
67
+ - **Sub Category**: {r.get('Sub Category', 'N/A')}
68
+ - **Score**: `{r['score']}`
69
+ """)
70
+ else:
71
+ st.warning("No results found.")
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
core/__pycache__/bm25_utility.cpython-312.pyc ADDED
Binary file (2.77 kB). View file
 
core/__pycache__/preprocessing_pipeline.cpython-312.pyc ADDED
Binary file (3.25 kB). View file
 
core/__pycache__/splade_utility.cpython-312.pyc ADDED
Binary file (4.6 kB). View file
 
core/bm25_utility.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from core.preprocessing_pipeline import preprocessing_pipeline
3
+ from rank_bm25 import BM25Okapi
4
+ import pandas as pd
5
+
6
+ # Load the dataset
7
+ df = pd.read_csv('data/CCSS Common Core Standards(English Standards).csv')
8
+ df.dropna(inplace=True)
9
+ df['State Standard'] = df['State Standard'].apply(lambda x: preprocessing_pipeline(x).preprocess())
10
+
11
+ # Tokenize the documents for BM25
12
+ tokenized_docs = [doc.lower().split() for doc in df['State Standard']]
13
+ bm25 = BM25Okapi(tokenized_docs)
14
+
15
+
16
+ class bm25_utility:
17
+ def __init__(self,text,top_n=5):
18
+ self.text = text
19
+ self.top_n = top_n
20
+
21
+ def retrieve_top_n_bm25(self):
22
+ preprocessing_pipeline_instance = preprocessing_pipeline(self.text)
23
+ preprocessed_text = preprocessing_pipeline_instance.preprocess()
24
+ tokenized_query = preprocessed_text.split()
25
+
26
+ scores = bm25.get_scores(tokenized_query)
27
+
28
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:self.top_n]
29
+
30
+
31
+ # ID Category Sub Category State Standard
32
+
33
+ results = []
34
+ for idx in top_indices:
35
+ row = df.iloc[idx]
36
+ results.append({
37
+ "ID": row["ID"],
38
+ "Category": row["Category"],
39
+ "Sub Category": row["Sub Category"],
40
+ "standard": row["State Standard"],
41
+ "score": round(scores[idx], 4)
42
+
43
+ })
44
+ return results
45
+
46
+ query = "Identify the main idea of a text"
47
+ bm25_utility_instance = bm25_utility(query, top_n=5)
48
+ top_n_results = bm25_utility_instance.retrieve_top_n_bm25()
49
+ print(top_n_results)
core/main.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from preprocessing_pipeline import preprocessing_pipeline
2
+
3
+ preprocessing_pipeline_instance = preprocessing_pipeline("i am a student and i am learning how to code. hi, how are you? and what are you doing?")
4
+ preprocessed_text = preprocessing_pipeline_instance.preprocess()
5
+ print(preprocessed_text)
core/preprocessing_pipeline.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+ import re
4
+ from nltk.tokenize import word_tokenize
5
+ from nltk.corpus import stopwords
6
+ from nltk.stem import WordNetLemmatizer
7
+ stop_words = set(stopwords.words('english'))
8
+
9
+ class preprocessing_pipeline:
10
+ def __init__(self,text):
11
+ self.text = text
12
+
13
+ def preprocess(self):
14
+ self.text = self.clean_text(self.text)
15
+ self.text = self.lowercase(self.text)
16
+ self.text = self.remove_punctuation(self.text)
17
+ self.text = self.remove_stopwords(self.text)
18
+ self.text = self.lemmatize_tokens(self.text)
19
+ return self.text
20
+
21
+ def clean_text(self , text: str) -> str:
22
+ text = text.strip()
23
+ text = text.replace("\n", " ").replace("\xa0", " ")
24
+ text = text.replace("“", "\"").replace("”", "\"").replace("–", "-")
25
+ return text
26
+
27
+ def lowercase(self, text: str) -> str:
28
+ return text.lower()
29
+
30
+ def remove_punctuation(self, text: str) -> str:
31
+ return re.sub(r"[^\w\s]", "", text)
32
+
33
+ def remove_stopwords(self, text: str) -> str:
34
+ tokens = word_tokenize(text)
35
+ return ' '.join([word for word in tokens if word not in stop_words])
36
+
37
+ def lemmatize_tokens(self, text: str) -> str:
38
+ tokens = word_tokenize(text)
39
+ lemmatizer = WordNetLemmatizer()
40
+ return ' '.join([lemmatizer.lemmatize(token) for token in tokens])
41
+
42
+
core/splade_utility.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ model_name = "naver/splade-cocondenser-ensembledistil"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
9
+ model.eval()
10
+ df = pd.read_csv('data/CCSS Common Core Standards(English Standards).csv')
11
+ df.dropna(inplace=True)
12
+
13
+ # Reset index to align doc IDs
14
+
15
+ class splade_utility:
16
+ def __init__(self, query, top_n=5):
17
+ self.query = query
18
+ self.top_n = top_n
19
+
20
+ @staticmethod
21
+ def get_splade_sparse_vector(text):
22
+ with torch.no_grad():
23
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
24
+ logits = model(**inputs).logits.squeeze(0) # [seq_len, vocab_size]
25
+ relu_out = F.relu(logits)
26
+ splade_weights = torch.log1p(relu_out).max(dim=0).values
27
+ indices = torch.nonzero(splade_weights).squeeze()
28
+ return {
29
+ tokenizer.convert_ids_to_tokens([i.item()])[0]: splade_weights[i].item()
30
+ for i in indices
31
+ }
32
+
33
+ def dot_product_sparse(self , query_vec, doc_vec):
34
+ return sum(query_vec.get(term, 0.0) * doc_vec.get(term, 0.0) for term in query_vec)
35
+
36
+ def retrieve_top_n_splade(self):
37
+ query_vec = self.get_splade_sparse_vector(self.query)
38
+ scores = [
39
+ (self.dot_product_sparse(query_vec, doc_vec), idx)
40
+ for idx, doc_vec in enumerate(splade_doc_vectors)
41
+ ]
42
+
43
+ top_matches = sorted(scores, reverse=True)[:self.top_n]
44
+
45
+ results = []
46
+ for score, idx in top_matches:
47
+ results.append({
48
+ "score": round(score, 4),
49
+ "standard": df.iloc[idx]["State Standard"],
50
+ "ID": df.iloc[idx]["ID"],
51
+ "Category": df.iloc[idx]["Category"],
52
+ "Sub Category": df.iloc[idx]["Sub Category"]
53
+ })
54
+ return results
55
+
56
+ df = df.reset_index(drop=True)
57
+
58
+ # Get list of standard texts
59
+ standard_texts = df["State Standard"].astype(str).tolist()
60
+
61
+ # Compute sparse vectors
62
+ splade_doc_vectors = [splade_utility.get_splade_sparse_vector(text) for text in (standard_texts)]
63
+
64
+
65
+ # Example usage
66
+ query = "determine main idea text explain supported key detail summarize text"
67
+ splade_instance = splade_utility(query)
68
+ results = splade_instance.retrieve_top_n_splade()
69
+ print(results)
data/CCSS Common Core Standards(English Standards).csv ADDED
The diff for this file is too large to render. See raw diff
 
data/data.csv ADDED
The diff for this file is too large to render. See raw diff
 
notebook/ccss_standard_mapper.ipynb ADDED
@@ -0,0 +1,1120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e2192e55",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Project: CCSS Standard Alignment using BM25 and SPLADE\n",
9
+ "\n",
10
+ "---\n",
11
+ "\n",
12
+ "## Background\n",
13
+ "\n",
14
+ "### BM25 (Best Matching 25)\n",
15
+ "\n",
16
+ "BM25 is a **traditional lexical retrieval model** used in information retrieval systems (like search engines). It ranks documents based on the **term frequency–inverse document frequency (TF-IDF)** concept, with additional normalization for document length.\n",
17
+ "\n",
18
+ "**Core Characteristics:**\n",
19
+ "- Lexical-only: matches exact words (not synonyms/paraphrases)\n",
20
+ "- Scores documents using a tunable function of:\n",
21
+ " - **Term frequency (TF)** – how often a query term appears in the doc\n",
22
+ " - **Inverse Document Frequency (IDF)** – how rare the term is overall\n",
23
+ " - **Document length normalization**\n",
24
+ "- Fast and interpretable\n",
25
+ "\n",
26
+ "**Strengths:**\n",
27
+ "- Simple and fast\n",
28
+ "- Strong for keyword-heavy queries\n",
29
+ "- Works well on small datasets\n",
30
+ "\n",
31
+ "**Limitations:**\n",
32
+ "- Cannot understand synonyms, rephrasing, or context\n",
33
+ "\n",
34
+ "---\n",
35
+ "\n",
36
+ "### SPLADE (Sparse Lexical and Expansion Model)\n",
37
+ "\n",
38
+ "SPLADE is a **neural sparse retriever** that combines the **interpretability of sparse vectors** with the **semantic power of transformers (like BERT)**.\n",
39
+ "\n",
40
+ "**How it works:**\n",
41
+ "- Instead of dense embeddings (like BERT or SBERT), SPLADE generates **sparse term-weighted vectors**\n",
42
+ "- These vectors can:\n",
43
+ " - Activate terms **not explicitly in the query** (semantic expansion)\n",
44
+ " - Assign importance scores to vocabulary terms\n",
45
+ "- Supports use of **inverted indexes** like BM25, but with neural knowledge\n",
46
+ "\n",
47
+ "**Strengths:**\n",
48
+ "- Captures paraphrasing and synonyms\n",
49
+ "- Sparse and interpretable\n",
50
+ "- Works better on natural language queries\n",
51
+ "\n",
52
+ "**Limitations:**\n",
53
+ "- Slower than BM25\n",
54
+ "- Requires GPU for efficient inference\n",
55
+ "\n",
56
+ "---\n",
57
+ "\n",
58
+ "## Project Overview\n",
59
+ "\n",
60
+ "### Goal:\n",
61
+ "\n",
62
+ "Build a system that **automatically aligns educational content (e.g., lesson descriptions, learning objectives)** to the most relevant **Common Core State Standards (CCSS)** for English Language Arts (ELA).\n",
63
+ "\n",
64
+ "---\n",
65
+ "\n",
66
+ "### Approach:\n",
67
+ "\n",
68
+ "We implement and compare **two retrieval pipelines**:\n",
69
+ "\n",
70
+ "| Component | Pipeline 1 | Pipeline 2 |\n",
71
+ "|---------------|----------------------|------------------------|\n",
72
+ "| Model | BM25 | SPLADE |\n",
73
+ "| Representation | Token frequency | Sparse transformer weights |\n",
74
+ "| Input | Free-form text | Free-form text |\n",
75
+ "| Output | Top-N most relevant CCSS standards with scores |\n",
76
+ "\n",
77
+ "---\n",
78
+ "\n",
79
+ "### Dataset:\n",
80
+ "\n",
81
+ "- Source: `CCSS Common Core Standards.xlsx`\n",
82
+ "- Focus: Only **ELA standards**\n",
83
+ "- Fields used: `ID`, `Sub Category`, `State Standard`\n",
84
+ "\n",
85
+ "---\n",
86
+ "\n",
87
+ "### Output Format:\n",
88
+ "\n",
89
+ "Each pipeline returns a list of matches:\n",
90
+ "```json\n",
91
+ "[\n",
92
+ " {\n",
93
+ " \"rank\": 1,\n",
94
+ " \"score\": 10.87,\n",
95
+ " \"ID\": \"4.RI.2\",\n",
96
+ " \"Category\": \"Reading Informational\",\n",
97
+ " \"Sub Category\": \"Key Ideas and Details\",\n",
98
+ " \"State Standard\": \"Determine the main idea of a text...\"\n",
99
+ " },\n",
100
+ " ...\n",
101
+ "]\n"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 28,
107
+ "id": "cfa8b1b6",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "import pandas as pd\n",
112
+ "import re\n",
113
+ "from nltk.tokenize import word_tokenize\n",
114
+ "from nltk.corpus import stopwords\n",
115
+ "from nltk.stem import WordNetLemmatizer"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 46,
121
+ "id": "748918e3",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "stop_words = set(stopwords.words('english'))"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 62,
131
+ "id": "3cf17d78",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "df = pd.read_csv('/Users/shivendragupta/Desktop/internship25/CCSS/data/CCSS Common Core Standards(English Standards).csv')"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 63,
141
+ "id": "ee2b47e5",
142
+ "metadata": {},
143
+ "outputs": [
144
+ {
145
+ "data": {
146
+ "text/html": [
147
+ "<div>\n",
148
+ "<style scoped>\n",
149
+ " .dataframe tbody tr th:only-of-type {\n",
150
+ " vertical-align: middle;\n",
151
+ " }\n",
152
+ "\n",
153
+ " .dataframe tbody tr th {\n",
154
+ " vertical-align: top;\n",
155
+ " }\n",
156
+ "\n",
157
+ " .dataframe thead th {\n",
158
+ " text-align: right;\n",
159
+ " }\n",
160
+ "</style>\n",
161
+ "<table border=\"1\" class=\"dataframe\">\n",
162
+ " <thead>\n",
163
+ " <tr style=\"text-align: right;\">\n",
164
+ " <th></th>\n",
165
+ " <th>ID</th>\n",
166
+ " <th>Category</th>\n",
167
+ " <th>Sub Category</th>\n",
168
+ " <th>State Standard</th>\n",
169
+ " </tr>\n",
170
+ " </thead>\n",
171
+ " <tbody>\n",
172
+ " <tr>\n",
173
+ " <th>0</th>\n",
174
+ " <td>K.RL.1</td>\n",
175
+ " <td>Reading Literature</td>\n",
176
+ " <td>Key Ideas and Details</td>\n",
177
+ " <td>With prompting and support, ask and answer que...</td>\n",
178
+ " </tr>\n",
179
+ " <tr>\n",
180
+ " <th>1</th>\n",
181
+ " <td>K.RL.2</td>\n",
182
+ " <td>Reading Literature</td>\n",
183
+ " <td>Key Ideas and Details</td>\n",
184
+ " <td>With prompting and support, retell familiar st...</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <th>2</th>\n",
188
+ " <td>K.RL.3</td>\n",
189
+ " <td>Reading Literature</td>\n",
190
+ " <td>Key Ideas and Details</td>\n",
191
+ " <td>With prompting and support, identify character...</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>3</th>\n",
195
+ " <td>K.RL.4</td>\n",
196
+ " <td>Reading Literature</td>\n",
197
+ " <td>Craft and Structure</td>\n",
198
+ " <td>Ask and answer questions about unknown words i...</td>\n",
199
+ " </tr>\n",
200
+ " <tr>\n",
201
+ " <th>4</th>\n",
202
+ " <td>K.RL.5</td>\n",
203
+ " <td>Reading Literature</td>\n",
204
+ " <td>Craft and Structure</td>\n",
205
+ " <td>Recognize common types of texts (e.g., storybo...</td>\n",
206
+ " </tr>\n",
207
+ " </tbody>\n",
208
+ "</table>\n",
209
+ "</div>"
210
+ ],
211
+ "text/plain": [
212
+ " ID Category Sub Category \\\n",
213
+ "0 K.RL.1 Reading Literature Key Ideas and Details \n",
214
+ "1 K.RL.2 Reading Literature Key Ideas and Details \n",
215
+ "2 K.RL.3 Reading Literature Key Ideas and Details \n",
216
+ "3 K.RL.4 Reading Literature Craft and Structure \n",
217
+ "4 K.RL.5 Reading Literature Craft and Structure \n",
218
+ "\n",
219
+ " State Standard \n",
220
+ "0 With prompting and support, ask and answer que... \n",
221
+ "1 With prompting and support, retell familiar st... \n",
222
+ "2 With prompting and support, identify character... \n",
223
+ "3 Ask and answer questions about unknown words i... \n",
224
+ "4 Recognize common types of texts (e.g., storybo... "
225
+ ]
226
+ },
227
+ "execution_count": 63,
228
+ "metadata": {},
229
+ "output_type": "execute_result"
230
+ }
231
+ ],
232
+ "source": [
233
+ "df.head() # Display the first few rows of the DataFrame"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 64,
239
+ "id": "3958653b",
240
+ "metadata": {},
241
+ "outputs": [
242
+ {
243
+ "data": {
244
+ "text/plain": [
245
+ "(1486, 4)"
246
+ ]
247
+ },
248
+ "execution_count": 64,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "df.shape"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 65,
260
+ "id": "0e747290",
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "ID 501\n",
267
+ "Category 501\n",
268
+ "Sub Category 501\n",
269
+ "State Standard 501\n",
270
+ "dtype: int64"
271
+ ]
272
+ },
273
+ "execution_count": 65,
274
+ "metadata": {},
275
+ "output_type": "execute_result"
276
+ }
277
+ ],
278
+ "source": [
279
+ "df.isna().sum()"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": 66,
285
+ "id": "34001c04",
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "df.dropna(inplace=True) # Drop rows with any NaN values"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "id": "5e02750b",
295
+ "metadata": {},
296
+ "source": [
297
+ "# ```Preprocessing data```"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 67,
303
+ "id": "506a332c",
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "def clean_text(text: str) -> str:\n",
308
+ " text = text.strip()\n",
309
+ " text = text.replace(\"\\n\", \" \").replace(\"\\xa0\", \" \")\n",
310
+ " text = text.replace(\"“\", \"\\\"\").replace(\"”\", \"\\\"\").replace(\"–\", \"-\")\n",
311
+ " return text"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "markdown",
316
+ "id": "a97a9cfb",
317
+ "metadata": {},
318
+ "source": [
319
+ "## ```Lower Casing```"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 38,
325
+ "id": "2f843d49",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "def lowercase(text: str) -> str:\n",
330
+ " return text.lower()"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "4bc931f1",
336
+ "metadata": {},
337
+ "source": [
338
+ "## ```Removing Punctuation```"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 39,
344
+ "id": "734d4b30",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "def remove_punctuation(text: str) -> str:\n",
349
+ " return re.sub(r\"[^\\w\\s]\", \"\", text)"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "id": "d12ab012",
355
+ "metadata": {},
356
+ "source": [
357
+ "## ``` Removing Stop Words ```"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 49,
363
+ "id": "b925980c",
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "def remove_stopwords(text: str) -> str:\n",
368
+ " tokens = word_tokenize(text)\n",
369
+ " return ' '.join([word for word in tokens if word not in stop_words])"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "id": "814c3818",
375
+ "metadata": {},
376
+ "source": [
377
+ "## ``` Lemmatization ```"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": 41,
383
+ "id": "70500704",
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "lemmatizer = WordNetLemmatizer()"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 50,
393
+ "id": "7e287fc2",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "def lemmatize_tokens(text: str) -> str:\n",
398
+ " tokens = word_tokenize(text)\n",
399
+ " return ' '.join([lemmatizer.lemmatize(token) for token in tokens])\n"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "id": "39bd33a6",
405
+ "metadata": {},
406
+ "source": [
407
+ "## ``` PipeLine ```"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 68,
413
+ "id": "b443ffec",
414
+ "metadata": {},
415
+ "outputs": [],
416
+ "source": [
417
+ "def preprocessing_pipeline(text: str) -> str:\n",
418
+ " text = clean_text(text)\n",
419
+ " text = lowercase(text)\n",
420
+ " text = remove_punctuation(text)\n",
421
+ " text = remove_stopwords(text)\n",
422
+ " text = lemmatize_tokens(text)\n",
423
+ " return text"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": 69,
429
+ "id": "13a7d65a",
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": [
433
+ "df['State Standard'] = df['State Standard'].apply(preprocessing_pipeline)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": 70,
439
+ "id": "f5a8cb5d",
440
+ "metadata": {},
441
+ "outputs": [
442
+ {
443
+ "data": {
444
+ "text/plain": [
445
+ "0 prompting support ask answer question key deta...\n",
446
+ "1 prompting support retell familiar story includ...\n",
447
+ "2 prompting support identify character setting m...\n",
448
+ "3 ask answer question unknown word text\n",
449
+ "4 recognize common type text eg storybook poem\n",
450
+ " ... \n",
451
+ "980 use technology including internet produce publ...\n",
452
+ "981 conduct short well sustained research project ...\n",
453
+ "982 gather relevant information multiple authorita...\n",
454
+ "983 draw evidence informational text support analy...\n",
455
+ "984 write routinely extended time frame time refle...\n",
456
+ "Name: State Standard, Length: 985, dtype: object"
457
+ ]
458
+ },
459
+ "execution_count": 70,
460
+ "metadata": {},
461
+ "output_type": "execute_result"
462
+ }
463
+ ],
464
+ "source": [
465
+ "df['State Standard']"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "markdown",
470
+ "id": "ecb1369e",
471
+ "metadata": {},
472
+ "source": [
473
+ "## ``` BM25 Retreiver Function ```"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 73,
479
+ "id": "41e30b91",
480
+ "metadata": {},
481
+ "outputs": [],
482
+ "source": [
483
+ "from rank_bm25 import BM25Okapi"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": 74,
489
+ "id": "d34de1f1",
490
+ "metadata": {},
491
+ "outputs": [],
492
+ "source": [
493
+ "tokenized_docs = [doc.lower().split() for doc in df['State Standard']]"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": 77,
499
+ "id": "32594e27",
500
+ "metadata": {},
501
+ "outputs": [],
502
+ "source": [
503
+ "bm25 = BM25Okapi(tokenized_docs)"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": 155,
509
+ "id": "3d552a24",
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "def retrieve_top_n_bm25(query: str, top_n=5):\n",
514
+ " query_tokens = preprocessing_pipeline(query)\n",
515
+ " tokenized_query = query_tokens.split()\n",
516
+ " \n",
517
+ " scores = bm25.get_scores(tokenized_query)\n",
518
+ "\n",
519
+ " top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]\n",
520
+ "\n",
521
+ "\n",
522
+ " # ID\tCategory\tSub Category\tState Standard\n",
523
+ "\n",
524
+ " results = []\n",
525
+ " for idx in top_indices:\n",
526
+ " row = df.iloc[idx]\n",
527
+ " results.append({\n",
528
+ " \"ID\": row[\"ID\"],\n",
529
+ " \"Category\": row[\"Category\"],\n",
530
+ " \"Sub Category\": row[\"Sub Category\"],\n",
531
+ " \"standard\": row[\"State Standard\"],\n",
532
+ " \"score\": round(scores[idx], 4)\n",
533
+ "\n",
534
+ " })\n",
535
+ " return results\n"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": 123,
541
+ "id": "5a18deb6",
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "query = \"Identify the main idea of a text\""
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": 124,
551
+ "id": "11954a8a",
552
+ "metadata": {},
553
+ "outputs": [],
554
+ "source": [
555
+ "results = retrieve_top_n_bm25(query, top_n=5)"
556
+ ]
557
+ },
558
+ {
559
+ "cell_type": "markdown",
560
+ "id": "49538e11",
561
+ "metadata": {},
562
+ "source": [
563
+ "## ``` Top 5 Results from BM25 Retrieval ```"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": 125,
569
+ "id": "f7d12c4d",
570
+ "metadata": {},
571
+ "outputs": [
572
+ {
573
+ "data": {
574
+ "text/plain": [
575
+ "[{'ID': '1.RI.2',\n",
576
+ " 'Category': 'Reading Informational',\n",
577
+ " 'Sub Category': 'Key Ideas and Details',\n",
578
+ " 'State Standard': 'identify main topic retell key detail text',\n",
579
+ " 'score': 10.666},\n",
580
+ " {'ID': '3.RI.2',\n",
581
+ " 'Category': 'Reading Informational',\n",
582
+ " 'Sub Category': 'Key Ideas and Details',\n",
583
+ " 'State Standard': 'determine main idea text recount key detail explain support main idea',\n",
584
+ " 'score': 10.0953},\n",
585
+ " {'ID': 'K.RI.2',\n",
586
+ " 'Category': 'Reading Informational',\n",
587
+ " 'Sub Category': 'Key Ideas and Details',\n",
588
+ " 'State Standard': 'prompting support identify main topic retell key detail text',\n",
589
+ " 'score': 9.8043},\n",
590
+ " {'ID': '2.RI.6',\n",
591
+ " 'Category': 'Reading Informational',\n",
592
+ " 'Sub Category': 'Craft and Structure',\n",
593
+ " 'State Standard': 'identify main purpose text including author want answer explain describe',\n",
594
+ " 'score': 9.4236},\n",
595
+ " {'ID': '2.RI.2',\n",
596
+ " 'Category': 'Reading Informational',\n",
597
+ " 'Sub Category': 'Key Ideas and Details',\n",
598
+ " 'State Standard': 'identify main topic multiparagraph text well focus specific paragraph within text',\n",
599
+ " 'score': 9.3944}]"
600
+ ]
601
+ },
602
+ "execution_count": 125,
603
+ "metadata": {},
604
+ "output_type": "execute_result"
605
+ }
606
+ ],
607
+ "source": [
608
+ "results"
609
+ ]
610
+ },
611
+ {
612
+ "cell_type": "markdown",
613
+ "id": "1dd7ac6e",
614
+ "metadata": {},
615
+ "source": [
616
+ "## ``` Using Splade sparse retreiver```"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 126,
622
+ "id": "f8c3fee3",
623
+ "metadata": {},
624
+ "outputs": [
625
+ {
626
+ "data": {
627
+ "application/vnd.jupyter.widget-view+json": {
628
+ "model_id": "be42ffa9ef0949679ea06670a3436378",
629
+ "version_major": 2,
630
+ "version_minor": 0
631
+ },
632
+ "text/plain": [
633
+ "tokenizer_config.json: 0%| | 0.00/466 [00:00<?, ?B/s]"
634
+ ]
635
+ },
636
+ "metadata": {},
637
+ "output_type": "display_data"
638
+ },
639
+ {
640
+ "data": {
641
+ "application/vnd.jupyter.widget-view+json": {
642
+ "model_id": "e638218262224627be57d394b0bb8d07",
643
+ "version_major": 2,
644
+ "version_minor": 0
645
+ },
646
+ "text/plain": [
647
+ "vocab.txt: 0.00B [00:00, ?B/s]"
648
+ ]
649
+ },
650
+ "metadata": {},
651
+ "output_type": "display_data"
652
+ },
653
+ {
654
+ "data": {
655
+ "application/vnd.jupyter.widget-view+json": {
656
+ "model_id": "448331ce71bf4b98b07f0291f734fd97",
657
+ "version_major": 2,
658
+ "version_minor": 0
659
+ },
660
+ "text/plain": [
661
+ "tokenizer.json: 0.00B [00:00, ?B/s]"
662
+ ]
663
+ },
664
+ "metadata": {},
665
+ "output_type": "display_data"
666
+ },
667
+ {
668
+ "data": {
669
+ "application/vnd.jupyter.widget-view+json": {
670
+ "model_id": "0b4b7108d21c44ca96368b6b6f137002",
671
+ "version_major": 2,
672
+ "version_minor": 0
673
+ },
674
+ "text/plain": [
675
+ "special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
676
+ ]
677
+ },
678
+ "metadata": {},
679
+ "output_type": "display_data"
680
+ },
681
+ {
682
+ "data": {
683
+ "application/vnd.jupyter.widget-view+json": {
684
+ "model_id": "1f86b5b8152645bcb5318b15d57243d8",
685
+ "version_major": 2,
686
+ "version_minor": 0
687
+ },
688
+ "text/plain": [
689
+ "config.json: 0%| | 0.00/670 [00:00<?, ?B/s]"
690
+ ]
691
+ },
692
+ "metadata": {},
693
+ "output_type": "display_data"
694
+ },
695
+ {
696
+ "data": {
697
+ "application/vnd.jupyter.widget-view+json": {
698
+ "model_id": "d9e4d08d968e4a6a8bd93476a68e1f83",
699
+ "version_major": 2,
700
+ "version_minor": 0
701
+ },
702
+ "text/plain": [
703
+ "pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
704
+ ]
705
+ },
706
+ "metadata": {},
707
+ "output_type": "display_data"
708
+ },
709
+ {
710
+ "data": {
711
+ "text/plain": [
712
+ "BertForMaskedLM(\n",
713
+ " (bert): BertModel(\n",
714
+ " (embeddings): BertEmbeddings(\n",
715
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
716
+ " (position_embeddings): Embedding(512, 768)\n",
717
+ " (token_type_embeddings): Embedding(2, 768)\n",
718
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
719
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
720
+ " )\n",
721
+ " (encoder): BertEncoder(\n",
722
+ " (layer): ModuleList(\n",
723
+ " (0-11): 12 x BertLayer(\n",
724
+ " (attention): BertAttention(\n",
725
+ " (self): BertSdpaSelfAttention(\n",
726
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
727
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
728
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
729
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
730
+ " )\n",
731
+ " (output): BertSelfOutput(\n",
732
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
733
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
734
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
735
+ " )\n",
736
+ " )\n",
737
+ " (intermediate): BertIntermediate(\n",
738
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
739
+ " (intermediate_act_fn): GELUActivation()\n",
740
+ " )\n",
741
+ " (output): BertOutput(\n",
742
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
743
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
744
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
745
+ " )\n",
746
+ " )\n",
747
+ " )\n",
748
+ " )\n",
749
+ " )\n",
750
+ " (cls): BertOnlyMLMHead(\n",
751
+ " (predictions): BertLMPredictionHead(\n",
752
+ " (transform): BertPredictionHeadTransform(\n",
753
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
754
+ " (transform_act_fn): GELUActivation()\n",
755
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
756
+ " )\n",
757
+ " (decoder): Linear(in_features=768, out_features=30522, bias=True)\n",
758
+ " )\n",
759
+ " )\n",
760
+ ")"
761
+ ]
762
+ },
763
+ "execution_count": 126,
764
+ "metadata": {},
765
+ "output_type": "execute_result"
766
+ },
767
+ {
768
+ "data": {
769
+ "application/vnd.jupyter.widget-view+json": {
770
+ "model_id": "afa41fd273a147f08a2017dad5455866",
771
+ "version_major": 2,
772
+ "version_minor": 0
773
+ },
774
+ "text/plain": [
775
+ "model.safetensors: 0%| | 0.00/438M [00:00<?, ?B/s]"
776
+ ]
777
+ },
778
+ "metadata": {},
779
+ "output_type": "display_data"
780
+ }
781
+ ],
782
+ "source": [
783
+ "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
784
+ "import torch\n",
785
+ "import torch.nn.functional as F\n",
786
+ "\n",
787
+ "model_name = \"naver/splade-cocondenser-ensembledistil\"\n",
788
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
789
+ "model = AutoModelForMaskedLM.from_pretrained(model_name)\n",
790
+ "model.eval()\n"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": 127,
796
+ "id": "0dcffefe",
797
+ "metadata": {},
798
+ "outputs": [],
799
+ "source": [
800
+ "def get_splade_sparse_vector(text):\n",
801
+ " with torch.no_grad():\n",
802
+ " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512)\n",
803
+ " logits = model(**inputs).logits.squeeze(0) # [seq_len, vocab_size]\n",
804
+ " relu_out = F.relu(logits)\n",
805
+ " splade_weights = torch.log1p(relu_out).max(dim=0).values\n",
806
+ " indices = torch.nonzero(splade_weights).squeeze()\n",
807
+ " return {\n",
808
+ " tokenizer.convert_ids_to_tokens([i.item()])[0]: splade_weights[i].item()\n",
809
+ " for i in indices\n",
810
+ " }\n"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 128,
816
+ "id": "0493ae98",
817
+ "metadata": {},
818
+ "outputs": [
819
+ {
820
+ "name": "stderr",
821
+ "output_type": "stream",
822
+ "text": [
823
+ "100%|██████████| 985/985 [00:39<00:00, 24.77it/s]\n"
824
+ ]
825
+ }
826
+ ],
827
+ "source": [
828
+ "from tqdm import tqdm\n",
829
+ "\n",
830
+ "# Reset index to align doc IDs\n",
831
+ "df = df.reset_index(drop=True)\n",
832
+ "\n",
833
+ "# Get list of standard texts\n",
834
+ "standard_texts = df[\"State Standard\"].astype(str).tolist()\n",
835
+ "\n",
836
+ "# Compute sparse vectors\n",
837
+ "splade_doc_vectors = [get_splade_sparse_vector(text) for text in tqdm(standard_texts)]\n"
838
+ ]
839
+ },
840
+ {
841
+ "cell_type": "code",
842
+ "execution_count": 129,
843
+ "id": "42c8b47b",
844
+ "metadata": {},
845
+ "outputs": [],
846
+ "source": [
847
+ "def dot_product_sparse(query_vec, doc_vec):\n",
848
+ " return sum(query_vec.get(term, 0.0) * doc_vec.get(term, 0.0) for term in query_vec)\n"
849
+ ]
850
+ },
851
+ {
852
+ "cell_type": "code",
853
+ "execution_count": 131,
854
+ "id": "6254a5e9",
855
+ "metadata": {},
856
+ "outputs": [],
857
+ "source": [
858
+ "def retrieve_top_n_splade(query, top_n=5):\n",
859
+ " query_vec = get_splade_sparse_vector(query)\n",
860
+ " scores = [\n",
861
+ " (dot_product_sparse(query_vec, doc_vec), idx)\n",
862
+ " for idx, doc_vec in enumerate(splade_doc_vectors)\n",
863
+ " ]\n",
864
+ " \n",
865
+ " top_matches = sorted(scores, reverse=True)[:top_n]\n",
866
+ " \n",
867
+ " results = []\n",
868
+ " for score, idx in top_matches:\n",
869
+ " results.append({\n",
870
+ " \"rank\": len(results) + 1,\n",
871
+ " \"score\": round(score, 4),\n",
872
+ " \"standard\": df.iloc[idx][\"State Standard\"],\n",
873
+ " \"ID\": df.iloc[idx][\"ID\"],\n",
874
+ " \"Category\": df.iloc[idx][\"Category\"],\n",
875
+ " \"Sub Category\": df.iloc[idx][\"Sub Category\"]\n",
876
+ " })\n",
877
+ " return results\n"
878
+ ]
879
+ },
880
+ {
881
+ "cell_type": "code",
882
+ "execution_count": 146,
883
+ "id": "26f27920",
884
+ "metadata": {},
885
+ "outputs": [],
886
+ "source": [
887
+ "query = \"Identify the main idea of a text\"\n",
888
+ "results = retrieve_top_n_splade(query)\n"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": 147,
894
+ "id": "6f736b83",
895
+ "metadata": {},
896
+ "outputs": [
897
+ {
898
+ "data": {
899
+ "text/plain": [
900
+ "[{'rank': 1,\n",
901
+ " 'score': 21.3089,\n",
902
+ " 'standard': 'determine main idea text explain supported key detail summarize text',\n",
903
+ " 'ID': '4.RI.2',\n",
904
+ " 'Category': 'Reading Informational',\n",
905
+ " 'Sub Category': 'Key Ideas and Details'},\n",
906
+ " {'rank': 2,\n",
907
+ " 'score': 20.8493,\n",
908
+ " 'standard': 'determine main idea text recount key detail explain support main idea',\n",
909
+ " 'ID': '3.RI.2',\n",
910
+ " 'Category': 'Reading Informational',\n",
911
+ " 'Sub Category': 'Key Ideas and Details'},\n",
912
+ " {'rank': 3,\n",
913
+ " 'score': 20.2714,\n",
914
+ " 'standard': 'determine two main idea text explain supported key detail summarize text',\n",
915
+ " 'ID': '5.RI.2',\n",
916
+ " 'Category': 'Reading Informational',\n",
917
+ " 'Sub Category': 'Key Ideas and Details'},\n",
918
+ " {'rank': 4,\n",
919
+ " 'score': 17.5151,\n",
920
+ " 'standard': 'determine main idea supporting detail text read aloud information presented diverse medium format including visually quantitatively orally',\n",
921
+ " 'ID': '3.SL.2',\n",
922
+ " 'Category': 'Speaking & Listening',\n",
923
+ " 'Sub Category': 'Comprehension and Collaboration'},\n",
924
+ " {'rank': 5,\n",
925
+ " 'score': 17.512,\n",
926
+ " 'standard': 'identify main purpose text including author want answer explain describe',\n",
927
+ " 'ID': '2.RI.6',\n",
928
+ " 'Category': 'Reading Informational',\n",
929
+ " 'Sub Category': 'Craft and Structure'}]"
930
+ ]
931
+ },
932
+ "execution_count": 147,
933
+ "metadata": {},
934
+ "output_type": "execute_result"
935
+ }
936
+ ],
937
+ "source": [
938
+ "results"
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "code",
943
+ "execution_count": 148,
944
+ "id": "f6232e19",
945
+ "metadata": {},
946
+ "outputs": [],
947
+ "source": [
948
+ "def evaluate_top1_accuracy(df, retrieve_fn):\n",
949
+ " correct = 0\n",
950
+ " total = len(df)\n",
951
+ "\n",
952
+ " for i in range(total):\n",
953
+ " query = df.loc[i, \"State Standard\"]\n",
954
+ " expected = query.strip().lower()\n",
955
+ "\n",
956
+ " results = retrieve_fn(query, top_n=1)\n",
957
+ " predicted = results[0][\"standard\"].strip().lower()\n",
958
+ "\n",
959
+ " if predicted == expected:\n",
960
+ " correct += 1\n",
961
+ "\n",
962
+ " accuracy = round(correct / total, 4)\n",
963
+ " print(f\"Top-1 Accuracy: {accuracy}\")\n",
964
+ " return accuracy\n"
965
+ ]
966
+ },
967
+ {
968
+ "cell_type": "code",
969
+ "execution_count": 156,
970
+ "id": "5d653426",
971
+ "metadata": {},
972
+ "outputs": [
973
+ {
974
+ "name": "stdout",
975
+ "output_type": "stream",
976
+ "text": [
977
+ "Top-1 Accuracy: 0.9959\n"
978
+ ]
979
+ },
980
+ {
981
+ "data": {
982
+ "text/plain": [
983
+ "0.9959"
984
+ ]
985
+ },
986
+ "execution_count": 156,
987
+ "metadata": {},
988
+ "output_type": "execute_result"
989
+ }
990
+ ],
991
+ "source": [
992
+ "# For BM25\n",
993
+ "evaluate_top1_accuracy(df, retrieve_top_n_bm25)"
994
+ ]
995
+ },
996
+ {
997
+ "cell_type": "code",
998
+ "execution_count": 153,
999
+ "id": "6f9e5c59",
1000
+ "metadata": {},
1001
+ "outputs": [
1002
+ {
1003
+ "name": "stdout",
1004
+ "output_type": "stream",
1005
+ "text": [
1006
+ "Top-1 Accuracy: 0.9797\n"
1007
+ ]
1008
+ },
1009
+ {
1010
+ "data": {
1011
+ "text/plain": [
1012
+ "0.9797"
1013
+ ]
1014
+ },
1015
+ "execution_count": 153,
1016
+ "metadata": {},
1017
+ "output_type": "execute_result"
1018
+ }
1019
+ ],
1020
+ "source": [
1021
+ "# For SPLADE\n",
1022
+ "evaluate_top1_accuracy(df, retrieve_top_n_splade)\n"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "markdown",
1027
+ "id": "a2b5800b",
1028
+ "metadata": {},
1029
+ "source": [
1030
+ "## Comparison: BM25 vs SPLADE for CCSS Alignment\n",
1031
+ "\n",
1032
+ "**Query:** \n",
1033
+ "> *\"Identify the main idea of a text\"*\n",
1034
+ "\n",
1035
+ "---\n",
1036
+ "\n",
1037
+ "### Top-5 Results: **BM25**\n",
1038
+ "\n",
1039
+ "| Rank | ID | Category | Sub Category | State Standard | Score |\n",
1040
+ "|------|---------|----------------------|----------------------|----------------------------------------------------------------------------------|---------|\n",
1041
+ "| 1 | 1.RI.2 | Reading Informational| Key Ideas and Details| identify main topic retell key detail text | 10.666 |\n",
1042
+ "| 2 | 3.RI.2 | Reading Informational| Key Ideas and Details| determine main idea text recount key detail explain support main idea | 10.0953 |\n",
1043
+ "| 3 | K.RI.2 | Reading Informational| Key Ideas and Details| prompting support identify main topic retell key detail text | 9.8043 |\n",
1044
+ "| 4 | 2.RI.6 | Reading Informational| Craft and Structure | identify main purpose text including author want answer explain describe | 9.4236 |\n",
1045
+ "| 5 | 2.RI.2 | Reading Informational| Key Ideas and Details| identify main topic multiparagraph text well focus specific paragraph within text | 9.3944 |\n",
1046
+ "\n",
1047
+ "---\n",
1048
+ "\n",
1049
+ "### Top-5 Results: **SPLADE (Sparse Embedding Model)**\n",
1050
+ "\n",
1051
+ "| Rank | ID | Category | Sub Category | State Standard | Score |\n",
1052
+ "|------|---------|----------------------|----------------------|----------------------------------------------------------------------------------|---------|\n",
1053
+ "| 1 | 4.RI.2 | Reading Informational| Key Ideas and Details| determine main idea text explain supported key detail summarize text | 21.3089 |\n",
1054
+ "| 2 | 3.RI.2 | Reading Informational| Key Ideas and Details| determine main idea text recount key detail explain support main idea | 20.8493 |\n",
1055
+ "| 3 | 5.RI.2 | Reading Informational| Key Ideas and Details| determine two main idea text explain supported key detail summarize text | 20.2714 |\n",
1056
+ "| 4 | 3.SL.2 | Speaking & Listening | Comprehension and Collaboration | determine main idea supporting detail text read aloud information presented diverse medium format including visually quantitatively orally | 17.5151 |\n",
1057
+ "| 5 | 2.RI.6 | Reading Informational| Craft and Structure | identify main purpose text including author want answer explain describe | 17.512 |\n",
1058
+ "\n",
1059
+ "---\n",
1060
+ "\n",
1061
+ "### Insights:\n",
1062
+ "\n",
1063
+ "- Both **BM25 and SPLADE** correctly rank **\"3.RI.2\"** and **\"2.RI.6\"** in the top-5.\n",
1064
+ "- **SPLADE ranks more abstract or paraphrased variants** (e.g., \"summarize\", \"supported key detail\") higher due to its semantic understanding.\n",
1065
+ "- SPLADE retrieves **higher-level matches** like **\"5.RI.2\"** and **\"4.RI.2\"**, which are **semantically related** but not lexically identical.\n",
1066
+ "- BM25 relies on **exact term overlap**, favoring simpler phrasings like \"identify main topic\".\n",
1067
+ "\n",
1068
+ "---\n",
1069
+ "\n",
1070
+ "### Conclusion:\n",
1071
+ "\n",
1072
+ "| Feature | BM25 | SPLADE |\n",
1073
+ "|--------------------------|----------------------------|----------------------------------|\n",
1074
+ "| Matching Type | Exact lexical match | Semantic sparse match |\n",
1075
+ "| Interpretability | High (term overlap) | High (per-term weights) |\n",
1076
+ "| Handles Paraphrasing | No | Yes |\n",
1077
+ "| Use Case Fit | Good for short, exact queries | Great for natural language input |\n",
1078
+ "\n",
1079
+ "---\n",
1080
+ "\n",
1081
+ "### Top-1 Accuracy\n",
1082
+ "\n",
1083
+ "| Model | Top-1 Accuracy |\n",
1084
+ "|---------|----------------|\n",
1085
+ "| BM25 | **0.9959** |\n",
1086
+ "| SPLADE | **0.9797** |\n",
1087
+ "\n",
1088
+ "---\n",
1089
+ "\n",
1090
+ "### Insights\n",
1091
+ "\n",
1092
+ "- **BM25** achieves near-perfect accuracy due to exact term matching, especially since queries are identical to indexed documents.\n",
1093
+ "- **SPLADE** performs slightly lower because it may **re-rank paraphrases or semantic neighbors**, even when the original text is present.\n",
1094
+ "\n",
1095
+ "---\n"
1096
+ ]
1097
+ }
1098
+ ],
1099
+ "metadata": {
1100
+ "kernelspec": {
1101
+ "display_name": "venv",
1102
+ "language": "python",
1103
+ "name": "venv"
1104
+ },
1105
+ "language_info": {
1106
+ "codemirror_mode": {
1107
+ "name": "ipython",
1108
+ "version": 3
1109
+ },
1110
+ "file_extension": ".py",
1111
+ "mimetype": "text/x-python",
1112
+ "name": "python",
1113
+ "nbconvert_exporter": "python",
1114
+ "pygments_lexer": "ipython3",
1115
+ "version": "3.12.7"
1116
+ }
1117
+ },
1118
+ "nbformat": 4,
1119
+ "nbformat_minor": 5
1120
+ }
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core libraries
2
+ pandas
3
+ numpy
4
+ scikit-learn
5
+
6
+ # NLP
7
+ nltk
8
+ transformers
9
+ torch
10
+
11
+ # BM25
12
+ rank_bm25
13
+
14
+ # Streamlit UI
15
+ streamlit
16
+
17
+ # Plotting (optional, if using matplotlib in app.py)
18
+ matplotlib
19
+
20
+ # For reading Excel
21
+ openpyxl
22
+
23
+ # Ensure compatibility with tokenizers
24
+ tokenizers