VotIE-demo / app.py
Anonymous3445's picture
Solved windowing issues
23eab0c
"""
VotIE Demo - Portuguese Voting Information Extraction
Streamlit app for extracting structured voting information from Portuguese municipal meeting minutes
"""
import streamlit as st
import json
import sys
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
from typing import Dict, List, Any
from torchcrf import CRF
import nltk
nltk.download('punkt_tab', quiet=True)
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
from src.utils.event_constructor import EventConstructor
# Page configuration
st.set_page_config(
page_title="VotIE - Voting Information Extraction",
page_icon="πŸ—³οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
.main-header {
font-size: 3rem;
font-weight: 800;
margin-bottom: 1.5rem;
margin-top: 0.5rem;
text-align: center;
}
.sub-header {
font-size: 1.1rem;
color: #666;
margin-bottom: 2rem;
}
.entity-label {
padding: 2px 8px;
border-radius: 4px;
font-size: 0.85rem;
font-weight: 600;
margin: 0 2px;
display: inline-block;
}
.highlight-subject {
background-color: #fff3cd;
border-bottom: 2px solid #ffc107;
padding: 2px;
color: #000;
}
.highlight-voting {
background-color: #d4edda;
border-bottom: 2px solid #28a745;
padding: 2px;
color: #000;
}
.highlight-counting {
background-color: #d1ecf1;
border-bottom: 2px solid #17a2b8;
padding: 2px;
color: #000;
}
.highlight-voter-favor {
background-color: #e2d6f3;
border-bottom: 2px solid #9c27b0;
padding: 2px;
color: #000;
}
.highlight-voter-against {
background-color: #f8d7da;
border-bottom: 2px solid #dc3545;
padding: 2px;
color: #000;
}
.highlight-voter-abstention {
background-color: #ffe5b4;
border-bottom: 2px solid #fd7e14;
padding: 2px;
color: #000;
}
.highlight-voter-absent {
background-color: #e0e0e0;
border-bottom: 2px solid #6c757d;
padding: 2px;
color: #000;
}
</style>
""", unsafe_allow_html=True)
# Entity color mapping (using background colors from CSS)
ENTITY_COLORS = {
'SUBJECT': ('🏷️', '#fff3cd', 'highlight-subject'),
'VOTING': ('βœ…', '#d4edda', 'highlight-voting'),
'COUNTING-UNANIMITY': ('πŸ‘₯', '#d1ecf1', 'highlight-counting'),
'COUNTING-MAJORITY': ('πŸ‘₯', '#d1ecf1', 'highlight-counting'),
'VOTER-FAVOR': ('πŸ‘', '#e2d6f3', 'highlight-voter-favor'),
'VOTER-AGAINST': ('πŸ‘Ž', '#f8d7da', 'highlight-voter-against'),
'VOTER-ABSTENTION': ('🀷', '#ffe5b4', 'highlight-voter-abstention'),
'VOTER-ABSENT': ('❌', '#e0e0e0', 'highlight-voter-absent'),
}
# Cache model loading
@st.cache_resource
def load_model():
"""Load the VotIE model and tokenizer."""
model_name = "Anonymous3445/XLM-RoBERTa-CRF-VotIE"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
return tokenizer, model
def predict_with_windowing(text, tokenizer, model, max_length=512, overlap_words=50):
"""
Run model prediction with sliding window for texts that exceed the model's max sequence length.
For short texts that fit within max_length tokens, runs a single prediction.
For long texts, splits into overlapping word-level windows, predicts each,
and merges results using center-preference strategy.
Returns:
List[dict] with keys 'word' and 'label', same format as model.decode().
"""
# Fast path: check if text fits in a single window
inputs = tokenizer(text, return_tensors="pt")
if inputs['input_ids'].shape[1] <= max_length:
return model.decode(**inputs, tokenizer=tokenizer, text=text), 1
# Text is too long β€” split into overlapping windows
words = nltk.word_tokenize(text, language='portuguese')
effective_max_length = int((max_length - 2) * 0.9) # Reserve for [CLS]/[SEP] + safety margin
# Build windows of words that fit within the token limit
windows = [] # Each entry is a list of words
start_idx = 0
while start_idx < len(words):
# Conservative initial window size
initial_size = min(effective_max_length // 2, len(words) - start_idx)
end_idx = start_idx + initial_size
window_words = words[start_idx:end_idx]
window_text = ' '.join(window_words)
subword_count = len(tokenizer.tokenize(window_text))
# Grow window by 10 words at a time until hitting the limit
while subword_count < effective_max_length and end_idx < len(words):
candidate_end = min(end_idx + 10, len(words))
candidate_text = ' '.join(words[start_idx:candidate_end])
candidate_count = len(tokenizer.tokenize(candidate_text))
if candidate_count <= effective_max_length:
end_idx = candidate_end
window_words = words[start_idx:end_idx]
subword_count = candidate_count
else:
break
# Shrink by 10% iteratively if still over limit
while subword_count > effective_max_length and len(window_words) > 1:
new_size = max(1, int(len(window_words) * 0.9))
window_words = window_words[:new_size]
end_idx = start_idx + new_size
subword_count = len(tokenizer.tokenize(' '.join(window_words)))
windows.append(window_words)
if end_idx >= len(words):
break
step_size = max(len(window_words) - overlap_words, 1)
start_idx += step_size
# Safety: prevent infinite loops
if len(windows) > 100:
break
# Predict each window
window_predictions = []
for window_words in windows:
window_text = ' '.join(window_words)
window_inputs = tokenizer(window_text, return_tensors="pt", truncation=True, max_length=max_length)
preds = model.decode(**window_inputs, tokenizer=tokenizer, text=window_text)
window_predictions.append(preds)
# Merge predictions using center-preference strategy
if len(window_predictions) == 1:
return window_predictions[0], 1
trim = overlap_words // 2
merged = []
for i, preds in enumerate(window_predictions):
if i == 0:
# First window: keep everything except last trim predictions
keep_end = max(len(preds) - trim, 1)
merged.extend(preds[:keep_end])
elif i == len(window_predictions) - 1:
# Last window: skip first trim predictions
skip_start = min(trim, len(preds) - 1)
merged.extend(preds[skip_start:])
else:
# Middle windows: skip first trim, keep until last trim
skip_start = min(trim, len(preds) - 1)
keep_end = max(len(preds) - trim, skip_start + 1)
merged.extend(preds[skip_start:keep_end])
return merged, len(windows)
# Cache examples from one document
@st.cache_data
def load_document_examples():
"""Load examples from a single municipal meeting document."""
demo_path = Path(__file__).parent / "data" / "demo_examples.json"
with open(demo_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['examples'], data['document_id']
@st.cache_data
def get_examples_by_category():
"""Group examples by category."""
examples, _ = load_document_examples()
from collections import defaultdict
categories = defaultdict(list)
for i, ex in enumerate(examples):
category = ex.get('category', 'Uncategorized')
categories[category].append((i, ex))
# Define category order for display
category_order = [
'Unanimous Votes',
'Majority with Abstentions',
'Against Votes',
'Absent Voters'
]
# Sort categories according to defined order
sorted_categories = []
for cat in category_order:
if cat in categories:
sorted_categories.append((cat, categories[cat]))
# Add any remaining categories not in the order
for cat in sorted(categories.keys()):
if cat not in category_order:
sorted_categories.append((cat, categories[cat]))
return sorted_categories
def get_entity_type(label: str) -> str:
"""Extract entity type from BIO label."""
if label == 'O':
return 'O'
return label[2:] if label.startswith(('B-', 'I-')) else label
def calculate_metrics(predicted: List[str], ground_truth: List[str]) -> Dict[str, Any]:
"""Calculate precision, recall, and F1 for entity-level predictions."""
correct = sum(1 for p, g in zip(predicted, ground_truth) if p == g)
total = len(ground_truth)
accuracy = correct / total if total > 0 else 0
# Count entity-level metrics
pred_entities = set()
true_entities = set()
# Extract predicted entities with positions
current_entity = None
start_pos = 0
for i, label in enumerate(predicted):
if label.startswith('B-'):
if current_entity:
pred_entities.add((current_entity, start_pos, i-1))
current_entity = label[2:]
start_pos = i
elif label == 'O' and current_entity:
pred_entities.add((current_entity, start_pos, i-1))
current_entity = None
if current_entity:
pred_entities.add((current_entity, start_pos, len(predicted)-1))
# Extract ground truth entities
current_entity = None
start_pos = 0
for i, label in enumerate(ground_truth):
if label.startswith('B-'):
if current_entity:
true_entities.add((current_entity, start_pos, i-1))
current_entity = label[2:]
start_pos = i
elif label == 'O' and current_entity:
true_entities.add((current_entity, start_pos, i-1))
current_entity = None
if current_entity:
true_entities.add((current_entity, start_pos, len(ground_truth)-1))
tp = len(pred_entities & true_entities)
fp = len(pred_entities - true_entities)
fn = len(true_entities - pred_entities)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {
'token_accuracy': accuracy,
'entity_precision': precision,
'entity_recall': recall,
'entity_f1': f1,
'correct_tokens': correct,
'total_tokens': total,
'true_positives': tp,
'false_positives': fp,
'false_negatives': fn
}
def render_highlighted_text(tokens: List[str], labels: List[str], ground_truth: List[str] = None) -> str:
"""Render text with entity highlighting."""
html_parts = []
for i, (token, label) in enumerate(zip(tokens, labels)):
entity_type = get_entity_type(label)
# Check if there's a mismatch with ground truth
mismatch = False
if ground_truth and i < len(ground_truth):
mismatch = label != ground_truth[i]
if entity_type == 'O':
if mismatch:
html_parts.append(f'<span style="background-color: #ffcccc; padding: 2px; border: 1px solid red; color: #000;" title="Predicted: {label} | Truth: {ground_truth[i]}">{token}</span> ')
else:
html_parts.append(f'{token} ')
else:
_, color, css_class = ENTITY_COLORS.get(entity_type, ('', '#000', ''))
style_extra = 'border: 2px solid red; ' if mismatch else ''
title = f"Predicted: {label}" + (f" | Truth: {ground_truth[i]}" if mismatch else f": {label}")
html_parts.append(
f'<span class="{css_class}" style="{style_extra}" title="{title}">{token}</span> '
)
return ''.join(html_parts)
def render_comparison_text(tokens: List[str], pred_labels: List[str], ground_truth_labels: List[str]) -> str:
"""Render comparison with green boxes for correct and red boxes for incorrect predictions."""
html_parts = []
for i, token in enumerate(tokens):
pred_label = pred_labels[i] if i < len(pred_labels) else 'O'
gt_label = ground_truth_labels[i] if i < len(ground_truth_labels) else 'O'
is_correct = pred_label == gt_label
pred_type = get_entity_type(pred_label)
gt_type = get_entity_type(gt_label)
if is_correct:
# Green box for correct predictions
if pred_type == 'O':
html_parts.append(f'{token} ')
else:
_, _, css_class = ENTITY_COLORS.get(pred_type, ('', '', ''))
html_parts.append(
f'<span class="{css_class}" style="border: 2px solid #28a745; box-shadow: 0 0 3px #28a745;" title="βœ“ Correct: {pred_label}">{token}</span> '
)
else:
# Red box for incorrect predictions
if pred_type == 'O':
# Predicted O but should be something else
html_parts.append(
f'<span style="background-color: #ffdddd; padding: 2px; border: 2px solid #dc3545; box-shadow: 0 0 3px #dc3545; color: #000;" title="βœ— Predicted: {pred_label} | Truth: {gt_label}">{token}</span> '
)
else:
_, _, css_class = ENTITY_COLORS.get(pred_type, ('', '', ''))
html_parts.append(
f'<span class="{css_class}" style="border: 2px solid #dc3545; box-shadow: 0 0 3px #dc3545;" title="βœ— Predicted: {pred_label} | Truth: {gt_label}">{token}</span> '
)
return ''.join(html_parts)
def format_event(event: Dict[str, Any]) -> None:
"""Display structured event in a clean format."""
if not event['has_voting_event']:
st.warning("No voting event detected in this text.")
return
ev = event['event']
# Subject
st.markdown("**πŸ“‹ Subject**")
if ev['subject']:
st.info(ev['subject'])
else:
st.caption("_Not detected_")
# Voting Expressions
st.markdown(f"**βœ… Voting Expressions** ({len(ev['voting_expressions'])})")
if ev['voting_expressions']:
for expr in ev['voting_expressions']:
st.markdown(f"- {expr}")
else:
st.caption("_None detected_")
# Counting
if ev['counting']:
st.markdown(f"**πŸ‘₯ Counting** ({len(ev['counting'])})")
for count in ev['counting']:
st.markdown(f"- Text: **\"{count['text']}\"** | Type: _{count['type']}_")
# Participants
if ev['participants']:
st.markdown(f"**πŸ—³οΈ Participants** ({len(ev['participants'])})")
positions = {'Favor': [], 'Against': [], 'Abstention': [], 'Absent': []}
for p in ev['participants']:
positions[p['position']].append(p['text'])
for position, names in positions.items():
if names:
st.markdown(f"**{position}** ({len(names)})")
for name in names:
st.caption(f"β€’ {name}")
# Outcome
st.markdown("**🎯 Outcome**")
if ev['outcome'] == 'Approved':
st.success(f"βœ… {ev['outcome']}")
elif ev['outcome'] == 'Rejected':
st.error(f"❌ {ev['outcome']}")
else:
st.info("❓ Cannot be determined")
def main():
# Header - centered with large title
st.markdown("<h1 style='text-align: center; font-size: 2.8rem; font-weight: 700; margin-bottom: 0.5rem;'>πŸ—³οΈ VotIE: Voting Information Extraction</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; font-size: 1.2rem; color: #666; margin-bottom: 2rem;'>VotIE extracts structured voting information from Portuguese text using <strong>XLM-RoBERTa</strong> + CRF layer.</p>", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### βš™οΈ Input Mode")
input_mode = st.radio(
"Select input source:",
["Custom Text", "Sample Data"],
label_visibility="collapsed"
)
# Show example selector only in Sample Data mode
example_idx = None
if input_mode == "Sample Data":
st.markdown("---")
st.markdown("### πŸ“„ Select Example by Category")
# Get examples grouped by category
categories = get_examples_by_category()
# Category selector
category_names = [cat for cat, _ in categories]
selected_category = st.selectbox(
"Choose entity type to test:",
category_names,
format_func=lambda cat: f"{cat}",
key="category_selector"
)
# Get examples for selected category
selected_category_examples = None
for cat, examples_list in categories:
if cat == selected_category:
selected_category_examples = examples_list
break
if selected_category_examples:
# Show category description
if selected_category_examples[0][1].get('category_description'):
st.caption(selected_category_examples[0][1]['category_description'])
# Example selector within category
example_idx = st.selectbox(
f"Choose from {len(selected_category_examples)} example(s):",
[idx for idx, _ in selected_category_examples],
format_func=lambda i: f"{selected_category_examples[[idx for idx, _ in selected_category_examples].index(i)][1]['id'].split('_')[-1]}",
key="example_selector",
label_visibility="collapsed"
)
st.markdown("---")
st.markdown("### 🏷️ Entity Types")
for entity, (emoji, color, _) in ENTITY_COLORS.items():
st.markdown(
f'{emoji} <span class="entity-label" style="background-color:{color};color:#000">{entity}</span>',
unsafe_allow_html=True
)
st.markdown("---")
st.markdown("**Model**: [XLM-RoBERTa-CRF-VotIE](https://huggingface.co/Anonymous3445/XLM-RoBERTa-CRF-VotIE)")
# Main content area - unified layout
st.markdown("---")
# How it works - collapsible section at top
with st.expander("πŸ”§ How It Works", expanded=False):
st.markdown("""
**Process**:
1. **Tokenization**: Text is split into tokens using XLM-RoBERTa tokenizer
2. **Entity Recognition**: Each token is classified using XLM-RoBERTa + CRF model
3. **Token Classification**: Tokens are labeled with BIO tags (B-SUBJECT, I-VOTING, etc.)
4. **Event Construction**: Labeled entities are grouped into structured voting events
5. **Outcome Determination**: System infers voting results from extracted data
**Example**:
**Input**:
```
"A proposta foi aprovada por unanimidade, tendo votado a favor os vereadores JoΓ£o Silva e Maria Santos."
```
**Token Classification**:
```
A β†’ B-SUBJECT
proposta β†’ I-SUBJECT
foi β†’ O
aprovada β†’ B-VOTING
por β†’ B-COUNTING-UNANIMITY
unanimidade β†’ I-COUNTING-UNANIMITY
...
```
**Output**:
```json
{
"subject": "A proposta",
"voting_expressions":
["aprovada"],
"counting": [{
"text": "por unanimidade",
"type": "unanimity"
}],
"participants": [
{"name": "JoΓ£o Silva",
"position": "Favor"},
{"name": "Maria Santos",
"position": "Favor"}
],
"outcome": "Approved"
}
```
""")
st.markdown("## πŸ“ Input Text")
# Determine text to display based on mode
if input_mode == "Sample Data" and example_idx is not None:
examples, _ = load_document_examples()
example = examples[example_idx]
text_input = st.text_area(
"Text from selected example:",
value=example['text'],
height=300,
key=f"sample_text_input_{example_idx}" # Dynamic key based on example index
)
has_ground_truth = True
ground_truth_tokens = example['tokens']
ground_truth_labels = example['labels']
else:
text_input = st.text_area(
"Enter Portuguese municipal meeting text:",
height=300,
placeholder="Paste Portuguese municipal meeting text here...",
key="custom_text_input"
)
has_ground_truth = False
# Predict button
if st.button("πŸ” Extract Voting Information", type="primary", use_container_width=True):
if text_input.strip():
with st.spinner("Loading model and processing..."):
# Load model
tokenizer, model = load_model()
# Predict (with windowing for long texts)
predictions, num_windows = predict_with_windowing(text_input, tokenizer, model)
if num_windows > 1:
st.info(f"Text was split into {num_windows} overlapping windows for processing.")
# Extract tokens and labels
tokens = [p['word'] for p in predictions]
labels = [p['label'] for p in predictions]
# For sample data with ground truth, align predictions
if has_ground_truth:
# Try both the original text and reconstructed text from tokens
text_from_tokens = ' '.join(ground_truth_tokens)
# Try on reconstructed text from ground truth tokens
predictions_reconstructed, _ = predict_with_windowing(text_from_tokens, tokenizer, model)
pred_tokens_recon = [p['word'] for p in predictions_reconstructed]
pred_labels_recon = [p['label'] for p in predictions_reconstructed]
# Use whichever matches ground truth better
if len(pred_tokens_recon) == len(ground_truth_tokens):
tokens = pred_tokens_recon
labels = pred_labels_recon
tokens_match = True
elif len(tokens) == len(ground_truth_tokens):
tokens_match = True
else:
# Neither matches perfectly - use ground truth tokens
st.warning(f"⚠️ Tokenization mismatch: Model produced {len(tokens)} tokens, ground truth has {len(ground_truth_tokens)} tokens.")
tokens = ground_truth_tokens
labels = pred_labels_recon if len(pred_labels_recon) == len(ground_truth_tokens) else labels
tokens_match = False
# Align labels to ground truth length
if len(labels) != len(ground_truth_labels):
if len(labels) < len(ground_truth_labels):
labels = labels + ['O'] * (len(ground_truth_labels) - len(labels))
else:
labels = labels[:len(ground_truth_labels)]
# Construct event
event_constructor = EventConstructor()
event = event_constructor.construct_event(
tokens=tokens,
labels=labels,
example_id="sample" if has_ground_truth else "custom"
)
st.markdown("---")
# Display results based on whether we have ground truth
if has_ground_truth:
# Calculate metrics (for use in comparison tab)
metrics = calculate_metrics(labels, ground_truth_labels)
# Construct ground truth event
event_gt = event_constructor.construct_event(
tokens=ground_truth_tokens,
labels=ground_truth_labels,
example_id=example['id']
)
# Show tabs: Model Prediction, Ground Truth, Comparison
tab1, tab2, tab3 = st.tabs(["πŸ€– Model Prediction", "βœ… Ground Truth", "πŸ” Comparison"])
with tab1:
st.markdown("### 🏷️ Entity Recognition")
st.markdown(render_highlighted_text(tokens, labels), unsafe_allow_html=True)
st.markdown("---")
st.markdown("### πŸ“Š Structured Event")
format_event(event)
with st.expander("πŸ“„ View Raw JSON"):
st.json(event, expanded=True)
with tab2:
st.markdown("### 🏷️ Ground Truth Labels")
st.markdown(render_highlighted_text(ground_truth_tokens, ground_truth_labels), unsafe_allow_html=True)
st.markdown("---")
st.markdown("### πŸ“Š Ground Truth Event")
format_event(event_gt)
with st.expander("πŸ“„ View Raw JSON"):
st.json(event_gt, expanded=False)
with tab3:
st.markdown("### 🏷️ Comparison View")
st.caption("Green border = correct prediction | Red border = incorrect prediction")
st.markdown(render_comparison_text(tokens, labels, ground_truth_labels), unsafe_allow_html=True)
st.markdown("---")
# Performance metrics
st.markdown("### πŸ“Š Performance Metrics")
if tokens_match:
st.success("βœ“ Tokenization aligned perfectly with ground truth")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Token Accuracy", f"{metrics['token_accuracy']:.1%}")
with col2:
st.metric("Entity Precision", f"{metrics['entity_precision']:.1%}")
with col3:
st.metric("Entity Recall", f"{metrics['entity_recall']:.1%}")
with col4:
st.metric("Entity F1", f"{metrics['entity_f1']:.1%}")
st.markdown("---")
# Side-by-side event comparison
st.markdown("### πŸ“‹ Event Comparison")
comp_col1, comp_col2 = st.columns(2)
with comp_col1:
st.markdown("**πŸ€– Predicted Event**")
format_event(event)
with st.expander("πŸ“„ View Raw JSON"):
st.json(event, expanded=False)
with comp_col2:
st.markdown("**βœ… Ground Truth Event**")
format_event(event_gt)
with st.expander("πŸ“„ View Raw JSON"):
st.json(event_gt, expanded=False)
else:
# Custom text mode - only show predictions
st.markdown("### 🏷️ Entity Recognition")
st.markdown(render_highlighted_text(tokens, labels), unsafe_allow_html=True)
st.markdown("---")
st.markdown("### πŸ“Š Structured Event")
format_event(event)
with st.expander("πŸ“„ View Raw JSON"):
st.json(event, expanded=False)
else:
st.warning("Please enter some text to analyze.")
if __name__ == "__main__":
main()