| |
|
|
| import os, io, gc |
| import panel as pn |
| import pandas as pd |
| import boto3, torch |
| import psycopg2 |
| from sentence_transformers import SentenceTransformer, util |
|
|
| pn.extension('tabulator') |
|
|
| |
| |
| |
| DB_HOST = os.getenv("DB_HOST") |
| DB_PORT = os.getenv("DB_PORT", "5432") |
| DB_NAME = os.getenv("DB_NAME") |
| DB_USER = os.getenv("DB_USER") |
| DB_PASSWORD = os.getenv("DB_PASSWORD") |
|
|
| @pn.cache() |
| def get_data(): |
| conn = psycopg2.connect( |
| host=DB_HOST, port=DB_PORT, |
| dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, |
| sslmode="require" |
| ) |
| df_ = pd.read_sql_query(""" |
| SELECT id, country, year, section, |
| question_code, question_text, |
| answer_code, answer_text |
| FROM survey_info; |
| """, conn) |
| conn.close() |
|
|
| |
| if "year" in df_.columns: |
| df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''}) |
| return df_ |
|
|
| df = get_data() |
|
|
| @pn.cache() |
| def load_embeddings(): |
| BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt" |
| buf = io.BytesIO() |
| boto3.client("s3").download_fileobj(BUCKET, KEY, buf) |
| buf.seek(0) |
| ckpt = torch.load(buf, map_location="cpu") |
| buf.close(); gc.collect() |
| return ckpt["ids"], ckpt["embeddings"] |
|
|
| @pn.cache() |
| def get_st_model(): |
| return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") |
|
|
| @pn.cache() |
| def get_semantic_resources(): |
| model = get_st_model() |
| ids_list, emb_tensor = load_embeddings() |
| return model, ids_list, emb_tensor |
|
|
| |
| |
| |
| country_opts = sorted(df["country"].dropna().unique()) |
| year_opts = sorted(df["year"].dropna().unique()) |
|
|
| ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"] |
| w_columns = pn.widgets.MultiChoice( |
| name="Columns to show", |
| options=ALL_COLUMNS, |
| value=["country","year","question_text","answer_text"] |
| ) |
|
|
| w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts) |
| w_years = pn.widgets.MultiSelect(name="Years", options=year_opts) |
| w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching") |
| w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False) |
| w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True) |
|
|
| w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered search") |
| w_search_button = pn.widgets.Button(name="Search", button_type="primary") |
| w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning") |
|
|
| |
| |
| |
| result_table = pn.widgets.Tabulator( |
| pagination='remote', |
| page_size=25, |
| sizing_mode="stretch_width", |
| layout='fit_columns', |
| show_index=False |
| ) |
|
|
| |
| |
| |
|
|
| def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame: |
| if df_in.empty: |
| return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"]) |
| tmp = df_in.copy() |
| tmp["year"] = tmp["year"].replace('', pd.NA) |
| grouped = ( |
| tmp.groupby("question_text", dropna=False) |
| .agg({ |
| "country": lambda x: sorted({v for v in x if pd.notna(v)}), |
| "year": lambda x: sorted({str(v) for v in x if pd.notna(v)}), |
| "answer_text": lambda x: list(x.dropna())[:3], |
| }) |
| .reset_index() |
| .rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"}) |
| ) |
| return grouped |
| |
| def _selected_cols(has_score=False): |
| allowed = set(ALL_COLUMNS) |
| if not has_score and "Score" in w_columns.value: |
| w_columns.value = [c for c in w_columns.value if c != "Score"] |
| cols = [c for c in w_columns.value if c in allowed] |
| if not cols: |
| cols = ["country", "year", "question_text", "answer_text"] |
| return cols |
|
|
| |
| def search(event=None): |
| query = w_semquery.value.strip() |
| filt = df.copy() |
| if w_countries.value: |
| filt = filt[filt["country"].isin(w_countries.value)] |
| if w_years.value: |
| filt = filt[filt["year"].isin(w_years.value)] |
| if w_keyword.value: |
| filt = filt[ |
| filt["question_text"].str.contains(w_keyword.value, case=False, na=False) | |
| filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) | |
| filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False) |
| ] |
|
|
| if not query: |
| result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)] |
| return |
|
|
| model, ids_list, emb_tensor = get_semantic_resources() |
| filtered_ids = filt["id"].tolist() |
| id_to_index = {id_: i for i, id_ in enumerate(ids_list)} |
| filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index] |
| if not filtered_indices: |
| result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True)) |
| return |
|
|
| top_k = min(int(w_topk.value), len(filtered_indices)) |
| filtered_embs = emb_tensor[filtered_indices] |
| q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu() |
| sims = util.cos_sim(q_vec, filtered_embs)[0] |
| top_vals, top_idx = torch.topk(sims, k=top_k) |
|
|
| top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()] |
| sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy() |
| score_map = dict(zip(top_filtered_ids, top_vals.tolist())) |
| sem_rows["Score"] = sem_rows["id"].map(score_map) |
| sem_rows = sem_rows.sort_values("Score", ascending=False) |
|
|
| result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)] |
|
|
|
|
| def clear_filters(event=None): |
| w_countries.value = [] |
| w_years.value = [] |
| w_keyword.value = "" |
| w_semquery.value = "" |
| w_topk.disabled = True |
| result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() |
|
|
| w_search_button.on_click(search) |
| w_clear_filters.on_click(clear_filters) |
|
|
| |
| w_group.param.watch(lambda e: search(), 'value') |
| w_countries.param.watch(lambda e: search(), 'value') |
| w_years.param.watch(lambda e: search(), 'value') |
| w_columns.param.watch(lambda e: search(), 'value') |
|
|
| |
| w_semquery.param.watch(lambda e: search(), 'enter_pressed') |
| w_keyword.param.watch(lambda e: search(), 'enter_pressed') |
|
|
| |
| def _toggle_topk_disabled(event=None): |
| w_topk.disabled = (w_semquery.value.strip() == '') |
| _toggle_topk_disabled() |
| w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value') |
|
|
| |
| result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() |
|
|
| |
| |
| |
| sidebar = pn.Column( |
| "## π Filters", |
| w_countries, w_years, w_keyword, w_group, w_columns, |
| pn.Spacer(height=20), |
| "## π§ Semantic Search", |
| w_semquery, |
| w_topk, |
| w_search_button, |
| pn.Spacer(height=20), |
| w_clear_filters, |
| width=300 |
| ) |
|
|
| main = pn.Column( |
| pn.pane.Markdown("## π CGD Question Search"), |
| result_table |
| ) |
|
|
| pn.template.FastListTemplate( |
| title="CGD Survey Explorer", |
| sidebar=sidebar, |
| main=main, |
| theme_toggle=True, |
| ).servable() |
|
|