import os import json import numpy as np import pandas as pd import faiss import streamlit as st import altair as alt from sentence_transformers import SentenceTransformer import csv from datetime import datetime #Config DB_DIR = "." FEEDBACK_CSV = os.path.join(DB_DIR, "impact_feedback.csv") DEFAULT_TOP_K = 10 IMPACT_ORDER = [ "Not Impactful", "Slightly Impactful", "Moderately Impactful", "Very Impactful" ] st.set_page_config( page_title="IGPA Legislation Explorer", layout="wide", initial_sidebar_state="expanded" ) #Loading vector database @st.cache_resource def load_vector_db(db_dir: str = DB_DIR): with open(os.path.join(db_dir, "config.json"), "r") as f: cfg = json.load(f) index = faiss.read_index(os.path.join(db_dir, "faiss_index.bin")) meta = pd.read_parquet(os.path.join(db_dir, "metadata.parquet")) if "vec_id" not in meta.columns: meta = meta.reset_index().rename(columns={"index": "vec_id"}) model = SentenceTransformer(cfg["embedding_model_name"]) return index, meta, model, cfg index, meta_df, embed_model, cfg = load_vector_db() DATE_COL = "status_date_y" meta_df[DATE_COL] = pd.to_datetime( meta_df[DATE_COL], errors="coerce" ) DEFAULT_FILTERS = { "intended_beneficiary": "All", "policy_domain": "All", "impact_selected": "All", "category_main": "All", "category_sub": "All", "status_desc": "All", "date_range": ( meta_df[DATE_COL].min().date(), meta_df[DATE_COL].max().date() ) } for key, value in DEFAULT_FILTERS.items(): if key not in st.session_state: st.session_state[key] = value if "search_results" not in st.session_state: st.session_state.search_results = None if "current_query" not in st.session_state: st.session_state.current_query = "" def embed_query(query: str): return embed_model.encode( [query], normalize_embeddings=True, convert_to_numpy=True ).astype("float32") def impact_threshold(level): if level not in IMPACT_ORDER: return [] return IMPACT_ORDER[IMPACT_ORDER.index(level):] def append_feedback_row( bill_id, predicted_impact, user_response, corrected_impact=None, path=FEEDBACK_CSV, ): try: file_exists = os.path.isfile(path) with open(path, "a", newline="", encoding="utf-8") as f: writer = csv.writer(f) if not file_exists: writer.writerow( [ "timestamp", "bill_id", "predicted_impact", "user_response", "corrected_impact", ] ) writer.writerow( [ datetime.utcnow().isoformat(), bill_id, predicted_impact, user_response, corrected_impact if corrected_impact else "", ] ) st.sidebar.success(f"Feedback saved to: `{path}`") except Exception as e: st.error(f"Failed to save feedback: {str(e)}") def build_filter_mask(df, intended_beneficiary, policy_domain, impact_selected): mask = pd.Series(True, index=df.index) if intended_beneficiary != "All": mask &= df["intended_beneficiaries_standardized"] == intended_beneficiary if policy_domain != "All": mask &= df["policy_domain_standardized"] == policy_domain if impact_selected != "All": allowed = impact_threshold(impact_selected) mask &= df["impact_rating_standardized"].isin(allowed) if st.session_state.category_main != "All": mask &= df["category_main_label"] == st.session_state.category_main if st.session_state.category_sub != "All": mask &= df["category_sub_label"] == st.session_state.category_sub if "status_desc" in st.session_state and st.session_state.status_desc != "All": mask &= df["status_desc"] == st.session_state.status_desc if "date_range" in st.session_state and st.session_state.date_range: dr = st.session_state.date_range if isinstance(dr, (tuple, list)) and len(dr) == 2: start, end = dr else: start = end = dr if end == start: end = df[DATE_COL].max().date() start = pd.to_datetime(start) end = pd.to_datetime(end) mask &= df[DATE_COL].between(start, end) return mask def get_sorted_filter_options(df, col_name): counts = df[col_name].dropna().value_counts() sorted_vals = counts.index.tolist() return ["All"] + sorted_vals def reset_filters(): for key, value in DEFAULT_FILTERS.items(): st.session_state[key] = value st.rerun() #Filters with st.sidebar: st.header("Filters") if "history" not in st.session_state: st.session_state.history = [] if st.button("Reset Filters"): reset_filters() intended_beneficiary = st.selectbox( "Intended Beneficiary", get_sorted_filter_options(meta_df, "intended_beneficiaries_standardized"), key="intended_beneficiary" ) policy_domain = st.selectbox( "Policy Area", get_sorted_filter_options(meta_df, "policy_domain_standardized"), key="policy_domain" ) impact_selected = st.selectbox( "Impact Rating (≥ Selected Level)", ["All"] + IMPACT_ORDER, key="impact_selected" ) category_main = st.selectbox( "Category", get_sorted_filter_options(meta_df, "category_main_label"), key="category_main" ) category_sub = st.selectbox( "Sub Category", get_sorted_filter_options(meta_df, "category_sub_label"), key="category_sub" ) top_k = st.slider("Number of results", 5, 50, DEFAULT_TOP_K, 5) status_desc = st.selectbox( "Bill Status", ["All"] + sorted(meta_df["status_desc"].dropna().unique().tolist()), key="status_desc" ) st.subheader("Time Filter") min_date = meta_df[DATE_COL].min().date() max_date = meta_df[DATE_COL].max().date() default_value = st.session_state.get("date_range", (min_date, max_date)) if isinstance(default_value, (tuple, list)): if len(default_value) == 2: start, end = default_value else: start = end = default_value[0] else: start = end = default_value st.date_input( "Status Date Range", value=(start, end), min_value=min_date, max_value=max_date, key="date_range" ) if os.path.exists(FEEDBACK_CSV): try: df_feedback = pd.read_csv(FEEDBACK_CSV) st.info(f" Feedback records: {len(df_feedback)}") if st.button(" Download Feedback CSV"): st.download_button( label="Download impact_feedback.csv", data=open(FEEDBACK_CSV, 'rb').read(), file_name="impact_feedback.csv", mime="text/csv" ) except: st.info("Feedback CSV ready (empty)") filtered_df = meta_df[ build_filter_mask( meta_df, st.session_state.intended_beneficiary, st.session_state.policy_domain, st.session_state.impact_selected ) ] tab_search, tab_trends = st.tabs(["Search & Results", "Trends & Insights"]) #Search Tab with tab_search: st.title("IGPA Legislation Explorer") #Overview col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Bills", len(filtered_df)) with col2: st.metric( "Policy Domains", filtered_df["policy_domain_standardized"].nunique() ) with col3: st.metric( "Beneficiary Groups", filtered_df["intended_beneficiaries_standardized"].nunique() ) with col4: impact_counts = ( filtered_df["impact_rating_standardized"] .dropna() .value_counts() .reindex(IMPACT_ORDER, fill_value=0) ) st.metric("Impact Breakdown", len(filtered_df)) st.markdown( f"
" f"Very Impactful: {impact_counts['Very Impactful']} | " f"Moderately: {impact_counts['Moderately Impactful']} | " f"Slightly: {impact_counts['Slightly Impactful']} | " f"Not: {impact_counts['Not Impactful']}" f"
", unsafe_allow_html=True ) #Most Impacted Beneficiary Categories st.subheader("Most Impacted Beneficiary Categories") impact_df = ( filtered_df.dropna(subset=["beneficiary_category", "impact_rating_score"]) .groupby("beneficiary_category") .agg( avg_impact=("impact_rating_score", "mean"), bills=("bill_id","count"), top_bills=("title", lambda x: "; ".join(x.head(5))), top_beneficiaries=("intended_beneficiaries_standardized", lambda x: ", ".join(x.value_counts().head(3).index)) ) .reset_index() .sort_values("avg_impact", ascending=False) .head(10) ) if not impact_df.empty: st.altair_chart( alt.Chart(impact_df) .mark_bar() .encode( x=alt.X("beneficiary_category:N", sort="-y", title="Beneficiary Category"), y=alt.Y("avg_impact:Q", title="Average Impact Score"), color=alt.Color( "avg_impact:Q", scale=alt.Scale(domain=[0,4], range=["#FFF176","#E53935"]), legend=alt.Legend(title="Impact Severity") ), tooltip=[ alt.Tooltip("beneficiary_category:N", title="Beneficiary"), alt.Tooltip("avg_impact:Q", format=".2f", title="Average Impact"), alt.Tooltip("bills:Q", title="Number of Bills"), alt.Tooltip("top_bills:N", title="Top Bills"), alt.Tooltip("top_beneficiaries:N", title="Top Beneficiaries") ] ) .properties(height=350), use_container_width=True ) # Bills from Filters st.subheader("Bills Matching Selected Filters") display_cols = { "bill_number": "Bill Number", "title": "Title", "description": "Description", "policy_domain_standardized": "Policy Domain", "category_main_label": "Category", "intent_standardized": "Intent", "legislative_goal_standardized": "Legislative Goal", "beneficiary_category": "Beneficiary Group", "intended_beneficiaries_standardized": "Intended Beneficiaries", "potential_impact_raw": "Potential Impact", "impact_rating_standardized": "Impact Rating", "status_desc": "Status", "full_text_url": "Bill Link" } available_cols = {k: v for k, v in display_cols.items() if k in filtered_df.columns} filter_bill_df = ( filtered_df[list(available_cols.keys())] .rename(columns=available_cols) .copy() ) st.dataframe( filter_bill_df, use_container_width=True, column_config={ "Bill Link": st.column_config.LinkColumn( label="Bill Link", display_text="Open Bill" ) } ) st.markdown("---") #Search Bills st.subheader("Search Bills") query = st.text_area( "Ask a question about legislation", value=st.session_state.current_query, height=80, placeholder="Example: bills related to funding", key="search_query_input" ) search_clicked = st.button("Search", key="search_button") if search_clicked and query.strip(): st.session_state.current_query = query st.session_state.history.append({"query": query}) q_vec = embed_query(query) n_search = min(len(meta_df), top_k*5) scores, ids = index.search(q_vec, n_search) ids, scores = ids[0], scores[0] allowed = set(filtered_df.index) kept = [(i,s) for i,s in zip(ids,scores) if i in allowed][:top_k] if not kept: st.warning("No results found.") st.session_state.search_results = None else: results = meta_df.loc[[i for i,_ in kept]].copy() results["similarity"] = [s for _,s in kept] st.session_state.search_results = results if st.session_state.search_results is not None: results = st.session_state.search_results #Filtered Results Table st.subheader("Filtered Results Table") review_cols = [ "bill_number", "title", "description", "potential_impact_raw", "increasing_aspects_standardized", "decreasing_aspects_standardized", "similarity", "full_text_url" ] review_df = results[[c for c in review_cols if c in results.columns]].copy() review_df.rename( columns={ "bill_number": "Bill Number", "title": "Title", "description": "Description", "potential_impact_raw": "Potential Impact", "increasing_aspects_standardized": "Increasing Aspects", "decreasing_aspects_standardized": "Decreasing Aspects", "similarity": "Score", "full_text_url": "Bill URL" }, inplace=True ) st.dataframe( review_df, use_container_width=True, column_config={ "Bill URL": st.column_config.LinkColumn( "ILGA URL", display_text="Open bill" ) } ) st.markdown("---") st.subheader("Filtered Results") for idx, row in results.iterrows(): with st.container(): st.markdown(f"### Bill Number: {row['bill_number']}") st.markdown(f"**Title:** {row['title']}") st.write(row["description"]) if pd.notna(row.get("category_main_label")): st.write(f"**Main Category**: {row['category_main_label']}") if pd.notna(row.get("category_sub_label")): st.write(f"**Sub Category**: {row['category_sub_label']}") if pd.notna(row.get("llama_summary_raw")): st.markdown(f"**LLaMA Summary:** {row['llama_summary_raw']}") info_text = ( f"Session: {row.get('session','')} • " f"Chamber: {row.get('chamber','')} • " f"Impact: {row.get('impact_rating_standardized','')} • " f"Beneficiaries: {row.get('intended_beneficiaries_standardized','')} • " f"Domain: {row.get('policy_domain_standardized','')} • " f"Similarity: {row.get('similarity'):.3f}" ) st.caption(info_text) if pd.notna(row.get("full_text_url")): st.markdown(f"[🔗 View Full Bill]({row['full_text_url']})", unsafe_allow_html=True) std_cols = [ c for c in results.columns if c.endswith("_standardized") and c not in [ "impact_rating_standardized", "increasing_aspects_standardized", "decreasing_aspects_standardized", "original_law_standardized" ] ] with st.expander("More Details"): for c in std_cols: val = row.get(c) if pd.notna(val) and str(val).strip(): label = c.replace("_standardized","").replace("_"," ").title() st.write(f"**{label}**: {val}") with st.expander("Similar Bills"): sim_df = results.iloc[:5][ ["bill_number","title","description","full_text_url"] ].copy() st.dataframe( sim_df, use_container_width=True, column_config={ "full_text_url": st.column_config.LinkColumn( "Bill Link", display_text="Open" ) } ) #Impact rating feedbacK with st.expander("👍👎 Rate Impact Accuracy", expanded=False): st.markdown("**Is this impact rating accurate?**") predicted_impact = row.get("impact_rating_standardized", "") bill_id_safe = str(row.get('bill_id', idx)) # Check if feedback was already submitted for this bill feedback_submitted = st.session_state.get(f"feedback_done_{bill_id_safe}", False) if feedback_submitted: st.success("Thank you for your feedback!") st.caption(f"Bill: {row.get('bill_number', 'N/A')} | Saved to impact_feedback.csv") else: col1, col2 = st.columns(2) with col1: if st.button("👍 **Yes - Accurate**", key=f"yes_{bill_id_safe}", use_container_width=True): append_feedback_row( bill_id=bill_id_safe, predicted_impact=predicted_impact, user_response="Yes", corrected_impact=None, ) st.session_state[f"feedback_done_{bill_id_safe}"] = True st.sidebar.success(f"Feedback saved for {row.get('bill_number', bill_id_safe)}") st.rerun() with col2: if st.button("👎 **No - Incorrect**", key=f"no_{bill_id_safe}", use_container_width=True): st.session_state[f"show_corrected_{bill_id_safe}"] = True st.rerun() if st.session_state.get(f"show_corrected_{bill_id_safe}", False): st.info(f"**What should the impact rating be instead?**") corrected_value = st.selectbox( "**Correct impact rating**", IMPACT_ORDER, key=f"corrected_{bill_id_safe}", ) col_submit, col_cancel = st.columns([3, 1]) with col_submit: if st.button("**Submit Feedback**", key=f"submit_{bill_id_safe}", type="primary"): append_feedback_row( bill_id=bill_id_safe, predicted_impact=predicted_impact, user_response="No", corrected_impact=corrected_value, ) st.session_state[f"feedback_done_{bill_id_safe}"] = True st.session_state[f"show_corrected_{bill_id_safe}"] = False st.sidebar.success(f"Feedback saved for {row.get('bill_number', bill_id_safe)}") st.rerun() with col_cancel: if st.button("Cancel", key=f"cancel_{bill_id_safe}"): st.session_state[f"show_corrected_{bill_id_safe}"] = False st.rerun() #Search History with st.sidebar.expander("Search History"): for i,item in enumerate(reversed(st.session_state.history[-5:]),1): st.write(f"{i}. {item.get('query','')}") # TRENDS TAB with tab_trends: st.subheader("Trends & Insights") # Key Insights top_policy = filtered_df["policy_domain_standardized"].value_counts().head(1) top_beneficiaries = filtered_df["beneficiary_category"].value_counts().head(1) strategy_impact = ( filtered_df[filtered_df["impact_rating_standardized"].notna()] .groupby("legislative_strategy_standardized")["impact_rating_standardized"] .apply(lambda x: (x=="Very Impactful").sum()) ) avg_impact_ben = ( filtered_df.dropna(subset=["impact_rating_score"]) .groupby("beneficiary_category")["impact_rating_score"] .mean() .sort_values(ascending=False) ) total_bills = len(filtered_df) total_high_impact = (filtered_df["impact_rating_standardized"]=="Very Impactful").sum() st.markdown("### Key Insights") st.write(f"**Total Bills Considered:** {total_bills}") st.write(f"**Total Very Impactful Bills:** {total_high_impact}") st.write(f"**Most Active Policy Domain:** {top_policy.index[0]} ({top_policy.iloc[0]} bills)" if not top_policy.empty else "No data") st.write(f"**Most Benefited Group:** {top_beneficiaries.index[0]} ({top_beneficiaries.iloc[0]} bills)" if not top_beneficiaries.empty else "No data") st.write(f"**Strategy Producing Most Very Impactful Bills:** {strategy_impact.idxmax() if not strategy_impact.empty else 'N/A'}") st.write(f"**Highest Average Impact (Beneficiary):** {avg_impact_ben.index[0]} ({avg_impact_ben.iloc[0]:.2f})" if not avg_impact_ben.empty else "N/A") st.markdown("---") col1, col2 = st.columns(2) # Policy Domain with col1: st.markdown("### Policy Domain Activity") policy_agg = ( filtered_df.groupby("policy_domain_standardized") .agg( Count=("bill_id","count"), avg_impact=("impact_rating_score","mean"), top_bills=("title", lambda x: "; ".join(x.head(5))), top_beneficiaries=("intended_beneficiaries_standardized", lambda x: ", ".join(x.value_counts().head(3).index)), recent_date=("status_date_y", lambda x: x.max().strftime("%Y-%m-%d")), bill_numbers=("bill_number", lambda x: ", ".join(map(str, x.head(5)))) ) .reset_index() .rename(columns={"policy_domain_standardized":"Policy Domain"}) ) policy_chart = ( alt.Chart(policy_agg) .mark_bar() .encode( x=alt.X("Policy Domain:N", sort="-y", title="Policy Domain"), y=alt.Y("Count:Q", title="Number of Bills"), color=alt.Color("Count:Q", scale=alt.Scale(scheme="reds"), legend=None), tooltip=[ alt.Tooltip("Policy Domain:N"), alt.Tooltip("Count:Q", title="Number of Bills"), alt.Tooltip("avg_impact:Q", format=".2f", title="Average Impact"), alt.Tooltip("top_bills:N", title="Top Bills"), alt.Tooltip("top_beneficiaries:N", title="Top Beneficiaries"), alt.Tooltip("recent_date:N", title="Most Recent Bill"), alt.Tooltip("bill_numbers:N", title="Bill Numbers") ] ) .properties(height=400) ) st.altair_chart(policy_chart, use_container_width=True) # Impact Distribution with col2: st.markdown("### Impact Distribution") impact_dist = ( filtered_df[filtered_df["impact_rating_standardized"].notna()]["impact_rating_standardized"] .value_counts() .reindex(IMPACT_ORDER, fill_value=0) .reset_index() ) impact_dist.columns = ["Impact Level", "Count"] impact_chart = ( alt.Chart(impact_dist) .mark_bar() .encode( x=alt.X("Impact Level:N", sort=IMPACT_ORDER), y=alt.Y("Count:Q"), color=alt.Color("Count:Q", scale=alt.Scale(scheme="reds")), tooltip=[ alt.Tooltip("Impact Level:N"), alt.Tooltip("Count:Q") ] ) .properties(height=300) ) st.altair_chart(impact_chart, use_container_width=True) # Strategy High Impact st.markdown("### Legislative Strategy: Very Impactful Bills") strategy_high_impact = ( filtered_df[filtered_df["impact_rating_standardized"].notna()] .groupby("legislative_strategy_standardized") .agg( Very_Impactful_Bills=("impact_rating_standardized", lambda x: (x=="Very Impactful").sum()), top_bills=("title", lambda x: "; ".join(x.head(5))), top_beneficiaries=("intended_beneficiaries_standardized", lambda x: ", ".join(x.value_counts().head(3).index)), recent_date=("status_date_y", lambda x: x.max().strftime("%Y-%m-%d")) ) .reset_index() .rename(columns={"legislative_strategy_standardized":"Strategy"}) ) strategy_chart = ( alt.Chart(strategy_high_impact) .mark_bar() .encode( x=alt.X("Strategy:N", sort="-y", title="Strategy"), y=alt.Y("Very_Impactful_Bills:Q", title="Very Impactful Bills"), color=alt.Color("Very_Impactful_Bills:Q", scale=alt.Scale(scheme="orangered")), tooltip=[ alt.Tooltip("Strategy:N"), alt.Tooltip("Very_Impactful_Bills:Q"), alt.Tooltip("top_bills:N", title="Top Bills"), alt.Tooltip("top_beneficiaries:N", title="Top Beneficiaries"), alt.Tooltip("recent_date:N", title="Most Recent Bill") ] ) .properties(height=400) ) st.altair_chart(strategy_chart, use_container_width=True) # Impact by Category st.markdown("### Impact by Category") impact_cat = ( filtered_df[ filtered_df["impact_rating_standardized"].notna() & filtered_df["category_main_label"].notna() ] .groupby(["category_main_label", "impact_rating_standardized"]) .agg( Count=("bill_id","count"), avg_impact=("impact_rating_score","mean"), top_bills=("title", lambda x: "; ".join(x.head(5))), top_beneficiaries=("intended_beneficiaries_standardized", lambda x: ", ".join(x.value_counts().head(3).index)), recent_date=("status_date_y", lambda x: x.max().strftime("%Y-%m-%d")), bill_numbers=("bill_number", lambda x: ", ".join(map(str, x.head(5)))) ) .reset_index() ) if impact_cat.empty: st.write("No data available for impact by category.") else: top_categories = ( impact_cat.groupby("category_main_label")["Count"] .sum() .sort_values(ascending=False) .head(15) .index.tolist() ) impact_cat_top = impact_cat[impact_cat["category_main_label"].isin(top_categories)] impact_cat_chart = ( alt.Chart(impact_cat_top) .mark_bar() .encode( y=alt.Y("category_main_label:N", sort=top_categories, title="Category"), x=alt.X("Count:Q", stack="zero", title="Number of Bills"), color=alt.Color("impact_rating_standardized:N", sort=IMPACT_ORDER, scale=alt.Scale(scheme="reds"), title="Impact Rating"), tooltip=[ alt.Tooltip("category_main_label:N", title="Category"), alt.Tooltip("impact_rating_standardized:N", title="Impact Rating"), alt.Tooltip("Count:Q", title="Number of Bills"), alt.Tooltip("avg_impact:Q", format=".2f", title="Average Impact"), alt.Tooltip("top_bills:N", title="Top Bills"), alt.Tooltip("top_beneficiaries:N", title="Top Beneficiaries"), alt.Tooltip("recent_date:N", title="Most Recent Bill"), alt.Tooltip("bill_numbers:N", title="Bill Numbers") ] ) .properties(height=400) ) st.altair_chart(impact_cat_chart, use_container_width=True) # Beneficiary Treemap st.markdown("### Beneficiary Coverage & Average Impact") ben_treemap_df = ( filtered_df.dropna(subset=["beneficiary_category", "impact_rating_score"]) .groupby("beneficiary_category") .agg( total_bills=("bill_id","count"), avg_impact=("impact_rating_score","mean"), top_bills=("title", lambda x: "; ".join(x.head(5))), recent_date=("status_date_y", lambda x: x.max().strftime("%Y-%m-%d")), bill_numbers=("bill_number", lambda x: ", ".join(map(str, x.head(5)))) ) .reset_index() ) if not ben_treemap_df.empty: treemap = ( alt.Chart(ben_treemap_df) .mark_rect() .encode( x=alt.X("total_bills:Q", title="Number of Bills"), y=alt.Y("beneficiary_category:N", sort="-x", title="Beneficiary Category"), size="total_bills:Q", color=alt.Color("avg_impact:Q", scale=alt.Scale(domain=[0,4], range=["#FFF176","#E53935"]), legend=alt.Legend(title="Average Impact Score")), tooltip=[ alt.Tooltip("beneficiary_category:N", title="Beneficiary"), alt.Tooltip("total_bills:Q", title="Number of Bills"), alt.Tooltip("avg_impact:Q", format=".2f", title="Average Impact"), alt.Tooltip("top_bills:N", title="Top Bills"), alt.Tooltip("recent_date:N", title="Most Recent Bill"), alt.Tooltip("bill_numbers:N", title="Bill Numbers") ] ) .properties(height=400) ) st.altair_chart(treemap, use_container_width=True) else: st.write("No beneficiary impact data available for selected filters.")