# Program name: Initial node generator # Author: Jiayi Chen # Date: 2023/7/27 # Description: This program will automatically generate the initial node of a traffic scenario by map and scenario description from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema.messages import BaseMessage from langchain.memory.chat_message_histories import ChatMessageHistory import time import json import copy import os from generate_scenario.generate_init import map_prompts def parse_json(s:str)->object: try: p = json.loads(s) except: return s return p def check_format(history:ChatMessageHistory, chat:AzureChatOpenAI, prompt_template:ChatPromptTemplate, formaterror:str, result:BaseMessage)->bool: keep_history = copy.deepcopy(history) history.add_ai_message(result.content) cnt = 0 while type(parse_json(result.content)) == str and cnt < 5: print(result.content) result = chat(prompt_template.format_prompt(text=formaterror).to_messages(), history=history) history.add_user_message(formaterror) history.add_ai_message(result.content) cnt += 1 if type(parse_json(result.content)) != str: history = copy.deepcopy(keep_history) history.add_ai_message(result.content) return True else: return False def decode_judge( scenario:str, judgement:str, curr_score, curr_ans ): judgement = parse_json(judgement) if type(judgement) == str: print("Cannot decode the string, it's not in json format") exit(-1) score = judgement["score"] lst_error = "" for index, data in enumerate(judgement["mistake"], start=1): lst_error += f"mistake {index}: " + data["description"] + "\n" if lst_error != "": lst_error += "You should correct the mistake and regenerate the initial node. Please output in only one json file" if score > curr_score: return score, scenario, lst_error return curr_score, curr_ans, lst_error def find_json(answer:str): start = answer.find("{") if start == -1: return "" end = answer.rfind("}") if end == -1: return "" json_answer = answer[start:end+1] return json_answer os.environ["OPENAI_API_TYPE"] = "azure" os.environ["OPENAI_API_VERSION"] = "2023-03-15-preview" os.environ["OPENAI_API_BASE"] = "https://scenariogen.openai.azure.com/" os.environ["OPENAI_API_KEY"] = "02c96d0cde1f4b07a7e77948d57aaaec" class get_initial_point(object): def __init__(self, map_prompt): self.generator = AzureChatOpenAI( deployment_name="gpt35turbo16k", model_name="gpt-35-turbo-16k", temperature=0.9, max_tokens=3000, ) self.judge = AzureChatOpenAI( deployment_name="gpt35turbo16k", model_name="gpt-35-turbo-16k", temperature=0.9, max_tokens=3000, ) self.curr_gene_text = map_prompts.generator_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.generator_answer curr_judge_text = map_prompts.judge_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.judge_answer human_template="{text}" system_prompt_gene = SystemMessagePromptTemplate.from_template(self.curr_gene_text) system_prompt_judge = SystemMessagePromptTemplate.from_template(curr_judge_text) human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) self.gene_prompt = ChatPromptTemplate.from_messages([system_prompt_gene, human_message_prompt]) self.judge_prompt = ChatPromptTemplate.from_messages([system_prompt_judge, human_message_prompt]) self.gene_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.generator_answer self.judge_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.judge_answer def get_init(self, description): scena_gen = description + "\nYou have to select the location of these vehicle.Use #inner_monologue to indicate the thinking process and #json_answer to indicate the json format answer.Anything that are not relevent are not required\n" print(scena_gen) result_gene = self.generator(self.gene_prompt.format_prompt(text=scena_gen).to_messages()) print(result_gene.content) ans = result_gene.content return find_json(ans) scena_judge = description + "\n--------------------\nhere's the json file from the scenario generator\n" + json_answer_gene + "\n----------------------------\nPlease help to judge the initial nodes" result_judge = self.judge(self.judge_prompt.format_prompt(text=scena_judge).to_messages()) start = result_judge.content.find("#json_answer") + len("#json_answer") end = len(result_judge.content) json_answer = result_judge.content[start:end] # curr_score = 0 # curr_ans = "" # print("----------------------------------------------------") # print(json_answer) # x = input("pauseLLLLL1111") # history_judge = ChatMessageHistory() # history_judge.add_user_message(curr_judge_text + "\n" + scena_judge) # if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False: # print("Error in judge!!!!!") # break # time.sleep(60) # curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans ) # cnt = 0 # while cnt <= 5: # cnt += 1 # if lst_error == "": # break # result_gene = generator(gene_prompt.format_prompt(text=lst_error, history = history_gene.messages).to_messages()) # print(result_judge.content) # history_gene = ChatMessageHistory() # history_gene.add_user_message(lst_error) # if check_format(history_gene, generator, gene_prompt, gene_formaterror, result_gene) == False: # print("Error in gene!!!!!") # break # scena_judge = scena + "\n--------------------\nhere's the json file from the scenario generator\n" + result_gene.content + "\n----------------------------\nPlease help to judge the initial nodes and output your answer with only one json format file" # result_judge = judge(judge_prompt.format_prompt(text=scena_judge).to_messages()) # print(result_judge.content) # history_judge = ChatMessageHistory() # history_judge.add_user_message(curr_judge_text + "\n" + scena_judge) # if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False: # print("Error in judge!!!!!") # break # curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans ) # time.sleep(60) # print( curr_ans) # print(curr_score) # x = input("pauseLLLL")