Tsukihjy's picture
download
raw
4.82 kB
import concurrent.futures
import psutil
import os
import time
import threading
import json
def get_response_function(repsonse_path, model_name, test_al):
ds = json.load(open("/home/luoxianzhen/yang/data/Ours/TestcaseBench-v28.json", "r", encoding="utf-8"))
ds = ds[0:100]
subset = {}
for item in ds:
subset[item['tcb_id']] = len(item['wrong_code'])
test_func_list = read_jsonl(repsonse_path.format(test_al, model_name))
test_functions = []
for response_item in test_func_list:
try:
if "```python" in response_item["response"]:
code_string = extract_code(response_item["response"])
elif '<ASSISTANT>'in response_item["response"]:
code_string = extract_content_code(response_item["response"])
else:
code_string = response_item["response"]
except:
continue
try:
exec(code_string)
except:
code_string = replace_newline_in_fstring(code_string)
test_functions.append(code_string)
# func_data = []
# for k, v in test_functions.items():
# if k not in subset.keys() :
# continue
# func_data.append({
# "tcb_id": k,
# "func_list": v,
# })
return test_functions
def read_jsonl(file_path):
data = []
with open(file_path, 'r') as file:
for line in file:
data.append(json.loads(line))
return data
import re
def extract_code(ans_str):
pattern = r'```python\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
def extract_content_code(ans_str):
pattern = r'<ASSISTANT>(.*?)</ASSISTANT>'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
import re
def replace_newline_in_fstring(code: str) -> str:
# 找到以 f" 开头的行,并且在行内替换 \n 为 \\n
def replace_in_fstring(match):
# 获取 f" 后面的部分,直到下一个 " 结束
string_content = match.group(1)
# 替换其中的 \n 为 \\n
modified_content = string_content.replace("\n", "\\n")
# 返回修改后的 f"内容"
return f'f"{modified_content}"'
# 正则表达式:匹配 f" 开头的字符串,并且中间的内容捕获到括号内
pattern = r'f"([^"]*)"'
# 全局变量,统计内存爆炸的次数
memory_explosion_count = 0
MEMORY_THRESHOLD = 1 * 1024 * 1024 * 1024 # 1 GB in bytes
# 监控内存的方法
def monitor_memory(pid):
global memory_explosion_count
process = psutil.Process(pid)
while True:
memory_info = process.memory_info()
if memory_info.rss > MEMORY_THRESHOLD: # 如果占用内存超过1GB
print(f"Memory usage exceeded 1GB! Killing process with PID: {pid}")
os.kill(pid, 9) # 强制终止进程
memory_explosion_count += 1
break
time.sleep(1)
# 执行代码的并行方法
def execute_code(codestr):
try:
# 执行传入的代码字符串
exec(codestr)
# 执行方法 aaa
aaa()
except Exception as e:
print(f"Error executing code: {e}")
# 并行任务执行的主方法
def run_parallel(codestr_list, max_parallel):
global memory_explosion_count
futures = []
# 使用 ThreadPoolExecutor 来控制并行数
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel) as executor:
# 提交任务并行执行
for codestr in codestr_list:
# 启动目标任务并获取其进程ID
process = executor.submit(execute_code, codestr)
# 启动一个线程来监控该进程的内存占用
monitor_thread = threading.Thread(target=monitor_memory, args=(process.pid,))
monitor_thread.start()
futures.append(process)
# 等待所有任务完成
for future in concurrent.futures.as_completed(futures):
try:
future.result() # 获取任务结果,如果有异常会被抛出
except Exception as e:
print(f"Error in parallel execution: {e}")
if __name__ == "__main__":
# 示例代码字符串列表,假设每个字符串里面有相同的代码
model_name = 'gpt-4o'
testcase_alg = 'lcb'
codestr_list = get_response_function(repsonse_path="/home/luoxianzhen/yang/data/response-orginal/orginal_response_{}_{}.jsonl", model_name=model_name, test_al=testcase_alg)
codestr_list = codestr_list[0:10]
# 设置最大并行数量
max_parallel = 20
# 调用并行执行方法
run_parallel(codestr_list, max_parallel)
print(f"Total memory explosion count: {memory_explosion_count}")

Xet Storage Details

Size:
4.82 kB
·
Xet hash:
e5839dc9de370d298aac59217599f343b6a41f50ea8aadca4713af69375269ba

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.