| import asyncio | |
| import os | |
| from config import cfg | |
| from openai import AsyncOpenAI | |
| from sglang.utils import launch_server_cmd, terminate_process, wait_for_server | |
| from tap import Tap | |
| from tqdm.auto import tqdm | |
| from collections.abc import Coroutine, Sequence | |
| from prompt import CF_PROMPT | |
| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from dataset_all import get_datasets_by_name | |
| class Argument(Tap): | |
| model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct" | |
| temperature: float = 0.2 | |
| max_completion_tokens: int = 10000 | |
| concurrency: int = 100 | |
| def get_datasets_prompt(): | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| print(len(al_dataset)) | |
| prompt_list = [] | |
| for item in al_dataset: | |
| tcb_id = item["tcb_id"] | |
| gen_prompt = CF_PROMPT["data_generator_original"] | |
| gen_prompt = gen_prompt.format(problem=item["query_en"]) | |
| random_generator = { | |
| "tcb_id": tcb_id, | |
| "type": "random", | |
| "system_prompt": "You are an expert software tester tasked with thoroughly testing a given piece of code. ", | |
| "user_prompt": gen_prompt | |
| } | |
| prompt_list += [random_generator, ] * 5 | |
| oracle_prompt = CF_PROMPT["oracle_generator"] | |
| oracle_prompt = oracle_prompt.format(problem=item["query_en"]) | |
| adversarial_generator = { | |
| "tcb_id": tcb_id, | |
| "type": "brute-force", | |
| "system_prompt": "You are an AI programming assistant.", | |
| "user_prompt": oracle_prompt | |
| } | |
| prompt_list += [adversarial_generator, ] * 5 | |
| return prompt_list | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def limit_concurrency( | |
| coroutines: Sequence[Coroutine], concurrency: int | |
| ) -> list[Coroutine]: | |
| semaphore = asyncio.Semaphore(concurrency) | |
| async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine: | |
| async with semaphore: | |
| return await coroutine | |
| return [with_concurrency_limit(coroutine) for coroutine in coroutines] | |
| async def request(data: str, client: AsyncOpenAI): | |
| response = await client.chat.completions.create( | |
| model=args.model_name_or_path, | |
| messages=[ | |
| {"role": "system", "content": data['system_prompt']}, | |
| {"role": "user", "content": data['user_prompt']} | |
| ], | |
| temperature=args.temperature, | |
| max_tokens=args.max_completion_tokens, | |
| ) | |
| return data['tcb_id'], data['type'], response.choices[0].message.content | |
| async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""): | |
| dataset = get_datasets_prompt() | |
| client = AsyncOpenAI(base_url=base_url, api_key=api_key) | |
| tasks = [request(data, client) for data in dataset] | |
| tasks = limit_concurrency(tasks, args.concurrency) | |
| async for rezip_response in tqdm( | |
| asyncio.as_completed(tasks), total=len(tasks), desc="Running" | |
| ): | |
| tcb_id, type_of_func, response = await rezip_response | |
| res_dict = { | |
| "tcb_id": tcb_id, | |
| "type": type_of_func, | |
| "response": response | |
| } | |
| # print result, or save to file | |
| append_dict_to_jsonl(save_file_path, res_dict) | |
| if __name__ == "__main__": | |
| args = Argument().parse_args() | |
| try: | |
| os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" | |
| server_process, port = launch_server_cmd( | |
| ( | |
| "python3 -m sglang.launch_server " | |
| "--tp 8 " | |
| "--dp 1 " | |
| f"--model-path {args.model_name_or_path} " | |
| # f"--served-model-name {args.model} " | |
| # "--reasoning-parser qwen3 " | |
| "--context-length 20000 " | |
| # """--json-model-override-args {"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}} """ | |
| "--host 0.0.0.0 " | |
| "--port 33389 " | |
| "--log-level warning " | |
| ) | |
| ) | |
| wait_for_server(f"http://localhost:{port}") | |
| orginal_response_file = cfg.response_file.format("algo", os.path.basename(args.model_name_or_path)) | |
| print(orginal_response_file) | |
| asyncio.run(main(args, base_url=f"http://localhost:{port}/v1", api_key="sglang", save_file_path=orginal_response_file)) | |
| except Exception as e: | |
| print(f'Error! {e}') | |
| finally: | |
| terminate_process(server_process) | |
Xet Storage Details
- Size:
- 4.66 kB
- Xet hash:
- 75706c108fbc3c7f82163bf0b0657a090db8d8b0337685b8a370e41136279dec
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.