Spaces:
Sleeping
Sleeping
File size: 2,302 Bytes
34a1c85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import pandas as pd
import json
import matplotlib.pyplot as plt
from core.splade_utility import splade_utility
from core.bm25_utility import bm25_utility
# ==== Import your models and functions ====
# Assume these are already defined in your notebook/script
# from your_code import retrieve_top_n_bm25, retrieve_top_n_splade, evaluate_top1_on_state_standard
# Dummy accuracy values (replace with real ones from your code)
bm25_accuracy = 0.9959
splade_accuracy = 0.9797
# Dummy placeholder functions (replace with your actual ones)
def retrieve_top_n_bm25(query, top_n=5):
bm25_utility_instance = bm25_utility(query, top_n=5)
top_n_results = bm25_utility_instance.retrieve_top_n_bm25()
return top_n_results
def retrieve_top_n_splade(query, top_n=5):
splade_utility_instance = splade_utility(query, top_n=top_n)
return splade_utility_instance.retrieve_top_n_splade()
# ==== Streamlit UI ====
st.set_page_config(page_title="CCSS Alignment", layout="centered")
st.title("π CCSS Alignment Search")
# Select model
model_choice = st.radio("Select Retrieval Model:", ["BM25", "SPLADE"])
# Accuracy bar chart
st.subheader("π― Model Top-1 Accuracy")
fig, ax = plt.subplots()
ax.bar(["BM25", "SPLADE"], [bm25_accuracy, splade_accuracy], color=["skyblue", "lightgreen"])
ax.set_ylim([0.9, 1.01])
ax.set_ylabel("Top-1 Accuracy")
for i, acc in enumerate([bm25_accuracy, splade_accuracy]):
ax.text(i, acc + 0.001, f"{acc:.4f}", ha='center', fontsize=10)
st.pyplot(fig)
# Query input
st.subheader("π Try a Query")
query = st.text_area("Enter a lesson or objective text:", height=100)
# Search button
if st.button("Search"):
st.subheader("π Top Results")
if model_choice == "BM25":
results = retrieve_top_n_bm25(query, top_n=5)
else:
results = retrieve_top_n_splade(query, top_n=5)
if results:
for i, r in enumerate(results, 1):
st.markdown(f"""
**Rank {i}**
- **Standard**: {r['standard']}
- **ID**: {r.get('ID', 'N/A')}
- **Category**: {r.get('Category', 'N/A')}
- **Sub Category**: {r.get('Sub Category', 'N/A')}
- **Score**: `{r['score']}`
""")
else:
st.warning("No results found.")
|