Spaces:
Build error
Build error
fix filtering
Browse files
app.py
CHANGED
|
@@ -15,6 +15,14 @@ st.set_page_config(
|
|
| 15 |
initial_sidebar_state="expanded"
|
| 16 |
)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Load a pre-trained model for embeddings with HF caching
|
| 19 |
@st.cache_resource
|
| 20 |
def load_model():
|
|
@@ -38,7 +46,6 @@ def load_data():
|
|
| 38 |
return df[["uuid", "problem", "source", "question_type", "problem_type"]]
|
| 39 |
except Exception as e:
|
| 40 |
st.error(f"Error loading dataset: {e}")
|
| 41 |
-
# Return empty DataFrame with correct columns if loading fails
|
| 42 |
return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
|
| 43 |
|
| 44 |
# Cache embeddings computation with error handling
|
|
@@ -51,13 +58,11 @@ def compute_embeddings(problems):
|
|
| 51 |
st.error(f"Error computing embeddings: {e}")
|
| 52 |
return np.array([])
|
| 53 |
|
| 54 |
-
# ================== FUNCTION DEFINITIONS ==================
|
| 55 |
def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
| 56 |
"""Find similar problems using cosine similarity, optimized for speed."""
|
| 57 |
if df.empty:
|
| 58 |
return []
|
| 59 |
|
| 60 |
-
# Compute embeddings with progress tracking
|
| 61 |
embeddings = compute_embeddings(df['problem'].tolist())
|
| 62 |
if embeddings.size == 0:
|
| 63 |
return []
|
|
@@ -65,17 +70,14 @@ def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
|
| 65 |
if progress_bar:
|
| 66 |
progress_bar.progress(0.33, "Computing similarity matrix...")
|
| 67 |
|
| 68 |
-
# Compute similarity matrix
|
| 69 |
similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
|
| 70 |
if progress_bar:
|
| 71 |
progress_bar.progress(0.66, "Finding similar pairs...")
|
| 72 |
|
| 73 |
-
# Use numpy operations for better performance
|
| 74 |
num_problems = len(df)
|
| 75 |
upper_triangle_indices = np.triu_indices(num_problems, k=1)
|
| 76 |
similarity_scores = similarity_matrix[upper_triangle_indices]
|
| 77 |
|
| 78 |
-
# Filter based on threshold
|
| 79 |
mask = similarity_scores > similarity_threshold
|
| 80 |
filtered_indices = np.where(mask)[0]
|
| 81 |
|
|
@@ -121,19 +123,22 @@ def analyze_clusters(_df, pairs):
|
|
| 121 |
})
|
| 122 |
return detailed_analysis
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def main():
|
| 126 |
st.title("🔍 Problem Deduplication Explorer")
|
| 127 |
|
| 128 |
-
# Check if model loaded successfully
|
| 129 |
if model is None:
|
| 130 |
st.error("Failed to load the model. Please try again later.")
|
| 131 |
return
|
| 132 |
|
| 133 |
-
# Initialize session state for pagination
|
| 134 |
-
if 'page_number' not in st.session_state:
|
| 135 |
-
st.session_state.page_number = 0
|
| 136 |
-
|
| 137 |
# Sidebar configuration
|
| 138 |
with st.sidebar:
|
| 139 |
st.header("Settings")
|
|
@@ -168,12 +173,13 @@ def main():
|
|
| 168 |
)
|
| 169 |
|
| 170 |
# Analysis section
|
| 171 |
-
if st.sidebar.button("Run Deduplication Analysis", type="primary"):
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
|
| 175 |
-
pairs = find_similar_problems(df, similarity_threshold, progress_bar)
|
| 176 |
-
results = analyze_clusters(df, pairs)
|
| 177 |
|
| 178 |
if not results:
|
| 179 |
st.warning("No similar problems found with the current threshold.")
|
|
@@ -189,18 +195,17 @@ def main():
|
|
| 189 |
with col2:
|
| 190 |
selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
|
| 191 |
|
| 192 |
-
# Apply filters
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
if selected_qtype:
|
| 196 |
-
results = [r for r in results if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
|
| 197 |
|
| 198 |
-
if not
|
| 199 |
st.warning("No results found with the current filters.")
|
| 200 |
return
|
| 201 |
|
| 202 |
# Pagination
|
| 203 |
-
total_pages = len(
|
|
|
|
| 204 |
|
| 205 |
col1, col2, col3 = st.columns([1, 3, 1])
|
| 206 |
with col1:
|
|
@@ -215,7 +220,7 @@ def main():
|
|
| 215 |
# Display results
|
| 216 |
start_idx = st.session_state.page_number * items_per_page
|
| 217 |
end_idx = start_idx + items_per_page
|
| 218 |
-
page_results =
|
| 219 |
|
| 220 |
for entry in page_results:
|
| 221 |
with st.container():
|
|
|
|
| 15 |
initial_sidebar_state="expanded"
|
| 16 |
)
|
| 17 |
|
| 18 |
+
# Initialize session state
|
| 19 |
+
if 'page_number' not in st.session_state:
|
| 20 |
+
st.session_state.page_number = 0
|
| 21 |
+
if 'analysis_results' not in st.session_state:
|
| 22 |
+
st.session_state.analysis_results = None
|
| 23 |
+
if 'filtered_results' not in st.session_state:
|
| 24 |
+
st.session_state.filtered_results = None
|
| 25 |
+
|
| 26 |
# Load a pre-trained model for embeddings with HF caching
|
| 27 |
@st.cache_resource
|
| 28 |
def load_model():
|
|
|
|
| 46 |
return df[["uuid", "problem", "source", "question_type", "problem_type"]]
|
| 47 |
except Exception as e:
|
| 48 |
st.error(f"Error loading dataset: {e}")
|
|
|
|
| 49 |
return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
|
| 50 |
|
| 51 |
# Cache embeddings computation with error handling
|
|
|
|
| 58 |
st.error(f"Error computing embeddings: {e}")
|
| 59 |
return np.array([])
|
| 60 |
|
|
|
|
| 61 |
def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
| 62 |
"""Find similar problems using cosine similarity, optimized for speed."""
|
| 63 |
if df.empty:
|
| 64 |
return []
|
| 65 |
|
|
|
|
| 66 |
embeddings = compute_embeddings(df['problem'].tolist())
|
| 67 |
if embeddings.size == 0:
|
| 68 |
return []
|
|
|
|
| 70 |
if progress_bar:
|
| 71 |
progress_bar.progress(0.33, "Computing similarity matrix...")
|
| 72 |
|
|
|
|
| 73 |
similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
|
| 74 |
if progress_bar:
|
| 75 |
progress_bar.progress(0.66, "Finding similar pairs...")
|
| 76 |
|
|
|
|
| 77 |
num_problems = len(df)
|
| 78 |
upper_triangle_indices = np.triu_indices(num_problems, k=1)
|
| 79 |
similarity_scores = similarity_matrix[upper_triangle_indices]
|
| 80 |
|
|
|
|
| 81 |
mask = similarity_scores > similarity_threshold
|
| 82 |
filtered_indices = np.where(mask)[0]
|
| 83 |
|
|
|
|
| 123 |
})
|
| 124 |
return detailed_analysis
|
| 125 |
|
| 126 |
+
def apply_filters(results, df, selected_source, selected_qtype):
|
| 127 |
+
"""Apply filters to results."""
|
| 128 |
+
filtered = results.copy()
|
| 129 |
+
if selected_source:
|
| 130 |
+
filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["source"].values[0] == selected_source]
|
| 131 |
+
if selected_qtype:
|
| 132 |
+
filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
|
| 133 |
+
return filtered
|
| 134 |
+
|
| 135 |
def main():
|
| 136 |
st.title("🔍 Problem Deduplication Explorer")
|
| 137 |
|
|
|
|
| 138 |
if model is None:
|
| 139 |
st.error("Failed to load the model. Please try again later.")
|
| 140 |
return
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# Sidebar configuration
|
| 143 |
with st.sidebar:
|
| 144 |
st.header("Settings")
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# Analysis section
|
| 176 |
+
if st.sidebar.button("Run Deduplication Analysis", type="primary") or st.session_state.analysis_results is not None:
|
| 177 |
+
if st.session_state.analysis_results is None:
|
| 178 |
+
progress_bar = st.progress(0, "Starting analysis...")
|
| 179 |
+
pairs = find_similar_problems(df, similarity_threshold, progress_bar)
|
| 180 |
+
st.session_state.analysis_results = analyze_clusters(df, pairs)
|
| 181 |
|
| 182 |
+
results = st.session_state.analysis_results
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if not results:
|
| 185 |
st.warning("No similar problems found with the current threshold.")
|
|
|
|
| 195 |
with col2:
|
| 196 |
selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
|
| 197 |
|
| 198 |
+
# Apply filters and store in session state
|
| 199 |
+
filtered_results = apply_filters(results, df, selected_source, selected_qtype)
|
| 200 |
+
st.session_state.filtered_results = filtered_results
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
if not filtered_results:
|
| 203 |
st.warning("No results found with the current filters.")
|
| 204 |
return
|
| 205 |
|
| 206 |
# Pagination
|
| 207 |
+
total_pages = (len(filtered_results) - 1) // items_per_page
|
| 208 |
+
st.session_state.page_number = min(st.session_state.page_number, total_pages)
|
| 209 |
|
| 210 |
col1, col2, col3 = st.columns([1, 3, 1])
|
| 211 |
with col1:
|
|
|
|
| 220 |
# Display results
|
| 221 |
start_idx = st.session_state.page_number * items_per_page
|
| 222 |
end_idx = start_idx + items_per_page
|
| 223 |
+
page_results = filtered_results[start_idx:end_idx]
|
| 224 |
|
| 225 |
for entry in page_results:
|
| 226 |
with st.container():
|