Spaces:
Running
Running
| import os | |
| import time | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import re | |
| import string | |
| import json | |
| from io import BytesIO | |
| # --- Visualization & PPTX --- | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| from pptx import Presentation | |
| from pptx.util import Inches, Pt | |
| # --- NLP & Analysis --- | |
| from gliner import GLiNER | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.decomposition import LatentDirichletAllocation | |
| # --- 1. CONFIGURATION & STYLING --- | |
| os.environ['HF_HOME'] = '/tmp' | |
| entity_color_map = { | |
| "person": "#10b981", "country": "#3b82f6", "city": "#4ade80", | |
| "organization": "#f59e0b", "date": "#8b5cf6", "time": "#ec4899", | |
| "cardinal": "#06b6d4", "money": "#f43f5e", "position": "#a855f7" | |
| } | |
| labels = list(entity_color_map.keys()) | |
| category_mapping = { | |
| "People": ["person", "organization", "position"], | |
| "Locations": ["country", "city"], | |
| "Time": ["date", "time"], | |
| "Numbers": ["money", "cardinal"] | |
| } | |
| reverse_category_mapping = {label: cat for cat, lbls in category_mapping.items() for label in lbls} | |
| # --- 2. CORE UTILITY FUNCTIONS --- | |
| def remove_trailing_punctuation(text_string): | |
| return text_string.rstrip(string.punctuation) | |
| def highlight_entities(text, df_entities): | |
| if df_entities.empty: | |
| return text | |
| # Sort entities by start index descending to prevent index shifting | |
| entities = df_entities.sort_values(by='start', ascending=False).to_dict('records') | |
| highlighted_text = text | |
| for entity in entities: | |
| start, end = entity['start'], entity['end'] | |
| label, entity_text = entity['label'], entity['text'] | |
| color = entity_color_map.get(label, '#000000') | |
| highlight_html = f'<span style="background-color: {color}; color: white; padding: 2px 4px; border-radius: 3px; font-weight: bold;">{entity_text}</span>' | |
| highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:] | |
| return f'<div class="highlighted-text" style="border: 1px solid #ddd; padding: 15px; border-radius: 8px; background-color: #ffffff; line-height: 2; white-space: pre-wrap;">{highlighted_text}</div>' | |
| def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10): | |
| documents = df_entities['text'].unique().tolist() | |
| if len(documents) < 2: return None | |
| try: | |
| tfidf_vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), min_df=1) | |
| tfidf = tfidf_vectorizer.fit_transform(documents) | |
| feature_names = tfidf_vectorizer.get_feature_names_out() | |
| lda = LatentDirichletAllocation(n_components=num_topics, random_state=42) | |
| lda.fit(tfidf) | |
| topic_data = [] | |
| for idx, topic in enumerate(lda.components_): | |
| top_indices = topic.argsort()[:-num_top_words - 1:-1] | |
| for i in top_indices: | |
| topic_data.append({'Topic_ID': f'Topic #{idx + 1}', 'Word': feature_names[i], 'Weight': topic[i]}) | |
| return pd.DataFrame(topic_data) | |
| except: return None | |
| # --- 3. VISUALIZATION FUNCTIONS (FIXED TITLES) --- | |
| def create_topic_word_bubbles(df_topic_data): | |
| df = df_topic_data.rename(columns={'Topic_ID': 'topic','Word': 'word', 'Weight': 'weight'}) | |
| df['x_pos'] = range(len(df)) | |
| fig = px.scatter(df, x='x_pos', y='weight', size='weight', color='topic', text='word', title='Topic Word Weights') | |
| # FIX: Increased top margin for title visibility | |
| fig.update_layout(margin=dict(t=80, b=50), xaxis_showticklabels=False, plot_bgcolor='#f9f9f9') | |
| fig.update_traces(textposition='middle center', textfont=dict(color='white', size=10)) | |
| return fig | |
| def generate_network_graph(df, raw_text): | |
| counts = df['text'].value_counts().reset_index(name='frequency') | |
| unique = df.drop_duplicates(subset=['text']).merge(counts, on='text') | |
| num_nodes = len(unique) | |
| thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False) | |
| unique['x'] = 10 * np.cos(thetas) | |
| unique['y'] = 10 * np.sin(thetas) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=unique['x'], y=unique['y'], mode='markers+text', text=unique['text'], | |
| marker=dict(size=unique['frequency']*5 + 15, color=[entity_color_map.get(l, '#ccc') for l in unique['label']]) | |
| )) | |
| # FIX: Added top margin for Title | |
| fig.update_layout(title="Entity Relationship Map", margin=dict(t=80), showlegend=False, xaxis_visible=False, yaxis_visible=False) | |
| return fig | |
| # --- 4. EXPORT FUNCTIONS --- | |
| def generate_html_report(df, text_input, elapsed_time, df_topic_data): | |
| # Prepare all charts with fixed layout margins | |
| fig_tree = px.treemap(df, path=[px.Constant("All"), 'category', 'label', 'text'], values='score', title="Entity Hierarchy") | |
| fig_tree.update_layout(margin=dict(t=60, b=20, l=20, r=20)) | |
| tree_html = fig_tree.to_html(full_html=False, include_plotlyjs='cdn') | |
| net_html = generate_network_graph(df, text_input).to_html(full_html=False, include_plotlyjs='cdn') | |
| html_template = f""" | |
| <html> | |
| <head> | |
| <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> | |
| <style> | |
| body {{ font-family: sans-serif; background: #f4f7f6; padding: 30px; }} | |
| .card {{ background: white; padding: 25px; border-radius: 12px; margin-bottom: 25px; box-shadow: 0 2px 10px rgba(0,0,0,0.05); }} | |
| /* FIX: Critical for title visibility */ | |
| .chart-box {{ min-height: 500px; overflow: visible !important; border: 1px solid #eee; }} | |
| h1, h2 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="card"> | |
| <h1>NER & Topic Analysis Report</h1> | |
| <p>Processing Time: {elapsed_time:.2f}s</p> | |
| <h2>1. Highlighted Entities</h2> | |
| {highlight_entities(text_input, df)} | |
| <h2>2. Visual Analytics</h2> | |
| <div class="chart-box">{tree_html}</div> | |
| <div class="chart-box">{net_html}</div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return html_template | |
| def generate_pptx_report(df): | |
| prs = Presentation() | |
| slide = prs.slides.add_slide(prs.slide_layouts[0]) | |
| slide.shapes.title.text = "Entity Analysis" | |
| slide = prs.slides.add_slide(prs.slide_layouts[1]) | |
| slide.shapes.title.text = "Entity List" | |
| tf = slide.placeholders[1].text_frame | |
| for i, row in df.head(15).iterrows(): | |
| p = tf.add_paragraph() | |
| p.text = f"{row['text']} ({row['label']})" | |
| buffer = BytesIO() | |
| prs.save(buffer) | |
| buffer.seek(0) | |
| return buffer | |
| # --- 5. STREAMLIT UI & LOGIC --- | |
| st.set_page_config(layout="wide", page_title="DataHarvest NER") | |
| def load_model(): | |
| return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True) | |
| model = load_model() | |
| # Session State Init | |
| if 'results_df' not in st.session_state: | |
| st.session_state.results_df = pd.DataFrame() | |
| st.session_state.show = False | |
| st.subheader("Entity & Topic Analysis Report Generator", divider="blue") | |
| text = st.text_area("Paste text here (max 1000 words):", height=250) | |
| if st.button("Run Analysis"): | |
| if text: | |
| with st.spinner("Processing..."): | |
| start = time.time() | |
| entities = model.predict_entities(text, labels) | |
| df = pd.DataFrame(entities) | |
| if not df.empty: | |
| df['text'] = df['text'].apply(remove_trailing_punctuation) | |
| df['category'] = df['label'].map(reverse_category_mapping) | |
| st.session_state.results_df = df | |
| st.session_state.elapsed = time.time() - start | |
| st.session_state.topics = perform_topic_modeling(df) | |
| st.session_state.show = True | |
| else: | |
| st.warning("No entities found.") | |
| if st.session_state.show: | |
| df = st.session_state.results_df | |
| st.markdown("### 1. Extracted Entities") | |
| st.markdown(highlight_entities(text, df), unsafe_allow_html=True) | |
| t1, t2, t3 = st.tabs(["Charts", "Network Map", "Topics"]) | |
| with t1: | |
| fig_tree = px.treemap(df, path=['category', 'label', 'text'], values='score', title="Entity Treemap") | |
| # Ensure the preview also has margins | |
| fig_tree.update_layout(margin=dict(t=50)) | |
| st.plotly_chart(fig_tree, use_container_width=True) | |
| with t2: | |
| st.plotly_chart(generate_network_graph(df, text), use_container_width=True) | |
| with t3: | |
| if st.session_state.topics is not None: | |
| st.plotly_chart(create_topic_word_bubbles(st.session_state.topics), use_container_width=True) | |
| else: | |
| st.info("Not enough data for topic modeling.") | |
| st.divider() | |
| st.markdown("### Download Artifacts") | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| st.download_button("Download HTML Report", | |
| generate_html_report(df, text, st.session_state.elapsed, st.session_state.topics), | |
| "report.html", "text/html", type="primary") | |
| with c2: | |
| csv = df.to_csv(index=False).encode('utf-8') | |
| st.download_button("Download CSV Data", csv, "entities.csv", "text/csv") | |
| with c3: | |
| st.download_button("Download PPTX Summary", generate_pptx_report(df), "summary.pptx") | |