| import asyncio | |
| import os | |
| from config import cfg | |
| from openai import AsyncOpenAI | |
| from tap import Tap | |
| from tqdm.auto import tqdm | |
| from collections.abc import Coroutine, Sequence | |
| from prompt import code_system_prompt, baseline_code_test_prompt, shots | |
| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from dataset_all import get_datasets_by_name | |
| class Argument(Tap): | |
| # model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-7B-Instruct" | |
| temperature: float = 0.2 | |
| max_completion_tokens: int = 16000 | |
| concurrency: int = 20 | |
| def get_datasets_prompt(): | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| print(len(al_dataset)) | |
| prompt_list = [] | |
| for item in al_dataset: | |
| tcb_id = item["tcb_id"] | |
| query = item['query_en'] | |
| user_prompt = baseline_code_test_prompt.replace("{num_of_test}", str(cfg.gen_test_nums)) | |
| user_prompt += "===Example===\n" + shots | |
| user_prompt += "===Question===\n" + query | |
| prompt_list.append( | |
| { | |
| "tcb_id": tcb_id, | |
| "system_prompt": code_system_prompt, | |
| "user_prompt": user_prompt | |
| } | |
| ) | |
| return prompt_list * 5 | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def limit_concurrency( | |
| coroutines: Sequence[Coroutine], concurrency: int | |
| ) -> list[Coroutine]: | |
| semaphore = asyncio.Semaphore(concurrency) | |
| async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine: | |
| async with semaphore: | |
| return await coroutine | |
| return [with_concurrency_limit(coroutine) for coroutine in coroutines] | |
| async def request(data: str, client: AsyncOpenAI): | |
| try: | |
| response = await client.chat.completions.create( | |
| model=cfg.model_name, | |
| messages=[ | |
| {"role": "system", "content": data['system_prompt']}, | |
| {"role": "user", "content": data['user_prompt']} | |
| ], | |
| temperature=args.temperature, | |
| max_tokens=args.max_completion_tokens, | |
| ) | |
| except Exception as e: | |
| print(f"Error - {e}") | |
| return data['tcb_id'], "" | |
| return data['tcb_id'], response.choices[0].message.content | |
| async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""): | |
| dataset = get_datasets_prompt() | |
| client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=5000) | |
| tasks = [request(data, client) for data in dataset] | |
| tasks = limit_concurrency(tasks, args.concurrency) | |
| async for rezip_response in tqdm( | |
| asyncio.as_completed(tasks), total=len(tasks), desc="Running" | |
| ): | |
| tcb_id, response = await rezip_response | |
| res_dict = { | |
| "tcb_id": tcb_id, | |
| "response": response | |
| } | |
| # print result, or save to file | |
| append_dict_to_jsonl(save_file_path, res_dict) | |
| if __name__ == "__main__": | |
| args = Argument().parse_args() | |
| try: | |
| orginal_response_file = cfg.response_file.format("crux", cfg.model_name) | |
| print(orginal_response_file) | |
| asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file)) | |
| except Exception as e: | |
| print(f'Error! {e}') | |
Xet Storage Details
- Size:
- 3.45 kB
- Xet hash:
- bd6dca4156c42cbe3f789f3c497ea7e71dc39cbcec598c876abc703f24cddb8f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.