Tsukihjy/testcase / methods /lcb /batch_inference_api.py
Tsukihjy's picture
download
raw
3.79 kB
import asyncio
import os
from config import cfg
from openai import AsyncOpenAI
from tap import Tap
from tqdm.auto import tqdm
from collections.abc import Coroutine, Sequence
from prompt import random_input_prompt, adversarial_prompt, random_input_template, adversarial_input_template, shots
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from dataset_all import get_datasets_by_name
class Argument(Tap):
# model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-Coder-14B-Instruct"
temperature: float = 0.2
max_completion_tokens: int = 16000
concurrency: int = 10
def get_datasets_prompt():
al_dataset = get_datasets_by_name(cfg.dataset_name)
print(len(al_dataset))
prompt_list = []
for item in al_dataset:
tcb_id = item["tcb_id"]
random_generator = {
"tcb_id": tcb_id,
"type": "random",
"system_prompt": random_input_prompt.replace("{EXAMPLE_PROBLEM}", shots),
"user_prompt": random_input_template.replace("{PROBLEM}", item["query_en"])
}
prompt_list += [random_generator, ] * 10
adversarial_generator = {
"tcb_id": tcb_id,
"type": "eage",
"system_prompt": adversarial_prompt.replace("{EXAMPLE_PROBLEM}", shots),
"user_prompt": adversarial_input_template.replace("{PROBLEM}", item["query_en"])
}
prompt_list += [adversarial_generator, ] * 10
return prompt_list
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def limit_concurrency(
coroutines: Sequence[Coroutine], concurrency: int
) -> list[Coroutine]:
semaphore = asyncio.Semaphore(concurrency)
async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine:
async with semaphore:
return await coroutine
return [with_concurrency_limit(coroutine) for coroutine in coroutines]
async def request(data: str, client: AsyncOpenAI):
try:
response = await client.chat.completions.create(
model=cfg.model_name,
messages=[
{"role": "system", "content": data['system_prompt']},
{"role": "user", "content": data['user_prompt']}
],
temperature=args.temperature,
max_tokens=args.max_completion_tokens,
)
except Exception as e:
print(f"Error - {e}")
return data['tcb_id'], data['type'], ""
return data['tcb_id'], data['type'], response.choices[0].message.content
async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""):
dataset = get_datasets_prompt()
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=5000)
tasks = [request(data, client) for data in dataset]
tasks = limit_concurrency(tasks, args.concurrency)
async for rezip_response in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running"
):
tcb_id, type_of_func, response = await rezip_response
res_dict = {
"tcb_id": tcb_id,
"type": type_of_func,
"response": response
}
# print result, or save to file
append_dict_to_jsonl(save_file_path, res_dict)
if __name__ == "__main__":
args = Argument().parse_args()
try:
orginal_response_file = cfg.response_file.format("lcb", cfg.model_name)
print(orginal_response_file)
asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file))
except Exception as e:
print(f'Error! {e}')

Xet Storage Details

Size:
3.79 kB
·
Xet hash:
207d513194e4a74ac88ada9847bf7a546aa018774981b5d8d771f68b20e85b03

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.