puffyyy's picture
Upload 101 files
33bdf0e
# Program name: Initial node generator
# Author: Jiayi Chen
# Date: 2023/7/27
# Description: This program will automatically generate the initial node of a traffic scenario by map and scenario description
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema.messages import BaseMessage
from langchain.memory.chat_message_histories import ChatMessageHistory
import time
import json
import copy
import os
from generate_scenario.generate_init import map_prompts
def parse_json(s:str)->object:
try:
p = json.loads(s)
except:
return s
return p
def check_format(history:ChatMessageHistory, chat:AzureChatOpenAI, prompt_template:ChatPromptTemplate, formaterror:str, result:BaseMessage)->bool:
keep_history = copy.deepcopy(history)
history.add_ai_message(result.content)
cnt = 0
while type(parse_json(result.content)) == str and cnt < 5:
print(result.content)
result = chat(prompt_template.format_prompt(text=formaterror).to_messages(), history=history)
history.add_user_message(formaterror)
history.add_ai_message(result.content)
cnt += 1
if type(parse_json(result.content)) != str:
history = copy.deepcopy(keep_history)
history.add_ai_message(result.content)
return True
else:
return False
def decode_judge( scenario:str, judgement:str, curr_score, curr_ans ):
judgement = parse_json(judgement)
if type(judgement) == str:
print("Cannot decode the string, it's not in json format")
exit(-1)
score = judgement["score"]
lst_error = ""
for index, data in enumerate(judgement["mistake"], start=1):
lst_error += f"mistake {index}: " + data["description"] + "\n"
if lst_error != "":
lst_error += "You should correct the mistake and regenerate the initial node. Please output in only one json file"
if score > curr_score:
return score, scenario, lst_error
return curr_score, curr_ans, lst_error
def find_json(answer:str):
start = answer.find("{")
if start == -1:
return ""
end = answer.rfind("}")
if end == -1:
return ""
json_answer = answer[start:end+1]
return json_answer
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "2023-03-15-preview"
os.environ["OPENAI_API_BASE"] = "https://scenariogen.openai.azure.com/"
os.environ["OPENAI_API_KEY"] = "02c96d0cde1f4b07a7e77948d57aaaec"
class get_initial_point(object):
def __init__(self, map_prompt):
self.generator = AzureChatOpenAI(
deployment_name="gpt35turbo16k",
model_name="gpt-35-turbo-16k",
temperature=0.9,
max_tokens=3000,
)
self.judge = AzureChatOpenAI(
deployment_name="gpt35turbo16k",
model_name="gpt-35-turbo-16k",
temperature=0.9,
max_tokens=3000,
)
self.curr_gene_text = map_prompts.generator_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.generator_answer
curr_judge_text = map_prompts.judge_prompts + "\nNext is the current map information\n" + map_prompt + "\n" + map_prompts.judge_answer
human_template="{text}"
system_prompt_gene = SystemMessagePromptTemplate.from_template(self.curr_gene_text)
system_prompt_judge = SystemMessagePromptTemplate.from_template(curr_judge_text)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
self.gene_prompt = ChatPromptTemplate.from_messages([system_prompt_gene, human_message_prompt])
self.judge_prompt = ChatPromptTemplate.from_messages([system_prompt_judge, human_message_prompt])
self.gene_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.generator_answer
self.judge_formaterror = "Please check the format of your answer. You should only output the json format answer. e.g.\n" + map_prompts.judge_answer
def get_init(self, description):
scena_gen = description + "\nYou have to select the location of these vehicle.Use #inner_monologue to indicate the thinking process and #json_answer to indicate the json format answer.Anything that are not relevent are not required\n"
print(scena_gen)
result_gene = self.generator(self.gene_prompt.format_prompt(text=scena_gen).to_messages())
print(result_gene.content)
ans = result_gene.content
return find_json(ans)
scena_judge = description + "\n--------------------\nhere's the json file from the scenario generator\n" + json_answer_gene + "\n----------------------------\nPlease help to judge the initial nodes"
result_judge = self.judge(self.judge_prompt.format_prompt(text=scena_judge).to_messages())
start = result_judge.content.find("#json_answer") + len("#json_answer")
end = len(result_judge.content)
json_answer = result_judge.content[start:end]
# curr_score = 0
# curr_ans = ""
# print("----------------------------------------------------")
# print(json_answer)
# x = input("pauseLLLLL1111")
# history_judge = ChatMessageHistory()
# history_judge.add_user_message(curr_judge_text + "\n" + scena_judge)
# if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False:
# print("Error in judge!!!!!")
# break
# time.sleep(60)
# curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans )
# cnt = 0
# while cnt <= 5:
# cnt += 1
# if lst_error == "":
# break
# result_gene = generator(gene_prompt.format_prompt(text=lst_error, history = history_gene.messages).to_messages())
# print(result_judge.content)
# history_gene = ChatMessageHistory()
# history_gene.add_user_message(lst_error)
# if check_format(history_gene, generator, gene_prompt, gene_formaterror, result_gene) == False:
# print("Error in gene!!!!!")
# break
# scena_judge = scena + "\n--------------------\nhere's the json file from the scenario generator\n" + result_gene.content + "\n----------------------------\nPlease help to judge the initial nodes and output your answer with only one json format file"
# result_judge = judge(judge_prompt.format_prompt(text=scena_judge).to_messages())
# print(result_judge.content)
# history_judge = ChatMessageHistory()
# history_judge.add_user_message(curr_judge_text + "\n" + scena_judge)
# if check_format(history_judge, judge, judge_prompt, judge_formaterror, result_judge) == False:
# print("Error in judge!!!!!")
# break
# curr_score, curr_ans, lst_error = decode_judge(result_gene.content, result_judge.content, curr_score, curr_ans )
# time.sleep(60)
# print( curr_ans)
# print(curr_score)
# x = input("pauseLLLL")