klm-hitl / src /streamlit_app.py
Rajaa's picture
Rename src/app.py to src/streamlit_app.py
0379b83 verified
"""
MUSART-Augmented: FDR Audit Hub
================================
A Human-in-the-Loop (HITL) Extraction Audit Dashboard for validating
factual triples extracted from Wikipedia articles.
Usage:
# With mock data (works out-of-the-box):
streamlit run scripts/hitl_audit_app.py
# With your own CSV:
AUDIT_CSV=data/manual_evaluation_sample.csv streamlit run scripts/hitl_audit_app.py
"""
import os
import sys
import random
import hashlib
import html
import argparse
import pandas as pd
import numpy as np
import streamlit as st
# statsmodels is lazy-loaded inside compute_fdr_metrics() to avoid
# slow startup (importing scipy/statsmodels takes several seconds).
# ---------------------------------------------------------------------------
# Page configuration β€” must be the very first Streamlit command
# ---------------------------------------------------------------------------
st.set_page_config(
page_title="MUSART‑Augmented: FDR Audit Hub",
page_icon="πŸ”¬",
layout="wide",
initial_sidebar_state="expanded",
)
# ---------------------------------------------------------------------------
# Annotation label constants
# ---------------------------------------------------------------------------
PENDING = "Pending"
CORRECT = "Correct"
RELATION_MISMATCH = "Relation Mismatch"
EXTRACTION_ERROR = "Extraction Error"
ANNOTATION_OPTIONS = [PENDING, CORRECT, RELATION_MISMATCH, EXTRACTION_ERROR]
# ---------------------------------------------------------------------------
# Mock data generator (realistic domains so the app runs out‑of‑the‑box)
# ---------------------------------------------------------------------------
_MOCK_RECORDS = [
# (subject, relation, extracted_object, article_snippet)
("Paris Saint-Germain", "league", "Ligue 1",
"Paris Saint-Germain Football Club, commonly referred to as Paris Saint-Germain, "
"PSG, Paris SG or simply Paris, is a French professional football club based in "
"Paris. Founded in 1970, the club has competed in Ligue 1, the top division of "
"French football, for most of its history. PSG has won twelve Ligue 1 titles, "
"a record fifteen Coupes de France, and nine Coupes de la Ligue."),
("Paris Saint-Germain", "founded", "1970",
"Paris Saint-Germain Football Club, commonly referred to as Paris Saint-Germain, "
"PSG, Paris SG or simply Paris, is a French professional football club based in "
"Paris. Founded in 1970, the club has competed in Ligue 1, the top division of "
"French football, for most of its history."),
("Mega Man 6", "game mode", "single-player",
"Mega Man 6, known as Rockman 6: Shijō Saidai no Tatakai!! in Japan, is an "
"action-platform video game developed and published by Capcom for the Nintendo "
"Entertainment System (NES). It was released in Japan on November 5, 1993. "
"The game supports single-player mode in which the player controls Mega Man."),
("Mega Man 6", "publisher", "Capcom",
"Mega Man 6, known as Rockman 6: Shijō Saidai no Tatakai!! in Japan, is an "
"action-platform video game developed and published by Capcom for the Nintendo "
"Entertainment System (NES). It was released in Japan on November 5, 1993."),
("Mega Man 6", "platform", "Nintendo Entertainment System",
"Mega Man 6, known as Rockman 6: Shijō Saidai no Tatakai!! in Japan, is an "
"action-platform video game developed and published by Capcom for the Nintendo "
"Entertainment System (NES). It was released in Japan on November 5, 1993."),
("Eiffel Tower", "located in", "Paris",
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
"France. It is named after the engineer Gustave Eiffel, whose company designed "
"and built the tower from 1887 to 1889. Locally nicknamed 'La dame de fer', it "
"was constructed as the centerpiece of the 1889 World's Fair."),
("Eiffel Tower", "architect", "Gustave Eiffel",
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
"France. It is named after the engineer Gustave Eiffel, whose company designed "
"and built the tower from 1887 to 1889."),
("Homo sapiens", "kingdom", "Animalia",
"Homo sapiens, or modern humans, are the most common and widespread species "
"of primate, and the last surviving species of the genus Homo. They are part "
"of the kingdom Animalia and belong to the family Hominidae. Modern humans "
"evolved in Africa around 300,000 years ago."),
("Penicillin", "discovered by", "Alexander Fleming",
"Penicillin is a group of antibiotics that are widely used to treat bacterial "
"infections. It was discovered in 1928 by Scottish scientist Alexander Fleming "
"as a natural product of the mould Penicillium rubens. Howard Florey and "
"Ernst Boris Chain later purified and mass-produced it."),
("Halo 4", "genre", "first-person shooter",
"Halo 4 is a 2012 first-person shooter video game developed by 343 Industries "
"and published by Microsoft Studios for the Xbox 360 video game console. The "
"game was released on November 6, 2012. Halo 4's story follows a cybernetically "
"enhanced human supersoldier, Master Chief."),
("Lionel Messi", "position", "forward",
"Lionel AndrΓ©s Messi, also known as Leo Messi, is an Argentine professional "
"footballer who plays as a forward for Inter Miami and the Argentina national "
"team. Widely regarded as one of the greatest players of all time, Messi has "
"won a record eight Ballon d'Or awards."),
("Aspirin", "active ingredient", "acetylsalicylic acid",
"Aspirin, also known as acetylsalicylic acid (ASA), is a nonsteroidal "
"anti-inflammatory drug used to reduce pain, fever, and inflammation. It was "
"first synthesized by Felix Hoffmann at Bayer in 1897. Aspirin is one of the "
"most widely used medications globally."),
]
def _generate_mock_df(n: int = 50, seed: int = 42) -> pd.DataFrame:
"""Build *n* realistic mock records with verifiable span offsets."""
rng = random.Random(seed)
rows = []
for i in range(n):
tpl = rng.choice(_MOCK_RECORDS)
subj, rel, obj, article = tpl
# Find the *actual* span of the extracted object in the article
idx = article.lower().find(obj.lower())
if idx == -1:
# Object not literally present β†’ simulate an extraction error
span_start = 0
span_end = min(len(obj), len(article))
else:
span_start = idx
span_end = idx + len(obj)
uid = hashlib.md5(f"{i}-{subj}-{rel}-{obj}".encode()).hexdigest()[:10]
rows.append({
"extraction_id": f"ext-{uid}",
"subject_label": subj,
"relation_label": rel,
"extracted_object": obj,
"wikipedia_text": article,
"span_start": span_start,
"span_end": span_end,
"human_annotation": PENDING,
})
return pd.DataFrame(rows)
# ---------------------------------------------------------------------------
# Real data loader β€” loads the evaluation CSV generated by our pipeline
# ---------------------------------------------------------------------------
def _load_real_csv(path: str) -> pd.DataFrame:
"""
Load a CSV with columns like:
category, subject_title, relation, question,
new_entity_qid, new_entity_labels, wikipedia_url, ...
and reshape it into the expected dashboard schema.
"""
raw = pd.read_csv(path, dtype=str).fillna("")
# If the file was previously saved by the app, it already has our schema
if "human_annotation" in raw.columns:
return raw
rows = []
for i, r in raw.iterrows():
obj = r.get("new_entity_labels", "").split(" | ")[0] # first label
# We don't have the full article text in the CSV, so use the question
# as context and leave span at 0-0 (manual highlighting not applicable)
uid = r.get("extraction_id", "") or hashlib.md5(
f"{i}-{r.get('subject_title','')}-{r.get('relation','')}-{obj}".encode()
).hexdigest()[:10]
rows.append({
"extraction_id": f"ext-{uid}",
"subject_label": r.get("subject_title", ""),
"relation_label": r.get("relation", ""),
"extracted_object": obj,
"all_entity_labels": r.get("new_entity_labels", ""),
"wikipedia_url": r.get("wikipedia_url", ""),
"wikidata_url": r.get("wikidata_url", ""),
"entity_wikidata_url": f"https://www.wikidata.org/wiki/{r.get('new_entity_qid', '')}" if r.get("new_entity_qid") else "",
"question": r.get("question", ""),
"category": r.get("category", ""),
"wikipedia_text": "", # no article text in CSV
"span_start": 0,
"span_end": 0,
"human_annotation": r.get("verdict", "") or PENDING,
})
return pd.DataFrame(rows)
# ---------------------------------------------------------------------------
# Core display helper β€” highlight the extracted span in the source text
# ---------------------------------------------------------------------------
def highlight_span(text: str, start: int, end: int) -> str:
"""
Return HTML with the [start:end] substring wrapped in a <mark> tag.
This gives the annotator an instant visual proof of text grounding.
"""
if not text:
return "<em>(No article text available β€” use the Wikipedia link instead.)</em>"
safe = html.escape(text)
# Re‑compute offsets after HTML-escaping (escaping can shift indices)
prefix = html.escape(text[:start])
span = html.escape(text[start:end])
suffix = html.escape(text[end:])
highlighted = (
f'{prefix}'
f'<mark style="background:#ffea00; font-weight:700; padding:2px 4px; '
f'border-radius:3px;">{span}</mark>'
f'{suffix}'
)
return highlighted
# ---------------------------------------------------------------------------
# FDR computation with Wilson Score CI
# ---------------------------------------------------------------------------
def compute_fdr_metrics(df: pd.DataFrame):
"""
False Discovery Rate = (False Positives) / (Total Reviewed).
Where FP = Relation Mismatch + Extraction Error.
Returns:
(n_reviewed, fdr, ci_low, ci_high)
The Wilson Score interval is used because it has better coverage
properties than the Wald interval at small sample sizes β€” critical
for an in-progress audit where n may be < 30.
"""
reviewed = df[df["human_annotation"] != PENDING]
n = len(reviewed)
if n == 0:
return 0, 0.0, 0.0, 0.0
fp = len(reviewed[reviewed["human_annotation"].isin([RELATION_MISMATCH, EXTRACTION_ERROR])])
fdr = fp / n
# Lazy-load statsmodels to keep app startup fast
try:
from statsmodels.stats.proportion import proportion_confint
ci_low, ci_high = proportion_confint(fp, n, alpha=0.05, method="wilson")
except ImportError:
# Fallback: normal approximation
z = 1.96
se = np.sqrt(fdr * (1 - fdr) / n)
ci_low, ci_high = max(0.0, fdr - z * se), min(1.0, fdr + z * se)
return n, fdr, ci_low, ci_high
# ---------------------------------------------------------------------------
# Custom CSS for "Academic Clean" styling
# ---------------------------------------------------------------------------
_CSS = """
<style>
/* Metric cards */
.metric-card {
background: linear-gradient(135deg, #1e1e2f 0%, #2a2a40 100%);
border: 1px solid #3a3a5c;
border-radius: 12px;
padding: 20px 24px;
text-align: center;
}
.metric-card .metric-value {
font-size: 2.4rem;
font-weight: 800;
color: #e0e0ff;
margin: 4px 0;
font-family: 'SF Mono', 'Fira Code', monospace;
}
.metric-card .metric-label {
font-size: 0.85rem;
color: #9999bb;
text-transform: uppercase;
letter-spacing: 1.5px;
}
/* Claim display */
.claim-box {
background: #1a1a2e;
border-left: 4px solid #6c63ff;
padding: 18px 22px;
border-radius: 0 10px 10px 0;
margin: 10px 0;
}
.claim-box .claim-subject { color: #82b1ff; font-size: 1.3rem; font-weight: 700; }
.claim-box .claim-arrow { color: #555; font-size: 1.3rem; margin: 0 8px; }
.claim-box .claim-relation { color: #ffa726; font-size: 1.1rem; }
.claim-box .claim-object { color: #69f0ae; font-size: 1.3rem; font-weight: 700; }
/* Source text container */
.source-text {
font-family: 'Georgia', serif;
font-size: 1.05rem;
line-height: 1.75;
color: #d0d0e0;
}
/* Action buttons row */
div[data-testid="stHorizontalBlock"] button {
font-size: 1.05rem !important;
padding: 12px 20px !important;
}
</style>
"""
# ===========================================================================
# MAIN APP
# ===========================================================================
def main():
st.markdown(_CSS, unsafe_allow_html=True)
# --- Session State Initialization ---
if "df" not in st.session_state:
# Use AUDIT_CSV environment variable to load real data
# Usage: AUDIT_CSV=path/to/data.csv streamlit run scripts/hitl_audit_app.py
csv_path = os.environ.get("AUDIT_CSV")
if csv_path:
from pathlib import Path
p = Path(csv_path)
annotated_path = p.with_name(p.stem + "_annotated" + p.suffix)
# Resume from annotated file if it exists
if annotated_path.exists():
st.session_state.df = _load_real_csv(str(annotated_path))
st.session_state.data_source = csv_path
st.toast(f"♻️ Resumed from {annotated_path.name}")
else:
st.session_state.df = _load_real_csv(csv_path)
st.session_state.data_source = csv_path
else:
st.session_state.df = _generate_mock_df(50)
st.session_state.data_source = "mock"
if "current_index" not in st.session_state:
st.session_state.current_index = 0
df = st.session_state.df
# -----------------------------------------------------------------------
# SIDEBAR β€” Navigation, Filters & Progress
# -----------------------------------------------------------------------
with st.sidebar:
st.title("πŸ”¬ MUSART‑Augmented")
st.caption("FDR Audit Hub")
st.divider()
if st.session_state.data_source != "mock":
st.info(f"πŸ“‚ Loaded: `{st.session_state.data_source}`", icon="πŸ“‚")
# Filter: relation
all_relations = ["All"] + sorted(df["relation_label"].unique().tolist())
selected_relation = st.selectbox("Filter by Relation", all_relations, index=0)
# Filter: annotation status
status_options = ["Pending Only", "All", CORRECT, RELATION_MISMATCH, EXTRACTION_ERROR]
selected_status = st.selectbox("Filter by Status", status_options, index=0)
# Build the filtered view
mask = pd.Series([True] * len(df))
if selected_relation != "All":
mask &= df["relation_label"] == selected_relation
if selected_status == "Pending Only":
mask &= df["human_annotation"] == PENDING
elif selected_status != "All":
mask &= df["human_annotation"] == selected_status
filtered_indices = df.index[mask].tolist()
st.divider()
# Progress
total = len(df)
annotated = len(df[df["human_annotation"] != PENDING])
target_n = min(384, total)
progress = annotated / target_n if target_n > 0 else 0.0
st.metric("Progress", f"{annotated} / {target_n}")
st.progress(min(progress, 1.0))
# Navigation
st.divider()
st.subheader("Navigate")
if filtered_indices:
position_in_filtered = 0
if st.session_state.current_index in filtered_indices:
position_in_filtered = filtered_indices.index(st.session_state.current_index)
nav_val = st.number_input(
f"Record # (1–{len(filtered_indices)})",
min_value=1, max_value=len(filtered_indices),
value=position_in_filtered + 1,
step=1, key="nav_input",
)
st.session_state.current_index = filtered_indices[nav_val - 1]
else:
st.warning("No records match the current filters.")
st.divider()
# Export
csv_data = df.to_csv(index=False).encode("utf-8")
st.download_button(
"⬇️ Export Annotated CSV",
csv_data,
file_name="musart_audit_annotations.csv",
mime="text/csv",
use_container_width=True,
)
# -----------------------------------------------------------------------
# TOP BAR β€” Live Defense Metrics
# -----------------------------------------------------------------------
n_reviewed, fdr, ci_low, ci_high = compute_fdr_metrics(df)
m1, m2, m3 = st.columns(3)
with m1:
st.markdown(
f'<div class="metric-card">'
f'<div class="metric-label">Samples Reviewed</div>'
f'<div class="metric-value">{n_reviewed}</div>'
f'</div>',
unsafe_allow_html=True,
)
with m2:
st.markdown(
f'<div class="metric-card">'
f'<div class="metric-label">Empirical FDR</div>'
f'<div class="metric-value">{fdr:.1%}</div>'
f'</div>',
unsafe_allow_html=True,
)
with m3:
st.markdown(
f'<div class="metric-card">'
f'<div class="metric-label">95 % Wilson CI</div>'
f'<div class="metric-value">[{ci_low:.1%}, {ci_high:.1%}]</div>'
f'</div>',
unsafe_allow_html=True,
)
st.divider()
# -----------------------------------------------------------------------
# Guard: nothing to show
# -----------------------------------------------------------------------
if not filtered_indices:
st.success("πŸŽ‰ All records in this view have been annotated!")
return
idx = st.session_state.current_index
if idx not in df.index:
idx = filtered_indices[0]
st.session_state.current_index = idx
row = df.loc[idx]
# -----------------------------------------------------------------------
# MAIN VIEW β€” Two-Column Layout
# -----------------------------------------------------------------------
col_left, col_right = st.columns([2, 3], gap="large")
# --- Left Column: The Claim ---
with col_left:
st.subheader("πŸ“Œ Factual Claim")
st.markdown(
f'<div class="claim-box">'
f'<span class="claim-subject">{html.escape(str(row["subject_label"]))}</span>'
f'<span class="claim-arrow"> βž” </span>'
f'<span class="claim-relation">{html.escape(str(row["relation_label"]))}</span>'
f'<span class="claim-arrow"> βž” </span>'
f'<span class="claim-object">{html.escape(str(row["extracted_object"]))}</span>'
f'</div>',
unsafe_allow_html=True,
)
# Show extra metadata if available (for real data)
if "question" in row and row.get("question"):
st.markdown(f"**Question:** {row['question']}")
if "all_entity_labels" in row and row.get("all_entity_labels"):
st.markdown(f"**All labels:** `{row['all_entity_labels']}`")
if "category" in row and row.get("category"):
st.caption(f"Category: {row['category']}")
if "wikipedia_url" in row and row.get("wikipedia_url"):
st.markdown(f"πŸ”— [Open Wikipedia article (subject)]({row['wikipedia_url']})")
if "wikidata_url" in row and row.get("wikidata_url"):
st.markdown(f"πŸ”— [Open Wikidata (subject)]({row['wikidata_url']})")
if "entity_wikidata_url" in row and row.get("entity_wikidata_url"):
st.markdown(f"πŸ”— [Open Wikidata (augmented entity)]({row['entity_wikidata_url']})")
# Current status badge
status = row["human_annotation"]
badge_colors = {
PENDING: "πŸ”΅", CORRECT: "🟩",
RELATION_MISMATCH: "🟧", EXTRACTION_ERROR: "πŸŸ₯",
}
st.markdown(f"**Status:** {badge_colors.get(status, '⬜')} {status}")
st.caption(f"ID: `{row['extraction_id']}` Β· Record {filtered_indices.index(idx)+1}/{len(filtered_indices)}")
# --- Right Column: Provenance / Text Grounding ---
with col_right:
st.subheader("πŸ“„ Source Text Grounding")
wiki_text = str(row.get("wikipedia_text", ""))
span_s = int(row.get("span_start", 0))
span_e = int(row.get("span_end", 0))
if wiki_text.strip():
highlighted_html = highlight_span(wiki_text, span_s, span_e)
with st.container(height=350):
st.markdown(
f'<div class="source-text">{highlighted_html}</div>',
unsafe_allow_html=True,
)
else:
# No inline text β€” embed the Wikipedia article in an iframe
wiki_url = str(row.get("wikipedia_url", ""))
entity_label = str(row.get("extracted_object", ""))
if wiki_url:
# Use Text Fragments API to auto-highlight in Chrome/Edge
from urllib.parse import quote
highlight_url = f"{wiki_url}#:~:text={quote(entity_label)}" if entity_label else wiki_url
st.markdown(
f'<iframe src="{highlight_url}" '
f'width="100%" height="500" '
f'style="border:1px solid #3a3a5c; border-radius:8px;"></iframe>',
unsafe_allow_html=True,
)
st.caption(f"πŸ”Ž Auto-highlighting: **{entity_label}** Β· "
f"[Open in new tab β†—]({highlight_url})")
else:
st.warning("No article text or Wikipedia URL available for this record.")
# -----------------------------------------------------------------------
# BOTTOM ACTION BAR β€” Annotation Engine
# -----------------------------------------------------------------------
st.divider()
st.subheader("⚑ Annotate")
b1, b2, b3, b_skip = st.columns(4)
def sync_to_hf_hub(local_path: str):
"""Silently syncs the annotated file to Hugging Face Hub if token is present."""
token = os.environ.get("HF_TOKEN")
space_id = os.environ.get("SPACE_ID") # HF Spaces auto-injects this!
if token and space_id:
import threading
def _upload():
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
api.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=str(local_path),
repo_id=space_id,
repo_type="space"
)
except Exception:
pass
threading.Thread(target=_upload).start()
def _annotate(label: str):
"""Update annotation, auto-save to disk, and advance to next record."""
st.session_state.df.at[idx, "human_annotation"] = label
# --- Auto-save to disk ---
src = st.session_state.data_source
if src and src != "mock":
# Save next to the input file: foo.csv β†’ foo_annotated.csv
from pathlib import Path
p = Path(src)
save_path = p.with_name(p.stem + "_annotated" + p.suffix)
else:
save_path = "data/manual_evaluation_annotated.csv"
st.session_state.df.to_csv(save_path, index=False)
# --- Auto-save to Hugging Face Hub ---
sync_to_hf_hub(save_path)
# Advance to next record in the filtered list
pos = filtered_indices.index(idx)
if pos + 1 < len(filtered_indices):
st.session_state.current_index = filtered_indices[pos + 1]
with b1:
if st.button("🟩 Correct (TP)", use_container_width=True, type="primary"):
_annotate(CORRECT)
st.rerun()
with b2:
if st.button("🟧 Relation Mismatch", use_container_width=True):
_annotate(RELATION_MISMATCH)
st.rerun()
with b3:
if st.button("πŸŸ₯ Extraction Error", use_container_width=True):
_annotate(EXTRACTION_ERROR)
st.rerun()
with b_skip:
if st.button("⏭️ Skip", use_container_width=True):
pos = filtered_indices.index(idx)
if pos + 1 < len(filtered_indices):
st.session_state.current_index = filtered_indices[pos + 1]
st.rerun()
if __name__ == "__main__":
main()