diff --git "a/tools/dedup_summaries.py" "b/tools/dedup_summaries.py"
new file mode 100644--- /dev/null
+++ "b/tools/dedup_summaries.py"
@@ -0,0 +1,2905 @@
+import os
+import re
+import time
+from typing import List, Tuple
+
+import boto3
+import gradio as gr
+import markdown
+import pandas as pd
+import spaces
+from rapidfuzz import fuzz, process
+from tqdm import tqdm
+
+from tools.aws_functions import connect_to_bedrock_runtime
+from tools.config import (
+ BATCH_SIZE_DEFAULT,
+ CHOSEN_LOCAL_MODEL_TYPE,
+ DEDUPLICATION_THRESHOLD,
+ DEFAULT_SAMPLED_SUMMARIES,
+ LLM_CONTEXT_LENGTH,
+ LLM_MAX_NEW_TOKENS,
+ LLM_SEED,
+ MAX_COMMENT_CHARS,
+ MAX_GROUPS,
+ MAX_SPACES_GPU_RUN_TIME,
+ MAX_TIME_FOR_LOOP,
+ NUMBER_OF_RETRY_ATTEMPTS,
+ OUTPUT_DEBUG_FILES,
+ OUTPUT_FOLDER,
+ REASONING_SUFFIX,
+ RUN_LOCAL_MODEL,
+ TIMEOUT_WAIT,
+ model_name_map,
+)
+from tools.helper_functions import (
+ clean_column_name,
+ convert_reference_table_to_pivot_table,
+ create_batch_file_path_details,
+ create_topic_summary_df_from_reference_table,
+ ensure_model_in_map,
+ generate_zero_shot_topics_df,
+ get_basic_response_data,
+ get_file_name_no_ext,
+ load_in_data_file,
+ read_file,
+ wrap_text,
+)
+from tools.llm_funcs import (
+ calculate_tokens_from_metadata,
+ call_llm_with_markdown_table_checks,
+ construct_azure_client,
+ construct_gemini_generative_model,
+ get_assistant_model,
+ get_model,
+ get_tokenizer,
+ process_requests,
+)
+from tools.prompts import (
+ comprehensive_summary_format_prompt,
+ comprehensive_summary_format_prompt_by_group,
+ llm_deduplication_prompt,
+ llm_deduplication_prompt_with_candidates,
+ llm_deduplication_system_prompt,
+ summarise_everything_prompt,
+ summarise_everything_system_prompt,
+ summarise_topic_descriptions_prompt,
+ summarise_topic_descriptions_system_prompt,
+ summary_assistant_prefill,
+ system_prompt,
+)
+
+max_tokens = LLM_MAX_NEW_TOKENS
+timeout_wait = TIMEOUT_WAIT
+number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
+max_time_for_loop = MAX_TIME_FOR_LOOP
+batch_size_default = BATCH_SIZE_DEFAULT
+deduplication_threshold = DEDUPLICATION_THRESHOLD
+max_comment_character_length = MAX_COMMENT_CHARS
+reasoning_suffix = REASONING_SUFFIX
+output_debug_files = OUTPUT_DEBUG_FILES
+default_number_of_sampled_summaries = DEFAULT_SAMPLED_SUMMARIES
+max_text_length = 500
+
+
+# DEDUPLICATION/SUMMARISATION FUNCTIONS
+def deduplicate_categories(
+ category_series: pd.Series,
+ join_series: pd.Series,
+ reference_df: pd.DataFrame,
+ general_topic_series: pd.Series = None,
+ merge_general_topics="No",
+ merge_sentiment: str = "No",
+ threshold: float = 90,
+) -> pd.DataFrame:
+ """
+ Deduplicates similar category names in a pandas Series based on a fuzzy matching threshold,
+ merging smaller topics into larger topics.
+
+ Parameters:
+ category_series (pd.Series): Series containing category names to deduplicate.
+ join_series (pd.Series): Additional series used for joining back to original results.
+ reference_df (pd.DataFrame): DataFrame containing the reference data to count occurrences.
+ threshold (float): Similarity threshold for considering two strings as duplicates.
+
+ Returns:
+ pd.DataFrame: DataFrame with columns ['old_category', 'deduplicated_category'].
+ """
+ # Count occurrences of each category in the reference_df
+ category_counts = reference_df["Subtopic"].value_counts().to_dict()
+
+ # Initialize dictionaries for both category mapping and scores
+ deduplication_map = {}
+ match_scores = {} # New dictionary to store match scores
+
+ # First pass: Handle exact matches
+ for category in category_series.unique():
+ if category in deduplication_map:
+ continue
+
+ # Find all exact matches
+ exact_matches = category_series[
+ category_series.str.lower() == category.lower()
+ ].index.tolist()
+ if len(exact_matches) > 1:
+ # Find the variant with the highest count
+ match_counts = {
+ match: category_counts.get(category_series[match], 0)
+ for match in exact_matches
+ }
+ most_common = max(match_counts.items(), key=lambda x: x[1])[0]
+ most_common_category = category_series[most_common]
+
+ # Map all exact matches to the most common variant and store score
+ for match in exact_matches:
+ deduplication_map[category_series[match]] = most_common_category
+ match_scores[category_series[match]] = (
+ 100 # Exact matches get score of 100
+ )
+
+ # Second pass: Handle fuzzy matches for remaining categories
+ # Create a DataFrame to maintain the relationship between categories and general topics
+ categories_df = pd.DataFrame(
+ {"category": category_series, "general_topic": general_topic_series}
+ ).drop_duplicates()
+
+ for _, row in categories_df.iterrows():
+ category = row["category"]
+ if category in deduplication_map:
+ continue
+
+ current_general_topic = row["general_topic"]
+
+ # Filter potential matches to only those within the same General topic if relevant
+ if merge_general_topics == "No":
+ potential_matches = categories_df[
+ (categories_df["category"] != category)
+ & (categories_df["general_topic"] == current_general_topic)
+ ]["category"].tolist()
+ else:
+ potential_matches = categories_df[(categories_df["category"] != category)][
+ "category"
+ ].tolist()
+
+ matches = process.extract(
+ category, potential_matches, scorer=fuzz.WRatio, score_cutoff=threshold
+ )
+
+ if matches:
+ best_match = max(matches, key=lambda x: x[1])
+ match, score, _ = best_match
+
+ if category_counts.get(category, 0) < category_counts.get(match, 0):
+ deduplication_map[category] = match
+ match_scores[category] = score
+ else:
+ deduplication_map[match] = category
+ match_scores[match] = score
+ else:
+ deduplication_map[category] = category
+ match_scores[category] = 100
+
+ # Create the result DataFrame with scores
+ result_df = pd.DataFrame(
+ {
+ "old_category": category_series + " | " + join_series,
+ "deduplicated_category": category_series.map(
+ lambda x: deduplication_map.get(x, x)
+ ),
+ "match_score": category_series.map(
+ lambda x: match_scores.get(x, 100)
+ ), # Add scores column
+ }
+ )
+
+ # print(result_df)
+
+ return result_df
+
+
+def deduplicate_topics(
+ reference_df: pd.DataFrame,
+ topic_summary_df: pd.DataFrame,
+ reference_table_file_name: str,
+ unique_topics_table_file_name: str,
+ in_excel_sheets: str = "",
+ merge_sentiment: str = "No",
+ merge_general_topics: str = "No",
+ score_threshold: int = 90,
+ in_data_files: List[str] = list(),
+ chosen_cols: List[str] = "",
+ output_folder: str = OUTPUT_FOLDER,
+ deduplicate_topics: str = "Yes",
+):
+ """
+ Deduplicate topics based on a reference and unique topics table, merging similar topics.
+
+ Args:
+ reference_df (pd.DataFrame): DataFrame containing reference data with topics.
+ topic_summary_df (pd.DataFrame): DataFrame summarizing unique topics.
+ reference_table_file_name (str): Base file name for the output reference table.
+ unique_topics_table_file_name (str): Base file name for the output unique topics table.
+ in_excel_sheets (str, optional): Comma-separated list of Excel sheet names to load. Defaults to "".
+ merge_sentiment (str, optional): Whether to merge topics regardless of sentiment ("Yes" or "No"). Defaults to "No".
+ merge_general_topics (str, optional): Whether to merge topics across different general topics ("Yes" or "No"). Defaults to "No".
+ score_threshold (int, optional): Fuzzy matching score threshold for deduplication. Defaults to 90.
+ in_data_files (List[str], optional): List of input data file paths. Defaults to [].
+ chosen_cols (List[str], optional): List of chosen columns from the input data files. Defaults to "".
+ output_folder (str, optional): Folder path to save output files. Defaults to OUTPUT_FOLDER.
+ deduplicate_topics (str, optional): Whether to perform topic deduplication ("Yes" or "No"). Defaults to "Yes".
+ """
+ output_files = list()
+ log_output_files = list()
+ file_data = pd.DataFrame()
+ deduplicated_unique_table_markdown = ""
+
+ if (len(reference_df["Response References"].unique()) == 1) | (
+ len(topic_summary_df["Topic number"].unique()) == 1
+ ):
+ print(
+ "Data file outputs are too short for deduplicating. Returning original data."
+ )
+
+ # Get file name without extension and create proper output paths
+ reference_table_file_name_no_ext = get_file_name_no_ext(
+ reference_table_file_name
+ )
+ unique_topics_table_file_name_no_ext = get_file_name_no_ext(
+ unique_topics_table_file_name
+ )
+
+ # Create output paths with _dedup suffix to match normal path
+ reference_file_out_path = (
+ output_folder + reference_table_file_name_no_ext + "_dedup.csv"
+ )
+ unique_topics_file_out_path = (
+ output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv"
+ )
+
+ # Save the DataFrames to CSV files
+ reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ reference_file_out_path, index=None, encoding="utf-8-sig"
+ )
+ topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ unique_topics_file_out_path, index=None, encoding="utf-8-sig"
+ )
+
+ output_files.append(reference_file_out_path)
+ output_files.append(unique_topics_file_out_path)
+
+ # Create markdown output for display
+ topic_summary_df_revised_display = topic_summary_df.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ deduplicated_unique_table_markdown = (
+ topic_summary_df_revised_display.to_markdown(index=False)
+ )
+
+ return (
+ reference_df,
+ topic_summary_df,
+ output_files,
+ log_output_files,
+ deduplicated_unique_table_markdown,
+ )
+
+ # For checking that data is not lost during the process
+ initial_unique_references = len(reference_df["Response References"].unique())
+
+ if topic_summary_df.empty:
+ topic_summary_df = create_topic_summary_df_from_reference_table(reference_df)
+
+ # Then merge the topic numbers back to the original dataframe
+ reference_df = reference_df.merge(
+ topic_summary_df[
+ ["General topic", "Subtopic", "Sentiment", "Topic number"]
+ ],
+ on=["General topic", "Subtopic", "Sentiment"],
+ how="left",
+ )
+
+ if in_data_files and chosen_cols:
+ file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(
+ in_data_files, chosen_cols, 1, in_excel_sheets
+ )
+ else:
+ out_message = "No file data found, pivot table output will not be created."
+ print(out_message)
+ # raise Exception(out_message)
+
+ # Run through this x times to try to get all duplicate topics
+ if deduplicate_topics == "Yes":
+ if "Group" not in reference_df.columns:
+ reference_df["Group"] = "All"
+ for i in range(0, 8):
+ if merge_sentiment == "No":
+ if merge_general_topics == "No":
+ reference_df["old_category"] = (
+ reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
+ )
+ reference_df_unique = reference_df.drop_duplicates("old_category")
+
+ # Create an empty list to store results from each group
+ results = list()
+ # Iterate over each group instead of using .apply()
+ for name, group in reference_df_unique.groupby(
+ ["General topic", "Sentiment", "Group"]
+ ):
+ # Run your function on the 'group' DataFrame
+ result = deduplicate_categories(
+ group["Subtopic"],
+ group["Sentiment"],
+ reference_df,
+ general_topic_series=group["General topic"],
+ merge_general_topics="No",
+ threshold=score_threshold,
+ )
+ results.append(result)
+
+ # Concatenate all the results into a single DataFrame
+ deduplicated_topic_map_df = pd.concat(results).reset_index(
+ drop=True
+ )
+ # --- MODIFIED SECTION END ---
+
+ else:
+ # This case should allow cross-topic matching but is still grouping by Sentiment
+ reference_df["old_category"] = (
+ reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
+ )
+ reference_df_unique = reference_df.drop_duplicates("old_category")
+
+ results = list()
+ for name, group in reference_df_unique.groupby("Sentiment"):
+ result = deduplicate_categories(
+ group["Subtopic"],
+ group["Sentiment"],
+ reference_df,
+ general_topic_series=None,
+ merge_general_topics="Yes",
+ threshold=score_threshold,
+ )
+ results.append(result)
+ deduplicated_topic_map_df = pd.concat(results).reset_index(
+ drop=True
+ )
+
+ else:
+ if merge_general_topics == "No":
+ reference_df["old_category"] = (
+ reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
+ )
+ reference_df_unique = reference_df.drop_duplicates("old_category")
+
+ results = list()
+ for name, group in reference_df_unique.groupby("General topic"):
+ result = deduplicate_categories(
+ group["Subtopic"],
+ group["Sentiment"],
+ reference_df,
+ general_topic_series=group["General topic"],
+ merge_general_topics="No",
+ merge_sentiment=merge_sentiment,
+ threshold=score_threshold,
+ )
+ results.append(result)
+ deduplicated_topic_map_df = pd.concat(results).reset_index(
+ drop=True
+ )
+
+ else:
+ reference_df["old_category"] = (
+ reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
+ )
+ reference_df_unique = reference_df.drop_duplicates("old_category")
+
+ deduplicated_topic_map_df = deduplicate_categories(
+ reference_df_unique["Subtopic"],
+ reference_df_unique["Sentiment"],
+ reference_df,
+ general_topic_series=None,
+ merge_general_topics="Yes",
+ merge_sentiment=merge_sentiment,
+ threshold=score_threshold,
+ ).reset_index(drop=True)
+
+ if deduplicated_topic_map_df["deduplicated_category"].isnull().all():
+ print("No deduplicated categories found, skipping the following code.")
+
+ else:
+ # Remove rows where 'deduplicated_category' is blank or NaN
+ deduplicated_topic_map_df = deduplicated_topic_map_df.loc[
+ (
+ deduplicated_topic_map_df["deduplicated_category"].str.strip()
+ != ""
+ )
+ & ~(deduplicated_topic_map_df["deduplicated_category"].isnull()),
+ ["old_category", "deduplicated_category", "match_score"],
+ ]
+
+ reference_df = reference_df.merge(
+ deduplicated_topic_map_df, on="old_category", how="left"
+ )
+
+ reference_df.rename(
+ columns={"Subtopic": "Subtopic_old", "Sentiment": "Sentiment_old"},
+ inplace=True,
+ )
+ # Extract subtopic and sentiment from deduplicated_category
+ reference_df["Subtopic"] = reference_df[
+ "deduplicated_category"
+ ].str.extract(r"^(.*?) \|")[
+ 0
+ ] # Extract subtopic
+ reference_df["Sentiment"] = reference_df[
+ "deduplicated_category"
+ ].str.extract(r"\| (.*)$")[
+ 0
+ ] # Extract sentiment
+
+ # Combine with old values to ensure no data is lost
+ reference_df["Subtopic"] = reference_df[
+ "deduplicated_category"
+ ].combine_first(reference_df["Subtopic_old"])
+ reference_df["Sentiment"] = reference_df["Sentiment"].combine_first(
+ reference_df["Sentiment_old"]
+ )
+
+ reference_df = reference_df.rename(
+ columns={"General Topic": "General topic"}, errors="ignore"
+ )
+ reference_df = reference_df[
+ [
+ "Response References",
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Summary",
+ "Start row of group",
+ "Group",
+ ]
+ ]
+
+ if merge_general_topics == "Yes":
+ # Replace General topic names for each Subtopic with that for the Subtopic with the most responses
+ # Step 1: Count the number of occurrences for each General topic and Subtopic combination
+ count_df = (
+ reference_df.groupby(["Subtopic", "General topic"])
+ .size()
+ .reset_index(name="Count")
+ )
+
+ # Step 2: Find the General topic with the maximum count for each Subtopic
+ max_general_topic = count_df.loc[
+ count_df.groupby("Subtopic")["Count"].idxmax()
+ ]
+
+ # Step 3: Map the General topic back to the original DataFrame
+ reference_df = reference_df.merge(
+ max_general_topic[["Subtopic", "General topic"]],
+ on="Subtopic",
+ suffixes=("", "_max"),
+ how="left",
+ )
+
+ reference_df["General topic"] = reference_df[
+ "General topic_max"
+ ].combine_first(reference_df["General topic"])
+
+ if merge_sentiment == "Yes":
+ # Step 1: Count the number of occurrences for each General topic and Subtopic combination
+ count_df = (
+ reference_df.groupby(["Subtopic", "Sentiment"])
+ .size()
+ .reset_index(name="Count")
+ )
+
+ # Step 2: Determine the number of unique Sentiment values for each Subtopic
+ unique_sentiments = (
+ count_df.groupby("Subtopic")["Sentiment"]
+ .nunique()
+ .reset_index(name="UniqueCount")
+ )
+
+ # Step 3: Update Sentiment to 'Mixed' where there is more than one unique sentiment
+ reference_df = reference_df.merge(
+ unique_sentiments, on="Subtopic", how="left"
+ )
+ reference_df["Sentiment"] = reference_df.apply(
+ lambda row: "Mixed" if row["UniqueCount"] > 1 else row["Sentiment"],
+ axis=1,
+ )
+
+ # Clean up the DataFrame by dropping the UniqueCount column
+ reference_df.drop(columns=["UniqueCount"], inplace=True)
+
+ # print("reference_df:", reference_df)
+ reference_df = reference_df[
+ [
+ "Response References",
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Summary",
+ "Start row of group",
+ "Group",
+ ]
+ ]
+ # reference_df.drop(['old_category', 'deduplicated_category', "Subtopic_old", "Sentiment_old"], axis=1, inplace=True, errors="ignore")
+
+ # Update reference summary column with all summaries
+ reference_df["Summary"] = reference_df.groupby(
+ ["Response References", "General topic", "Subtopic", "Sentiment"]
+ )["Summary"].transform("
".join)
+
+ # Check that we have not inadvertantly removed some data during the above process
+ end_unique_references = len(reference_df["Response References"].unique())
+
+ if initial_unique_references != end_unique_references:
+ raise Exception(
+ f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}"
+ )
+
+ # Drop duplicates in the reference table - each comment should only have the same topic referred to once
+ reference_df.drop_duplicates(
+ ["Response References", "General topic", "Subtopic", "Sentiment"],
+ inplace=True,
+ )
+
+ # Remake topic_summary_df based on new reference_df
+ topic_summary_df = create_topic_summary_df_from_reference_table(reference_df)
+
+ # Then merge the topic numbers back to the original dataframe
+ reference_df = reference_df.merge(
+ topic_summary_df[
+ ["General topic", "Subtopic", "Sentiment", "Group", "Topic number"]
+ ],
+ on=["General topic", "Subtopic", "Sentiment", "Group"],
+ how="left",
+ )
+
+ else:
+ print("Topics have not beeen deduplicated")
+
+ reference_table_file_name_no_ext = get_file_name_no_ext(reference_table_file_name)
+ unique_topics_table_file_name_no_ext = get_file_name_no_ext(
+ unique_topics_table_file_name
+ )
+
+ if not file_data.empty:
+ basic_response_data = get_basic_response_data(file_data, chosen_cols)
+ reference_df_pivot = convert_reference_table_to_pivot_table(
+ reference_df, basic_response_data
+ )
+
+ reference_pivot_file_path = (
+ output_folder + reference_table_file_name_no_ext + "_pivot_dedup.csv"
+ )
+ reference_df_pivot.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ reference_pivot_file_path, index=None, encoding="utf-8-sig"
+ )
+ log_output_files.append(reference_pivot_file_path)
+
+ reference_file_out_path = (
+ output_folder + reference_table_file_name_no_ext + "_dedup.csv"
+ )
+ unique_topics_file_out_path = (
+ output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv"
+ )
+ reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ reference_file_out_path, index=None, encoding="utf-8-sig"
+ )
+ topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ unique_topics_file_out_path, index=None, encoding="utf-8-sig"
+ )
+
+ output_files.append(reference_file_out_path)
+ output_files.append(unique_topics_file_out_path)
+
+ # Outputs for markdown table output
+ topic_summary_df_revised_display = topic_summary_df.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ deduplicated_unique_table_markdown = topic_summary_df_revised_display.to_markdown(
+ index=False
+ )
+
+ return (
+ reference_df,
+ topic_summary_df,
+ output_files,
+ log_output_files,
+ deduplicated_unique_table_markdown,
+ )
+
+
+def deduplicate_topics_llm(
+ reference_df: pd.DataFrame,
+ topic_summary_df: pd.DataFrame,
+ reference_table_file_name: str,
+ unique_topics_table_file_name: str,
+ model_choice: str,
+ in_api_key: str,
+ temperature: float,
+ model_source: str,
+ bedrock_runtime=None,
+ local_model=None,
+ tokenizer=None,
+ assistant_model=None,
+ in_excel_sheets: str = "",
+ merge_sentiment: str = "No",
+ merge_general_topics: str = "No",
+ in_data_files: List[str] = list(),
+ chosen_cols: List[str] = "",
+ output_folder: str = OUTPUT_FOLDER,
+ candidate_topics=None,
+ azure_endpoint: str = "",
+ output_debug_files: str = "False",
+ api_url: str = None,
+):
+ """
+ Deduplicate topics using LLM semantic understanding to identify and merge similar topics.
+
+ Args:
+ reference_df (pd.DataFrame): DataFrame containing reference data with topics.
+ topic_summary_df (pd.DataFrame): DataFrame summarizing unique topics.
+ reference_table_file_name (str): Base file name for the output reference table.
+ unique_topics_table_file_name (str): Base file name for the output unique topics table.
+ model_choice (str): The LLM model to use for deduplication.
+ in_api_key (str): API key for the LLM service.
+ temperature (float): Temperature setting for the LLM.
+ model_source (str): Source of the model (AWS, Gemini, Local, etc.).
+ bedrock_runtime: AWS Bedrock runtime client (if using AWS).
+ local_model: Local model instance (if using local model).
+ tokenizer: Tokenizer for local model.
+ assistant_model: Assistant model for speculative decoding.
+ in_excel_sheets (str, optional): Comma-separated list of Excel sheet names to load. Defaults to "".
+ merge_sentiment (str, optional): Whether to merge topics regardless of sentiment ("Yes" or "No"). Defaults to "No".
+ merge_general_topics (str, optional): Whether to merge topics across different general topics ("Yes" or "No"). Defaults to "No".
+ in_data_files (List[str], optional): List of input data file paths. Defaults to [].
+ chosen_cols (List[str], optional): List of chosen columns from the input data files. Defaults to "".
+ output_folder (str, optional): Folder path to save output files. Defaults to OUTPUT_FOLDER.
+ candidate_topics (optional): Candidate topics file for zero-shot guidance. Defaults to None.
+ azure_endpoint (str, optional): Azure endpoint for the LLM. Defaults to "".
+ output_debug_files (str, optional): Whether to output debug files. Defaults to "False".
+ """
+
+ output_files = list()
+ log_output_files = list()
+ file_data = pd.DataFrame()
+ deduplicated_unique_table_markdown = ""
+
+ # Check if data is too short for deduplication
+ if (len(reference_df["Response References"].unique()) == 1) | (
+ len(topic_summary_df["Topic number"].unique()) == 1
+ ):
+ print(
+ "Data file outputs are too short for deduplicating. Returning original data."
+ )
+
+ # Get file name without extension and create proper output paths
+ reference_table_file_name_no_ext = get_file_name_no_ext(
+ reference_table_file_name
+ )
+ unique_topics_table_file_name_no_ext = get_file_name_no_ext(
+ unique_topics_table_file_name
+ )
+
+ # Create output paths with _dedup suffix to match normal path
+ reference_file_out_path = (
+ output_folder + reference_table_file_name_no_ext + "_dedup.csv"
+ )
+ unique_topics_file_out_path = (
+ output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv"
+ )
+
+ # Save the DataFrames to CSV files
+ reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ reference_file_out_path, index=None, encoding="utf-8-sig"
+ )
+ topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ unique_topics_file_out_path, index=None, encoding="utf-8-sig"
+ )
+
+ output_files.append(reference_file_out_path)
+ output_files.append(unique_topics_file_out_path)
+
+ # Create markdown output for display
+ topic_summary_df_revised_display = topic_summary_df.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ deduplicated_unique_table_markdown = (
+ topic_summary_df_revised_display.to_markdown(index=False)
+ )
+
+ # Return with token counts set to 0 for early return
+ return (
+ reference_df,
+ topic_summary_df,
+ output_files,
+ log_output_files,
+ deduplicated_unique_table_markdown,
+ 0, # input_tokens
+ 0, # output_tokens
+ 0, # number_of_calls
+ 0.0, # estimated_time_taken
+ )
+
+ # For checking that data is not lost during the process
+ initial_unique_references = len(reference_df["Response References"].unique())
+
+ # Create topic summary if it doesn't exist
+ if topic_summary_df.empty:
+ topic_summary_df = create_topic_summary_df_from_reference_table(reference_df)
+
+ # Merge topic numbers back to the original dataframe
+ reference_df = reference_df.merge(
+ topic_summary_df[
+ ["General topic", "Subtopic", "Sentiment", "Topic number"]
+ ],
+ on=["General topic", "Subtopic", "Sentiment"],
+ how="left",
+ )
+
+ # Load data files if provided
+ if in_data_files and chosen_cols:
+ file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(
+ in_data_files, chosen_cols, 1, in_excel_sheets
+ )
+ else:
+ out_message = "No file data found, pivot table output will not be created."
+ print(out_message)
+
+ # Process candidate topics if provided
+ candidate_topics_table = ""
+ if candidate_topics is not None:
+ try:
+
+ # Read and process candidate topics
+ # Handle both string paths (CLI) and gr.FileData objects (Gradio)
+ candidate_topics_path = (
+ candidate_topics
+ if isinstance(candidate_topics, str)
+ else getattr(candidate_topics, "name", None)
+ )
+ if candidate_topics_path is None:
+ raise ValueError(
+ "candidate_topics must be a file path string or a FileData object with a 'name' attribute"
+ )
+ candidate_topics_df = read_file(candidate_topics_path)
+ candidate_topics_df = candidate_topics_df.fillna("")
+ candidate_topics_df = candidate_topics_df.astype(str)
+
+ # Generate zero-shot topics DataFrame
+ zero_shot_topics_df = generate_zero_shot_topics_df(
+ candidate_topics_df, "No", False
+ )
+
+ if not zero_shot_topics_df.empty:
+ candidate_topics_table = zero_shot_topics_df[
+ ["General topic", "Subtopic"]
+ ].to_markdown(index=False)
+ print(
+ f"Found {len(zero_shot_topics_df)} candidate topics to consider during deduplication"
+ )
+ except Exception as e:
+ print(f"Error processing candidate topics: {e}")
+ candidate_topics_table = ""
+
+ # Prepare topics table for LLM analysis
+ topics_table = topic_summary_df[
+ ["General topic", "Subtopic", "Sentiment", "Number of responses"]
+ ].to_markdown(index=False)
+
+ # Format the prompt with candidate topics if available
+ if candidate_topics_table:
+ formatted_prompt = llm_deduplication_prompt_with_candidates.format(
+ topics_table=topics_table, candidate_topics_table=candidate_topics_table
+ )
+ else:
+ formatted_prompt = llm_deduplication_prompt.format(topics_table=topics_table)
+
+ # Initialise conversation history
+ conversation_history = list()
+ whole_conversation = list()
+ whole_conversation_metadata = list()
+
+ # Set up model clients based on model source
+ if "Gemini" in model_source:
+ client, config = construct_gemini_generative_model(
+ in_api_key,
+ temperature,
+ model_choice,
+ llm_deduplication_system_prompt,
+ max_tokens,
+ LLM_SEED,
+ )
+ bedrock_runtime = None
+ elif "AWS" in model_source:
+ if not bedrock_runtime:
+ bedrock_runtime = boto3.client("bedrock-runtime")
+ client = None
+ config = None
+ elif "Azure/OpenAI" in model_source:
+ client, config = construct_azure_client(in_api_key, azure_endpoint)
+ bedrock_runtime = None
+ elif "Local" in model_source:
+ client = None
+ config = None
+ bedrock_runtime = None
+ elif "inference-server" in model_source:
+ client = None
+ config = None
+ bedrock_runtime = None
+ # api_url is already passed to call_llm_with_markdown_table_checks
+ if api_url is None:
+ raise ValueError(
+ "api_url is required when model_source is 'inference-server'"
+ )
+ else:
+ raise ValueError(f"Unsupported model source: {model_source}")
+
+ # Call LLM to get deduplication suggestions
+ print("Calling LLM for topic deduplication analysis...")
+
+ # Use the existing call_llm_with_markdown_table_checks function
+ (
+ responses,
+ conversation_history,
+ whole_conversation,
+ whole_conversation_metadata,
+ response_text,
+ ) = call_llm_with_markdown_table_checks(
+ batch_prompts=[formatted_prompt],
+ system_prompt=llm_deduplication_system_prompt,
+ conversation_history=conversation_history,
+ whole_conversation=whole_conversation,
+ whole_conversation_metadata=whole_conversation_metadata,
+ client=client,
+ client_config=config,
+ model_choice=model_choice,
+ temperature=temperature,
+ reported_batch_no=1,
+ local_model=local_model,
+ tokenizer=tokenizer,
+ bedrock_runtime=bedrock_runtime,
+ model_source=model_source,
+ MAX_OUTPUT_VALIDATION_ATTEMPTS=3,
+ assistant_prefill="",
+ master=False,
+ CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE,
+ random_seed=LLM_SEED,
+ api_url=api_url,
+ )
+
+ # Generate debug files if enabled
+ if output_debug_files == "True":
+ try:
+ # Create batch file path details for debug files
+ batch_file_path_details = (
+ get_file_name_no_ext(reference_table_file_name) + "_llm_dedup"
+ )
+ model_choice_clean_short = (
+ model_choice.replace("/", "_").replace(":", "_").replace(".", "_")
+ )
+
+ # Create full prompt for debug output
+ full_prompt = llm_deduplication_system_prompt + "\n" + formatted_prompt
+
+ # Write debug files
+ (
+ current_prompt_content_logged,
+ current_summary_content_logged,
+ current_conversation_content_logged,
+ current_metadata_content_logged,
+ ) = process_debug_output_iteration(
+ OUTPUT_DEBUG_FILES,
+ output_folder,
+ batch_file_path_details,
+ model_choice_clean_short,
+ full_prompt,
+ response_text,
+ whole_conversation,
+ whole_conversation_metadata,
+ log_output_files,
+ task_type="llm_deduplication",
+ )
+
+ print("Debug files written for LLM deduplication analysis")
+
+ except Exception as e:
+ print(f"Error writing debug files for LLM deduplication: {e}")
+
+ # Parse the LLM response to extract merge suggestions
+ merge_suggestions_df = (
+ pd.DataFrame()
+ ) # Initialize empty DataFrame for analysis results
+ num_merges_applied = 0
+
+ try:
+ # Extract the markdown table from the response
+ table_match = re.search(
+ r"\|.*\|.*\n\|.*\|.*\n(\|.*\|.*\n)*", response_text, re.MULTILINE
+ )
+ if table_match:
+ table_text = table_match.group(0)
+
+ # Convert markdown table to DataFrame
+ from io import StringIO
+
+ merge_suggestions_df = pd.read_csv(
+ StringIO(table_text), sep="|", skipinitialspace=True
+ )
+
+ # Clean up the DataFrame
+ merge_suggestions_df = merge_suggestions_df.dropna(
+ axis=1, how="all"
+ ) # Remove empty columns
+ merge_suggestions_df.columns = merge_suggestions_df.columns.str.strip()
+
+ # Remove rows where all values are NaN
+ merge_suggestions_df = merge_suggestions_df.dropna(how="all")
+
+ if not merge_suggestions_df.empty:
+ print(
+ f"LLM identified {len(merge_suggestions_df)} potential topic merges"
+ )
+
+ # Apply the merges to the reference_df
+ for _, row in merge_suggestions_df.iterrows():
+ original_general = row.get("Original General topic", "").strip()
+ original_subtopic = row.get("Original Subtopic", "").strip()
+ original_sentiment = row.get("Original Sentiment", "").strip()
+ merged_general = row.get("Merged General topic", "").strip()
+ merged_subtopic = row.get("Merged Subtopic", "").strip()
+ merged_sentiment = row.get("Merged Sentiment", "").strip()
+
+ if all(
+ [
+ original_general,
+ original_subtopic,
+ original_sentiment,
+ merged_general,
+ merged_subtopic,
+ merged_sentiment,
+ ]
+ ):
+
+ # Find matching rows in reference_df
+ mask = (
+ (reference_df["General topic"] == original_general)
+ & (reference_df["Subtopic"] == original_subtopic)
+ & (reference_df["Sentiment"] == original_sentiment)
+ )
+
+ if mask.any():
+ # Update the matching rows
+ reference_df.loc[mask, "General topic"] = merged_general
+ reference_df.loc[mask, "Subtopic"] = merged_subtopic
+ reference_df.loc[mask, "Sentiment"] = merged_sentiment
+ num_merges_applied += 1
+ print(
+ f"Merged: {original_general} | {original_subtopic} | {original_sentiment} -> {merged_general} | {merged_subtopic} | {merged_sentiment}"
+ )
+ else:
+ print("No merge suggestions found in LLM response")
+ else:
+ print("No markdown table found in LLM response")
+
+ except Exception as e:
+ print(f"Error parsing LLM response: {e}")
+ print("Continuing with original data...")
+
+ # Update reference summary column with all summaries
+ reference_df["Summary"] = reference_df.groupby(
+ ["Response References", "General topic", "Subtopic", "Sentiment"]
+ )["Summary"].transform("
".join)
+
+ # Check that we have not inadvertently removed some data during the process
+ end_unique_references = len(reference_df["Response References"].unique())
+
+ if initial_unique_references != end_unique_references:
+ raise Exception(
+ f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}"
+ )
+
+ # Drop duplicates in the reference table
+ reference_df.drop_duplicates(
+ ["Response References", "General topic", "Subtopic", "Sentiment"], inplace=True
+ )
+
+ # Remake topic_summary_df based on new reference_df
+ topic_summary_df = create_topic_summary_df_from_reference_table(reference_df)
+
+ # Merge the topic numbers back to the original dataframe
+ reference_df = reference_df.merge(
+ topic_summary_df[
+ ["General topic", "Subtopic", "Sentiment", "Group", "Topic number"]
+ ],
+ on=["General topic", "Subtopic", "Sentiment", "Group"],
+ how="left",
+ )
+
+ # Create pivot table if file data is available
+ if not file_data.empty:
+ basic_response_data = get_basic_response_data(file_data, chosen_cols)
+ reference_df_pivot = convert_reference_table_to_pivot_table(
+ reference_df, basic_response_data
+ )
+
+ reference_pivot_file_path = (
+ output_folder
+ + get_file_name_no_ext(reference_table_file_name)
+ + "_pivot_dedup.csv"
+ )
+ reference_df_pivot.to_csv(
+ reference_pivot_file_path, index=None, encoding="utf-8-sig"
+ )
+ log_output_files.append(reference_pivot_file_path)
+
+ # Save analysis results CSV if merge suggestions were found
+ if not merge_suggestions_df.empty:
+ analysis_results_file_path = (
+ output_folder
+ + get_file_name_no_ext(reference_table_file_name)
+ + "_dedup_llm_analysis_results.csv"
+ )
+ merge_suggestions_df.to_csv(
+ analysis_results_file_path, index=None, encoding="utf-8-sig"
+ )
+ log_output_files.append(analysis_results_file_path)
+ print(f"Analysis results saved to: {analysis_results_file_path}")
+
+ # Save output files
+ reference_file_out_path = (
+ output_folder + get_file_name_no_ext(reference_table_file_name) + "_dedup.csv"
+ )
+ unique_topics_file_out_path = (
+ output_folder
+ + get_file_name_no_ext(unique_topics_table_file_name)
+ + "_dedup.csv"
+ )
+ reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ reference_file_out_path, index=None, encoding="utf-8-sig"
+ )
+ topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ unique_topics_file_out_path, index=None, encoding="utf-8-sig"
+ )
+
+ output_files.append(reference_file_out_path)
+ output_files.append(unique_topics_file_out_path)
+
+ # Outputs for markdown table output
+ topic_summary_df_revised_display = topic_summary_df.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ deduplicated_unique_table_markdown = topic_summary_df_revised_display.to_markdown(
+ index=False
+ )
+
+ # Calculate token usage and timing information for logging
+ total_input_tokens = 0
+ total_output_tokens = 0
+ number_of_calls = 1 # Single LLM call for deduplication
+
+ # Extract token usage from conversation metadata
+ if whole_conversation_metadata:
+ for metadata in whole_conversation_metadata:
+ if "input_tokens:" in metadata and "output_tokens:" in metadata:
+ try:
+ input_tokens = int(
+ metadata.split("input_tokens: ")[1].split(" ")[0]
+ )
+ output_tokens = int(
+ metadata.split("output_tokens: ")[1].split(" ")[0]
+ )
+ total_input_tokens += input_tokens
+ total_output_tokens += output_tokens
+ except (ValueError, IndexError):
+ pass
+
+ # Calculate estimated time taken (rough estimate based on token usage)
+ estimated_time_taken = (
+ total_input_tokens + total_output_tokens
+ ) / 1000 # Rough estimate in seconds
+
+ return (
+ reference_df,
+ topic_summary_df,
+ output_files,
+ log_output_files,
+ deduplicated_unique_table_markdown,
+ total_input_tokens,
+ total_output_tokens,
+ number_of_calls,
+ estimated_time_taken,
+ ) # , num_merges_applied
+
+
+def sample_reference_table_summaries(
+ reference_df: pd.DataFrame,
+ random_seed: int,
+ no_of_sampled_summaries: int = default_number_of_sampled_summaries,
+ sample_reference_table_checkbox: bool = False,
+):
+ """
+ Sample x number of summaries from which to produce summaries, so that the input token length is not too long.
+ """
+
+ if sample_reference_table_checkbox:
+
+ all_summaries = pd.DataFrame(
+ columns=[
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Group",
+ "Response References",
+ "Summary",
+ ]
+ )
+
+ if "Group" not in reference_df.columns:
+ reference_df["Group"] = "All"
+
+ reference_df_grouped = reference_df.groupby(
+ ["General topic", "Subtopic", "Sentiment", "Group"]
+ )
+
+ if "Revised summary" in reference_df.columns:
+ out_message = "Summary has already been created for this file"
+ print(out_message)
+ raise Exception(out_message)
+
+ for group_keys, reference_df_group in reference_df_grouped:
+ if len(reference_df_group["General topic"]) > 1:
+
+ filtered_reference_df = reference_df_group.reset_index()
+
+ filtered_reference_df_unique = filtered_reference_df.drop_duplicates(
+ ["General topic", "Subtopic", "Sentiment", "Summary"]
+ )
+
+ # Sample n of the unique topic summaries PER GROUP. To limit the length of the text going into the summarisation tool
+ # This ensures each group gets up to no_of_sampled_summaries summaries, not the total across all groups
+ filtered_reference_df_unique_sampled = (
+ filtered_reference_df_unique.sample(
+ min(no_of_sampled_summaries, len(filtered_reference_df_unique)),
+ random_state=random_seed,
+ )
+ )
+
+ all_summaries = pd.concat(
+ [all_summaries, filtered_reference_df_unique_sampled]
+ )
+
+ # If no responses/topics qualify, just go ahead with the original reference dataframe
+ if all_summaries.empty:
+ sampled_reference_table_df = reference_df
+ # Filter by sentiment only (Response References is a string in original df, not a count)
+ sampled_reference_table_df = sampled_reference_table_df.loc[
+ sampled_reference_table_df["Sentiment"] != "Not Mentioned"
+ ]
+ else:
+ # FIXED: Preserve Group column in aggregation to maintain group-specific summaries
+ sampled_reference_table_df = (
+ all_summaries.groupby(
+ ["General topic", "Subtopic", "Sentiment", "Group"]
+ )
+ .agg(
+ {
+ "Response References": "size", # Count the number of references
+ "Summary": lambda x: "\n".join(
+ [s.split(": ", 1)[1] for s in x if ": " in s]
+ ), # Join substrings after ': '
+ }
+ )
+ .reset_index()
+ )
+ # Filter by sentiment and count (Response References is now a numeric count after aggregation)
+ sampled_reference_table_df = sampled_reference_table_df.loc[
+ (sampled_reference_table_df["Sentiment"] != "Not Mentioned")
+ & (sampled_reference_table_df["Response References"] > 1)
+ ]
+ else:
+ sampled_reference_table_df = reference_df
+
+ summarised_references_markdown = sampled_reference_table_df.to_markdown(index=False)
+
+ return sampled_reference_table_df, summarised_references_markdown
+
+
+def count_tokens_in_text(text: str, tokenizer=None, model_source: str = "Local") -> int:
+ """
+ Count the number of tokens in the given text.
+
+ Args:
+ text (str): The text to count tokens for
+ tokenizer (object, optional): Tokenizer object for local models. Defaults to None.
+ model_source (str): Source of the model to determine tokenization method. Defaults to "Local".
+
+ Returns:
+ int: Number of tokens in the text
+ """
+ if not text:
+ return 0
+
+ try:
+ if model_source == "Local" and tokenizer and len(tokenizer) > 0:
+ # Use local tokenizer if available
+ tokens = tokenizer[0].encode(text, add_special_tokens=False)
+ return len(tokens)
+ else:
+ # Fallback: rough estimation using word count (approximately 1.3 tokens per word)
+ word_count = len(text.split())
+ return int(word_count * 1.3)
+ except Exception as e:
+ print(f"Error counting tokens: {e}. Using word count estimation.")
+ # Fallback: rough estimation using word count
+ word_count = len(text.split())
+ return int(word_count * 1.3)
+
+
+def summarise_output_topics_query(
+ model_choice: str,
+ in_api_key: str,
+ temperature: float,
+ formatted_summary_prompt: str,
+ summarise_topic_descriptions_system_prompt: str,
+ model_source: str,
+ bedrock_runtime: boto3.Session.client,
+ local_model=list(),
+ tokenizer=list(),
+ assistant_model=list(),
+ azure_endpoint: str = "",
+ api_url: str = None,
+):
+ """
+ Query an LLM to generate a summary of topics based on the provided prompts.
+
+ Args:
+ model_choice (str): The name/type of model to use for generation
+ in_api_key (str): API key for accessing the model service
+ temperature (float): Temperature parameter for controlling randomness in generation
+ formatted_summary_prompt (str): The formatted prompt containing topics to summarize
+ summarise_topic_descriptions_system_prompt (str): System prompt providing context and instructions
+ model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
+ bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
+ local_model (object, optional): Local model object if using local inference. Defaults to empty list.
+ tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list.
+ Returns:
+ tuple: Contains:
+ - response_text (str): The generated summary text
+ - conversation_history (list): History of the conversation with the model
+ - whole_conversation_metadata (list): Metadata about the conversation
+ """
+ conversation_history = list()
+ whole_conversation_metadata = list()
+ client = list()
+ client_config = {}
+
+ # Combine system prompt and user prompt for token counting
+ full_input_text = (
+ summarise_topic_descriptions_system_prompt + "\n" + formatted_summary_prompt[0]
+ if isinstance(formatted_summary_prompt, list)
+ else summarise_topic_descriptions_system_prompt
+ + "\n"
+ + formatted_summary_prompt
+ )
+
+ # Count tokens in the input text
+ input_token_count = count_tokens_in_text(full_input_text, tokenizer, model_source)
+
+ # Check if input exceeds context length
+ if input_token_count > LLM_CONTEXT_LENGTH:
+ error_message = f"Input text exceeds LLM context length. Input tokens: {input_token_count}, Max context length: {LLM_CONTEXT_LENGTH}. Please reduce the input text size."
+ print(error_message)
+ raise ValueError(error_message)
+
+ print(f"Input token count: {input_token_count} (Max: {LLM_CONTEXT_LENGTH})")
+
+ # Prepare Gemini models before query
+ if "Gemini" in model_source:
+ # print("Using Gemini model:", model_choice)
+ client, config = construct_gemini_generative_model(
+ in_api_key=in_api_key,
+ temperature=temperature,
+ model_choice=model_choice,
+ system_prompt=system_prompt,
+ max_tokens=max_tokens,
+ )
+ elif "Azure/OpenAI" in model_source:
+ client, config = construct_azure_client(
+ in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""),
+ endpoint=azure_endpoint,
+ )
+ elif "Local" in model_source:
+ pass
+ # print("Using local model: ", model_choice)
+ elif "AWS" in model_source:
+ pass
+ # print("Using AWS Bedrock model:", model_choice)
+
+ whole_conversation = [summarise_topic_descriptions_system_prompt]
+
+ # Process requests to large language model
+ (
+ responses,
+ conversation_history,
+ whole_conversation,
+ whole_conversation_metadata,
+ response_text,
+ ) = process_requests(
+ formatted_summary_prompt,
+ system_prompt,
+ conversation_history,
+ whole_conversation,
+ whole_conversation_metadata,
+ client,
+ client_config,
+ model_choice,
+ temperature,
+ bedrock_runtime=bedrock_runtime,
+ model_source=model_source,
+ local_model=local_model,
+ tokenizer=tokenizer,
+ assistant_model=assistant_model,
+ assistant_prefill=summary_assistant_prefill,
+ api_url=api_url,
+ )
+
+ summarised_output = re.sub(
+ r"\n{2,}", "\n", response_text
+ ) # Replace multiple line breaks with a single line break
+ summarised_output = re.sub(
+ r"^\n{1,}", "", summarised_output
+ ) # Remove one or more line breaks at the start
+ summarised_output = re.sub(
+ r"\n", "
", summarised_output
+ ) # Replace \n with more html friendly
tags
+ summarised_output = summarised_output.strip()
+
+ print("Finished summary query")
+
+ # Ensure the system prompt is included in the conversation history
+ try:
+ if isinstance(conversation_history, list):
+ has_system_prompt = False
+
+ if conversation_history:
+ first_entry = conversation_history[0]
+ if isinstance(first_entry, dict):
+ role_is_system = first_entry.get("role") == "system"
+ parts = first_entry.get("parts")
+ content_matches = (
+ parts == summarise_topic_descriptions_system_prompt
+ or (
+ isinstance(parts, list)
+ and summarise_topic_descriptions_system_prompt in parts
+ )
+ )
+ has_system_prompt = role_is_system and content_matches
+ elif isinstance(first_entry, str):
+ has_system_prompt = (
+ first_entry.strip().lower().startswith("system:")
+ )
+
+ if not has_system_prompt:
+ conversation_history.insert(
+ 0,
+ {
+ "role": "system",
+ "parts": [summarise_topic_descriptions_system_prompt],
+ },
+ )
+ except Exception as _e:
+ # Non-fatal: if anything goes wrong, return the original conversation history
+ pass
+
+ return (
+ summarised_output,
+ conversation_history,
+ whole_conversation_metadata,
+ response_text,
+ )
+
+
+def process_debug_output_iteration(
+ output_debug_files: str,
+ output_folder: str,
+ batch_file_path_details: str,
+ model_choice_clean_short: str,
+ final_system_prompt: str,
+ summarised_output: str,
+ conversation_history: list,
+ metadata: list,
+ log_output_files: list,
+ task_type: str,
+) -> tuple[str, str, str, str]:
+ """
+ Writes debug files for summary generation if output_debug_files is "True",
+ and returns the content of the prompt, summary, conversation, and metadata for the current iteration.
+
+ Args:
+ output_debug_files (str): Flag to indicate if debug files should be written.
+ output_folder (str): The folder where output files are saved.
+ batch_file_path_details (str): Details for the batch file path.
+ model_choice_clean_short (str): Shortened cleaned model choice.
+ final_system_prompt (str): The system prompt content.
+ summarised_output (str): The summarised output content.
+ conversation_history (list): The full conversation history.
+ metadata (list): The metadata for the conversation.
+ log_output_files (list): A list to append paths of written log files. This list is modified in-place.
+ task_type (str): The type of task being performed.
+ Returns:
+ tuple[str, str, str, str]: A tuple containing the content of the prompt,
+ summarised output, conversation history (as string),
+ and metadata (as string) for the current iteration.
+ """
+ current_prompt_content = final_system_prompt
+ current_summary_content = summarised_output
+
+ if isinstance(conversation_history, list):
+
+ # Handle both list of strings and list of dicts
+ if conversation_history and isinstance(conversation_history[0], dict):
+ # Convert list of dicts to list of strings
+ conversation_strings = list()
+ for entry in conversation_history:
+ if "role" in entry and "parts" in entry:
+ role = entry["role"].capitalize()
+ message = (
+ " ".join(entry["parts"])
+ if isinstance(entry["parts"], list)
+ else str(entry["parts"])
+ )
+ conversation_strings.append(f"{role}: {message}")
+ else:
+ # Fallback for unexpected dict format
+ conversation_strings.append(str(entry))
+ current_conversation_content = "\n".join(conversation_strings)
+ else:
+ # Handle list of strings
+ current_conversation_content = "\n".join(conversation_history)
+ else:
+ current_conversation_content = str(conversation_history)
+ current_metadata_content = str(metadata)
+ current_task_type = task_type
+
+ if output_debug_files == "True":
+ try:
+ formatted_prompt_output_path = (
+ output_folder
+ + batch_file_path_details
+ + "_full_prompt_"
+ + model_choice_clean_short
+ + "_"
+ + current_task_type
+ + ".txt"
+ )
+ final_table_output_path = (
+ output_folder
+ + batch_file_path_details
+ + "_full_response_"
+ + model_choice_clean_short
+ + "_"
+ + current_task_type
+ + ".txt"
+ )
+ whole_conversation_path = (
+ output_folder
+ + batch_file_path_details
+ + "_full_conversation_"
+ + model_choice_clean_short
+ + "_"
+ + current_task_type
+ + ".txt"
+ )
+ whole_conversation_path_meta = (
+ output_folder
+ + batch_file_path_details
+ + "_metadata_"
+ + model_choice_clean_short
+ + "_"
+ + current_task_type
+ + ".txt"
+ )
+
+ with open(
+ formatted_prompt_output_path,
+ "w",
+ encoding="utf-8-sig",
+ errors="replace",
+ ) as f:
+ f.write(current_prompt_content)
+ with open(
+ final_table_output_path, "w", encoding="utf-8-sig", errors="replace"
+ ) as f:
+ f.write(current_summary_content)
+ with open(
+ whole_conversation_path, "w", encoding="utf-8-sig", errors="replace"
+ ) as f:
+ f.write(current_conversation_content)
+ with open(
+ whole_conversation_path_meta,
+ "w",
+ encoding="utf-8-sig",
+ errors="replace",
+ ) as f:
+ f.write(current_metadata_content)
+
+ log_output_files.append(formatted_prompt_output_path)
+ log_output_files.append(final_table_output_path)
+ log_output_files.append(whole_conversation_path)
+ log_output_files.append(whole_conversation_path_meta)
+ except Exception as e:
+ print(f"Error in writing debug files for summary: {e}")
+
+ # Return the content of the objects for the current iteration.
+ # The caller can then append these to separate lists if accumulation is desired.
+ return (
+ current_prompt_content,
+ current_summary_content,
+ current_conversation_content,
+ current_metadata_content,
+ )
+
+
+@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
+def summarise_output_topics(
+ sampled_reference_table_df: pd.DataFrame,
+ topic_summary_df: pd.DataFrame,
+ reference_table_df: pd.DataFrame,
+ model_choice: str,
+ in_api_key: str,
+ temperature: float,
+ reference_data_file_name: str,
+ summarised_outputs: list = list(),
+ latest_summary_completed: int = 0,
+ out_metadata_str: str = "",
+ in_data_files: List[str] = list(),
+ in_excel_sheets: str = "",
+ chosen_cols: List[str] = list(),
+ log_output_files: list[str] = list(),
+ summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
+ output_folder: str = OUTPUT_FOLDER,
+ context_textbox: str = "",
+ aws_access_key_textbox: str = "",
+ aws_secret_key_textbox: str = "",
+ aws_region_textbox: str = "",
+ model_name_map: dict = model_name_map,
+ hf_api_key_textbox: str = "",
+ azure_endpoint_textbox: str = "",
+ existing_logged_content: list = list(),
+ additional_summary_instructions_provided: str = "",
+ output_debug_files: str = "False",
+ group_value: str = "All",
+ reasoning_suffix: str = reasoning_suffix,
+ local_model: object = None,
+ tokenizer: object = None,
+ assistant_model: object = None,
+ summarise_topic_descriptions_prompt: str = summarise_topic_descriptions_prompt,
+ summarise_topic_descriptions_system_prompt: str = summarise_topic_descriptions_system_prompt,
+ do_summaries: str = "Yes",
+ api_url: str = None,
+ progress=gr.Progress(track_tqdm=True),
+):
+ """
+ Create improved summaries of topics by consolidating raw batch-level summaries from the initial model run. Works on a single group of summaries at a time (called from wrapper function summarise_output_topics_by_group).
+
+ Args:
+ sampled_reference_table_df (pd.DataFrame): DataFrame containing sampled reference data with summaries
+ topic_summary_df (pd.DataFrame): DataFrame containing topic summary information
+ reference_table_df (pd.DataFrame): DataFrame mapping response references to topics
+ model_choice (str): Name of the LLM model to use
+ in_api_key (str): API key for model access
+ temperature (float): Temperature parameter for model generation
+ reference_data_file_name (str): Name of the reference data file
+ summarised_outputs (list, optional): List to store generated summaries. Defaults to empty list.
+ latest_summary_completed (int, optional): Index of last completed summary. Defaults to 0.
+ out_metadata_str (str, optional): String for metadata output. Defaults to empty string.
+ in_data_files (List[str], optional): List of input data file paths. Defaults to empty list.
+ in_excel_sheets (str, optional): Excel sheet names if using Excel files. Defaults to empty string.
+ chosen_cols (List[str], optional): List of columns selected for analysis. Defaults to empty list.
+ log_output_files (list[str], optional): List of log file paths. Defaults to empty list.
+ summarise_format_radio (str, optional): Format instructions for summary generation. Defaults to two paragraph format.
+ output_folder (str, optional): Folder path for outputs. Defaults to OUTPUT_FOLDER.
+ context_textbox (str, optional): Additional context for summarization. Defaults to empty string.
+ aws_access_key_textbox (str, optional): AWS access key. Defaults to empty string.
+ aws_secret_key_textbox (str, optional): AWS secret key. Defaults to empty string.
+ model_name_map (dict, optional): Dictionary mapping model choices to their properties. Defaults to model_name_map.
+ hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
+ azure_endpoint_textbox (str, optional): Azure endpoint. Defaults to empty string.
+ additional_summary_instructions_provided (str, optional): Additional summary instructions provided by the user. Defaults to empty string.
+ existing_logged_content (list, optional): List of existing logged content. Defaults to empty list.
+ output_debug_files (str, optional): Flag to indicate if debug files should be written. Defaults to "False".
+ group_value (str, optional): Value of the group to summarise. Defaults to "All".
+ reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
+ local_model (object, optional): Local model object if using local inference. Defaults to None.
+ tokenizer (object, optional): Tokenizer object if using local inference. Defaults to None.
+ assistant_model (object, optional): Assistant model object if using local inference. Defaults to None.
+ summarise_topic_descriptions_prompt (str, optional): Prompt template for topic summarization.
+ summarise_topic_descriptions_system_prompt (str, optional): System prompt for topic summarization.
+ do_summaries (str, optional): Flag to control summary generation. Defaults to "Yes".
+ progress (gr.Progress, optional): Gradio progress tracker. Defaults to track_tqdm=True.
+
+ Returns:
+ Multiple outputs including summarized content, metadata, and file paths
+ """
+ out_metadata = list()
+ summarised_output_markdown = ""
+ output_files = list()
+ acc_input_tokens = 0
+ acc_output_tokens = 0
+ acc_number_of_calls = 0
+ time_taken = 0
+ out_metadata_str = (
+ "" # Output metadata is currently replaced on starting a summarisation task
+ )
+ out_message = list()
+ task_type = "Topic summarisation"
+ topic_summary_df_revised = pd.DataFrame()
+
+ all_prompts_content = list()
+ all_summaries_content = list()
+ all_metadata_content = list()
+ all_groups_content = list()
+ all_batches_content = list()
+ all_model_choice_content = list()
+ all_validated_content = list()
+ all_task_type_content = list()
+ all_logged_content = list()
+ all_file_names_content = list()
+
+ tic = time.perf_counter()
+
+ # Ensure custom model_choice is registered in model_name_map
+ ensure_model_in_map(model_choice, model_name_map)
+
+ model_choice_clean = clean_column_name(
+ model_name_map[model_choice]["short_name"],
+ max_length=20,
+ front_characters=False,
+ )
+
+ if context_textbox and "The context of this analysis is" not in context_textbox:
+ context_textbox = "The context of this analysis is '" + context_textbox + "'."
+
+ if log_output_files is None:
+ log_output_files = list()
+
+ # Check for data for summarisations
+ if not topic_summary_df.empty and not reference_table_df.empty:
+ print("Unique table and reference table data found.")
+ else:
+ out_message = "Please upload a unique topic table and reference table file to continue with summarisation."
+ print(out_message)
+ raise Exception(out_message)
+
+ if "Revised summary" in reference_table_df.columns:
+ out_message = "Summary has already been created for this file"
+ print(out_message)
+ raise Exception(out_message)
+
+ # Load in data file and chosen columns if exists to create pivot table later
+ file_data = pd.DataFrame()
+ if in_data_files and chosen_cols:
+ file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(
+ in_data_files, chosen_cols, 1, in_excel_sheets=in_excel_sheets
+ )
+ else:
+ out_message = "No file data found, pivot table output will not be created."
+ print(out_message)
+ # Use sys.stdout.write to avoid issues with progress bars
+ # sys.stdout.write(out_message + "\n")
+ # sys.stdout.flush()
+ # Note: file_data will remain empty, pivot tables will not be created
+
+ reference_table_df = reference_table_df.rename(
+ columns={"General Topic": "General topic"}, errors="ignore"
+ )
+ topic_summary_df = topic_summary_df.rename(
+ columns={"General Topic": "General topic"}, errors="ignore"
+ )
+ if "Group" not in reference_table_df.columns:
+ reference_table_df["Group"] = "All"
+ if "Group" not in topic_summary_df.columns:
+ topic_summary_df["Group"] = "All"
+ if "Group" not in sampled_reference_table_df.columns:
+ sampled_reference_table_df["Group"] = "All"
+
+ # Use the Summary column if it exists, otherwise use the Revised summary column
+ if "Summary" in sampled_reference_table_df.columns:
+ all_summaries = sampled_reference_table_df["Summary"].tolist()
+ else:
+ all_summaries = sampled_reference_table_df["Revised summary"].tolist()
+
+ all_groups = sampled_reference_table_df["Group"].tolist()
+
+ if not group_value:
+ group_value = str(all_groups[0])
+ else:
+ group_value = str(group_value)
+
+ length_all_summaries = len(all_summaries)
+
+ model_source = model_name_map[model_choice]["source"]
+
+ if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not local_model):
+ progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
+ local_model = get_model()
+ tokenizer = get_tokenizer()
+ assistant_model = get_assistant_model()
+
+ (
+ "Revising topic-level summaries. "
+ + str(latest_summary_completed)
+ + " summaries completed so far."
+ )
+ summary_loop = progress.tqdm(
+ range(latest_summary_completed, length_all_summaries),
+ desc="Revising topic-level summaries",
+ unit="summaries",
+ )
+
+ if do_summaries == "Yes":
+
+ bedrock_runtime = connect_to_bedrock_runtime(
+ model_name_map,
+ model_choice,
+ aws_access_key_textbox,
+ aws_secret_key_textbox,
+ aws_region_textbox,
+ )
+
+ create_batch_file_path_details(reference_data_file_name)
+ model_choice_clean_short = clean_column_name(
+ model_choice_clean, max_length=20, front_characters=False
+ )
+ file_name_clean = f"{clean_column_name(reference_data_file_name, max_length=15)}_{clean_column_name(str(group_value), max_length=15).replace(' ','_')}"
+ # file_name_clean = clean_column_name(reference_data_file_name, max_length=20, front_characters=True)
+ in_column_cleaned = clean_column_name(chosen_cols, max_length=20)
+
+ combined_summary_instructions = (
+ summarise_format_radio + ". " + additional_summary_instructions_provided
+ )
+
+ for summary_no in summary_loop:
+ print("Current summary number is:", summary_no)
+
+ batch_file_path_details = f"{file_name_clean}_batch_{latest_summary_completed + 1}_size_1_col_{in_column_cleaned}"
+
+ summary_text = all_summaries[summary_no]
+ formatted_summary_prompt = [
+ summarise_topic_descriptions_prompt.format(
+ summaries=summary_text, summary_format=combined_summary_instructions
+ )
+ ]
+
+ formatted_summarise_topic_descriptions_system_prompt = (
+ summarise_topic_descriptions_system_prompt.format(
+ column_name=chosen_cols, consultation_context=context_textbox
+ )
+ )
+
+ if "Local" in model_source and reasoning_suffix:
+ formatted_summarise_topic_descriptions_system_prompt = (
+ formatted_summarise_topic_descriptions_system_prompt
+ + "\n"
+ + reasoning_suffix
+ )
+
+ try:
+ response, conversation_history, metadata, response_text = (
+ summarise_output_topics_query(
+ model_choice,
+ in_api_key,
+ temperature,
+ formatted_summary_prompt,
+ formatted_summarise_topic_descriptions_system_prompt,
+ model_source,
+ bedrock_runtime,
+ local_model,
+ tokenizer=tokenizer,
+ assistant_model=assistant_model,
+ azure_endpoint=azure_endpoint_textbox,
+ api_url=api_url,
+ )
+ )
+ summarised_output = response_text
+ except Exception as e:
+ print("Creating summary failed:", e)
+ summarised_output = ""
+
+ summarised_outputs.append(summarised_output)
+ out_metadata.extend(metadata)
+ out_metadata_str = ". ".join(out_metadata)
+
+ # Call the new function to process and log debug outputs for the current iteration.
+ # The returned values are the contents of the prompt, summary, conversation, and metadata
+
+ full_prompt = (
+ formatted_summarise_topic_descriptions_system_prompt
+ + "\n"
+ + formatted_summary_prompt[0]
+ )
+
+ # Coerce toggle to string expected by debug writer (accepts True/False or "True"/"False")
+ output_debug_files_str = (
+ "True"
+ if (
+ (isinstance(output_debug_files, bool) and output_debug_files)
+ or (str(output_debug_files) == "True")
+ )
+ else "False"
+ )
+
+ (
+ current_prompt_content_logged,
+ current_summary_content_logged,
+ current_conversation_content_logged,
+ current_metadata_content_logged,
+ ) = process_debug_output_iteration(
+ output_debug_files_str,
+ output_folder,
+ batch_file_path_details,
+ model_choice_clean_short,
+ full_prompt,
+ summarised_output,
+ conversation_history,
+ metadata,
+ log_output_files,
+ task_type=task_type,
+ )
+
+ all_prompts_content.append(current_prompt_content_logged)
+ all_summaries_content.append(current_summary_content_logged)
+ # all_conversation_content.append(current_conversation_content_logged)
+ all_metadata_content.append(current_metadata_content_logged)
+ all_groups_content.append(all_groups[summary_no])
+ all_batches_content.append(f"{summary_no}:")
+ all_model_choice_content.append(model_choice_clean_short)
+ all_validated_content.append("No")
+ all_task_type_content.append(task_type)
+ all_file_names_content.append(reference_data_file_name)
+ latest_summary_completed += 1
+
+ toc = time.perf_counter()
+ time_taken = toc - tic
+
+ if time_taken > max_time_for_loop:
+ print(
+ "Time taken for loop is greater than maximum time allowed. Exiting and restarting loop"
+ )
+ summary_loop.close()
+ tqdm._instances.clear()
+ break
+
+ # If all summaries completed, make final outputs
+ if latest_summary_completed >= length_all_summaries:
+ print("All summaries completed. Creating outputs.")
+
+ sampled_reference_table_df["Revised summary"] = summarised_outputs
+
+ join_cols = ["General topic", "Subtopic", "Sentiment"]
+ join_plus_summary_cols = [
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Revised summary",
+ ]
+
+ summarised_references_j = sampled_reference_table_df[
+ join_plus_summary_cols
+ ].drop_duplicates(join_plus_summary_cols)
+
+ topic_summary_df_revised = topic_summary_df.merge(
+ summarised_references_j, on=join_cols, how="left"
+ )
+
+ # If no new summary is available, keep the original
+ # But prefer the version without "Rows X to Y" prefix to avoid duplication
+ def clean_summary_text(text):
+ if pd.isna(text):
+ return text
+ # Remove "Rows X to Y:" prefix if present (both at start and after
tags)
+ import re
+
+ # First remove from the beginning
+ cleaned = re.sub(r"^Rows\s+\d+\s+to\s+\d+:\s*", "", str(text))
+ # Then remove from after
tags
+ cleaned = re.sub(r"
\s*Rows\s+\d+\s+to\s+\d+:\s*", "
", cleaned)
+ return cleaned
+
+ topic_summary_df_revised["Revised summary"] = topic_summary_df_revised[
+ "Revised summary"
+ ].combine_first(topic_summary_df_revised["Summary"])
+ # Clean the revised summary to remove "Rows X to Y" prefixes
+ topic_summary_df_revised["Revised summary"] = topic_summary_df_revised[
+ "Revised summary"
+ ].apply(clean_summary_text)
+ topic_summary_df_revised = topic_summary_df_revised[
+ [
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Group",
+ "Number of responses",
+ "Revised summary",
+ ]
+ ]
+
+ # Note: "Rows X to Y:" prefixes are now cleaned by the clean_summary_text function above
+ topic_summary_df_revised["Topic number"] = range(
+ 1, len(topic_summary_df_revised) + 1
+ )
+
+ # If no new summary is available, keep the original. Also join on topic number to ensure consistent topic number assignment
+ reference_table_df_revised = reference_table_df.copy()
+ reference_table_df_revised = reference_table_df_revised.drop(
+ "Topic number", axis=1, errors="ignore"
+ )
+
+ # Ensure reference table has Topic number column
+ if (
+ "Topic number" not in reference_table_df_revised.columns
+ or "Revised summary" not in reference_table_df_revised.columns
+ ):
+ if (
+ "Topic number" in topic_summary_df_revised.columns
+ and "Revised summary" in topic_summary_df_revised.columns
+ ):
+ reference_table_df_revised = reference_table_df_revised.merge(
+ topic_summary_df_revised[
+ [
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ "Group",
+ "Topic number",
+ "Revised summary",
+ ]
+ ],
+ on=["General topic", "Subtopic", "Sentiment", "Group"],
+ how="left",
+ )
+
+ reference_table_df_revised["Revised summary"] = reference_table_df_revised[
+ "Revised summary"
+ ].combine_first(reference_table_df_revised["Summary"])
+ # Clean the revised summary to remove "Rows X to Y" prefixes
+ reference_table_df_revised["Revised summary"] = reference_table_df_revised[
+ "Revised summary"
+ ].apply(clean_summary_text)
+ reference_table_df_revised = reference_table_df_revised.drop(
+ "Summary", axis=1, errors="ignore"
+ )
+
+ # Remove topics that are tagged as 'Not Mentioned'
+ topic_summary_df_revised = topic_summary_df_revised.loc[
+ topic_summary_df_revised["Sentiment"] != "Not Mentioned", :
+ ]
+ reference_table_df_revised = reference_table_df_revised.loc[
+ reference_table_df_revised["Sentiment"] != "Not Mentioned", :
+ ]
+
+ # Combine the logged content into a list of dictionaries
+ all_logged_content = [
+ {
+ "prompt": prompt,
+ "response": summary,
+ "metadata": metadata,
+ "batch": batch,
+ "model_choice": model_choice,
+ "validated": validated,
+ "group": group,
+ "task_type": task_type,
+ "file_name": file_name,
+ }
+ for prompt, summary, metadata, batch, model_choice, validated, group, task_type, file_name in zip(
+ all_prompts_content,
+ all_summaries_content,
+ all_metadata_content,
+ all_batches_content,
+ all_model_choice_content,
+ all_validated_content,
+ all_groups_content,
+ all_task_type_content,
+ all_file_names_content,
+ )
+ ]
+
+ if isinstance(existing_logged_content, pd.DataFrame):
+ existing_logged_content = existing_logged_content.to_dict(orient="records")
+
+ out_logged_content = existing_logged_content + all_logged_content
+
+ ### Save output files
+
+ if output_debug_files == "True":
+
+ if not file_data.empty:
+ basic_response_data = get_basic_response_data(file_data, chosen_cols)
+ reference_table_df_revised_pivot = (
+ convert_reference_table_to_pivot_table(
+ reference_table_df_revised, basic_response_data
+ )
+ )
+
+ ### Save pivot file to log area
+ reference_table_df_revised_pivot_path = (
+ output_folder
+ + file_name_clean
+ + "_summ_reference_table_pivot_"
+ + model_choice_clean
+ + ".csv"
+ )
+ reference_table_df_revised_pivot.drop(
+ ["1", "2", "3"], axis=1, errors="ignore"
+ ).to_csv(
+ reference_table_df_revised_pivot_path,
+ index=None,
+ encoding="utf-8-sig",
+ )
+ log_output_files.append(reference_table_df_revised_pivot_path)
+
+ # Save to file
+ topic_summary_df_revised_path = (
+ output_folder
+ + file_name_clean
+ + "_summ_unique_topics_table_"
+ + model_choice_clean
+ + ".csv"
+ )
+ topic_summary_df_revised.drop(
+ ["1", "2", "3"], axis=1, errors="ignore"
+ ).to_csv(topic_summary_df_revised_path, index=None, encoding="utf-8-sig")
+
+ reference_table_df_revised_path = (
+ output_folder
+ + file_name_clean
+ + "_summ_reference_table_"
+ + model_choice_clean
+ + ".csv"
+ )
+ reference_table_df_revised.drop(
+ ["1", "2", "3"], axis=1, errors="ignore"
+ ).to_csv(reference_table_df_revised_path, index=None, encoding="utf-8-sig")
+
+ log_output_files.extend(
+ [reference_table_df_revised_path, topic_summary_df_revised_path]
+ )
+
+ ###
+ topic_summary_df_revised_display = topic_summary_df_revised.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ summarised_output_markdown = topic_summary_df_revised_display.to_markdown(
+ index=False
+ )
+
+ # Ensure same file name not returned twice
+ output_files = list(set(output_files))
+ log_output_files = list(set(log_output_files))
+
+ acc_input_tokens, acc_output_tokens, acc_number_of_calls = (
+ calculate_tokens_from_metadata(
+ out_metadata_str, model_choice, model_name_map
+ )
+ )
+
+ toc = time.perf_counter()
+ time_taken = toc - tic
+
+ if isinstance(out_message, list):
+ out_message = "\n".join(out_message)
+ else:
+ out_message = out_message
+
+ out_message = (
+ out_message
+ + f"\nTopic summarisation finished processing. Total time: {round(float(time_taken), 1)}s"
+ )
+ print(out_message)
+
+ return (
+ sampled_reference_table_df,
+ topic_summary_df_revised,
+ reference_table_df_revised,
+ output_files,
+ summarised_outputs,
+ latest_summary_completed,
+ out_metadata_str,
+ summarised_output_markdown,
+ log_output_files,
+ output_files,
+ acc_input_tokens,
+ acc_output_tokens,
+ acc_number_of_calls,
+ time_taken,
+ out_message,
+ out_logged_content,
+ )
+
+
+@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
+def wrapper_summarise_output_topics_per_group(
+ grouping_col: str,
+ sampled_reference_table_df: pd.DataFrame,
+ topic_summary_df: pd.DataFrame,
+ reference_table_df: pd.DataFrame,
+ model_choice: str,
+ in_api_key: str,
+ temperature: float,
+ reference_data_file_name: str,
+ summarised_outputs: list = list(),
+ latest_summary_completed: int = 0,
+ out_metadata_str: str = "",
+ in_data_files: List[str] = list(),
+ in_excel_sheets: str = "",
+ chosen_cols: List[str] = list(),
+ log_output_files: list[str] = list(),
+ summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
+ output_folder: str = OUTPUT_FOLDER,
+ context_textbox: str = "",
+ aws_access_key_textbox: str = "",
+ aws_secret_key_textbox: str = "",
+ aws_region_textbox: str = "",
+ model_name_map: dict = model_name_map,
+ hf_api_key_textbox: str = "",
+ azure_endpoint_textbox: str = "",
+ existing_logged_content: list = list(),
+ sample_reference_table: bool = False,
+ no_of_sampled_summaries: int = default_number_of_sampled_summaries,
+ random_seed: int = 42,
+ api_url: str = None,
+ additional_summary_instructions_provided: str = "",
+ output_debug_files: str = OUTPUT_DEBUG_FILES,
+ reasoning_suffix: str = reasoning_suffix,
+ local_model: object = None,
+ tokenizer: object = None,
+ assistant_model: object = None,
+ summarise_topic_descriptions_prompt: str = summarise_topic_descriptions_prompt,
+ summarise_topic_descriptions_system_prompt: str = summarise_topic_descriptions_system_prompt,
+ do_summaries: str = "Yes",
+ progress=gr.Progress(track_tqdm=True),
+) -> Tuple[
+ pd.DataFrame,
+ pd.DataFrame,
+ pd.DataFrame,
+ List[str],
+ List[str],
+ int,
+ str,
+ str,
+ List[str],
+ List[str],
+ int,
+ int,
+ int,
+ float,
+ str,
+ List[dict],
+]:
+ """
+ A wrapper function that iterates through unique values in a specified grouping column
+ and calls the `summarise_output_topics` function for each group of summaries.
+ It accumulates results from each call and returns a consolidated output.
+
+ :param grouping_col: The name of the column to group the data by.
+ :param sampled_reference_table_df: DataFrame containing sampled reference data with summaries
+ :param topic_summary_df: DataFrame containing topic summary information
+ :param reference_table_df: DataFrame mapping response references to topics
+ :param model_choice: Name of the LLM model to use
+ :param in_api_key: API key for model access
+ :param temperature: Temperature parameter for model generation
+ :param reference_data_file_name: Name of the reference data file
+ :param summarised_outputs: List to store generated summaries
+ :param latest_summary_completed: Index of last completed summary
+ :param out_metadata_str: String for metadata output
+ :param in_data_files: List of input data file paths
+ :param in_excel_sheets: Excel sheet names if using Excel files
+ :param chosen_cols: List of columns selected for analysis
+ :param log_output_files: List of log file paths
+ :param summarise_format_radio: Format instructions for summary generation
+ :param output_folder: Folder path for outputs
+ :param context_textbox: Additional context for summarization
+ :param aws_access_key_textbox: AWS access key
+ :param aws_secret_key_textbox: AWS secret key
+ :param model_name_map: Dictionary mapping model choices to their properties
+ :param hf_api_key_textbox: Hugging Face API key
+ :param azure_endpoint_textbox: Azure endpoint
+ :param existing_logged_content: List of existing logged content
+ :param additional_summary_instructions_provided: Additional summary instructions
+ :param output_debug_files: Flag to indicate if debug files should be written
+ :param reasoning_suffix: Suffix for reasoning
+ :param local_model: Local model object if using local inference
+ :param tokenizer: Tokenizer object if using local inference
+ :param assistant_model: Assistant model object if using local inference
+ :param summarise_topic_descriptions_prompt: Prompt template for topic summarization
+ :param summarise_topic_descriptions_system_prompt: System prompt for topic summarization
+ :param do_summaries: Flag to control summary generation
+ :param sample_reference_table: If True, sample the reference table at the top of the function
+ :param no_of_sampled_summaries: Number of summaries to sample per group (default 100)
+ :param random_seed: Random seed for reproducible sampling (default 42)
+ :param progress: Gradio progress tracker
+ :return: A tuple containing consolidated results, mimicking the return structure of `summarise_output_topics`
+ """
+
+ acc_input_tokens = 0
+ acc_output_tokens = 0
+ acc_number_of_calls = 0
+ out_message = list()
+
+ # Logged content
+ all_groups_logged_content = existing_logged_content
+
+ # Check if we have data to process
+ # Allow empty sampled_reference_table_df if sample_reference_table is True (it will be created from reference_table_df)
+ if (
+ (sampled_reference_table_df.empty and not sample_reference_table)
+ or topic_summary_df.empty
+ or reference_table_df.empty
+ ):
+ out_message = "Please upload reference table, topic summary, and sampled reference table files to continue with summarisation."
+ print(out_message)
+ raise Exception(out_message)
+
+ # Ensure Group column exists
+ if "Group" not in sampled_reference_table_df.columns:
+ sampled_reference_table_df["Group"] = "All"
+ if "Group" not in topic_summary_df.columns:
+ topic_summary_df["Group"] = "All"
+ if "Group" not in reference_table_df.columns:
+ reference_table_df["Group"] = "All"
+
+ # Sample reference table if requested
+ if sample_reference_table:
+ print(
+ f"Sampling reference table with {no_of_sampled_summaries} summaries per group..."
+ )
+ sampled_reference_table_df, _ = sample_reference_table_summaries(
+ reference_table_df,
+ random_seed=random_seed,
+ no_of_sampled_summaries=no_of_sampled_summaries,
+ sample_reference_table_checkbox=sample_reference_table,
+ )
+ print(
+ f"Sampling complete. {len(sampled_reference_table_df)} summaries selected."
+ )
+
+ # Get unique group values
+ unique_values = sampled_reference_table_df["Group"].unique()
+
+ if len(unique_values) > MAX_GROUPS:
+ print(
+ f"Warning: More than {MAX_GROUPS} unique values found in '{grouping_col}'. Processing only the first {MAX_GROUPS}."
+ )
+ unique_values = unique_values[:MAX_GROUPS]
+
+ # Initialize accumulators for results across all groups
+ acc_sampled_reference_table_df = pd.DataFrame()
+ acc_topic_summary_df_revised = pd.DataFrame()
+ acc_reference_table_df_revised = pd.DataFrame()
+ acc_output_files = list()
+ acc_log_output_files = list()
+ acc_summarised_outputs = list()
+ acc_latest_summary_completed = latest_summary_completed
+ acc_out_metadata_str = out_metadata_str
+ acc_summarised_output_markdown = ""
+ acc_total_time_taken = 0.0
+ acc_logged_content = list()
+
+ if len(unique_values) == 1:
+ # If only one unique value, no need for progress bar, iterate directly
+ loop_object = unique_values
+ else:
+ # If multiple unique values, use tqdm progress bar
+ loop_object = progress.tqdm(
+ unique_values, desc="Summarising group", unit="groups"
+ )
+
+ for i, group_value in enumerate(loop_object):
+ print(
+ f"\nProcessing summary group: {grouping_col} = {group_value} ({i+1}/{len(unique_values)})"
+ )
+
+ # Filter data for current group
+ filtered_sampled_reference_table_df = sampled_reference_table_df[
+ sampled_reference_table_df["Group"] == group_value
+ ].copy()
+ filtered_topic_summary_df = topic_summary_df[
+ topic_summary_df["Group"] == group_value
+ ].copy()
+ filtered_reference_table_df = reference_table_df[
+ reference_table_df["Group"] == group_value
+ ].copy()
+
+ if filtered_sampled_reference_table_df.empty:
+ print(f"No data for {grouping_col} = {group_value}. Skipping.")
+ continue
+
+ # Create unique file name for this group's outputs
+ group_file_name = f"{reference_data_file_name}_{clean_column_name(str(group_value), max_length=15).replace(' ','_')}"
+
+ # Call summarise_output_topics for the current group
+ try:
+ (
+ seg_sampled_reference_table_df,
+ seg_topic_summary_df_revised,
+ seg_reference_table_df_revised,
+ seg_output_files,
+ seg_summarised_outputs,
+ seg_latest_summary_completed,
+ seg_out_metadata_str,
+ seg_summarised_output_markdown,
+ seg_log_output_files,
+ seg_output_files_2,
+ seg_acc_input_tokens,
+ seg_acc_output_tokens,
+ seg_acc_number_of_calls,
+ seg_time_taken,
+ seg_out_message,
+ seg_logged_content,
+ ) = summarise_output_topics(
+ sampled_reference_table_df=filtered_sampled_reference_table_df,
+ topic_summary_df=filtered_topic_summary_df,
+ reference_table_df=filtered_reference_table_df,
+ model_choice=model_choice,
+ in_api_key=in_api_key,
+ temperature=temperature,
+ reference_data_file_name=group_file_name,
+ summarised_outputs=list(), # Fresh for each call
+ latest_summary_completed=0, # Reset for each group
+ out_metadata_str="", # Fresh for each call
+ in_data_files=in_data_files,
+ in_excel_sheets=in_excel_sheets,
+ chosen_cols=chosen_cols,
+ log_output_files=list(), # Fresh for each call
+ summarise_format_radio=summarise_format_radio,
+ output_folder=output_folder,
+ context_textbox=context_textbox,
+ aws_access_key_textbox=aws_access_key_textbox,
+ aws_secret_key_textbox=aws_secret_key_textbox,
+ aws_region_textbox=aws_region_textbox,
+ model_name_map=model_name_map,
+ hf_api_key_textbox=hf_api_key_textbox,
+ azure_endpoint_textbox=azure_endpoint_textbox,
+ existing_logged_content=all_groups_logged_content,
+ additional_summary_instructions_provided=additional_summary_instructions_provided,
+ output_debug_files=output_debug_files,
+ group_value=group_value,
+ reasoning_suffix=reasoning_suffix,
+ local_model=local_model,
+ tokenizer=tokenizer,
+ assistant_model=assistant_model,
+ summarise_topic_descriptions_prompt=summarise_topic_descriptions_prompt,
+ summarise_topic_descriptions_system_prompt=summarise_topic_descriptions_system_prompt,
+ do_summaries=do_summaries,
+ api_url=api_url,
+ )
+
+ # Aggregate results
+ acc_sampled_reference_table_df = pd.concat(
+ [acc_sampled_reference_table_df, seg_sampled_reference_table_df]
+ )
+ acc_topic_summary_df_revised = pd.concat(
+ [acc_topic_summary_df_revised, seg_topic_summary_df_revised]
+ )
+ acc_reference_table_df_revised = pd.concat(
+ [acc_reference_table_df_revised, seg_reference_table_df_revised]
+ )
+
+ # For lists, extend
+ acc_output_files.extend(
+ f for f in seg_output_files if f not in acc_output_files
+ )
+ acc_log_output_files.extend(
+ f for f in seg_log_output_files if f not in acc_log_output_files
+ )
+ acc_summarised_outputs.extend(seg_summarised_outputs)
+
+ acc_latest_summary_completed = seg_latest_summary_completed
+ acc_out_metadata_str += (
+ ("\n---\n" if acc_out_metadata_str else "")
+ + f"Group {grouping_col}={group_value}:\n"
+ + seg_out_metadata_str
+ )
+ acc_summarised_output_markdown = (
+ seg_summarised_output_markdown # Keep the latest markdown
+ )
+ acc_total_time_taken += float(seg_time_taken)
+ acc_logged_content.extend(seg_logged_content)
+
+ # Accumulate token counts
+ acc_input_tokens += seg_acc_input_tokens
+ acc_output_tokens += seg_acc_output_tokens
+ acc_number_of_calls += seg_acc_number_of_calls
+
+ print(
+ f"Group {grouping_col} = {group_value} summarised. Time: {seg_time_taken:.2f}s"
+ )
+
+ except Exception as e:
+ print(f"Error processing summary group {grouping_col} = {group_value}: {e}")
+ # Optionally, decide if you want to continue with other groups or stop
+ # For now, it will continue
+ continue
+
+ # Ensure custom model_choice is registered in model_name_map
+ ensure_model_in_map(model_choice, model_name_map)
+
+ # Create consolidated output files
+ overall_file_name = clean_column_name(reference_data_file_name, max_length=20)
+ model_choice_clean = model_name_map[model_choice]["short_name"]
+ model_choice_clean_short = clean_column_name(
+ model_choice_clean, max_length=20, front_characters=False
+ )
+
+ # Save consolidated outputs
+ if (
+ not acc_topic_summary_df_revised.empty
+ and not acc_reference_table_df_revised.empty
+ ):
+ # Sort the dataframes
+ if "General topic" in acc_topic_summary_df_revised.columns:
+ acc_topic_summary_df_revised["Number of responses"] = (
+ acc_topic_summary_df_revised["Number of responses"].astype(int)
+ )
+ acc_topic_summary_df_revised.sort_values(
+ [
+ "Group",
+ "Number of responses",
+ "General topic",
+ "Subtopic",
+ "Sentiment",
+ ],
+ ascending=[True, False, True, True, True],
+ inplace=True,
+ )
+ elif "Main heading" in acc_topic_summary_df_revised.columns:
+ acc_topic_summary_df_revised["Number of responses"] = (
+ acc_topic_summary_df_revised["Number of responses"].astype(int)
+ )
+ acc_topic_summary_df_revised.sort_values(
+ [
+ "Group",
+ "Number of responses",
+ "Main heading",
+ "Subheading",
+ "Topic number",
+ ],
+ ascending=[True, False, True, True, True],
+ inplace=True,
+ )
+
+ # Save consolidated files
+ consolidated_topic_summary_path = (
+ output_folder
+ + overall_file_name
+ + "_all_final_summ_unique_topics_"
+ + model_choice_clean_short
+ + ".csv"
+ )
+ consolidated_reference_table_path = (
+ output_folder
+ + overall_file_name
+ + "_all_final_summ_reference_table_"
+ + model_choice_clean_short
+ + ".csv"
+ )
+
+ acc_topic_summary_df_revised.drop(
+ ["1", "2", "3"], axis=1, errors="ignore"
+ ).to_csv(consolidated_topic_summary_path, index=None, encoding="utf-8-sig")
+ acc_reference_table_df_revised.drop(
+ ["1", "2", "3"], axis=1, errors="ignore"
+ ).to_csv(consolidated_reference_table_path, index=None, encoding="utf-8-sig")
+
+ acc_output_files.extend(
+ [consolidated_topic_summary_path, consolidated_reference_table_path]
+ )
+
+ # Create markdown output for display
+ topic_summary_df_revised_display = acc_topic_summary_df_revised.apply(
+ lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length))
+ )
+ acc_summarised_output_markdown = topic_summary_df_revised_display.to_markdown(
+ index=False
+ )
+
+ out_message = "\n".join(out_message)
+ out_message = (
+ out_message
+ + " "
+ + f"Topic summarisation finished processing all groups. Total time: {acc_total_time_taken:.2f}s"
+ )
+ print(out_message)
+
+ # The return signature should match summarise_output_topics
+ return (
+ acc_sampled_reference_table_df,
+ acc_topic_summary_df_revised,
+ acc_reference_table_df_revised,
+ acc_output_files,
+ acc_summarised_outputs,
+ acc_latest_summary_completed,
+ acc_out_metadata_str,
+ acc_summarised_output_markdown,
+ acc_log_output_files,
+ acc_output_files, # Duplicate for compatibility
+ acc_input_tokens,
+ acc_output_tokens,
+ acc_number_of_calls,
+ acc_total_time_taken,
+ out_message,
+ acc_logged_content,
+ )
+
+
+@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
+def overall_summary(
+ topic_summary_df: pd.DataFrame,
+ model_choice: str,
+ in_api_key: str,
+ temperature: float,
+ reference_data_file_name: str,
+ output_folder: str = OUTPUT_FOLDER,
+ chosen_cols: List[str] = list(),
+ context_textbox: str = "",
+ aws_access_key_textbox: str = "",
+ aws_secret_key_textbox: str = "",
+ aws_region_textbox: str = "",
+ model_name_map: dict = model_name_map,
+ hf_api_key_textbox: str = "",
+ azure_endpoint_textbox: str = "",
+ existing_logged_content: list = list(),
+ api_url: str = None,
+ output_debug_files: str = output_debug_files,
+ log_output_files: list = list(),
+ reasoning_suffix: str = reasoning_suffix,
+ local_model: object = None,
+ tokenizer: object = None,
+ assistant_model: object = None,
+ summarise_everything_prompt: str = summarise_everything_prompt,
+ comprehensive_summary_format_prompt: str = comprehensive_summary_format_prompt,
+ comprehensive_summary_format_prompt_by_group: str = comprehensive_summary_format_prompt_by_group,
+ summarise_everything_system_prompt: str = summarise_everything_system_prompt,
+ do_summaries: str = "Yes",
+ progress=gr.Progress(track_tqdm=True),
+) -> Tuple[
+ List[str],
+ List[str],
+ int,
+ str,
+ List[str],
+ List[str],
+ int,
+ int,
+ int,
+ float,
+ List[dict],
+]:
+ """
+ Create an overall summary of all responses based on a topic summary table.
+
+ Args:
+ topic_summary_df (pd.DataFrame): DataFrame containing topic summaries
+ model_choice (str): Name of the LLM model to use
+ in_api_key (str): API key for model access
+ temperature (float): Temperature parameter for model generation
+ reference_data_file_name (str): Name of reference data file
+ output_folder (str, optional): Folder to save outputs. Defaults to OUTPUT_FOLDER.
+ chosen_cols (List[str], optional): Columns to analyze. Defaults to empty list.
+ context_textbox (str, optional): Additional context. Defaults to empty string.
+ aws_access_key_textbox (str, optional): AWS access key. Defaults to empty string.
+ aws_secret_key_textbox (str, optional): AWS secret key. Defaults to empty string.
+ aws_region_textbox (str, optional): AWS region. Defaults to empty string.
+ model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
+ hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
+ existing_logged_content (list, optional): List of existing logged content. Defaults to empty list.
+ output_debug_files (str, optional): Flag to indicate if debug files should be written. Defaults to "False".
+ log_output_files (list, optional): List of existing logged content. Defaults to empty list.
+ api_url (str, optional): API URL for inference-server models. Defaults to None.
+ reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
+ local_model (object, optional): Local model object. Defaults to None.
+ tokenizer (object, optional): Tokenizer object. Defaults to None.
+ assistant_model (object, optional): Assistant model object. Defaults to None.
+ summarise_everything_prompt (str, optional): Prompt for overall summary
+ comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
+ comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
+ summarise_everything_system_prompt (str, optional): System prompt for overall summary
+ do_summaries (str, optional): Whether to generate summaries. Defaults to "Yes".
+ progress (gr.Progress, optional): Progress tracker. Defaults to gr.Progress(track_tqdm=True).
+
+ Returns:
+ Tuple containing:
+ List[str]: Output files
+ List[str]: Text summarized outputs
+ int: Latest summary completed
+ str: Output metadata
+ List[str]: Summarized outputs
+ List[str]: Summarized outputs for DataFrame
+ int: Number of input tokens
+ int: Number of output tokens
+ int: Number of API calls
+ float: Time taken
+ List[dict]: List of logged content
+ """
+
+ out_metadata = list()
+ latest_summary_completed = 0
+ output_files = list()
+ txt_summarised_outputs = list()
+ summarised_outputs = list()
+ summarised_outputs_for_df = list()
+ input_tokens_num = 0
+ output_tokens_num = 0
+ number_of_calls_num = 0
+ time_taken = 0
+ out_message = list()
+ all_logged_content = list()
+ all_prompts_content = list()
+ all_summaries_content = list()
+ all_metadata_content = list()
+ all_groups_content = list()
+ all_batches_content = list()
+ all_model_choice_content = list()
+ all_validated_content = list()
+ task_type = "Overall summary"
+ all_task_type_content = list()
+ log_output_files = list()
+ all_logged_content = list()
+ all_file_names_content = list()
+ tic = time.perf_counter()
+
+ if "Group" not in topic_summary_df.columns:
+ topic_summary_df["Group"] = "All"
+
+ topic_summary_df = topic_summary_df.sort_values(
+ by=["Group", "Number of responses"], ascending=[True, False]
+ )
+
+ unique_groups = sorted(topic_summary_df["Group"].unique())
+
+ length_groups = len(unique_groups)
+
+ if context_textbox and "The context of this analysis is" not in context_textbox:
+ context_textbox = "The context of this analysis is '" + context_textbox + "'."
+
+ if length_groups > 1:
+ comprehensive_summary_format_prompt = (
+ comprehensive_summary_format_prompt_by_group
+ )
+ else:
+ comprehensive_summary_format_prompt = comprehensive_summary_format_prompt
+
+ # Ensure custom model_choice is registered in model_name_map
+ ensure_model_in_map(model_choice, model_name_map)
+
+ batch_file_path_details = create_batch_file_path_details(reference_data_file_name)
+ model_choice_clean = model_name_map[model_choice]["short_name"]
+ model_choice_clean_short = clean_column_name(
+ model_choice_clean, max_length=20, front_characters=False
+ )
+
+ tic = time.perf_counter()
+
+ if (
+ (model_choice == CHOSEN_LOCAL_MODEL_TYPE)
+ & (RUN_LOCAL_MODEL == "1")
+ & (not local_model)
+ ):
+ progress(0.1, f"Using model: {CHOSEN_LOCAL_MODEL_TYPE}")
+ local_model = get_model()
+ tokenizer = get_tokenizer()
+ assistant_model = get_assistant_model()
+
+ summary_loop = tqdm(
+ unique_groups, desc="Creating overall summary for groups", unit="groups"
+ )
+
+ if do_summaries == "Yes":
+ model_source = model_name_map[model_choice]["source"]
+ bedrock_runtime = connect_to_bedrock_runtime(
+ model_name_map,
+ model_choice,
+ aws_access_key_textbox,
+ aws_secret_key_textbox,
+ aws_region_textbox,
+ )
+
+ for summary_group in summary_loop:
+
+ print("Creating overall summary for group:", summary_group)
+
+ summary_text = topic_summary_df.loc[
+ topic_summary_df["Group"] == summary_group
+ ].to_markdown(index=False)
+
+ formatted_summary_prompt = [
+ summarise_everything_prompt.format(
+ topic_summary_table=summary_text,
+ summary_format=comprehensive_summary_format_prompt,
+ )
+ ]
+
+ formatted_summarise_everything_system_prompt = (
+ summarise_everything_system_prompt.format(
+ column_name=chosen_cols, consultation_context=context_textbox
+ )
+ )
+
+ if "Local" in model_source and reasoning_suffix:
+ formatted_summarise_everything_system_prompt = (
+ formatted_summarise_everything_system_prompt
+ + "\n"
+ + reasoning_suffix
+ )
+
+ try:
+ response, conversation_history, metadata, response_text = (
+ summarise_output_topics_query(
+ model_choice,
+ in_api_key,
+ temperature,
+ formatted_summary_prompt,
+ formatted_summarise_everything_system_prompt,
+ model_source,
+ bedrock_runtime,
+ local_model,
+ tokenizer=tokenizer,
+ assistant_model=assistant_model,
+ azure_endpoint=azure_endpoint_textbox,
+ api_url=api_url,
+ )
+ )
+ summarised_output_for_df = response_text
+ summarised_output = response
+ except Exception as e:
+ print(
+ "Cannot create overall summary for group:",
+ summary_group,
+ "due to:",
+ e,
+ )
+ summarised_output = ""
+ summarised_output_for_df = ""
+
+ summarised_outputs_for_df.append(summarised_output_for_df)
+ summarised_outputs.append(summarised_output)
+ txt_summarised_outputs.append(
+ f"""Group name: {summary_group}\n""" + summarised_output
+ )
+
+ out_metadata.extend(metadata)
+ out_metadata_str = ". ".join(out_metadata)
+
+ full_prompt = (
+ formatted_summarise_everything_system_prompt
+ + "\n"
+ + formatted_summary_prompt[0]
+ )
+
+ (
+ current_prompt_content_logged,
+ current_summary_content_logged,
+ current_conversation_content_logged,
+ current_metadata_content_logged,
+ ) = process_debug_output_iteration(
+ output_debug_files,
+ output_folder,
+ batch_file_path_details,
+ model_choice_clean_short,
+ full_prompt,
+ summarised_output,
+ conversation_history,
+ metadata,
+ log_output_files,
+ task_type=task_type,
+ )
+
+ all_prompts_content.append(current_prompt_content_logged)
+ all_summaries_content.append(current_summary_content_logged)
+ # all_conversation_content.append(current_conversation_content_logged)
+ all_metadata_content.append(current_metadata_content_logged)
+ all_groups_content.append(summary_group)
+ all_batches_content.append("1")
+ all_model_choice_content.append(model_choice_clean_short)
+ all_validated_content.append("No")
+ all_task_type_content.append(task_type)
+ all_file_names_content.append(reference_data_file_name)
+ latest_summary_completed += 1
+ clean_column_name(summary_group)
+
+ # Write overall outputs to csv
+ overall_summary_output_csv_path = (
+ output_folder
+ + batch_file_path_details
+ + "_overall_summary_"
+ + model_choice_clean_short
+ + ".csv"
+ )
+ summarised_outputs_df = pd.DataFrame(
+ data={"Group": unique_groups, "Summary": summarised_outputs_for_df}
+ )
+ summarised_outputs_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
+ overall_summary_output_csv_path, index=None, encoding="utf-8-sig"
+ )
+ output_files.append(overall_summary_output_csv_path)
+
+ summarised_outputs_df_for_display = pd.DataFrame(
+ data={"Group": unique_groups, "Summary": summarised_outputs}
+ )
+ summarised_outputs_df_for_display["Summary"] = (
+ summarised_outputs_df_for_display["Summary"]
+ .apply(lambda x: markdown.markdown(x) if isinstance(x, str) else x)
+ .str.replace(r"\n", "
", regex=False)
+ )
+ html_output_table = summarised_outputs_df_for_display.to_html(
+ index=False, escape=False
+ )
+
+ output_files = list(set(output_files))
+
+ input_tokens_num, output_tokens_num, number_of_calls_num = (
+ calculate_tokens_from_metadata(
+ out_metadata_str, model_choice, model_name_map
+ )
+ )
+
+ # Check if beyond max time allowed for processing and break if necessary
+ toc = time.perf_counter()
+ time_taken = toc - tic
+
+ out_message = "\n".join(out_message)
+ out_message = (
+ out_message
+ + " "
+ + f"Overall summary finished processing. Total time: {time_taken:.2f}s"
+ )
+ print(out_message)
+
+ # Combine the logged content into a list of dictionaries
+ all_logged_content = [
+ {
+ "prompt": prompt,
+ "response": summary,
+ "metadata": metadata,
+ "batch": batch,
+ "model_choice": model_choice,
+ "validated": validated,
+ "group": group,
+ "task_type": task_type,
+ "file_name": file_name,
+ }
+ for prompt, summary, metadata, batch, model_choice, validated, group, task_type, file_name in zip(
+ all_prompts_content,
+ all_summaries_content,
+ all_metadata_content,
+ all_batches_content,
+ all_model_choice_content,
+ all_validated_content,
+ all_groups_content,
+ all_task_type_content,
+ all_file_names_content,
+ )
+ ]
+
+ if isinstance(existing_logged_content, pd.DataFrame):
+ existing_logged_content = existing_logged_content.to_dict(orient="records")
+
+ out_logged_content = existing_logged_content + all_logged_content
+
+ return (
+ output_files,
+ html_output_table,
+ summarised_outputs_df,
+ out_metadata_str,
+ input_tokens_num,
+ output_tokens_num,
+ number_of_calls_num,
+ time_taken,
+ out_message,
+ out_logged_content,
+ )