CCSS_Alignment / app.py
meruem123's picture
Upload 15 files
34a1c85 verified
import streamlit as st
import pandas as pd
import json
import matplotlib.pyplot as plt
from core.splade_utility import splade_utility
from core.bm25_utility import bm25_utility
# ==== Import your models and functions ====
# Assume these are already defined in your notebook/script
# from your_code import retrieve_top_n_bm25, retrieve_top_n_splade, evaluate_top1_on_state_standard
# Dummy accuracy values (replace with real ones from your code)
bm25_accuracy = 0.9959
splade_accuracy = 0.9797
# Dummy placeholder functions (replace with your actual ones)
def retrieve_top_n_bm25(query, top_n=5):
bm25_utility_instance = bm25_utility(query, top_n=5)
top_n_results = bm25_utility_instance.retrieve_top_n_bm25()
return top_n_results
def retrieve_top_n_splade(query, top_n=5):
splade_utility_instance = splade_utility(query, top_n=top_n)
return splade_utility_instance.retrieve_top_n_splade()
# ==== Streamlit UI ====
st.set_page_config(page_title="CCSS Alignment", layout="centered")
st.title("πŸ“š CCSS Alignment Search")
# Select model
model_choice = st.radio("Select Retrieval Model:", ["BM25", "SPLADE"])
# Accuracy bar chart
st.subheader("🎯 Model Top-1 Accuracy")
fig, ax = plt.subplots()
ax.bar(["BM25", "SPLADE"], [bm25_accuracy, splade_accuracy], color=["skyblue", "lightgreen"])
ax.set_ylim([0.9, 1.01])
ax.set_ylabel("Top-1 Accuracy")
for i, acc in enumerate([bm25_accuracy, splade_accuracy]):
ax.text(i, acc + 0.001, f"{acc:.4f}", ha='center', fontsize=10)
st.pyplot(fig)
# Query input
st.subheader("πŸ” Try a Query")
query = st.text_area("Enter a lesson or objective text:", height=100)
# Search button
if st.button("Search"):
st.subheader("πŸ“„ Top Results")
if model_choice == "BM25":
results = retrieve_top_n_bm25(query, top_n=5)
else:
results = retrieve_top_n_splade(query, top_n=5)
if results:
for i, r in enumerate(results, 1):
st.markdown(f"""
**Rank {i}**
- **Standard**: {r['standard']}
- **ID**: {r.get('ID', 'N/A')}
- **Category**: {r.get('Category', 'N/A')}
- **Sub Category**: {r.get('Sub Category', 'N/A')}
- **Score**: `{r['score']}`
""")
else:
st.warning("No results found.")