import os
import time
import streamlit as st
import pandas as pd
import numpy as np
import re
import string
import json
from io import BytesIO
# --- Visualization & PPTX ---
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from pptx import Presentation
from pptx.util import Inches, Pt
# --- NLP & Analysis ---
from gliner import GLiNER
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
# --- 1. CONFIGURATION & STYLING ---
os.environ['HF_HOME'] = '/tmp'
entity_color_map = {
"person": "#10b981", "country": "#3b82f6", "city": "#4ade80",
"organization": "#f59e0b", "date": "#8b5cf6", "time": "#ec4899",
"cardinal": "#06b6d4", "money": "#f43f5e", "position": "#a855f7"
}
labels = list(entity_color_map.keys())
category_mapping = {
"People": ["person", "organization", "position"],
"Locations": ["country", "city"],
"Time": ["date", "time"],
"Numbers": ["money", "cardinal"]
}
reverse_category_mapping = {label: cat for cat, lbls in category_mapping.items() for label in lbls}
# --- 2. CORE UTILITY FUNCTIONS ---
def remove_trailing_punctuation(text_string):
return text_string.rstrip(string.punctuation)
def highlight_entities(text, df_entities):
if df_entities.empty:
return text
# Sort entities by start index descending to prevent index shifting
entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
highlighted_text = text
for entity in entities:
start, end = entity['start'], entity['end']
label, entity_text = entity['label'], entity['text']
color = entity_color_map.get(label, '#000000')
highlight_html = f'{entity_text}'
highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:]
return f'
{highlighted_text}
'
def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10):
documents = df_entities['text'].unique().tolist()
if len(documents) < 2: return None
try:
tfidf_vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), min_df=1)
tfidf = tfidf_vectorizer.fit_transform(documents)
feature_names = tfidf_vectorizer.get_feature_names_out()
lda = LatentDirichletAllocation(n_components=num_topics, random_state=42)
lda.fit(tfidf)
topic_data = []
for idx, topic in enumerate(lda.components_):
top_indices = topic.argsort()[:-num_top_words - 1:-1]
for i in top_indices:
topic_data.append({'Topic_ID': f'Topic #{idx + 1}', 'Word': feature_names[i], 'Weight': topic[i]})
return pd.DataFrame(topic_data)
except: return None
# --- 3. VISUALIZATION FUNCTIONS (FIXED TITLES) ---
def create_topic_word_bubbles(df_topic_data):
df = df_topic_data.rename(columns={'Topic_ID': 'topic','Word': 'word', 'Weight': 'weight'})
df['x_pos'] = range(len(df))
fig = px.scatter(df, x='x_pos', y='weight', size='weight', color='topic', text='word', title='Topic Word Weights')
# FIX: Increased top margin for title visibility
fig.update_layout(margin=dict(t=80, b=50), xaxis_showticklabels=False, plot_bgcolor='#f9f9f9')
fig.update_traces(textposition='middle center', textfont=dict(color='white', size=10))
return fig
def generate_network_graph(df, raw_text):
counts = df['text'].value_counts().reset_index(name='frequency')
unique = df.drop_duplicates(subset=['text']).merge(counts, on='text')
num_nodes = len(unique)
thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
unique['x'] = 10 * np.cos(thetas)
unique['y'] = 10 * np.sin(thetas)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=unique['x'], y=unique['y'], mode='markers+text', text=unique['text'],
marker=dict(size=unique['frequency']*5 + 15, color=[entity_color_map.get(l, '#ccc') for l in unique['label']])
))
# FIX: Added top margin for Title
fig.update_layout(title="Entity Relationship Map", margin=dict(t=80), showlegend=False, xaxis_visible=False, yaxis_visible=False)
return fig
# --- 4. EXPORT FUNCTIONS ---
def generate_html_report(df, text_input, elapsed_time, df_topic_data):
# Prepare all charts with fixed layout margins
fig_tree = px.treemap(df, path=[px.Constant("All"), 'category', 'label', 'text'], values='score', title="Entity Hierarchy")
fig_tree.update_layout(margin=dict(t=60, b=20, l=20, r=20))
tree_html = fig_tree.to_html(full_html=False, include_plotlyjs='cdn')
net_html = generate_network_graph(df, text_input).to_html(full_html=False, include_plotlyjs='cdn')
html_template = f"""
NER & Topic Analysis Report
Processing Time: {elapsed_time:.2f}s
1. Highlighted Entities
{highlight_entities(text_input, df)}
2. Visual Analytics
{tree_html}
{net_html}
"""
return html_template
def generate_pptx_report(df):
prs = Presentation()
slide = prs.slides.add_slide(prs.slide_layouts[0])
slide.shapes.title.text = "Entity Analysis"
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = "Entity List"
tf = slide.placeholders[1].text_frame
for i, row in df.head(15).iterrows():
p = tf.add_paragraph()
p.text = f"{row['text']} ({row['label']})"
buffer = BytesIO()
prs.save(buffer)
buffer.seek(0)
return buffer
# --- 5. STREAMLIT UI & LOGIC ---
st.set_page_config(layout="wide", page_title="DataHarvest NER")
@st.cache_resource
def load_model():
return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
model = load_model()
# Session State Init
if 'results_df' not in st.session_state:
st.session_state.results_df = pd.DataFrame()
st.session_state.show = False
st.subheader("Entity & Topic Analysis Report Generator", divider="blue")
text = st.text_area("Paste text here (max 1000 words):", height=250)
if st.button("Run Analysis"):
if text:
with st.spinner("Processing..."):
start = time.time()
entities = model.predict_entities(text, labels)
df = pd.DataFrame(entities)
if not df.empty:
df['text'] = df['text'].apply(remove_trailing_punctuation)
df['category'] = df['label'].map(reverse_category_mapping)
st.session_state.results_df = df
st.session_state.elapsed = time.time() - start
st.session_state.topics = perform_topic_modeling(df)
st.session_state.show = True
else:
st.warning("No entities found.")
if st.session_state.show:
df = st.session_state.results_df
st.markdown("### 1. Extracted Entities")
st.markdown(highlight_entities(text, df), unsafe_allow_html=True)
t1, t2, t3 = st.tabs(["Charts", "Network Map", "Topics"])
with t1:
fig_tree = px.treemap(df, path=['category', 'label', 'text'], values='score', title="Entity Treemap")
# Ensure the preview also has margins
fig_tree.update_layout(margin=dict(t=50))
st.plotly_chart(fig_tree, use_container_width=True)
with t2:
st.plotly_chart(generate_network_graph(df, text), use_container_width=True)
with t3:
if st.session_state.topics is not None:
st.plotly_chart(create_topic_word_bubbles(st.session_state.topics), use_container_width=True)
else:
st.info("Not enough data for topic modeling.")
st.divider()
st.markdown("### Download Artifacts")
c1, c2, c3 = st.columns(3)
with c1:
st.download_button("Download HTML Report",
generate_html_report(df, text, st.session_state.elapsed, st.session_state.topics),
"report.html", "text/html", type="primary")
with c2:
csv = df.to_csv(index=False).encode('utf-8')
st.download_button("Download CSV Data", csv, "entities.csv", "text/csv")
with c3:
st.download_button("Download PPTX Summary", generate_pptx_report(df), "summary.pptx")