| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from response import TurboResponser | |
| from dataset_all import get_datasets_by_name | |
| from is_correct import test_output_comparison | |
| from config import cfg | |
| from prompt import code_system_prompt, baseline_code_test_prompt, shots | |
| from typing import List, Optional | |
| import re | |
| import json | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-lcb.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| def extract_code(ans_str): | |
| pattern = r'```python\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| def extract_json(ans_str): | |
| pattern = r'```json\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| import json | |
| def extract_ids_from_jsonl(filepath): | |
| ids = [] | |
| with open(filepath, 'r', encoding='utf-8') as file: | |
| for line in file: | |
| data = json.loads(line) | |
| if 'tcb_id' in data: | |
| ids.append(data['tcb_id']) | |
| return ids | |
| import argparse | |
| batch = -1 | |
| # api_key = keys[batch // 3] | |
| log_file = cfg.log_file.format(batch) | |
| al_dataset = get_datasets_by_name(cfg.dataset_name)[0: 2] | |
| print(len(al_dataset)) | |
| testcases_pass_rate = {} | |
| prompt_list = [] | |
| for item in al_dataset: | |
| total_gen_count = 0 | |
| tcb_id = item["tcb_id"] | |
| problem_test_count = 0 | |
| query = item['query_en'] | |
| orginal_response = { | |
| "tcb_id": item["tcb_id"] | |
| } | |
| write_log(f"{item['tcb_id']} start", log_file) | |
| user_prompt = baseline_code_test_prompt.replace("{num_of_test}", str(cfg.gen_test_nums)) | |
| user_prompt += "===Example===\n" + shots | |
| user_prompt += "===Question===\n" + query | |
| prompt_list.append( | |
| { | |
| "tcb_id": tcb_id, | |
| "system_prompt": code_system_prompt, | |
| "user_prompt": user_prompt | |
| } | |
| ) | |
| import ray | |
| from packaging.version import Version | |
| from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig | |
| # prompt_list = prompt_list * 5 | |
| ds = ray.data.from_items(prompt_list) | |
| # size = len(ds) | |
| size = ds.count() | |
| print(f"Size of dataset: {size} prompts") | |
| # Configure vLLM engine. | |
| config = vLLMEngineProcessorConfig( | |
| model_source="/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct", | |
| engine_kwargs={ | |
| "enable_chunked_prefill": True, | |
| "max_num_batched_tokens": 4096, | |
| "max_model_len": 16384, | |
| "tensor_parallel_size": 8, | |
| }, | |
| concurrency=1, # set the number of parallel vLLM replicas | |
| batch_size=32, | |
| ) | |
| # Create a Processor object, which will be used to | |
| # do batch inference on the dataset | |
| vllm_processor = build_llm_processor( | |
| config, | |
| preprocess=lambda row: dict( | |
| messages=[ | |
| {"role": "system", "content": row['system_prompt']}, | |
| {"role": "user", "content": row["user_prompt"]}, | |
| ], | |
| sampling_params=dict( | |
| temperature=0.2, | |
| max_tokens=8192, | |
| ), | |
| ), | |
| postprocess=lambda row: dict( | |
| answer=row["generated_text"], | |
| **row, # This will return all the original columns in the dataset. | |
| ), | |
| ) | |
| ds = vllm_processor(ds) | |
| # Peek first 10 results. | |
| # NOTE: This is for local testing and debugging. For production use case, | |
| # one should write full result out as shown below. | |
| outputs = ds.take(limit=1) | |
| for output in outputs: | |
| print(output['generated_text']) | |
| # Write inference output data out as Parquet files to S3. | |
| # Multiple files would be written to the output destination, | |
| # and each task would write one or more files separately. | |
| # | |
| ds.write_json(orginal_response_file) |
Xet Storage Details
- Size:
- 4.26 kB
- Xet hash:
- 3d07d1b76d2b13872110ddc569a9f6c323bb067c60aa8ac7062d050718162921
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.