| |
|
|
| import streamlit as st |
|
|
| |
| st.set_page_config(page_title="SNAP", layout="wide") |
|
|
| |
| import warnings |
| |
| warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*') |
| warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*') |
|
|
| import pandas as pd |
| import numpy as np |
| import os |
| import io |
| import time |
| from datetime import datetime |
| import base64 |
| import re |
| import pickle |
| from typing import List, Dict, Any, Tuple |
| import plotly.express as px |
| import torch |
|
|
| |
| from concurrent.futures import ThreadPoolExecutor |
| from functools import partial |
|
|
| |
| from sentence_transformers import SentenceTransformer |
| from sklearn.metrics.pairwise import cosine_similarity |
| from bertopic import BERTopic |
| from hdbscan import HDBSCAN |
| import nltk |
| from nltk.corpus import stopwords |
| from nltk.tokenize import word_tokenize |
|
|
| |
| from langchain.chains import LLMChain |
| from langchain_community.chat_models import ChatOpenAI |
| from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate |
| from openai import OpenAI |
| from transformers import GPT2TokenizerFast |
|
|
| |
| client = OpenAI() |
|
|
| |
| |
| |
| def get_base_dir(): |
| try: |
| base_dir = os.path.dirname(__file__) |
| if not base_dir: |
| return os.getcwd() |
| return base_dir |
| except NameError: |
| |
| return os.getcwd() |
|
|
| BASE_DIR = get_base_dir() |
|
|
| |
| def get_model_dir(): |
| base_dir = get_base_dir() |
| model_dir = os.path.join(base_dir, 'models') |
| os.makedirs(model_dir, exist_ok=True) |
| return model_dir |
|
|
| |
| def load_tokenizer(): |
| model_dir = get_model_dir() |
| tokenizer_dir = os.path.join(model_dir, 'tokenizer') |
| os.makedirs(tokenizer_dir, exist_ok=True) |
| |
| try: |
| |
| tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir) |
| |
| except Exception as e: |
| |
| try: |
| |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
| tokenizer.save_pretrained(tokenizer_dir) |
| |
| except Exception as download_e: |
| |
| raise |
| |
| return tokenizer |
|
|
| |
| try: |
| tokenizer = load_tokenizer() |
| except Exception as e: |
| |
| tokenizer = None |
|
|
| MAX_CONTEXT_WINDOW = 128000 |
|
|
| |
| if 'chat_history' not in st.session_state: |
| st.session_state.chat_history = [] |
|
|
| |
| |
| |
| def get_chat_response(messages): |
| try: |
| response = client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=messages, |
| temperature=0, |
| ) |
| return response.choices[0].message.content.strip() |
| except Exception as e: |
| st.error(f"Error querying OpenAI: {e}") |
| return None |
|
|
| |
| |
| |
| def generate_raw_cluster_summary( |
| topic_val: int, |
| cluster_df: pd.DataFrame, |
| llm: Any, |
| chat_prompt: Any |
| ) -> Dict[str, Any]: |
| """Generate a summary for a single cluster without reference enhancement, |
| automatically trimming text if it exceeds a safe token limit.""" |
| cluster_text = " ".join(cluster_df['text'].tolist()) |
| if not cluster_text.strip(): |
| return None |
|
|
| |
| safe_limit = int(MAX_CONTEXT_WINDOW * 0.95) |
| |
| |
| encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False) |
| |
| |
| if len(encoded_text) > safe_limit: |
| |
| encoded_text = encoded_text[:safe_limit] |
| cluster_text = tokenizer.decode(encoded_text) |
| |
| user_prompt_local = f"**Text to summarize**: {cluster_text}" |
| try: |
| local_chain = LLMChain(llm=llm, prompt=chat_prompt) |
| summary_local = local_chain.run(user_prompt=user_prompt_local).strip() |
| return {'Topic': topic_val, 'Summary': summary_local} |
| except Exception as e: |
| st.error(f"Error generating summary for cluster {topic_val}: {str(e)}") |
| return None |
|
|
| |
| |
| |
| def enhance_summary_with_references( |
| summary_dict: Dict[str, Any], |
| df_scope: pd.DataFrame, |
| reference_id_column: str, |
| url_column: str = None, |
| llm: Any = None |
| ) -> Dict[str, Any]: |
| """Add references to a summary.""" |
| if not summary_dict or 'Summary' not in summary_dict: |
| return summary_dict |
| |
| try: |
| cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']] |
| enhanced = add_references_to_summary( |
| summary_dict['Summary'], |
| cluster_df, |
| reference_id_column, |
| url_column, |
| llm |
| ) |
| summary_dict['Enhanced_Summary'] = enhanced |
| return summary_dict |
| except Exception as e: |
| st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}") |
| return summary_dict |
|
|
| |
| |
| |
| def process_summaries_in_parallel( |
| df_scope: pd.DataFrame, |
| unique_selected_topics: List[int], |
| llm: Any, |
| chat_prompt: Any, |
| enable_references: bool = False, |
| reference_id_column: str = None, |
| url_column: str = None, |
| max_workers: int = 16 |
| ) -> List[Dict[str, Any]]: |
| """Process multiple cluster summaries in parallel using ThreadPoolExecutor.""" |
| summaries = [] |
| total_topics = len(unique_selected_topics) |
| |
| |
| progress_text = st.empty() |
| progress_bar = st.progress(0) |
| |
| try: |
| |
| progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)") |
| completed_summaries = 0 |
| |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| |
| future_to_topic = { |
| executor.submit( |
| generate_raw_cluster_summary, |
| topic_val, |
| df_scope[df_scope['Topic'] == topic_val], |
| llm, |
| chat_prompt |
| ): topic_val |
| for topic_val in unique_selected_topics |
| } |
| |
| |
| for future in future_to_topic: |
| try: |
| result = future.result() |
| if result: |
| summaries.append(result) |
| completed_summaries += 1 |
| |
| progress = completed_summaries / total_topics |
| progress_bar.progress(progress) |
| progress_text.text( |
| f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)" |
| ) |
| except Exception as e: |
| topic_val = future_to_topic[future] |
| st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}") |
| completed_summaries += 1 |
| continue |
| |
| |
| if enable_references and reference_id_column and summaries: |
| total_to_enhance = len(summaries) |
| completed_enhancements = 0 |
| progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)") |
| progress_bar.progress(0) |
| |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| |
| future_to_summary = { |
| executor.submit( |
| enhance_summary_with_references, |
| summary_dict, |
| df_scope, |
| reference_id_column, |
| url_column, |
| llm |
| ): summary_dict.get('Topic') |
| for summary_dict in summaries |
| } |
| |
| |
| enhanced_summaries = [] |
| for future in future_to_summary: |
| try: |
| result = future.result() |
| if result: |
| enhanced_summaries.append(result) |
| completed_enhancements += 1 |
| |
| progress = completed_enhancements / total_to_enhance |
| progress_bar.progress(progress) |
| progress_text.text( |
| f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)" |
| ) |
| except Exception as e: |
| topic_val = future_to_summary[future] |
| st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}") |
| completed_enhancements += 1 |
| continue |
| |
| summaries = enhanced_summaries |
|
|
| |
| if summaries: |
| total_to_name = len(summaries) |
| completed_names = 0 |
| progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)") |
| progress_bar.progress(0) |
|
|
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| |
| future_to_summary = { |
| executor.submit( |
| generate_cluster_name, |
| summary_dict.get('Enhanced_Summary', summary_dict['Summary']), |
| llm |
| ): summary_dict.get('Topic') |
| for summary_dict in summaries |
| } |
|
|
| |
| named_summaries = [] |
| for future in future_to_summary: |
| try: |
| cluster_name = future.result() |
| topic_val = future_to_summary[future] |
| |
| summary_dict = next(s for s in summaries if s['Topic'] == topic_val) |
| summary_dict['Cluster_Name'] = cluster_name |
| named_summaries.append(summary_dict) |
| completed_names += 1 |
| |
| progress = completed_names / total_to_name |
| progress_bar.progress(progress) |
| progress_text.text( |
| f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)" |
| ) |
| except Exception as e: |
| topic_val = future_to_summary[future] |
| st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}") |
| completed_names += 1 |
| continue |
|
|
| summaries = named_summaries |
| finally: |
| |
| progress_text.empty() |
| progress_bar.empty() |
| |
| return summaries |
|
|
| |
| |
| |
| def generate_cluster_name(summary_text: str, llm: Any) -> str: |
| """Generate a concise, descriptive name for a cluster based on its summary.""" |
| system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster. |
| |
| Rules: |
| 1. Keep it between 3-6 words |
| 2. Be specific but concise |
| 3. Capture the main theme/focus |
| 4. Use title case |
| 4. Do not include words like "Cluster", "Topic", or "Theme" |
| 5. Focus on the content, not metadata |
| |
| Example good names: |
| - Agricultural Water Management Innovation |
| - Gender Equality in Farming |
| - Climate-Smart Village Implementation |
| - Sustainable Livestock Practices""" |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"} |
| ] |
|
|
| try: |
| response = get_chat_response(messages) |
| |
| cluster_name = response.strip().strip('"').strip("'").strip() |
| return cluster_name |
| except Exception as e: |
| st.error(f"Error generating cluster name: {str(e)}") |
| return "Unnamed Cluster" |
|
|
| |
| |
| |
| def get_base_dir(): |
| try: |
| base_dir = os.path.dirname(__file__) |
| if not base_dir: |
| return os.getcwd() |
| return base_dir |
| except NameError: |
| |
| return os.getcwd() |
|
|
| BASE_DIR = get_base_dir() |
|
|
| |
| |
| |
| def init_nltk_resources(): |
| """Initialize NLTK resources with better error handling and less verbose output""" |
| nltk.data.path.append('/home/appuser/nltk_data') |
| |
| resources = { |
| 'tokenizers/punkt': 'punkt_tab', |
| 'corpora/stopwords': 'stopwords' |
| } |
| |
| for resource_path, resource_name in resources.items(): |
| try: |
| nltk.data.find(resource_path) |
| except LookupError: |
| try: |
| nltk.download(resource_name, quiet=True) |
| except Exception as e: |
| st.warning(f"Error downloading NLTK resource {resource_name}: {e}") |
| |
| |
| try: |
| from nltk.tokenize import PunktSentenceTokenizer |
| tokenizer = PunktSentenceTokenizer() |
| tokenizer.tokenize("Test sentence.") |
| except Exception as e: |
| st.error(f"Error initializing NLTK tokenizer: {e}") |
| try: |
| nltk.download('punkt_tab', quiet=True) |
| except Exception as e: |
| st.error(f"Failed to download punkt_tab tokenizer: {e}") |
|
|
| |
| init_nltk_resources() |
|
|
| |
| |
| |
| def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None): |
| """ |
| Add references to a summary by identifying which parts of the summary come |
| from which source documents. References will be appended as [ID], |
| optionally linked if a URL column is provided. |
| |
| Args: |
| summary (str): The summary text to enhance with references. |
| source_df (DataFrame): DataFrame containing the source documents. |
| reference_column (str): Column name to use for reference IDs. |
| url_column (str, optional): Column name containing URLs for hyperlinks. |
| llm (LLM, optional): Language model for source attribution. |
| Returns: |
| str: Enhanced summary with references as HTML if possible. |
| """ |
| if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns: |
| return summary |
| |
| |
| if llm is None: |
| return summary |
|
|
| |
| paragraphs = summary.split('\n\n') |
| enhanced_paragraphs = [] |
|
|
| |
| source_texts = [] |
| reference_ids = [] |
| urls = [] |
| for _, row in source_df.iterrows(): |
| if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]): |
| source_texts.append(str(row['text'])) |
| reference_ids.append(str(row[reference_column])) |
| if url_column and url_column in row and pd.notna(row[url_column]): |
| urls.append(str(row[url_column])) |
| else: |
| urls.append(None) |
| if not source_texts: |
| return summary |
|
|
| |
| url_map = {} |
| for ref_id, u in zip(reference_ids, urls): |
| if u: |
| url_map[ref_id] = u |
|
|
| |
| system_prompt = """ |
| You are an expert at identifying the source of information. You will be given: |
| 1. A sentence or bullet point from a summary |
| 2. A list of source texts with their IDs |
| |
| Your task is to identify which source text(s) the text most likely came from. |
| Return ONLY the IDs of the source texts that contributed to the text, separated by commas. |
| If you cannot confidently attribute the text to any source, return "unknown". |
| """ |
|
|
| for paragraph in paragraphs: |
| if not paragraph.strip(): |
| enhanced_paragraphs.append('') |
| continue |
|
|
| |
| if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')): |
| |
| bullet_lines = paragraph.split('\n') |
| enhanced_bullets = [] |
| for line in bullet_lines: |
| if not line.strip(): |
| enhanced_bullets.append(line) |
| continue |
| |
| if line.strip().startswith('- ') or line.strip().startswith('* '): |
| |
| source_texts_formatted = '\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)]) |
| user_prompt = f""" |
| Text: {line.strip()} |
| |
| Source texts: |
| {source_texts_formatted} |
| |
| Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown". |
| """ |
|
|
| try: |
| system_message = SystemMessagePromptTemplate.from_template(system_prompt) |
| human_message = HumanMessagePromptTemplate.from_template({user_prompt}) |
| chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message]) |
| chain = LLMChain(llm=llm, prompt=chat_prompt) |
| response = chain.run(user_prompt=user_prompt) |
| source_ids = response.strip() |
|
|
| if source_ids.lower() == "unknown": |
| enhanced_bullets.append(line) |
| else: |
| |
| source_ids = re.sub(r'[^0-9,\s]', '', source_ids) |
| source_ids = re.sub(r'\s+', '', source_ids) |
| ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()] |
| |
| if ids: |
| ref_parts = [] |
| for id_ in ids: |
| if id_ in url_map: |
| ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>') |
| else: |
| ref_parts.append(id_) |
| ref_string = ", ".join(ref_parts) |
| enhanced_bullets.append(f"{line} [{ref_string}]") |
| else: |
| enhanced_bullets.append(line) |
| except Exception: |
| enhanced_bullets.append(line) |
| else: |
| enhanced_bullets.append(line) |
| |
| enhanced_paragraphs.append('\n'.join(enhanced_bullets)) |
| else: |
| |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
| enhanced_sentences = [] |
|
|
| for sentence in sentences: |
| if not sentence.strip(): |
| continue |
|
|
| source_texts_formatted = '\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)]) |
| user_prompt = f""" |
| Sentence: {sentence.strip()} |
| |
| Source texts: |
| {source_texts_formatted} |
| |
| Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown". |
| """ |
|
|
| try: |
| system_message = SystemMessagePromptTemplate.from_template(system_prompt) |
| human_message = HumanMessagePromptTemplate.from_template({user_prompt}) |
| chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message]) |
| chain = LLMChain(llm=llm, prompt=chat_prompt) |
| response = chain.run(user_prompt=user_prompt) |
| source_ids = response.strip() |
|
|
| if source_ids.lower() == "unknown": |
| enhanced_sentences.append(sentence) |
| else: |
| |
| source_ids = re.sub(r'[^0-9,\s]', '', source_ids) |
| source_ids = re.sub(r'\s+', '', source_ids) |
| ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()] |
| |
| if ids: |
| ref_parts = [] |
| for id_ in ids: |
| if id_ in url_map: |
| ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>') |
| else: |
| ref_parts.append(id_) |
| ref_string = ", ".join(ref_parts) |
| enhanced_sentences.append(f"{sentence} [{ref_string}]") |
| else: |
| enhanced_sentences.append(sentence) |
| except Exception: |
| enhanced_sentences.append(sentence) |
|
|
| enhanced_paragraphs.append(' '.join(enhanced_sentences)) |
|
|
| |
| return '\n\n'.join(enhanced_paragraphs) |
|
|
|
|
| st.sidebar.image("static/SNAP_logo.png", width=350) |
|
|
| |
| |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| if device == 'cuda': |
| st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}") |
| else: |
| st.sidebar.info("Using CPU") |
|
|
| |
| |
| |
| @st.cache_resource |
| def get_embedding_model(): |
| model_dir = get_model_dir() |
| st_model_dir = os.path.join(model_dir, 'sentence_transformer') |
| os.makedirs(st_model_dir, exist_ok=True) |
| |
| model_name = 'all-MiniLM-L6-v2' |
| try: |
| |
| model = SentenceTransformer(st_model_dir) |
| |
| except Exception as e: |
| |
| try: |
| |
| model = SentenceTransformer(model_name) |
| model.save(st_model_dir) |
| |
| except Exception as download_e: |
| st.error(f"Error downloading sentence transformer model: {str(download_e)}") |
| raise |
| |
| return model.to(device) |
|
|
| def generate_embeddings(texts, model): |
| with st.spinner('Calculating embeddings...'): |
| embeddings = model.encode(texts, show_progress_bar=True, device=device) |
| return embeddings |
|
|
| @st.cache_data |
| def load_default_dataset(default_dataset_path): |
| if os.path.exists(default_dataset_path): |
| df_ = pd.read_excel(default_dataset_path) |
| return df_ |
| else: |
| st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.") |
| return None |
|
|
| @st.cache_data |
| def load_uploaded_dataset(uploaded_file): |
| df_ = pd.read_excel(uploaded_file) |
| return df_ |
|
|
| def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None): |
| """ |
| Loads pre-computed embeddings from a pickle file if they match current data, |
| otherwise computes and caches them. |
| """ |
| if not text_columns: |
| return None, None |
|
|
| base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset" |
| if uploaded_file_name: |
| base_name = os.path.splitext(uploaded_file_name)[0] |
|
|
| cols_key = "_".join(sorted(text_columns)) |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| embeddings_dir = BASE_DIR |
| if using_default_dataset: |
| embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl') |
| else: |
| |
| embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl") |
|
|
| df_fill = df.fillna("") |
| texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist() |
|
|
| |
| if ('embeddings' in st.session_state |
| and 'last_text_columns' in st.session_state |
| and st.session_state['last_text_columns'] == text_columns |
| and len(st.session_state['embeddings']) == len(texts)): |
| return st.session_state['embeddings'], st.session_state.get('embeddings_file', None) |
|
|
| |
| if os.path.exists(embeddings_file): |
| with open(embeddings_file, 'rb') as f: |
| embeddings = pickle.load(f) |
| if len(embeddings) == len(texts): |
| st.write("Loaded pre-calculated embeddings.") |
| st.session_state['embeddings'] = embeddings |
| st.session_state['embeddings_file'] = embeddings_file |
| st.session_state['last_text_columns'] = text_columns |
| return embeddings, embeddings_file |
|
|
| |
| st.write("Generating embeddings...") |
| model = get_embedding_model() |
| embeddings = generate_embeddings(texts, model) |
| with open(embeddings_file, 'wb') as f: |
| pickle.dump(embeddings, f) |
|
|
| st.session_state['embeddings'] = embeddings |
| st.session_state['embeddings_file'] = embeddings_file |
| st.session_state['last_text_columns'] = text_columns |
| return embeddings, embeddings_file |
|
|
|
|
| |
| |
| |
| def reset_filters(): |
| st.session_state['selected_additional_filters'] = {} |
|
|
| |
| st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view") |
|
|
| if st.session_state.view == "Power User Mode": |
| st.header("Power User Mode") |
| |
| |
| |
| st.sidebar.title("Data Selection") |
| dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset')) |
|
|
| if 'df' not in st.session_state: |
| st.session_state['df'] = pd.DataFrame() |
| if 'filtered_df' not in st.session_state: |
| st.session_state['filtered_df'] = pd.DataFrame() |
|
|
| if dataset_option == 'PRMS 2022+2023+2024 QAed': |
| default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx') |
| df = load_default_dataset(default_dataset_path) |
| if df is not None: |
| st.session_state['df'] = df.copy() |
| st.session_state['using_default_dataset'] = True |
| |
| |
| if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty: |
| st.session_state['filtered_df'] = df.copy() |
| |
| |
| if 'filter_state' not in st.session_state: |
| st.session_state['filter_state'] = { |
| 'applied': False, |
| 'filters': {} |
| } |
| |
| |
| if 'text_columns' not in st.session_state or not st.session_state['text_columns']: |
| default_text_cols = [] |
| if 'Title' in df.columns and 'Description' in df.columns: |
| default_text_cols = ['Title', 'Description'] |
| st.session_state['text_columns'] = default_text_cols |
| |
| |
| |
| |
| |
|
|
| df_cols = df.columns.tolist() |
|
|
| |
| st.subheader("Select Filters") |
| if 'additional_filters_selected' not in st.session_state: |
| st.session_state['additional_filters_selected'] = [] |
| if 'filter_values' not in st.session_state: |
| st.session_state['filter_values'] = {} |
|
|
| with st.form("filter_selection_form"): |
| all_columns = df.columns.tolist() |
| selected_additional_cols = st.multiselect( |
| "Select columns from your dataset to use as filters:", |
| all_columns, |
| default=st.session_state['additional_filters_selected'] |
| ) |
| add_filters_submitted = st.form_submit_button("Add Additional Filters") |
|
|
| if add_filters_submitted: |
| if selected_additional_cols != st.session_state['additional_filters_selected']: |
| st.session_state['additional_filters_selected'] = selected_additional_cols |
| |
| st.session_state['filter_values'] = { |
| k: v for k, v in st.session_state['filter_values'].items() |
| if k in selected_additional_cols |
| } |
|
|
| |
| if st.session_state['additional_filters_selected']: |
| st.subheader("Apply Filters") |
| |
| |
| for col_name in st.session_state['additional_filters_selected']: |
| unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
| |
| |
| search_key = f"search_{col_name}" |
| if search_key not in st.session_state: |
| st.session_state[search_key] = "" |
| |
| col1, col2 = st.columns([3, 1]) |
| with col1: |
| search_term = st.text_input( |
| f"Search in {col_name}", |
| key=search_key, |
| help="Enter text to find and select all matching values" |
| ) |
| with col2: |
| if st.button(f"Select Matching", key=f"select_{col_name}"): |
| |
| if search_term: |
| matching_vals = [ |
| val for val in unique_vals |
| if any(search_term.lower() in str(part).lower() |
| for part in (val.split(',') if isinstance(val, str) else [val])) |
| ] |
| |
| current_selected = st.session_state['filter_values'].get(col_name, []) |
| st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals)) |
| |
| |
| if matching_vals: |
| st.success(f"Found and selected {len(matching_vals)} matching values") |
| else: |
| st.warning("No matching values found") |
|
|
| |
| with st.form("apply_filters_form"): |
| for col_name in st.session_state['additional_filters_selected']: |
| unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
| selected_vals = st.multiselect( |
| f"Filter by {col_name}", |
| options=unique_vals, |
| default=st.session_state['filter_values'].get(col_name, []) |
| ) |
| st.session_state['filter_values'][col_name] = selected_vals |
|
|
| |
| col1, col2 = st.columns([1, 4]) |
| with col1: |
| clear_filters = st.form_submit_button("Clear All") |
| with col2: |
| apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset") |
|
|
| if clear_filters: |
| st.session_state['filter_values'] = {} |
| |
| if 'summary_df' in st.session_state: |
| del st.session_state['summary_df'] |
| if 'high_level_summary' in st.session_state: |
| del st.session_state['high_level_summary'] |
| if 'enhanced_summary' in st.session_state: |
| del st.session_state['enhanced_summary'] |
| st.rerun() |
|
|
| |
| with st.expander("⚙️ Advanced Settings", expanded=False): |
| st.subheader("**Select Text Columns for Embedding**") |
| text_columns_selected = st.multiselect( |
| "Text Columns:", |
| df_cols, |
| default=st.session_state['text_columns'], |
| help="Choose columns containing text for semantic search and clustering. " |
| "If multiple are selected, their text will be concatenated." |
| ) |
| st.session_state['text_columns'] = text_columns_selected |
|
|
| |
| filtered_df = df.copy() |
| if 'apply_filters_submitted' in locals() and apply_filters_submitted: |
| |
| if 'summary_df' in st.session_state: |
| del st.session_state['summary_df'] |
| if 'high_level_summary' in st.session_state: |
| del st.session_state['high_level_summary'] |
| if 'enhanced_summary' in st.session_state: |
| del st.session_state['enhanced_summary'] |
| |
| for col_name in st.session_state['additional_filters_selected']: |
| selected_vals = st.session_state['filter_values'].get(col_name, []) |
| if selected_vals: |
| filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
| st.success("Filters applied successfully!") |
| st.session_state['filtered_df'] = filtered_df.copy() |
| st.session_state['filter_state'] = { |
| 'applied': True, |
| 'filters': st.session_state['filter_values'].copy() |
| } |
| |
| for k in ['clustered_data', 'topic_model', 'current_clustering_data', |
| 'current_clustering_option', 'hierarchy']: |
| if k in st.session_state: |
| del st.session_state[k] |
|
|
| elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']: |
| |
| for col_name, selected_vals in st.session_state['filter_state']['filters'].items(): |
| if selected_vals: |
| filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
| st.session_state['filtered_df'] = filtered_df.copy() |
|
|
| |
| if st.session_state['filtered_df'] is not None: |
| if st.session_state['filter_state']['applied']: |
| st.write("Filtered Data Preview:") |
| else: |
| st.write("Current Data Preview:") |
| st.dataframe(st.session_state['filtered_df'].head(), hide_index=True) |
| st.write(f"Total number of results: {len(st.session_state['filtered_df'])}") |
|
|
| output = io.BytesIO() |
| writer = pd.ExcelWriter(output, engine='openpyxl') |
| st.session_state['filtered_df'].to_excel(writer, index=False) |
| writer.close() |
| processed_data = output.getvalue() |
|
|
| st.download_button( |
| label="Download Current Data", |
| data=processed_data, |
| file_name='data.xlsx', |
| mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' |
| ) |
| else: |
| st.warning("Please ensure the default dataset exists in the 'input' directory.") |
|
|
| else: |
| |
| uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"]) |
| if uploaded_file is not None: |
| df = load_uploaded_dataset(uploaded_file) |
| if df is not None: |
| st.session_state['df'] = df.copy() |
| st.session_state['using_default_dataset'] = False |
| st.session_state['uploaded_file_name'] = uploaded_file.name |
| st.write("Data preview:") |
| st.write(df.head()) |
| df_cols = df.columns.tolist() |
|
|
| st.subheader("**Select Text Columns for Embedding**") |
| text_columns_selected = st.multiselect( |
| "Text Columns:", |
| df_cols, |
| default=df_cols[:1] if df_cols else [] |
| ) |
| st.session_state['text_columns'] = text_columns_selected |
|
|
| st.write("**Additional Filters**") |
| selected_additional_cols = st.multiselect( |
| "Select additional columns from your dataset to use as filters:", |
| df_cols, |
| default=[] |
| ) |
| st.session_state['additional_filters_selected'] = selected_additional_cols |
|
|
| filtered_df = df.copy() |
| for col_name in selected_additional_cols: |
| if f'selected_filter_{col_name}' not in st.session_state: |
| st.session_state[f'selected_filter_{col_name}'] = [] |
| unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
| selected_vals = st.multiselect( |
| f"Filter by {col_name}", |
| options=unique_vals, |
| default=st.session_state[f'selected_filter_{col_name}'] |
| ) |
| st.session_state[f'selected_filter_{col_name}'] = selected_vals |
| if selected_vals: |
| filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
|
|
| st.session_state['filtered_df'] = filtered_df |
| st.write("Filtered Data Preview:") |
| st.dataframe(filtered_df.head(), hide_index=True) |
| st.write(f"Total number of results: {len(filtered_df)}") |
|
|
| output = io.BytesIO() |
| writer = pd.ExcelWriter(output, engine='openpyxl') |
| filtered_df.to_excel(writer, index=False) |
| writer.close() |
| processed_data = output.getvalue() |
|
|
| st.download_button( |
| label="Download Filtered Data", |
| data=processed_data, |
| file_name='filtered_data.xlsx', |
| mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' |
| ) |
| else: |
| st.warning("Failed to load the uploaded dataset.") |
| else: |
| st.warning("Please upload an Excel file to proceed.") |
|
|
| if 'filtered_df' in st.session_state: |
| st.write(f"Total number of results: {len(st.session_state['filtered_df'])}") |
|
|
|
|
| |
| |
| |
| if 'active_tab_index' not in st.session_state: |
| st.session_state.active_tab_index = 0 |
|
|
| tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"] |
| tabs = st.tabs(tabs_titles) |
| |
| tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs |
|
|
| |
| |
| |
| with tab_help: |
| st.header("Help") |
| st.markdown(""" |
| ### About SNAP |
| |
| SNAP allows you to explore, filter, search, cluster, and summarize textual datasets. |
| |
| **Workflow**: |
| 1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own. |
| 2. **Filtering**: Set additional filters for your dataset. |
| 3. **Select Text Columns**: Which columns to embed. |
| 4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents. |
| 5. **Clustering** (Tab): Group documents into topics. |
| 6. **Summarization** (Tab): Summarize the clustered documents (with optional references). |
| |
| ### Troubleshooting |
| - If you see no results, try lowering the similarity threshold or removing negative/required keywords. |
| - Ensure you have at least one text column selected for embeddings. |
| """) |
|
|
| |
| |
| |
| with tab_semantic: |
| st.header("Semantic Search") |
| if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
| text_columns = st.session_state.get('text_columns', []) |
| if not text_columns: |
| st.warning("No text columns selected. Please select at least one column for text embedding.") |
| else: |
| df_full = st.session_state['df'] |
| |
| embeddings, _ = load_or_compute_embeddings( |
| df_full, |
| st.session_state.get('using_default_dataset', False), |
| st.session_state.get('uploaded_file_name'), |
| text_columns |
| ) |
|
|
| if embeddings is not None: |
| with st.expander("ℹ️ How Semantic Search Works", expanded=False): |
| st.markdown(""" |
| ### Understanding Semantic Search |
| |
| Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works: |
| |
| 1. **Query Processing**: |
| - Your search query is converted into a numerical representation (embedding) that captures its meaning |
| - Example: Searching for "Climate Smart Villages" will understand the concept, not just the words |
| - Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words |
| |
| 2. **Similarity Matching**: |
| - Documents are ranked by how closely their meaning matches your query |
| - The similarity threshold controls how strict this matching is |
| - Higher threshold (e.g., 0.8) = more precise but fewer results |
| - Lower threshold (e.g., 0.3) = more results but might be less relevant |
| |
| 3. **Advanced Features**: |
| - **Negative Keywords**: Use to explicitly exclude documents containing certain terms |
| - **Required Keywords**: Ensure specific terms appear in the results |
| - These work as traditional keyword filters after the semantic search |
| |
| ### Search Tips |
| |
| - **Phrase Queries**: Enter complete phrases for better context |
| - "Climate Smart Villages" (as one concept) |
| - Better than separate terms: "climate", "smart", "villages" |
| |
| - **Descriptive Queries**: Add context for better results |
| - Instead of: "water" |
| - Better: "water management in agriculture" |
| |
| - **Conceptual Queries**: Focus on concepts rather than specific terms |
| - Instead of: "increased yield" |
| - Better: "agricultural productivity improvements" |
| |
| ### Example Searches |
| |
| 1. **Query**: "Climate Smart Villages" |
| - Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development |
| - Even if they don't use these exact words |
| |
| 2. **Query**: "Gender equality in agriculture" |
| - Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development |
| - Related concepts are captured semantically |
| |
| 3. **Query**: "Sustainable water management" |
| + Required keyword: "irrigation" |
| - Combines semantic understanding of water sustainability with specific irrigation focus |
| """) |
|
|
| with st.form("search_parameters"): |
| query = st.text_input("Enter your search query:") |
| include_keywords = st.text_input("Include only documents containing these words (comma-separated):") |
| similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35) |
| submitted = st.form_submit_button("Search") |
|
|
| if submitted: |
| if query.strip(): |
| with st.spinner("Performing Semantic Search..."): |
| |
| if 'summary_df' in st.session_state: |
| del st.session_state['summary_df'] |
| if 'high_level_summary' in st.session_state: |
| del st.session_state['high_level_summary'] |
| if 'enhanced_summary' in st.session_state: |
| del st.session_state['enhanced_summary'] |
|
|
| model = get_embedding_model() |
| df_filtered = st.session_state['filtered_df'].fillna("") |
| search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist() |
|
|
| |
| subset_indices = df_filtered.index |
| subset_embeddings = embeddings[subset_indices] |
|
|
| query_embedding = model.encode([query], device=device) |
| similarities = cosine_similarity(query_embedding, subset_embeddings)[0] |
|
|
| |
| fig = px.histogram( |
| x=similarities, |
| nbins=30, |
| labels={'x': 'Similarity Score', 'y': 'Number of Documents'}, |
| title='Distribution of Similarity Scores' |
| ) |
| fig.add_vline( |
| x=similarity_threshold, |
| line_dash="dash", |
| line_color="red", |
| annotation_text=f"Threshold: {similarity_threshold:.2f}", |
| annotation_position="top" |
| ) |
| st.write("### Similarity Score Distribution") |
| st.plotly_chart(fig) |
|
|
| above_threshold_indices = np.where(similarities > similarity_threshold)[0] |
| if len(above_threshold_indices) == 0: |
| st.warning("No results found above the similarity threshold.") |
| if 'search_results' in st.session_state: |
| del st.session_state['search_results'] |
| else: |
| selected_indices = subset_indices[above_threshold_indices] |
| results = df_filtered.loc[selected_indices].copy() |
| results['similarity_score'] = similarities[above_threshold_indices] |
| results.sort_values(by='similarity_score', ascending=False, inplace=True) |
|
|
| |
| if include_keywords.strip(): |
| inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()] |
| if inc_words: |
| results = results[ |
| results.apply( |
| lambda row: all( |
| w in (' '.join(row.astype(str)).lower()) for w in inc_words |
| ), |
| axis=1 |
| ) |
| ] |
| |
| if results.empty: |
| st.warning("No results found after applying keyword filters.") |
| if 'search_results' in st.session_state: |
| del st.session_state['search_results'] |
| else: |
| st.session_state['search_results'] = results.copy() |
| output = io.BytesIO() |
| writer = pd.ExcelWriter(output, engine='openpyxl') |
| results.to_excel(writer, index=False) |
| writer.close() |
| processed_data = output.getvalue() |
| st.session_state['search_results_processed_data'] = processed_data |
| else: |
| st.warning("Please enter a query to search.") |
|
|
| |
| if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
| st.write("## Search Results") |
| results = st.session_state['search_results'] |
| cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score'] |
| st.dataframe(results[cols_to_display], hide_index=True) |
| st.write(f"Total number of results: {len(results)}") |
|
|
| if 'search_results_processed_data' in st.session_state: |
| st.download_button( |
| label="Download Full Results", |
| data=st.session_state['search_results_processed_data'], |
| file_name='search_results.xlsx', |
| mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', |
| key='download_search_results' |
| ) |
| else: |
| st.info("No search results to display. Enter a query and click 'Search'.") |
| else: |
| st.warning("No embeddings available because no text columns were chosen.") |
| else: |
| st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.") |
|
|
|
|
| |
| |
| |
| with tab_clustering: |
| st.header("Clustering") |
| if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
| |
| with st.expander("ℹ️ How Clustering Works", expanded=False): |
| st.markdown(""" |
| ### Understanding Document Clustering |
| |
| Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works: |
| |
| 1. **Cluster Formation**: |
| - Documents are grouped based on their semantic similarity |
| - Each cluster represents a distinct theme or topic |
| - Documents that are too different from others may remain unclustered (labeled as -1) |
| - The "Min Cluster Size" parameter controls how clusters are formed |
| |
| 2. **Interpreting Results**: |
| - Each cluster is assigned a number (e.g., 0, 1, 2...) |
| - Cluster -1 contains "outlier" documents that didn't fit well in other clusters |
| - The size of each cluster indicates how common that theme is |
| - Keywords for each cluster show the main topics/concepts |
| |
| 3. **Visualizations**: |
| - **Intertopic Distance Map**: Shows how clusters relate to each other |
| - Closer clusters are more semantically similar |
| - Size of circles indicates number of documents |
| - Hover to see top terms for each cluster |
| |
| - **Topic Document Visualization**: Shows individual documents |
| - Each point is a document |
| - Colors indicate cluster membership |
| - Distance between points shows similarity |
| |
| - **Topic Hierarchy**: Shows how topics are related |
| - Tree structure shows topic relationships |
| - Parent topics contain broader themes |
| - Child topics show more specific sub-themes |
| |
| ### How to Use Clusters |
| |
| 1. **Exploration**: |
| - Use clusters to discover main themes in your data |
| - Look for unexpected groupings that might reveal insights |
| - Identify outliers that might need special attention |
| |
| 2. **Analysis**: |
| - Compare cluster sizes to understand theme distribution |
| - Examine keywords to understand what defines each cluster |
| - Use hierarchy to see how themes are nested |
| |
| 3. **Practical Applications**: |
| - Generate summaries for specific clusters |
| - Focus detailed analysis on clusters of interest |
| - Use clusters to organize and categorize documents |
| - Identify gaps or overlaps in your dataset |
| |
| ### Tips for Better Results |
| |
| - **Adjust Min Cluster Size**: |
| - Larger values (15-20): Fewer, broader clusters |
| - Smaller values (2-5): More specific, smaller clusters |
| - Balance between too many small clusters and too few large ones |
| |
| - **Choose Data Wisely**: |
| - Cluster full dataset for overall themes |
| - Cluster search results for focused analysis |
| - More documents generally give better clusters |
| |
| - **Interpret with Context**: |
| - Consider your domain knowledge |
| - Look for patterns across multiple visualizations |
| - Use cluster insights to guide further analysis |
| """) |
|
|
| df_to_cluster = None |
| |
| |
| with st.form("clustering_form"): |
| st.subheader("Clustering Settings") |
| |
| |
| clustering_option = st.radio( |
| "Select data for clustering:", |
| ('Full Dataset', 'Filtered Dataset', 'Semantic Search Results') |
| ) |
| |
| |
| min_cluster_size_val = st.slider( |
| "Min Cluster Size", |
| min_value=2, |
| max_value=50, |
| value=st.session_state.get('min_cluster_size', 5), |
| help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)" |
| ) |
| |
| run_clustering = st.form_submit_button("Run Clustering") |
| |
| if run_clustering: |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.session_state['min_cluster_size'] = min_cluster_size_val |
|
|
| |
| if clustering_option == 'Semantic Search Results': |
| if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
| df_to_cluster = st.session_state['search_results'].copy() |
| else: |
| st.warning("No semantic search results found. Please run a search first.") |
| elif clustering_option == 'Filtered Dataset': |
| if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
| df_to_cluster = st.session_state['filtered_df'].copy() |
| else: |
| st.warning("Filtered dataset is empty. Please check your filters.") |
| else: |
| if 'df' in st.session_state and not st.session_state['df'].empty: |
| df_to_cluster = st.session_state['df'].copy() |
|
|
| text_columns = st.session_state.get('text_columns', []) |
| if not text_columns: |
| st.warning("No text columns selected. Please select text columns to embed before clustering.") |
| else: |
| |
| df_full = st.session_state['df'] |
| embeddings, _ = load_or_compute_embeddings( |
| df_full, |
| st.session_state.get('using_default_dataset', False), |
| st.session_state.get('uploaded_file_name'), |
| text_columns |
| ) |
|
|
| if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering: |
| with st.spinner("Performing clustering..."): |
| |
| if 'summary_df' in st.session_state: |
| del st.session_state['summary_df'] |
| if 'high_level_summary' in st.session_state: |
| del st.session_state['high_level_summary'] |
| if 'enhanced_summary' in st.session_state: |
| del st.session_state['enhanced_summary'] |
|
|
| dfc = df_to_cluster.copy().fillna("") |
| dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1) |
|
|
| |
| selected_indices = dfc.index |
| embeddings_clustering = embeddings[selected_indices] |
|
|
| |
| stop_words = set(stopwords.words('english')) |
| texts_cleaned = [] |
| for text in dfc['text'].tolist(): |
| try: |
| |
| try: |
| word_tokens = word_tokenize(text) |
| except LookupError: |
| |
| nltk.download('punkt_tab', quiet=False) |
| word_tokens = word_tokenize(text) |
| except Exception as e: |
| |
| st.warning(f"Using fallback tokenization due to error: {e}") |
| word_tokens = text.split() |
| |
| filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words]) |
| texts_cleaned.append(filtered_text) |
| except Exception as e: |
| st.error(f"Error processing text: {e}") |
| |
| texts_cleaned.append(text) |
|
|
| try: |
| |
| if len(texts_cleaned) < min_cluster_size_val: |
| st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.") |
| st.session_state['clustering_error'] = "Insufficient documents for clustering" |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.stop() |
|
|
| |
| if torch.is_tensor(embeddings_clustering): |
| embeddings_for_clustering = embeddings_clustering.cpu().numpy() |
| else: |
| embeddings_for_clustering = embeddings_clustering |
|
|
| |
| if embeddings_for_clustering.shape[0] != len(texts_cleaned): |
| st.error("Mismatch between number of embeddings and texts.") |
| st.session_state['clustering_error'] = "Embedding and text count mismatch" |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.stop() |
|
|
| |
| try: |
| hdbscan_model = HDBSCAN( |
| min_cluster_size=min_cluster_size_val, |
| metric='euclidean', |
| cluster_selection_method='eom' |
| ) |
| |
| |
| topic_model = BERTopic( |
| embedding_model=get_embedding_model(), |
| hdbscan_model=hdbscan_model |
| ) |
|
|
| |
| topics, probs = topic_model.fit_transform( |
| texts_cleaned, |
| embeddings=embeddings_for_clustering |
| ) |
|
|
| |
| unique_topics = set(topics) |
| if len(unique_topics) < 2: |
| st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.") |
| if -1 in unique_topics: |
| non_noise_docs = sum(1 for t in topics if t != -1) |
| st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).") |
| if non_noise_docs < min_cluster_size_val: |
| st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.") |
| st.session_state['clustering_error'] = "Insufficient clustered documents" |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.stop() |
|
|
| |
| dfc['Topic'] = topics |
| st.session_state['topic_model'] = topic_model |
| st.session_state['clustered_data'] = dfc.copy() |
| st.session_state['clustering_texts_cleaned'] = texts_cleaned |
| st.session_state['clustering_embeddings'] = embeddings_for_clustering |
| st.session_state['clustering_completed'] = True |
|
|
| |
| try: |
| st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics() |
| except Exception as viz_error: |
| st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.") |
| st.session_state['intertopic_distance_fig'] = None |
|
|
| try: |
| st.session_state['topic_document_fig'] = topic_model.visualize_documents( |
| texts_cleaned, |
| embeddings=embeddings_for_clustering |
| ) |
| except Exception as viz_error: |
| st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.") |
| st.session_state['topic_document_fig'] = None |
|
|
| try: |
| hierarchy = topic_model.hierarchical_topics(texts_cleaned) |
| st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame() |
| st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy() |
| except Exception as viz_error: |
| st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.") |
| st.session_state['hierarchy'] = pd.DataFrame() |
| st.session_state['hierarchy_fig'] = None |
|
|
| except ValueError as ve: |
| if "zero-size array to reduction operation maximum which has no identity" in str(ve): |
| st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.") |
| elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve): |
| st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.") |
| else: |
| st.error(f"Clustering error: {str(ve)}") |
| st.session_state['clustering_error'] = str(ve) |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.stop() |
|
|
| except Exception as e: |
| st.error(f"An error occurred during clustering: {str(e)}") |
| st.session_state['clustering_error'] = str(e) |
| st.session_state['clustering_completed'] = False |
| st.session_state.active_tab_index = tabs_titles.index("Clustering") |
| st.stop() |
|
|
| |
| if st.session_state.get('clustering_completed', False): |
| st.subheader("Topic Overview") |
| dfc = st.session_state['clustered_data'] |
| topic_model = st.session_state['topic_model'] |
| topics = dfc['Topic'].tolist() |
| |
| unique_topics = sorted(list(set(topics))) |
| cluster_info = [] |
| for t in unique_topics: |
| cluster_docs = dfc[dfc['Topic'] == t] |
| count = len(cluster_docs) |
| top_words = topic_model.get_topic(t) |
| if top_words: |
| top_keywords = ", ".join([w[0] for w in top_words[:5]]) |
| else: |
| top_keywords = "N/A" |
| cluster_info.append((t, count, top_keywords)) |
| cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
|
|
| st.write("### Topic Overview") |
| st.dataframe( |
| cluster_df, |
| column_config={ |
| "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
| "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
| "Top Keywords": st.column_config.TextColumn( |
| "Top Keywords", |
| help="Top 5 keywords that characterize this topic" |
| ) |
| }, |
| hide_index=True |
| ) |
|
|
| st.subheader("Clustering Results") |
| columns_to_display = [c for c in dfc.columns if c != 'text'] |
| st.dataframe(dfc[columns_to_display], hide_index=True) |
|
|
| |
| st.write("### Intertopic Distance Map") |
| if st.session_state.get('intertopic_distance_fig') is not None: |
| try: |
| st.plotly_chart(st.session_state['intertopic_distance_fig']) |
| except Exception: |
| st.info("Topic visualization is not available for the current clustering results.") |
|
|
| st.write("### Topic Document Visualization") |
| if st.session_state.get('topic_document_fig') is not None: |
| try: |
| st.plotly_chart(st.session_state['topic_document_fig']) |
| except Exception: |
| st.info("Document visualization is not available for the current clustering results.") |
|
|
| st.write("### Topic Hierarchy") |
| if st.session_state.get('hierarchy_fig') is not None: |
| try: |
| st.plotly_chart(st.session_state['hierarchy_fig']) |
| except Exception: |
| st.info("Topic hierarchy visualization is not available for the current clustering results.") |
| if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering): |
| pass |
| else: |
| st.warning("Please select or upload a dataset and filter as needed.") |
|
|
|
|
| |
| |
| |
| with tab_summarization: |
| st.header("Summarization") |
| |
| with st.expander("ℹ️ How Summarization Works", expanded=False): |
| st.markdown(""" |
| ### Understanding Document Summarization |
| |
| Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works: |
| |
| 1. **Summary Generation**: |
| - Documents are processed using advanced language models |
| - Key themes and important points are identified |
| - Content is condensed while maintaining context |
| - Both high-level and cluster-specific summaries are available |
| |
| 2. **Reference System**: |
| - Summaries can include references to source documents |
| - References are shown as [ID] or as clickable links |
| - Each statement can be traced back to its source |
| - Helps maintain accountability and verification |
| |
| 3. **Types of Summaries**: |
| - **High-Level Summary**: Overview of all selected documents |
| - Captures main themes across the entire selection |
| - Ideal for quick understanding of large document sets |
| - Shows relationships between different topics |
| |
| - **Cluster-Specific Summaries**: Focused on each cluster |
| - More detailed for specific themes |
| - Shows unique aspects of each cluster |
| - Helps understand sub-topics in depth |
| |
| ### How to Use Summaries |
| |
| 1. **Configuration**: |
| - Choose between all clusters or specific ones |
| - Set temperature for creativity vs. consistency |
| - Adjust max tokens for summary length |
| - Enable/disable reference system |
| |
| 2. **Reference Options**: |
| - Select column for reference IDs |
| - Add hyperlinks to references |
| - Choose URL column for clickable links |
| - References help track information sources |
| |
| 3. **Practical Applications**: |
| - Quick overview of large datasets |
| - Detailed analysis of specific themes |
| - Evidence-based reporting with references |
| - Compare different document groups |
| |
| ### Tips for Better Results |
| |
| - **Temperature Setting**: |
| - Higher (0.7-1.0): More creative, varied summaries |
| - Lower (0.1-0.3): More consistent, conservative summaries |
| - Balance based on your needs for creativity vs. consistency |
| |
| - **Token Length**: |
| - Longer limits: More detailed summaries |
| - Shorter limits: More concise, focused summaries |
| - Adjust based on document complexity |
| |
| - **Reference Usage**: |
| - Enable references for traceability |
| - Use hyperlinks for easy navigation |
| - Choose meaningful reference columns |
| - Helps validate summary accuracy |
| |
| ### Best Practices |
| |
| 1. **For General Overview**: |
| - Use high-level summary |
| - Keep temperature moderate (0.5-0.7) |
| - Enable references for verification |
| - Focus on broader themes |
| |
| 2. **For Detailed Analysis**: |
| - Use cluster-specific summaries |
| - Adjust temperature based on need |
| - Include references with hyperlinks |
| - Look for patterns within clusters |
| |
| 3. **For Reporting**: |
| - Combine both summary types |
| - Use references extensively |
| - Balance detail and brevity |
| - Ensure source traceability |
| """) |
|
|
| df_summ = None |
| |
| if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty: |
| df_summ = st.session_state['clustered_data'] |
| elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
| df_summ = st.session_state['filtered_df'] |
| else: |
| st.warning("No data available for summarization. Please cluster first or have some filtered data.") |
| |
| if df_summ is not None and not df_summ.empty: |
| text_columns = st.session_state.get('text_columns', []) |
| if not text_columns: |
| st.warning("No text columns selected. Please select columns for text embedding first.") |
| else: |
| if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state: |
| st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.") |
| else: |
| topic_model = st.session_state['topic_model'] |
| df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1) |
|
|
| |
| topics = sorted(df_summ['Topic'].unique()) |
| cluster_info = [] |
| for t in topics: |
| cluster_docs = df_summ[df_summ['Topic'] == t] |
| count = len(cluster_docs) |
| top_words = topic_model.get_topic(t) |
| if top_words: |
| top_keywords = ", ".join([w[0] for w in top_words[:5]]) |
| else: |
| top_keywords = "N/A" |
| cluster_info.append((t, count, top_keywords)) |
| cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
|
|
| |
| if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns: |
| summary_df = st.session_state['summary_df'] |
| |
| topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
| |
| cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster')) |
| |
| cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']] |
| |
| st.write("### Available Clusters:") |
| st.dataframe( |
| cluster_df, |
| column_config={ |
| "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
| "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
| "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
| "Top Keywords": st.column_config.TextColumn( |
| "Top Keywords", |
| help="Top 5 keywords that characterize this topic" |
| ) |
| }, |
| hide_index=True |
| ) |
|
|
| |
| st.subheader("Summarization Settings") |
| |
| summary_scope = st.radio( |
| "Generate summaries for:", |
| ["All clusters", "Specific clusters"] |
| ) |
| if summary_scope == "Specific clusters": |
| |
| if 'Cluster_Name' in cluster_df.columns: |
| topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])] |
| topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])} |
| selected_topic_options = st.multiselect("Select clusters to summarize", topic_options) |
| selected_topics = [topic_to_id[opt] for opt in selected_topic_options] |
| else: |
| selected_topics = st.multiselect("Select clusters to summarize", topics) |
| else: |
| selected_topics = topics |
|
|
| |
| default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries. |
| You will be given text and an objective context. Please produce a clear, cohesive, |
| and thematically relevant summary. |
| Focus on key points, insights, or patterns that emerge from the text.""" |
|
|
| if 'system_prompt' not in st.session_state: |
| st.session_state['system_prompt'] = default_system_prompt |
|
|
| with st.expander("🔧 Advanced Settings", expanded=False): |
| st.markdown(""" |
| ### System Prompt Configuration |
| |
| The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs: |
| - Be specific about the style and focus you want |
| - Add domain-specific context if needed |
| - Include any special formatting requirements |
| """) |
| |
| system_prompt = st.text_area( |
| "Customize System Prompt", |
| value=st.session_state['system_prompt'], |
| height=150, |
| help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus." |
| ) |
| |
| if st.button("Reset to Default"): |
| system_prompt = default_system_prompt |
| st.session_state['system_prompt'] = default_system_prompt |
|
|
| st.markdown("### Generation Parameters") |
| temperature = st.slider( |
| "Temperature", |
| 0.0, 1.0, 0.7, |
| help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent." |
| ) |
| max_tokens = st.slider( |
| "Max Tokens", |
| 100, 3000, 1000, |
| help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate." |
| ) |
|
|
| st.session_state['system_prompt'] = system_prompt |
|
|
| st.write("### Enhanced Summary References") |
| st.write("Select columns for references (optional).") |
| all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']] |
| |
| |
| if 'reference_id_column' not in st.session_state: |
| st.session_state.reference_id_column = all_cols[0] if all_cols else None |
| |
| url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None) |
| if 'url_column' not in st.session_state: |
| st.session_state.url_column = url_guess |
|
|
| enable_references = st.checkbox( |
| "Enable references in summaries", |
| value=True, |
| help="Add source references to the final summary text." |
| ) |
| reference_id_column = st.selectbox( |
| "Select column to use as reference ID:", |
| all_cols, |
| index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0 |
| ) |
| add_hyperlinks = st.checkbox( |
| "Add hyperlinks to references", |
| value=True, |
| help="If the reference column has a matching URL, make it clickable." |
| ) |
| url_column = None |
| if add_hyperlinks: |
| url_column = st.selectbox( |
| "Select column containing URLs:", |
| all_cols, |
| index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0 |
| ) |
|
|
| |
| if st.button("Generate Summaries"): |
| openai_api_key = os.environ.get('OPENAI_API_KEY') |
| if not openai_api_key: |
| st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
| else: |
| |
| st.session_state['_summarization_button_clicked'] = True |
| |
| llm = ChatOpenAI( |
| api_key=openai_api_key, |
| model_name='gpt-4o-mini', |
| temperature=temperature, |
| max_tokens=max_tokens |
| ) |
|
|
| |
| if selected_topics: |
| df_scope = df_summ[df_summ['Topic'].isin(selected_topics)] |
| else: |
| st.warning("No topics selected for summarization.") |
| df_scope = pd.DataFrame() |
|
|
| if df_scope.empty: |
| st.warning("No documents match the selected topics for summarization.") |
| else: |
| all_texts = df_scope['text'].tolist() |
| combined_text = " ".join(all_texts) |
| if not combined_text.strip(): |
| st.warning("No text data available for summarization.") |
| else: |
| |
| local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt']) |
| local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
| local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message]) |
|
|
| |
| |
| unique_selected_topics = df_scope['Topic'].unique() |
| if len(unique_selected_topics) > 1: |
| st.write("### Summaries per Selected Cluster") |
| |
| |
| with st.spinner("Generating cluster summaries in parallel..."): |
| summaries = process_summaries_in_parallel( |
| df_scope=df_scope, |
| unique_selected_topics=unique_selected_topics, |
| llm=llm, |
| chat_prompt=local_chat_prompt, |
| enable_references=enable_references, |
| reference_id_column=reference_id_column, |
| url_column=url_column if add_hyperlinks else None, |
| max_workers=min(16, len(unique_selected_topics)) |
| ) |
|
|
| if summaries: |
| summary_df = pd.DataFrame(summaries) |
| |
| st.session_state['summary_df'] = summary_df |
| |
| st.session_state['has_references'] = enable_references |
| st.session_state['reference_id_column'] = reference_id_column |
| st.session_state['url_column'] = url_column if add_hyperlinks else None |
| |
| |
| if 'Cluster_Name' in summary_df.columns: |
| topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
| cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster')) |
| cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']] |
| |
| |
| st.write("### Updated Topic Overview:") |
| st.dataframe( |
| cluster_df, |
| column_config={ |
| "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
| "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
| "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
| "Top Keywords": st.column_config.TextColumn( |
| "Top Keywords", |
| help="Top 5 keywords that characterize this topic" |
| ) |
| }, |
| hide_index=True |
| ) |
| |
| |
| with st.spinner("Generating high-level summary from cluster summaries..."): |
| |
| formatted_summaries = [] |
| total_tokens = 0 |
| MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) |
| summary_batches = [] |
| current_batch = [] |
| current_batch_tokens = 0 |
|
|
| for _, row in summary_df.iterrows(): |
| summary_text = row.get('Enhanced_Summary', row['Summary']) |
| formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}" |
| summary_tokens = len(tokenizer(formatted_summary)["input_ids"]) |
| |
| |
| if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS: |
| if current_batch: |
| summary_batches.append(current_batch) |
| current_batch = [] |
| current_batch_tokens = 0 |
| |
| current_batch.append(formatted_summary) |
| current_batch_tokens += summary_tokens |
|
|
| |
| if current_batch: |
| summary_batches.append(current_batch) |
|
|
| |
| batch_overviews = [] |
| with st.spinner("Generating batch summaries..."): |
| for i, batch in enumerate(summary_batches, 1): |
| st.write(f"Processing batch {i} of {len(summary_batches)}...") |
| |
| batch_text = "\n\n".join(batch) |
| batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>. |
| |
| Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT: |
| 1. Preserve all hyperlinked references exactly as they appear in the input summaries |
| 2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries |
| 3. Keep the markdown formatting for better readability |
| 4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters |
| |
| Here are the cluster summaries to synthesize: |
| |
| {batch_text}""" |
|
|
| |
| high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt']) |
| high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
| high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message]) |
| high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt) |
| batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip() |
| batch_overviews.append(batch_overview) |
|
|
| |
| with st.spinner("Generating final combined summary..."): |
| combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
| final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents. |
| |
| Please create a final comprehensive synthesis that: |
| 1. Integrates the key themes and findings from all parts |
| 2. Preserves all hyperlinked references exactly as they appear |
| 3. Maintains the HTML anchor tags (<a href="...">) intact |
| 4. Keeps the markdown formatting for better readability |
| 5. Creates a coherent narrative across all parts |
| 6. Highlights any themes that span multiple parts |
| |
| Here are the overviews to synthesize: |
| |
| ### Part 1: |
| |
| {combined_overviews}""" |
|
|
| |
| final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"]) |
| if final_prompt_tokens > MAX_SAFE_TOKENS: |
| st.error(f"❌ Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.") |
| high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
| else: |
| |
| high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt) |
| high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip() |
|
|
| |
| st.session_state['high_level_summary'] = high_level_summary |
| st.session_state['enhanced_summary'] = high_level_summary |
|
|
| |
| st.session_state['summarization_completed'] = True |
|
|
| |
| st.write("### High-Level Summary:") |
| st.markdown(high_level_summary, unsafe_allow_html=True) |
|
|
| |
| st.write("### Cluster Summaries:") |
| if enable_references and 'Enhanced_Summary' in summary_df.columns: |
| for idx, row in summary_df.iterrows(): |
| cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
| st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
| st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
| st.write("---") |
| with st.expander("View original summaries in table format"): |
| display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
| display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| st.dataframe(display_df, hide_index=True) |
| else: |
| st.write("### Summaries per Cluster:") |
| if 'Cluster_Name' in summary_df.columns: |
| display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
| display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| st.dataframe(display_df, hide_index=True) |
| else: |
| st.dataframe(summary_df, hide_index=True) |
|
|
| |
| if 'Enhanced_Summary' in summary_df.columns: |
| dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
| dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| else: |
| dl_df = summary_df |
| csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
| b64 = base64.b64encode(csv_bytes).decode() |
| href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
| st.markdown(href, unsafe_allow_html=True) |
|
|
| |
| if st.session_state.get('summarization_completed', False): |
| if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
| if 'high_level_summary' in st.session_state: |
| st.write("### High-Level Summary:") |
| st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True) |
|
|
| st.write("### Cluster Summaries:") |
| summary_df = st.session_state['summary_df'] |
| if 'Enhanced_Summary' in summary_df.columns: |
| for idx, row in summary_df.iterrows(): |
| cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
| st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
| st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
| st.write("---") |
| with st.expander("View original summaries in table format"): |
| display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
| display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| st.dataframe(display_df, hide_index=True) |
| else: |
| st.dataframe(summary_df, hide_index=True) |
|
|
| |
| dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df |
| if 'Cluster_Name' in dl_df.columns: |
| dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
| b64 = base64.b64encode(csv_bytes).decode() |
| href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
| st.markdown(href, unsafe_allow_html=True) |
| else: |
| st.warning("No data available for summarization.") |
|
|
| |
| if not st.session_state.get('_summarization_button_clicked', False): |
| if 'high_level_summary' in st.session_state: |
| st.write("### Existing High-Level Summary:") |
| if st.session_state.get('enhanced_summary'): |
| st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True) |
| with st.expander("View original summary (without references)"): |
| st.write(st.session_state['high_level_summary']) |
| else: |
| st.write(st.session_state['high_level_summary']) |
|
|
| if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
| st.write("### Existing Cluster Summaries:") |
| summary_df = st.session_state['summary_df'] |
| if 'Enhanced_Summary' in summary_df.columns: |
| for idx, row in summary_df.iterrows(): |
| cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
| st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
| st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
| st.write("---") |
| with st.expander("View original summaries in table format"): |
| display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
| display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| st.dataframe(display_df, hide_index=True) |
| else: |
| st.dataframe(summary_df, hide_index=True) |
|
|
| |
| dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df |
| if 'Cluster_Name' in dl_df.columns: |
| dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
| csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
| b64 = base64.b64encode(csv_bytes).decode() |
| href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
| st.markdown(href, unsafe_allow_html=True) |
|
|
|
|
| |
| |
| |
| with tab_chat: |
| st.header("Chat with Your Data") |
| |
| |
| with st.expander("ℹ️ How Chat Works", expanded=False): |
| st.markdown(""" |
| ### Understanding Chat with Your Data |
| |
| The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works: |
| |
| 1. **Data Selection**: |
| - Choose which dataset to chat about (filtered, clustered, or search results) |
| - Optionally focus on specific clusters if clustering was performed |
| - System automatically includes relevant context from your selection |
| |
| 2. **Context Window**: |
| - Shows how much of the GPT-4 context window is being used |
| - Helps you understand if you need to filter data further |
| - Displays token usage statistics |
| |
| 3. **Chat Features**: |
| - Ask questions about your data |
| - Get insights and analysis |
| - Reference specific documents or clusters |
| - Download chat context for transparency |
| |
| ### Best Practices |
| |
| 1. **Data Selection**: |
| - Start with filtered or clustered data for more focused conversations |
| - Select specific clusters if you want to dive deep into a topic |
| - Consider the context window usage when selecting data |
| |
| 2. **Asking Questions**: |
| - Be specific in your questions |
| - Ask about patterns, trends, or insights |
| - Reference clusters or documents by their IDs |
| - Build on previous questions for deeper analysis |
| |
| 3. **Managing Context**: |
| - Monitor the context window usage |
| - Filter data further if context is too full |
| - Download chat context for documentation |
| - Clear chat history to start fresh |
| |
| ### Tips for Better Results |
| |
| - **Question Types**: |
| - "What are the main themes in cluster 3?" |
| - "Compare the findings between clusters 1 and 2" |
| - "Summarize the methodology used across these documents" |
| - "What are the common outcomes reported?" |
| |
| - **Follow-up Questions**: |
| - Build on previous answers |
| - Ask for clarification |
| - Request specific examples |
| - Explore relationships between findings |
| """) |
|
|
| |
| def get_available_data_sources(): |
| sources = [] |
| if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
| sources.append("Filtered Dataset") |
| if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty: |
| sources.append("Clustered Data") |
| if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
| sources.append("Search Results") |
| if ('high_level_summary' in st.session_state or |
| ('summary_df' in st.session_state and not st.session_state['summary_df'].empty)): |
| sources.append("Summarized Data") |
| return sources |
|
|
| |
| available_sources = get_available_data_sources() |
| |
| if not available_sources: |
| st.warning("No data available for chat. Please filter, cluster, search, or summarize first.") |
| st.stop() |
|
|
| |
| if 'chat_data_source' not in st.session_state: |
| st.session_state.chat_data_source = available_sources[0] |
| elif st.session_state.chat_data_source not in available_sources: |
| st.session_state.chat_data_source = available_sources[0] |
|
|
| |
| data_source = st.radio( |
| "Select data to chat about:", |
| available_sources, |
| index=available_sources.index(st.session_state.chat_data_source), |
| help="Choose which dataset you want to analyze in the chat." |
| ) |
|
|
| |
| if data_source != st.session_state.chat_data_source: |
| st.session_state.chat_data_source = data_source |
| |
| if 'chat_selected_cluster' in st.session_state: |
| del st.session_state.chat_selected_cluster |
|
|
| |
| df_chat = None |
| if data_source == "Filtered Dataset": |
| df_chat = st.session_state['filtered_df'] |
| elif data_source == "Clustered Data": |
| df_chat = st.session_state['clustered_data'] |
| elif data_source == "Search Results": |
| df_chat = st.session_state['search_results'] |
| elif data_source == "Summarized Data": |
| |
| summary_rows = [] |
| |
| |
| if 'high_level_summary' in st.session_state: |
| summary_rows.append({ |
| 'Summary_Type': 'High-Level Summary', |
| 'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary']) |
| }) |
| |
| |
| if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
| summary_df = st.session_state['summary_df'] |
| for _, row in summary_df.iterrows(): |
| summary_rows.append({ |
| 'Summary_Type': f"Cluster {row['Topic']} Summary", |
| 'Content': row.get('Enhanced_Summary', row['Summary']) |
| }) |
| |
| if summary_rows: |
| df_chat = pd.DataFrame(summary_rows) |
|
|
| if df_chat is not None and not df_chat.empty: |
| |
| selected_cluster = None |
| if data_source != "Summarized Data" and 'Topic' in df_chat.columns: |
| cluster_option = st.radio( |
| "Choose cluster scope:", |
| ["All Clusters", "Specific Cluster"] |
| ) |
| if cluster_option == "Specific Cluster": |
| unique_topics = sorted(df_chat['Topic'].unique()) |
| |
| if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns: |
| summary_df = st.session_state['summary_df'] |
| |
| topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
| |
| topic_options = [ |
| (t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}") |
| for t in unique_topics |
| ] |
| selected_cluster = st.selectbox( |
| "Select cluster to focus on:", |
| [t[0] for t in topic_options], |
| format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x) |
| ) |
| else: |
| selected_cluster = st.selectbox( |
| "Select cluster to focus on:", |
| unique_topics, |
| format_func=lambda x: f"Cluster {x}" |
| ) |
| if selected_cluster is not None: |
| df_chat = df_chat[df_chat['Topic'] == selected_cluster] |
| st.session_state.chat_selected_cluster = selected_cluster |
| elif 'chat_selected_cluster' in st.session_state: |
| del st.session_state.chat_selected_cluster |
|
|
| |
| text_columns = st.session_state.get('text_columns', []) |
| if not text_columns and data_source != "Summarized Data": |
| st.warning("No text columns selected. Please select text columns to enable chat functionality.") |
| st.stop() |
|
|
| |
| MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95) |
| |
| |
| system_msg = { |
| "role": "system", |
| "content": """You are a specialized assistant analyzing data from a research database. |
| Your role is to: |
| 1. Provide clear, concise answers based on the data provided |
| 2. Highlight relevant information from specific results when answering |
| 3. When referencing specific results, use their row index or ID if available |
| 4. Clearly state if information is not available in the results |
| 5. Maintain a professional and analytical tone |
| 6. Format your responses using Markdown: |
| - Use **bold** for emphasis |
| - Use bullet points and numbered lists for structured information |
| - Create tables using Markdown syntax when presenting structured data |
| - Use backticks for code or technical terms |
| - Include hyperlinks when referencing external sources |
| - Use headings (###) to organize long responses |
| |
| The data is provided in a structured format where:""" + (""" |
| - Each result contains multiple fields |
| - Text content is primarily in the following columns: """ + ", ".join(text_columns) + """ |
| - Additional metadata and fields are available for reference |
| - If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """ |
| - The data consists of AI-generated summaries of the documents |
| - Each summary may contain references to source documents in markdown format |
| - References are shown as [ID] or as clickable hyperlinks |
| - Summaries may be high-level (covering all documents) or cluster-specific""") + """ |
| """ |
| } |
|
|
| |
| system_tokens = len(tokenizer(system_msg["content"])["input_ids"]) |
| remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens |
|
|
| |
| data_text = "Available Data:\n" |
| included_rows = 0 |
| total_rows = len(df_chat) |
|
|
| if data_source == "Summarized Data": |
| |
| for idx, row in df_chat.iterrows(): |
| row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n" |
| row_tokens = len(tokenizer(row_text)["input_ids"]) |
| |
| if remaining_tokens - row_tokens > 0: |
| data_text += row_text |
| remaining_tokens -= row_tokens |
| included_rows += 1 |
| else: |
| break |
| else: |
| |
| for idx, row in df_chat.iterrows(): |
| row_text = f"\nItem {idx}:\n" |
| for col in df_chat.columns: |
| if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score': |
| row_text += f"{col}: {row[col]}\n" |
| |
| row_tokens = len(tokenizer(row_text)["input_ids"]) |
| if remaining_tokens - row_tokens > 0: |
| data_text += row_text |
| remaining_tokens -= row_tokens |
| included_rows += 1 |
| else: |
| break |
|
|
| |
| data_tokens = len(tokenizer(data_text)["input_ids"]) |
| total_tokens = system_tokens + data_tokens |
| context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100 |
|
|
| |
| st.subheader("Context Window Usage") |
| st.write(f"System Message: {system_tokens:,} tokens") |
| st.write(f"Data Context: {data_tokens:,} tokens") |
| st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)") |
| st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)") |
| |
| if context_usage_percent > 90: |
| st.warning("⚠️ High context usage! Consider reducing the number of results or filtering further.") |
| elif context_usage_percent > 75: |
| st.info("ℹ️ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.") |
|
|
| |
| chat_context = f"""System Message: |
| {system_msg['content']} |
| |
| {data_text}""" |
| st.download_button( |
| label="📥 Download Chat Context", |
| data=chat_context, |
| file_name="chat_context.txt", |
| mime="text/plain", |
| help="Download the exact context that the chatbot receives" |
| ) |
|
|
| |
| col_chat1, col_chat2 = st.columns([3, 1]) |
| with col_chat1: |
| user_input = st.text_area("Ask a question about your data:", key="chat_input") |
| with col_chat2: |
| if st.button("Clear Chat History"): |
| st.session_state.chat_history = [] |
| st.rerun() |
|
|
| |
| current_tab = tabs_titles.index("Chat") |
| |
| if st.button("Send", key="send_button"): |
| if user_input: |
| |
| st.session_state.active_tab_index = current_tab |
| |
| with st.spinner("Processing your question..."): |
| |
| st.session_state.chat_history.append({"role": "user", "content": user_input}) |
| |
| |
| messages = [system_msg] |
| messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"}) |
| |
| |
| response = get_chat_response(messages) |
| |
| if response: |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
|
| |
| st.subheader("Chat History") |
| for message in st.session_state.chat_history: |
| if message["role"] == "user": |
| st.write("**You:**", message["content"]) |
| else: |
| st.write("**Assistant:**") |
| st.markdown(message["content"], unsafe_allow_html=True) |
| st.write("---") |
|
|
|
|
| |
| |
| |
|
|
| else: |
| st.header("Automatic Mode") |
| |
| |
| if 'df' not in st.session_state: |
| default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx') |
| df = load_default_dataset(default_dataset_path) |
| if df is not None: |
| st.session_state['df'] = df.copy() |
| st.session_state['using_default_dataset'] = True |
| st.session_state['filtered_df'] = df.copy() |
| |
| |
| if 'text_columns' not in st.session_state or not st.session_state['text_columns']: |
| default_text_cols = [] |
| if 'Title' in df.columns and 'Description' in df.columns: |
| default_text_cols = ['Title', 'Description'] |
| st.session_state['text_columns'] = default_text_cols |
|
|
| |
| |
| query = st.text_input("Write your query here:") |
|
|
| |
|
|
| |
| if st.button("SNAP!"): |
| if query.strip(): |
| |
| st.write("### Step 1: Semantic Search") |
| with st.spinner("Performing Semantic Search..."): |
| text_columns = st.session_state.get('text_columns', []) |
| if text_columns: |
| df_full = st.session_state['df'] |
| embeddings, _ = load_or_compute_embeddings( |
| df_full, |
| st.session_state.get('using_default_dataset', False), |
| st.session_state.get('uploaded_file_name'), |
| text_columns |
| ) |
| |
| if embeddings is not None: |
| model = get_embedding_model() |
| df_filtered = st.session_state['filtered_df'].fillna("") |
| search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist() |
| |
| subset_indices = df_filtered.index |
| subset_embeddings = embeddings[subset_indices] |
| |
| query_embedding = model.encode([query], device=device) |
| similarities = cosine_similarity(query_embedding, subset_embeddings)[0] |
| |
| similarity_threshold = 0.35 |
| above_threshold_indices = np.where(similarities > similarity_threshold)[0] |
| |
| if len(above_threshold_indices) > 0: |
| selected_indices = subset_indices[above_threshold_indices] |
| results = df_filtered.loc[selected_indices].copy() |
| results['similarity_score'] = similarities[above_threshold_indices] |
| results.sort_values(by='similarity_score', ascending=False, inplace=True) |
| st.session_state['search_results'] = results.copy() |
| st.write(f"Found {len(results)} relevant documents") |
| else: |
| st.warning("No results found above the similarity threshold.") |
| st.stop() |
| |
| |
| if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
| st.write("### Step 2: Clustering") |
| with st.spinner("Performing clustering..."): |
| df_to_cluster = st.session_state['search_results'].copy() |
| dfc = df_to_cluster.copy().fillna("") |
| dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1) |
| |
| |
| selected_indices = dfc.index |
| embeddings_clustering = embeddings[selected_indices] |
| |
| |
| stop_words = set(stopwords.words('english')) |
| texts_cleaned = [] |
| for text in dfc['text'].tolist(): |
| try: |
| word_tokens = word_tokenize(text) |
| filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words]) |
| texts_cleaned.append(filtered_text) |
| except Exception as e: |
| texts_cleaned.append(text) |
| |
| min_cluster_size = 5 |
| |
| try: |
| |
| if torch.is_tensor(embeddings_clustering): |
| embeddings_for_clustering = embeddings_clustering.cpu().numpy() |
| else: |
| embeddings_for_clustering = embeddings_clustering |
| |
| |
| hdbscan_model = HDBSCAN( |
| min_cluster_size=min_cluster_size, |
| metric='euclidean', |
| cluster_selection_method='eom' |
| ) |
| |
| |
| topic_model = BERTopic( |
| embedding_model=get_embedding_model(), |
| hdbscan_model=hdbscan_model |
| ) |
| |
| |
| topics, probs = topic_model.fit_transform( |
| texts_cleaned, |
| embeddings=embeddings_for_clustering |
| ) |
| |
| |
| dfc['Topic'] = topics |
| st.session_state['topic_model'] = topic_model |
| st.session_state['clustered_data'] = dfc.copy() |
| st.session_state['clustering_completed'] = True |
|
|
| |
| unique_topics = sorted(list(set(topics))) |
| num_clusters = len([t for t in unique_topics if t != -1]) |
| noise_docs = len([t for t in topics if t == -1]) |
| clustered_docs = len(topics) - noise_docs |
| |
| st.write(f"Found {num_clusters} distinct clusters") |
| |
| |
| |
| |
| |
| cluster_info = [] |
| for t in unique_topics: |
| if t != -1: |
| cluster_docs = dfc[dfc['Topic'] == t] |
| count = len(cluster_docs) |
| top_words = topic_model.get_topic(t) |
| top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A" |
| cluster_info.append((t, count, top_keywords)) |
| |
| if cluster_info: |
| |
| cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics() |
| except Exception: |
| st.session_state['intertopic_distance_fig'] = None |
| |
| try: |
| st.session_state['topic_document_fig'] = topic_model.visualize_documents( |
| texts_cleaned, |
| embeddings=embeddings_for_clustering |
| ) |
| except Exception: |
| st.session_state['topic_document_fig'] = None |
| |
| try: |
| hierarchy = topic_model.hierarchical_topics(texts_cleaned) |
| st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame() |
| st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy() |
| except Exception: |
| st.session_state['hierarchy'] = pd.DataFrame() |
| st.session_state['hierarchy_fig'] = None |
| |
| except Exception as e: |
| st.error(f"An error occurred during clustering: {str(e)}") |
| st.stop() |
| |
| |
| if st.session_state.get('clustering_completed', False): |
| st.write("### Step 3: Summarization") |
| |
| |
| openai_api_key = os.environ.get('OPENAI_API_KEY') |
| if not openai_api_key: |
| st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
| st.stop() |
| |
| llm = ChatOpenAI( |
| api_key=openai_api_key, |
| model_name='gpt-4o-mini', |
| temperature=0.7, |
| max_tokens=1000 |
| ) |
| |
| df_scope = st.session_state['clustered_data'] |
| unique_selected_topics = df_scope['Topic'].unique() |
| |
| |
| with st.spinner("Generating summaries..."): |
| local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries. |
| You will be given text and an objective context. Please produce a clear, cohesive, |
| and thematically relevant summary. |
| Focus on key points, insights, or patterns that emerge from the text.""") |
| local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
| local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message]) |
| |
| |
| url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None) |
| |
| summaries = process_summaries_in_parallel( |
| df_scope=df_scope, |
| unique_selected_topics=unique_selected_topics, |
| llm=llm, |
| chat_prompt=local_chat_prompt, |
| enable_references=True, |
| reference_id_column=df_scope.columns[0], |
| url_column=url_column, |
| max_workers=min(16, len(unique_selected_topics)) |
| ) |
| |
| if summaries: |
| summary_df = pd.DataFrame(summaries) |
| st.session_state['summary_df'] = summary_df |
| |
| |
| if 'Cluster_Name' in summary_df.columns: |
| st.write("### Updated Topic Overview:") |
| cluster_info = [] |
| for t in unique_selected_topics: |
| cluster_docs = df_scope[df_scope['Topic'] == t] |
| count = len(cluster_docs) |
| top_words = topic_model.get_topic(t) |
| top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A" |
| cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0] |
| cluster_info.append((t, cluster_name, count, top_keywords)) |
| |
| cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"]) |
| st.dataframe( |
| cluster_df, |
| column_config={ |
| "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
| "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
| "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
| "Top Keywords": st.column_config.TextColumn( |
| "Top Keywords", |
| help="Top 5 keywords that characterize this topic" |
| ) |
| }, |
| hide_index=True |
| ) |
| |
| |
| with st.spinner("Generating high-level summary..."): |
| formatted_summaries = [] |
| summary_batches = [] |
| current_batch = [] |
| current_batch_tokens = 0 |
| MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) |
| |
| for _, row in summary_df.iterrows(): |
| summary_text = row.get('Enhanced_Summary', row['Summary']) |
| formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}" |
| summary_tokens = len(tokenizer(formatted_summary)["input_ids"]) |
| |
| if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS: |
| if current_batch: |
| summary_batches.append(current_batch) |
| current_batch = [] |
| current_batch_tokens = 0 |
| |
| current_batch.append(formatted_summary) |
| current_batch_tokens += summary_tokens |
| |
| if current_batch: |
| summary_batches.append(current_batch) |
| |
| |
| batch_overviews = [] |
| for i, batch in enumerate(summary_batches, 1): |
| st.write(f"Processing summary batch {i} of {len(summary_batches)}...") |
| batch_text = "\n\n".join(batch) |
| batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>. |
| |
| Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT: |
| 1. Preserve all hyperlinked references exactly as they appear in the input summaries |
| 2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries |
| 3. Keep the markdown formatting for better readability |
| 4. Create clear sections with headings for different themes |
| 5. Use bullet points or numbered lists where appropriate |
| 6. Focus on synthesizing the main themes and findings |
| |
| Here are the cluster summaries to synthesize: |
| |
| {batch_text}""" |
| |
| high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt) |
| batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip() |
| batch_overviews.append(batch_overview) |
| |
| |
| if len(batch_overviews) > 1: |
| st.write("Generating final synthesis...") |
| combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
| final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents. |
| |
| Please create a final comprehensive synthesis that: |
| 1. Integrates the key themes and findings from all parts into a cohesive narrative |
| 2. Preserves all hyperlinked references exactly as they appear |
| 3. Maintains the HTML anchor tags (<a href="...">) intact |
| 4. Uses clear section headings and structured formatting |
| 5. Highlights cross-cutting themes and relationships between different aspects |
| 6. Provides a clear introduction and conclusion |
| |
| Here are the overviews to synthesize: |
| |
| # Part 1 |
| |
| {combined_overviews}""" |
| |
| final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"]) |
| if final_prompt_tokens > MAX_SAFE_TOKENS: |
| |
| high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
| else: |
| high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt) |
| high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip() |
| else: |
| |
| high_level_summary = batch_overviews[0] |
| |
| st.session_state['high_level_summary'] = high_level_summary |
| st.session_state['enhanced_summary'] = high_level_summary |
| |
| |
| st.write("### High-Level Summary:") |
| with st.expander("High-Level Summary", expanded=True): |
| st.markdown(high_level_summary, unsafe_allow_html=True) |
|
|
| st.write("### Cluster Summaries:") |
| for idx, row in summary_df.iterrows(): |
| cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
| with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False): |
| st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True) |
| st.markdown("##### About this tool") |
| with st.expander("Click to expand/collapse", expanded=True): |
| st.markdown(""" |
| This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on. |
| |
| **Tips:** |
| - **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`). |
| - Avoid writing full questions — **this is not a chatbot**. |
| - Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`). |
| - Focus on **concepts or themes** — not single words like `"climate"` or `"yield"` alone. |
| - Example good queries: |
| - `"climate adaptation smallholder farming"` |
| - `"digital agriculture innovations"` |
| - `"nutrition-sensitive value chains"` |
| |
| **Example use case**: |
| You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**. |
| A good search phrase would be: |
| 👉 `"poverty reduction maize Africa"` |
| This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*. |
| """) |