| import asyncio | |
| import os | |
| from config import cfg | |
| from openai import AsyncOpenAI | |
| from sglang.utils import launch_server_cmd, terminate_process, wait_for_server | |
| from tap import Tap | |
| from tqdm.auto import tqdm | |
| from collections.abc import Coroutine, Sequence | |
| from prompt import validator_prompt, input_generator_prompt | |
| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from dataset_all import get_datasets_by_name | |
| class Argument(Tap): | |
| model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct" | |
| temperature: float = 0.2 | |
| max_completion_tokens: int = 10000 | |
| concurrency: int = 100 | |
| import re | |
| def extract_code(ans_str): | |
| pattern = r'```json\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| def get_datasets_prompt(): | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| print(len(al_dataset)) | |
| prompt_list = [] | |
| for item in al_dataset: | |
| tcb_id = item["tcb_id"] | |
| if len(item['solutions']) <= 0: | |
| # print(tcb_id) | |
| continue | |
| response_validator = { | |
| "tcb_id": tcb_id, | |
| "type": "validator", | |
| "system_prompt": "You are a helpful code assistant.", | |
| "user_prompt": validator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code']) | |
| } | |
| # .replace("{{ input_validator }}", response_validator_dict['input_validator']) | |
| response_generator = { | |
| "tcb_id": tcb_id, | |
| "type": "generator", | |
| "system_prompt": "You are a helpful code assistant.", | |
| "user_prompt": input_generator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code']) | |
| } | |
| ht_pair = { | |
| "validator": response_validator, | |
| "generator": response_generator | |
| } | |
| prompt_list += [ht_pair, ] * 10 | |
| return prompt_list | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def limit_concurrency( | |
| coroutines: Sequence[Coroutine], concurrency: int | |
| ) -> list[Coroutine]: | |
| semaphore = asyncio.Semaphore(concurrency) | |
| async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine: | |
| async with semaphore: | |
| return await coroutine | |
| return [with_concurrency_limit(coroutine) for coroutine in coroutines] | |
| async def request(data: str, client: AsyncOpenAI): | |
| validator = data["validator"] | |
| generator = data["generator"] | |
| try: | |
| response_validator = await client.chat.completions.create( | |
| model=args.model_name_or_path, | |
| messages=[ | |
| {"role": "system", "content": validator['system_prompt']}, | |
| {"role": "user", "content": validator['user_prompt']} | |
| ], | |
| temperature=args.temperature, | |
| max_tokens=args.max_completion_tokens, | |
| ) | |
| generator_prompt = "" | |
| response_validator_dict = json.loads(extract_code(response_validator.choices[0].message.content)) | |
| generator_prompt = generator['user_prompt'].replace("{{ input_validator }}", response_validator_dict['input_validator']) | |
| response_generator = await client.chat.completions.create( | |
| model=args.model_name_or_path, | |
| messages=[ | |
| {"role": "system", "content": generator['system_prompt']}, | |
| {"role": "user", "content": generator_prompt} | |
| ], | |
| temperature=args.temperature, | |
| max_tokens=args.max_completion_tokens, | |
| ) | |
| except Exception as e: | |
| # print(f"Error - {e}") | |
| return validator['tcb_id'], "", "" | |
| return validator['tcb_id'], response_validator.choices[0].message.content, response_generator.choices[0].message.content | |
| async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""): | |
| dataset = get_datasets_prompt() | |
| client = AsyncOpenAI(base_url=base_url, api_key=api_key) | |
| tasks = [request(data, client) for data in dataset] | |
| tasks = limit_concurrency(tasks, args.concurrency) | |
| async for rezip_response in tqdm( | |
| asyncio.as_completed(tasks), total=len(tasks), desc="Running" | |
| ): | |
| tcb_id, response_validator, response_generator = await rezip_response | |
| if response_validator == "" and response_generator == "": | |
| continue | |
| res_dict = { | |
| "tcb_id": tcb_id, | |
| "response_validator": response_validator, | |
| "response_generator": response_generator | |
| } | |
| # print result, or save to file | |
| append_dict_to_jsonl(save_file_path, res_dict) | |
| if __name__ == "__main__": | |
| args = Argument().parse_args() | |
| try: | |
| os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" | |
| server_process, port = launch_server_cmd( | |
| ( | |
| "python3 -m sglang.launch_server " | |
| "--tp 8 " | |
| "--dp 1 " | |
| f"--model-path {args.model_name_or_path} " | |
| # f"--served-model-name {args.model} " | |
| # "--reasoning-parser qwen3 " | |
| "--context-length 32768 " | |
| # """--json-model-override-args {"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768}} """ | |
| "--host 0.0.0.0 " | |
| "--port 33382 " | |
| "--log-level warning " | |
| ) | |
| ) | |
| wait_for_server(f"http://localhost:{port}") | |
| orginal_response_file = cfg.response_file.format("ht", os.path.basename(args.model_name_or_path)) | |
| print(orginal_response_file) | |
| asyncio.run(main(args, base_url=f"http://localhost:{port}/v1", api_key="sglang", save_file_path=orginal_response_file)) | |
| except Exception as e: | |
| print(f'Error! {e}') | |
| finally: | |
| terminate_process(server_process) | |
Xet Storage Details
- Size:
- 6.12 kB
- Xet hash:
- 75a9c7a95718cc466e5d984cd1b742ad4909529bf6efaf665ab35f5e72529508
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.