Spaces:
Runtime error
Runtime error
| """ | |
| BERTopic Topic Modeling Gradio App | |
| Upload a text file and visualize topics with an intertopic distance map. | |
| Uses Hugging Face sentence-transformers for embeddings. | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from bertopic import BERTopic | |
| from sentence_transformers import SentenceTransformer | |
| from hdbscan import HDBSCAN | |
| from umap import UMAP | |
| import plotly.graph_objects as go | |
| import tempfile | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class TopicModelingApp: | |
| """Topic Modeling Application using BERTopic with Hugging Face embeddings.""" | |
| def __init__(self): | |
| self.topic_model = None | |
| self.topics = None | |
| self.probs = None | |
| self.embeddings = None | |
| self.documents = [] | |
| self.embedding_model_name = "all-MiniLM-L6-v2" | |
| def load_documents(self, file_path): | |
| """Load documents from a text file. Each line is treated as a separate document.""" | |
| documents = [] | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Try to split by different delimiters | |
| if '\n\n' in content: | |
| # Split by double newlines (paragraphs) | |
| documents = [doc.strip() for doc in content.split('\n\n') if doc.strip()] | |
| elif '\n' in content: | |
| # Split by single newlines (lines) | |
| documents = [doc.strip() for doc in content.split('\n') if doc.strip()] | |
| else: | |
| # Single document | |
| documents = [content.strip()] | |
| # Filter out very short documents | |
| documents = [doc for doc in documents if len(doc.split()) >= 3] | |
| return documents | |
| def fit_topic_model(self, documents, n_neighbors=15, n_components=5, min_cluster_size=10, min_samples=5): | |
| """Fit BERTopic model on the documents.""" | |
| if len(documents) < 5: | |
| raise ValueError("Need at least 5 documents to perform topic modeling.") | |
| # Initialize embedding model from Hugging Face | |
| embedding_model = SentenceTransformer(self.embedding_model_name) | |
| # Generate embeddings | |
| self.embeddings = embedding_model.encode(documents, show_progress_bar=True) | |
| # Configure UMAP for dimensionality reduction | |
| umap_model = UMAP( | |
| n_neighbors=min(n_neighbors, len(documents) - 1), | |
| n_components=min(n_components, len(documents) - 1), | |
| min_dist=0.0, | |
| metric='cosine', | |
| random_state=42 | |
| ) | |
| # Configure HDBSCAN for clustering | |
| hdbscan_model = HDBSCAN( | |
| min_cluster_size=min(min_cluster_size, max(2, len(documents) // 2)), | |
| min_samples=min(min_samples, max(1, len(documents) // 4)), | |
| metric='euclidean', | |
| cluster_selection_method='eom', | |
| prediction_data=True | |
| ) | |
| # Initialize BERTopic with custom models | |
| self.topic_model = BERTopic( | |
| embedding_model=embedding_model, | |
| umap_model=umap_model, | |
| hdbscan_model=hdbscan_model, | |
| verbose=True, | |
| calculate_probabilities=True | |
| ) | |
| # Fit the model | |
| self.topics, self.probs = self.topic_model.fit_transform(documents, self.embeddings) | |
| self.documents = documents | |
| return self.topic_model, self.topics, self.probs | |
| def get_topic_info(self): | |
| """Get information about discovered topics.""" | |
| if self.topic_model is None: | |
| return None | |
| return self.topic_model.get_topic_info() | |
| def create_intertopic_distance_map(self): | |
| """Create an interactive intertopic distance map visualization.""" | |
| if self.topic_model is None: | |
| return None | |
| # Get topic info | |
| topic_info = self.topic_model.get_topic_info() | |
| # Get topic embeddings (2D projection for visualization) | |
| topic_embeddings = self.topic_model.topic_embeddings_ | |
| if topic_embeddings is None or len(topic_embeddings) == 0: | |
| return self._create_fallback_visualization(topic_info) | |
| # Reduce to 2D for visualization using UMAP | |
| from umap import UMAP | |
| if topic_embeddings.shape[0] > 1: | |
| n_neighbors = min(15, topic_embeddings.shape[0] - 1) | |
| reducer = UMAP(n_components=2, n_neighbors=n_neighbors, metric='cosine', random_state=42) | |
| topic_coords_2d = reducer.fit_transform(topic_embeddings) | |
| else: | |
| topic_coords_2d = np.array([[0, 0]]) | |
| # Create DataFrame for visualization | |
| viz_df = pd.DataFrame({ | |
| 'x': topic_coords_2d[:, 0], | |
| 'y': topic_coords_2d[:, 1], | |
| 'Topic': topic_info['Topic'].values, | |
| 'Count': topic_info['Count'].values, | |
| 'Name': topic_info['Name'].values | |
| }) | |
| # Filter out outlier topic (-1) for better visualization | |
| viz_df_filtered = viz_df[viz_df['Topic'] != -1].copy() | |
| if len(viz_df_filtered) == 0: | |
| return self._create_fallback_visualization(topic_info) | |
| # Calculate bubble sizes (normalized) | |
| max_count = viz_df_filtered['Count'].max() | |
| min_count = viz_df_filtered['Count'].min() | |
| if max_count > min_count: | |
| viz_df_filtered['Size'] = 30 + (viz_df_filtered['Count'] - min_count) / (max_count - min_count) * 70 | |
| else: | |
| viz_df_filtered['Size'] = 50 | |
| # Create the interactive plot | |
| fig = go.Figure() | |
| # Add scatter plot with custom styling | |
| fig.add_trace(go.Scatter( | |
| x=viz_df_filtered['x'], | |
| y=viz_df_filtered['y'], | |
| mode='markers+text', | |
| marker=dict( | |
| size=viz_df_filtered['Size'], | |
| color=viz_df_filtered['Topic'], | |
| colorscale='Viridis', | |
| showscale=True, | |
| colorbar=dict(title='Topic ID'), | |
| line=dict(width=2, color='white'), | |
| opacity=0.8 | |
| ), | |
| text=viz_df_filtered['Topic'].astype(str), | |
| textposition='middle center', | |
| textfont=dict(size=12, color='white', family='Arial Black'), | |
| customdata=viz_df_filtered[['Name', 'Count', 'Topic']], | |
| hovertemplate=( | |
| '<b>Topic %{customdata[2]}</b><br>' | |
| '<b>Keywords:</b> %{customdata[0]}<br>' | |
| '<b>Document Count:</b> %{customdata[1]}<br>' | |
| '<extra></extra>' | |
| ), | |
| name='Topics' | |
| )) | |
| # Update layout for better visualization | |
| fig.update_layout( | |
| title=dict( | |
| text='<b>Intertopic Distance Map</b><br><sup>Bubble size represents number of documents in each topic</sup>', | |
| font=dict(size=20, family='Arial'), | |
| x=0.5, | |
| xanchor='center' | |
| ), | |
| xaxis=dict( | |
| title='Dimension 1', | |
| showgrid=True, | |
| gridcolor='lightgray', | |
| zeroline=True, | |
| zerolinecolor='gray' | |
| ), | |
| yaxis=dict( | |
| title='Dimension 2', | |
| showgrid=True, | |
| gridcolor='lightgray', | |
| zeroline=True, | |
| zerolinecolor='gray' | |
| ), | |
| plot_bgcolor='white', | |
| paper_bgcolor='white', | |
| width=900, | |
| height=700, | |
| hovermode='closest', | |
| showlegend=False | |
| ) | |
| return fig | |
| def _create_fallback_visualization(self, topic_info): | |
| """Create a bar chart as fallback when 2D projection is not possible.""" | |
| # Filter out outlier topic | |
| df = topic_info[topic_info['Topic'] != -1].head(20) | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=df['Topic'].astype(str), | |
| y=df['Count'], | |
| marker_color=df['Topic'], | |
| marker_colorscale='Viridis', | |
| text=df['Name'], | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>Topic %{x}</b><br>' | |
| '<b>Keywords:</b> %{text}<br>' | |
| '<b>Document Count:</b> %{y}<br>' | |
| '<extra></extra>' | |
| ) | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title=dict( | |
| text='<b>Topic Distribution</b><br><sup>Number of documents per topic</sup>', | |
| font=dict(size=20), | |
| x=0.5, | |
| xanchor='center' | |
| ), | |
| xaxis_title='Topic ID', | |
| yaxis_title='Document Count', | |
| plot_bgcolor='white', | |
| paper_bgcolor='white', | |
| width=900, | |
| height=700 | |
| ) | |
| return fig | |
| def get_topic_documents(self, topic_id, n_docs=5): | |
| """Get representative documents for a specific topic.""" | |
| if self.topic_model is None or topic_id not in self.topics: | |
| return [] | |
| # Get indices of documents in this topic | |
| topic_doc_indices = [i for i, t in enumerate(self.topics) if t == topic_id] | |
| # Get representative documents | |
| representative_docs = [self.documents[i] for i in topic_doc_indices[:n_docs]] | |
| return representative_docs | |
| # Initialize the app | |
| app = TopicModelingApp() | |
| def process_file(file, n_neighbors, n_components, min_cluster_size, min_samples): | |
| """Process uploaded file and generate topic model.""" | |
| if file is None: | |
| return None, None, "Please upload a text file.", None | |
| try: | |
| # Load documents | |
| documents = app.load_documents(file) | |
| if len(documents) < 5: | |
| return None, None, f"Error: Need at least 5 documents. Found {len(documents)} documents. Please upload a file with more content.", None | |
| # Fit the model | |
| app.fit_topic_model( | |
| documents, | |
| n_neighbors=int(n_neighbors), | |
| n_components=int(n_components), | |
| min_cluster_size=int(min_cluster_size), | |
| min_samples=int(min_samples) | |
| ) | |
| # Get topic info | |
| topic_info = app.get_topic_info() | |
| # Create visualization | |
| fig = app.create_intertopic_distance_map() | |
| # Create summary text | |
| n_topics = len(topic_info[topic_info['Topic'] != -1]) | |
| n_docs = len(documents) | |
| n_outliers = topic_info[topic_info['Topic'] == -1]['Count'].values[0] if -1 in topic_info['Topic'].values else 0 | |
| summary = f""" | |
| ## Topic Modeling Results | |
| **Total Documents:** {n_docs} | |
| **Topics Discovered:** {n_topics} | |
| **Outlier Documents:** {n_outliers} | |
| ### Topic Summary Table: | |
| """ | |
| # Return results | |
| return fig, topic_info, summary, topic_info | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during processing: {str(e)}\n\n{traceback.format_exc()}" | |
| return None, None, error_msg, None | |
| def get_topic_details(topic_id): | |
| """Get detailed information about a specific topic.""" | |
| if app.topic_model is None: | |
| return "Please run topic modeling first." | |
| try: | |
| topic_id = int(topic_id) | |
| # Get topic words | |
| topic_words = app.topic_model.get_topic(topic_id) | |
| if topic_words is None or len(topic_words) == 0: | |
| return f"Topic {topic_id} not found or has no keywords." | |
| # Format output | |
| output = f"## Topic {topic_id} Details\n\n" | |
| output += "### Top Keywords:\n" | |
| for word, score in topic_words[:10]: | |
| output += f"- **{word}**: {score:.4f}\n" | |
| # Get representative documents | |
| rep_docs = app.get_topic_documents(topic_id, n_docs=3) | |
| if rep_docs: | |
| output += "\n### Representative Documents:\n" | |
| for i, doc in enumerate(rep_docs, 1): | |
| output += f"\n**Document {i}:**\n> {doc[:300]}{'...' if len(doc) > 300 else ''}\n" | |
| return output | |
| except Exception as e: | |
| return f"Error getting topic details: {str(e)}" | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks( | |
| title="BERTopic Topic Modeling", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-title {text-align: center; margin-bottom: 20px;} | |
| .upload-area {min-height: 100px;} | |
| .results-area {margin-top: 20px;} | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # π― BERTopic Topic Modeling App | |
| Upload a text file to discover and visualize topics using **BERTopic** with **Hugging Face** embeddings. | |
| **Instructions:** | |
| 1. Upload a text file (each line or paragraph will be treated as a separate document) | |
| 2. Adjust parameters if needed (or use defaults) | |
| 3. Click "Run Topic Modeling" to discover topics | |
| 4. Explore the intertopic distance map and topic table | |
| 5. Enter a topic ID to see detailed information | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # File upload | |
| file_input = gr.File( | |
| label="Upload Text File", | |
| file_types=[".txt"], | |
| type="filepath" | |
| ) | |
| # Parameters | |
| gr.Markdown("### Model Parameters") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| n_neighbors = gr.Slider( | |
| minimum=2, maximum=50, value=15, step=1, | |
| label="UMAP n_neighbors", | |
| info="Controls local vs global structure preservation" | |
| ) | |
| n_components = gr.Slider( | |
| minimum=2, maximum=20, value=5, step=1, | |
| label="UMAP n_components", | |
| info="Dimension of the reduced embedding space" | |
| ) | |
| min_cluster_size = gr.Slider( | |
| minimum=2, maximum=50, value=10, step=1, | |
| label="HDBSCAN min_cluster_size", | |
| info="Minimum cluster size for topic formation" | |
| ) | |
| min_samples = gr.Slider( | |
| minimum=1, maximum=30, value=5, step=1, | |
| label="HDBSCAN min_samples", | |
| info="Controls cluster density threshold" | |
| ) | |
| # Run button | |
| run_btn = gr.Button("π Run Topic Modeling", variant="primary", size="lg") | |
| # Status output | |
| status_output = gr.Markdown(label="Status") | |
| with gr.Row(): | |
| # Visualization | |
| with gr.Column(scale=2): | |
| viz_output = gr.Plot(label="Intertopic Distance Map") | |
| # Topic table | |
| with gr.Column(scale=1): | |
| topic_table = gr.Dataframe( | |
| label="Topic Information", | |
| headers=["Topic", "Count", "Name"], | |
| wrap=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Topic details explorer | |
| gr.Markdown("### π Topic Explorer") | |
| with gr.Row(): | |
| topic_id_input = gr.Number( | |
| label="Topic ID", | |
| value=0, | |
| precision=0, | |
| minimum=0 | |
| ) | |
| get_details_btn = gr.Button("Get Topic Details", variant="secondary") | |
| topic_details = gr.Markdown(label="Topic Details") | |
| # Example text for demo | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π Example Format | |
| Your text file should contain multiple documents, each on a new line or separated by blank lines: | |
| ``` | |
| Machine learning is a subset of artificial intelligence that enables systems to learn from data. | |
| Climate change poses significant risks to global ecosystems and human societies. | |
| The stock market showed volatility amid concerns about inflation and interest rates. | |
| ... | |
| ``` | |
| **Tip:** For best results, upload at least 20-50 documents with varied content. | |
| """ | |
| ) | |
| # Event handlers | |
| run_btn.click( | |
| fn=process_file, | |
| inputs=[file_input, n_neighbors, n_components, min_cluster_size, min_samples], | |
| outputs=[viz_output, topic_table, status_output, topic_table] | |
| ) | |
| get_details_btn.click( | |
| fn=get_topic_details, | |
| inputs=[topic_id_input], | |
| outputs=[topic_details] | |
| ) | |
| return demo | |
| # Main entry point | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |