Tsukihjy/testcase / methods /Predo /batch_inference_api.py
Tsukihjy's picture
download
raw
3.69 kB
import asyncio
import os
from config import cfg
from openai import AsyncOpenAI
from tap import Tap
from tqdm.auto import tqdm
from collections.abc import Coroutine, Sequence
from prompt import test_gen_prompt, predo_prompt, input_template, System_prompt
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from dataset_all import get_datasets_by_name
class Argument(Tap):
# model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-7B-Instruct"
temperature: float = 0.2
max_completion_tokens: int = 10000
concurrency: int = 10
def get_datasets_prompt():
al_dataset = get_datasets_by_name(cfg.dataset_name)
print(len(al_dataset))
prompt_list = []
for item in al_dataset:
tcb_id = item["tcb_id"]
random_generator = {
"tcb_id": tcb_id,
"type": "input",
"system_prompt": test_gen_prompt,
"user_prompt": input_template.replace("[[Question]]", item["query_en"])
}
prompt_list += [random_generator, ] * 10
oracle_generator = {
"tcb_id": tcb_id,
"type": "answer",
"system_prompt": System_prompt,
"user_prompt": predo_prompt.replace("{question}", item['query_en'])
}
prompt_list += [oracle_generator, ] * 10
return prompt_list
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def limit_concurrency(
coroutines: Sequence[Coroutine], concurrency: int
) -> list[Coroutine]:
semaphore = asyncio.Semaphore(concurrency)
async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine:
async with semaphore:
return await coroutine
return [with_concurrency_limit(coroutine) for coroutine in coroutines]
async def request(data: str, client: AsyncOpenAI):
try:
response = await client.chat.completions.create(
model=cfg.model_name,
messages=[
{"role": "system", "content": data['system_prompt']},
{"role": "user", "content": data['user_prompt']}
],
temperature=args.temperature if data['type'] == "input" else 0.7,
max_tokens=args.max_completion_tokens,
)
except Exception as e:
print(f"Error - {e}")
return data['tcb_id'], "error", ""
return data['tcb_id'], data['type'], response.choices[0].message.content
async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""):
dataset = get_datasets_prompt()
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
tasks = [request(data, client) for data in dataset]
tasks = limit_concurrency(tasks, args.concurrency)
async for rezip_response in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running"
):
tcb_id, type_of_func, response = await rezip_response
res_dict = {
"tcb_id": tcb_id,
"type": type_of_func,
"response": response
}
# print result, or save to file
append_dict_to_jsonl(save_file_path, res_dict)
if __name__ == "__main__":
args = Argument().parse_args()
try:
orginal_response_file = cfg.response_file.format("predo", cfg.model_name)
print(orginal_response_file)
asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file))
except Exception as e:
print(f'Error! {e}')

Xet Storage Details

Size:
3.69 kB
·
Xet hash:
347a7dc211d8b8eca866e4a946dda4926465e13f36a1c5f6cc9ebc633fa70592

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.