| |
|
| | import re |
| | import json |
| | import copy |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class SurveyManager: |
| | BASE_SURVEY_STRUCTURE = { |
| | "title": "", |
| | "abstract": "", |
| | "introduction": { |
| | "content": "" |
| | }, |
| | "sections": [], |
| | "conclusion": "" |
| | } |
| | |
| | def __init__(self): |
| | pass |
| | |
| | @staticmethod |
| | def parse_update_pos(update_pos): |
| | """ |
| | (1) "title", "abstract", "introduction", or "conclusion" |
| | (2) "section-i/subsection-j/..." |
| | |
| | """ |
| | if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]: |
| | return update_pos |
| | else: |
| | keys = update_pos.split("/") |
| | if len(keys) == 1: |
| | i = int(keys[0].lower().split("section-")[-1]) |
| | return f"section-{i}" |
| | elif len(keys) == 2: |
| | i = int(keys[0].lower().split("section-")[-1]) |
| | j = int(keys[1].lower().split("subsection-")[-1]) |
| | return f"section-{i}/subsection-{j}" |
| | elif len(keys) == 3: |
| | i = int(keys[0].lower().split("section-")[-1]) |
| | j = int(keys[1].lower().split("subsection-")[-1]) |
| | k = int(keys[2].lower().split("subsubsection-")[-1]) |
| | return f"section-{i}/subsection-{j}/subsubsection-{k}" |
| | else: |
| | raise ValueError("unsupported update_pos keys") |
| | |
| | @staticmethod |
| | def _to_one_line(string): |
| | if isinstance(string, dict): |
| | if "content" in string and string["content"]: |
| | return SurveyManager._to_one_line(string["content"]) |
| | |
| | else: |
| | return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip() |
| | if not string: |
| | return "" |
| | else: |
| | return string |
| | |
| | @staticmethod |
| | def convert_survey_dict_to_str(current_survey): |
| | string = "" |
| | if current_survey == {}: |
| | return "There is no survey." |
| | |
| | try: |
| | content = SurveyManager._to_one_line(current_survey["title"]) |
| | string += f"# {content}\n" |
| | except: |
| | string += f"# Title: None\n" |
| | |
| | |
| | try: |
| | content = SurveyManager._to_one_line(current_survey["abstract"]) |
| | string += f"## Abstract\n{content}\n" |
| | except: |
| | string += f"## Abstract\nNone\n" |
| | |
| | |
| | try: |
| | content = SurveyManager._to_one_line(current_survey["introduction"]) |
| | string += f"## Introduction\n{content}\n" |
| | except: |
| | string += f"## Introduction\nNone\n" |
| | |
| | |
| | if "sections" in current_survey: |
| | for i, section in enumerate(current_survey["sections"]): |
| | title_key = "name" if "name" in section else "title" |
| | name, content = section[title_key], SurveyManager._to_one_line(section) |
| | |
| | string += f"## {name}\n{content}\n" |
| |
|
| | if "subsections" in section: |
| | for j, subsection in enumerate(section["subsections"]): |
| | name, content = subsection[title_key], SurveyManager._to_one_line(subsection) |
| | |
| | string += f"### {name}\n{content}\n" |
| |
|
| | if "subsubsections" in subsection: |
| | for k, subsubsection in enumerate(subsection["subsubsections"]): |
| | name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection) |
| | |
| | string += f"#### {name}\n{content}\n" |
| | |
| | |
| | |
| | try: |
| | content = SurveyManager._to_one_line(current_survey["conclusion"]) |
| | string += f"## Conclusion\n{content}\n" |
| | except: |
| | string += f"## Conclusion:\nNone\n" |
| | |
| | return string |
| | |
| | @staticmethod |
| | def _abbr_one_line(string, abbr=True): |
| | if isinstance(string, dict): |
| | if "content" in string and string["content"]: |
| | return SurveyManager._abbr_one_line(string["content"], abbr=abbr) |
| | elif "plan" in string: |
| | return "[PLAN] " + string["plan"].replace("\n", " ").strip() |
| | else: |
| | return "" |
| | else: |
| | if not string: |
| | return "" |
| | else: |
| | if abbr and len(string) > 50: |
| | return "[OK] " + string.replace("\n", " ").strip()[:50] + "..." |
| | else: |
| | return "[OK] " + string.replace("\n", " ").strip() |
| |
|
| | @staticmethod |
| | def convert_survey_dict_to_abbr_str(current_survey): |
| | string = "" |
| | if current_survey == {}: |
| | return "There is no survey." |
| | |
| | try: |
| | content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False) |
| | string += f"# Title: {content}\n" |
| | except: |
| | string += f"# Title: None\n" |
| | |
| | try: |
| | content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False) |
| | string += f"# Abstract: {content}\n" |
| | except: |
| | string += f"# Abstract: None\n" |
| | |
| | |
| | try: |
| | content = SurveyManager._abbr_one_line(current_survey["introduction"]) |
| | string += f"# Introduction: {content}\n" |
| | except: |
| | string += f"# Introduction: None\n" |
| | |
| | |
| | if "sections" in current_survey: |
| | for i, section in enumerate(current_survey["sections"]): |
| | title_key = "name" if "name" in section else "title" |
| | name, content = section[title_key], SurveyManager._abbr_one_line(section) |
| | string += f"# Section-{i+1} [{name}]: {content}\n" |
| |
|
| | if "subsections" in section: |
| | for j, subsection in enumerate(section["subsections"]): |
| | name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection) |
| | string += f" ## Subsection-{j+1} [{name}]: {content}\n" |
| |
|
| | if "subsubsections" in subsection: |
| | for k, subsubsection in enumerate(subsection["subsubsections"]): |
| | name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection) |
| | string += f" ### Subsubsection-{k+1} [{name}]: {content}\n" |
| | |
| | |
| | try: |
| | content = SurveyManager._abbr_one_line(current_survey["conclusion"]) |
| | string += f"# Conclusion: {content}\n" |
| | except: |
| | string += f"# Conclusion: None\n" |
| | |
| | return string |
| |
|
| | @staticmethod |
| | def update_one_section(sections, i, content): |
| | |
| | if i >= 0 and i <= (len(sections)-1): |
| | sections[i]["content"] = content |
| | return True |
| | else: |
| | |
| | return False |
| | |
| | @staticmethod |
| | def update_current_survey(current_survey, answer) -> bool: |
| | """ |
| | update_pos: "section-i/subsection-j/subsubsection-k" |
| | """ |
| | |
| | |
| | try: |
| | update_pos, content = answer["update"], answer["content"] |
| |
|
| | if update_pos == "plan": |
| | |
| | if current_survey == {}: |
| | for k,v in content.items(): |
| | current_survey[k] = copy.deepcopy(v) |
| | else: |
| | return False |
| | elif update_pos in ["conclusion", "abstract"]: |
| | if update_pos not in current_survey: |
| | |
| | return False |
| | current_survey[update_pos] = content |
| |
|
| | elif update_pos == "introduction": |
| | if update_pos not in current_survey: |
| | |
| | return False |
| | current_survey[update_pos] = {"content": content} |
| |
|
| | else: |
| | keys = update_pos.split("/") |
| | if len(keys) == 1: |
| | i = int(keys[0].lower().split("section-")[-1])-1 |
| | return SurveyManager.update_one_section(current_survey["sections"], i, content) |
| |
|
| | elif len(keys) == 2: |
| | i = int(keys[0].lower().split("section-")[-1])-1 |
| | j = int(keys[1].lower().split("subsection-")[-1])-1 |
| | try: |
| | return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content) |
| | except: |
| | |
| | return False |
| |
|
| | elif len(keys) == 3: |
| | i = int(keys[0].lower().split("section-")[-1])-1 |
| | j = int(keys[1].lower().split("subsection-")[-1])-1 |
| | k = int(keys[2].lower().split("subsubsection-")[-1])-1 |
| | try: |
| | return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) |
| | except: |
| | |
| | return False |
| | else: |
| | |
| | |
| | return False |
| | |
| | except: |
| | |
| | return False |
| | |
| | |
| | |
| | return True |
| | |
| |
|
| | from prompts import * |
| | class PromptManger: |
| | system_prompt = SYSTEM_PROMPT_0415_BUFFER |
| | user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER |
| | user_prompt = USER_PROMPT_0415_BUFFER |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class BufferManager: |
| | """ |
| | Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training. |
| | batch_rollout_data = [ |
| | { |
| | query (or env_id): # Uniquely identifies a query or environment, [input parameter]. |
| | *running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat. |
| | state: { # Indicates whether the process is finished. |
| | "score": 0.0, |
| | "done": True / False |
| | "current_survey": dict # Structured data. |
| | } |
| | trajectory: [ # Organizes all data into a multi-turn interaction format. |
| | { |
| | step: int, 0~?, # The first step, usually includes some init_info or plan. |
| | original_response: str, The raw output from the model, which may have various formatting issues. |
| | answer_thought: str, # Encapsulated using the <think>...</think> block. |
| | answer: { |
| | "original_str": str |
| | "update": str, |
| | "name": str, |
| | "content": str, |
| | "inclusions": list, # Extracted independently? |
| | } |
| | tool_call_thought: str, # Encapsulated using the <think>...</think> block. |
| | tool_call: { |
| | "original_str": str, # Encapsulated using the <tool_call>...</tool_call> block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search". |
| | "tool_name": str # done or search. |
| | "keywords": list[str], Extracted search keywords from tool_call, otherwise none. |
| | } |
| | *papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed. |
| | cites: list[str], # References cited by the model, which may include multiple citations. |
| | summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY. |
| | *prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed. |
| | }, |
| | ... |
| | |
| | ] |
| | |
| | }, |
| | ... |
| | ] |
| | |
| | """ |
| | def __init__(self, prompts, repeat_n: int=1): |
| | |
| | self.step = 0 |
| | self.batch_rollout_data = [] |
| | self.running_ids = [] |
| | batch_size = prompts.batch['input_ids'].size(0) |
| | uids = prompts.non_tensor_batch['uid'] |
| | querys = prompts.non_tensor_batch['raw_prompt'].copy() |
| | ground_truths = prompts.non_tensor_batch['ground_truth'] |
| | |
| | new_querys = [] |
| | for i_batch in range(batch_size): |
| | raw_prompt_i_batch = querys[i_batch][-1]["content"] |
| | new_querys.append(raw_prompt_i_batch) |
| | querys = new_querys |
| | |
| | assert len(querys) == len(uids) |
| | for query, uid, ground_truth in zip(querys, uids, ground_truths): |
| | |
| | now_survey = {} |
| | |
| | for _ in range(repeat_n): |
| | self.batch_rollout_data.append({ |
| | "query": query, |
| | "uid": uid, |
| | "state": { |
| | |
| | |
| | "done": False, |
| | "current_survey": {} |
| | }, |
| | "trajectory": [], |
| | "history_messages": [], |
| | }) |
| | |
| | @staticmethod |
| | def _build_system_prompt(): |
| | prompt = PromptManger.system_prompt |
| | return prompt |
| | @staticmethod |
| | def _build_user_prompt_v0(query, current_survey): |
| | |
| | prompt = PromptManger.user_prompt_v0.replace("<user_query>", query) |
| | |
| | |
| | prompt = prompt.replace("<init_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) |
| | return prompt |
| | |
| | @staticmethod |
| | def _build_user_prompt(query, current_survey, trajs): |
| | last_traj = trajs[-1] |
| | |
| | prompt = PromptManger.user_prompt.replace("<user_query>", query) |
| | |
| | |
| | prompt = prompt.replace("<current_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) |
| | |
| | |
| | if last_traj["tool_call_thought"] == "": |
| | prompt = prompt.replace("<last_step_thought>", "Your last thought is not available, please give new plan") |
| | else: |
| | prompt = prompt.replace("<last_step_thought>", last_traj["tool_call_thought"]) |
| | prompt = prompt.replace("<last_step_tool_call>", json.dumps(last_traj["tool_call"])) |
| | |
| | |
| | for traj in reversed(trajs): |
| | if len(traj["summarys"]) > 0: |
| | break |
| | summary_num = len(traj["summarys"]) |
| | |
| | if summary_num == 0: |
| | prompt = prompt.replace("<summarys>", "There is no result.") |
| | else: |
| | prompt = prompt.replace("<summarys>", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"])) |
| | |
| | return prompt |
| | |
| | @staticmethod |
| | def _build_user_prompt_force_correct(query, current_survey, trajs): |
| | if current_survey == {}: |
| | |
| | now_section = "plan" |
| | |
| | else: |
| | now_section = "" |
| | if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: |
| | now_section = "abstract" |
| | elif "content" not in current_survey["introduction"]: |
| | now_section = "introduction" |
| | elif "sections" in current_survey: |
| | for section in current_survey["sections"]: |
| | if "content" not in section: |
| | now_section = "section-{}".format(current_survey["sections"].index(section) + 1) |
| | break |
| | elif "subsections" in section: |
| | for subsection in section["subsections"]: |
| | if "content" not in subsection: |
| | now_section = "section-{}/subsection-{}".format( |
| | current_survey["sections"].index(section) + 1, |
| | section["subsections"].index(subsection) + 1 |
| | ) |
| | break |
| | elif "subsubsections" in subsection: |
| | for subsubsection in subsection["subsubsections"]: |
| | if "content" not in subsubsection: |
| | now_section = "section-{}/subsection-{}/subsubsection-{}".format( |
| | current_survey["sections"].index(section) + 1, |
| | section["subsections"].index(subsection) + 1, |
| | subsection["subsubsections"].index(subsubsection) + 1 |
| | ) |
| | break |
| | if now_section: |
| | break |
| | if now_section: |
| | break |
| | |
| | elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: |
| | now_section = "conclusion" |
| | else: |
| | trajs[-1]["tool_call_thought"] = "Next I will finalize the survey." |
| | if now_section != "": |
| | trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}" |
| | for traj in reversed(trajs): |
| | if len(traj["summarys"]) > 0: |
| | break |
| | summary_num = len(traj["summarys"]) |
| | if now_section == "plan" and summary_num == 0: |
| | trajs[-1]["tool_call_thought"] = "I need to get enough information." |
| | |
| | return BufferManager._build_user_prompt(query, current_survey, trajs) |
| | |
| | @staticmethod |
| | def _check_finalize(query, current_survey, trajs): |
| | if current_survey == {}: |
| | |
| | return False |
| | |
| | else: |
| | now_section = "" |
| | if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: |
| | now_section = "abstract" |
| | elif "content" not in current_survey["introduction"]: |
| | now_section = "introduction" |
| | elif "sections" in current_survey: |
| | for section in current_survey["sections"]: |
| | if "content" not in section: |
| | now_section = "section-{}".format(current_survey["sections"].index(section) + 1) |
| | break |
| | elif "subsections" in section: |
| | for subsection in section["subsections"]: |
| | if "content" not in subsection: |
| | now_section = "section-{}/subsection-{}".format( |
| | current_survey["sections"].index(section) + 1, |
| | section["subsections"].index(subsection) + 1 |
| | ) |
| | break |
| | elif "subsubsections" in subsection: |
| | for subsubsection in subsection["subsubsections"]: |
| | if "content" not in subsubsection: |
| | now_section = "section-{}/subsection-{}/subsubsection-{}".format( |
| | current_survey["sections"].index(section) + 1, |
| | section["subsections"].index(subsection) + 1, |
| | subsection["subsubsections"].index(subsubsection) + 1 |
| | ) |
| | break |
| | if now_section: |
| | break |
| | if now_section: |
| | break |
| | |
| | elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: |
| | now_section = "conclusion" |
| | |
| | |
| | if now_section != "": |
| | return False |
| | |
| | return True |
| |
|
| | |
| | def build_prompt_for_generator(self): |
| | total_messages = [] |
| | self.running_ids = [] |
| | for running_id, data in enumerate(self.batch_rollout_data): |
| | if data["state"]["done"]: |
| | pass |
| | else: |
| | if len(data["trajectory"]) == 0: |
| | user_prompt = BufferManager._build_user_prompt_v0(data["query"], |
| | data["state"]["current_survey"]) |
| | else: |
| | if data["trajectory"][-1]["update_success"]: |
| | user_prompt = BufferManager._build_user_prompt(data["query"], |
| | data["state"]["current_survey"], |
| | data["trajectory"]) |
| | else: |
| | |
| | user_prompt = BufferManager._build_user_prompt_force_correct(data["query"], |
| | data["state"]["current_survey"], |
| | data["trajectory"]) |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": BufferManager._build_system_prompt(), |
| | }, |
| | { |
| | "role": "user", |
| | "content": user_prompt, |
| | } |
| | ] |
| | data["history_messages"].append(messages) |
| | total_messages.append(messages) |
| | self.running_ids.append(running_id) |
| | return total_messages |
| | |
| | def update_all_scores(self, scores): |
| | assert len(scores) == len(self.batch_rollout_data) |
| | for score, log in zip(scores, self.batch_rollout_data): |
| | log["state"]["score"] = score |
| | |
| | def update_all_format_scores(self, scores): |
| | assert len(scores) == len(self.batch_rollout_data) |
| | for score, log in zip(scores, self.batch_rollout_data): |
| | log["state"]["format_score"] = score |
| | |
| | |
| | def update_trajectory(self, model_responses, env_feedbacks): |
| | """ |
| | model_response: original_response, thought, paragraph, tool_call, format_reward |
| | env_feedback: done, search_keywards, abstracts, outcome_reward |
| | """ |
| | assert len(self.running_ids) == len(model_responses) |
| | assert len(self.running_ids) == len(env_feedbacks) |
| | |
| | for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks): |
| | |
| | self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] |
| | |
| | update_success = False |
| | if response["true"]: |
| | if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}: |
| | if len(response["answer"]) != 0: |
| | update_success = SurveyManager.update_current_survey( |
| | self.batch_rollout_data[running_id]["state"]["current_survey"], |
| | response["answer"]) |
| | else: |
| | |
| | if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]: |
| | update_success = SurveyManager.update_current_survey( |
| | self.batch_rollout_data[running_id]["state"]["current_survey"], |
| | response["answer"]) |
| | elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0: |
| | update_success = True |
| |
|
| |
|
| | self.batch_rollout_data[running_id]["trajectory"].append({ |
| | "step": self.step, |
| | "original_response": response["original_response"], |
| | "answer_thought": response["answer_thought"], |
| | "answer": response["answer"], |
| | "tool_call_thought": response["tool_call_thought"], |
| | "tool_call": response["tool_call"], |
| | "search_keywords": feedback["search_keywords"], |
| | "summarys": feedback["summarys"], |
| | "update_success": update_success and response["true"], |
| | }) |
| | |
| | |
| | self.batch_rollout_data[running_id]["history_messages"][-1].append({ |
| | "role": "assistant", |
| | "content": response["original_response"], |
| | }) |
| | |
| | if self.batch_rollout_data[running_id]["state"]["done"]: |
| | real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"], |
| | self.batch_rollout_data[running_id]["state"]["current_survey"], |
| | self.batch_rollout_data[running_id]["trajectory"]) |
| | if not real_done: |
| | self.batch_rollout_data[running_id]["state"]["done"] = False |
| |
|
| | |
| | @staticmethod |
| | def match_reference(text:str): |
| | reg = r"\\\w*cite(?!style)\w*\{(.+?)\}" |
| | placeholder_reg = re.compile(r"^#\d+$") |
| | reg_bibkeys = re.findall(reg, text) |
| | bibkeys = set() |
| | for bibkey in reg_bibkeys: |
| | single_bib = bibkey.split(",") |
| | for bib in single_bib: |
| | if not placeholder_reg.match(bib): |
| | bib = bib.strip() |
| | if bib and bib != "*": |
| | bibkeys.add(bib) |
| |
|
| | reg = r"\\nocite{(.+?)\}" |
| | reg_bibkeys = re.findall(reg, text) |
| | for bibkey in reg_bibkeys: |
| | single_bib = bibkey.split(",") |
| | for bib in single_bib: |
| | if not placeholder_reg.match(bib): |
| | bib = bib.strip() |
| | if bib and bib != "*": |
| | bibkeys.remove(bib) |
| |
|
| | ref_key_list = list(bibkeys) |
| | return ref_key_list |
| | |
| | @staticmethod |
| | def parse_generator_response(response): |
| | """ |
| | 1. 解析失败: step + 1, 重新生成, 给出提示 |
| | 2. 解析成功: |
| | 2.1 tool_call == search(keywords) 发送post请求 |
| | 2.2 tool_call == done 结束任务 |
| | |
| | **standard format** |
| | |
| | Current Update: |
| | <think> [Your Thoughts]: str </think> |
| | <answer> {"update": str, "content": str}: dict </answer> |
| | |
| | Next Plan: |
| | <think> [Your Thoughts]: str </think> |
| | <tool_call> {"tool": "search", "arguments": {}}: dict</tool_call> |
| | """ |
| | extracted_result = { |
| | "original_response": response |
| | } |
| | |
| | try: |
| | current_update = response.split("Current Update:")[-1].split("Next Plan:")[0] |
| | except: |
| | current_update = response |
| | |
| | |
| | think_pattern = r"<think>(.*?)</think>" |
| | answer_pattern = r"<answer>(.*?)</answer>" |
| | tool_pattern = r"<tool_call>(.*?)</tool_call>" |
| |
|
| | |
| |
|
| | think_match = re.search(think_pattern, current_update, re.DOTALL) |
| | if think_match: |
| | think = think_match.group(1) |
| | think = think.strip() |
| | else: |
| | think = "" |
| | extracted_result["answer_thought"] = think |
| | |
| | answer_match = re.search(answer_pattern, current_update, re.DOTALL) |
| | has_answer = False |
| | if answer_match: |
| | answer = answer_match.group(1) |
| | answer = answer.strip() |
| | try: |
| | answer = json.loads(answer) |
| | if not answer == {}: |
| | assert isinstance(answer["update"], str) |
| | answer["update"] = SurveyManager.parse_update_pos(answer["update"]) |
| | if answer["update"] == "plan": |
| | |
| | assert isinstance(answer["content"], dict) |
| | plan = answer["content"] |
| | assert isinstance(plan, dict) |
| | plan.pop("instruction",None) |
| | keys = ["abstract", "introduction", "conclusion","sections","title"] |
| | for key in keys: |
| | assert key in plan |
| | for key in plan: |
| | assert key in keys |
| | if key == "sections": |
| | assert isinstance(plan[key], list) |
| | for section in plan[key]: |
| | assert isinstance(section, dict) |
| | assert "plan" in section |
| | assert "title" in section |
| | assert isinstance(section["plan"], str) |
| | assert isinstance(section["title"], str) |
| | assert section["title"] != "Methodology" |
| | if "subsections" in section: |
| | assert isinstance(section["subsections"], list) |
| | for subsection in section["subsections"]: |
| | assert isinstance(subsection, dict) |
| | assert "plan" in subsection |
| | assert "title" in subsection |
| | assert isinstance(subsection["plan"], str) |
| | assert isinstance(subsection["title"], str) |
| | if "subsubsections" in section: |
| | assert isinstance(subsection["subsubsections"], list) |
| | for subsubsection in subsection["subsubsections"]: |
| | assert isinstance(subsubsection, dict) |
| | assert "plan" in subsubsection |
| | assert "title" in subsubsection |
| | assert isinstance(subsubsection["plan"], str) |
| | assert isinstance(subsubsection["title"], str) |
| | elif key == "title": |
| | assert isinstance(plan[key], str) |
| | else: |
| | assert isinstance(plan[key], dict) |
| | assert "plan" in plan[key] |
| | if key not in ["abstract", "conclusion", "introduction"]: |
| | assert "title" in plan[key] |
| | else: |
| | assert isinstance(answer["content"], str) |
| | has_answer = True |
| | except: |
| | answer = {} |
| | else: |
| | answer = {} |
| | extracted_result["answer"] = answer |
| | |
| | |
| | |
| | try: |
| | next_plan = response.split("Next Plan:")[1] |
| | except: |
| | try: |
| | next_plan = response.split("</answer>")[1] |
| | except: |
| | next_plan = response |
| | |
| | think_match = re.search(think_pattern, next_plan, re.DOTALL) |
| | if think_match: |
| | think = think_match.group(1) |
| | think = think.strip() |
| | else: |
| | think = "" |
| | extracted_result["tool_call_thought"] = think |
| | |
| | tool_match = re.search(tool_pattern, next_plan, re.DOTALL) |
| | has_tool_call = False |
| | if tool_match: |
| | tool_text = tool_match.group(1) |
| | tool_text = tool_text.strip() |
| | try: |
| | tool_call = json.loads(tool_text) |
| | assert tool_call["name"] in ["search_engine", "finalize"] |
| | if tool_call["name"] == "search_engine": |
| | assert isinstance(tool_call["arguments"]["query"], list) |
| | has_tool_call = True |
| | except: |
| | tool_call = {} |
| | else: |
| | |
| | tool_call = {} |
| | |
| | extracted_result["tool_call"] = tool_call |
| |
|
| | extracted_result["true"] = has_answer and has_tool_call |
| | reg = r"[\u4e00-\u9fa5]" |
| | has_chinese = re.search(reg, response) is not None |
| | extracted_result["true"] = extracted_result["true"] and not has_chinese |
| | |
| | return extracted_result |
| | |
| |
|
| | class BufferManager_V2(BufferManager): |
| |
|
| | def __init__(self, querys, repeat_n=1): |
| | |
| | self.step = 0 |
| | self.batch_rollout_data = [] |
| | self.running_ids = [] |
| | |
| | for uid, query in enumerate(querys): |
| | print("CURRENT QUERY: ", query) |
| | for _ in range(repeat_n): |
| | self.batch_rollout_data.append({ |
| | "query": query, |
| | "uid": f"query_{uid}", |
| | "state": { |
| | |
| | |
| | "done": False, |
| | "current_survey": {} |
| | }, |
| | "trajectory": [], |
| | "history_messages": [] |
| | }) |
| |
|
| |
|