Spaces:
Running
Running
Commit
·
c366528
1
Parent(s):
bfead6e
Update model_inference.py
Browse files- model_inference.py +55 -40
model_inference.py
CHANGED
|
@@ -13,32 +13,42 @@ def load_json(prompt_path, endpoint_path):
|
|
| 13 |
|
| 14 |
return prompt_dict, endpoint_dict
|
| 15 |
|
| 16 |
-
def construct_message(
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
summarize_prompt = f"[Response 1]: {contexts[0]}\n[Response 2]: {contexts[1]}\nResponse 3: {contexts[2]}\n\nThese are response of each model to a certain question. Summarize comprehensively without compromising the meaning of each response."
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
]
|
| 31 |
|
| 32 |
-
|
| 33 |
-
model="gpt-3.5-turbo-16k-0613",
|
| 34 |
-
messages=message,
|
| 35 |
-
max_tokens=256,
|
| 36 |
-
n=1
|
| 37 |
-
)
|
| 38 |
|
| 39 |
-
prefix_string =
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
def generate_question(agents, question):
|
| 44 |
agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
|
|
@@ -47,7 +57,7 @@ def generate_question(agents, question):
|
|
| 47 |
|
| 48 |
return agent_contexts, content
|
| 49 |
|
| 50 |
-
def Inference(model_list, question, API_KEY,
|
| 51 |
if len(model_list) != 3:
|
| 52 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
| 53 |
|
|
@@ -58,16 +68,21 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
| 58 |
def generate_answer(model, formatted_prompt):
|
| 59 |
API_URL = endpoint_dict[model]["API_URL"]
|
| 60 |
headers = endpoint_dict[model]["headers"]
|
| 61 |
-
payload = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
try:
|
| 63 |
resp = requests.post(API_URL, json=payload, headers=headers)
|
| 64 |
response = resp.json()
|
| 65 |
except:
|
| 66 |
print("retrying due to an error......")
|
| 67 |
time.sleep(5)
|
| 68 |
-
return generate_answer(
|
| 69 |
|
| 70 |
-
return {"model": model, "content": response[0]["generated_text"]
|
| 71 |
|
| 72 |
def prompt_formatting(model, instruction, cot):
|
| 73 |
if model == "alpaca" or model == "orca":
|
|
@@ -77,37 +92,37 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
| 77 |
|
| 78 |
if cot:
|
| 79 |
instruction += "Let's think step by step."
|
| 80 |
-
|
| 81 |
-
return {"model": model, "content": prompt.format(instruction)}
|
| 82 |
-
|
| 83 |
agents = len(model_list)
|
| 84 |
-
rounds =
|
| 85 |
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
| 89 |
|
| 90 |
# Debate
|
| 91 |
for debate in range(rounds+1):
|
| 92 |
# Refer to the summarized previous response
|
| 93 |
if debate != 0:
|
| 94 |
-
message
|
| 95 |
-
for i in range(agent_contexts):
|
| 96 |
agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
|
| 97 |
|
| 98 |
# Generate new response based on summarized response
|
| 99 |
for agent_context in agent_contexts:
|
| 100 |
-
completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"]
|
| 101 |
agent_context.append(completion)
|
| 102 |
|
| 103 |
models_response = {
|
| 104 |
-
f"{
|
| 105 |
-
f"{
|
| 106 |
-
f"{
|
| 107 |
}
|
| 108 |
response_summarization = [
|
| 109 |
-
|
| 110 |
]
|
| 111 |
-
generated_description
|
| 112 |
|
| 113 |
return generated_description
|
|
|
|
| 13 |
|
| 14 |
return prompt_dict, endpoint_dict
|
| 15 |
|
| 16 |
+
def construct_message(agent_context, instruction, idx):
|
| 17 |
+
prefix_string = "Here are a list of opinions from different agents: "
|
| 18 |
+
|
| 19 |
+
prefix_string = prefix_string + agent_context + "\n\n Write a summary of the different opinions from each of the individual agent."
|
| 20 |
+
|
| 21 |
+
message = [{"role": "user", "content": prefix_string}]
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
completion = openai.ChatCompletion.create(
|
| 25 |
+
model="gpt-3.5-turbo-0613",
|
| 26 |
+
messages=message,
|
| 27 |
+
max_tokens=256,
|
| 28 |
+
n=1
|
| 29 |
+
)['choices'][0]['message']['content']
|
| 30 |
+
except:
|
| 31 |
+
print("retrying ChatGPT due to an error......")
|
| 32 |
+
time.sleep(5)
|
| 33 |
+
return construct_message(agent_context, instruction, idx)
|
| 34 |
+
|
| 35 |
+
prefix_string = f"Here is a summary of responses from other agents: {completion}"
|
| 36 |
+
prefix_string = prefix_string + "\n\n Use this summarization carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response." + instruction
|
| 37 |
+
return prefix_string
|
| 38 |
|
| 39 |
+
def summarize_message(agent_contexts, instruction, idx):
|
| 40 |
+
prefix_string = "Here are a list of opinions from different agents: "
|
|
|
|
| 41 |
|
| 42 |
+
for agent in agent_contexts:
|
| 43 |
+
agent_response = agent[-1]["content"]
|
| 44 |
+
response = "\n\n One agent response: ```{}```".format(agent_response)
|
|
|
|
| 45 |
|
| 46 |
+
prefix_string = prefix_string + response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
prefix_string = prefix_string + "\n\n Write a summary of the different opinions from each of the individual agent."
|
| 49 |
+
completion = construct_message(prefix_string, instruction, idx)
|
| 50 |
+
|
| 51 |
+
return completion
|
| 52 |
|
| 53 |
def generate_question(agents, question):
|
| 54 |
agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
|
|
|
|
| 57 |
|
| 58 |
return agent_contexts, content
|
| 59 |
|
| 60 |
+
def Inference(model_list, question, API_KEY, cot):
|
| 61 |
if len(model_list) != 3:
|
| 62 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
| 63 |
|
|
|
|
| 68 |
def generate_answer(model, formatted_prompt):
|
| 69 |
API_URL = endpoint_dict[model]["API_URL"]
|
| 70 |
headers = endpoint_dict[model]["headers"]
|
| 71 |
+
payload = {
|
| 72 |
+
"inputs": formatted_prompt,
|
| 73 |
+
"parameters": {
|
| 74 |
+
"max_new_tokens": 256
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
try:
|
| 78 |
resp = requests.post(API_URL, json=payload, headers=headers)
|
| 79 |
response = resp.json()
|
| 80 |
except:
|
| 81 |
print("retrying due to an error......")
|
| 82 |
time.sleep(5)
|
| 83 |
+
return generate_answer(model, formatted_prompt)
|
| 84 |
|
| 85 |
+
return {"model": model, "content": response[0]["generated_text"]}
|
| 86 |
|
| 87 |
def prompt_formatting(model, instruction, cot):
|
| 88 |
if model == "alpaca" or model == "orca":
|
|
|
|
| 92 |
|
| 93 |
if cot:
|
| 94 |
instruction += "Let's think step by step."
|
| 95 |
+
|
| 96 |
+
return {"model": model, "content": prompt.format(instruction=instruction)}
|
| 97 |
+
|
| 98 |
agents = len(model_list)
|
| 99 |
+
rounds = 2
|
| 100 |
|
| 101 |
+
agent_contexts, content = generate_question(agents=model_list, question=args.question)
|
| 102 |
|
| 103 |
+
message = []
|
| 104 |
|
| 105 |
# Debate
|
| 106 |
for debate in range(rounds+1):
|
| 107 |
# Refer to the summarized previous response
|
| 108 |
if debate != 0:
|
| 109 |
+
message.append(summarize_message(agent_contexts, content, 2 * debate - 1))
|
| 110 |
+
for i in range(len(agent_contexts)):
|
| 111 |
agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
|
| 112 |
|
| 113 |
# Generate new response based on summarized response
|
| 114 |
for agent_context in agent_contexts:
|
| 115 |
+
completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"])
|
| 116 |
agent_context.append(completion)
|
| 117 |
|
| 118 |
models_response = {
|
| 119 |
+
f"{model_list[0]}": [agent_contexts[0][1]["content"], agent_contexts[0][3]["content"], agent_contexts[0][-1]["content"]],
|
| 120 |
+
f"{model_list[1]}": [agent_contexts[1][1]["content"], agent_contexts[1][3]["content"], agent_contexts[1][-1]["content"]],
|
| 121 |
+
f"{model_list[2]}": [agent_contexts[2][1]["content"], agent_contexts[2][3]["content"], agent_contexts[2][-1]["content"]]
|
| 122 |
}
|
| 123 |
response_summarization = [
|
| 124 |
+
message[0], message[1]
|
| 125 |
]
|
| 126 |
+
generated_description = {"question": content, "agent_response": models_response, "summarization": response_summarization})
|
| 127 |
|
| 128 |
return generated_description
|