Tsukihjy/testcase / methods /ALGO /batch_inference.py
Tsukihjy's picture
download
raw
4.66 kB
import asyncio
import os
from config import cfg
from openai import AsyncOpenAI
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server
from tap import Tap
from tqdm.auto import tqdm
from collections.abc import Coroutine, Sequence
from prompt import CF_PROMPT
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from dataset_all import get_datasets_by_name
class Argument(Tap):
model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct"
temperature: float = 0.2
max_completion_tokens: int = 10000
concurrency: int = 100
def get_datasets_prompt():
al_dataset = get_datasets_by_name(cfg.dataset_name)
print(len(al_dataset))
prompt_list = []
for item in al_dataset:
tcb_id = item["tcb_id"]
gen_prompt = CF_PROMPT["data_generator_original"]
gen_prompt = gen_prompt.format(problem=item["query_en"])
random_generator = {
"tcb_id": tcb_id,
"type": "random",
"system_prompt": "You are an expert software tester tasked with thoroughly testing a given piece of code. ",
"user_prompt": gen_prompt
}
prompt_list += [random_generator, ] * 5
oracle_prompt = CF_PROMPT["oracle_generator"]
oracle_prompt = oracle_prompt.format(problem=item["query_en"])
adversarial_generator = {
"tcb_id": tcb_id,
"type": "brute-force",
"system_prompt": "You are an AI programming assistant.",
"user_prompt": oracle_prompt
}
prompt_list += [adversarial_generator, ] * 5
return prompt_list
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def limit_concurrency(
coroutines: Sequence[Coroutine], concurrency: int
) -> list[Coroutine]:
semaphore = asyncio.Semaphore(concurrency)
async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine:
async with semaphore:
return await coroutine
return [with_concurrency_limit(coroutine) for coroutine in coroutines]
async def request(data: str, client: AsyncOpenAI):
response = await client.chat.completions.create(
model=args.model_name_or_path,
messages=[
{"role": "system", "content": data['system_prompt']},
{"role": "user", "content": data['user_prompt']}
],
temperature=args.temperature,
max_tokens=args.max_completion_tokens,
)
return data['tcb_id'], data['type'], response.choices[0].message.content
async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""):
dataset = get_datasets_prompt()
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
tasks = [request(data, client) for data in dataset]
tasks = limit_concurrency(tasks, args.concurrency)
async for rezip_response in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running"
):
tcb_id, type_of_func, response = await rezip_response
res_dict = {
"tcb_id": tcb_id,
"type": type_of_func,
"response": response
}
# print result, or save to file
append_dict_to_jsonl(save_file_path, res_dict)
if __name__ == "__main__":
args = Argument().parse_args()
try:
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
server_process, port = launch_server_cmd(
(
"python3 -m sglang.launch_server "
"--tp 8 "
"--dp 1 "
f"--model-path {args.model_name_or_path} "
# f"--served-model-name {args.model} "
# "--reasoning-parser qwen3 "
"--context-length 20000 "
# """--json-model-override-args {"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}} """
"--host 0.0.0.0 "
"--port 33389 "
"--log-level warning "
)
)
wait_for_server(f"http://localhost:{port}")
orginal_response_file = cfg.response_file.format("algo", os.path.basename(args.model_name_or_path))
print(orginal_response_file)
asyncio.run(main(args, base_url=f"http://localhost:{port}/v1", api_key="sglang", save_file_path=orginal_response_file))
except Exception as e:
print(f'Error! {e}')
finally:
terminate_process(server_process)

Xet Storage Details

Size:
4.66 kB
·
Xet hash:
75706c108fbc3c7f82163bf0b0657a090db8d8b0337685b8a370e41136279dec

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.