Tsukihjy's picture
download
raw
4.26 kB
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from prompt import code_system_prompt, baseline_code_test_prompt, shots
from typing import List, Optional
import re
import json
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
import datetime
def write_log(message: str, log_file: str = "log-lcb.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
def extract_code(ans_str):
pattern = r'```python\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
def extract_json(ans_str):
pattern = r'```json\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
import json
def extract_ids_from_jsonl(filepath):
ids = []
with open(filepath, 'r', encoding='utf-8') as file:
for line in file:
data = json.loads(line)
if 'tcb_id' in data:
ids.append(data['tcb_id'])
return ids
import argparse
batch = -1
# api_key = keys[batch // 3]
log_file = cfg.log_file.format(batch)
al_dataset = get_datasets_by_name(cfg.dataset_name)[0: 2]
print(len(al_dataset))
testcases_pass_rate = {}
prompt_list = []
for item in al_dataset:
total_gen_count = 0
tcb_id = item["tcb_id"]
problem_test_count = 0
query = item['query_en']
orginal_response = {
"tcb_id": item["tcb_id"]
}
write_log(f"{item['tcb_id']} start", log_file)
user_prompt = baseline_code_test_prompt.replace("{num_of_test}", str(cfg.gen_test_nums))
user_prompt += "===Example===\n" + shots
user_prompt += "===Question===\n" + query
prompt_list.append(
{
"tcb_id": tcb_id,
"system_prompt": code_system_prompt,
"user_prompt": user_prompt
}
)
import ray
from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
# prompt_list = prompt_list * 5
ds = ray.data.from_items(prompt_list)
# size = len(ds)
size = ds.count()
print(f"Size of dataset: {size} prompts")
# Configure vLLM engine.
config = vLLMEngineProcessorConfig(
model_source="/mnt/jfs/ckpt/checkpoints/Qwen2.5-32B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
"tensor_parallel_size": 8,
},
concurrency=1, # set the number of parallel vLLM replicas
batch_size=32,
)
# Create a Processor object, which will be used to
# do batch inference on the dataset
vllm_processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": row['system_prompt']},
{"role": "user", "content": row["user_prompt"]},
],
sampling_params=dict(
temperature=0.2,
max_tokens=8192,
),
),
postprocess=lambda row: dict(
answer=row["generated_text"],
**row, # This will return all the original columns in the dataset.
),
)
ds = vllm_processor(ds)
# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=1)
for output in outputs:
print(output['generated_text'])
# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
ds.write_json(orginal_response_file)

Xet Storage Details

Size:
4.26 kB
·
Xet hash:
3d07d1b76d2b13872110ddc569a9f6c323bb067c60aa8ac7062d050718162921

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.