learn / test_time_scaling /get_response_utils.py
unfair11212's picture
Upload folder using huggingface_hub
a80f6e6 verified
from openai import OpenAI
import os
import logging
import torch
import re
import os
import random
import transformers
from tqdm import tqdm
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
def setup_logger(log_path):
logging.basicConfig(
filename=log_path,
filemode="w",
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO
)
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
ANSWER_TRIGGER = "The answer is"
# prompt="""Please solve the following problem carefully, showing all intermediate steps. At the end, write your final answer in the following format: "#### your_answer_here" where "your_answer_here" is the final numerical result.
# Problem:
# {question}
# Remember: Only the final answer should be placed after "####", and do not include any extra text after it."""
def extract_answer_from_output(completion):
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
return INVALID_ANS
def is_correct(model_answer, answer):
gt_answer = extract_answer_from_output(answer)
assert gt_answer != INVALID_ANS
return model_answer == gt_answer
def is_correct_direct(model_answer, answer):
return model_answer == answer
def is_correct_choice(model_answer, answer):
return int(model_answer) == answer
# 本地vllm调用
def get_response_template(message,model="meta-llama/Meta-Llama-3-8B-Instruct",client=OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8422/v1",
)):
prompt= message
chat_response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature = 0.8,
)
print("Chat response:", chat_response.choices[0].message.content)
return chat_response.choices[0].message.content