Spaces:
Sleeping
Sleeping
| # app.py β Unified Panel App with Semantic Search + Filterable Tabulator | |
| import os, io, gc | |
| import panel as pn | |
| import pandas as pd | |
| import boto3, torch | |
| import psycopg2 | |
| from sentence_transformers import SentenceTransformer, util | |
| pn.extension('tabulator') | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) Database and Resource Loading | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DB_HOST = os.getenv("DB_HOST") | |
| DB_PORT = os.getenv("DB_PORT", "5432") | |
| DB_NAME = os.getenv("DB_NAME") | |
| DB_USER = os.getenv("DB_USER") | |
| DB_PASSWORD = os.getenv("DB_PASSWORD") | |
| def get_data(): | |
| conn = psycopg2.connect( | |
| host=DB_HOST, port=DB_PORT, | |
| dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, | |
| sslmode="require" | |
| ) | |
| df_ = pd.read_sql_query(""" | |
| SELECT id, country, year, section, | |
| question_code, question_text, | |
| answer_code, answer_text | |
| FROM survey_info; | |
| """, conn) | |
| conn.close() | |
| # Ensure year column is int, show blank instead of NaN | |
| if "year" in df_.columns: | |
| df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''}) | |
| return df_ | |
| df = get_data() | |
| def load_embeddings(): | |
| BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt" | |
| buf = io.BytesIO() | |
| boto3.client("s3").download_fileobj(BUCKET, KEY, buf) | |
| buf.seek(0) | |
| ckpt = torch.load(buf, map_location="cpu") | |
| buf.close(); gc.collect() | |
| return ckpt["ids"], ckpt["embeddings"] | |
| def get_st_model(): | |
| return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") | |
| def get_semantic_resources(): | |
| model = get_st_model() | |
| ids_list, emb_tensor = load_embeddings() | |
| return model, ids_list, emb_tensor | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2) Widgets | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| country_opts = sorted(df["country"].dropna().unique()) | |
| year_opts = sorted(df["year"].dropna().unique()) | |
| ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"] | |
| w_columns = pn.widgets.MultiChoice( | |
| name="Columns to show", | |
| options=ALL_COLUMNS, | |
| value=["country","year","question_text","answer_text"] | |
| ) | |
| w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts) | |
| w_years = pn.widgets.MultiSelect(name="Years", options=year_opts) | |
| w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching") | |
| w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False) | |
| w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True) | |
| w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search") | |
| w_search_button = pn.widgets.Button(name="Search", button_type="primary") | |
| w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3) Unified Results Table (Tabulator) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| result_table = pn.widgets.Tabulator( | |
| pagination='remote', | |
| page_size=15, | |
| sizing_mode="stretch_width", | |
| layout='fit_columns', | |
| show_index=False | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4) Search Logic | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame: | |
| if df_in.empty: | |
| return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"]) | |
| tmp = df_in.copy() | |
| tmp["year"] = tmp["year"].replace('', pd.NA) | |
| grouped = ( | |
| tmp.groupby("question_text", dropna=False) | |
| .agg({ | |
| "country": lambda x: sorted({v for v in x if pd.notna(v)}), | |
| "year": lambda x: sorted({str(v) for v in x if pd.notna(v)}), | |
| "answer_text": lambda x: list(x.dropna())[:3], | |
| }) | |
| .reset_index() | |
| .rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"}) | |
| ) | |
| return grouped | |
| def _selected_cols(has_score=False): | |
| allowed = set(ALL_COLUMNS) | |
| if not has_score and "Score" in w_columns.value: | |
| w_columns.value = [c for c in w_columns.value if c != "Score"] | |
| cols = [c for c in w_columns.value if c in allowed] | |
| if not cols: | |
| cols = ["country", "year", "question_text", "answer_text"] | |
| return cols | |
| def search(event=None): | |
| query = w_semquery.value.strip() | |
| filt = df.copy() | |
| if w_countries.value: | |
| filt = filt[filt["country"].isin(w_countries.value)] | |
| if w_years.value: | |
| filt = filt[filt["year"].isin(w_years.value)] | |
| if w_keyword.value: | |
| filt = filt[ | |
| filt["question_text"].str.contains(w_keyword.value, case=False, na=False) | | |
| filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) | | |
| filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False) | |
| ] | |
| if not query: | |
| result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)] | |
| return | |
| model, ids_list, emb_tensor = get_semantic_resources() | |
| filtered_ids = filt["id"].tolist() | |
| id_to_index = {id_: i for i, id_ in enumerate(ids_list)} | |
| filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index] | |
| if not filtered_indices: | |
| result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True)) | |
| return | |
| top_k = min(int(w_topk.value), len(filtered_indices)) | |
| filtered_embs = emb_tensor[filtered_indices] | |
| q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu() | |
| sims = util.cos_sim(q_vec, filtered_embs)[0] | |
| top_vals, top_idx = torch.topk(sims, k=top_k) | |
| top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()] | |
| sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy() | |
| score_map = dict(zip(top_filtered_ids, top_vals.tolist())) | |
| sem_rows["Score"] = sem_rows["id"].map(score_map) | |
| sem_rows = sem_rows.sort_values("Score", ascending=False) | |
| result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)] | |
| def clear_filters(event=None): | |
| w_countries.value = [] | |
| w_years.value = [] | |
| w_keyword.value = "" | |
| w_semquery.value = "" | |
| w_topk.disabled = True | |
| result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() | |
| w_search_button.on_click(search) | |
| w_clear_filters.on_click(clear_filters) | |
| # Live updates for filters (except semantic query and keyword) | |
| w_group.param.watch(lambda e: search(), 'value') | |
| w_countries.param.watch(lambda e: search(), 'value') | |
| w_years.param.watch(lambda e: search(), 'value') | |
| w_columns.param.watch(lambda e: search(), 'value') | |
| # Allow pressing Enter in semantic query or keyword to trigger search | |
| w_semquery.param.watch(lambda e: search(), 'enter_pressed') | |
| w_keyword.param.watch(lambda e: search(), 'enter_pressed') | |
| # Enable/disable Top-K based on semantic query presence | |
| def _toggle_topk_disabled(event=None): | |
| w_topk.disabled = (w_semquery.value.strip() == '') | |
| _toggle_topk_disabled() | |
| w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value') | |
| # Show all data at startup | |
| result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5) Layout | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sidebar = pn.Column( | |
| "## π Filters", | |
| w_countries, w_years, w_keyword, w_group, w_columns, | |
| pn.Spacer(height=20), | |
| "## π§ Semantic Search", | |
| w_semquery, | |
| w_topk, | |
| w_search_button, | |
| pn.Spacer(height=20), | |
| w_clear_filters, | |
| width=300 | |
| ) | |
| main = pn.Column( | |
| pn.pane.Markdown("## π CGD Survey Explorer"), | |
| result_table | |
| ) | |
| pn.template.FastListTemplate( | |
| title="CGD Survey Explorer", | |
| sidebar=sidebar, | |
| main=main, | |
| theme_toggle=True, | |
| ).servable() | |