File size: 2,302 Bytes
34a1c85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
import pandas as pd
import json
import matplotlib.pyplot as plt

from core.splade_utility import splade_utility
from core.bm25_utility import bm25_utility



# ==== Import your models and functions ====
# Assume these are already defined in your notebook/script
# from your_code import retrieve_top_n_bm25, retrieve_top_n_splade, evaluate_top1_on_state_standard

# Dummy accuracy values (replace with real ones from your code)
bm25_accuracy = 0.9959
splade_accuracy = 0.9797

# Dummy placeholder functions (replace with your actual ones)
def retrieve_top_n_bm25(query, top_n=5):
    bm25_utility_instance = bm25_utility(query, top_n=5)
    top_n_results = bm25_utility_instance.retrieve_top_n_bm25()
    return top_n_results

def retrieve_top_n_splade(query, top_n=5):
    splade_utility_instance = splade_utility(query, top_n=top_n)
    return splade_utility_instance.retrieve_top_n_splade()

# ==== Streamlit UI ====

st.set_page_config(page_title="CCSS Alignment", layout="centered")
st.title("πŸ“š CCSS Alignment Search")

# Select model
model_choice = st.radio("Select Retrieval Model:", ["BM25", "SPLADE"])

# Accuracy bar chart
st.subheader("🎯 Model Top-1 Accuracy")
fig, ax = plt.subplots()
ax.bar(["BM25", "SPLADE"], [bm25_accuracy, splade_accuracy], color=["skyblue", "lightgreen"])
ax.set_ylim([0.9, 1.01])
ax.set_ylabel("Top-1 Accuracy")
for i, acc in enumerate([bm25_accuracy, splade_accuracy]):
    ax.text(i, acc + 0.001, f"{acc:.4f}", ha='center', fontsize=10)
st.pyplot(fig)

# Query input
st.subheader("πŸ” Try a Query")
query = st.text_area("Enter a lesson or objective text:", height=100)

# Search button
if st.button("Search"):
    st.subheader("πŸ“„ Top Results")

    if model_choice == "BM25":
        results = retrieve_top_n_bm25(query, top_n=5)
    else:
        results = retrieve_top_n_splade(query, top_n=5)

    if results:
        for i, r in enumerate(results, 1):
            st.markdown(f"""
            **Rank {i}**  
            - **Standard**: {r['standard']}  
            - **ID**: {r.get('ID', 'N/A')}  
            - **Category**: {r.get('Category', 'N/A')}  
            - **Sub Category**: {r.get('Sub Category', 'N/A')}  
            - **Score**: `{r['score']}`
            """)
    else:
        st.warning("No results found.")