|
|
|
|
|
|
|
|
| import pickle
|
| import json
|
| from pathlib import Path
|
| from tqdm import tqdm
|
| from scripts.tools.tool_libraries import FuncAgent
|
| from scripts.tools.tool_prompts import get_system_prompt, get_ego_prompts, get_detection_prompt
|
|
|
|
|
| class AgentThink:
|
| def __init__(self, token: str = None, split: str = 'train',
|
| data_path: str = 'DriveLMM-o1-main/data/tool_results',
|
| drivelmm_json_file: str = 'Drive-MLLM-main/data/DriveLMMo1/DriveLMMo1_TEST.json',
|
| model_name: str = "qwen2.5-VL", verbose: bool = False) -> None:
|
| """
|
| Initialize AgentThink class for processing driving scenario data.
|
|
|
| Args:
|
| token (str): Token identifier for the data
|
| split (str): Data split type ('train' or 'val')
|
| data_path (str): Path to tool results data
|
| drivelmm_json_file (str): Path to DriveLMM JSON file
|
| model_name (str): Name of the model being used
|
| verbose (bool): Whether to show detailed logs
|
| """
|
| self.token = token
|
| self.split = split
|
| self.data_path = data_path
|
| self.model_name = model_name
|
| self.verbose = verbose
|
|
|
|
|
| folder_name = Path("val") if "val" in split else Path("train")
|
| self.file_name = Path(data_path) / folder_name / Path(f"{self.token}.pkl")
|
| with open(self.file_name, "rb") as f:
|
| self.data_dict = pickle.load(f)
|
|
|
|
|
| self.func_agent = FuncAgent(self.data_dict)
|
|
|
|
|
| self.num_call_detection_times = 3
|
| self.num_call_prediction_times = 1
|
| self.num_call_occupancy_times = 1
|
| self.num_call_map_times = 1
|
|
|
| def _preprocess_tool_results(self, json_data):
|
| """
|
| Convert agent-driver data into the scene in drivelmm-o1 format.
|
|
|
| Args:
|
| json_data: JSON data to preprocess
|
|
|
| Returns:
|
| Preprocessed JSON data with tool results
|
| """
|
|
|
| new_json_data = []
|
| for sample in json_data:
|
| sample_idx = sample['idx']
|
| scene_token = sample_idx.split('_')[0]
|
| frame_token = sample_idx.split('_')[1]
|
|
|
|
|
| folder_name = Path("val") if "val" in self.split else Path("train")
|
| file_name = Path(self.data_path) / folder_name / Path(f"{frame_token}.pkl")
|
| with open(file_name, "rb") as f:
|
| data_dict = pickle.load(f)
|
|
|
|
|
| sample['tool_results'] = data_dict
|
| new_json_data.append(sample)
|
|
|
|
|
| output_file = f'cot_{self.split}_{self.model_name}.json'
|
| with open(output_file, 'w', encoding='utf-8') as f:
|
| json.dump(new_json_data, f, indent=4)
|
|
|
| return new_json_data
|
|
|
| def tool_call(self, response_message):
|
| """
|
| Execute a tool call based on the response message.
|
|
|
| Args:
|
| response_message: Message containing tool call information
|
|
|
| Returns:
|
| Tool response dictionary or None if call fails
|
| """
|
| try:
|
| tool_name = response_message['Tool']['function_name']
|
| except (KeyError, TypeError):
|
| return None
|
|
|
| if tool_name == '' or tool_name == 'none':
|
| return None
|
|
|
| function_args = response_message['Tool']['parameters']
|
| if len(function_args) > 0:
|
| if function_args[0] == '':
|
| function_args = {}
|
| else:
|
| return None
|
|
|
|
|
| if isinstance(function_args, list):
|
| if 'occupancy' in tool_name:
|
| locations = function_args[0]
|
| timestep = function_args[1]
|
| if isinstance(locations, list):
|
| locations = tuple(locations)
|
| function_args = {'locations': [locations], 'timestep': timestep}
|
| elif 'location' in tool_name:
|
| locations = function_args[0]
|
| if isinstance(locations, list):
|
| locations = tuple(locations)
|
| function_args = {'locations': [locations]}
|
| else:
|
| if 'open' not in tool_name:
|
| obj_list = function_args[0]
|
| function_args = {'object_ids': obj_list}
|
| else:
|
| tool_name = 'get_open_world_vocabulary_detection'
|
| obj_list = function_args[0]
|
| function_args = {'object_names': obj_list}
|
|
|
|
|
| try:
|
| function_to_call = getattr(self.func_agent, tool_name)
|
| except AttributeError:
|
| return None
|
|
|
|
|
| if not callable(function_to_call):
|
| print(f"Function {tool_name} is not callable!")
|
| return None
|
| else:
|
| try:
|
| tool_returns = function_to_call(**function_args)
|
| except Exception:
|
| return None
|
|
|
| tool_prompt, tool_result_data = tool_returns
|
|
|
| if tool_prompt is None:
|
| tool_prompt = ""
|
|
|
|
|
| tool_response = {
|
| "name": tool_name,
|
| "args": function_args,
|
| "prompt": tool_prompt,
|
| }
|
|
|
| if self.verbose:
|
| print(f"Tool: {tool_name}")
|
| print(f"Args: {function_args}")
|
| print(f"Prompt: {tool_prompt}")
|
|
|
| return tool_response
|
|
|
| def get_tool_results(self, sample, ego_prompts=None):
|
| """
|
| Collect information from driving scenarios using chain-of-thought reasoning with function calls.
|
|
|
| Args:
|
| sample: Data sample to process
|
| ego_prompts: Optional ego prompts to include
|
|
|
| Returns:
|
| Tuple of (full_messages, system_message, tool_responses)
|
| """
|
|
|
| init_system_message = get_system_prompt()
|
| full_messages = []
|
| tool_responses = []
|
|
|
|
|
| system_message = init_system_message + "\n" + ego_prompts + "\n"
|
|
|
| if self.verbose:
|
| print("System Message:", system_message)
|
| print("Detection Prompt:", get_detection_prompt())
|
|
|
|
|
| cot_data = sample['cot_data']
|
| tool_chain = cot_data['Chain']
|
|
|
|
|
| cur_num_det_tool_call = 0
|
| for chain_node in tool_chain:
|
| try:
|
| tool_name = chain_node['Tool']['function_name']
|
| except (KeyError, TypeError):
|
| continue
|
|
|
|
|
| if 'detection' in tool_name:
|
| cur_num_det_tool_call += 1
|
| if cur_num_det_tool_call > self.num_call_detection_times:
|
| continue
|
|
|
|
|
| tool_response = self.tool_call(chain_node)
|
|
|
|
|
| if tool_response is not None:
|
| full_messages.append({
|
| 'role': 'function',
|
| 'name': tool_response['name'],
|
| 'content': tool_response['prompt'],
|
| })
|
|
|
| tool_responses.append(tool_response)
|
|
|
| return full_messages, system_message, tool_responses
|
|
|
|
|
| def main(drivelmm_json_file="/path/to/final_cot_test_gpt-4.1-mini.json"):
|
| """
|
| Main function to process DriveLMM JSON data with AgentThink.
|
|
|
| Args:
|
| drivelmm_json_file: Path to DriveLMM JSON file
|
|
|
| Returns:
|
| Processed JSON data
|
| """
|
|
|
| with open(drivelmm_json_file, "r", encoding="utf-8") as file:
|
| json_data = json.load(file)
|
|
|
|
|
| new_json_data = []
|
| for index, sample in enumerate(tqdm(json_data, desc="Processing JSON samples")):
|
| sample_idx = sample['idx']
|
| scene_token = sample_idx.split('_')[0]
|
| frame_token = sample_idx.split('_')[1]
|
|
|
|
|
| agent = AgentThink(
|
| token=frame_token,
|
| split='val',
|
| data_path="/path/to/tool_results",
|
| drivelmm_json_file=drivelmm_json_file,
|
| model_name='Qwen2.5-VL'
|
| )
|
|
|
| cur_data_dict = agent.data_dict
|
| ego_prompts = get_ego_prompts(cur_data_dict)
|
|
|
|
|
| full_messages, system_prompts, tool_responses = agent.get_tool_results(
|
| sample=sample,
|
| ego_prompts=ego_prompts
|
| )
|
|
|
|
|
| sample['tool_result'] = tool_responses
|
| sample['system_prompts'] = system_prompts
|
| new_json_data.append(sample)
|
|
|
|
|
| output_file = f'{agent.data_path}/cot_{agent.split}_{agent.model_name}.json'
|
| with open(output_file, 'w', encoding='utf-8') as f:
|
| json.dump(new_json_data, f, indent=4)
|
|
|
| return new_json_data
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| main() |