import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Tuple
import boto3
import gradio as gr
import markdown
import pandas as pd
import spaces
from gradio import Progress as progress
from tqdm import tqdm
from tools.config import (
AWS_ACCESS_KEY,
AWS_LLM_PII_OPTION,
AWS_REGION,
AWS_SECRET_KEY,
CLOUD_LLM_PII_MODEL_CHOICE,
CLOUD_SUMMARISATION_MODEL_CHOICE,
DEFAULT_INFERENCE_SERVER_PII_MODEL,
INFERENCE_SERVER_API_URL,
INFERENCE_SERVER_PII_OPTION,
LLM_CONTEXT_LENGTH,
LLM_MAX_NEW_TOKENS,
LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE,
LOCAL_TRANSFORMERS_LLM_PII_OPTION,
MAX_SPACES_GPU_RUN_TIME,
OUTPUT_FOLDER,
PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
REASONING_SUFFIX,
RUN_AWS_FUNCTIONS,
SUMMARY_PAGE_GROUP_MAX_WORKERS,
model_name_map,
)
from tools.file_conversion import is_pdf, word_level_ocr_df_to_line_level_ocr_df
from tools.helper_functions import (
clean_column_name,
create_batch_file_path_details,
get_file_name_no_ext,
)
from tools.llm_funcs import (
calculate_tokens_from_metadata,
construct_azure_client,
construct_gemini_generative_model,
load_model,
process_requests,
)
max_tokens = LLM_MAX_NEW_TOKENS
reasoning_suffix = REASONING_SUFFIX
max_text_length = 500
###
# System prompt
###
generic_system_prompt = """You are a researcher analysing a document. Use British English spelling and grammar."""
system_prompt = """You are a researcher analysing a document. Use British English spelling and grammar."""
markdown_additional_prompt = """ You will be given a request for a markdown table. You must respond with ONLY the markdown table. Do not include any introduction, explanation, or concluding text."""
###
# SUMMARISE TOPICS PROMPT
###
summary_assistant_prefill = ""
summarise_topic_descriptions_system_prompt = system_prompt
summarise_topic_descriptions_prompt = """Your task is to make a consolidated summary of the text below. {summary_format}
Return only the summary and no other text. Do not mention specific response numbers in the summary.{additional_summary_instructions}
Text to summarise:
{summaries}
Summary:"""
concise_summary_format_prompt = "Return a concise summary that summarises only the most important themes from the original text"
detailed_summary_format_prompt = (
"Return a summary that includes as much detail as possible from the original text"
)
###
# OVERALL SUMMARY PROMPTS
###
summarise_everything_system_prompt = system_prompt
summarise_everything_prompt = """Below is a table that gives an overview of the main issues related to a document.
Your task is to summarise the text in the table below. {summary_format}. Return only the summary and no other text. Use headers and paragraphs to structure the summary where appropriate. Format the output for Excel display using: **bold text** for main headings, • bullet points for sub-items, and line breaks between sections. Avoid markdown symbols like # or ##. {additional_summary_instructions}
Table to summarise:
{topic_summary_table}
Summary:"""
def _summarisation_upload_to_paths(file_upload):
"""Normalise Gradio file input to a list of file paths (str, list, or dict with 'name')."""
if not file_upload:
return []
paths = []
if isinstance(file_upload, str):
paths.append(file_upload)
elif isinstance(file_upload, list):
for item in file_upload:
if isinstance(item, str):
paths.append(item)
elif isinstance(item, dict):
paths.append(item.get("name") or item.get("path") or "")
elif hasattr(item, "name"):
paths.append(item.name)
elif hasattr(item, "path"):
paths.append(item.path)
elif isinstance(file_upload, dict):
paths.append(file_upload.get("name") or file_upload.get("path") or "")
elif hasattr(file_upload, "name"):
paths.append(file_upload.name)
elif hasattr(file_upload, "path"):
paths.append(file_upload.path)
return [p for p in paths if p and str(p).strip()]
def _upload_contains_pdf(file_upload):
"""Return True if the summarisation upload contains any PDF file."""
paths = _summarisation_upload_to_paths(file_upload)
return any(is_pdf(p) for p in paths)
###
# Document Summarisation Functions
###
def get_model_choice_from_inference_method(inference_method: str) -> str:
"""
Get the default model choice for a given inference method (for summarisation).
Uses the default values defined in config.py (CLOUD_SUMMARISATION_MODEL_CHOICE for cloud).
Args:
inference_method: One of "aws-bedrock", "local", "inference-server"
Returns:
str: The model choice string to use
"""
# Map inference method to model choice using defaults from config.py
if inference_method == "aws-bedrock":
return CLOUD_SUMMARISATION_MODEL_CHOICE
elif inference_method == "local":
return LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE
elif inference_method == "inference-server":
return DEFAULT_INFERENCE_SERVER_PII_MODEL
else:
raise ValueError(
f"Unknown inference method: {inference_method}. "
f"Expected one of: 'aws-bedrock', 'local', 'inference-server'"
)
def get_model_source_from_model_choice(model_choice: str) -> str:
"""
Determine model source from model_choice by comparing to defaults from config.py.
Does not check model_name_map - uses the defined defaults.
Args:
model_choice: The model choice string
Returns:
str: The model source ("AWS", "Local", or "inference-server")
"""
# Compare model_choice to the default config values to determine source
if model_choice == LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE:
return "Local"
elif model_choice == DEFAULT_INFERENCE_SERVER_PII_MODEL:
return "inference-server"
elif (
model_choice == CLOUD_LLM_PII_MODEL_CHOICE
or model_choice == CLOUD_SUMMARISATION_MODEL_CHOICE
):
return "AWS"
else:
# If it doesn't match any default, infer from common patterns
# AWS Bedrock models typically have "amazon." or "anthropic." prefix
if model_choice.startswith("amazon.") or model_choice.startswith("anthropic."):
return "AWS"
# Inference server models are often custom names
# Default to AWS for backward compatibility, but could be inference-server
# Since we're using defaults, assume AWS if it's not clearly local
return "AWS"
def load_csv_files_to_dataframe(file_input):
"""
Load CSV files from Gradio file input and combine them into a single DataFrame.
Similar to how duplicate pages function handles file input.
Args:
file_input: Gradio file input (can be a single file, list of files, or file objects)
Returns:
pd.DataFrame: Combined DataFrame with columns page, line, and text
"""
if not file_input:
return pd.DataFrame(columns=["page", "line", "text"])
# Handle different input types (similar to run_tabular_duplicate_detection)
file_paths = []
if isinstance(file_input, str):
file_paths.append(file_input)
elif isinstance(file_input, list):
for f_item in file_input:
if isinstance(f_item, str):
file_paths.append(f_item)
elif hasattr(f_item, "name"):
file_paths.append(f_item.name)
elif hasattr(file_input, "name"):
file_paths.append(file_input.name)
# Load and combine all CSV files
all_dfs = []
for file_path in file_paths:
try:
df = pd.read_csv(file_path)
# Convert word-level OCR to line-level if user uploaded word-level file
if "ocr_results_with_words" in os.path.basename(file_path) and (
"word_text" in df.columns and "text" not in df.columns
):
df = word_level_ocr_df_to_line_level_ocr_df(df)
# Ensure required columns exist
if "page" in df.columns and "line" in df.columns and "text" in df.columns:
all_dfs.append(df[["page", "line", "text"]])
else:
print(
f"Warning: {file_path} does not have required columns (page, line, text)"
)
except Exception as e:
print(f"Error loading {file_path}: {e}")
if not all_dfs:
return pd.DataFrame(columns=["page", "line", "text"])
# Combine all DataFrames
combined_df = pd.concat(all_dfs, ignore_index=True)
return combined_df
# Wrapper function to convert inference method to model choice
@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
def summarise_document_wrapper(
all_page_line_level_ocr_results_df,
output_folder,
summarisation_inference_method,
summarisation_api_key,
summarisation_temperature,
file_name,
summarisation_context,
summarisation_aws_access_key,
summarisation_aws_secret_key,
summarisation_hf_api_key,
summarisation_azure_endpoint,
summarisation_format,
summarisation_additional_instructions,
summarisation_max_pages_per_group,
in_summarisation_ocr_files=None,
):
"""
Wrapper to select the correct model and format for document summarization, and optionally
load input OCR CSV files if they are provided.
Args:
all_page_line_level_ocr_results_df (pd.DataFrame): Pre-loaded DataFrame containing the line-level OCR results.
output_folder (str): Path to folder where outputs should be saved.
summarisation_inference_method (str): String specifying which inference/LLM method to use ('aws-bedrock', etc).
summarisation_api_key (str): API key for the selected inference method, if required.
summarisation_temperature (float): The temperature parameter for the model (controls randomness).
file_name (str): Name to use as a base for output files.
summarisation_context (str): Additional context string to include in the summarization.
summarisation_aws_access_key (str): AWS access key if using AWS inference.
summarisation_aws_secret_key (str): AWS secret key if using AWS inference.
summarisation_hf_api_key (str): HuggingFace API key if required.
summarisation_azure_endpoint (str): Endpoint string if using Azure inference.
summarisation_format (str): Format for the summary output (e.g., "bullets", "structured").
summarisation_additional_instructions (str): Extra instructions to pass to the summarization LLM.
summarisation_max_pages_per_group (int): Maximum number of pages to group per LLM summarization pass.
in_summarisation_ocr_files (str | list | object, optional): One or more file paths or file-like objects to OCR results in CSV format.
Returns:
Output of the downstream summarisation process (see next code section for details).
"""
"""Wrapper to convert inference method selection to model choice and load CSV files."""
# Map inference method option to inference method string
inference_method_map = {
AWS_LLM_PII_OPTION: "aws-bedrock",
LOCAL_TRANSFORMERS_LLM_PII_OPTION: "local",
INFERENCE_SERVER_PII_OPTION: "inference-server",
}
inference_method = inference_method_map.get(
summarisation_inference_method, "aws-bedrock"
)
# Use config default for region
summarisation_aws_region = AWS_REGION
summarisation_api_url = INFERENCE_SERVER_API_URL
# Get model choice from inference method
model_choice = get_model_choice_from_inference_method(inference_method)
# Load CSV files if provided, otherwise use the dataframe
if in_summarisation_ocr_files:
ocr_df = load_csv_files_to_dataframe(in_summarisation_ocr_files)
else:
ocr_df = all_page_line_level_ocr_results_df
# If file_name is None or empty, derive it from in_summarisation_ocr_files
if not file_name or file_name.strip() == "":
if in_summarisation_ocr_files:
# Extract file path from in_summarisation_ocr_files (similar to load_csv_files_to_dataframe)
file_paths = []
if isinstance(in_summarisation_ocr_files, str):
file_paths.append(in_summarisation_ocr_files)
elif isinstance(in_summarisation_ocr_files, list):
for f_item in in_summarisation_ocr_files:
if isinstance(f_item, str):
file_paths.append(f_item)
elif hasattr(f_item, "name"):
file_paths.append(f_item.name)
elif hasattr(in_summarisation_ocr_files, "name"):
file_paths.append(in_summarisation_ocr_files.name)
# Get the first file path and extract filename prefix
if file_paths:
first_file_path = file_paths[0]
# Get basename without extension
basename = os.path.basename(first_file_path)
filename_without_ext, _ = os.path.splitext(basename)
# Take first 20 characters, removing any invalid filename characters
filename_prefix = filename_without_ext[:20]
# Remove any invalid characters for filenames
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename_prefix = filename_prefix.replace(char, "_")
file_name = filename_prefix if filename_prefix else "document"
else:
file_name = "document"
else:
file_name = "document"
# Call the actual summarise_document function (timed for usage logs)
start_time = time.perf_counter()
(
output_files,
status_message,
llm_model_name,
llm_total_input_tokens,
llm_total_output_tokens,
summary_display_text,
) = summarise_document(
ocr_df,
output_folder,
model_choice,
summarisation_api_key,
summarisation_temperature,
file_name,
summarisation_context,
summarisation_aws_access_key,
summarisation_aws_secret_key,
summarisation_aws_region,
summarisation_hf_api_key,
summarisation_azure_endpoint,
summarisation_api_url,
summarisation_format,
summarisation_additional_instructions,
max_pages_per_group=summarisation_max_pages_per_group,
)
elapsed_seconds = round(time.perf_counter() - start_time, 1)
return (
output_files,
status_message,
llm_model_name,
llm_total_input_tokens,
llm_total_output_tokens,
summary_display_text,
elapsed_seconds,
)
def group_pages_by_context_length(
all_page_line_level_ocr_results_df: pd.DataFrame,
context_length: int = LLM_CONTEXT_LENGTH,
tokenizer=None,
model_source: str = "Local",
max_pages_per_group: int = 30,
) -> List[Tuple[List[int], str]]:
"""
Group pages into chunks that fit within the LLM context length.
Splits pages into roughly equal-sized groups (e.g. 56 pages with room for 50
per context -> two groups of 28, not 50 and 6). Each page is prefixed with
'=== Page x ==='.
Args:
all_page_line_level_ocr_results_df: DataFrame with columns 'page', 'line', 'text'
context_length: Maximum context length in tokens
tokenizer: Tokenizer for accurate token counting
model_source: Source of the model for token counting
Returns:
List of tuples: (list of page numbers, formatted text for that group)
"""
if (
all_page_line_level_ocr_results_df is None
or all_page_line_level_ocr_results_df.empty
):
return []
# Group by page and concatenate text
page_texts = {}
for _, row in all_page_line_level_ocr_results_df.iterrows():
page = int(row["page"])
text = str(row.get("text", ""))
if page not in page_texts:
page_texts[page] = []
page_texts[page].append(text)
# Format each page with header and get token count per page
page_list = [] # (page_num, formatted_page, page_tokens)
for page_num in sorted(page_texts.keys()):
page_text = " ".join(page_texts[page_num])
formatted_page = f"=== Page {page_num} ===\n{page_text}"
page_tokens = count_tokens_in_text(formatted_page, tokenizer, model_source)
page_list.append((page_num, formatted_page, page_tokens))
# Reserve some tokens for the prompt template
reserved_tokens = 500
available_tokens = context_length - reserved_tokens
if not page_list:
return []
# Sanitise max_pages_per_group
try:
max_pages_per_group_int = int(max_pages_per_group)
except Exception:
max_pages_per_group_int = 30
if max_pages_per_group_int < 1:
max_pages_per_group_int = 1
# Step 1: Greedy pass to determine minimum number of groups by tokens
k_token = 0
cur_tokens = 0
for _, _, pt in page_list:
if cur_tokens + pt > available_tokens and cur_tokens > 0:
k_token += 1
cur_tokens = 0
cur_tokens += pt
k_token += 1 # last group
n = len(page_list)
# Also enforce a maximum pages-per-group cap
k_pages = (n + max_pages_per_group_int - 1) // max_pages_per_group_int
# Final number of groups must satisfy both token limit and max-pages limit
k = max(k_token, k_pages)
# Step 2: Target pages per group for roughly equal split (e.g. 56 pages, 2 groups -> 28, 28)
q, r = n // k, n % k
target_per_group = [q + 1] * r + [q] * (k - r)
# Step 3: Assign pages to groups with target sizes, respecting token limit
groups = []
page_idx = 0
for group_idx in range(k):
target = min(target_per_group[group_idx], max_pages_per_group_int)
current_group_pages = []
current_group_text = ""
current_tokens = 0
while page_idx < n and len(current_group_pages) < target:
page_num, formatted_page, page_tokens = page_list[page_idx]
if current_tokens + page_tokens > available_tokens and current_group_pages:
break # full by token limit; start next group
current_group_pages.append(page_num)
if current_group_text:
current_group_text += "\n\n" + formatted_page
else:
current_group_text = formatted_page
current_tokens += page_tokens
page_idx += 1
if current_group_pages:
groups.append((current_group_pages, current_group_text))
# Any remaining pages (e.g. group hit token limit before target) go into final group(s)
while page_idx < n:
current_group_pages = []
current_group_text = ""
current_tokens = 0
while page_idx < n and len(current_group_pages) < max_pages_per_group_int:
page_num, formatted_page, page_tokens = page_list[page_idx]
if current_tokens + page_tokens > available_tokens and current_group_pages:
break
# If even a single page exceeds limit, add it anyway to avoid infinite loop
current_group_pages.append(page_num)
if current_group_text:
current_group_text += "\n\n" + formatted_page
else:
current_group_text = formatted_page
current_tokens += page_tokens
page_idx += 1
if current_group_pages:
groups.append((current_group_pages, current_group_text))
return groups
def summarise_text_chunk(
text_chunk: str,
model_choice: str,
in_api_key: str,
temperature: float,
context_textbox: str = "",
aws_access_key_textbox: str = "",
aws_secret_key_textbox: str = "",
aws_region_textbox: str = "",
model_name_map: dict = None,
hf_api_key_textbox: str = "",
azure_endpoint_textbox: str = "",
api_url: str = None,
reasoning_suffix: str = "",
local_model=None,
tokenizer=None,
assistant_model=None,
summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
additional_summary_instructions: str = "",
) -> Tuple[str, str, dict]:
"""
Summarise a single text chunk using the summarise_output_topics_query function.
Returns:
Tuple of (summary_text, full_prompt, metadata)
"""
from tools.config import (
model_name_map as default_model_name_map,
)
# Note: load_model is already imported at the top of the file
if model_name_map is None:
model_name_map = default_model_name_map
if additional_summary_instructions:
additional_summary_instructions = (
"Important additional instructions to follow closely: "
+ additional_summary_instructions
)
formatted_summary_prompt = [
summarise_topic_descriptions_prompt.format(
summaries=text_chunk,
summary_format=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
]
# Format system prompt
formatted_system_prompt = summarise_topic_descriptions_system_prompt.format(
column_name="document text",
consultation_context=context_textbox if context_textbox else "",
)
# Determine model source from model_choice using defaults from config.py
# Does not check model_name_map - uses the defined defaults
model_source = get_model_source_from_model_choice(model_choice)
# Setup model based on model source
# Load model and tokenizer together to ensure they're from the same source
# This prevents mismatches that could occur if they're loaded separately
# Similar to llm_funcs.py pattern (lines 830-839) and llm_entity_detection.py (lines 519-533)
if (model_source == "Local") & (local_model is None or tokenizer is None):
progress(0.1, f"Using model: {LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE}")
# Use load_model() to ensure both are loaded atomically
# This is safer than calling get_pii_model() and get_pii_tokenizer() separately
loaded_model, loaded_tokenizer, loaded_assistant_model = load_model()
if local_model is None:
local_model = loaded_model
if tokenizer is None:
tokenizer = loaded_tokenizer
if assistant_model is None:
assistant_model = loaded_assistant_model
# Setup bedrock for AWS models
# Use the same approach as file_redaction.py (lines 939-969) for consistency
bedrock_runtime = None
if model_source == "AWS":
# Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
region = aws_region_textbox if aws_region_textbox else AWS_REGION
if RUN_AWS_FUNCTIONS and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS:
print("Connecting to Bedrock via existing SSO connection")
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
elif aws_access_key_textbox and aws_secret_key_textbox:
print(
"Connecting to Bedrock using AWS access key and secret keys from user input."
)
bedrock_runtime = boto3.client(
"bedrock-runtime",
aws_access_key_id=aws_access_key_textbox,
aws_secret_access_key=aws_secret_key_textbox,
region_name=region,
)
elif RUN_AWS_FUNCTIONS:
print("Connecting to Bedrock via existing SSO connection")
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
print("Getting Bedrock credentials from environment variables")
bedrock_runtime = boto3.client(
"bedrock-runtime",
aws_access_key_id=AWS_ACCESS_KEY,
aws_secret_access_key=AWS_SECRET_KEY,
region_name=region,
)
else:
bedrock_runtime = None
out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
print(out_message)
raise Exception(out_message)
# Note: Gemini and Azure/OpenAI clients are handled within summarise_output_topics_query
# via the process_requests function, so we don't need to set them up here
# Similar to how llm_entity_detection.py handles them (lines 554-584)
# Apply reasoning suffix if needed
if reasoning_suffix:
is_gpt_oss_model = (
"gpt-oss" in model_choice.lower() or "gpt_oss" in model_choice.lower()
)
if is_gpt_oss_model or ("Local" in model_source and reasoning_suffix):
formatted_system_prompt = formatted_system_prompt + "\n" + reasoning_suffix
# Call the summarisation function
try:
response, conversation_history, metadata, response_text = (
summarise_output_topics_query(
model_choice,
in_api_key,
temperature,
formatted_summary_prompt,
formatted_system_prompt,
model_source,
bedrock_runtime,
local_model if local_model else [],
tokenizer if tokenizer else [],
assistant_model if assistant_model else [],
azure_endpoint_textbox,
api_url,
)
)
full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt[0]
return response_text, full_prompt, metadata
except Exception as e:
print(f"Error summarising text chunk: {e}")
full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt[0]
return "", full_prompt, {}
def recursively_summarise(
summaries: List[str],
model_choice: str,
in_api_key: str,
temperature: float,
context_length: int = LLM_CONTEXT_LENGTH,
tokenizer=None,
model_source: str = "Local",
token_accumulator=None,
**kwargs,
) -> List[str]:
"""
Recursively summarise summaries until they fit within context length.
Args:
token_accumulator: Optional list to accumulate [input_tokens, output_tokens] from metadata
"""
# Check total length
combined_summaries = "\n\n".join(summaries)
total_tokens = count_tokens_in_text(combined_summaries, tokenizer, model_source)
# Reserve tokens for prompt
reserved_tokens = 500
available_tokens = context_length - reserved_tokens
if total_tokens <= available_tokens:
return summaries
# Need to summarise further - group summaries into chunks
groups = []
current_group = []
current_tokens = 0
for summary in summaries:
summary_tokens = count_tokens_in_text(summary, tokenizer, model_source)
if current_tokens + summary_tokens > available_tokens and current_group:
groups.append("\n\n".join(current_group))
current_group = [summary]
current_tokens = summary_tokens
else:
current_group.append(summary)
current_tokens += summary_tokens
if current_group:
groups.append("\n\n".join(current_group))
# Summarise each group
new_summaries = []
for group_text in groups:
summary_text, _, metadata = summarise_text_chunk(
group_text,
model_choice,
in_api_key,
temperature,
tokenizer=tokenizer,
model_source=model_source,
**kwargs,
)
if summary_text:
new_summaries.append(summary_text)
# Accumulate tokens if accumulator provided
if token_accumulator is not None and metadata:
# Convert metadata to string if it's a list
metadata_string = (
str(metadata) if not isinstance(metadata, str) else metadata
)
input_tokens, output_tokens, _ = calculate_tokens_from_metadata(
metadata_string, model_choice, model_name_map
)
token_accumulator[0] += input_tokens
token_accumulator[1] += output_tokens
# Recursively call if still too long
if len(new_summaries) > 1:
return recursively_summarise(
new_summaries,
model_choice,
in_api_key,
temperature,
context_length,
tokenizer,
model_source,
token_accumulator=token_accumulator,
**kwargs,
)
return new_summaries
def summarise_document(
all_page_line_level_ocr_results_df: pd.DataFrame,
output_folder: str,
model_choice: str,
in_api_key: str,
temperature: float,
file_name: str = "document",
context_textbox: str = "",
aws_access_key_textbox: str = "",
aws_secret_key_textbox: str = "",
aws_region_textbox: str = "",
hf_api_key_textbox: str = "",
azure_endpoint_textbox: str = "",
api_url: str = None,
summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
additional_summary_instructions: str = "",
max_pages_per_group: int = 30,
summary_page_group_max_workers: Optional[int] = None,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[List[str], str]:
"""
Main function to summarise a document from OCR results.
Args:
all_page_line_level_ocr_results_df (pd.DataFrame): DataFrame containing line-level OCR results.
output_folder (str): The folder where outputs will be saved.
model_choice (str): The model to use for summarization.
in_api_key (str): API key for the selected model/inference method.
temperature (float): LLM temperature hyperparameter.
file_name (str, optional): Name to use for the output files. Default is "document".
context_textbox (str, optional): Extra context for summarization. Default is "".
aws_access_key_textbox (str, optional): AWS access key, if using AWS. Default is "".
aws_secret_key_textbox (str, optional): AWS secret key, if using AWS. Default is "".
aws_region_textbox (str, optional): AWS region string. Default is "".
hf_api_key_textbox (str, optional): HuggingFace API key, if used. Default is "".
azure_endpoint_textbox (str, optional): Azure endpoint, if used. Default is "".
api_url (str, optional): API URL. Default is None.
summarise_format_radio (str, optional): Summary output format instructions. Default is detailed summary.
additional_summary_instructions (str, optional): Extra instructions for the summarization. Default is "".
max_pages_per_group (int, optional): Maximum number of pages to group per LLM pass. Default is 30.
progress (gr.Progress, optional): Gradio progress tracker. Default is Gradio Progress with tqdm.
Returns:
Tuple of (output_file_paths, status_message)
"""
import os
from datetime import datetime
from tools.llm_funcs import load_model
output_files = []
all_prompts = []
all_responses = []
all_token_counts = (
[]
) # Store (input_tokens, output_tokens) for each prompt/response
page_group_page_ranges = (
[]
) # Store (min_page, max_page) for each saved prompt/response
page_group_summaries = []
# Initialize token tracking variables
llm_total_input_tokens = 0
llm_total_output_tokens = 0
llm_model_name = ""
try:
# Determine model source from model_choice using defaults from config.py
# Does not check model_name_map - uses the defined defaults
model_source = get_model_source_from_model_choice(model_choice)
local_model = None
tokenizer = None
assistant_model = None
# Setup model based on model source - check for Local models
# Load model and tokenizer together to ensure they're from the same source
# This prevents mismatches that could occur if they're loaded separately
# Similar to llm_funcs.py pattern (lines 830-839) and llm_entity_detection.py (lines 519-533)
if model_source == "Local":
if local_model is None or tokenizer is None:
progress(0.05, "Loading local model...")
# Use load_model() to ensure both are loaded atomically
# This is safer than calling get_pii_model() and get_pii_tokenizer() separately
loaded_model, loaded_tokenizer, loaded_assistant_model = load_model()
if local_model is None:
local_model = loaded_model
if tokenizer is None:
tokenizer = loaded_tokenizer
if assistant_model is None:
assistant_model = loaded_assistant_model
# Step 1: Group pages by context length
progress(0.1, "Grouping pages by context length...")
page_groups = group_pages_by_context_length(
all_page_line_level_ocr_results_df,
LLM_CONTEXT_LENGTH,
tokenizer,
model_source,
max_pages_per_group=max_pages_per_group,
)
if not page_groups:
return [], "No OCR results found. Please run text extraction first."
# Step 2: Summarise each page group (optionally in parallel)
_summary_page_group_max_workers = (
summary_page_group_max_workers
if summary_page_group_max_workers is not None
else SUMMARY_PAGE_GROUP_MAX_WORKERS
)
use_parallel_page_groups = (
_summary_page_group_max_workers > 1 and len(page_groups) > 1
)
progress(0.2, f"Summarising {len(page_groups)} page groups...")
def _summarise_one_group(args):
i, page_nums, group_text = args
summary_text, full_prompt, metadata = summarise_text_chunk(
group_text,
model_choice,
in_api_key,
temperature,
context_textbox=context_textbox,
aws_access_key_textbox=aws_access_key_textbox,
aws_secret_key_textbox=aws_secret_key_textbox,
aws_region_textbox=aws_region_textbox,
hf_api_key_textbox=hf_api_key_textbox,
azure_endpoint_textbox=azure_endpoint_textbox,
api_url=api_url,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
summarise_format_radio=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
return (i, page_nums, summary_text, full_prompt, metadata)
if use_parallel_page_groups:
max_workers = min(_summary_page_group_max_workers, len(page_groups))
tasks = [
(i, page_nums, group_text)
for i, (page_nums, group_text) in enumerate(page_groups)
]
results_by_index = [None] * len(page_groups)
pbar = tqdm(
total=len(page_groups),
unit="groups",
desc="Summarising page groups",
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_summarise_one_group, t): t[0] for t in tasks
}
completed = 0
for future in as_completed(futures):
i, page_nums, summary_text, full_prompt, metadata = future.result()
results_by_index[i] = (
page_nums,
summary_text,
full_prompt,
metadata,
)
completed += 1
pbar.update(1)
progress(
0.2 + (completed / len(page_groups)) * 0.5,
f"Summarising page group {completed}/{len(page_groups)} (pages {min(page_nums)}-{max(page_nums)})...",
)
pbar.close()
# Build lists in page-group order
for i in range(len(page_groups)):
if results_by_index[i] is None:
continue
page_nums, summary_text, full_prompt, metadata = results_by_index[i]
if summary_text:
try:
min_page = int(min(page_nums)) if page_nums else 0
max_page = int(max(page_nums)) if page_nums else 0
except Exception:
min_page, max_page = 0, 0
page_group_page_ranges.append((min_page, max_page))
page_group_summaries.append(summary_text)
all_prompts.append(full_prompt)
all_responses.append(summary_text)
input_tokens, output_tokens = 0, 0
if metadata:
metadata_string = (
str(metadata) if not isinstance(metadata, str) else metadata
)
input_tokens, output_tokens, _ = calculate_tokens_from_metadata(
metadata_string, model_choice, model_name_map
)
llm_total_input_tokens += input_tokens
llm_total_output_tokens += output_tokens
if not llm_model_name and model_choice:
llm_model_name = model_choice
all_token_counts.append((input_tokens, output_tokens))
else:
seq_pbar = tqdm(
page_groups,
unit="groups",
desc="Summarising page groups",
)
for i, (page_nums, group_text) in enumerate(seq_pbar):
progress(
0.2 + (i / len(page_groups)) * 0.5,
f"Summarising page group {i+1}/{len(page_groups)} (pages {min(page_nums)}-{max(page_nums)})...",
)
summary_text, full_prompt, metadata = summarise_text_chunk(
group_text,
model_choice,
in_api_key,
temperature,
context_textbox=context_textbox,
aws_access_key_textbox=aws_access_key_textbox,
aws_secret_key_textbox=aws_secret_key_textbox,
aws_region_textbox=aws_region_textbox,
hf_api_key_textbox=hf_api_key_textbox,
azure_endpoint_textbox=azure_endpoint_textbox,
api_url=api_url,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
summarise_format_radio=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
if summary_text:
try:
min_page = int(min(page_nums)) if page_nums else 0
max_page = int(max(page_nums)) if page_nums else 0
except Exception:
min_page, max_page = 0, 0
page_group_page_ranges.append((min_page, max_page))
page_group_summaries.append(summary_text)
all_prompts.append(full_prompt)
all_responses.append(summary_text)
input_tokens, output_tokens = 0, 0
if metadata:
metadata_string = (
str(metadata) if not isinstance(metadata, str) else metadata
)
input_tokens, output_tokens, _ = calculate_tokens_from_metadata(
metadata_string, model_choice, model_name_map
)
llm_total_input_tokens += input_tokens
llm_total_output_tokens += output_tokens
if not llm_model_name and model_choice:
llm_model_name = model_choice
all_token_counts.append((input_tokens, output_tokens))
seq_pbar.close()
# Step 3: Recursively summarise if needed
progress(0.7, "Checking if recursive summarisation is needed...")
# Create token accumulator for recursive summarization
recursive_token_accumulator = [0, 0] # [input_tokens, output_tokens]
final_summaries = recursively_summarise(
page_group_summaries,
model_choice,
in_api_key,
temperature,
context_length=LLM_CONTEXT_LENGTH,
tokenizer=tokenizer,
model_source=model_source,
token_accumulator=recursive_token_accumulator,
context_textbox=context_textbox,
aws_access_key_textbox=aws_access_key_textbox,
aws_secret_key_textbox=aws_secret_key_textbox,
aws_region_textbox=aws_region_textbox,
hf_api_key_textbox=hf_api_key_textbox,
azure_endpoint_textbox=azure_endpoint_textbox,
api_url=api_url,
local_model=local_model,
assistant_model=assistant_model,
summarise_format_radio=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
# Add tokens from recursive summarization
llm_total_input_tokens += recursive_token_accumulator[0]
llm_total_output_tokens += recursive_token_accumulator[1]
# Step 4: Create overall summary
progress(0.85, "Creating overall summary...")
# Create a topic summary DataFrame for overall_summary: three columns only
summary_numbers = list(range(1, len(final_summaries) + 1))
if len(final_summaries) == len(page_groups):
page_ranges = [f"Pages {min(pg[0])}-{max(pg[0])}" for pg in page_groups]
else:
# Recursion combined some summaries - use "All" or full range
if len(final_summaries) == 1 and page_groups:
all_pages = [p for pg in page_groups for p in pg[0]]
page_ranges = [f"Pages {min(all_pages)}-{max(all_pages)}"]
else:
page_ranges = ["All"] * len(final_summaries)
topic_summary_df = pd.DataFrame(
{
"Summary number": summary_numbers,
"Page range": page_ranges,
"Summary": final_summaries,
}
)
# Call overall_summary
(
output_files,
html_output_table,
overall_summarised_outputs_df,
out_metadata_str,
overall_input_tokens,
overall_output_tokens,
number_of_calls_num,
time_taken,
out_message,
overall_logged_content,
overall_prompt,
overall_response,
) = overall_summary(
topic_summary_df=topic_summary_df,
model_choice=model_choice,
in_api_key=in_api_key,
temperature=temperature,
reference_data_file_name=file_name,
output_folder=output_folder,
context_textbox=context_textbox,
aws_access_key_textbox=aws_access_key_textbox,
aws_secret_key_textbox=aws_secret_key_textbox,
aws_region_textbox=aws_region_textbox,
hf_api_key_textbox=hf_api_key_textbox,
azure_endpoint_textbox=azure_endpoint_textbox,
api_url=api_url,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
summarise_format_radio=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
progress=progress,
)
llm_total_input_tokens += overall_input_tokens
llm_total_output_tokens += overall_output_tokens
# Extract summary texts from the DataFrame
if (
overall_summarised_outputs_df is not None
and not overall_summarised_outputs_df.empty
):
if "Summary" in overall_summarised_outputs_df.columns:
overall_summary_texts = overall_summarised_outputs_df[
"Summary"
].tolist()
else:
# Fallback: get from first column if "Summary" column doesn't exist
overall_summary_texts = overall_summarised_outputs_df.iloc[
:, 0
].tolist()
else:
overall_summary_texts = []
# Step 5: Save outputs
progress(0.95, "Saving output files...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name_clean = get_file_name_no_ext(file_name) if file_name else "document"
# Ensure file_name_clean is not empty
if not file_name_clean or file_name_clean.strip() == "":
file_name_clean = "document"
summaries_folder = os.path.join(output_folder, "summaries")
os.makedirs(summaries_folder, exist_ok=True)
# Save prompts and responses as .txt files for page group summaries
for i, (prompt, response) in enumerate(zip(all_prompts, all_responses)):
# Page range for this prompt/response pair
min_page, max_page = (
page_group_page_ranges[i] if i < len(page_group_page_ranges) else (0, 0)
)
page_range_slug = f"pages_{min_page}_{max_page}"
txt_file_path = os.path.join(
summaries_folder,
f"{file_name_clean}_{page_range_slug}_prompt_response_{timestamp}.txt",
)
# Get token counts for this prompt/response pair
input_tokens, output_tokens = (
all_token_counts[i] if i < len(all_token_counts) else (0, 0)
)
with open(txt_file_path, "w", encoding="utf-8") as f:
f.write("=" * 80 + "\n")
f.write("TOKEN INFORMATION\n")
f.write("=" * 80 + "\n")
f.write(f"Page Range: {min_page}-{max_page}\n")
f.write(f"Input Tokens: {input_tokens}\n")
f.write(f"Output Tokens: {output_tokens}\n")
f.write(f"Maximum Context Length: {LLM_CONTEXT_LENGTH}\n")
f.write(f"Model: {model_choice}\n")
f.write(f"Temperature: {temperature}\n")
f.write("=" * 80 + "\n\n")
f.write("=" * 80 + "\n")
f.write("PROMPT\n")
f.write("=" * 80 + "\n")
f.write(prompt)
f.write("\n\n" + "=" * 80 + "\n")
f.write("RESPONSE\n")
f.write("=" * 80 + "\n")
f.write(response)
output_files.append(txt_file_path)
# Save overall summary prompt/response
# Fallback: If we don't have prompt/response from logged_content, use summary texts
# This should rarely happen, but provides a safety net
if not overall_prompt and overall_summary_texts:
# Construct a basic prompt representation (this is a fallback, not ideal)
overall_prompt = (
f"Overall summary request for document: {file_name_clean}\n"
)
overall_prompt += f"Input: {len(final_summaries)} summary group(s) to combine into overall summary\n"
overall_prompt += f"Summary format: {summarise_format_radio}\n"
if additional_summary_instructions:
overall_prompt += (
f"Additional instructions: {additional_summary_instructions}\n"
)
# If we still don't have a response, use summary texts
if not overall_response and overall_summary_texts:
overall_response = (
"\n\n".join(overall_summary_texts)
if isinstance(overall_summary_texts, list)
else str(overall_summary_texts)
)
# Save overall summary .txt file if we have response content (always create if we have summary texts)
if overall_response or overall_summary_texts:
txt_file_path = os.path.join(
summaries_folder,
f"{file_name_clean}_overall_summary_prompt_response_{timestamp}.txt",
)
with open(txt_file_path, "w", encoding="utf-8") as f:
f.write("=" * 80 + "\n")
f.write("TOKEN INFORMATION\n")
f.write("=" * 80 + "\n")
f.write(f"Input Tokens: {overall_input_tokens}\n")
f.write(f"Output Tokens: {overall_output_tokens}\n")
f.write(f"Maximum Context Length: {LLM_CONTEXT_LENGTH}\n")
f.write(f"Model: {model_choice}\n")
f.write(f"Temperature: {temperature}\n")
f.write("=" * 80 + "\n\n")
f.write("=" * 80 + "\n")
f.write("PROMPT\n")
f.write("=" * 80 + "\n")
f.write(overall_prompt)
f.write("\n\n" + "=" * 80 + "\n")
f.write("RESPONSE\n")
f.write("=" * 80 + "\n")
f.write(overall_response)
output_files.append(txt_file_path)
# Save summaries as CSV
summary_data = {"Type": [], "Page_Range": [], "Summary": []}
# Add page group summaries
for i, (page_nums, summary) in enumerate(
zip([pg[0] for pg in page_groups], page_group_summaries)
):
summary_data["Type"].append("Page Group Summary")
summary_data["Page_Range"].append(f"{min(page_nums)}-{max(page_nums)}")
summary_data["Summary"].append(summary)
# Add final summaries if different from page group summaries
if final_summaries != page_group_summaries:
for i, summary in enumerate(final_summaries):
summary_data["Type"].append("Final Summary")
summary_data["Page_Range"].append(f"Group {i+1}")
summary_data["Summary"].append(summary)
# Add overall summary - ensure overall_summary_texts is a list of strings
if overall_summary_texts:
# Handle case where overall_summary_texts might be a single string
if isinstance(overall_summary_texts, str):
overall_summary_texts = [overall_summary_texts]
# Ensure each item is a string, not being iterated character by character
for summary in overall_summary_texts:
if isinstance(summary, str):
summary_data["Type"].append("Overall Summary")
summary_data["Page_Range"].append("All")
summary_data["Summary"].append(summary)
elif hasattr(summary, "__iter__") and not isinstance(summary, str):
# If it's iterable but not a string, convert to string
summary_str = str(summary)
summary_data["Type"].append("Overall Summary")
summary_data["Page_Range"].append("All")
summary_data["Summary"].append(summary_str)
summary_df = pd.DataFrame(summary_data)
csv_file_path = os.path.join(
summaries_folder, f"{file_name_clean}_summaries_{timestamp}.csv"
)
summary_df.to_csv(csv_file_path, index=False, encoding="utf-8-sig")
output_files.append(csv_file_path)
progress(1.0, "Summarisation complete!")
status_message = (
f"Summarisation complete! Generated {len(output_files)} output files."
)
# Prepare summary text for display (combine all overall summary texts)
summary_display_text = ""
if overall_summary_texts:
if isinstance(overall_summary_texts, list):
summary_display_text = "\n\n".join(overall_summary_texts)
else:
summary_display_text = str(overall_summary_texts)
return (
output_files,
status_message,
llm_model_name,
llm_total_input_tokens,
llm_total_output_tokens,
summary_display_text,
)
except Exception as e:
error_message = f"Error during summarisation: {str(e)}"
print(error_message)
import traceback
traceback.print_exc()
return (
output_files,
error_message,
llm_model_name,
llm_total_input_tokens,
llm_total_output_tokens,
"", # Empty summary display text on error
)
def join_unique_summaries(x):
unique_summaries = []
seen = set()
for s in x:
if pd.isna(s):
continue
# 1. Normalize whitespace and split lines
s_str = str(s).strip()
lines = s_str.split("\n")
for line in lines:
# 2. Aggressive Cleaning
# Remove "Rows X to Y:" prefix
line = re.sub(
r"^Rows\s+\d+\s+to\s+\d+:\s*", "", line, flags=re.IGNORECASE
).strip()
# Remove generic "Prefix:" if it exists (e.g., "Summary: ...")
if ": " in line:
parts = line.split(": ", 1)
if len(parts[0]) < 50 and " " not in parts[0]:
line = parts[1].strip()
# 3. Handle Invisible Characters (Crucial)
# Replace non-breaking spaces (\xa0) and multiple spaces with a single standard space
normalized_line = re.sub(r"\s+", " ", line).strip()
# 4. Check against Seen
if normalized_line and normalized_line not in seen:
unique_summaries.append(normalized_line)
seen.add(normalized_line)
return "\n".join(unique_summaries)
def sample_reference_table_summaries(
reference_df: pd.DataFrame,
random_seed: int,
no_of_sampled_summaries: int = 100,
sample_reference_table_checkbox: bool = False,
):
"""
Sample x number of summaries from which to produce summaries, so that the input token length is not too long.
"""
if sample_reference_table_checkbox:
all_summaries = pd.DataFrame(
columns=[
"General topic",
"Subtopic",
"Sentiment",
"Group",
"Response References",
"Summary",
]
)
if "Group" not in reference_df.columns:
reference_df["Group"] = "All"
reference_df_grouped = reference_df.groupby(
["General topic", "Subtopic", "Sentiment", "Group"]
)
if "Revised summary" in reference_df.columns:
out_message = "Summary has already been created for this file"
print(out_message)
raise Exception(out_message)
for group_keys, reference_df_group in reference_df_grouped:
if len(reference_df_group["General topic"]) > 1:
filtered_reference_df = reference_df_group.reset_index()
filtered_reference_df_unique = filtered_reference_df.drop_duplicates(
[
"General topic",
"Subtopic",
"Sentiment",
"Group",
"Start row of group",
]
)
# Sample n of the unique topic summaries PER GROUP. To limit the length of the text going into the summarisation tool
# This ensures each group gets up to no_of_sampled_summaries summaries, not the total across all groups
number_of_summaries_to_sample = min(
no_of_sampled_summaries, len(filtered_reference_df_unique)
)
print(
f"Sampling {number_of_summaries_to_sample} summaries from group {group_keys}, from dataframe filtered_reference_df_unique.head(5):\n{filtered_reference_df_unique.head(5)}"
)
filtered_reference_df_unique_sampled = (
filtered_reference_df_unique.sample(
number_of_summaries_to_sample, random_state=random_seed
)
)
all_summaries = pd.concat(
[all_summaries, filtered_reference_df_unique_sampled]
)
print("all_summaries.tail(5):\n", all_summaries.tail(5))
# If no responses/topics qualify, just go ahead with the original reference dataframe
if all_summaries.empty:
sampled_reference_table_df = reference_df
# Filter by sentiment only (Response References is a string in original df, not a count)
sampled_reference_table_df = sampled_reference_table_df.loc[
sampled_reference_table_df["Sentiment"] != "Not Mentioned"
]
else:
# Deduplicate summaries within each group before joining to prevent repeated summaries
sampled_reference_table_df = (
all_summaries.groupby(
["General topic", "Subtopic", "Sentiment", "Group"]
)
.agg(
{
"Response References": "size", # Count the number of references
"Summary": join_unique_summaries, # Join unique summaries only
}
)
.reset_index()
)
# Filter by sentiment and count (Response References is now a numeric count after aggregation)
sampled_reference_table_df = sampled_reference_table_df.loc[
(sampled_reference_table_df["Sentiment"] != "Not Mentioned")
& (sampled_reference_table_df["Response References"] > 1)
]
else:
sampled_reference_table_df = reference_df
summarised_references_markdown = sampled_reference_table_df.to_markdown(index=False)
return sampled_reference_table_df, summarised_references_markdown
def count_tokens_in_text(text: str, tokenizer=None, model_source: str = "Local") -> int:
"""
Count the number of tokens in the given text.
Args:
text (str): The text to count tokens for
tokenizer (object, optional): Tokenizer object for local models. Defaults to None.
model_source (str): Source of the model to determine tokenization method. Defaults to "Local".
Returns:
int: Number of tokens in the text
"""
if not text:
return 0
try:
if model_source == "Local" and tokenizer and len(tokenizer) > 0:
# Use local tokenizer if available
tokens = tokenizer[0].encode(text, add_special_tokens=False)
return len(tokens)
else:
# Fallback: rough estimation using word count (approximately 1.3 tokens per word)
word_count = len(text.split())
return int(word_count * 1.3)
except Exception as e:
print(f"Error counting tokens: {e}. Using word count estimation.")
# Fallback: rough estimation using word count
word_count = len(text.split())
return int(word_count * 1.3)
def clean_markdown_table_whitespace(markdown_text: str) -> str:
if not markdown_text:
return markdown_text
lines = markdown_text.splitlines()
cleaned_lines = []
for line in lines:
# 1. Clean all types of whitespace (including non-breaking spaces \u00A0)
# This turns every cell into a single-spaced string
cells = [re.sub(r"[\s\u00A0]+", " ", cell.strip()) for cell in line.split("|")]
# 2. Check if the row is effectively empty (only pipes or whitespace)
# We join the content; if nothing is left, it's a "ghost" row.
if not "".join(cells).strip():
continue
# 3. Handle the separator row specifically (e.g., |:---|---:|)
# We reset these to a small fixed width so they don't stretch the table.
if re.match(r"^[|\s\-:]+$", line):
new_separator = []
for cell in cells:
if not cell: # Outer pipes
new_separator.append("")
elif ":" in cell: # Alignment markers
left = ":" if cell.startswith(":") else "-"
right = ":" if cell.endswith(":") else "-"
new_separator.append(f"{left}---{right}")
else:
new_separator.append("---")
cleaned_lines.append("|".join(new_separator))
continue
# 4. Standard data row: Rejoin with single padding
# We filter out empty outer parts caused by leading/trailing pipes
formatted_row = (
"| "
+ " | ".join(
c for c in cells if c or cells.index(c) not in [0, len(cells) - 1]
)
+ " |"
)
# Simple fallback if the logic above is too aggressive for your specific table style:
# formatted_row = "|".join(f" {c} " if c else "" for c in cells)
cleaned_lines.append(formatted_row)
return "\n".join(cleaned_lines)
def summarise_output_topics_query(
model_choice: str,
in_api_key: str,
temperature: float,
formatted_summary_prompt: str,
summarise_topic_descriptions_system_prompt: str,
model_source: str,
bedrock_runtime: boto3.Session.client,
local_model=list(),
tokenizer=list(),
assistant_model=list(),
azure_endpoint: str = "",
api_url: str = None,
):
"""
Query an LLM to generate a summary of topics based on the provided prompts.
Args:
model_choice (str): The name/type of model to use for generation
in_api_key (str): API key for accessing the model service
temperature (float): Temperature parameter for controlling randomness in generation
formatted_summary_prompt (str): The formatted prompt containing topics to summarise
summarise_topic_descriptions_system_prompt (str): System prompt providing context and instructions
model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
local_model (object, optional): Local model object if using local inference. Defaults to empty list.
tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list.
Returns:
tuple: Contains:
- response_text (str): The generated summary text
- conversation_history (list): History of the conversation with the model
- whole_conversation_metadata (list): Metadata about the conversation
"""
conversation_history = list()
whole_conversation_metadata = list()
client = list()
client_config = {}
# Combine system prompt and user prompt for token counting
full_input_text = (
summarise_topic_descriptions_system_prompt + "\n" + formatted_summary_prompt[0]
if isinstance(formatted_summary_prompt, list)
else summarise_topic_descriptions_system_prompt
+ "\n"
+ formatted_summary_prompt
)
# Count tokens in the input text
input_token_count = count_tokens_in_text(full_input_text, tokenizer, model_source)
# Check if input exceeds context length
if input_token_count > LLM_CONTEXT_LENGTH:
error_message = f"Input text exceeds LLM context length. Input tokens: {input_token_count}, Max context length: {LLM_CONTEXT_LENGTH}. Please reduce the input text size."
print(error_message)
raise ValueError(error_message)
print(f"Input token count: {input_token_count} (Max: {LLM_CONTEXT_LENGTH})")
# Prepare Gemini models before query
if "Gemini" in model_source:
# print("Using Gemini model:", model_choice)
client, config = construct_gemini_generative_model(
in_api_key=in_api_key,
temperature=temperature,
model_choice=model_choice,
system_prompt=system_prompt,
max_tokens=max_tokens,
)
elif "Azure/OpenAI" in model_source:
client, config = construct_azure_client(
in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""),
endpoint=azure_endpoint,
)
elif "Local" in model_source:
pass
# print("Using local model: ", model_choice)
elif "AWS" in model_source:
pass
# print("Using AWS Bedrock model:", model_choice)
whole_conversation = [summarise_topic_descriptions_system_prompt]
# Process requests to large language model
(
responses,
conversation_history,
whole_conversation,
whole_conversation_metadata,
response_text,
) = process_requests(
formatted_summary_prompt,
summarise_topic_descriptions_system_prompt,
conversation_history,
whole_conversation,
whole_conversation_metadata,
client,
client_config,
model_choice,
temperature,
bedrock_runtime=bedrock_runtime,
model_source=model_source,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
assistant_prefill=summary_assistant_prefill,
api_url=api_url,
)
summarised_output = re.sub(
r"\n{2,}", "\n", response_text
) # Replace multiple line breaks with a single line break
summarised_output = re.sub(
r"^\n{1,}", "", summarised_output
) # Remove one or more line breaks at the start
summarised_output = re.sub(
r"\n", "
", summarised_output
) # Replace \n with more html friendly
tags
summarised_output = summarised_output.strip()
print("Finished summary query")
# Ensure the system prompt is included in the conversation history
try:
if isinstance(conversation_history, list):
has_system_prompt = False
if conversation_history:
first_entry = conversation_history[0]
if isinstance(first_entry, dict):
role_is_system = first_entry.get("role") == "system"
parts = first_entry.get("parts")
content_matches = (
parts == summarise_topic_descriptions_system_prompt
or (
isinstance(parts, list)
and summarise_topic_descriptions_system_prompt in parts
)
)
has_system_prompt = role_is_system and content_matches
elif isinstance(first_entry, str):
has_system_prompt = (
first_entry.strip().lower().startswith("system:")
)
if not has_system_prompt:
conversation_history.insert(
0,
{
"role": "system",
"parts": [summarise_topic_descriptions_system_prompt],
},
)
except Exception as _e:
# Non-fatal: if anything goes wrong, return the original conversation history
pass
return (
summarised_output,
conversation_history,
whole_conversation_metadata,
response_text,
)
def process_debug_output_iteration(
output_debug_files: str,
summaries_folder: str,
batch_file_path_details: str,
model_choice_clean_short: str,
final_system_prompt: str,
summarised_output: str,
conversation_history: list,
metadata: list,
log_output_files: list,
task_type: str,
) -> tuple[str, str, str, str]:
"""
Writes debug files for summary generation if output_debug_files is "True",
and returns the content of the prompt, summary, conversation, and metadata for the current iteration.
Args:
output_debug_files (str): Flag to indicate if debug files should be written.
summaries_folder (str): The folder where output files are saved.
batch_file_path_details (str): Details for the batch file path.
model_choice_clean_short (str): Shortened cleaned model choice.
final_system_prompt (str): The system prompt content.
summarised_output (str): The summarised output content.
conversation_history (list): The full conversation history.
metadata (list): The metadata for the conversation.
log_output_files (list): A list to append paths of written log files. This list is modified in-place.
task_type (str): The type of task being performed.
Returns:
tuple[str, str, str, str]: A tuple containing the content of the prompt,
summarised output, conversation history (as string),
and metadata (as string) for the current iteration.
"""
current_prompt_content = final_system_prompt
current_summary_content = summarised_output
if isinstance(conversation_history, list):
# Handle both list of strings and list of dicts
if conversation_history and isinstance(conversation_history[0], dict):
# Convert list of dicts to list of strings
conversation_strings = list()
for entry in conversation_history:
if "role" in entry and "parts" in entry:
role = entry["role"].capitalize()
message = (
" ".join(entry["parts"])
if isinstance(entry["parts"], list)
else str(entry["parts"])
)
conversation_strings.append(f"{role}: {message}")
else:
# Fallback for unexpected dict format
conversation_strings.append(str(entry))
current_conversation_content = "\n".join(conversation_strings)
else:
# Handle list of strings
current_conversation_content = "\n".join(conversation_history)
else:
current_conversation_content = str(conversation_history)
current_metadata_content = str(metadata)
current_task_type = task_type
if output_debug_files == "True":
try:
formatted_prompt_output_path = (
summaries_folder
+ batch_file_path_details
+ "_full_prompt_"
+ model_choice_clean_short
+ "_"
+ current_task_type
+ ".txt"
)
final_table_output_path = (
summaries_folder
+ batch_file_path_details
+ "_full_response_"
+ model_choice_clean_short
+ "_"
+ current_task_type
+ ".txt"
)
whole_conversation_path = (
summaries_folder
+ batch_file_path_details
+ "_full_conversation_"
+ model_choice_clean_short
+ "_"
+ current_task_type
+ ".txt"
)
whole_conversation_path_meta = (
summaries_folder
+ batch_file_path_details
+ "_metadata_"
+ model_choice_clean_short
+ "_"
+ current_task_type
+ ".txt"
)
with open(
formatted_prompt_output_path,
"w",
encoding="utf-8-sig",
errors="replace",
) as f:
f.write(current_prompt_content)
with open(
final_table_output_path, "w", encoding="utf-8-sig", errors="replace"
) as f:
f.write(current_summary_content)
with open(
whole_conversation_path, "w", encoding="utf-8-sig", errors="replace"
) as f:
f.write(current_conversation_content)
with open(
whole_conversation_path_meta,
"w",
encoding="utf-8-sig",
errors="replace",
) as f:
f.write(current_metadata_content)
log_output_files.append(formatted_prompt_output_path)
log_output_files.append(final_table_output_path)
log_output_files.append(whole_conversation_path)
log_output_files.append(whole_conversation_path_meta)
except Exception as e:
print(f"Error in writing debug files for summary: {e}")
# Return the content of the objects for the current iteration.
# The caller can then append these to separate lists if accumulation is desired.
return (
current_prompt_content,
current_summary_content,
current_conversation_content,
current_metadata_content,
)
def convert_markdown_headers_to_excel_format(text: str) -> str:
"""
Convert markdown headers to Excel-friendly format that preserves hierarchy.
Converts:
- # Header (H1) -> === HEADER === (most prominent)
- ## Header (H2) -> --- Header --- (medium)
- ### Header (H3) -> ── Header ── (less prominent)
- #### Header (H4) -> • Header (with bullet)
- ##### Header (H5) -> • Header (indented)
- ###### Header (H6) -> • Header (more indented)
Args:
text (str): Text containing markdown headers
Returns:
str: Text with markdown headers converted to Excel-friendly format
"""
if not text:
return text
lines = text.split("\n")
converted_lines = []
for line in lines:
# Match markdown headers (# through ######)
header_match = re.match(r"^(#{1,6})\s+(.+)$", line)
if header_match:
header_level = len(header_match.group(1)) # Number of # characters
header_text = header_match.group(2).strip()
if header_level == 1:
# H1: Most prominent - uppercase with double equals
converted_line = f"=== {header_text.upper()} ==="
elif header_level == 2:
# H2: Medium prominence - title case with dashes
converted_line = f"--- {header_text.title()} ---"
elif header_level == 3:
# H3: Less prominent - title case with single dashes
converted_line = f"── {header_text.title()} ──"
elif header_level == 4:
# H4: Bullet with no indentation
converted_line = f"• {header_text}"
elif header_level == 5:
# H5: Bullet with indentation
converted_line = f" • {header_text}"
else: # header_level == 6
# H6: Bullet with more indentation
converted_line = f" • {header_text}"
converted_lines.append(converted_line)
else:
converted_lines.append(line)
return "\n".join(converted_lines)
@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
def overall_summary(
topic_summary_df: pd.DataFrame,
model_choice: str,
in_api_key: str,
temperature: float,
reference_data_file_name: str,
output_folder: str = OUTPUT_FOLDER,
context_textbox: str = "",
aws_access_key_textbox: str = "",
aws_secret_key_textbox: str = "",
aws_region_textbox: str = "",
model_name_map: dict = model_name_map,
hf_api_key_textbox: str = "",
azure_endpoint_textbox: str = "",
existing_logged_content: list = list(),
api_url: str = None,
output_debug_files: str = "False",
log_output_files: list = list(),
reasoning_suffix: str = reasoning_suffix,
local_model: object = None,
tokenizer: object = None,
assistant_model: object = None,
summarise_everything_prompt: str = summarise_everything_prompt,
summarise_everything_system_prompt: str = summarise_everything_system_prompt,
summarise_format_radio: str = detailed_summary_format_prompt,
additional_summary_instructions: str = "",
do_summaries: str = "Yes",
progress=gr.Progress(track_tqdm=True),
) -> Tuple[
List[str],
List[str],
int,
str,
List[str],
List[str],
int,
int,
int,
float,
List[dict],
]:
"""
Create an overall summary of all responses based on a topic summary table.
Args:
topic_summary_df (pd.DataFrame): DataFrame with columns "Summary number", "Page range", "Summary"
model_choice (str): Name of the LLM model to use
in_api_key (str): API key for model access
temperature (float): Temperature parameter for model generation
reference_data_file_name (str): Name of reference data file
output_folder (str, optional): Folder to save outputs. Defaults to OUTPUT_FOLDER.
context_textbox (str, optional): Additional context. Defaults to empty string.
aws_access_key_textbox (str, optional): AWS access key. Defaults to empty string.
aws_secret_key_textbox (str, optional): AWS secret key. Defaults to empty string.
aws_region_textbox (str, optional): AWS region. Defaults to empty string.
model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
existing_logged_content (list, optional): List of existing logged content. Defaults to empty list.
output_debug_files (str, optional): Flag to indicate if debug files should be written. Defaults to "False".
log_output_files (list, optional): List of existing logged content. Defaults to empty list.
api_url (str, optional): API URL for inference-server models. Defaults to None.
reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
local_model (object, optional): Local model object. Defaults to None.
tokenizer (object, optional): Tokenizer object. Defaults to None.
assistant_model (object, optional): Assistant model object. Defaults to None.
summarise_everything_prompt (str, optional): Prompt for overall summary
summarise_everything_system_prompt (str, optional): System prompt for overall summary
summarise_format_radio (str, optional): Summary format radio. Defaults to summarise_format_radio.
additional_summary_instructions (str, optional): Additional summary instructions. Defaults to additional_summary_instructions.
do_summaries (str, optional): Whether to generate summaries. Defaults to "Yes".
progress (gr.Progress, optional): Progress tracker. Defaults to gr.Progress(track_tqdm=True).
Returns:
Tuple containing:
List[str]: Output files
List[str]: Text summarised outputs
int: Latest summary completed
str: Output metadata
List[str]: Summarised outputs
List[str]: Summarised outputs for DataFrame
int: Number of input tokens
int: Number of output tokens
int: Number of API calls
float: Time taken
List[dict]: List of logged content
"""
out_metadata = list()
latest_summary_completed = 0
output_files = list()
txt_summarised_outputs = list()
summarised_outputs = list()
summarised_outputs_for_df = list()
input_tokens_num = 0
output_tokens_num = 0
number_of_calls_num = 0
time_taken = 0
out_message = list()
all_logged_content = list()
all_prompts_content = list()
all_summaries_content = list()
all_metadata_content = list()
all_groups_content = list()
all_batches_content = list()
all_model_choice_content = list()
all_validated_content = list()
task_type = "Overall summary"
all_task_type_content = list()
log_output_files = list()
all_logged_content = list()
all_file_names_content = list()
tic = time.perf_counter()
summaries_folder = os.path.join(output_folder, "summaries")
os.makedirs(summaries_folder, exist_ok=True)
# Expect three columns: Summary number, Page range, Summary
required_cols = ["Summary number", "Page range", "Summary"]
if not all(c in topic_summary_df.columns for c in required_cols):
raise ValueError(
"topic_summary_df must have columns: Summary number, Page range, Summary"
)
topic_summary_df = topic_summary_df[required_cols].copy()
topic_summary_df = topic_summary_df.sort_values(by="Summary number", ascending=True)
# Single "group" containing the whole table (no grouping by Group column)
unique_groups = ["All"]
len(unique_groups)
if context_textbox and "The context of this analysis is" not in context_textbox:
context_textbox = "The context of this analysis is '" + context_textbox + "'."
# if length_groups > 1:
# comprehensive_summary_format_prompt = (
# comprehensive_summary_format_prompt_by_group
# )
# else:
# comprehensive_summary_format_prompt = comprehensive_summary_format_prompt
batch_file_path_details = create_batch_file_path_details(reference_data_file_name)
# Use model_choice directly as short_name, or try to get from model_name_map if available
if model_name_map and model_choice in model_name_map:
model_choice_clean = model_name_map[model_choice]["short_name"]
else:
# Use model_choice directly if not in model_name_map
model_choice_clean = model_choice
model_choice_clean_short = clean_column_name(
model_choice_clean, max_length=20, front_characters=False
)
tic = time.perf_counter()
# Determine model source from model_choice using defaults from config.py
# Does not check model_name_map - uses the defined defaults
model_source = get_model_source_from_model_choice(model_choice)
# Load model and tokenizer together to ensure they're from the same source
# This prevents mismatches that could occur if they're loaded separately
# Similar to llm_funcs.py pattern (lines 830-839) and llm_entity_detection.py (lines 519-533)
if (model_source == "Local") & (local_model is None or tokenizer is None):
progress(0.1, f"Using model: {LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE}")
# Use load_model() to ensure both are loaded atomically
# This is safer than calling get_pii_model() and get_pii_tokenizer() separately
loaded_model, loaded_tokenizer, loaded_assistant_model = load_model()
if local_model is None:
local_model = loaded_model
if tokenizer is None:
tokenizer = loaded_tokenizer
if assistant_model is None:
assistant_model = loaded_assistant_model
summary_loop = tqdm(
unique_groups, desc="Creating overall summary for groups", unit="groups"
)
if do_summaries == "Yes":
# Determine model source from model_choice using defaults from config.py
# Does not check model_name_map - uses the defined defaults
model_source = get_model_source_from_model_choice(model_choice)
# Setup bedrock for AWS models only
# Use the same approach as file_redaction.py (lines 939-969) for consistency
bedrock_runtime = None
if model_source == "AWS":
# Use aws_region_textbox if provided, otherwise fall back to AWS_REGION from config
region = aws_region_textbox if aws_region_textbox else AWS_REGION
if RUN_AWS_FUNCTIONS and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS:
print("Connecting to Bedrock via existing SSO connection")
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
elif aws_access_key_textbox and aws_secret_key_textbox:
print(
"Connecting to Bedrock using AWS access key and secret keys from user input."
)
bedrock_runtime = boto3.client(
"bedrock-runtime",
aws_access_key_id=aws_access_key_textbox,
aws_secret_access_key=aws_secret_key_textbox,
region_name=region,
)
elif RUN_AWS_FUNCTIONS:
print("Connecting to Bedrock via existing SSO connection")
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
print("Getting Bedrock credentials from environment variables")
bedrock_runtime = boto3.client(
"bedrock-runtime",
aws_access_key_id=AWS_ACCESS_KEY,
aws_secret_access_key=AWS_SECRET_KEY,
region_name=region,
)
else:
bedrock_runtime = None
out_message = "Cannot connect to AWS Bedrock service. Please provide access keys under LLM settings, or choose another model type."
print(out_message)
raise Exception(out_message)
for summary_group in summary_loop:
print("Creating overall summary for group:", summary_group)
# Use the full table (three columns: Summary number, Page range, Summary)
group_df = topic_summary_df.copy()
# Prepare the system prompt first (needed for token counting)
formatted_summarise_everything_system_prompt = (
summarise_everything_system_prompt.format(
consultation_context=context_textbox
)
)
# Apply reasoning suffix for GPT-OSS models (Local, inference-server, or AWS)
is_gpt_oss_model = (
"gpt-oss" in model_choice.lower() or "gpt_oss" in model_choice.lower()
)
if is_gpt_oss_model:
# Use default reasoning suffix if not set
effective_reasoning_suffix = (
reasoning_suffix if reasoning_suffix else "Reasoning: low"
)
if effective_reasoning_suffix:
formatted_summarise_everything_system_prompt = (
formatted_summarise_everything_system_prompt
+ "\n"
+ effective_reasoning_suffix
)
elif "Local" in model_source and reasoning_suffix:
# For other local models, use reasoning_suffix if provided
formatted_summarise_everything_system_prompt = (
formatted_summarise_everything_system_prompt
+ "\n"
+ reasoning_suffix
)
if additional_summary_instructions:
additional_summary_instructions = (
"Important additional instructions to follow closely: "
+ additional_summary_instructions
)
# Create a test prompt with empty table to get base token count
test_summary_text = ""
test_formatted_summary_prompt = [
summarise_everything_prompt.format(
topic_summary_table=test_summary_text,
summary_format=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
]
# Calculate base token count (system prompt + prompt template without table)
full_test_text = (
formatted_summarise_everything_system_prompt
+ "\n"
+ test_formatted_summary_prompt[0]
)
base_token_count = count_tokens_in_text(
full_test_text, tokenizer, model_source
)
# Calculate available tokens for the summary table
available_tokens = LLM_CONTEXT_LENGTH - base_token_count
# Ensure markdown table rows don't get visually "split" by newlines inside cells.
# Markdown tables don't reliably support multiline cells, so we replace internal
# newlines with a single-line representation before calling `to_markdown()`.
def _escape_markdown_table_cell(value):
if not isinstance(value, str):
return value
s = value.replace("\r\n", "\n").replace("\r", "\n")
# Keep content in a single cell/row in markdown output
s = s.replace("\n", "\\n")
# Avoid breaking markdown table syntax
s = s.replace("|", "\\|")
return s
if "Summary" in group_df.columns:
group_df["Summary"] = group_df["Summary"].apply(
_escape_markdown_table_cell
)
# Truncate DataFrame rows if needed to fit within context limit
if len(group_df) > 0:
# Start with all rows and check if they fit
current_summary_text = group_df.to_markdown(index=False)
current_summary_text = clean_markdown_table_whitespace(
current_summary_text
)
current_token_count = count_tokens_in_text(
current_summary_text, tokenizer, model_source
)
# If the full table exceeds available tokens, truncate rows
if current_token_count > available_tokens:
print(
f"Warning: Summary table for group '{summary_group}' exceeds context limit. "
f"Truncating rows. Table tokens: {current_token_count}, Available: {available_tokens}"
)
# Binary search approach: find the maximum number of rows that fit
# Start with all rows and reduce until we fit
num_rows = len(group_df)
min_rows = 0
max_rows = num_rows
best_df = group_df.iloc[:0] # Empty DataFrame as fallback
# Try to find the maximum number of rows that fit
while min_rows < max_rows:
mid_rows = (min_rows + max_rows + 1) // 2
test_df = group_df.iloc[:mid_rows]
test_summary = test_df.to_markdown(index=False)
test_summary = clean_markdown_table_whitespace(test_summary)
test_token_count = count_tokens_in_text(
test_summary, tokenizer, model_source
)
if test_token_count <= available_tokens:
best_df = test_df
min_rows = mid_rows
else:
max_rows = mid_rows - 1
# Use the best fitting DataFrame
group_df = best_df
print(
f"Truncated to {len(group_df)} rows (from {num_rows} original rows) "
f"to fit within context limit."
)
# Create summary_text from (possibly truncated) DataFrame
summary_text = group_df.to_markdown(index=False)
# Clean extraneous whitespace from markdown table cells
summary_text = clean_markdown_table_whitespace(summary_text)
formatted_summary_prompt = [
summarise_everything_prompt.format(
topic_summary_table=summary_text,
summary_format=summarise_format_radio,
additional_summary_instructions=additional_summary_instructions,
)
]
combined_prompt = (
formatted_summarise_everything_system_prompt
+ "\n"
+ formatted_summary_prompt[0]
)
try:
response, conversation_history, metadata, response_text = (
summarise_output_topics_query(
model_choice,
in_api_key,
temperature,
formatted_summary_prompt,
formatted_summarise_everything_system_prompt,
model_source,
bedrock_runtime,
local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
azure_endpoint=azure_endpoint_textbox,
api_url=api_url,
)
)
summarised_output_for_df = response_text
summarised_output = response
except Exception as e:
print(
"Cannot create overall summary for group:",
summary_group,
"due to:",
e,
)
summarised_output = ""
summarised_output_for_df = ""
# Remove multiple consecutive line breaks (2 or more) and replace with single line break
if summarised_output_for_df:
summarised_output_for_df = re.sub(
r"\n{2,}", "\n", summarised_output_for_df
)
# Convert markdown headers to Excel-friendly format
summarised_output_for_df = convert_markdown_headers_to_excel_format(
summarised_output_for_df
)
if summarised_output:
summarised_output = re.sub(r"\n{2,}", "\n", summarised_output)
summarised_outputs_for_df.append(summarised_output_for_df)
summarised_outputs.append(summarised_output)
txt_summarised_outputs.append(
f"""Group name: {summary_group}\n""" + summarised_output
)
out_metadata.extend(metadata)
out_metadata_str = ". ".join(out_metadata)
full_prompt = (
formatted_summarise_everything_system_prompt
+ "\n"
+ formatted_summary_prompt[0]
)
(
current_prompt_content_logged,
current_summary_content_logged,
current_conversation_content_logged,
current_metadata_content_logged,
) = process_debug_output_iteration(
output_debug_files,
summaries_folder,
batch_file_path_details,
model_choice_clean_short,
full_prompt,
summarised_output,
conversation_history,
metadata,
log_output_files,
task_type=task_type,
)
all_prompts_content.append(current_prompt_content_logged)
all_summaries_content.append(current_summary_content_logged)
# all_conversation_content.append(current_conversation_content_logged)
all_metadata_content.append(current_metadata_content_logged)
all_groups_content.append(summary_group)
all_batches_content.append("1")
all_model_choice_content.append(model_choice_clean_short)
all_validated_content.append("No")
all_task_type_content.append(task_type)
all_file_names_content.append(reference_data_file_name)
latest_summary_completed += 1
clean_column_name(summary_group)
# Write overall outputs to csv
overall_summary_output_csv_path = (
output_folder
+ "summaries/"
+ batch_file_path_details
+ "_overall_summary_"
+ model_choice_clean_short
+ ".csv"
)
summarised_outputs_df = pd.DataFrame(
data={"Group": unique_groups, "Summary": summarised_outputs_for_df}
)
if output_debug_files == "True":
summarised_outputs_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv(
overall_summary_output_csv_path, index=None, encoding="utf-8-sig"
)
output_files.append(overall_summary_output_csv_path)
summarised_outputs_df_for_display = pd.DataFrame(
data={"Group": unique_groups, "Summary": summarised_outputs}
)
summarised_outputs_df_for_display["Summary"] = (
summarised_outputs_df_for_display["Summary"]
.apply(lambda x: markdown.markdown(x) if isinstance(x, str) else x)
.str.replace(r"\n", "
", regex=False)
.str.replace(r"(
\s*){2,}", "
", regex=True)
)
html_output_table = summarised_outputs_df_for_display.to_html(
index=False, escape=False
)
output_files = list(set(output_files))
input_tokens_num, output_tokens_num, number_of_calls_num = (
calculate_tokens_from_metadata(
out_metadata_str, model_choice, model_name_map
)
)
# Check if beyond max time allowed for processing and break if necessary
toc = time.perf_counter()
time_taken = toc - tic
out_message = "\n".join(out_message)
out_message = (
out_message
+ " "
+ f"Overall summary finished processing. Total time: {time_taken:.2f}s"
)
print(out_message)
# Combine the logged content into a list of dictionaries
all_logged_content = [
{
"prompt": prompt,
"response": summary,
"metadata": metadata,
"batch": batch,
"model_choice": model_choice,
"validated": validated,
"group": group,
"task_type": task_type,
"file_name": file_name,
}
for prompt, summary, metadata, batch, model_choice, validated, group, task_type, file_name in zip(
all_prompts_content,
all_summaries_content,
all_metadata_content,
all_batches_content,
all_model_choice_content,
all_validated_content,
all_groups_content,
all_task_type_content,
all_file_names_content,
)
]
if isinstance(existing_logged_content, pd.DataFrame):
existing_logged_content = existing_logged_content.to_dict(orient="records")
out_logged_content = existing_logged_content + all_logged_content
return (
output_files,
html_output_table,
summarised_outputs_df,
out_metadata_str,
input_tokens_num,
output_tokens_num,
number_of_calls_num,
time_taken,
out_message,
out_logged_content,
combined_prompt,
response_text,
)