import os
os.environ['HF_HOME'] = '/tmp'
import time
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import io
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import re
import string
import json
from itertools import cycle
# --- PPTX Imports (Note: pptx must be installed via 'pip install python-pptx') ---
from io import BytesIO
import plotly.io as pio
# ---------------------------
# --- Stable Scikit-learn LDA Imports ---
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
# ------------------------------
from gliner import GLiNER
from streamlit_extras.stylable_container import stylable_container
# Using a try/except for comet_ml import
try:
from comet_ml import Experiment
except ImportError:
class Experiment:
def __init__(self, **kwargs): pass
def log_parameter(self, *args): pass
def log_table(self, *args): pass
def end(self): pass
# --- Model Home Directory (Fix for deployment environments) ---
os.environ['HF_HOME'] = '/tmp'
# --- Fixed Label Definitions and Mappings (Used as Fallback) ---
FIXED_LABELS = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
FIXED_ENTITY_COLOR_MAP = {
"person": "#10b981", # Green
"country": "#3b82f6", # Blue
"city": "#4ade80", # Light Green
"organization": "#f59e0b", # Orange
"date": "#8b5cf6", # Purple
"time": "#ec4899", # Pink
"cardinal": "#06b6d4", # Cyan
"money": "#f43f5e", # Red
"position": "#a855f7", # Violet
}
# --- Fixed Category Mapping ---
FIXED_CATEGORY_MAPPING = {
"People & Roles": ["person", "organization", "position"],
"Locations": ["country", "city"],
"Time & Dates": ["date", "time"],
"Numbers & Finance": ["money", "cardinal"]}
REVERSE_FIXED_CATEGORY_MAPPING = {label: category for category, label_list in FIXED_CATEGORY_MAPPING.items() for label in label_list}
# --- Dynamic Color Generator for Custom Labels ---
# Use Plotly's Alphabet set for a large pool of distinct colors
COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet)
def extract_label(node_name):
"""Extracts the label from a node string like 'Text (Label)'."""
match = re.search(r'\(([^)]+)\)$', node_name)
return match.group(1) if match else "Unknown"
def remove_trailing_punctuation(text_string):
"""Removes trailing punctuation from a string."""
return text_string.rstrip(string.punctuation)
def get_dynamic_color_map(active_labels, fixed_map):
"""Generates a color map, using fixed colors if available, otherwise dynamic colors."""
color_map = {}
# If using fixed labels, use the fixed map directly
if active_labels == FIXED_LABELS:
return fixed_map
# If using custom labels, generate colors
for label in active_labels:
# Prioritize fixed color if the custom label happens to match a fixed one
if label in fixed_map:
color_map[label] = fixed_map[label]
else:
# Generate a new color from the palette
color_map[label] = next(COLOR_PALETTE)
return color_map
def highlight_entities(text, df_entities, entity_color_map):
"""
Generates HTML to display text with entities highlighted and colored.
IMPORTANT: Assumes 'start' and 'end' are relative to the 'text' input.
"""
if df_entities.empty:
return text
# Sort entities by start index descending to insert highlights without affecting subsequent indices
entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
highlighted_text = text
for entity in entities:
# Ensure the entity indices are within the bounds of the full text
start = max(0, entity['start'])
end = min(len(text), entity['end'])
# Get entity text from the full document based on its indices
# The 'text' column in the dataframe is now an attribute of the chunked text, not the original span
entity_text_from_full_doc = text[start:end]
label = entity['label']
color = entity_color_map.get(label, '#000000')
# Create a span with background color and tooltip
highlight_html = f'{entity_text_from_full_doc}'
# Replace the original text segment with the highlighted HTML
highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:]
# Use a div to mimic the Streamlit input box style for the report
return f'
{highlighted_text}
'
def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10):
"""Performs basic Topic Modeling using LDA."""
documents = df_entities['text'].unique().tolist()
# Topic modeling is usually more effective with full sentences/paragraphs,
# but here we use the extracted entity texts as per the original code's intent.
if len(documents) < 2:
return None
N = min(num_top_words, len(documents))
try:
tfidf_vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english', ngram_range=(1, 3))
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
if len(tfidf_feature_names) < num_topics:
tfidf_vectorizer = TfidfVectorizer(max_df=1.0, min_df=1, stop_words='english', ngram_range=(1, 3))
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
if len(tfidf_feature_names) < num_topics:
return None
lda = LatentDirichletAllocation(n_components=num_topics, max_iter=5, learning_method='online', random_state=42, n_jobs=-1)
lda.fit(tfidf)
topic_data_list = []
for topic_idx, topic in enumerate(lda.components_):
top_words_indices = topic.argsort()[:-N - 1:-1]
top_words = [tfidf_feature_names[i] for i in top_words_indices]
word_weights = [topic[i] for i in top_words_indices]
for word, weight in zip(top_words, word_weights):
topic_data_list.append({
'Topic_ID': f'Topic #{topic_idx + 1}',
'Word': word,
'Weight': weight,
})
return pd.DataFrame(topic_data_list)
except Exception as e:
return None
def create_topic_word_bubbles(df_topic_data):
"""Generates a Plotly Bubble Chart for top words across all topics."""
df_topic_data = df_topic_data.rename(columns={'Topic_ID': 'topic','Word': 'word', 'Weight': 'weight'})
df_topic_data['x_pos'] = df_topic_data.index
if df_topic_data.empty:
return None
fig = px.scatter(
df_topic_data,
x='x_pos', y='weight', size='weight', color='topic', text='word', hover_name='word', size_max=40,
title='Topic Word Weights (Bubble Chart)',
color_discrete_sequence=px.colors.qualitative.Bold,
labels={'x_pos': 'Entity/Word Index', 'weight': 'Word Weight', 'topic': 'Topic ID'},
custom_data=['word', 'weight', 'topic']
)
fig.update_layout(
xaxis_title="Entity/Word", yaxis_title="Word Weight",
xaxis={'showgrid': False, 'showticklabels': False, 'zeroline': False, 'showline': False},
yaxis={'showgrid': True},
showlegend=True, height=600,
margin=dict(t=50, b=100, l=50, r=10),
plot_bgcolor='#f9f9f9', paper_bgcolor='#f9f9f9'
)
fig.update_traces(
textposition='middle center',
textfont=dict(color='white', size=10),
hovertemplate="%{customdata[0]}
Weight: %{customdata[1]:.3f}
Topic: %{customdata[2]}",
marker=dict(line=dict(width=1, color='DarkSlateGrey'))
)
return fig
def generate_network_graph(df, raw_text, entity_color_map):
"""Generates a network graph visualization (Node Plot) with edges based on entity co-occurrence in sentences."""
entity_counts = df['text'].value_counts().reset_index()
entity_counts.columns = ['text', 'frequency']
unique_entities = df.drop_duplicates(subset=['text', 'label']).merge(entity_counts, on='text')
if unique_entities.shape[0] < 2:
return go.Figure().update_layout(title="Not enough unique entities for a meaningful graph.")
num_nodes = len(unique_entities)
thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
radius = 10
unique_entities['x'] = radius * np.cos(thetas) + np.random.normal(0, 0.5, num_nodes)
unique_entities['y'] = radius * np.sin(thetas) + np.random.normal(0, 0.5, num_nodes)
pos_map = unique_entities.set_index('text')[['x', 'y']].to_dict('index')
edges = set()
# Simple sentence tokenizer
sentences = re.split(r'(?%{text}
Label: %{customdata[0]}
Score: %{customdata[1]:.2f}
Frequency: %{customdata[2]}")
))
legend_traces = []
seen_labels = set()
for index, row in unique_entities.iterrows():
label = row['label']
if label not in seen_labels:
seen_labels.add(label)
color = entity_color_map.get(label, '#cccccc')
legend_traces.append(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=10, color=color), name=f"{label.capitalize()}", showlegend=True))
for trace in legend_traces:
fig.add_trace(trace)
fig.update_layout(
title='Entity Co-occurrence Network (Edges = Same Sentence)',
showlegend=True, hovermode='closest',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
plot_bgcolor='#f9f9f9', paper_bgcolor='#f9f9f9',
margin=dict(t=50, b=10, l=10, r=10), height=600
)
return fig
# --- CSV GENERATION FUNCTION ---
def generate_entity_csv(df):
"""Generates a CSV file of the extracted entities in an in-memory buffer."""
csv_buffer = BytesIO()
df_export = df[['text', 'label', 'category', 'score', 'start', 'end']]
csv_buffer.write(df_export.to_csv(index=False).encode('utf-8'))
csv_buffer.seek(0)
return csv_buffer
# -----------------------------------
# --- HTML REPORT GENERATION FUNCTION (MODIFIED FOR WHITE-LABEL) ---
def generate_html_report(df, text_input, elapsed_time, df_topic_data, entity_color_map, report_title="Entity and Topic Analysis Report", branding_html=""):
"""
Generates a full HTML report containing all analysis results and visualizations.
Accepts report_title and branding_html for white-labeling.
"""
# Use the category values from the DataFrame to ensure the report matches the app's current mode (fixed or custom)
unique_categories = df['category'].unique()
# 1. Generate Visualizations (Plotly HTML)
# 1a. Treemap
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='category',
title="Entity Distribution by Category and Label",
color_discrete_sequence=px.colors.qualitative.Dark24
)
fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
treemap_html = fig_treemap.to_html(full_html=False, include_plotlyjs='cdn') # 1b. Pie Chart
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
color_seq = px.colors.qualitative.Pastel if len(grouped_counts) > 1 else px.colors.sequential.Cividis
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=color_seq)
fig_pie.update_layout(margin=dict(t=50, b=10))
pie_html = fig_pie.to_html(full_html=False, include_plotlyjs='cdn')
# 1c. Bar Chart (Category Count)
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count',color='Category', title='Total Entities per Category',color_discrete_sequence=color_seq)
fig_bar_category.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_category_html = fig_bar_category.to_html(full_html=False,include_plotlyjs='cdn')
# 1d. Bar Chart (Most Frequent Entities)
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
bar_freq_html = 'No entities appear more than once in the text for visualization.
'
if not repeating_entities.empty:
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count',color='Entity', title='Top 10 Most Frequent Entities',color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_freq_html = fig_bar_freq.to_html(full_html=False, include_plotlyjs='cdn')
# 1e. Network Graph HTML - IMPORTANT: Pass color map
network_fig = generate_network_graph(df, text_input, entity_color_map)
network_html = network_fig.to_html(full_html=False, include_plotlyjs='cdn')
# 1f. Topic Charts HTML
topic_charts_html = 'Topic Word Weights (Bubble Chart)
'
if df_topic_data is not None and not df_topic_data.empty:
bubble_figure = create_topic_word_bubbles(df_topic_data)
if bubble_figure:
topic_charts_html += f'{bubble_figure.to_html(full_html=False, include_plotlyjs="cdn", config={"responsive": True})}
'
else:
topic_charts_html += 'Error: Topic modeling data was available but visualization failed.
'
else:
topic_charts_html += '' # Changed border color
topic_charts_html += '
Topic Modeling requires more unique input.
'
topic_charts_html += '
Please enter text containing at least two unique entities to generate the Topic Bubble Chart.
'
topic_charts_html += '
'
# 2. Get Highlighted Text - IMPORTANT: Pass color map
highlighted_text_html = highlight_entities(text_input, df, entity_color_map).replace("div style", "div class='highlighted-text' style")
# 3. Entity Tables (Pandas to HTML)
entity_table_html = df[['text', 'label', 'score', 'start', 'end', 'category']].to_html(
classes='table table-striped',
index=False
)
# 4. Construct the Final HTML (UPDATED FOR WHITE-LABELING)
html_content = f"""
{report_title}
{report_title}
1. Analyzed Text & Extracted Entities
Original Text with Highlighted Entities
{highlighted_text_html}
2. Full Extracted Entities Table
{entity_table_html}
3. Data Visualizations
3.1 Entity Distribution Treemap
{treemap_html}
3.2 Comparative Charts (Pie, Category Count, Frequency) - *Stacked Vertically*
{pie_html}
{bar_category_html}
3.3 Entity Relationship Map (Edges = Same Sentence)
{network_html}
4. Topic Modelling
{topic_charts_html}
3.4 Most Frequent Entities
{bar_freq_html}
"""
return html_content
# --- CHUNKING IMPLEMENTATION FOR LARGE TEXT ---
def chunk_text(text, max_chunk_size=1500):
"""Splits text into chunks by sentence/paragraph, respecting a max size (by character count)."""
# Split by double newline (paragraph) or sentence-like separators
segments = re.split(r'(\n\n|(?<=[.!?])\s+)', text)
chunks = []
current_chunk = ""
current_offset = 0
for segment in segments:
if not segment: continue
if len(current_chunk) + len(segment) > max_chunk_size and current_chunk:
# Save the current chunk and its starting offset
chunks.append((current_chunk, current_offset))
current_offset += len(current_chunk)
current_chunk = segment
else:
current_chunk += segment
if current_chunk:
chunks.append((current_chunk, current_offset))
return chunks
def process_chunked_text(text, labels, model):
"""Processes large text in chunks and aggregates/offsets the entities."""
# GLiNER model context size can be around 1024-1500 tokens/words. We use a generous char limit.
# The word count limit is 10000, but we chunk around 500 words for safety/performance.
MAX_CHUNK_CHARS = 3500
chunks = chunk_text(text, max_chunk_size=MAX_CHUNK_CHARS)
all_entities = []
for chunk_text, chunk_offset in chunks:
# Predict entities on the small chunk
chunk_entities = model.predict_entities(chunk_text, labels)
# Offset the start and end indices to match the original document
for entity in chunk_entities:
entity['start'] += chunk_offset
entity['end'] += chunk_offset
all_entities.append(entity)
return all_entities
# -----------------------------------
# --- Page Configuration and Styling (No Sidebar) ---
st.set_page_config(layout="wide", page_title="NER & Topic Report App")
# --- Conditional Mobile Warning ---
st.markdown(
"""
β οΈ **Tip for Mobile Users:** For the best viewing experience of the charts and tables, please switch your browser to **"Desktop Site"** view.
""",
unsafe_allow_html=True)
# --- Topic Modeling Settings (Moved to main body, but need to initialize key outside of 'if st.session_state.show_results:') ---
# st.sidebar.header("Topic Modeling Settings π‘") # Removed sidebar header
st.subheader("Entity and Topic Analysis Report Generator", divider="blue") # Changed divider from "rainbow" (often includes red/pink) to "blue"
# Removed st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary") for white-labeling
tab1, tab2 = st.tabs(["Embed", "Important Notes"])
with tab1:
with st.expander("Embed"):
st.write("Use the following code to embed the DataHarvest web app on your website. Feel free to adjust the width and height values to fit your page.")
code = '''
'''
st.code(code, language="html")
with tab2:
expander = st.expander("**Important Notes**")
expander.markdown("""
**Named Entities (Fixed Mode):** This DataHarvest web app predicts nine (9) labels: "person", "country", "city", "organization", "date", "time", "cardinal", "money", "position".
**Custom Labels Mode:** You can define your own comma-separated labels (e.g., `product, symptom, client_id`) in the input box below.
**Results:** Results are compiled into a single, comprehensive **HTML report** and a **CSV file** for easy download and sharing.
**How to Use:** Type or paste your text into the text area below, then click the 'Results' button.
""")
st.markdown("For any errors or inquiries, please contact us at [info@your-company.com](mailto:info@your-company.com)") # Updated contact info
# --- Comet ML Setup (Placeholder/Conditional) ---
COMET_API_KEY = os.environ.get("COMET_API_KEY")
COMET_WORKSPACE = os.environ.get("COMET_WORKSPACE")
COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
comet_initialized = bool(COMET_API_KEY and COMET_WORKSPACE and COMET_PROJECT_NAME)
# --- Model Loading ---
@st.cache_resource
def load_ner_model(labels):
"""Loads the GLiNER model and caches it."""
try:
# The model requires constraints (labels) to be passed during loading
return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
except Exception as e:
# Log the actual error to the console for debugging
print(f"FATAL ERROR: Failed to load NER model: {e}")
st.error(f"Failed to load NER model. This may be due to a dependency issue or resource limits: {e}")
st.stop()
# --- LONG DEFAULT TEXT (178 Words) ---
DEFAULT_TEXT = (
"In June 2024, the founder, Dr. Emily Carter, officially announced a new, expansive partnership between "
"TechSolutions Inc. and the European Space Agency (ESA). This strategic alliance represents a significant "
"leap forward for commercial space technology across the entire **European Union**. The agreement, finalized "
"on Monday in Paris, France, focuses specifically on jointly developing the next generation of the 'Astra' "
"software platform. This version of the **Astra** platform is critical for processing and managing the vast amounts of data being sent "
"back from the recent Mars rover mission. This project underscores the ESA's commitment to advancing "
"space capabilities within the **European Union**. The core team, including lead engineer Marcus Davies, will hold "
"their first collaborative workshop in Berlin, Germany, on August 15th. The community response on social "
"media platform X (under the username @TechCEO) was overwhelmingly positive, with many major tech "
"publications, including Wired Magazine, predicting a major impact on the space technology industry by the "
"end of the year, further strengthening the technological standing of the **European Union**. The platform is designed to be compatible with both Windows and Linux operating systems. "
"The initial funding, secured via a Series B round, totaled $50 million. Financial analysts from Morgan Stanley "
"are closely monitoring the impact on TechSolutions Inc.'s Q3 financial reports, expected to be released to the "
"general public by October 1st. The goal is to deploy the **Astra** v2 platform before the next solar eclipse event in 2026.")
# -----------------------------------
# --- Session State Initialization (CRITICAL FIX) ---
if 'show_results' not in st.session_state: st.session_state.show_results = False
if 'last_text' not in st.session_state: st.session_state.last_text = ""
if 'results_df' not in st.session_state: st.session_state.results_df = pd.DataFrame()
if 'elapsed_time' not in st.session_state: st.session_state.elapsed_time = 0.0
if 'topic_results' not in st.session_state: st.session_state.topic_results = None
if 'my_text_area' not in st.session_state: st.session_state.my_text_area = DEFAULT_TEXT
if 'custom_labels_input' not in st.session_state: st.session_state.custom_labels_input = ""
if 'active_labels_list' not in st.session_state: st.session_state.active_labels_list = FIXED_LABELS
if 'is_custom_mode' not in st.session_state: st.session_state.is_custom_mode = False
# Initialize Topic Model settings in state, so they can be set even if not using the sidebar
if 'num_topics_slider' not in st.session_state: st.session_state.num_topics_slider = 5
if 'num_top_words_slider' not in st.session_state: st.session_state.num_top_words_slider = 10
if 'last_num_topics' not in st.session_state: st.session_state.last_num_topics = None
if 'last_num_top_words' not in st.session_state: st.session_state.last_num_top_words = None
# --- Clear Button Function (MODIFIED) ---
def clear_text():
"""Clears the text area (sets it to an empty string) and hides results."""
st.session_state['my_text_area'] = ""
st.session_state.show_results = False
st.session_state.last_text = ""
st.session_state.results_df = pd.DataFrame()
st.session_state.elapsed_time = 0.0
st.session_state.topic_results = None
# --- Text Input and Clear Button ---
word_limit = 10000 # Updated to 10000
text = st.text_area(
f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter",
height=250,
key='my_text_area',
)
word_count = len(text.split())
st.markdown(f"**Word count:** {word_count}/{word_limit}")
# --- Custom Labels Input ---
custom_labels_text = st.text_area(
"**Optional:** Enter your own comma-separated entity labels here (e.g., `product, symptom, client_id`). Leave blank for default labels.",
height=60,
key='custom_labels_input',
placeholder="e.g., product, symptom, client_id" # Show placeholder after the prompt
)
# Use columns to align the buttons neatly
col_results, col_clear = st.columns([1, 1])
with col_results:
run_button = st.button("Results", key='run_results', use_container_width=True)
with col_clear:
st.button("Clear text", on_click=clear_text, use_container_width=True)
# --- Results Trigger and Processing (Completed Logic with Chunking and Topic Vars) ---
if run_button:
# 1. Determine Active Labels and Mode
custom_labels_raw = st.session_state.custom_labels_input
if custom_labels_raw.strip():
# Sanitize and parse custom labels
custom_labels_list = [label.strip().lower() for label in custom_labels_raw.split(',') if label.strip()]
if not custom_labels_list:
# Fallback if user enters commas but no actual words
st.session_state.active_labels_list = FIXED_LABELS
st.session_state.is_custom_mode = False
st.info("No valid custom labels found. Falling back to default fixed labels.")
else:
st.session_state.active_labels_list = custom_labels_list
st.session_state.is_custom_mode = True
else:
st.session_state.active_labels_list = FIXED_LABELS
st.session_state.is_custom_mode = False
active_labels = st.session_state.active_labels_list
if not text.strip():
st.warning("Please enter some text to extract entities.")
st.session_state.show_results = False
elif word_count > word_limit:
st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
st.session_state.show_results = False
else:
# Define a safe threshold for when to start chunking (e.g., above 500 words)
CHUNKING_THRESHOLD = 500
should_chunk = word_count > CHUNKING_THRESHOLD
mode_msg = f"{'custom' if st.session_state.is_custom_mode else 'fixed'} labels"
if should_chunk:
mode_msg += " with **chunking** for large text"
# --- Topic Modeling Input Retrieval (Using default or current state values) ---
# The actual sliders are only visible after results are shown, so here we use the state defaults
# or the last successfully run values to check for changes and run the model.
current_num_topics = st.session_state.num_topics_slider
current_num_top_words = st.session_state.num_top_words_slider
with st.spinner(f"Extracting entities using {mode_msg}...", show_time=True):
# Re-run prediction only if text, active labels, OR topic parameters have changed
current_settings = (text, tuple(active_labels), current_num_topics, current_num_top_words)
# Add topic settings to last_settings check
last_settings = (
st.session_state.last_text,
tuple(st.session_state.get('last_active_labels', [])),
st.session_state.get('last_num_topics', None),
st.session_state.get('last_num_top_words', None)
)
if current_settings != last_settings:
start_time = time.time()
ner_model = load_ner_model(labels=active_labels)
# 2. Perform NER Extraction
if should_chunk:
all_entities_list = process_chunked_text(text, active_labels, ner_model)
else:
all_entities_list = ner_model.predict_entities(text, active_labels)
df = pd.DataFrame(all_entities_list)
if df.empty:
df_topic_data = None
else:
# 3. Add Category Mapping
df['category'] = df['label'].apply(
lambda l: REVERSE_FIXED_CATEGORY_MAPPING.get(l, "User Defined Entities")
)
# 4. Perform Topic Modeling (Passing the new parameters)
df_topic_data = perform_topic_modeling(
df_entities=df,
num_topics=current_num_topics, # NEW PARAMETER
num_top_words=current_num_top_words # NEW PARAMETER
)
end_time = time.time()
elapsed_time = end_time - start_time
# 5. Save Results to Session State
st.session_state.results_df = df
st.session_state.topic_results = df_topic_data
st.session_state.elapsed_time = elapsed_time
st.session_state.last_text = text
st.session_state.show_results = True
st.session_state.last_active_labels = active_labels
st.session_state.last_num_topics = current_num_topics # Save topic settings
st.session_state.last_num_top_words = current_num_top_words # Save topic settings
else:
st.info("Results already calculated for the current text and settings.")
st.session_state.show_results = True
# --- Display Download Link and Results (Updated with White-Label inputs) ---
if st.session_state.show_results:
df = st.session_state.results_df
# Note: Topic data needs to be re-run if the sliders change, but here we reuse the state value unless the re-run button is hit.
# To fix this, we need to handle the Topic Modeling calculation separately so that changing the slider triggers a run without hitting the main 'Results' button.
# --- Topic Model Slider Re-Run Logic (New Block) ---
st.markdown("---")
st.markdown("### 4. Advanced Analysis")
st.markdown("π‘ **Topic Modeling Settings:** Adjust these sliders and click **'Re-Run Topic Model'** to see instant changes.")
col_slider_topic, col_slider_words, col_rerun_btn = st.columns([1, 1, 0.5])
with col_slider_topic:
new_num_topics = st.slider(
"Number of Topics",
min_value=2,
max_value=10,
value=st.session_state.num_topics_slider,
step=1,
key='num_topics_slider_new',
help="The number of topics to discover (2 to 10)."
)
with col_slider_words:
new_num_top_words = st.slider(
"Number of Top Words",
min_value=5,
max_value=20,
value=st.session_state.num_top_words_slider,
step=1,
key='num_top_words_slider_new',
help="The number of top words to display per topic (5 to 20)."
)
# Function to trigger a recalculation of ONLY the topic model
def rerun_topic_model():
# Update session state with the new slider values
st.session_state.num_topics_slider = st.session_state.num_topics_slider_new
st.session_state.num_top_words_slider = st.session_state.num_top_words_slider_new
# Recalculate topic modeling results
if not st.session_state.results_df.empty:
df_topic_data_new = perform_topic_modeling(
df_entities=st.session_state.results_df,
num_topics=st.session_state.num_topics_slider,
num_top_words=st.session_state.num_top_words_slider
)
st.session_state.topic_results = df_topic_data_new
st.session_state.last_num_topics = st.session_state.num_topics_slider
st.session_state.last_num_top_words = st.session_state.num_top_words_slider
st.success("Topic Model Re-Run Complete!")
# Rerunning Streamlit will display the updated state immediately
with col_rerun_btn:
st.markdown("", unsafe_allow_html=True) # Vertical spacing
st.button("Re-Run Topic Model", on_click=rerun_topic_model, use_container_width=True, type="primary")
df_topic_data = st.session_state.topic_results
# --- End Topic Model Slider Re-Run Logic ---
entity_color_map = get_dynamic_color_map(df['label'].unique().tolist(), FIXED_ENTITY_COLOR_MAP)
if df.empty:
st.warning("No entities were found in the provided text with the current label set.")
else:
st.subheader("Analysis Results", divider="blue")
# 1. Highlighted Text
st.markdown(f"### 1. Analyzed Text with Highlighted Entities ({'Custom Mode' if st.session_state.is_custom_mode else 'Fixed Mode'})")
st.markdown(highlight_entities(st.session_state.last_text, df, entity_color_map), unsafe_allow_html=True)
# 2. Detailed Entity Analysis Tabs
st.markdown("### 2. Detailed Entity Analysis")
tab_category_details, tab_treemap_viz = st.tabs(["π Entities Grouped by Category", "πΊοΈ Treemap Distribution"])
# Determine which categories to use for the tabs
if st.session_state.is_custom_mode:
unique_categories = ["User Defined Entities"]
tabs_to_show = df['label'].unique().tolist()
st.markdown(f"**Custom Labels Detected: {', '.join(tabs_to_show)}**")
else:
unique_categories = list(FIXED_CATEGORY_MAPPING.keys())
# --- Section 2a: Detailed Tables by Category/Label ---
with tab_category_details:
st.markdown("#### Detailed Entities Table (Grouped by Category)")
if st.session_state.is_custom_mode:
# In custom mode, group by the actual label since the category is just "User Defined Entities"
tabs_list = df['label'].unique().tolist()
tabs_category = st.tabs(tabs_list)
for label, tab in zip(tabs_list, tabs_category):
df_label = df[df['label'] == label][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
with tab:
st.markdown(f"##### {label.capitalize()} Entities ({len(df_label)} total)")
st.dataframe(
df_label,
use_container_width=True,
column_config={'score': st.column_config.NumberColumn(format="%.4f")}
)
else:
# In fixed mode, group by the category defined in FIXED_CATEGORY_MAPPING
tabs_category = st.tabs(unique_categories)
for category, tab in zip(unique_categories, tabs_category):
df_category = df[df['category'] == category][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
with tab:
st.markdown(f"##### {category} Entities ({len(df_category)} total)")
if not df_category.empty:
st.dataframe(
df_category,
use_container_width=True,
column_config={'score': st.column_config.NumberColumn(format="%.4f")}
)
else:
st.info(f"No entities of category **{category}** were found in the text.")
# --- INSERTED GLOSSARY HERE ---
with st.expander("See Glossary of tags"):
st.write('''- **text**: ['entity extracted from your text data']- **label**: ['label (tag) assigned to a given extracted entity (custom or fixed)']- **category**: ['the grouping category (e.g., "Locations" or "User Defined Entities")']- **score**: ['accuracy score; how accurately a tag has been assigned to a given entity']- **start**: ['index of the start of the corresponding entity']- **end**: ['index of the end of the corresponding entity']''')
# --- END GLOSSARY INSERTION ---
# --- Section 2b: Treemap Visualization ---
with tab_treemap_viz:
st.markdown("#### Treemap: Entity Distribution")
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='category',
color_discrete_sequence=px.colors.qualitative.Dark24
)
fig_treemap.update_layout(margin=dict(t=10, l=10, r=10, b=10))
st.plotly_chart(fig_treemap, use_container_width=True)
# --- Section 3: Comparative Charts (COMPLETED) ---
st.markdown("---")
st.markdown("### 3. Comparative Charts")
col1, col2, col3 = st.columns(3)
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
# Determine color sequence for charts
chart_color_seq = px.colors.qualitative.Pastel if len(grouped_counts) > 1 else px.colors.sequential.Cividis
with col1: # Pie Chart
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=chart_color_seq)
fig_pie.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350)
st.plotly_chart(fig_pie, use_container_width=True)
with col2: # Bar Chart by Category
st.markdown("#### Entity Count by Category")
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count', color='Category', title='Total Entities per Category', color_discrete_sequence=chart_color_seq)
fig_bar_category.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350, showlegend=False)
st.plotly_chart(fig_bar_category, use_container_width=True)
with col3: # Bar Chart for Most Frequent Entities
st.markdown("#### Top 10 Most Frequent Entities")
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
if not repeating_entities.empty:
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count', title='Top 10 Most Frequent Entities', color='Entity', color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350, showlegend=False)
st.plotly_chart(fig_bar_freq, use_container_width=True)
else:
st.info("No entities were repeated enough for a Top 10 frequency chart.")
# 4. Network Graph and Topic Modeling (Modified to show controls and charts in columns)
col_network, col_topic = st.columns(2)
with col_network:
with st.expander("π Entity Co-occurrence Network Graph", expanded=True):
st.plotly_chart(generate_network_graph(df, st.session_state.last_text, entity_color_map), use_container_width=True)
with col_topic:
with st.expander("π‘ Topic Modeling (LDA)", expanded=True):
# Display the current settings used for the topic modeling result
st.markdown(f"""
**Current LDA Parameters:**
* Topics: **{st.session_state.last_num_topics}**
* Top Words: **{st.session_state.last_num_top_words}**
""")
if df_topic_data is not None and not df_topic_data.empty:
st.plotly_chart(create_topic_word_bubbles(df_topic_data), use_container_width=True)
st.markdown("This chart visualizes the key words driving the identified topics, based on extracted entities.")
else:
st.info("Topic Modeling requires at least two unique entities with a minimum frequency to perform statistical analysis.")
# --- 5. White-Label Configuration (NEW SECTION FOR CUSTOM BRANDING) ---
st.markdown("---")
st.markdown("### 5. White-Label Report Configuration π¨")
# Set a dynamic default title based on the mode
default_report_title = f"{'Custom' if st.session_state.is_custom_mode else 'Fixed'} Entity Analysis Report"
custom_report_title = st.text_input(
"Type Your Report Title (for HTML Report), and then press Enter.",
value=default_report_title
)
# UPDATED: Simplified input for the user
custom_branding_text_input = st.text_area(
"Type Your Brand Name or Tagline (Appears below the title in the report), and then press Enter.",
value="Analysis powered by My Own Brand", # Removed the technical tag
key='custom_branding_input',
help="Enter your brand name or a short tagline. This text will be automatically styled and included below the main title."
)
# 6. Downloads (Updated to pass custom variables)
st.markdown("---")
st.markdown("### 6. Downloads")
col_csv, col_html = st.columns(2)
# CSV Download
csv_buffer = generate_entity_csv(df)
with col_csv:
st.download_button(
label="β¬οΈ Download Entities as CSV",
data=csv_buffer,
file_name="ner_entities_report.csv",
mime="text/csv",
use_container_width=True
)
# --- NEW LOGIC: Wrap the simple text input into proper HTML for the report ---
# We wrap the user's plain text in a styled HTML paragraph element
branding_to_pass = f'
{custom_branding_text_input}
'
# HTML Download (Passing custom white-label parameters)
html_content = generate_html_report(
df,
st.session_state.last_text,
st.session_state.elapsed_time,
df_topic_data,
entity_color_map,
report_title=custom_report_title, # Pass custom title
branding_html=branding_to_pass # Pass the now-wrapped HTML
)
html_bytes = html_content.encode('utf-8')
with col_html:
st.download_button(
label="β¬οΈ Download Full HTML Report",
data=html_bytes,
file_name="ner_topic_full_report.html",
mime="text/html",
use_container_width=True
)