Spaces:
Runtime error
Runtime error
| # Program name: Initial node generator | |
| # Author: Jiayi Chen | |
| # Date: 2023/7/27 | |
| # Description: This program will automatically generate the initial node of a traffic scenario by map and scenario description | |
| from langchain.chat_models import ChatOpenAI, AzureChatOpenAI | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| from langchain.schema.messages import BaseMessage | |
| from langchain.memory.chat_message_histories import ChatMessageHistory | |
| import time | |
| import json | |
| import copy | |
| import os | |
| from generate_scenario.generate_init import map_prompts | |
| def parse_json(s:str)->object: | |
| try: | |
| p = json.loads(s) | |
| except: | |
| return s | |
| return p | |
| def check_format(history:ChatMessageHistory, chat:AzureChatOpenAI, prompt_template:ChatPromptTemplate, formaterror:str, result:BaseMessage)->bool: | |
| keep_history = copy.deepcopy(history) | |
| history.add_ai_message(result.content) | |
| cnt = 0 | |
| while type(parse_json(result.content)) == str and cnt < 5: | |
| print(result.content) | |
| result = chat(prompt_template.format_prompt(text=formaterror).to_messages(), history=history) | |
| history.add_user_message(formaterror) | |
| history.add_ai_message(result.content) | |
| cnt += 1 | |
| if type(parse_json(result.content)) != str: | |
| history = copy.deepcopy(keep_history) | |
| history.add_ai_message(result.content) | |
| return True | |
| else: | |
| return False | |
| def decode_judge( scenario:str, judgement:str, curr_score, curr_ans ): | |
| judgement = parse_json(judgement) | |
| if type(judgement) == str: | |
| print("Cannot decode the string, it's not in json format") | |
| exit(-1) | |
| score = judgement["score"] | |
| lst_error = "" | |
| for index, data in enumerate(judgement["mistake"], start=1): | |
| lst_error += f"mistake {index}: " + data["description"] + "\n" | |
| if lst_error != "": | |
| lst_error += "You should correct the mistake and regenerate the initial node. Please output in only one json file" | |
| if score > curr_score: | |
| return score, scenario, lst_error | |
| return curr_score, curr_ans, lst_error | |
| def find_json(answer:str): | |
| start = answer.find("{") | |
| if start == -1: | |
| return "" | |
| end = answer.rfind("}") | |
| if end == -1: | |
| return "" | |
| json_answer = answer[start:end+1] | |
| return json_answer | |
| os.environ["OPENAI_API_TYPE"] = "azure" | |
| os.environ["OPENAI_API_VERSION"] = "2023-03-15-preview" | |
| os.environ["OPENAI_API_BASE"] = "https://scenariogen.openai.azure.com/" | |
| os.environ["OPENAI_API_KEY"] = "02c96d0cde1f4b07a7e77948d57aaaec" | |
| class get_initial_point(object): | |
| def __init__(self, map_prompt): | |
| self.generator = AzureChatOpenAI( | |
| deployment_name="gpt35turbo16k", | |
| model_name="gpt-35-turbo-16k", | |
| temperature=0.9, | |
| max_tokens=3000, | |
| ) | |
| self.judge = AzureChatOpenAI( | |
| deployment_name="gpt35turbo16k", | |
| model_name="gpt-35-turbo-16k", | |
| temperature=0.9, | |
| max_tokens=3000, | |
| ) | |
| self.curr_gene_text = map_prompts.generator_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.generator_answer | |
| curr_judge_text = map_prompts.judge_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.judge_answer | |
| human_template="{text}" | |
| system_prompt_gene = SystemMessagePromptTemplate.from_template(self.curr_gene_text) | |
| system_prompt_judge = SystemMessagePromptTemplate.from_template(curr_judge_text) | |
| human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | |
| self.gene_prompt = ChatPromptTemplate.from_messages([system_prompt_gene, human_message_prompt]) | |
| self.judge_prompt = ChatPromptTemplate.from_messages([system_prompt_judge, human_message_prompt]) | |
| self.gene_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.generator_answer | |
| self.judge_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.judge_answer | |
| def get_init(self, description): | |
| scena_gen = description + "\nYou have to select the location of these vehicle.Use #inner_monologue to indicate the thinking process and #json_answer to indicate the json format answer.Anything that are not relevent are not required\n" | |
| print(scena_gen) | |
| result_gene = self.generator(self.gene_prompt.format_prompt(text=scena_gen).to_messages()) | |
| print(result_gene.content) | |
| ans = result_gene.content | |
| return find_json(ans) | |
| scena_judge = description + "\n--------------------\nhere's the json file from the scenario generator\n" + json_answer_gene + "\n----------------------------\nPlease help to judge the initial nodes" | |
| result_judge = self.judge(self.judge_prompt.format_prompt(text=scena_judge).to_messages()) | |
| start = result_judge.content.find("#json_answer") + len("#json_answer") | |
| end = len(result_judge.content) | |
| json_answer = result_judge.content[start:end] | |
| # curr_score = 0 | |
| # curr_ans = "" | |
| # print("----------------------------------------------------") | |
| # print(json_answer) | |
| # x = input("pauseLLLLL1111") | |
| # history_judge = ChatMessageHistory() | |
| # history_judge.add_user_message(curr_judge_text + "\n" + scena_judge) | |
| # if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False: | |
| # print("Error in judge!!!!!") | |
| # break | |
| # time.sleep(60) | |
| # curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans ) | |
| # cnt = 0 | |
| # while cnt <= 5: | |
| # cnt += 1 | |
| # if lst_error == "": | |
| # break | |
| # result_gene = generator(gene_prompt.format_prompt(text=lst_error, history = history_gene.messages).to_messages()) | |
| # print(result_judge.content) | |
| # history_gene = ChatMessageHistory() | |
| # history_gene.add_user_message(lst_error) | |
| # if check_format(history_gene, generator, gene_prompt, gene_formaterror, result_gene) == False: | |
| # print("Error in gene!!!!!") | |
| # break | |
| # scena_judge = scena + "\n--------------------\nhere's the json file from the scenario generator\n" + result_gene.content + "\n----------------------------\nPlease help to judge the initial nodes and output your answer with only one json format file" | |
| # result_judge = judge(judge_prompt.format_prompt(text=scena_judge).to_messages()) | |
| # print(result_judge.content) | |
| # history_judge = ChatMessageHistory() | |
| # history_judge.add_user_message(curr_judge_text + "\n" + scena_judge) | |
| # if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False: | |
| # print("Error in judge!!!!!") | |
| # break | |
| # curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans ) | |
| # time.sleep(60) | |
| # print( curr_ans) | |
| # print(curr_score) | |
| # x = input("pauseLLLL") |