Spaces:
Sleeping
Sleeping
| """ | |
| app.py β Streamlit frontend for the AI-driven Topic Modeling application. | |
| This module provides an interactive web interface that allows users to: | |
| 1. Upload a CSV file containing research paper Titles and Abstracts. | |
| 2. Configure pipeline parameters (min topics, LLM label generation). | |
| 3. Run the TopicAgent pipeline with a single click. | |
| 4. View and explore results: topics table, comparison, taxonomy map. | |
| 5. Review topics with an editable review table. | |
| 6. Visualize topic distributions with interactive Plotly charts. | |
| 7. Download all generated outputs (CSV, JSON). | |
| """ | |
| import os | |
| import json | |
| import tempfile | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from agent import TopicAgent | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace Spaces compatibility: use a writable output directory | |
| # On HF Spaces the working directory can be read-only, so fall back to /tmp | |
| # --------------------------------------------------------------------------- | |
| OUTPUT_DIR = "outputs" | |
| try: | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Test write access | |
| _test_path = os.path.join(OUTPUT_DIR, ".write_test") | |
| with open(_test_path, "w") as _f: | |
| _f.write("ok") | |
| os.remove(_test_path) | |
| except (OSError, PermissionError): | |
| OUTPUT_DIR = os.path.join(tempfile.gettempdir(), "topic_modeler_outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Page configuration | |
| # --------------------------------------------------------------------------- | |
| st.set_page_config( | |
| page_title="Research Topic Modeler β AI Agent", | |
| page_icon="π¬", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Custom CSS for a polished, professional look with dark-safe text colors | |
| # --------------------------------------------------------------------------- | |
| st.markdown(""" | |
| <style> | |
| /* Import Google Font */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| /* Global */ | |
| html, body, [class*="css"] { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| /* Header gradient banner */ | |
| .main-header { | |
| background: linear-gradient(135deg, #0f0c29 0%, #302b63 50%, #24243e 100%); | |
| padding: 2rem 2.5rem; | |
| border-radius: 16px; | |
| margin-bottom: 1.5rem; | |
| box-shadow: 0 8px 32px rgba(48, 43, 99, 0.3); | |
| } | |
| .main-header h1 { | |
| color: #ffffff; | |
| font-size: 2.2rem; | |
| font-weight: 700; | |
| margin: 0; | |
| letter-spacing: -0.5px; | |
| } | |
| .main-header p { | |
| color: #b8b5ff; | |
| font-size: 1.05rem; | |
| margin: 0.5rem 0 0 0; | |
| font-weight: 300; | |
| } | |
| /* Stat cards */ | |
| .stat-card { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1.25rem 1.5rem; | |
| border-radius: 12px; | |
| color: white; | |
| text-align: center; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| transition: transform 0.2s ease; | |
| } | |
| .stat-card:hover { | |
| transform: translateY(-2px); | |
| } | |
| .stat-card .stat-value { | |
| font-size: 2rem; | |
| font-weight: 700; | |
| line-height: 1.2; | |
| color: #ffffff; | |
| } | |
| .stat-card .stat-label { | |
| font-size: 0.85rem; | |
| opacity: 0.85; | |
| margin-top: 0.3rem; | |
| font-weight: 400; | |
| color: #e8e6ff; | |
| } | |
| /* Status badge */ | |
| .status-badge { | |
| display: inline-block; | |
| padding: 0.3rem 1rem; | |
| border-radius: 20px; | |
| font-size: 0.8rem; | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .status-success { | |
| background: linear-gradient(135deg, #11998e, #38ef7d); | |
| color: #ffffff; | |
| } | |
| .status-failed { | |
| background: linear-gradient(135deg, #eb3349, #f45c43); | |
| color: #ffffff; | |
| } | |
| .status-running { | |
| background: linear-gradient(135deg, #f7971e, #ffd200); | |
| color: #1a1a2e; | |
| } | |
| /* Section headers β always readable on both light and dark backgrounds */ | |
| .section-header { | |
| font-size: 1.3rem; | |
| font-weight: 600; | |
| color: #c4b5fd; | |
| margin: 1.5rem 0 0.75rem 0; | |
| padding-bottom: 0.5rem; | |
| border-bottom: 2px solid #667eea; | |
| display: inline-block; | |
| } | |
| /* Taxonomy badges */ | |
| .mapped-badge { | |
| display: inline-block; | |
| background: linear-gradient(135deg, #11998e, #38ef7d); | |
| color: #ffffff; | |
| padding: 0.2rem 0.7rem; | |
| border-radius: 12px; | |
| font-size: 0.75rem; | |
| font-weight: 600; | |
| } | |
| .novel-badge { | |
| display: inline-block; | |
| background: linear-gradient(135deg, #fc4a1a, #f7b733); | |
| color: #ffffff; | |
| padding: 0.2rem 0.7rem; | |
| border-radius: 12px; | |
| font-size: 0.75rem; | |
| font-weight: 600; | |
| } | |
| /* Sidebar styling */ | |
| section[data-testid="stSidebar"] { | |
| background: linear-gradient(180deg, #1a1a2e 0%, #16213e 100%); | |
| } | |
| section[data-testid="stSidebar"] .stMarkdown { | |
| color: #e0e0e0; | |
| } | |
| section[data-testid="stSidebar"] label { | |
| color: #e0e0e0 !important; | |
| } | |
| section[data-testid="stSidebar"] .stSlider label { | |
| color: #e0e0e0 !important; | |
| } | |
| /* Data table enhancements */ | |
| .stDataFrame { | |
| border-radius: 8px; | |
| overflow: hidden; | |
| } | |
| /* Info box β dark-safe: dark background with light text */ | |
| .info-box { | |
| background: linear-gradient(135deg, #1e1e3f 0%, #2d2b55 100%); | |
| padding: 1rem 1.5rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| margin: 0.75rem 0; | |
| color: #e0e0e0; | |
| } | |
| .info-box strong { | |
| color: #ffffff; | |
| } | |
| .info-box code { | |
| background: rgba(102, 126, 234, 0.2); | |
| color: #b8b5ff; | |
| padding: 0.1rem 0.4rem; | |
| border-radius: 4px; | |
| } | |
| /* Pipeline step */ | |
| .step-item { | |
| padding: 0.5rem 1rem; | |
| margin: 0.3rem 0; | |
| border-radius: 8px; | |
| background: rgba(102, 126, 234, 0.15); | |
| border-left: 3px solid #667eea; | |
| font-size: 0.9rem; | |
| color: #e0e0e0; | |
| } | |
| /* Chart container styling */ | |
| .chart-container { | |
| background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); | |
| border-radius: 12px; | |
| padding: 1rem; | |
| margin: 0.5rem 0; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); | |
| } | |
| /* Review section header */ | |
| .review-header { | |
| background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%); | |
| padding: 1rem 1.5rem; | |
| border-radius: 12px; | |
| margin-bottom: 1rem; | |
| box-shadow: 0 4px 15px rgba(17, 153, 142, 0.3); | |
| } | |
| .review-header h3 { | |
| color: #ffffff; | |
| margin: 0; | |
| font-weight: 600; | |
| } | |
| .review-header p { | |
| color: #e0fff8; | |
| margin: 0.3rem 0 0 0; | |
| font-size: 0.9rem; | |
| } | |
| /* Save confirmation */ | |
| .save-confirm { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: #ffffff; | |
| padding: 0.75rem 1.25rem; | |
| border-radius: 10px; | |
| margin-top: 0.5rem; | |
| font-weight: 500; | |
| } | |
| /* Ensure tab labels are readable */ | |
| .stTabs [data-baseweb="tab-list"] button { | |
| color: #c4b5fd; | |
| } | |
| .stTabs [data-baseweb="tab-list"] button[aria-selected="true"] { | |
| color: #ffffff; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Header | |
| # --------------------------------------------------------------------------- | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π¬ Research Topic Modeler</h1> | |
| <p>AI-powered topic modeling agent for research papers β discover, compare, and classify themes across Titles and Abstracts</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Sidebar β Configuration | |
| # --------------------------------------------------------------------------- | |
| with st.sidebar: | |
| st.markdown("## βοΈ Configuration") | |
| st.markdown("---") | |
| # File upload | |
| st.markdown("### π Dataset") | |
| uploaded_file = st.file_uploader( | |
| "Upload CSV with Title & Abstract columns", | |
| type=["csv"], | |
| help="The CSV must contain at least 'Title' and 'Abstract' columns.", | |
| ) | |
| # Or use default dataset | |
| use_default = st.checkbox( | |
| "Use default dataset (dataset.csv)", | |
| value=True if not uploaded_file else False, | |
| help="Use the bundled dataset.csv file in the project directory.", | |
| ) | |
| st.markdown("---") | |
| st.markdown("### π― Parameters") | |
| min_topics = st.slider( | |
| "Minimum Topics", | |
| min_value=50, | |
| max_value=200, | |
| value=100, | |
| step=10, | |
| help="Minimum number of topics to generate per source (Titles / Abstracts).", | |
| ) | |
| use_llm = st.checkbox( | |
| "π€ Use LLM for Label Generation (Groq)", | |
| value=False, | |
| help="Use Groq's LLaMA model to generate contextual topic labels. " | |
| "Falls back to keyword heuristic if unchecked.", | |
| ) | |
| groq_key = os.environ.get("GROQ_API_KEY", "") | |
| if use_llm: | |
| groq_key = st.text_input( | |
| "Groq API Key", | |
| value=groq_key, | |
| type="password", | |
| help="Your Groq API key for LLM label generation.", | |
| ) | |
| st.markdown("---") | |
| st.markdown("### π Pipeline Steps") | |
| steps_info = [ | |
| "1. Load & validate CSV", | |
| "2. Preprocess text (Titles + Abstracts)", | |
| "3. Topic modeling β Titles (β₯{} topics)".format(min_topics), | |
| "4. Topic modeling β Abstracts (β₯{} topics)".format(min_topics), | |
| "5. Generate human-readable labels", | |
| "6. Combine topics table", | |
| "7. Compare themes (Title vs Abstract)", | |
| "8. Build taxonomy map (MAPPED / NOVEL)", | |
| "9. Export outputs (CSV, JSON)", | |
| ] | |
| for step in steps_info: | |
| st.markdown(f'<div class="step-item">{step}</div>', unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Main area β Run button and results | |
| # --------------------------------------------------------------------------- | |
| col_run, col_status = st.columns([2, 3]) | |
| with col_run: | |
| run_clicked = st.button("π Run Topic Modeling Agent", use_container_width=True, type="primary") | |
| with col_status: | |
| if "result" in st.session_state and st.session_state.result is not None: | |
| res = st.session_state.result | |
| if res.status == "success": | |
| st.markdown('<span class="status-badge status-success">β Pipeline Complete</span>', unsafe_allow_html=True) | |
| elif res.status == "failed": | |
| st.markdown('<span class="status-badge status-failed">β Pipeline Failed</span>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<span class="status-badge status-running">β Awaiting Input</span>', unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Execute pipeline | |
| # --------------------------------------------------------------------------- | |
| if run_clicked: | |
| # Determine CSV path | |
| csv_path = None | |
| if uploaded_file is not None: | |
| # Save uploaded file to a temp location | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", dir=".") as tmp: | |
| tmp.write(uploaded_file.getvalue()) | |
| csv_path = tmp.name | |
| elif use_default: | |
| csv_path = "dataset.csv" | |
| if not os.path.exists(csv_path): | |
| st.error("β Default dataset.csv not found in the project directory.") | |
| st.stop() | |
| else: | |
| st.error("β Please upload a CSV file or select the default dataset.") | |
| st.stop() | |
| # Run the agent | |
| with st.spinner("π Running the Topic Modeling Agent β¦ this may take a few minutes."): | |
| progress = st.progress(0, text="Initializing β¦") | |
| agent = TopicAgent( | |
| csv_path=csv_path, | |
| output_dir=OUTPUT_DIR, | |
| min_topics=min_topics, | |
| use_llm_labels=use_llm, | |
| groq_api_key=groq_key if use_llm else None, | |
| ) | |
| # Display step-by-step progress | |
| progress.progress(5, text="Step 1/9: Loading CSV β¦") | |
| agent._step_load_csv() | |
| progress.progress(10, text="Step 2/9: Preprocessing text β¦") | |
| agent._step_preprocess() | |
| progress.progress(20, text="Step 3/9: Topic modeling on Titles β¦") | |
| agent._step_model_titles() | |
| progress.progress(45, text="Step 4/9: Topic modeling on Abstracts β¦") | |
| agent._step_model_abstracts() | |
| progress.progress(65, text="Step 5/9: Generating topic labels β¦") | |
| agent._step_generate_labels() | |
| progress.progress(75, text="Step 6/9: Building combined topics table β¦") | |
| agent._step_combine_topics() | |
| progress.progress(80, text="Step 7/9: Comparing themes β¦") | |
| agent._step_compare_themes() | |
| progress.progress(90, text="Step 8/9: Building taxonomy map β¦") | |
| agent._step_taxonomy_map() | |
| progress.progress(95, text="Step 9/9: Exporting outputs β¦") | |
| agent._step_export() | |
| agent._result.status = "success" | |
| progress.progress(100, text="β Pipeline complete!") | |
| st.session_state.result = agent._result | |
| # Clean up temp file | |
| if uploaded_file is not None and csv_path and os.path.exists(csv_path): | |
| try: | |
| os.unlink(csv_path) | |
| except Exception: | |
| pass | |
| st.rerun() | |
| # --------------------------------------------------------------------------- | |
| # Helper: Plotly chart theme (dark background, readable text) | |
| # --------------------------------------------------------------------------- | |
| PLOTLY_LAYOUT = dict( | |
| paper_bgcolor="rgba(26, 26, 46, 0.95)", | |
| plot_bgcolor="rgba(22, 33, 62, 0.95)", | |
| font=dict(family="Inter, sans-serif", size=13, color="#e0e0e0"), | |
| title_font=dict(size=18, color="#ffffff"), | |
| legend=dict( | |
| font=dict(color="#e0e0e0"), | |
| bgcolor="rgba(26, 26, 46, 0.7)", | |
| bordercolor="#667eea", | |
| borderwidth=1, | |
| ), | |
| xaxis=dict( | |
| gridcolor="rgba(102, 126, 234, 0.15)", | |
| zerolinecolor="rgba(102, 126, 234, 0.25)", | |
| tickfont=dict(color="#c4b5fd"), | |
| title_font=dict(color="#e0e0e0"), | |
| ), | |
| yaxis=dict( | |
| gridcolor="rgba(102, 126, 234, 0.15)", | |
| zerolinecolor="rgba(102, 126, 234, 0.25)", | |
| tickfont=dict(color="#c4b5fd"), | |
| title_font=dict(color="#e0e0e0"), | |
| ), | |
| margin=dict(l=60, r=30, t=60, b=60), | |
| ) | |
| # Gradient-like color sequence | |
| CHART_COLORS = [ | |
| "#667eea", "#764ba2", "#f093fb", "#f5576c", | |
| "#4facfe", "#00f2fe", "#43e97b", "#38f9d7", | |
| "#fa709a", "#fee140", "#a18cd1", "#fbc2eb", | |
| "#ff9a9e", "#fad0c4", "#ffecd2", "#fcb69f", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Display results | |
| # --------------------------------------------------------------------------- | |
| if "result" in st.session_state and st.session_state.result is not None: | |
| result = st.session_state.result | |
| if result.status == "failed": | |
| st.error(f"Pipeline failed with errors: {result.errors}") | |
| st.stop() | |
| # ---- Summary Statistics ---- | |
| st.markdown('<div class="section-header">π Summary Statistics</div>', unsafe_allow_html=True) | |
| c1, c2, c3, c4, c5 = st.columns(5) | |
| with c1: | |
| st.markdown(f""" | |
| <div class="stat-card"> | |
| <div class="stat-value">{len(result.title_topics)}</div> | |
| <div class="stat-label">Title Topics</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with c2: | |
| st.markdown(f""" | |
| <div class="stat-card"> | |
| <div class="stat-value">{len(result.abstract_topics)}</div> | |
| <div class="stat-label">Abstract Topics</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with c3: | |
| st.markdown(f""" | |
| <div class="stat-card"> | |
| <div class="stat-value">{len(result.combined_topics)}</div> | |
| <div class="stat-label">Total Topics</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with c4: | |
| mapped_count = result.taxonomy_map.get("metadata", {}).get("mapped_count", 0) | |
| st.markdown(f""" | |
| <div class="stat-card"> | |
| <div class="stat-value">{mapped_count}</div> | |
| <div class="stat-label">Mapped Themes</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with c5: | |
| novel_count = result.taxonomy_map.get("metadata", {}).get("novel_count", 0) | |
| st.markdown(f""" | |
| <div class="stat-card"> | |
| <div class="stat-value">{novel_count}</div> | |
| <div class="stat-label">Novel Themes</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # ---- Tabbed Results ---- | |
| tab1, tab2, tab3, tab4, tab5, tab_review, tab_charts = st.tabs([ | |
| "π Topics Table", | |
| "π¬ Title Topics", | |
| "π Abstract Topics", | |
| "βοΈ Theme Comparison", | |
| "πΊοΈ Taxonomy Map", | |
| "βοΈ Review Table", | |
| "π Charts", | |
| ]) | |
| # Tab 1: Combined Topics Table | |
| with tab1: | |
| st.markdown('<div class="section-header">Combined Topics Table</div>', unsafe_allow_html=True) | |
| st.markdown(f"Showing all **{len(result.combined_topics)}** topics from both Titles and Abstracts.") | |
| # Filter controls | |
| fcol1, fcol2 = st.columns(2) | |
| with fcol1: | |
| source_filter = st.multiselect( | |
| "Filter by Source", | |
| options=result.combined_topics["source"].unique().tolist(), | |
| default=result.combined_topics["source"].unique().tolist(), | |
| ) | |
| with fcol2: | |
| search_term = st.text_input("π Search keywords", "") | |
| display_df = result.combined_topics[result.combined_topics["source"].isin(source_filter)] | |
| if search_term: | |
| mask = display_df["keywords"].str.contains(search_term, case=False, na=False) | |
| mask |= display_df["label"].str.contains(search_term, case=False, na=False) | |
| display_df = display_df[mask] | |
| st.dataframe( | |
| display_df, | |
| use_container_width=True, | |
| height=500, | |
| column_config={ | |
| "topic_id": st.column_config.NumberColumn("Topic ID", width="small"), | |
| "keywords": st.column_config.TextColumn("Keywords", width="large"), | |
| "label": st.column_config.TextColumn("Label", width="medium"), | |
| "source": st.column_config.TextColumn("Source", width="small"), | |
| }, | |
| ) | |
| # Tab 2: Title Topics | |
| with tab2: | |
| st.markdown('<div class="section-header">Title Topics</div>', unsafe_allow_html=True) | |
| st.markdown(f"**{len(result.title_topics)}** topics discovered from paper titles.") | |
| st.dataframe(result.title_topics, use_container_width=True, height=500) | |
| # Tab 3: Abstract Topics | |
| with tab3: | |
| st.markdown('<div class="section-header">Abstract Topics</div>', unsafe_allow_html=True) | |
| st.markdown(f"**{len(result.abstract_topics)}** topics discovered from paper abstracts.") | |
| st.dataframe(result.abstract_topics, use_container_width=True, height=500) | |
| # Tab 4: Theme Comparison | |
| with tab4: | |
| st.markdown('<div class="section-header">Theme Comparison: Titles vs Abstracts</div>', unsafe_allow_html=True) | |
| if not result.comparison.empty: | |
| # Alignment distribution | |
| align_counts = result.comparison["alignment"].value_counts() | |
| acol1, acol2, acol3, acol4 = st.columns(4) | |
| for col, alignment in zip( | |
| [acol1, acol2, acol3, acol4], | |
| ["Strong", "Moderate", "Weak", "No Match"], | |
| ): | |
| with col: | |
| count = align_counts.get(alignment, 0) | |
| st.metric(label=f"{alignment} Alignment", value=count) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Filter by alignment | |
| alignment_filter = st.multiselect( | |
| "Filter by Alignment", | |
| options=["Strong", "Moderate", "Weak", "No Match"], | |
| default=["Strong", "Moderate", "Weak", "No Match"], | |
| ) | |
| filtered_comp = result.comparison[result.comparison["alignment"].isin(alignment_filter)] | |
| st.dataframe( | |
| filtered_comp, | |
| use_container_width=True, | |
| height=500, | |
| column_config={ | |
| "similarity": st.column_config.ProgressColumn( | |
| "Similarity", | |
| min_value=0, | |
| max_value=1, | |
| format="%.2f", | |
| ), | |
| }, | |
| ) | |
| else: | |
| st.info("No comparison data available.") | |
| # Tab 5: Taxonomy Map | |
| with tab5: | |
| st.markdown('<div class="section-header">Taxonomy Map</div>', unsafe_allow_html=True) | |
| taxonomy = result.taxonomy_map | |
| meta = taxonomy.get("metadata", {}) | |
| st.markdown(f""" | |
| <div class="info-box"> | |
| <strong>Classification Summary:</strong><br> | |
| Total Topics: <strong>{meta.get('total_topics', 0)}</strong> | | |
| <span class="mapped-badge">MAPPED: {meta.get('mapped_count', 0)}</span> | | |
| <span class="novel-badge">NOVEL: {meta.get('novel_count', 0)}</span> | | |
| Threshold: {meta.get('threshold', 0.15)} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| tax_tab1, tax_tab2 = st.tabs(["β Mapped Themes", "π Novel Themes"]) | |
| with tax_tab1: | |
| mapped_list = taxonomy.get("mapped", []) | |
| if mapped_list: | |
| mapped_df = pd.DataFrame(mapped_list) | |
| st.dataframe( | |
| mapped_df, | |
| use_container_width=True, | |
| height=400, | |
| column_config={ | |
| "score": st.column_config.ProgressColumn( | |
| "Match Score", | |
| min_value=0, | |
| max_value=1, | |
| format="%.3f", | |
| ), | |
| }, | |
| ) | |
| else: | |
| st.info("No mapped themes found.") | |
| with tax_tab2: | |
| novel_list = taxonomy.get("novel", []) | |
| if novel_list: | |
| novel_df = pd.DataFrame(novel_list) | |
| st.dataframe( | |
| novel_df, | |
| use_container_width=True, | |
| height=400, | |
| column_config={ | |
| "score": st.column_config.ProgressColumn( | |
| "Match Score", | |
| min_value=0, | |
| max_value=1, | |
| format="%.3f", | |
| ), | |
| }, | |
| ) | |
| else: | |
| st.info("No novel themes found.") | |
| # ================================================================== | |
| # Tab 6: Editable Review Table | |
| # ================================================================== | |
| with tab_review: | |
| st.markdown(""" | |
| <div class="review-header"> | |
| <h3>βοΈ Topic Review Table</h3> | |
| <p>Review, approve, rename, and annotate each topic. Changes are saved to outputs/review_table.csv.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Build review dataframe from combined topics | |
| # Load existing review table if available to preserve edits | |
| review_csv_path = os.path.join(OUTPUT_DIR, "review_table.csv") | |
| if "review_df" not in st.session_state: | |
| if os.path.exists(review_csv_path): | |
| # Load previously saved review table | |
| existing_review = pd.read_csv(review_csv_path) | |
| # Merge with current topics to ensure all topics are represented | |
| current_ids = set(result.combined_topics["topic_id"].tolist()) | |
| existing_ids = set(existing_review["topic_id"].tolist()) if "topic_id" in existing_review.columns else set() | |
| if current_ids == existing_ids or existing_ids.issuperset(current_ids): | |
| st.session_state.review_df = existing_review | |
| else: | |
| # Rebuild from current topics, but preserve existing edits | |
| review_data = [] | |
| for _, row in result.combined_topics.iterrows(): | |
| review_data.append({ | |
| "topic_id": int(row["topic_id"]), | |
| "label": row.get("label", ""), | |
| "keywords": row.get("keywords", ""), | |
| "source": row.get("source", ""), | |
| "approve": False, | |
| "rename_to": "", | |
| "reasoning": "", | |
| }) | |
| new_review_df = pd.DataFrame(review_data) | |
| # Merge existing edits | |
| if not existing_review.empty and "topic_id" in existing_review.columns: | |
| for _, erow in existing_review.iterrows(): | |
| mask = new_review_df["topic_id"] == erow["topic_id"] | |
| if mask.any(): | |
| if "approve" in erow: | |
| new_review_df.loc[mask, "approve"] = erow["approve"] | |
| if "rename_to" in erow and pd.notna(erow["rename_to"]): | |
| new_review_df.loc[mask, "rename_to"] = erow["rename_to"] | |
| if "reasoning" in erow and pd.notna(erow["reasoning"]): | |
| new_review_df.loc[mask, "reasoning"] = erow["reasoning"] | |
| st.session_state.review_df = new_review_df | |
| else: | |
| # Build fresh review table | |
| review_data = [] | |
| for _, row in result.combined_topics.iterrows(): | |
| review_data.append({ | |
| "topic_id": int(row["topic_id"]), | |
| "label": row.get("label", ""), | |
| "keywords": row.get("keywords", ""), | |
| "source": row.get("source", ""), | |
| "approve": False, | |
| "rename_to": "", | |
| "reasoning": "", | |
| }) | |
| st.session_state.review_df = pd.DataFrame(review_data) | |
| # Filter controls for review table | |
| rv_col1, rv_col2, rv_col3 = st.columns(3) | |
| with rv_col1: | |
| review_source_filter = st.multiselect( | |
| "Filter by Source", | |
| options=st.session_state.review_df["source"].unique().tolist(), | |
| default=st.session_state.review_df["source"].unique().tolist(), | |
| key="review_source_filter", | |
| ) | |
| with rv_col2: | |
| review_search = st.text_input("π Search in review table", "", key="review_search") | |
| with rv_col3: | |
| review_approval_filter = st.selectbox( | |
| "Show", | |
| options=["All Topics", "Approved Only", "Not Approved"], | |
| index=0, | |
| key="review_approval_filter", | |
| ) | |
| # Apply filters | |
| filtered_review = st.session_state.review_df[ | |
| st.session_state.review_df["source"].isin(review_source_filter) | |
| ] | |
| if review_search: | |
| search_mask = ( | |
| filtered_review["keywords"].str.contains(review_search, case=False, na=False) | | |
| filtered_review["label"].str.contains(review_search, case=False, na=False) | |
| ) | |
| filtered_review = filtered_review[search_mask] | |
| if review_approval_filter == "Approved Only": | |
| filtered_review = filtered_review[filtered_review["approve"] == True] | |
| elif review_approval_filter == "Not Approved": | |
| filtered_review = filtered_review[filtered_review["approve"] == False] | |
| # Editable data editor | |
| edited_df = st.data_editor( | |
| filtered_review, | |
| use_container_width=True, | |
| height=500, | |
| num_rows="fixed", | |
| key="review_editor", | |
| column_config={ | |
| "topic_id": st.column_config.NumberColumn( | |
| "Topic ID", width="small", disabled=True | |
| ), | |
| "label": st.column_config.TextColumn( | |
| "Label", width="medium", | |
| ), | |
| "keywords": st.column_config.TextColumn( | |
| "Keywords", width="large", disabled=True, | |
| ), | |
| "source": st.column_config.TextColumn( | |
| "Source", width="small", disabled=True, | |
| ), | |
| "approve": st.column_config.CheckboxColumn( | |
| "β Approve", width="small", default=False, | |
| ), | |
| "rename_to": st.column_config.TextColumn( | |
| "Rename To", width="medium", | |
| ), | |
| "reasoning": st.column_config.TextColumn( | |
| "Reasoning / Notes", width="large", | |
| ), | |
| }, | |
| column_order=["topic_id", "label", "keywords", "approve", "rename_to", "reasoning", "source"], | |
| ) | |
| # Update session state with edits | |
| if edited_df is not None: | |
| # Merge edits back into the full review dataframe | |
| for idx, erow in edited_df.iterrows(): | |
| mask = st.session_state.review_df.index == idx | |
| if mask.any(): | |
| for col in ["label", "approve", "rename_to", "reasoning"]: | |
| if col in erow: | |
| st.session_state.review_df.loc[mask, col] = erow[col] | |
| # Save button | |
| sv_col1, sv_col2, sv_col3 = st.columns([1, 1, 2]) | |
| with sv_col1: | |
| if st.button("πΎ Save Review Table", use_container_width=True, type="primary"): | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| st.session_state.review_df.to_csv(review_csv_path, index=False) | |
| st.markdown( | |
| '<div class="save-confirm">β Review table saved to outputs/review_table.csv</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| with sv_col2: | |
| approved_count = int(st.session_state.review_df["approve"].sum()) if "approve" in st.session_state.review_df.columns else 0 | |
| total_count = len(st.session_state.review_df) | |
| st.markdown(f""" | |
| <div class="stat-card" style="padding: 0.75rem 1rem;"> | |
| <div class="stat-value" style="font-size: 1.4rem;">{approved_count}/{total_count}</div> | |
| <div class="stat-label">Topics Approved</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ================================================================== | |
| # Tab 7: Charts | |
| # ================================================================== | |
| with tab_charts: | |
| st.markdown('<div class="section-header">π Topic Visualizations</div>', unsafe_allow_html=True) | |
| # ----------------------------------------------------------- | |
| # Chart 1: Topic Frequency by Source | |
| # ----------------------------------------------------------- | |
| st.markdown("#### π Topic Frequency by Source") | |
| st.caption("Number of topics discovered from each source (Titles vs Abstracts).") | |
| source_counts = result.combined_topics["source"].value_counts().reset_index() | |
| source_counts.columns = ["Source", "Count"] | |
| fig1 = px.bar( | |
| source_counts, | |
| x="Source", | |
| y="Count", | |
| color="Source", | |
| color_discrete_sequence=["#667eea", "#764ba2"], | |
| text="Count", | |
| ) | |
| fig1.update_traces( | |
| textposition="outside", | |
| textfont=dict(color="#e0e0e0", size=14, family="Inter"), | |
| marker=dict( | |
| line=dict(width=0), | |
| ), | |
| ) | |
| fig1.update_layout( | |
| **PLOTLY_LAYOUT, | |
| title="Topic Count by Source", | |
| xaxis_title="Source", | |
| yaxis_title="Number of Topics", | |
| showlegend=False, | |
| height=420, | |
| ) | |
| st.plotly_chart(fig1, use_container_width=True) | |
| st.markdown("---") | |
| # ----------------------------------------------------------- | |
| # Chart 2: Top Keywords Across All Topics | |
| # ----------------------------------------------------------- | |
| st.markdown("#### π€ Top Keywords Across All Topics") | |
| st.caption("Most frequently occurring keywords across all discovered topics.") | |
| # Extract all keywords, count frequencies | |
| all_keywords = [] | |
| for kw_str in result.combined_topics["keywords"].dropna(): | |
| for kw in kw_str.split(","): | |
| kw_clean = kw.strip().lower() | |
| if kw_clean and len(kw_clean) > 2: | |
| all_keywords.append(kw_clean) | |
| kw_counts = pd.Series(all_keywords).value_counts().head(25).reset_index() | |
| kw_counts.columns = ["Keyword", "Frequency"] | |
| fig2 = px.bar( | |
| kw_counts, | |
| x="Frequency", | |
| y="Keyword", | |
| orientation="h", | |
| color="Frequency", | |
| color_continuous_scale=["#302b63", "#667eea", "#f093fb", "#f5576c"], | |
| ) | |
| fig2.update_traces( | |
| marker=dict(line=dict(width=0)), | |
| ) | |
| fig2.update_layout( | |
| **PLOTLY_LAYOUT, | |
| title="Top 25 Keywords by Frequency", | |
| xaxis_title="Frequency (across all topics)", | |
| yaxis_title="", | |
| height=700, | |
| coloraxis_colorbar=dict( | |
| title="Freq", | |
| tickfont=dict(color="#c4b5fd"), | |
| title_font=dict(color="#e0e0e0"), | |
| ), | |
| ) | |
| # Override yaxis separately to avoid duplicate keyword with PLOTLY_LAYOUT | |
| fig2.update_layout( | |
| yaxis=dict( | |
| autorange="reversed", | |
| gridcolor="rgba(102, 126, 234, 0.1)", | |
| tickfont=dict(color="#c4b5fd", size=12), | |
| ), | |
| ) | |
| st.plotly_chart(fig2, use_container_width=True) | |
| st.markdown("---") | |
| # ----------------------------------------------------------- | |
| # Chart 3: Taxonomy Distribution (Mapped vs Novel) | |
| # ----------------------------------------------------------- | |
| st.markdown("#### 𧬠Taxonomy Classification Distribution") | |
| st.caption("How topics are classified against the known research taxonomy.") | |
| tax_meta = result.taxonomy_map.get("metadata", {}) | |
| tax_data = pd.DataFrame({ | |
| "Classification": ["MAPPED", "NOVEL"], | |
| "Count": [tax_meta.get("mapped_count", 0), tax_meta.get("novel_count", 0)], | |
| }) | |
| chart3_col1, chart3_col2 = st.columns(2) | |
| with chart3_col1: | |
| fig3a = px.pie( | |
| tax_data, | |
| values="Count", | |
| names="Classification", | |
| color="Classification", | |
| color_discrete_map={ | |
| "MAPPED": "#38ef7d", | |
| "NOVEL": "#f7b733", | |
| }, | |
| hole=0.55, | |
| ) | |
| fig3a.update_traces( | |
| textfont=dict(color="#ffffff", size=14), | |
| textinfo="percent+label", | |
| marker=dict(line=dict(color="#1a1a2e", width=3)), | |
| ) | |
| fig3a.update_layout( | |
| paper_bgcolor="rgba(26, 26, 46, 0.95)", | |
| plot_bgcolor="rgba(22, 33, 62, 0.95)", | |
| font=dict(family="Inter, sans-serif", size=13, color="#e0e0e0"), | |
| title=dict(text="Mapped vs Novel", font=dict(size=16, color="#ffffff")), | |
| legend=dict(font=dict(color="#e0e0e0")), | |
| height=380, | |
| margin=dict(l=20, r=20, t=50, b=20), | |
| ) | |
| st.plotly_chart(fig3a, use_container_width=True) | |
| with chart3_col2: | |
| fig3b = px.bar( | |
| tax_data, | |
| x="Classification", | |
| y="Count", | |
| color="Classification", | |
| color_discrete_map={ | |
| "MAPPED": "#38ef7d", | |
| "NOVEL": "#f7b733", | |
| }, | |
| text="Count", | |
| ) | |
| fig3b.update_traces( | |
| textposition="outside", | |
| textfont=dict(color="#e0e0e0", size=16, family="Inter"), | |
| marker=dict(line=dict(width=0)), | |
| ) | |
| fig3b.update_layout( | |
| **PLOTLY_LAYOUT, | |
| title="Classification Count", | |
| xaxis_title="", | |
| yaxis_title="Number of Topics", | |
| showlegend=False, | |
| height=380, | |
| ) | |
| st.plotly_chart(fig3b, use_container_width=True) | |
| st.markdown("---") | |
| # ----------------------------------------------------------- | |
| # Chart 4: Alignment Distribution (from comparisons) | |
| # ----------------------------------------------------------- | |
| if not result.comparison.empty: | |
| st.markdown("#### βοΈ Theme Alignment Distribution") | |
| st.caption("Distribution of alignment strength between Title and Abstract topics.") | |
| alignment_data = result.comparison["alignment"].value_counts().reset_index() | |
| alignment_data.columns = ["Alignment", "Count"] | |
| # Define order and colors | |
| align_order = ["Strong", "Moderate", "Weak", "No Match"] | |
| align_colors = { | |
| "Strong": "#38ef7d", | |
| "Moderate": "#4facfe", | |
| "Weak": "#f7971e", | |
| "No Match": "#f5576c", | |
| } | |
| fig4 = px.bar( | |
| alignment_data, | |
| x="Alignment", | |
| y="Count", | |
| color="Alignment", | |
| color_discrete_map=align_colors, | |
| text="Count", | |
| category_orders={"Alignment": align_order}, | |
| ) | |
| fig4.update_traces( | |
| textposition="outside", | |
| textfont=dict(color="#e0e0e0", size=14, family="Inter"), | |
| marker=dict(line=dict(width=0)), | |
| ) | |
| fig4.update_layout( | |
| **PLOTLY_LAYOUT, | |
| title="Title β Abstract Alignment Distribution", | |
| xaxis_title="Alignment Level", | |
| yaxis_title="Number of Topic Pairs", | |
| showlegend=False, | |
| height=420, | |
| ) | |
| st.plotly_chart(fig4, use_container_width=True) | |
| st.markdown("---") | |
| # ----------------------------------------------------------- | |
| # Chart 5: Similarity Score Histogram | |
| # ----------------------------------------------------------- | |
| st.markdown("#### π Similarity Score Distribution") | |
| st.caption("Distribution of Jaccard similarity scores between matched Title and Abstract topics.") | |
| fig5 = px.histogram( | |
| result.comparison, | |
| x="similarity", | |
| nbins=30, | |
| color_discrete_sequence=["#667eea"], | |
| marginal="box", | |
| ) | |
| fig5.update_traces( | |
| marker=dict( | |
| line=dict(width=1, color="#b8b5ff"), | |
| ), | |
| selector=dict(type="histogram"), | |
| ) | |
| fig5.update_layout( | |
| **PLOTLY_LAYOUT, | |
| title="Similarity Score Histogram", | |
| xaxis_title="Jaccard Similarity Score", | |
| yaxis_title="Count", | |
| height=420, | |
| bargap=0.05, | |
| ) | |
| st.plotly_chart(fig5, use_container_width=True) | |
| # ---- Downloads Section ---- | |
| st.markdown('<div class="section-header">π₯ Download Outputs</div>', unsafe_allow_html=True) | |
| dcol1, dcol2, dcol3, dcol4 = st.columns(4) | |
| with dcol1: | |
| csv_data = result.combined_topics.to_csv(index=False) | |
| st.download_button( | |
| "β¬οΈ Topics Table (CSV)", | |
| data=csv_data, | |
| file_name="topics_table.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| with dcol2: | |
| comp_data = result.comparison.to_csv(index=False) | |
| st.download_button( | |
| "β¬οΈ Comparison (CSV)", | |
| data=comp_data, | |
| file_name="comparison.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| with dcol3: | |
| json_data = json.dumps(result.taxonomy_map, indent=2, ensure_ascii=False) | |
| st.download_button( | |
| "β¬οΈ Taxonomy Map (JSON)", | |
| data=json_data, | |
| file_name="taxonomy_map.json", | |
| mime="application/json", | |
| use_container_width=True, | |
| ) | |
| with dcol4: | |
| # Download review table if it exists | |
| review_path = os.path.join(OUTPUT_DIR, "review_table.csv") | |
| if os.path.exists(review_path): | |
| with open(review_path, "r") as f: | |
| review_data = f.read() | |
| st.download_button( | |
| "β¬οΈ Review Table (CSV)", | |
| data=review_data, | |
| file_name="review_table.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| else: | |
| st.download_button( | |
| "β¬οΈ Review Table (CSV)", | |
| data="Not saved yet. Go to Review Table tab and click Save.", | |
| file_name="review_table.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| disabled=True, | |
| ) | |
| # ---- Auto-save comparison.csv and taxonomy_map.json to outputs ---- | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| result.comparison.to_csv(os.path.join(OUTPUT_DIR, "comparison.csv"), index=False) | |
| with open(os.path.join(OUTPUT_DIR, "taxonomy_map.json"), "w", encoding="utf-8") as f: | |
| json.dump(result.taxonomy_map, f, indent=2, ensure_ascii=False) | |
| # ---- Pipeline Log ---- | |
| with st.expander("π Pipeline Execution Log"): | |
| st.markdown(f"**Status:** `{result.status}`") | |
| st.markdown(f"**Steps Completed:** {len(result.steps_completed)}/9") | |
| for i, step in enumerate(result.steps_completed, 1): | |
| st.markdown(f" β Step {i}: `{step}`") | |
| if result.errors: | |
| st.markdown("**Errors:**") | |
| for err in result.errors: | |
| st.error(err) | |
| st.markdown("**Exported Files:**") | |
| for name, path in result.exported_files.items(): | |
| st.markdown(f" π `{name}` β `{path}`") | |
| else: | |
| # ---- Welcome / instructions when no results ---- | |
| st.markdown(""" | |
| <div class="info-box"> | |
| <strong>π Welcome!</strong><br><br> | |
| This application uses an AI agent to perform comprehensive topic modeling on research papers. | |
| <br><br> | |
| <strong>How to use:</strong><br> | |
| 1οΈβ£ Upload a CSV file with <code>Title</code> and <code>Abstract</code> columns (or use the default dataset).<br> | |
| 2οΈβ£ Configure the minimum number of topics and label generation method in the sidebar.<br> | |
| 3οΈβ£ Click <strong>"π Run Topic Modeling Agent"</strong> to start the analysis.<br> | |
| 4οΈβ£ Explore topics, comparisons, and taxonomy classification in the results tabs.<br> | |
| 5οΈβ£ Review and annotate topics in the <strong>βοΈ Review Table</strong> tab.<br> | |
| 6οΈβ£ View interactive charts in the <strong>π Charts</strong> tab.<br> | |
| 7οΈβ£ Download all outputs as CSV and JSON files. | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Show a preview if default dataset exists | |
| if os.path.exists("dataset.csv"): | |
| with st.expander("π Preview Default Dataset", expanded=False): | |
| try: | |
| preview_df = pd.read_csv("dataset.csv", nrows=10) | |
| st.markdown(f"**Columns:** {', '.join(preview_df.columns.tolist())}") | |
| if "Title" in preview_df.columns: | |
| st.dataframe(preview_df[["Title", "Abstract"]].head(10) if "Abstract" in preview_df.columns else preview_df[["Title"]].head(10), use_container_width=True) | |
| else: | |
| st.dataframe(preview_df.head(10), use_container_width=True) | |
| except Exception as e: | |
| st.warning(f"Could not preview dataset: {e}") | |