amps / utils /pubmed_plus_utils.py
jibsn's picture
Update utils/pubmed_plus_utils.py
d277610 verified
import io
import asyncio
from minio import Minio
from loguru import logger
from entities.task import PubMedPlusTask
from utils.api_utils import (
retry_operation,
get_chat_func,
compare_chat_chocies
)
from utils.r2_utils import (
get_client,
get_file_from_minio,
get_dataframe_from_minio,
upload_text_to_minio,
upload_task_json_to_minio,
)
from utils.paper_plus_utils import (
process_papers,
generate_subheadings,
assign_subheadings_to_summaries,
create_paragraphs_by_subheading,
refine_review_content,
translate_refined_review_to_chinese,
translate_to_chinese_before_references
)
from utils.pubmed_utils import (
generate_pubmed_search_string,
process_pubmed_data
)
BUCKET_NAME = "ai-scientist"
# =================================
# Function Groups: Pipeline for PubMed
#
# 1. pipeline
# 2. single model chat
# =================================
async def pubmed_plus_pipeline(
task: PubMedPlusTask,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Pubmed pipeline
Args:
task: PubMedTask object, containig basic information for PubMedTask
client: Minio, minio client
max_retries: int, max retries for each step
delay: float, delay between each retry
Returns:
None
"""
if client is None:
client = get_client()
customer_name = task.customer_name
uuid = task.uuid
model_names = task.model_names
task.status_string["overall"] = "processing"
await asyncio.gather(
*(process_pubmed_single_chat(
task, model_name, client, max_retries, delay
) for model_name in model_names)
)
# if compare between models
# at least 3 models should be selected
logger.info("Check Compare...")
if task.do_compare and len(task.model_names) >= 3:
if task.status.get("compare", 0) == 0:
contents = await asyncio.gather(
*(get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
client=client
) for model_name in model_names)
)
contents = [c.decode("utf-8") for c in contents]
task.status_string["overall"] = "Start Compare"
rank_scores = await compare_chat_chocies(
contents=contents,
model_names=model_names
)
best_content = contents[min(rank_scores, key=rank_scores.get)]
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/compared_reveiw_paper.txt",
file_content=best_content
)
task.status_string["overall"] = "Finished"
task.status["compare"] = 1
await upload_task_json_to_minio(task, client)
else:
task.status_string["overall"] = "Finished"
await upload_task_json_to_minio(task, client)
else:
logger.info("No Compare.")
task.status_string["overall"] = "Finished"
await upload_task_json_to_minio(task, client)
async def process_pubmed_single_chat(
task: PubMedPlusTask,
model_name: str,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Task
Args:
task: PubMedTask object, containig basic information for PubMedTask
model_name: str, model name, refer to the model used at this step
client: Minio, minio client
max_retries: int, max retries for each step
delay: float, delay between each retry
Returns:
None
"""
# get minio client
if client is None:
client = get_client()
# add status for <model_name>
if model_name not in task.status.keys():
task.status[model_name] = 0
# set task status string
task.status_string["overall"] = "processing"
process_steps = {
0: process_pubmed_generate_pubmed_string,
1: process_pubmed_fetch_data,
2: process_pubmed_process_papers,
3: process_pubmed_generate_subheadings,
4: process_pubmed_assign_subheadings_to_summaries,
5: process_pubmed_create_paragraphs_by_subheading,
6: process_pubmed_refine,
7: process_pubmed_translate,
}
state_description = {
0: "Finished pubmed string generation.",
1: "Finished fetching data.",
2: "Finished paper processing.",
3: "Finished subheading generation.",
4: "Finished subheading assignment.",
5: "Finished paragraph generation.",
6: "Finished review refine.",
7: "Finished review translate.",
}
# Execute Phase
current_state = task.status[model_name]
for state in range(current_state, len(process_steps.keys())):
await process_steps[state](
task=task,
model_name=model_name,
save_name=model_name,
prev_name=model_name,
client=client,
max_retries=max_retries, delay=delay
)
task.status_string[model_name] = state_description[state]
task.status[model_name] = state + 1
await upload_task_json_to_minio(task, client)
task.status_string[model_name] = "Finished."
await upload_task_json_to_minio(task, client)
# =================================
# Function Groups: process_pubmed_*
# 1. _generate_pubmed_string
# 2. _fetch_data
# 3. _process_papers
# 3. _generate_subheadings
# 4. _assign_subheadings_to_summaries
# 5. _create_paragraphs_by_subheading
# 6. _refine
# 7. _translate
# =================================
async def process_pubmed_generate_pubmed_string(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Generate pubmed search string step
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
max_retries: int, max retries for each step
delay: float, delay between each retry
Returns:
path to save results
"""
if client is None:
client = get_client()
if prev_name is not None:
logger.warning("For first step, prev_model_name is not used.")
query = task.query
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func(model_names=[model_name])[0]
pubmed_search_string, exceptions = await retry_operation(
generate_pubmed_search_string, task,
query=query,
max_retries=max_retries, delay=delay,
chat_func=chat_func
)
if pubmed_search_string is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Search String Generation Failed.") # exit
await upload_text_to_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{save_name}/pubmed_search_string.txt",
file_content=pubmed_search_string
)
async def process_pubmed_fetch_data(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Fetch Data
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
Returns:
path to save results
"""
if client is None:
client = get_client()
customer_name = task.customer_name
uuid = task.uuid
start_year = task.start_year
end_year = task.end_year
size = task.size
pubmed_search_string = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_search_string.txt",
client=client
)
pubmed_search_string = pubmed_search_string.decode("utf-8")
results, exceptions = await retry_operation(
process_pubmed_data, task,
query=pubmed_search_string,
model_name=save_name,
start_year=start_year, end_year=end_year,
size=size,
uuid=uuid, customer_name=customer_name,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise ConnectionError("Pubmed Data Fetch Failed.") # exit
async def process_pubmed_process_papers(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Process Papers
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
Returns:
path to save results
"""
if client is None:
client = get_client()
query = task.query
direction = task.direction
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func(model_names=[model_name])[0]
non_review_pubmed_df = await get_dataframe_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_results_non_reviews.csv",
client=client
)
results, exceptions = await retry_operation(
process_papers, task,
dataframe=non_review_pubmed_df,
topic=query, direction=direction,
uuid=uuid, customer_name=customer_name, model_name=save_name,
max_retries=max_retries, delay=delay,
chat_func=chat_func
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Paper Processing Failed.") # exit
async def process_pubmed_generate_subheadings(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Generate Subheadings
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
Returns:
path to save results
"""
if client is None:
client = get_client()
query = task.query
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func([model_name])[0]
relevant_papers_df = await get_dataframe_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
client=client
)
results, exceptions = await retry_operation(
generate_subheadings, task,
relevant_papers_df=relevant_papers_df,
main_topic=query,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Generate Subheadings Failed.") # exit
async def process_pubmed_assign_subheadings_to_summaries(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Assign Subheadings to Summaries
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
Returns:
path to save results
"""
if client is None:
client = get_client()
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func([model_name])[0]
subheadings = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
client=client
)
subheadings = subheadings.decode("utf-8").split("\n")
relevant_papers_df = await get_dataframe_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
client=client
)
results, exceptions = await retry_operation(
assign_subheadings_to_summaries, task,
subheadings=subheadings,
relevant_papers_df=relevant_papers_df,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Assign Subheadings Failed.") # exit
async def process_pubmed_create_paragraphs_by_subheading(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Create Paragraphs by Subheading
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
max_retries: int, max retries for the operation
delay: float, delay between retries
Returns:
path to save results
"""
if client is None:
client = get_client()
query = task.query
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func([model_name])[0]
subheadings = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
client=client
)
subheadings = subheadings.decode("utf-8").split("\n")
relevant_papers_df = await get_dataframe_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/assigned_subheadings.csv",
client=client
)
results, exceptions = await retry_operation(
create_paragraphs_by_subheading, task,
subheadings=subheadings, main_topic=query,
relevant_papers_df=relevant_papers_df,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Create Paragraphs Failed.") # exit
async def process_pubmed_translate(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Translate
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
max_retries: int, max retries for the operation
delay: float, delay between retries
Returns:
path to save results
"""
if client is None:
client = get_client()
customer_name = task.customer_name
uuid = task.uuid
do_refine = task.do_refine
chat_func = get_chat_func([model_name])[0]
if do_refine:
refined_review_content = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/review_paper_refined.docx",
client=client
)
refined_review_content = io.BytesIO(refined_review_content)
results, exceptions = await retry_operation(
translate_refined_review_to_chinese, task,
refined_review_content=refined_review_content,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Translate Refined Review Failed.") # exit
else:
review_content = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
client=client
)
review_content = review_content.decode("utf-8")
results, exceptions = await retry_operation(
translate_to_chinese_before_references, task,
text=review_content,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Translate Failed.") # exit
async def process_pubmed_refine(
task: PubMedPlusTask,
model_name: str,
save_name: str,
prev_name: str = None,
client: Minio = None,
max_retries: int = 5,
delay: float = 0.5
):
"""
Process PubMed Refine
Args:
task: PubMedTask object, containig basic information for PubMedTask
prev_model_name: str, previous model name, refer to previous step result
model_name: str, next model name, refer to the model used at this step
save_name: str, save name for minio path
client: Minio, minio client
max_retries: int, max retries for the operation
delay: float, delay between retries
Returns:
path to save results
"""
# additional check on if do_refine
# if not refine, exit here with 1
if not task.do_refine:
return 1
if client is None:
client = get_client()
customer_name = task.customer_name
uuid = task.uuid
chat_func = get_chat_func([model_name])[0]
review_content = await get_file_from_minio(
bucket_name=BUCKET_NAME,
object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
client=client
)
review_content = review_content.decode("utf-8")
results, exceptions = await retry_operation(
refine_review_content, task,
non_refine_content=review_content,
uuid=uuid, customer_name=customer_name, model_name=save_name,
chat_func=chat_func,
max_retries=max_retries, delay=delay
)
if results is None: # no valid result after max retries
# store exception strings in status
task.status_string[model_name] = exceptions
await upload_task_json_to_minio(task, client)
raise RuntimeError("Pubmed Refine Failed.") # exit