Spaces:
Running
Running
Commit
·
823c850
1
Parent(s):
4cb1f6b
Update model_inference.py
Browse files- model_inference.py +25 -25
model_inference.py
CHANGED
|
@@ -47,31 +47,6 @@ def generate_question(agents, question):
|
|
| 47 |
|
| 48 |
return agent_contexts, content
|
| 49 |
|
| 50 |
-
def generate_answer(model, formatted_prompt):
|
| 51 |
-
API_URL = endpoint_dict[model]
|
| 52 |
-
headers = {"Authorization": f"Bearer {args.auth_token}"}
|
| 53 |
-
payload = {"inputs": formatted_prompt}
|
| 54 |
-
try:
|
| 55 |
-
resp = requests.post(API_URL, json=payload, headers=headers)
|
| 56 |
-
response = resp.json()
|
| 57 |
-
except:
|
| 58 |
-
print("retrying due to an error......")
|
| 59 |
-
time.sleep(5)
|
| 60 |
-
return generate_answer(API_URL, headers, payload)
|
| 61 |
-
|
| 62 |
-
return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
|
| 63 |
-
|
| 64 |
-
def prompt_formatting(model, instruction, cot):
|
| 65 |
-
if model == "alpaca" or model == "orca":
|
| 66 |
-
prompt = prompt_dict[model]["prompt_no_input"]
|
| 67 |
-
else:
|
| 68 |
-
prompt = prompt_dict[model]["prompt"]
|
| 69 |
-
|
| 70 |
-
if cot:
|
| 71 |
-
instruction += "Let's think step by step."
|
| 72 |
-
|
| 73 |
-
return {"model": model, "content": prompt.format(instruction)}
|
| 74 |
-
|
| 75 |
def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
| 76 |
if len(model_list) != 3:
|
| 77 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
|
@@ -80,6 +55,31 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
| 80 |
|
| 81 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
agents = len(model_list)
|
| 84 |
rounds = round
|
| 85 |
|
|
|
|
| 47 |
|
| 48 |
return agent_contexts, content
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
| 51 |
if len(model_list) != 3:
|
| 52 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
|
|
|
| 55 |
|
| 56 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
| 57 |
|
| 58 |
+
def generate_answer(model, formatted_prompt):
|
| 59 |
+
API_URL = endpoint_dict[model]["API_URL"]
|
| 60 |
+
headers = endpoint_dict[model]["headers"]
|
| 61 |
+
payload = {"inputs": formatted_prompt}
|
| 62 |
+
try:
|
| 63 |
+
resp = requests.post(API_URL, json=payload, headers=headers)
|
| 64 |
+
response = resp.json()
|
| 65 |
+
except:
|
| 66 |
+
print("retrying due to an error......")
|
| 67 |
+
time.sleep(5)
|
| 68 |
+
return generate_answer(API_URL, headers, payload)
|
| 69 |
+
|
| 70 |
+
return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
|
| 71 |
+
|
| 72 |
+
def prompt_formatting(model, instruction, cot):
|
| 73 |
+
if model == "alpaca" or model == "orca":
|
| 74 |
+
prompt = prompt_dict[model]["prompt_no_input"]
|
| 75 |
+
else:
|
| 76 |
+
prompt = prompt_dict[model]["prompt"]
|
| 77 |
+
|
| 78 |
+
if cot:
|
| 79 |
+
instruction += "Let's think step by step."
|
| 80 |
+
|
| 81 |
+
return {"model": model, "content": prompt.format(instruction)}
|
| 82 |
+
|
| 83 |
agents = len(model_list)
|
| 84 |
rounds = round
|
| 85 |
|