Upload run_program.py with huggingface_hub
Browse files- run_program.py +11 -8
run_program.py
CHANGED
|
@@ -238,9 +238,11 @@ def update_question_with_new_parameters():
|
|
| 238 |
json.dump(program_data, outfile, indent=4)
|
| 239 |
|
| 240 |
|
| 241 |
-
def call_answer_question(question, model_name='gpt'):
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
prompt = prompt_template.format_map(
|
| 245 |
{"question": question}
|
| 246 |
)
|
|
@@ -321,7 +323,8 @@ def call_answer_question(question, model_name='gpt'):
|
|
| 321 |
outputs = llama_pipeline(
|
| 322 |
messages,
|
| 323 |
max_new_tokens=300,
|
| 324 |
-
temperature=0.00001
|
|
|
|
| 325 |
)
|
| 326 |
# print(outputs[0]["generated_text"][-1])
|
| 327 |
return outputs[0]["generated_text"][-1]['content']
|
|
@@ -332,19 +335,19 @@ def answer_question(model_name='gpt'):
|
|
| 332 |
program_data = json.load(infile)
|
| 333 |
print(len(program_data))
|
| 334 |
for case in tqdm(program_data):
|
| 335 |
-
response = call_answer_question(case['question'], model_name=model_name)
|
| 336 |
case['prediction'] = response
|
| 337 |
# print(case['prediction'])
|
| 338 |
case['new_prediction'] = []
|
| 339 |
for question in case['new_questions']:
|
| 340 |
-
response = call_answer_question(question, model_name=model_name)
|
| 341 |
case['new_prediction'].append(response)
|
| 342 |
# print(case)
|
| 343 |
# break
|
| 344 |
# print(case)
|
| 345 |
# break
|
| 346 |
-
outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
|
| 347 |
-
|
| 348 |
json.dump(program_data, outfile, indent=4)
|
| 349 |
|
| 350 |
|
|
|
|
| 238 |
json.dump(program_data, outfile, indent=4)
|
| 239 |
|
| 240 |
|
| 241 |
+
def call_answer_question(question, model_name='gpt', cot=False):
|
| 242 |
+
if cot:
|
| 243 |
+
prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
|
| 244 |
+
else:
|
| 245 |
+
prompt_template = PROMPT_DICT['prompt_answer_question']
|
| 246 |
prompt = prompt_template.format_map(
|
| 247 |
{"question": question}
|
| 248 |
)
|
|
|
|
| 323 |
outputs = llama_pipeline(
|
| 324 |
messages,
|
| 325 |
max_new_tokens=300,
|
| 326 |
+
# temperature=0.00001
|
| 327 |
+
temperature = 0.7
|
| 328 |
)
|
| 329 |
# print(outputs[0]["generated_text"][-1])
|
| 330 |
return outputs[0]["generated_text"][-1]['content']
|
|
|
|
| 335 |
program_data = json.load(infile)
|
| 336 |
print(len(program_data))
|
| 337 |
for case in tqdm(program_data):
|
| 338 |
+
response = call_answer_question(case['question'], model_name=model_name, cot=True)
|
| 339 |
case['prediction'] = response
|
| 340 |
# print(case['prediction'])
|
| 341 |
case['new_prediction'] = []
|
| 342 |
for question in case['new_questions']:
|
| 343 |
+
response = call_answer_question(question, model_name=model_name, cot=True)
|
| 344 |
case['new_prediction'].append(response)
|
| 345 |
# print(case)
|
| 346 |
# break
|
| 347 |
# print(case)
|
| 348 |
# break
|
| 349 |
+
# outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
|
| 350 |
+
outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
|
| 351 |
json.dump(program_data, outfile, indent=4)
|
| 352 |
|
| 353 |
|