| import asyncio | |
| import os | |
| from config import cfg | |
| from openai import AsyncOpenAI | |
| from tap import Tap | |
| from tqdm.auto import tqdm | |
| from collections.abc import Coroutine, Sequence | |
| from prompt import test_gen_prompt, predo_prompt, input_template, System_prompt | |
| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from dataset_all import get_datasets_by_name | |
| class Argument(Tap): | |
| # model_name_or_path: str = "/mnt/jfs/ckpt/checkpoints/Qwen2.5-7B-Instruct" | |
| temperature: float = 0.2 | |
| max_completion_tokens: int = 10000 | |
| concurrency: int = 10 | |
| def get_datasets_prompt(): | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| print(len(al_dataset)) | |
| prompt_list = [] | |
| for item in al_dataset: | |
| tcb_id = item["tcb_id"] | |
| random_generator = { | |
| "tcb_id": tcb_id, | |
| "type": "input", | |
| "system_prompt": test_gen_prompt, | |
| "user_prompt": input_template.replace("[[Question]]", item["query_en"]) | |
| } | |
| prompt_list += [random_generator, ] * 10 | |
| oracle_generator = { | |
| "tcb_id": tcb_id, | |
| "type": "answer", | |
| "system_prompt": System_prompt, | |
| "user_prompt": predo_prompt.replace("{question}", item['query_en']) | |
| } | |
| prompt_list += [oracle_generator, ] * 10 | |
| return prompt_list | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def limit_concurrency( | |
| coroutines: Sequence[Coroutine], concurrency: int | |
| ) -> list[Coroutine]: | |
| semaphore = asyncio.Semaphore(concurrency) | |
| async def with_concurrency_limit(coroutine: Coroutine) -> Coroutine: | |
| async with semaphore: | |
| return await coroutine | |
| return [with_concurrency_limit(coroutine) for coroutine in coroutines] | |
| async def request(data: str, client: AsyncOpenAI): | |
| try: | |
| response = await client.chat.completions.create( | |
| model=cfg.model_name, | |
| messages=[ | |
| {"role": "system", "content": data['system_prompt']}, | |
| {"role": "user", "content": data['user_prompt']} | |
| ], | |
| temperature=args.temperature if data['type'] == "input" else 0.7, | |
| max_tokens=args.max_completion_tokens, | |
| ) | |
| except Exception as e: | |
| print(f"Error - {e}") | |
| return data['tcb_id'], "error", "" | |
| return data['tcb_id'], data['type'], response.choices[0].message.content | |
| async def main(args: Argument, base_url: str, api_key: str = "sglang", save_file_path: str = ""): | |
| dataset = get_datasets_prompt() | |
| client = AsyncOpenAI(base_url=base_url, api_key=api_key) | |
| tasks = [request(data, client) for data in dataset] | |
| tasks = limit_concurrency(tasks, args.concurrency) | |
| async for rezip_response in tqdm( | |
| asyncio.as_completed(tasks), total=len(tasks), desc="Running" | |
| ): | |
| tcb_id, type_of_func, response = await rezip_response | |
| res_dict = { | |
| "tcb_id": tcb_id, | |
| "type": type_of_func, | |
| "response": response | |
| } | |
| # print result, or save to file | |
| append_dict_to_jsonl(save_file_path, res_dict) | |
| if __name__ == "__main__": | |
| args = Argument().parse_args() | |
| try: | |
| orginal_response_file = cfg.response_file.format("predo", cfg.model_name) | |
| print(orginal_response_file) | |
| asyncio.run(main(args, base_url=cfg.api_base, api_key=cfg.api_key, save_file_path=orginal_response_file)) | |
| except Exception as e: | |
| print(f'Error! {e}') | |
Xet Storage Details
- Size:
- 3.69 kB
- Xet hash:
- 347a7dc211d8b8eca866e4a946dda4926465e13f36a1c5f6cc9ebc633fa70592
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.