amps / utils /paper_utils.py
jibsn's picture
Update utils/paper_utils.py
1d00bb1 verified
import os
import math
import asyncio
from loguru import logger
from .r2_utils import (
upload_text_to_minio,
upload_dataframe_to_minio,
)
from .common_utils import escape_csv_field
BUCKET_NAME = "ai-scientist"
# Function to check relevance and obtain keywords as reason
async def is_relevant(title, abstract, topic, direction, chat_func):
"""
Check if a paper is relevant to a topic and obtain keywords as reason.
Args:
title (str): Title of the paper.
abstract (str): Abstract of the paper.
topic (str): Topic to check relevance against.
direction (str): Direction to check relevance against.
chat_func (function): Function to call the chat model.
Returns:
bool: True if the paper is relevant, False otherwise.
str: Keywords that indicate relevance.
"""
relevance_prompt = (
f"You are an academic expert in {topic}. Identify if the following paper is "
f"related to '{direction}' and list only the main keywords that indicate relevance:\n\n"
f"Title: {title}\nAbstract: {abstract}\n\n"
"Answer format:\n"
"Relevance: True or False\n"
"Keywords: [Comma-separated keywords]"
)
response = await chat_func(relevance_prompt)
if response is None:
return False, "Relevance check unavailable due to server error."
try:
response_text = response.choices[0].message.content
relevance = "True" in response_text
keywords = response_text.split(
"Keywords:")[-1].strip() if "Keywords:" in response_text else ""
return relevance, keywords
except AttributeError:
logger.error("Error in chat_func response format:", response)
return False, "Relevance check failed"
# Modified summarize_abstract function with error handling for failed completion requests
async def summarize_abstract(title, abstract, first_author, chat_func):
"""
Summarize the abstract of a research paper.
Args:
title (str): Title of the paper.
abstract (str): Abstract of the paper.
first_author (str): Name of the first author.
chat_func (function): Function to call the chat model.
Returns:
str: Summary of the abstract.
"""
formatted_author = reformat_author_name(first_author)
summary_prompt = (
f"Write a concise, high-level summary in 2-3 sentences, highlighting the study's "
f"purpose, specific methodology, main findings, and significance. Avoid generalizing "
f"or replacing specific method names or entities with vague language. Retain concrete terms "
f"and clear descriptions of methodology and findings.\n\n"
f"Title: {title}\nAbstract: {abstract}\n\n"
f"Summary by {formatted_author} et al.:"
)
response = await chat_func(summary_prompt)
if response is None:
return "Summary unavailable due to server error."
try:
result = response.choices[0].message.content
result_words = result.split()
summary = " ".join(result_words)
return summary
except AttributeError:
logger.error("Error in chat_func response format:", response)
return "Summary unavailable"
# Function to reformat first author name
def reformat_author_name(author_name):
"""
Reformat the first author name by removing commas.
Args:
author_name (str): Name of the first author.
Returns:
str: Reformatted name of the first author.
"""
try:
return author_name.replace(",", "")
except AttributeError:
return "Unknown Author"
# Function to generate 3-5 hierarchical subheadings related to the main topic
async def generate_subheadings(
relevant_papers_df, main_topic,
uuid, customer_name, model_name,
chat_func
):
"""
Generate 3-5 hierarchical subheadings related to the main topic based on the summaries of relevant papers.
Args:
relevant_papers_df: DataFrame containing relevant papers.
main_topic: Main topic of the research.
chat_func: Function to send chat messages to the chatbot.
Returns:
List[str]: List of generated subheadings.
"""
summaries = " ".join(relevant_papers_df['Summary'].tolist())
prompt = (
f"The main topic is '{main_topic}'. Based on this topic and the following summaries from relevant research papers, "
"generate 3-5 hierarchical subheadings that progressively explore the topic. Begin with broader subheadings and "
"move towards more specific themes, avoiding overlap in scope or content. Subheadings should be distinct and arranged "
"in a logical order suitable for a structured review.\n\n"
f"Summaries:\n{summaries}\n\n"
"Output format:\n- Subheading 1\n- Subheading 2\n- Subheading 3\n..."
)
response = await chat_func(prompt)
subheadings = response.choices[0].message.content.strip().splitlines()
logger.info("Generated Subheadings:\n" + "\n".join(subheadings))
output_filename = f"{customer_name}/{uuid}/{model_name}/generated_subheadings.txt"
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=output_filename,
file_content="\n".join(subheadings)
)
logger.info(f"Subheadings saved to {output_filename}")
return subheadings
# Function to assign summaries to subheadings with minimum allocation of references per subheading
async def assign_subheadings_to_summaries(
relevant_papers_df,
subheadings,
uuid, customer_name, model_name,
chat_func
):
"""
Assign summaries to subheadings with minimum allocation of references per subheading.
Args:
relevant_papers_df: DataFrame containing relevant papers.
subheadings: List of subheadings.
uuid: Unique identifier for the task.
customer_name: Name of the customer.
chat_func: Function to send chat messages to the chatbot.
Returns:
DataFrame with assigned subheadings.
"""
total_papers = len(relevant_papers_df)
min_papers_per_subheading = math.ceil(
total_papers / (len(subheadings) + 1))
assigned_subheadings = []
prompts = []
for summary in relevant_papers_df['Summary']:
prompt = (
"Given the following subheadings and a research paper summary, determine the most appropriate subheading "
"for this summary. Each subheading should cover a unique aspect of the main topic without overlap. "
"Select the best-fitting subheading based on thematic relevance and coherence with similar studies.\n\n"
f"Subheadings:\n{subheadings}\n\n"
f"Summary:\n{summary}\n\n"
"Output format:\nSubheading: [Chosen subheading]"
)
prompts.append(prompt)
responses = await asyncio.gather(
*(chat_func(prompt) for prompt in prompts)
)
for response in responses:
assigned_subheading = response.choices[0].message.content.split(": ")[1]
assigned_subheadings.append(assigned_subheading)
relevant_papers_df['Assigned Subheading'] = assigned_subheadings
# Ensure minimum papers per subheading
counts = relevant_papers_df['Assigned Subheading'].value_counts().to_dict()
for subheading in subheadings:
if counts.get(subheading, 0) < min_papers_per_subheading:
extra_summaries = relevant_papers_df[relevant_papers_df['Assigned Subheading'] != subheading].sample(
min_papers_per_subheading - counts.get(subheading, 0)
)
relevant_papers_df.loc[extra_summaries.index,
'Assigned Subheading'] = subheading
prefix = f"{customer_name}/{uuid}/{model_name}/"
output_dir = prefix
csv_filename = os.path.join(output_dir, f"assigned_subheadings.csv")
# relevant_papers_df.to_csv(csv_filename, index=False, encoding='utf-8')
await upload_dataframe_to_minio(
bucket_name=BUCKET_NAME,
object_name=csv_filename,
df=relevant_papers_df,
)
logger.info(f"Assigned subheadings saved to {csv_filename}")
logger.info(f"Found {len(relevant_papers_df)} related papers")
return relevant_papers_df
# Function to create expanded paragraphs with required reference count and consistent reference indexing
async def create_paragraphs_by_subheading(
relevant_papers_df, subheadings, main_topic,
uuid, customer_name, model_name,
chat_func
):
"""
Create expanded paragraphs by subheading with required reference count and consistent reference indexing.
Args:
relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers and their summaries.
subheadings (list): List of subheadings for the review paper.
main_topic (str): Main topic of the review paper.
uuid (str): UUID of the task.
customer_name (str): Name of the customer.
chat_func (function): Function to send chat messages to the chatbot.
Returns:
list: List of paragraphs with subheadings and consistent reference indexing.
"""
paragraphs = []
# Introduction
intro_prompt = (
f"Write a concise and advanced introductory paragraph for a scientific review paper on '{main_topic}'. "
"Introduce the topic, its importance, and the scope of the review. The introduction should provide a logical "
"setup for the following subheadings.\n\n"
"Output format:\n[Write introduction here]"
)
intro_response = await chat_func(intro_prompt)
intro_paragraph = intro_response.choices[0].message.content.strip()
paragraphs.append(f"**Introduction**\n{intro_paragraph}\n")
# Body paragraphs based on subheadings with consistent reference numbering
reference_map = {}
used_references = []
total_papers = len(relevant_papers_df)
min_papers_per_subheading = math.ceil(
total_papers / (len(subheadings) + 1))
ref_counter = 1
paragraph_prompts = []
for subheading in subheadings:
relevant_summaries = relevant_papers_df[relevant_papers_df['Assigned Subheading'] == subheading]
new_references = []
summaries_text = []
for idx, (summary, title, author, pub_date) in relevant_summaries[['Summary', 'Title', 'First Author', 'Publication Date']].iterrows():
if title not in reference_map:
reference_map[title] = ref_counter
ref_counter += 1
ref_index = reference_map[title]
summaries_text.append(f"{summary} [Ref: {ref_index}]")
new_references.append((title, author, pub_date))
# Compose prompt to generate an extended paragraph with at least 800 words
paragraph_prompt = (
f"Write an 800-word thematic and critical paragraph under the subheading '{subheading}' for a scientific review on '{main_topic}'. "
f"Combine the following summaries into a coherent, well-structured paragraph discussing the studies’ objectives, findings, "
"and methodologies. Use advanced academic language, include in-text citations in the format [Ref: number], and avoid repeating "
"content from previous sections. Provide critical insights and comparative analysis where relevant.\n\n"
f"Summaries:\n{' '.join(summaries_text)}\n\n"
"Output format:\n[Write paragraph here]"
)
paragraph_prompts.append(paragraph_prompt)
used_references.extend(new_references)
paragraph_responses = await asyncio.gather(
*(chat_func(para_prompt)
for para_prompt in paragraph_prompts)
)
for subheading, paragraph_response in \
zip(subheadings, paragraph_responses):
paragraph = f"**{subheading}**\n{paragraph_response.choices[0].message.content.strip()}\n"
paragraphs.append(paragraph)
# Conclusion
conclusion_prompt = (
f"Write a concluding paragraph for a scientific review on '{main_topic}'. Summarize the main points discussed in the previous sections, "
"highlight the significance of the research, and suggest possible future directions or applications.\n\n"
"Output format:\n[Write conclusion here]"
)
conclusion_response = await chat_func(conclusion_prompt)
conclusion_paragraph = conclusion_response.choices[0].message.content.strip()
paragraphs.append(f"**Conclusion**\n{conclusion_paragraph}\n")
# References section (only used references)
references = "\n".join(
[f"[Ref: {reference_map[title]}] {title}, {author}, {pub_date}"
for title, author, pub_date in used_references]
)
paragraphs.append(f"**References**\n{references}")
# Compile paragraphs into final content
final_content = "\n\n".join(paragraphs)
# Save grouped summaries to CSV with customer_name and current date
prefix = f"{customer_name}/{uuid}/{model_name}/"
output_dir = prefix
csv_filename = os.path.join(output_dir, f"grouped_summaries.csv")
output_filename = os.path.join(output_dir, f"review_non_refined.txt")
# Prepare data for CSV
grouped_data = relevant_papers_df[['Assigned Subheading', 'Summary']]
# grouped_data.to_csv(csv_filename, index=False, encoding='utf-8')
await upload_dataframe_to_minio(
bucket_name=BUCKET_NAME,
object_name=csv_filename,
df=grouped_data
)
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=output_filename,
file_content=final_content
)
logger.info(f"\nGrouped summaries saved to {csv_filename}")
logger.info(f"Non-refined review saved to {output_filename}")
return final_content
# Function to enhance language and readability to meet Nature journal style
async def enhance_language_readability(
content,
uuid, customer_name, model_name,
chat_func
):
"""
Enhance the language and readability of the given content to meet the style of the *Nature* journal.
Args:
content (str): The content to enhance.
chat_func (function): The function to use for the chat completion.
Returns:
str: The enhanced content.
"""
# Separate sections based on paragraph breaks
sections = content.split("\n\n")
enhanced_sections = []
prompts = []
for section in sections:
prompt = (
"Enhance the following text to align with the writing style of *Nature* journal. Refine language to be sophisticated and objective, "
"using advanced vocabulary and a factual tone. Ensure a high level of lexical diversity and rhythm, with alternating sentence lengths "
"and varied structures for readability. Avoid emotional, speculative, or conversational language, focusing on objective analysis.\n\n"
f"Text:\n{section}\n\n"
"Output format:\n[Enhanced text here]"
)
prompts.append(prompt)
responses = await asyncio.gather(
*(chat_func(prompt) for prompt in prompts)
)
for response in responses:
enhanced_section = response.choices[0].message.content.strip()
enhanced_sections.append(enhanced_section)
enhanced_content = "\n\n".join(enhanced_sections)
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
file_content=enhanced_content
)
return enhanced_content
async def process_papers(
dataframe, topic, direction,
uuid, customer_name, model_name,
chat_func
):
"""
Process the given papers to extract relevant information and save it to a CSV file.
Args:
dataframe (pandas.DataFrame): The DataFrame containing the papers.
topic (str): The topic to filter the papers by.
direction (str): The direction to filter the papers by.
uuid (str): The UUID of the task.
customer_name (str): The name of the customer.
chat_func (function): The function to use for the chat completion.
Returns:
pandas.DataFrame: The DataFrame containing the relevant papers.
"""
# Duplicate, no need
# relevant_rows = [] # List to collect relevant rows for DataFrame creation
# Set up the output directory and CSV file
# output_dir = os.path.join(customer_name)
# os.makedirs(output_dir, exist_ok=True)
prefix = f"{customer_name}/{uuid}/{model_name}/"
output_dir = prefix
output_path = os.path.join(output_dir, "relevant_papers.csv")
# Create or clear the output file at the beginning
# with open(output_path, 'w', newline='', encoding='utf-8') as f:
# writer = csv.writer(f, quoting=csv.QUOTE_ALL)
# writer.writerow(["Journal Title", "Publication Date", "Title", "First Author", "Summary", "Is Relevant", "Relevance Keywords"]) # Writing header
texts = ""
fieldnames = ["Journal Title", "Publication Date", "Title",
"First Author", "Summary", "Is Relevant", "Relevance Keywords"]
texts += ",".join([escape_csv_field(x) for x in fieldnames]) + "\n"
titles = []
abstracts = []
journal_titles = []
pubd_dates = []
first_authors = []
summaries = []
for idx, row in dataframe.iterrows():
title = row["TI"]
abstract = row["AB"]
journal_title = row["JT"]
pub_date = row["DCOM"]
first_author = row["FAU-frist"]
titles.append(title)
abstracts.append(abstract)
journal_titles.append(journal_title)
pubd_dates.append(pub_date)
first_authors.append(first_author)
relevants = await asyncio.gather(
*(is_relevant(
title, abstract, topic, direction, chat_func
) for title, abstract in zip(titles, abstracts))
)
is_relevant_flags = [relevant[0] for relevant in relevants]
relevance_keywords = [relevant[1] for relevant in relevants]
rtitles = []
rabstracts = []
rjournal_titles = []
rpubd_dates = []
rfirst_authors = []
rflags = []
rkeywords = []
for (
rflag, rkeyword, title, abstarct, first_author, journal_title, pub_date
) in zip(
is_relevant_flags, relevance_keywords,
titles, abstracts, first_authors, journal_titles, pubd_dates
):
if rflag:
rtitles.append(title)
rabstracts.append(abstarct)
rfirst_authors.append(first_author)
rjournal_titles.append(journal_title)
rpubd_dates.append(pub_date)
rflags.append(rflag)
rkeywords.append(rkeyword)
summaries = await asyncio.gather(
*(summarize_abstract(
title, abstract, first_author, chat_func
) for title, abstract, first_author in
zip(rtitles, rabstracts, rfirst_authors)
)
)
for (
summary,
journal_title, pub_date, title, first_author,
rflag, rkeyword
) in zip(
summaries,
rjournal_titles, rpubd_dates, rtitles, rfirst_authors,
rflags, rkeywords
):
journal_title = escape_csv_field(journal_title)
pub_date = escape_csv_field(pub_date)
title = escape_csv_field(title)
first_author = escape_csv_field(first_author)
summary = escape_csv_field(summary)
rkeyword = escape_csv_field(rkeyword)
texts += ",".join([
str(x) for x in [
journal_title, pub_date, title, first_author,
summary, rflag, rkeyword
]
]) + "\n"
# Print the added summary and keywords
logger.info(f"Added summary: {summary}")
logger.info(f"Relevance Keywords: {rkeyword}")
# Create the relevant DataFrame to return
# relevant_df = pd.DataFrame(relevant_rows)
# return relevant_df
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=output_path,
file_content=texts
)
return output_path
async def translate_to_chinese_before_references(
text,
uuid, customer_name, model_name,
chat_func
):
"""
Translates the content of a text file to Chinese, keeping the '**References**' section in English.
Args:
text (str): The content of the text file.
output_filename (str): The name of the output file.
chat_func (function): The function to use for translation.
Returns:
str: The translated content.
"""
lines = text.split("\n")
# Step 3: 找到 '**References**' 行的索引
references_index = None
for i, line in enumerate(lines):
if line.strip() == "**References**":
references_index = i
break
# Step 4: 根据找到的索引分割内容
if references_index is not None:
main_content_lines = lines[:references_index]
references_content_lines = lines[references_index:]
else:
# 如果没有找到 '**References**',则认为整个内容为正文
main_content_lines = lines
references_content_lines = []
# 将正文内容拼接为一个字符串
main_content = "\n".join(main_content_lines)
# Step 5: 分段处理正文内容进行翻译
sections = main_content.split("\n\n")
translated_sections = []
prompts = []
for section in sections:
# 简化 prompt,只要求翻译正文内容
prompt = (
"Translate the following text to academic Chinese:\n\n"
f"Text:\n{section}\n\n"
"Output format:\n[Translated Chinese text here]"
)
prompts.append(prompt)
responses = await asyncio.gather(
*(chat_func(prompt) for prompt in prompts)
)
for response in responses:
translated_section = response.choices[0].message.content.strip()
translated_sections.append(translated_section)
# Step 6: 将翻译后的正文拼接
translated_content = "\n\n".join(translated_sections)
# Step 7: 合并翻译后的正文和 References 部分
if references_content_lines:
references_content = "\n".join(references_content_lines)
final_content = translated_content + "\n\n" + references_content
else:
final_content = translated_content
# Step 8: 保存结果到新的文件
output_filename = f"{customer_name}/{uuid}/{model_name}/review_paper_translated.txt"
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=output_filename,
file_content=final_content
)
logger.info(f"\nTranslated content saved to {output_filename}")
return output_filename
# Main function to automate the review paper creation process with language enhancement step
async def create_review_paper(
relevant_papers_df,
main_topic,
uuid, customer_name, model_name,
chat_func,
translate_to_cn=False
):
"""
Main function to automate the review paper creation process with language enhancement step.
Args:
relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers.
main_topic (str): Main topic of the review paper.
uuid (str): Unique identifier for the review paper.
customer_name (str): Name of the customer.
chat_func (function): Function to handle chat interactions.
translate_to_cn (bool): Flag to indicate if translation to Chinese is required.
Returns:
None
"""
# Step 1: Generate subheadings related to the main topic
subheadings = await generate_subheadings(
relevant_papers_df, main_topic,
chat_func
)
# Step 2: Assign each summary to a subheading
relevant_papers_df = await assign_subheadings_to_summaries(
relevant_papers_df, subheadings,
uuid, customer_name, model_name,
chat_func
)
# Step 3: Create paragraphs by subheading, with introductory and concluding sections, and references
review_content = await create_paragraphs_by_subheading(
relevant_papers_df, subheadings, main_topic,
uuid, customer_name, model_name,
chat_func
)
# Step 4: Enhance language and readability
enhanced_content = await enhance_language_readability(
review_content,
chat_func
)
prefix = f"{customer_name}/{uuid}/{model_name}/"
output_dir = prefix
output_filename = os.path.join(output_dir, "review_paper.txt")
# Step: Translate to Chinese
if translate_to_cn:
await translate_to_chinese_before_references(
enhanced_content,
output_filename.replace(".txt", "_cn.txt"),
chat_func
)
# Step 6: Save the generated content to a text file
# with open(output_filename, "w", encoding="utf-8") as f:
# f.write(enhanced_content)
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=output_filename,
file_content=enhanced_content
)
logger.info(f"\nReview paper saved to {output_filename}")
return output_filename