|
|
""" |
|
|
Agent used to synthesize a final report by iteratively writing each section of the report. |
|
|
Used to produce long reports given drafts of each section. Broadly aligned with the methodology described here: |
|
|
|
|
|
|
|
|
The LongWriterAgent takes as input a string in the following format: |
|
|
=========================================================== |
|
|
ORIGINAL QUERY: <original user query> |
|
|
|
|
|
CURRENT REPORT DRAFT: <current working draft of the report, all sections up to the current one being written> |
|
|
|
|
|
TITLE OF NEXT SECTION TO WRITE: <title of the next section of the report to be written> |
|
|
|
|
|
DRAFT OF NEXT SECTION: <draft of the next section of the report> |
|
|
=========================================================== |
|
|
|
|
|
The Agent then: |
|
|
1. Reads the current draft and the draft of the next section |
|
|
2. Writes the next section of the report |
|
|
3. Produces an updated draft of the new section to fit the flow of the report |
|
|
4. Returns the updated draft of the new section along with references/citations |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
from ..utils.llm_client import ( |
|
|
long_model, |
|
|
qianwen_plus_model, |
|
|
) |
|
|
from ..utils.baseclass import ResearchAgent, ResearchRunner |
|
|
from ..utils.parse_output import create_type_parser |
|
|
from ..utils.schemas import ReportDraft |
|
|
from ..config_logger import logger |
|
|
except ImportError: |
|
|
|
|
|
from utils.llm_client import ( |
|
|
long_model, |
|
|
qianwen_plus_model, |
|
|
) |
|
|
from utils.baseclass import ResearchAgent, ResearchRunner |
|
|
from utils.parse_output import create_type_parser |
|
|
from utils.schemas import ReportDraft |
|
|
from config_logger import logger |
|
|
|
|
|
import re |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
from openai.types.responses import ResponseTextDeltaEvent |
|
|
from pydantic import BaseModel, Field, ValidationError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LongWriterOutput(BaseModel): |
|
|
next_section_markdown: str = Field( |
|
|
description="The final draft of the next section in markdown format" |
|
|
) |
|
|
references: List[str] = Field( |
|
|
description="A list of references and their corresponding reference hash id for the section" |
|
|
) |
|
|
|
|
|
|
|
|
INSTRUCTIONS = f""" |
|
|
You are an expert report writer tasked with iteratively writing each section of a report. |
|
|
Today's date is {datetime.now().strftime("%Y-%m-%d")}. |
|
|
You will be provided with: |
|
|
1. The original research query |
|
|
3. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet) |
|
|
3. A first draft of the next section of the report to be written |
|
|
|
|
|
OBJECTIVE: |
|
|
1. Write a final draft of the next section of the report with numbered citations in square brackets in the body of the report |
|
|
2. Produce a list of references to be appended to the end of the report |
|
|
3. Content Depth: The review should comprehensively cover the provided articles, ensuring detailed analysis and discussion of each study, including methodologies, key findings, and contributions. Feel free to include supplementary information, explanations, and insights that may enhance the depth and breadth of your review, even if it seems verbose. The goal is to produce a comprehensive and thorough output that fulfills the length requirement. |
|
|
|
|
|
CITATIONS/REFERENCES: |
|
|
The citations should be in numerical order, written in numbered square brackets in the body of the report. |
|
|
Separately, a list of all references and their corresponding reference numbers will be included at the end of the report. |
|
|
|
|
|
|
|
|
For the References : |
|
|
1. Use ONLY information that is explicitly provided in the articles |
|
|
2. DO NOT invent or fabricate any information, dates, journal names, or other details |
|
|
3. For missing information, use "N/A" or omit the field entirely, but NEVER invent data |
|
|
4. Use this format: Author(s et al), (Year). Title. |
|
|
5. If any piece of information is missing, simply exclude it rather than making it up |
|
|
For example, if author, year and title are available but not journal details: |
|
|
- Smith J, Johnson K. (2020). Advances in gene therapy for cancer treatment. |
|
|
|
|
|
If only author and title are available: |
|
|
- Smith J. Advances in gene therapy for cancer treatment. |
|
|
Follow the example below for fomartting. |
|
|
|
|
|
DO NOT create fictional references or invent missing data. |
|
|
GUIDELINES: |
|
|
- You can reformat and reorganize the flow of the content and headings within a section to flow logically, but DO NOT remove details that were included in the first draft |
|
|
- Only remove text from the first draft if it is already mentioned earlier in the report, or if it should be covered in a later section per the table of contents |
|
|
- Ensure the heading for the section matches the table of contents |
|
|
- Format the final output and references section as markdown |
|
|
- Do not include a title for the reference section, just a list of numbered references |
|
|
|
|
|
Important: |
|
|
-Ensure that the output body of your output review contains 3,000-3,500 words. |
|
|
|
|
|
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only: |
|
|
{LongWriterOutput.model_json_schema()} |
|
|
""" |
|
|
INSTRUCTIONS_TEST = f""" |
|
|
You are an expert academic writer specializing in writing each section of comprehensive literature reviews. |
|
|
Today's date is {datetime.now().strftime("%Y-%m-%d")}. |
|
|
|
|
|
INPUT PROVIDED: |
|
|
1. Original research query |
|
|
2. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet) |
|
|
3. A first draft of the next section of the report to be written |
|
|
4. Language preference |
|
|
|
|
|
OBJECTIVE: |
|
|
Write a a final draft comprehensive, well-structured literature next review section with <hash string> in the body of the report. |
|
|
ATTENTION: The <hash string> is the hash string provided by the user to the reference. DON'T change the <hash string> to other string. |
|
|
CRITICAL FORMATTING REQUIREMENTS (MUST FOLLOW EXACTLY): |
|
|
|
|
|
## Section Structure: |
|
|
- Main section title: ## [Section Title] |
|
|
- Primary subsections: ### [Subsection Title] |
|
|
- Secondary subsections: #### [Sub-subsection Title] |
|
|
- Never use numbered headings (e.g., avoid "2.1", "2.2") |
|
|
|
|
|
## Writing Style: |
|
|
- Use flowing narrative paragraphs, NOT bullet points or lists |
|
|
- Each paragraph should be 4-8 sentences with clear topic sentences |
|
|
- Integrate citations naturally within sentences using <hash string> format |
|
|
- Maintain academic tone with sophisticated vocabulary |
|
|
- Use transitional phrases between paragraphs for smooth flow |
|
|
|
|
|
## Content Organization: |
|
|
- Start each subsection with a clear introductory paragraph |
|
|
- Present findings in a logical sequence |
|
|
- Compare and contrast studies within the same paragraph when relevant |
|
|
- Synthesize information across multiple sources |
|
|
- End subsections with brief summary or transition to next topic |
|
|
|
|
|
## Citation Requirements: |
|
|
- Use ONLY the seperated <hash string> format provided (e.g., <a1b2c3d4> <a3b4c5d6>) |
|
|
- NEVER change or modify the hash strings |
|
|
- Integrate citations naturally: "Recent studies have shown <a1b2c3d4> that..." |
|
|
- for multiple citations format,DONOT use "<a1b2c3d4,a3b4c5d6>",use the "<a1b2c3d4> <a3b4c5d6> that..." in the text,and the <hash string> is the hash string provided by the user to the reference. |
|
|
- Avoid citation clustering at paragraph ends |
|
|
|
|
|
## Language Requirements: |
|
|
- If language is "CH": Write in Chinese but keep <hash string> unchanged |
|
|
- If language is "EN": Write in English |
|
|
- Maintain consistent language throughout |
|
|
|
|
|
## Content Requirements: |
|
|
- Comprehensively cover ALL provided articles |
|
|
- Include methodology discussion when relevant |
|
|
- Discuss key findings, limitations, and implications |
|
|
- Maintain 800-1000 words for the section |
|
|
- Do NOT remove details from the original draft unless clearly redundant,do not change the <hash string> |
|
|
|
|
|
## Prohibited Formats: |
|
|
- No bullet points (•) or numbered lists (1., 2., 3.) |
|
|
|
|
|
- No excessive short paragraphs (under 3 sentences) |
|
|
- No standalone citation sentences |
|
|
|
|
|
REFERENCES FORMAT: |
|
|
Collect all <hash string> citations and their corresponding sources exactly as provided by the user. |
|
|
|
|
|
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only: |
|
|
{LongWriterOutput.model_json_schema()} |
|
|
""" |
|
|
|
|
|
|
|
|
selected_model = long_model |
|
|
|
|
|
long_writer_agent = ResearchAgent( |
|
|
name="LongWriterAgent", |
|
|
instructions=INSTRUCTIONS_TEST, |
|
|
model=selected_model, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
INSTRUCTIONS_Translation = """ |
|
|
You are an expert translator tasked with translating a text from English to Chinese. |
|
|
|
|
|
INPUT PROVIDED: |
|
|
1. A text in English |
|
|
|
|
|
OBJECTIVE: |
|
|
Translate the text from English to Chinese |
|
|
|
|
|
""" |
|
|
|
|
|
translation_agent = ResearchAgent( |
|
|
name="TranslationAgent", |
|
|
instructions=INSTRUCTIONS_Translation, |
|
|
model=selected_model, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
async def write_next_section( |
|
|
original_query: str, |
|
|
report_draft: str, |
|
|
next_section_title: str, |
|
|
next_section_draft: str, |
|
|
thoughts_callback, |
|
|
language: str = "EN", |
|
|
): |
|
|
"""Write the next section of the report""" |
|
|
|
|
|
user_message = f""" |
|
|
<ORIGINAL QUERY> |
|
|
{original_query} |
|
|
</ORIGINAL QUERY> |
|
|
|
|
|
<CURRENT REPORT DRAFT> |
|
|
{report_draft or "No draft yet"} |
|
|
</CURRENT REPORT DRAFT> |
|
|
|
|
|
<TITLE OF NEXT SECTION TO WRITE> |
|
|
{next_section_title} |
|
|
</TITLE OF NEXT SECTION TO WRITE> |
|
|
|
|
|
<DRAFT OF NEXT SECTION> |
|
|
{next_section_draft} |
|
|
</DRAFT OF NEXT SECTION> |
|
|
|
|
|
<LANGUAGE> |
|
|
{language} |
|
|
</LANGUAGE> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_iter = 3 |
|
|
iter_num = 0 |
|
|
temp_agent_type = "" |
|
|
|
|
|
while iter_num < max_iter: |
|
|
full_response = "" |
|
|
try: |
|
|
result = ResearchRunner.run_streamed( |
|
|
long_writer_agent, |
|
|
user_message, |
|
|
) |
|
|
|
|
|
async for event in result.stream_events(): |
|
|
|
|
|
if event.type == "raw_response_event" and isinstance( |
|
|
event.data, ResponseTextDeltaEvent |
|
|
): |
|
|
full_response += event.data.delta |
|
|
elif event.type == "agent_updated_stream_event": |
|
|
if event.new_agent.name != temp_agent_type: |
|
|
temp_agent_type = event.new_agent.name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
cleaned_response = clean_json_response(full_response) |
|
|
|
|
|
resf = create_type_parser(LongWriterOutput) |
|
|
res = resf(cleaned_response) |
|
|
return res |
|
|
except Exception as parse_error: |
|
|
|
|
|
logger.warning( |
|
|
f"Failed to parse output as JSON in write_next_section ,try extract from failed json: {str(parse_error)[:200]}" |
|
|
) |
|
|
try: |
|
|
manual_result = extract_from_failed_json(full_response) |
|
|
if manual_result: |
|
|
return manual_result |
|
|
except Exception as manual_error: |
|
|
logger.error( |
|
|
f"Manual extraction also failed: {str(manual_error)[:100]}" |
|
|
) |
|
|
|
|
|
|
|
|
iter_num += 1 |
|
|
logger.error( |
|
|
f"Parse error occurred: {parse_error}. Retrying {iter_num}/{max_iter}..." |
|
|
) |
|
|
continue |
|
|
|
|
|
except ValidationError: |
|
|
|
|
|
resf = create_type_parser(LongWriterOutput) |
|
|
res = resf(full_response) |
|
|
return res |
|
|
except Exception as e: |
|
|
logger.error(f"Write next section error: {e}") |
|
|
iter_num += 1 |
|
|
logger.error(f"Error occurred: {e}. Retrying {iter_num}/{max_iter}...") |
|
|
|
|
|
return LongWriterOutput( |
|
|
next_section_markdown="The section generate error", references=[] |
|
|
) |
|
|
|
|
|
|
|
|
def clean_json_response(response: str) -> str: |
|
|
"""Clean and prepare JSON response for parsing""" |
|
|
import json |
|
|
|
|
|
|
|
|
response = response.strip() |
|
|
|
|
|
|
|
|
if not response.startswith("{"): |
|
|
json_start = response.find("{") |
|
|
if json_start != -1: |
|
|
response = response[json_start:] |
|
|
|
|
|
|
|
|
if not response.endswith("}"): |
|
|
json_end = response.rfind("}") |
|
|
if json_end != -1: |
|
|
response = response[: json_end + 1] |
|
|
|
|
|
|
|
|
|
|
|
response = response.replace('"', '"').replace('"', '"') |
|
|
response = response.replace(""", "'").replace(""", "'") |
|
|
|
|
|
|
|
|
|
|
|
response = response.replace("\\<", "<").replace("\\>", ">") |
|
|
|
|
|
|
|
|
response = re.sub(r",(\s*[}\]])", r"\1", response) |
|
|
|
|
|
|
|
|
try: |
|
|
json.loads(response) |
|
|
return response |
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning( |
|
|
f"JSON decode error at position {e.pos}: {str(e)},try to fix it " |
|
|
) |
|
|
|
|
|
if hasattr(e, "pos") and e.pos > 0: |
|
|
|
|
|
truncate_pos = e.pos |
|
|
while truncate_pos > 0 and response[truncate_pos - 1] not in [ |
|
|
'"', |
|
|
"}", |
|
|
"]", |
|
|
]: |
|
|
truncate_pos -= 1 |
|
|
if truncate_pos > 0: |
|
|
truncated = response[:truncate_pos] |
|
|
|
|
|
if truncated.count('"') % 2 == 1: |
|
|
truncated += '"' |
|
|
if truncated.count("{") > truncated.count("}"): |
|
|
truncated += "}" |
|
|
if truncated.count("[") > truncated.count("]"): |
|
|
truncated += "]" |
|
|
try: |
|
|
json.loads(truncated) |
|
|
return truncated |
|
|
except: |
|
|
pass |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
def extract_from_failed_json(response: str) -> Optional[LongWriterOutput]: |
|
|
"""Attempt to extract data from malformed JSON response""" |
|
|
try: |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
markdown_patterns = [ |
|
|
r'"next_section_markdown":\s*"(.*?)"(?=,\s*"references")', |
|
|
r'"next_section_markdown":\s*"(.*?)"(?=,?\s*"references")', |
|
|
r'"next_section_markdown":\s*"(.*?)"\s*,', |
|
|
r'"next_section_markdown":\s*"(.*?)"(?=\s*[,}])', |
|
|
] |
|
|
|
|
|
markdown_content = None |
|
|
for pattern in markdown_patterns: |
|
|
markdown_match = re.search(pattern, response, re.DOTALL) |
|
|
if markdown_match: |
|
|
markdown_content = markdown_match.group(1) |
|
|
break |
|
|
|
|
|
if not markdown_content: |
|
|
|
|
|
start_match = re.search(r'"next_section_markdown":\s*"', response) |
|
|
if start_match: |
|
|
start_pos = start_match.end() |
|
|
|
|
|
quote_count = 0 |
|
|
end_pos = start_pos |
|
|
while end_pos < len(response): |
|
|
if response[end_pos] == '"' and ( |
|
|
end_pos == 0 or response[end_pos - 1] != "\\" |
|
|
): |
|
|
break |
|
|
end_pos += 1 |
|
|
if end_pos < len(response): |
|
|
markdown_content = response[start_pos:end_pos] |
|
|
|
|
|
|
|
|
references = [] |
|
|
refs_patterns = [ |
|
|
r'"references":\s*\[(.*?)\]', |
|
|
r'"references":\s*\[(.*?)(?=\s*})', |
|
|
] |
|
|
|
|
|
for pattern in refs_patterns: |
|
|
refs_match = re.search(pattern, response, re.DOTALL) |
|
|
if refs_match: |
|
|
refs_content = refs_match.group(1) |
|
|
|
|
|
ref_items = re.findall(r'"([^"]*<[a-f0-9]{8}>[^"]*)"', refs_content) |
|
|
references = ref_items |
|
|
break |
|
|
|
|
|
if markdown_content: |
|
|
|
|
|
markdown_content = ( |
|
|
markdown_content.replace('\\"', '"') |
|
|
.replace("\\n", "\n") |
|
|
.replace("\\/", "/") |
|
|
) |
|
|
|
|
|
return LongWriterOutput( |
|
|
next_section_markdown=markdown_content, references=references |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Manual extraction error: {e}") |
|
|
return None |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def extract_hash_strings_from_text(text: str) -> List[str]: |
|
|
"""Extract all <hash_string> patterns from text, preserving order and removing duplicates""" |
|
|
pattern = r"<([a-f0-9]{8})>" |
|
|
matches = re.findall(pattern, text) |
|
|
|
|
|
|
|
|
return list(dict.fromkeys(matches)) |
|
|
|
|
|
|
|
|
def replace_hash_strings_with_numbered_refs( |
|
|
final_draft: str, all_references: List[str] |
|
|
) -> Tuple[str, List[str]]: |
|
|
""" |
|
|
在final_draft中搜索所有的<hash str>,然后和all_references对比, |
|
|
将不同的hash str逐个替换为[1][2]... |
|
|
如果有找不到的hash str,则直接删除这条,正文中也要删除 |
|
|
|
|
|
Args: |
|
|
final_draft: 报告正文 |
|
|
all_references: 所有引用列表,格式为 ["<hash> source", ...] |
|
|
|
|
|
Returns: |
|
|
(updated_final_draft, formatted_references) |
|
|
""" |
|
|
|
|
|
hash_strings_in_text = extract_hash_strings_from_text(final_draft) |
|
|
|
|
|
|
|
|
hash_to_source = {} |
|
|
for ref in all_references: |
|
|
if ref and "<" in ref and ">" in ref: |
|
|
|
|
|
match = re.match(r"<([a-f0-9]{8})>\s*(.*)", ref) |
|
|
if match: |
|
|
hash_str, source = match.groups() |
|
|
hash_to_source[hash_str] = source.strip() |
|
|
|
|
|
|
|
|
hash_to_number = {} |
|
|
formatted_references = [] |
|
|
ref_counter = 1 |
|
|
|
|
|
for hash_str in hash_strings_in_text: |
|
|
if hash_str in hash_to_source: |
|
|
hash_to_number[hash_str] = ref_counter |
|
|
formatted_references.append(f"[{ref_counter}] {hash_to_source[hash_str]}") |
|
|
ref_counter += 1 |
|
|
|
|
|
|
|
|
updated_final_draft = final_draft |
|
|
for hash_str in hash_strings_in_text: |
|
|
pattern = f"<{hash_str}>" |
|
|
if hash_str in hash_to_number: |
|
|
|
|
|
replacement = f"[{hash_to_number[hash_str]}]" |
|
|
updated_final_draft = updated_final_draft.replace(pattern, replacement) |
|
|
else: |
|
|
|
|
|
updated_final_draft = updated_final_draft.replace(pattern, "") |
|
|
|
|
|
return updated_final_draft, formatted_references |
|
|
|
|
|
|
|
|
async def write_report( |
|
|
original_query: str, |
|
|
report_title: str, |
|
|
report_draft: ReportDraft, |
|
|
ref: List[str], |
|
|
thoughts_callback, |
|
|
language: str = "EN", |
|
|
) -> str: |
|
|
"""Write the final report by iteratively writing each section""" |
|
|
|
|
|
if thoughts_callback == None: |
|
|
|
|
|
async def thoughts_callback(thought): |
|
|
pass |
|
|
|
|
|
if language == "CH": |
|
|
report_title_response = await ResearchRunner.run( |
|
|
translation_agent, |
|
|
report_title, |
|
|
) |
|
|
report_title = report_title_response.final_output |
|
|
final_draft = f"# {report_title}\n\n" + "\n\n" |
|
|
else: |
|
|
|
|
|
final_draft = ( |
|
|
f"# {report_title}\n\n" |
|
|
+ "## Table of Contents\n\n" |
|
|
+ "\n".join( |
|
|
[ |
|
|
f"{i + 1}. {section.section_title}" |
|
|
for i, section in enumerate(report_draft.sections) |
|
|
] |
|
|
) |
|
|
+ "\n\n" |
|
|
) |
|
|
all_references = ref |
|
|
|
|
|
for section in report_draft.sections: |
|
|
|
|
|
|
|
|
next_section_draft = await write_next_section( |
|
|
original_query, |
|
|
final_draft, |
|
|
section.section_title, |
|
|
section.section_content, |
|
|
thoughts_callback, |
|
|
language, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
section_markdown = next_section_draft.next_section_markdown |
|
|
section_markdown = reformat_section_headings(section_markdown) |
|
|
final_draft += section_markdown + "\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_draft, formatted_references = replace_hash_strings_with_numbered_refs( |
|
|
final_draft, all_references |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if formatted_references: |
|
|
final_draft += "## References\n\n" + "\n".join(formatted_references) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return final_draft |
|
|
|
|
|
|
|
|
async def write_report_from_section_drafts( |
|
|
original_query: str, |
|
|
abstract: str, |
|
|
report_title: str, |
|
|
report_draft: ReportDraft, |
|
|
ref: List[str], |
|
|
thoughts_callback, |
|
|
language: str = "EN", |
|
|
) -> str: |
|
|
"""Write the final report by iteratively writing each section""" |
|
|
if thoughts_callback == None: |
|
|
|
|
|
async def thoughts_callback(thought): |
|
|
pass |
|
|
|
|
|
if abstract: |
|
|
abstract_string = "# Abstract\n\n" + abstract + "\n\n" |
|
|
else: |
|
|
abstract_string = "" |
|
|
final_draft = ( |
|
|
f"# {report_title}\n\n" |
|
|
+ "## Table of Contents\n\n" |
|
|
+ "\n".join( |
|
|
[ |
|
|
f"{i + 1}. {section.section_title}" |
|
|
for i, section in enumerate(report_draft.sections) |
|
|
] |
|
|
) |
|
|
+ "\n\n" |
|
|
+ abstract_string |
|
|
) |
|
|
all_references = ref |
|
|
for section in report_draft.sections: |
|
|
section_markdown = section.section_content |
|
|
section_markdown = reformat_section_headings(section_markdown) |
|
|
final_draft += section_markdown + "\n\n" |
|
|
final_draft, formatted_references = replace_hash_strings_with_numbered_refs( |
|
|
final_draft, all_references |
|
|
) |
|
|
if formatted_references: |
|
|
final_draft += "## References\n\n" + "\n\n".join(formatted_references) |
|
|
return final_draft |
|
|
|
|
|
|
|
|
def reformat_references( |
|
|
section_markdown: str, section_references: List[str], all_references: List[str] |
|
|
) -> Tuple[str, List[str]]: |
|
|
""" |
|
|
This method gracefully handles the re-numbering, de-duplication and re-formatting of references as new sections are added to the report draft. |
|
|
It takes as input: |
|
|
1. The markdown content of the new section containing inline references in square brackets, e.g. [1], [2] |
|
|
2. The list of references for the new section, e.g. ["[1] Authors, (year). Title", "[2] [1] Authors, (year). Title"] |
|
|
3. The list of references covering all prior sections of the report |
|
|
|
|
|
It returns: |
|
|
1. The updated markdown content of the new section with the references re-numbered and de-duplicated, such that they increment from the previous references |
|
|
2. The updated list of references for the full report, to include the new section's references |
|
|
""" |
|
|
|
|
|
def convert_ref_list_to_map(ref_list: List[str]) -> Dict[str, str]: |
|
|
ref_map = {} |
|
|
for ref in ref_list: |
|
|
try: |
|
|
ref_num = int(ref.split("]")[0].strip("[")) |
|
|
url = ref.split("]", 1)[1].strip() |
|
|
ref_map[url] = ref_num |
|
|
except ValueError: |
|
|
print(f"Invalid reference format: {ref}") |
|
|
continue |
|
|
return ref_map |
|
|
|
|
|
section_ref_map = convert_ref_list_to_map(section_references) |
|
|
report_ref_map = convert_ref_list_to_map(all_references) |
|
|
section_to_report_ref_map = {} |
|
|
|
|
|
report_urls = set(report_ref_map.keys()) |
|
|
ref_count = max(report_ref_map.values() or [0]) |
|
|
for url, section_ref_num in section_ref_map.items(): |
|
|
if url in report_urls: |
|
|
section_to_report_ref_map[section_ref_num] = report_ref_map[url] |
|
|
else: |
|
|
|
|
|
ref_count += 1 |
|
|
section_to_report_ref_map[section_ref_num] = ref_count |
|
|
all_references.append(f"[{ref_count}] {url}") |
|
|
|
|
|
def replace_reference(match): |
|
|
|
|
|
ref_num = int(match.group(1)) |
|
|
|
|
|
mapped_ref_num = section_to_report_ref_map.get(ref_num) |
|
|
if mapped_ref_num: |
|
|
return f"[{mapped_ref_num}]" |
|
|
return "" |
|
|
|
|
|
|
|
|
section_markdown = re.sub(r"\[(\d+)\]", replace_reference, section_markdown) |
|
|
|
|
|
return section_markdown, all_references |
|
|
|
|
|
|
|
|
def reformat_section_headings(section_markdown: str) -> str: |
|
|
""" |
|
|
Reformat the headings of a section to be consistent with the report, by rebasing the section's heading to be a level-2 heading |
|
|
|
|
|
E.g. this: |
|
|
# Big Title |
|
|
Some content |
|
|
## Subsection |
|
|
|
|
|
Becomes this: |
|
|
## Big Title |
|
|
Some content |
|
|
### Subsection |
|
|
""" |
|
|
|
|
|
if not section_markdown.strip(): |
|
|
return section_markdown |
|
|
|
|
|
|
|
|
first_heading_match = re.search(r"^(#+)\s", section_markdown, re.MULTILINE) |
|
|
if not first_heading_match: |
|
|
return section_markdown |
|
|
|
|
|
|
|
|
first_heading_level = len(first_heading_match.group(1)) |
|
|
level_adjustment = 2 - first_heading_level |
|
|
|
|
|
def adjust_heading_level(match): |
|
|
hashes = match.group(1) |
|
|
content = match.group(2) |
|
|
new_level = max(2, len(hashes) + level_adjustment) |
|
|
return "#" * new_level + " " + content |
|
|
|
|
|
|
|
|
return re.sub( |
|
|
r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE |
|
|
) |
|
|
|