""" BERTopic Topic Modeling Gradio App Upload a text file and visualize topics with an intertopic distance map. Uses Hugging Face sentence-transformers for embeddings. """ import gradio as gr import pandas as pd import numpy as np from bertopic import BERTopic from sentence_transformers import SentenceTransformer from hdbscan import HDBSCAN from umap import UMAP import plotly.graph_objects as go import tempfile import os import warnings warnings.filterwarnings('ignore') class TopicModelingApp: """Topic Modeling Application using BERTopic with Hugging Face embeddings.""" def __init__(self): self.topic_model = None self.topics = None self.probs = None self.embeddings = None self.documents = [] self.embedding_model_name = "all-MiniLM-L6-v2" def load_documents(self, file_path): """Load documents from a text file. Each line is treated as a separate document.""" documents = [] with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # Try to split by different delimiters if '\n\n' in content: # Split by double newlines (paragraphs) documents = [doc.strip() for doc in content.split('\n\n') if doc.strip()] elif '\n' in content: # Split by single newlines (lines) documents = [doc.strip() for doc in content.split('\n') if doc.strip()] else: # Single document documents = [content.strip()] # Filter out very short documents documents = [doc for doc in documents if len(doc.split()) >= 3] return documents def fit_topic_model(self, documents, n_neighbors=15, n_components=5, min_cluster_size=10, min_samples=5): """Fit BERTopic model on the documents.""" if len(documents) < 5: raise ValueError("Need at least 5 documents to perform topic modeling.") # Initialize embedding model from Hugging Face embedding_model = SentenceTransformer(self.embedding_model_name) # Generate embeddings self.embeddings = embedding_model.encode(documents, show_progress_bar=True) # Configure UMAP for dimensionality reduction umap_model = UMAP( n_neighbors=min(n_neighbors, len(documents) - 1), n_components=min(n_components, len(documents) - 1), min_dist=0.0, metric='cosine', random_state=42 ) # Configure HDBSCAN for clustering hdbscan_model = HDBSCAN( min_cluster_size=min(min_cluster_size, max(2, len(documents) // 2)), min_samples=min(min_samples, max(1, len(documents) // 4)), metric='euclidean', cluster_selection_method='eom', prediction_data=True ) # Initialize BERTopic with custom models self.topic_model = BERTopic( embedding_model=embedding_model, umap_model=umap_model, hdbscan_model=hdbscan_model, verbose=True, calculate_probabilities=True ) # Fit the model self.topics, self.probs = self.topic_model.fit_transform(documents, self.embeddings) self.documents = documents return self.topic_model, self.topics, self.probs def get_topic_info(self): """Get information about discovered topics.""" if self.topic_model is None: return None return self.topic_model.get_topic_info() def create_intertopic_distance_map(self): """Create an interactive intertopic distance map visualization.""" if self.topic_model is None: return None # Get topic info topic_info = self.topic_model.get_topic_info() # Get topic embeddings (2D projection for visualization) topic_embeddings = self.topic_model.topic_embeddings_ if topic_embeddings is None or len(topic_embeddings) == 0: return self._create_fallback_visualization(topic_info) # Reduce to 2D for visualization using UMAP from umap import UMAP if topic_embeddings.shape[0] > 1: n_neighbors = min(15, topic_embeddings.shape[0] - 1) reducer = UMAP(n_components=2, n_neighbors=n_neighbors, metric='cosine', random_state=42) topic_coords_2d = reducer.fit_transform(topic_embeddings) else: topic_coords_2d = np.array([[0, 0]]) # Create DataFrame for visualization viz_df = pd.DataFrame({ 'x': topic_coords_2d[:, 0], 'y': topic_coords_2d[:, 1], 'Topic': topic_info['Topic'].values, 'Count': topic_info['Count'].values, 'Name': topic_info['Name'].values }) # Filter out outlier topic (-1) for better visualization viz_df_filtered = viz_df[viz_df['Topic'] != -1].copy() if len(viz_df_filtered) == 0: return self._create_fallback_visualization(topic_info) # Calculate bubble sizes (normalized) max_count = viz_df_filtered['Count'].max() min_count = viz_df_filtered['Count'].min() if max_count > min_count: viz_df_filtered['Size'] = 30 + (viz_df_filtered['Count'] - min_count) / (max_count - min_count) * 70 else: viz_df_filtered['Size'] = 50 # Create the interactive plot fig = go.Figure() # Add scatter plot with custom styling fig.add_trace(go.Scatter( x=viz_df_filtered['x'], y=viz_df_filtered['y'], mode='markers+text', marker=dict( size=viz_df_filtered['Size'], color=viz_df_filtered['Topic'], colorscale='Viridis', showscale=True, colorbar=dict(title='Topic ID'), line=dict(width=2, color='white'), opacity=0.8 ), text=viz_df_filtered['Topic'].astype(str), textposition='middle center', textfont=dict(size=12, color='white', family='Arial Black'), customdata=viz_df_filtered[['Name', 'Count', 'Topic']], hovertemplate=( 'Topic %{customdata[2]}
' 'Keywords: %{customdata[0]}
' 'Document Count: %{customdata[1]}
' '' ), name='Topics' )) # Update layout for better visualization fig.update_layout( title=dict( text='Intertopic Distance Map
Bubble size represents number of documents in each topic', font=dict(size=20, family='Arial'), x=0.5, xanchor='center' ), xaxis=dict( title='Dimension 1', showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='gray' ), yaxis=dict( title='Dimension 2', showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='gray' ), plot_bgcolor='white', paper_bgcolor='white', width=900, height=700, hovermode='closest', showlegend=False ) return fig def _create_fallback_visualization(self, topic_info): """Create a bar chart as fallback when 2D projection is not possible.""" # Filter out outlier topic df = topic_info[topic_info['Topic'] != -1].head(20) fig = go.Figure(data=[ go.Bar( x=df['Topic'].astype(str), y=df['Count'], marker_color=df['Topic'], marker_colorscale='Viridis', text=df['Name'], textposition='outside', hovertemplate=( 'Topic %{x}
' 'Keywords: %{text}
' 'Document Count: %{y}
' '' ) ) ]) fig.update_layout( title=dict( text='Topic Distribution
Number of documents per topic', font=dict(size=20), x=0.5, xanchor='center' ), xaxis_title='Topic ID', yaxis_title='Document Count', plot_bgcolor='white', paper_bgcolor='white', width=900, height=700 ) return fig def get_topic_documents(self, topic_id, n_docs=5): """Get representative documents for a specific topic.""" if self.topic_model is None or topic_id not in self.topics: return [] # Get indices of documents in this topic topic_doc_indices = [i for i, t in enumerate(self.topics) if t == topic_id] # Get representative documents representative_docs = [self.documents[i] for i in topic_doc_indices[:n_docs]] return representative_docs # Initialize the app app = TopicModelingApp() def process_file(file, n_neighbors, n_components, min_cluster_size, min_samples): """Process uploaded file and generate topic model.""" if file is None: return None, None, "Please upload a text file.", None try: # Load documents documents = app.load_documents(file) if len(documents) < 5: return None, None, f"Error: Need at least 5 documents. Found {len(documents)} documents. Please upload a file with more content.", None # Fit the model app.fit_topic_model( documents, n_neighbors=int(n_neighbors), n_components=int(n_components), min_cluster_size=int(min_cluster_size), min_samples=int(min_samples) ) # Get topic info topic_info = app.get_topic_info() # Create visualization fig = app.create_intertopic_distance_map() # Create summary text n_topics = len(topic_info[topic_info['Topic'] != -1]) n_docs = len(documents) n_outliers = topic_info[topic_info['Topic'] == -1]['Count'].values[0] if -1 in topic_info['Topic'].values else 0 summary = f""" ## Topic Modeling Results **Total Documents:** {n_docs} **Topics Discovered:** {n_topics} **Outlier Documents:** {n_outliers} ### Topic Summary Table: """ # Return results return fig, topic_info, summary, topic_info except Exception as e: import traceback error_msg = f"Error during processing: {str(e)}\n\n{traceback.format_exc()}" return None, None, error_msg, None def get_topic_details(topic_id): """Get detailed information about a specific topic.""" if app.topic_model is None: return "Please run topic modeling first." try: topic_id = int(topic_id) # Get topic words topic_words = app.topic_model.get_topic(topic_id) if topic_words is None or len(topic_words) == 0: return f"Topic {topic_id} not found or has no keywords." # Format output output = f"## Topic {topic_id} Details\n\n" output += "### Top Keywords:\n" for word, score in topic_words[:10]: output += f"- **{word}**: {score:.4f}\n" # Get representative documents rep_docs = app.get_topic_documents(topic_id, n_docs=3) if rep_docs: output += "\n### Representative Documents:\n" for i, doc in enumerate(rep_docs, 1): output += f"\n**Document {i}:**\n> {doc[:300]}{'...' if len(doc) > 300 else ''}\n" return output except Exception as e: return f"Error getting topic details: {str(e)}" # Create Gradio interface def create_interface(): """Create the Gradio interface.""" with gr.Blocks( title="BERTopic Topic Modeling", theme=gr.themes.Soft(), css=""" .main-title {text-align: center; margin-bottom: 20px;} .upload-area {min-height: 100px;} .results-area {margin-top: 20px;} """ ) as demo: gr.Markdown( """ # 🎯 BERTopic Topic Modeling App Upload a text file to discover and visualize topics using **BERTopic** with **Hugging Face** embeddings. **Instructions:** 1. Upload a text file (each line or paragraph will be treated as a separate document) 2. Adjust parameters if needed (or use defaults) 3. Click "Run Topic Modeling" to discover topics 4. Explore the intertopic distance map and topic table 5. Enter a topic ID to see detailed information """ ) with gr.Row(): with gr.Column(scale=1): # File upload file_input = gr.File( label="Upload Text File", file_types=[".txt"], type="filepath" ) # Parameters gr.Markdown("### Model Parameters") with gr.Accordion("Advanced Settings", open=False): n_neighbors = gr.Slider( minimum=2, maximum=50, value=15, step=1, label="UMAP n_neighbors", info="Controls local vs global structure preservation" ) n_components = gr.Slider( minimum=2, maximum=20, value=5, step=1, label="UMAP n_components", info="Dimension of the reduced embedding space" ) min_cluster_size = gr.Slider( minimum=2, maximum=50, value=10, step=1, label="HDBSCAN min_cluster_size", info="Minimum cluster size for topic formation" ) min_samples = gr.Slider( minimum=1, maximum=30, value=5, step=1, label="HDBSCAN min_samples", info="Controls cluster density threshold" ) # Run button run_btn = gr.Button("🚀 Run Topic Modeling", variant="primary", size="lg") # Status output status_output = gr.Markdown(label="Status") with gr.Row(): # Visualization with gr.Column(scale=2): viz_output = gr.Plot(label="Intertopic Distance Map") # Topic table with gr.Column(scale=1): topic_table = gr.Dataframe( label="Topic Information", headers=["Topic", "Count", "Name"], wrap=True ) with gr.Row(): with gr.Column(): # Topic details explorer gr.Markdown("### 🔍 Topic Explorer") with gr.Row(): topic_id_input = gr.Number( label="Topic ID", value=0, precision=0, minimum=0 ) get_details_btn = gr.Button("Get Topic Details", variant="secondary") topic_details = gr.Markdown(label="Topic Details") # Example text for demo gr.Markdown( """ --- ### 📝 Example Format Your text file should contain multiple documents, each on a new line or separated by blank lines: ``` Machine learning is a subset of artificial intelligence that enables systems to learn from data. Climate change poses significant risks to global ecosystems and human societies. The stock market showed volatility amid concerns about inflation and interest rates. ... ``` **Tip:** For best results, upload at least 20-50 documents with varied content. """ ) # Event handlers run_btn.click( fn=process_file, inputs=[file_input, n_neighbors, n_components, min_cluster_size, min_samples], outputs=[viz_output, topic_table, status_output, topic_table] ) get_details_btn.click( fn=get_topic_details, inputs=[topic_id_input], outputs=[topic_details] ) return demo # Main entry point if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )