Policy_Lens / src /streamlit_app.py
ANASAKHTAR's picture
Update src/streamlit_app.py
dde2b61 verified
# app.py
import streamlit as st
from transformers import pipeline
import spacy
from collections import Counter
import matplotlib.pyplot as plt
# Load AI models ONCE at startup to cache them and avoid reloading on every interaction
@st.cache_resource
def load_summarizer():
"""Load the text summarization model"""
return pipeline("summarization", model="facebook/bart-large-cnn")
@st.cache_resource
def load_ner_model():
"""Load the Named Entity Recognition model"""
return spacy.load("en_core_web_sm")
# Initialize the models
summarizer = load_summarizer()
nlp = load_ner_model()
def summarize_text(text):
"""Function to summarize long text"""
# Limit input text to avoid model limits
input_text = text[:2000]
summary = summarizer(input_text, max_length=150, min_length=30, do_sample=False)
return summary[0]['summary_text']
def extract_entities(text):
"""Function to find people, orgs, money, and laws"""
doc = nlp(text)
entities = []
for ent in doc.ents:
# Filter for only the entity types we care about
if ent.label_ in ['PERSON', 'ORG', 'GPE', 'MONEY', 'LAW']:
entities.append((ent.text, ent.label_))
return entities
# app.py (continued)
# Configure the page
st.set_page_config(page_title="Policy Lens", page_icon="πŸ“œ", layout="wide")
st.title("πŸ“œ Policy Lens")
st.markdown("**AI-Powered Legislative Analysis** - Paste a bill or policy below to get a plain language summary and key insights.")
# Input section
input_text = st.text_area("Paste Legislative Text Here:", height=250, placeholder="Paste the text of a bill, policy, or news article here...")
if st.button("Analyze", type="primary") and input_text:
with st.spinner("Analyzing text with AI..."):
# Create a layout with columns
col1, col2 = st.columns(2)
with col1:
st.header("πŸ“‹ Summary")
summary = summarize_text(input_text)
st.success(summary)
with col2:
st.header("🧠 Key Entities")
entities = extract_entities(input_text)
# Categorize the entities
people = [text for text, label in entities if label == 'PERSON']
organizations = [text for text, label in entities if label == 'ORG']
money = [text for text, label in entities if label == 'MONEY']
locations = [text for text, label in entities if label == 'GPE']
# Display the entities in an organized way
if people:
st.write("**People:**", ", ".join(set(people))) # Use set() to remove duplicates
if organizations:
st.write("**Organizations:**", ", ".join(set(organizations)))
if money:
st.write("**Financials:**", ", ".join(set(money)))
if locations:
st.write("**Locations:**", ", ".join(set(locations)))
# Visualization section (optional but impressive)
st.header("πŸ“Š Entity Frequency")
if entities:
# Count the most common entities
entity_counts = Counter([label for text, label in entities])
# Create a simple bar chart
fig, ax = plt.subplots()
ax.bar(entity_counts.keys(), entity_counts.values())
ax.set_ylabel('Frequency')
ax.set_title('Most Common Entity Types')
plt.xticks(rotation=45)
st.pyplot(fig)
else:
st.info("No significant entities found to display.")
else:
st.info("πŸ‘† Please paste some text to analyze. For demo purposes, you can find text on sites like congress.gov")
# Add a footer
st.markdown("---")
st.caption("Policy Lens uses Facebook's BART model for summarization and spaCy for entity recognition.")