Tsukihjy/testcase / methods /Hardtest /batch_inference_api.py
Tsukihjy's picture
download
raw
5.25 kB
import asyncio
import os
from config import cfg
from openai import AsyncOpenAI
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server
from tap import Tap
from tqdm.auto import tqdm
from collections.abc import Coroutine, Sequence
from prompt import validator_prompt, input_generator_prompt
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from dataset_all import get_datasets_by_name
class Argument(Tap):
# model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-Coder-7B-Instruct"
temperature: float = 0.2
max_completion_tokens: int = 10000
concurrency: int = 20
import re
def extract_code(ans_str):
pattern = r'```json\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
def get_datasets_prompt():
al_dataset = get_datasets_by_name(cfg.dataset_name)
print(len(al_dataset))
prompt_list = []
for item in al_dataset:
tcb_id = item["tcb_id"]
if len(item['solutions']) <= 0:
# print(tcb_id)
continue
response_validator = {
"tcb_id": tcb_id,
"type": "validator",
"system_prompt": "You are a helpful code assistant.",
"user_prompt": validator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code'])
}
# .replace("{{ input_validator }}", response_validator_dict['input_validator'])
response_generator = {
"tcb_id": tcb_id,
"type": "generator",
"system_prompt": "You are a helpful code assistant.",
"user_prompt": input_generator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code'])
}
ht_pair = {
"validator": response_validator,
"generator": response_generator
}
prompt_list += [ht_pair, ] * 10
return prompt_list
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def limit_concurrency(
coroutines: Sequence[Coroutine], concurrency: int
) -> list[Coroutine]:
semaphore = asyncio.Semaphore(concurrency)
async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine:
async with semaphore:
return await coroutine
return [with_concurrency_limit(coroutine) for coroutine in coroutines]
async def request(data: str, client: AsyncOpenAI):
validator = data["validator"]
generator = data["generator"]
try:
response_validator = await client.chat.completions.create(
model=cfg.model_name,
messages=[
{"role": "system", "content": validator['system_prompt']},
{"role": "user", "content": validator['user_prompt']}
],
temperature=args.temperature,
max_tokens=args.max_completion_tokens,
)
generator_prompt = ""
response_validator_dict = json.loads(extract_code(response_validator.choices[0].message.content))
generator_prompt = generator['user_prompt'].replace("{{ input_validator }}", response_validator_dict['input_validator'])
response_generator = await client.chat.completions.create(
model=cfg.model_name,
messages=[
{"role": "system", "content": generator['system_prompt']},
{"role": "user", "content": generator_prompt}
],
temperature=args.temperature,
max_tokens=args.max_completion_tokens,
)
except Exception as e:
# print(f"Error - {e}")
return validator['tcb_id'], "", ""
return validator['tcb_id'], response_validator.choices[0].message.content, response_generator.choices[0].message.content
async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""):
dataset = get_datasets_prompt()
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
tasks = [request(data, client) for data in dataset]
tasks = limit_concurrency(tasks, args.concurrency)
async for rezip_response in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running"
):
tcb_id, response_validator, response_generator = await rezip_response
if response_validator == "" and response_generator == "":
continue
res_dict = {
"tcb_id": tcb_id,
"response_validator": response_validator,
"response_generator": response_generator
}
# print result, or save to file
append_dict_to_jsonl(save_file_path, res_dict)
if __name__ == "__main__":
args = Argument().parse_args()
try:
orginal_response_file = cfg.response_file.format("ht", cfg.model_name)
print(orginal_response_file)
asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file))
except Exception as e:
print(f'Error! {e}')

Xet Storage Details

Size:
5.25 kB
·
Xet hash:
c769898835c8ea2b9e28ba31ca55dfe27ff9ecb23ad4ad40f2fec7bd22a1c8b3

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.