chatbot-model-bills / src /components /render_tab_inspect.py
penguinsfly's picture
reorganize into component files, add info, and add pca plot
fd4a87f verified
import numpy as np
import pandas as pd
import streamlit as st
from annotated_text import annotated_text
def render(
tab,
df,
seg_df,
text_df,
disp_sim_df,
min_nwords,
thres_ratio,
ratio_bg_cmap = 'Reds',
inspect_disp_max_cols = 4,
thres_fuzzy_ratio_valid = 50,
disp_char_jitter = 100,
disp_text_flank = ' [...] ',
disp_text_color = '#fc9272',
):
if 'bill_model_pair' not in st.session_state:
return
# Get selection
sel_bm_idx = st.session_state['bill_model_pair']['selection']['rows'][0]
sel_bm_row = disp_sim_df.iloc[sel_bm_idx]
sel_model, sel_bill, sel_doc = sel_bm_row['model'], sel_bm_row['bill_id'], sel_bm_row['doc_id']
# Get selected bill with high sim. sentences
sel_bm_df = (
df.query(
'source__model == @sel_model '
'and target__bill_id == @sel_bill '
'and source__est_nwords >= @min_nwords'
)
.reset_index(drop=True)
)
idx_highsentidx = (
sel_bm_df.query('ratio > @thres_ratio')
['source__model_sent_idx'].unique()
)
sel_sim_seg_df = (
sel_bm_df.query(
'source__model_sent_idx in @idx_highsentidx'
)
.filter(
regex='ratio|postproc|source__model(_sent_idx)?$|target__bill|target__doc_id|target__doc_rank'
)
.merge(
seg_df.filter([
'model','section_label','model_sent_idx','sentence'
]).add_prefix('source__'),
how='left'
)
)
assert sel_sim_seg_df['target__bill'].nunique() == 1
sel_bill_name = sel_sim_seg_df['target__bill'].iloc[0]
# Display table for high sim. model sentencecs
disp_ver_seg_df = (
sel_sim_seg_df
.rename(columns={
'source__model_sent_idx': ('model', 'sentence id'),
'source__section_label': ('model', 'section'),
'source__sentence': ('model', 'sentence')
})
.pivot(
index=[
('model', 'sentence id'),
('model', 'section'),
('model', 'sentence'),
],
columns='target__doc_rank',
values=['ratio']
)
.add_prefix('bill ver. ')
.rename(
columns={'bill ver. ratio': 'similarity % with model sentence'}
)
.reset_index()
)
disp_ver_seg_df.columns.names = [None,None]
# Create sub components
tab.header(f'Inspect similarity between {sel_model} model and {sel_bill_name} bill')
selseg_col1, selseg_col2 = tab.columns([1, 2])
# Visualize whole-bill similarity across available bill progress
selseg_col1.subheader(f'Bill-model similarity across versions')
sel_bill_prog = (
disp_sim_df.query('bill_id == @sel_bill and model == @sel_model')
)
assert len(sel_bill_prog) == 1
sel_bill_prog = sel_bill_prog.iloc[0]['progress_pct_src_sim']
if not hasattr(sel_bill_prog, '__len__'):
sel_bill_prog = [sel_bill_prog]
sel_bill_nvers = len(sel_bill_prog)
selseg_col1.bar_chart(
pd.DataFrame({
'version': np.arange(1, sel_bill_nvers + 1),
'% model sentences in bill': sel_bill_prog
}),
x='version',
y='% model sentences in bill'
)
# Table model sentences with single-row selection to inspect further
selseg_col2.subheader(f'Sentences from model {sel_model} found in bill')
selseg_col2.text(
'Note: Select one of the model sentences from table below.'
)
selseg_col2.dataframe(
data=(
disp_ver_seg_df.style
.background_gradient(
cmap=ratio_bg_cmap,
vmin=thres_ratio,
vmax=100.0,
subset=disp_ver_seg_df.filter(regex='bill').columns
)
.format(precision=1)
),
selection_mode='single-row-required',
key='selected_model_sentence',
on_select="rerun",
)
# Display model sentence and the potential corresponding bill's sentences per version
__inspect_single_sentence__(
tab,
text_df,
sel_sim_seg_df,
disp_ver_seg_df,
sel_bill_nvers,
inspect_disp_max_cols,
thres_fuzzy_ratio_valid,
disp_char_jitter,
disp_text_flank,
disp_text_color,
)
def __inspect_single_sentence__(
tab,
text_df,
sel_sim_seg_df,
disp_ver_seg_df,
sel_bill_nvers,
inspect_disp_max_cols = 4,
thres_fuzzy_ratio_valid = 50,
disp_char_jitter = 100,
disp_text_flank = ' [...] ',
disp_text_color = '#fc9272',
):
if 'selected_model_sentence' not in st.session_state:
return
sel_state_ms_idx = st.session_state['selected_model_sentence']['selection']['rows'][0]
# Get selection
if sel_state_ms_idx >= len(disp_ver_seg_df):
sel_state_ms_idx = 0
sel_ms_idx = disp_ver_seg_df.iloc[sel_state_ms_idx][('model','sentence id')]
sel_ms = disp_ver_seg_df.iloc[sel_state_ms_idx][('model','sentence')]
tab.subheader(f'Selected model sentence id {sel_ms_idx}, and found occurences across bill versions')
tab.text(
'Note: The annotation package currently messes with spacing so currently there may be disjoint sentences in show bill texts. '
'Typically at least 70% is where the similarity should be considered. Disregard stuff under that. '
'Additionally, the highlighted bill sentence may be a bit off sometimes to save processing time.'
)
# Create columns to display bill's segments similar to model sentence
num_dispsent_cols = sel_bill_nvers + 1
dispsent_cols = []
for _ in range(int(np.ceil(num_dispsent_cols / inspect_disp_max_cols))):
dispsent_cols.extend(tab.columns([1] * inspect_disp_max_cols))
for i_dsc, ds_col in enumerate(dispsent_cols):
# Display model sentence first
if i_dsc == 0:
ds_col.badge("Model sentence", color="blue")
ds_col.text(sel_ms)
continue
# Do nothing when columns exceed number of bill versions
if i_dsc > sel_bill_nvers:
continue
# `i_dsc` between 1 and `sel_bill_nvers` are the bill versions
sel_tgt_row = (
sel_sim_seg_df.query(
'source__model_sent_idx == @sel_ms_idx '
'and target__doc_rank == @i_dsc'
)
)
# None found
if len(sel_tgt_row) == 0:
ds_col.badge(
f"Bill ver. {i_dsc} sentence",
color="red"
)
ds_col.markdown('*(no fuzzy equiv. found)*')
continue
# Get version and text
assert len(sel_tgt_row) == 1
sel_tgt_row = sel_tgt_row.iloc[0]
sel_disp_doc_id = sel_tgt_row['target__doc_id']
sel_bill_ver_model_ratio = sel_tgt_row['ratio']
ds_col.badge(
f"Bill ver. {i_dsc} sentence" + (
f" ({sel_bill_ver_model_ratio:.1f}%)"
if sel_bill_ver_model_ratio > thres_fuzzy_ratio_valid
else ''
),
color="red"
)
# Too low fuzzy ratio to be display
if sel_bill_ver_model_ratio <= thres_fuzzy_ratio_valid:
ds_col.markdown('*(no fuzzy equiv. found)*')
continue
# Sufficient fuzzy ratio to display text, even if not reaching the global `thres_ratio`
sel_bill_ver_text = text_df.query('doc_id == @sel_disp_doc_id')
assert len(sel_bill_ver_text) == 1
sel_bill_ver_text = sel_bill_ver_text.iloc[0]['text'].replace('$',r'\$')
# Target bill text segments with some flanking to see some context around
# Note: the indices are found after some processing from `rapidfuzz`, but here just display pre-proc text
sel_tgt_start = int(sel_tgt_row['target_postproc_start'])
sel_tgt_end = int(sel_tgt_row['target_postproc_end'])
flanked_sel_tgt_start = max(sel_tgt_start - disp_char_jitter, 0)
flanked_sel_tgt_end = min(sel_tgt_end + disp_char_jitter, len(sel_bill_ver_text) - 1)
with ds_col:
annotated_text(
disp_text_flank + sel_bill_ver_text[flanked_sel_tgt_start:sel_tgt_start],
(
sel_bill_ver_text[sel_tgt_start:sel_tgt_end],
'', disp_text_color
),
sel_bill_ver_text[sel_tgt_end:flanked_sel_tgt_end] + disp_text_flank
)