Tsukihjy/testcase / methods /ALGO /batch_inference_api.py
Tsukihjy's picture
download
raw
3.88 kB
import asyncio
import os
from config import cfg
from openai import AsyncOpenAI
from tap import Tap
from tqdm.auto import tqdm
from collections.abc import Coroutine, Sequence
from prompt import CF_PROMPT
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from dataset_all import get_datasets_by_name
class Argument(Tap):
# model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct"
temperature: float = 0.2
max_completion_tokens: int = 16000
concurrency: int = 10
def get_datasets_prompt():
al_dataset = get_datasets_by_name(cfg.dataset_name)
print(len(al_dataset))
prompt_list = []
for item in al_dataset:
tcb_id = item["tcb_id"]
gen_prompt = CF_PROMPT["data_generator_original"]
gen_prompt = gen_prompt.format(problem=item["query_en"])
random_generator = {
"tcb_id": tcb_id,
"type": "random",
"system_prompt": "You are an expert software tester tasked with thoroughly testing a given piece of code. ",
"user_prompt": gen_prompt
}
prompt_list += [random_generator, ] * 5
oracle_prompt = CF_PROMPT["oracle_generator"]
oracle_prompt = oracle_prompt.format(problem=item["query_en"])
adversarial_generator = {
"tcb_id": tcb_id,
"type": "brute-force",
"system_prompt": "You are an AI programming assistant.",
"user_prompt": oracle_prompt
}
prompt_list += [adversarial_generator, ] * 5
return prompt_list
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def limit_concurrency(
coroutines: Sequence[Coroutine], concurrency: int
) -> list[Coroutine]:
semaphore = asyncio.Semaphore(concurrency)
async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine:
async with semaphore:
return await coroutine
return [with_concurrency_limit(coroutine) for coroutine in coroutines]
async def request(data: str, client: AsyncOpenAI):
try:
response = await client.chat.completions.create(
model=cfg.model_name,
messages=[
{"role": "system", "content": data['system_prompt']},
{"role": "user", "content": data['user_prompt']}
],
temperature=args.temperature,
max_tokens=args.max_completion_tokens,
)
except Exception as e:
print(f"Error - {e}")
return data['tcb_id'], data['type'], ""
return data['tcb_id'], data['type'], response.choices[0].message.content
async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""):
dataset = get_datasets_prompt()
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=6000)
tasks = [request(data, client) for data in dataset]
tasks = limit_concurrency(tasks, args.concurrency)
async for rezip_response in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running"
):
tcb_id, type_of_func, response = await rezip_response
res_dict = {
"tcb_id": tcb_id,
"type": type_of_func,
"response": response
}
# print result, or save to file
append_dict_to_jsonl(save_file_path, res_dict)
if __name__ == "__main__":
args = Argument().parse_args()
try:
orginal_response_file = cfg.response_file.format("algo", cfg.model_name)
print(orginal_response_file)
asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file))
except Exception as e:
print(f'Error! {e}')

Xet Storage Details

Size:
3.88 kB
·
Xet hash:
2162a6fd96ce99ce5fe71ae8931d367f1ef0cf4154a23f920dfb73636eb45ffc

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.