Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- hard_divide.py +64 -0
- simple_cell_acc.py +104 -0
- sudoku_cal_hardness3.py +552 -0
- test.csv +3 -0
- train.csv +3 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.gitattributes
CHANGED
|
@@ -66,3 +66,5 @@ cd5_train.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
| 66 |
hard_train.csv filter=lfs diff=lfs merge=lfs -text
|
| 67 |
path_train.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 68 |
sudoku_train.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 66 |
hard_train.csv filter=lfs diff=lfs merge=lfs -text
|
| 67 |
path_train.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 68 |
sudoku_train.csv filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
test.csv filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
train.csv filter=lfs diff=lfs merge=lfs -text
|
hard_divide.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
def process_csv(input_file, output_file, sample_size=50000):
|
| 4 |
+
"""
|
| 5 |
+
处理CSV文件:
|
| 6 |
+
1. 读取数据并将question列中的所有'.'替换为'0'
|
| 7 |
+
2. 重命名列
|
| 8 |
+
3. 保留前指定数量的记录
|
| 9 |
+
4. 保存处理后的数据
|
| 10 |
+
|
| 11 |
+
参数:
|
| 12 |
+
input_file (str): 输入CSV文件路径
|
| 13 |
+
output_file (str): 输出CSV文件路径
|
| 14 |
+
sample_size (int): 要保留的记录数,默认为50,000
|
| 15 |
+
"""
|
| 16 |
+
# 读取CSV文件
|
| 17 |
+
df = pd.read_csv(input_file)
|
| 18 |
+
|
| 19 |
+
# 检查列是否存在
|
| 20 |
+
required_columns = ['source', 'question', 'answer', 'rating']
|
| 21 |
+
for col in required_columns:
|
| 22 |
+
if col not in df.columns:
|
| 23 |
+
raise ValueError(f"CSV文件中缺少必需的列: {col}")
|
| 24 |
+
|
| 25 |
+
# 处理question列 - 将所有'.'替换为'0'(针对数独格式)
|
| 26 |
+
df['question'] = df['question'].str.replace('.', '0')
|
| 27 |
+
|
| 28 |
+
# 重命名列
|
| 29 |
+
df = df.rename(columns={
|
| 30 |
+
'question': 'quizzes', # 使用question列作为quizzes
|
| 31 |
+
'answer': 'solutions'
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
# 保留前N个记录
|
| 35 |
+
if len(df) > sample_size:
|
| 36 |
+
df = df.head(sample_size)
|
| 37 |
+
print(f"已从{len(df)}条记录中保留前{sample_size}条")
|
| 38 |
+
else:
|
| 39 |
+
print(f"警告:文件只有{len(df)}条记录,不足{sample_size}条,将保留全部记录")
|
| 40 |
+
|
| 41 |
+
# 只保留需要的列
|
| 42 |
+
df = df[['quizzes', 'solutions', 'rating']]
|
| 43 |
+
|
| 44 |
+
# 保存处理后的数据
|
| 45 |
+
df.to_csv(output_file, index=False)
|
| 46 |
+
print(f"处理完成,结果已保存到: {output_file}")
|
| 47 |
+
print(f"最终记录数: {len(df)}")
|
| 48 |
+
print("\n替换效果验证(前3个示例):")
|
| 49 |
+
print(df[['quizzes']].head(3).to_string(index=False))
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
# 设置输入输出文件路径
|
| 53 |
+
input_csv = "data/test.csv" # 替换为你的输入文件路径
|
| 54 |
+
output_csv = "data/hard_test.csv" # 替换为你想要的输出路径
|
| 55 |
+
|
| 56 |
+
# 执行处理
|
| 57 |
+
process_csv(input_csv, output_csv, sample_size=5000)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
input_csv = "data/train.csv" # 替换为你的输入文件路径
|
| 61 |
+
output_csv = "data/hard_train.csv" # 替换为你想要的输出路径
|
| 62 |
+
|
| 63 |
+
# 执行处理
|
| 64 |
+
# process_csv(input_csv, output_csv, sample_size=100000)
|
simple_cell_acc.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
def calculate_accuracy(label, predict):
|
| 8 |
+
"""计算单个样本的准确率"""
|
| 9 |
+
|
| 10 |
+
label = label.replace("[PAD]", "").replace("[EOS]", "0")
|
| 11 |
+
predict = predict.replace("[PAD]", "").replace("[EOS]", "0")
|
| 12 |
+
total_chars = len(label)
|
| 13 |
+
correct_chars = sum(1 for l, p in zip(label, predict) if l == p)
|
| 14 |
+
return correct_chars / total_chars
|
| 15 |
+
# return label == predict
|
| 16 |
+
|
| 17 |
+
def evaluate_jsonl(file_path):
|
| 18 |
+
total_accuracy = 0
|
| 19 |
+
sample_count = 0
|
| 20 |
+
individual_accuracies = []
|
| 21 |
+
|
| 22 |
+
with open(file_path, 'r') as f:
|
| 23 |
+
for line in f:
|
| 24 |
+
data = json.loads(line)
|
| 25 |
+
label = data['label']
|
| 26 |
+
predict = data['predict']
|
| 27 |
+
|
| 28 |
+
if len(label) != len(predict):
|
| 29 |
+
print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})")
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
accuracy = calculate_accuracy(label, predict)
|
| 33 |
+
total_accuracy += accuracy
|
| 34 |
+
sample_count += 1
|
| 35 |
+
|
| 36 |
+
if sample_count == 0:
|
| 37 |
+
return 0, []
|
| 38 |
+
|
| 39 |
+
avg_accuracy = total_accuracy / sample_count
|
| 40 |
+
return avg_accuracy
|
| 41 |
+
|
| 42 |
+
def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"):
|
| 43 |
+
t_values = []
|
| 44 |
+
accuracies = []
|
| 45 |
+
missing_files = []
|
| 46 |
+
results = {}
|
| 47 |
+
|
| 48 |
+
for t in T_list:
|
| 49 |
+
file_path = os.path.join(exp_path, f"{dataset_name}_T{t}")
|
| 50 |
+
filename = os.path.join(file_path, "all_results.json")
|
| 51 |
+
if not os.path.exists(filename):
|
| 52 |
+
missing_files.append(filename)
|
| 53 |
+
continue
|
| 54 |
+
try:
|
| 55 |
+
with open(filename, 'r') as f:
|
| 56 |
+
data = json.load(f)
|
| 57 |
+
if acc_key not in data:
|
| 58 |
+
print(f"警告: {filename} 中未找到键 '{acc_key}'")
|
| 59 |
+
src_filename = os.path.join(file_path, "generated_predictions.jsonl")
|
| 60 |
+
data[acc_key] = evaluate_jsonl(src_filename)
|
| 61 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 62 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 63 |
+
|
| 64 |
+
acc = data[acc_key]
|
| 65 |
+
if isinstance(acc, str) and acc.endswith('%'):
|
| 66 |
+
acc = float(acc.strip('%')) / 100.0
|
| 67 |
+
t_values.append(t)
|
| 68 |
+
accuracies.append(acc)
|
| 69 |
+
results[t] = acc
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"处理文件 {filename} 时出错: {str(e)}")
|
| 72 |
+
|
| 73 |
+
return results, missing_files
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# experiments = [
|
| 77 |
+
# "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250702-123750",
|
| 78 |
+
# "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-021344"
|
| 79 |
+
# ]
|
| 80 |
+
# dataset_names = ["hard_test", "sudoku_test"]
|
| 81 |
+
|
| 82 |
+
experiments = [
|
| 83 |
+
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900",
|
| 84 |
+
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910",
|
| 85 |
+
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232"
|
| 86 |
+
]
|
| 87 |
+
dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
acc_key = "cell_acc"
|
| 91 |
+
|
| 92 |
+
for experiment in experiments:
|
| 93 |
+
for dataset_name in dataset_names:
|
| 94 |
+
file_path = os.path.join(experiment, dataset_name)
|
| 95 |
+
filename = os.path.join(file_path, "all_results.json")
|
| 96 |
+
with open(filename, 'r') as f:
|
| 97 |
+
data = json.load(f)
|
| 98 |
+
if acc_key not in data:
|
| 99 |
+
print(f"警告: {filename} 中未找到键 '{acc_key}'")
|
| 100 |
+
src_filename = os.path.join(file_path, "generated_predictions.jsonl")
|
| 101 |
+
data[acc_key] = evaluate_jsonl(src_filename)
|
| 102 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 103 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 104 |
+
|
sudoku_cal_hardness3.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Sudoku batch solver + difficulty analytics
|
| 4 |
+
- Bitmask + MRV DFS
|
| 5 |
+
- Multiprocessing
|
| 6 |
+
- Caching (.npz)
|
| 7 |
+
- Plots (hist, CDF, log-scale, quintiles)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
import csv
|
| 13 |
+
import time
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib
|
| 16 |
+
matplotlib.use('Agg') # 非交互模式(服务器/脚本环境)
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from matplotlib.ticker import MaxNLocator
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from multiprocessing import Pool
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
# 增大递归深度,避免极难题爆栈(按需调整)
|
| 24 |
+
sys.setrecursionlimit(10000)
|
| 25 |
+
|
| 26 |
+
# ==================== 全局优化表 & 常量 ====================
|
| 27 |
+
POPCOUNT = [bin(x).count("1") for x in range(512)] # 0..511 的 popcount 查表
|
| 28 |
+
CELL_TO_BOX = [[3*(r//3)+(c//3) for c in range(9)] for r in range(9)] # 预计算 box 索引
|
| 29 |
+
RANDOM_SEED = 2025
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ==================== Bitmask Sudoku Solver ====================
|
| 33 |
+
def initialize_masks(board):
|
| 34 |
+
"""初始化行/列/宫 bitmask"""
|
| 35 |
+
row_mask = [0] * 9
|
| 36 |
+
col_mask = [0] * 9
|
| 37 |
+
box_mask = [0] * 9
|
| 38 |
+
for r in range(9):
|
| 39 |
+
for c in range(9):
|
| 40 |
+
num = board[r][c]
|
| 41 |
+
if num != 0:
|
| 42 |
+
bit = 1 << (num - 1)
|
| 43 |
+
row_mask[r] |= bit
|
| 44 |
+
col_mask[c] |= bit
|
| 45 |
+
box_mask[CELL_TO_BOX[r][c]] |= bit
|
| 46 |
+
return row_mask, col_mask, box_mask
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def select_mrv_bitmask(board, row_mask, col_mask, box_mask):
|
| 50 |
+
"""MRV 选择下一个格子,返回(cell, possible_mask)"""
|
| 51 |
+
min_count = 10
|
| 52 |
+
best_cell = None
|
| 53 |
+
best_mask = None
|
| 54 |
+
for r in range(9):
|
| 55 |
+
for c in range(9):
|
| 56 |
+
if board[r][c] == 0:
|
| 57 |
+
mask = row_mask[r] | col_mask[c] | box_mask[CELL_TO_BOX[r][c]]
|
| 58 |
+
possible = (~mask) & 0x1FF
|
| 59 |
+
count = POPCOUNT[possible]
|
| 60 |
+
if count < min_count:
|
| 61 |
+
min_count = count
|
| 62 |
+
best_cell = (r, c)
|
| 63 |
+
best_mask = possible
|
| 64 |
+
if count == 1:
|
| 65 |
+
return best_cell, best_mask
|
| 66 |
+
return best_cell, best_mask # 若无空格,cell=None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_initial_min_remaining(board, row_mask, col_mask, box_mask):
|
| 70 |
+
"""获取初始状态下的最小候选数(MRV基线)"""
|
| 71 |
+
min_count = 10
|
| 72 |
+
for r in range(9):
|
| 73 |
+
for c in range(9):
|
| 74 |
+
if board[r][c] == 0:
|
| 75 |
+
mask = row_mask[r] | col_mask[c] | box_mask[CELL_TO_BOX[r][c]]
|
| 76 |
+
possible = (~mask) & 0x1FF
|
| 77 |
+
count = POPCOUNT[possible]
|
| 78 |
+
if count < min_count:
|
| 79 |
+
min_count = count
|
| 80 |
+
if count == 1:
|
| 81 |
+
return min_count
|
| 82 |
+
return min_count if min_count != 10 else 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def solve_sudoku_bitmask(board, row_mask, col_mask, box_mask, steps=None):
|
| 86 |
+
"""数独求解(DFS + MRV + bitmask),steps 为可选计数器 list([0])"""
|
| 87 |
+
cell, mask = select_mrv_bitmask(board, row_mask, col_mask, box_mask)
|
| 88 |
+
if cell is None:
|
| 89 |
+
return True # 已完成
|
| 90 |
+
if mask == 0:
|
| 91 |
+
return False # 无解分支
|
| 92 |
+
|
| 93 |
+
r, c = cell
|
| 94 |
+
box_idx = CELL_TO_BOX[r][c]
|
| 95 |
+
while mask:
|
| 96 |
+
bit = mask & -mask # 取最低位的1
|
| 97 |
+
num = (bit.bit_length() - 1) + 1
|
| 98 |
+
mask -= bit
|
| 99 |
+
|
| 100 |
+
board[r][c] = num
|
| 101 |
+
row_mask[r] |= bit
|
| 102 |
+
col_mask[c] |= bit
|
| 103 |
+
box_mask[box_idx] |= bit
|
| 104 |
+
|
| 105 |
+
if steps is not None:
|
| 106 |
+
steps[0] += 1
|
| 107 |
+
|
| 108 |
+
if solve_sudoku_bitmask(board, row_mask, col_mask, box_mask, steps):
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
# 回溯
|
| 112 |
+
board[r][c] = 0
|
| 113 |
+
row_mask[r] ^= bit
|
| 114 |
+
col_mask[c] ^= bit
|
| 115 |
+
box_mask[box_idx] ^= bit
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def evaluate_sudoku(board):
|
| 120 |
+
"""
|
| 121 |
+
返回 (empty_count, steps_used, initial_mrv_min)
|
| 122 |
+
注:即便无解,也返回累计的 steps(便于统计)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
empty_count = int(np.sum(np.array(board) == 0))
|
| 127 |
+
board_copy = [row[:] for row in board]
|
| 128 |
+
row_mask, col_mask, box_mask = initialize_masks(board_copy)
|
| 129 |
+
|
| 130 |
+
initial_min_remaining = get_initial_min_remaining(board_copy, row_mask, col_mask, box_mask)
|
| 131 |
+
|
| 132 |
+
steps = [0]
|
| 133 |
+
_ = solve_sudoku_bitmask(board_copy, row_mask, col_mask, box_mask, steps)
|
| 134 |
+
return empty_count, int(steps[0]), int(initial_min_remaining)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ==================== 并行处理 ====================
|
| 138 |
+
def process_single_sudoku_optimized(args):
|
| 139 |
+
board, idx = args
|
| 140 |
+
empty_count, steps_count, initial_min_remaining = evaluate_sudoku(board)
|
| 141 |
+
return idx, empty_count, steps_count, initial_min_remaining
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def parallel_solve_optimized(boards, n_workers=4):
|
| 145 |
+
empty_counts = [0] * len(boards)
|
| 146 |
+
steps_counts = [0] * len(boards)
|
| 147 |
+
initial_min_remainings = [0] * len(boards)
|
| 148 |
+
args_list = [(board, i) for i, board in enumerate(boards)]
|
| 149 |
+
with Pool(processes=n_workers) as pool:
|
| 150 |
+
with tqdm(total=len(boards), desc="Solving Sudoku", unit="puzzle") as pbar:
|
| 151 |
+
for idx, empty_count, steps_count, initial_min_remaining in pool.imap(
|
| 152 |
+
process_single_sudoku_optimized, args_list, chunksize=100):
|
| 153 |
+
empty_counts[idx] = empty_count
|
| 154 |
+
steps_counts[idx] = steps_count
|
| 155 |
+
initial_min_remainings[idx] = initial_min_remaining
|
| 156 |
+
pbar.update(1)
|
| 157 |
+
return empty_counts, steps_counts, initial_min_remainings
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ==================== 性能测试 ====================
|
| 161 |
+
def benchmark_solver(boards, sample_size=100):
|
| 162 |
+
print(f"Benchmarking solver with {sample_size} puzzles...")
|
| 163 |
+
start_time = time.time()
|
| 164 |
+
empty_counts, steps_counts, initial_min_remainings = parallel_solve_optimized(
|
| 165 |
+
boards[:sample_size], n_workers=4)
|
| 166 |
+
total_time = time.time() - start_time
|
| 167 |
+
print(f"Processed {sample_size} puzzles in {total_time:.2f} seconds")
|
| 168 |
+
print(f"Average per puzzle: {total_time / sample_size * 1000:.2f} ms")
|
| 169 |
+
print(f"Puzzles per second: {sample_size / total_time:.1f}")
|
| 170 |
+
return empty_counts, steps_counts, initial_min_remainings
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ==================== 绘图函数 ====================
|
| 174 |
+
def auto_bins(data):
|
| 175 |
+
"""自动计算 bin 数(Freedman–Diaconis rule),带上下限"""
|
| 176 |
+
data = np.asarray(data)
|
| 177 |
+
data = data[np.isfinite(data)]
|
| 178 |
+
if data.size == 0:
|
| 179 |
+
return 30
|
| 180 |
+
q75, q25 = np.percentile(data, [75, 25])
|
| 181 |
+
iqr = q75 - q25
|
| 182 |
+
bin_width = 2 * iqr * (len(data) ** (-1/3)) if iqr > 0 else 0
|
| 183 |
+
if bin_width > 0:
|
| 184 |
+
bins = int((data.max() - data.min()) / bin_width)
|
| 185 |
+
return max(10, min(200, bins))
|
| 186 |
+
return 30
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def plot_cdf(data, title, xlabel, save_path=None):
|
| 190 |
+
"""累计分布函数 (CDF)"""
|
| 191 |
+
data_sorted = np.sort(data)
|
| 192 |
+
if len(data_sorted) == 0:
|
| 193 |
+
print(f"[WARN] plot_cdf: empty data for {title}")
|
| 194 |
+
return
|
| 195 |
+
cdf = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
|
| 196 |
+
|
| 197 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 198 |
+
ax.plot(data_sorted, cdf, linewidth=1) # 用默认配色循环
|
| 199 |
+
|
| 200 |
+
ax.set_title(title, fontsize=12)
|
| 201 |
+
ax.set_xlabel(xlabel, fontsize=10)
|
| 202 |
+
ax.set_ylabel('CDF', fontsize=10)
|
| 203 |
+
ax.grid(True, linestyle='--', alpha=0.3)
|
| 204 |
+
|
| 205 |
+
plt.tight_layout()
|
| 206 |
+
if save_path:
|
| 207 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 208 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 209 |
+
print(f"Plot saved to: {save_path}")
|
| 210 |
+
plt.close(fig)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def plot_histogram_auto(data, title, xlabel, save_path=None):
|
| 214 |
+
"""直方图(自动 bins,x 轴尽量整数刻度)"""
|
| 215 |
+
if len(data) == 0:
|
| 216 |
+
print(f"[WARN] plot_histogram_auto: empty data for {title}")
|
| 217 |
+
return
|
| 218 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 219 |
+
bins = auto_bins(data)
|
| 220 |
+
|
| 221 |
+
ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
|
| 222 |
+
|
| 223 |
+
ax.xaxis.set_major_locator(MaxNLocator(nbins=15, integer=True))
|
| 224 |
+
|
| 225 |
+
ax.set_title(title, fontsize=12)
|
| 226 |
+
ax.set_xlabel(xlabel, fontsize=10)
|
| 227 |
+
ax.set_ylabel('Frequency', fontsize=10)
|
| 228 |
+
ax.grid(True, linestyle='--', alpha=0.3)
|
| 229 |
+
|
| 230 |
+
plt.tight_layout()
|
| 231 |
+
if save_path:
|
| 232 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 233 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 234 |
+
print(f"Plot saved to: {save_path}")
|
| 235 |
+
plt.close(fig)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def plot_histogram_log(data, title, xlabel, save_path=None):
|
| 239 |
+
"""直方图 + 对数横轴"""
|
| 240 |
+
if len(data) == 0:
|
| 241 |
+
print(f"[WARN] plot_histogram_log: empty data for {title}")
|
| 242 |
+
return
|
| 243 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 244 |
+
bins = auto_bins(data)
|
| 245 |
+
ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
|
| 246 |
+
ax.set_xscale("log")
|
| 247 |
+
|
| 248 |
+
ax.set_title(title, fontsize=12)
|
| 249 |
+
ax.set_xlabel(xlabel + " (log scale)", fontsize=10)
|
| 250 |
+
ax.set_ylabel('Frequency', fontsize=10)
|
| 251 |
+
ax.grid(True, linestyle='--', alpha=0.3)
|
| 252 |
+
|
| 253 |
+
plt.tight_layout()
|
| 254 |
+
if save_path:
|
| 255 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 256 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 257 |
+
print(f"Plot saved to: {save_path}")
|
| 258 |
+
plt.close(fig)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def plot_histogram_with_vlines_log(data, vlines, title, xlabel, save_path=None):
|
| 262 |
+
"""直方图(对数横轴)+ 竖线标注阈值"""
|
| 263 |
+
if len(data) == 0:
|
| 264 |
+
print(f"[WARN] plot_histogram_with_vlines_log: empty data for {title}")
|
| 265 |
+
return
|
| 266 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 267 |
+
bins = auto_bins(data)
|
| 268 |
+
ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
|
| 269 |
+
ax.set_xscale("log")
|
| 270 |
+
for v in vlines:
|
| 271 |
+
ax.axvline(v, linestyle='--', linewidth=1)
|
| 272 |
+
|
| 273 |
+
ax.set_title(title, fontsize=12)
|
| 274 |
+
ax.set_xlabel(xlabel + " (log scale)", fontsize=10)
|
| 275 |
+
ax.set_ylabel('Frequency', fontsize=10)
|
| 276 |
+
ax.grid(True, linestyle='--', alpha=0.3)
|
| 277 |
+
|
| 278 |
+
plt.tight_layout()
|
| 279 |
+
if save_path:
|
| 280 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 281 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 282 |
+
print(f"Plot saved to: {save_path}")
|
| 283 |
+
plt.close(fig)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def plot_cdf_multiple(datasets, labels, title, xlabel, save_path=None):
|
| 287 |
+
"""多组 CDF 叠加"""
|
| 288 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 289 |
+
for data, lab in zip(datasets, labels):
|
| 290 |
+
data = np.asarray(data)
|
| 291 |
+
if data.size == 0:
|
| 292 |
+
continue
|
| 293 |
+
data_sorted = np.sort(data)
|
| 294 |
+
cdf = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
|
| 295 |
+
ax.plot(data_sorted, cdf, linewidth=1, label=lab)
|
| 296 |
+
ax.set_title(title, fontsize=12)
|
| 297 |
+
ax.set_xlabel(xlabel, fontsize=10)
|
| 298 |
+
ax.set_ylabel('CDF', fontsize=10)
|
| 299 |
+
ax.grid(True, linestyle='--', alpha=0.3)
|
| 300 |
+
ax.legend()
|
| 301 |
+
plt.tight_layout()
|
| 302 |
+
if save_path:
|
| 303 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 304 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 305 |
+
print(f"Plot saved to: {save_path}")
|
| 306 |
+
plt.close(fig)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ==================== 数据缓存 ====================
|
| 310 |
+
def save_data(empty_counts, steps_counts, initial_min_remainings, data_file):
|
| 311 |
+
np.savez(
|
| 312 |
+
data_file,
|
| 313 |
+
empty_counts=np.asarray(empty_counts, dtype=np.int32),
|
| 314 |
+
steps_counts=np.asarray(steps_counts, dtype=np.int32),
|
| 315 |
+
initial_min_remainings=np.asarray(initial_min_remainings, dtype=np.int32),
|
| 316 |
+
)
|
| 317 |
+
print(f"Data saved to: {data_file}")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def load_data(data_file):
|
| 321 |
+
if os.path.exists(data_file):
|
| 322 |
+
data = np.load(data_file)
|
| 323 |
+
print(f"Data loaded from: {data_file}")
|
| 324 |
+
# 兼容可能缺少 initial_min_remainings 的旧数据
|
| 325 |
+
initial_min_remainings = data.get('initial_min_remainings')
|
| 326 |
+
return data['empty_counts'], data['steps_counts'], initial_min_remainings
|
| 327 |
+
return None, None, None
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# ==================== CSV Loader ====================
|
| 331 |
+
def load_sudoku_csv(file_path):
|
| 332 |
+
"""读取 CSV(第一列是 81 长度题面,'.' 或 '0' 表空),返回 board list"""
|
| 333 |
+
quizzes = []
|
| 334 |
+
with open(file_path, newline='') as f:
|
| 335 |
+
reader = csv.reader(f)
|
| 336 |
+
header = next(reader, None) # 跳过表头(如果有)
|
| 337 |
+
for row in tqdm(reader, desc="Loading CSV"):
|
| 338 |
+
quiz_str = re.sub(r'[^0-9.]', '', row[0])
|
| 339 |
+
quiz = np.array([int(c) if c != '.' else 0 for c in quiz_str]).reshape(9, 9).tolist()
|
| 340 |
+
quizzes.append(quiz)
|
| 341 |
+
return quizzes
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def load_sudoku_csv_strings(file_path):
|
| 345 |
+
"""加载原始题面字符串(长度81),用'0'替代'.',便于保存回溯。"""
|
| 346 |
+
puzzles = []
|
| 347 |
+
with open(file_path, newline='') as f:
|
| 348 |
+
reader = csv.reader(f)
|
| 349 |
+
header = next(reader, None)
|
| 350 |
+
for row in reader:
|
| 351 |
+
quiz_str = re.sub(r'[^0-9.]', '', row[0])
|
| 352 |
+
quiz_str = ''.join(['0' if c == '.' else c for c in quiz_str])
|
| 353 |
+
puzzles.append(quiz_str)
|
| 354 |
+
return puzzles
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# ==================== Main ====================
|
| 358 |
+
if __name__ == "__main__":
|
| 359 |
+
INPUT_CSV_FILES = ['data/hard_test.csv'] # 可替换为你的文件列表
|
| 360 |
+
N_WORKERS = 4
|
| 361 |
+
SAMPLE_SIZE = 5000 # 每个 CSV 最多处理多少条(用于快速试跑)
|
| 362 |
+
BENCHMARK = True # True 则对 solver 做基准测试
|
| 363 |
+
for csv_file in INPUT_CSV_FILES:
|
| 364 |
+
print(f"\nProcessing file: {csv_file}")
|
| 365 |
+
file_prefix = os.path.splitext(os.path.basename(csv_file))[0]
|
| 366 |
+
data_file = f"{file_prefix}_data.npz"
|
| 367 |
+
|
| 368 |
+
empty_counts, steps_counts, initial_min_remainings = load_data(data_file)
|
| 369 |
+
# if empty_counts is None or steps_counts is None or initial_min_remainings is None:
|
| 370 |
+
if True:
|
| 371 |
+
print("Data not found, processing...")
|
| 372 |
+
quizzes = load_sudoku_csv(csv_file)
|
| 373 |
+
print(f"Loaded {len(quizzes)} sudoku puzzles")
|
| 374 |
+
if BENCHMARK:
|
| 375 |
+
empty_counts, steps_counts, initial_min_remainings = benchmark_solver(quizzes, SAMPLE_SIZE)
|
| 376 |
+
else:
|
| 377 |
+
empty_counts, steps_counts, initial_min_remainings = parallel_solve_optimized(
|
| 378 |
+
quizzes[:SAMPLE_SIZE], n_workers=N_WORKERS)
|
| 379 |
+
save_data(empty_counts, steps_counts, initial_min_remainings, data_file)
|
| 380 |
+
else:
|
| 381 |
+
print("Using cached data")
|
| 382 |
+
|
| 383 |
+
# ============ 绘图输出目录 ============
|
| 384 |
+
out_dir = os.path.join('figures', file_prefix)
|
| 385 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 386 |
+
|
| 387 |
+
# ============ 全量分布图 ============
|
| 388 |
+
plot_histogram_auto(empty_counts,
|
| 389 |
+
f'{file_prefix} - Empty Count Distribution',
|
| 390 |
+
'Empty Count',
|
| 391 |
+
os.path.join(out_dir, f'{file_prefix}_empty_count_hist.png'))
|
| 392 |
+
|
| 393 |
+
plot_histogram_auto(steps_counts,
|
| 394 |
+
f'{file_prefix} - Steps Count Distribution',
|
| 395 |
+
'Steps Count',
|
| 396 |
+
os.path.join(out_dir, f'{file_prefix}_steps_count_hist.png'))
|
| 397 |
+
|
| 398 |
+
plot_histogram_log(steps_counts,
|
| 399 |
+
f'{file_prefix} - Steps Count Distribution (Log X)',
|
| 400 |
+
'Steps Count',
|
| 401 |
+
os.path.join(out_dir, f'{file_prefix}_steps_count_hist_log.png'))
|
| 402 |
+
|
| 403 |
+
plot_cdf(steps_counts,
|
| 404 |
+
f'{file_prefix} - Steps Count CDF',
|
| 405 |
+
'Steps Count',
|
| 406 |
+
os.path.join(out_dir, f'{file_prefix}_steps_count_cdf.png'))
|
| 407 |
+
|
| 408 |
+
plot_cdf(empty_counts,
|
| 409 |
+
f'{file_prefix} - Empty Count CDF',
|
| 410 |
+
'Empty Count',
|
| 411 |
+
os.path.join(out_dir, f'{file_prefix}_empty_count_cdf.png'))
|
| 412 |
+
|
| 413 |
+
# ============ 五档分位 + 抽样保存 ============
|
| 414 |
+
TOTAL_N = 5000
|
| 415 |
+
SAMPLE_PER_BIN = 1000
|
| 416 |
+
CONSIDER_N = int(min(TOTAL_N, len(steps_counts)))
|
| 417 |
+
consider_idx = np.arange(CONSIDER_N)
|
| 418 |
+
|
| 419 |
+
if CONSIDER_N > 0:
|
| 420 |
+
steps_consider = np.array(steps_counts)[:CONSIDER_N]
|
| 421 |
+
empty_consider = np.array(empty_counts)[:CONSIDER_N]
|
| 422 |
+
|
| 423 |
+
# 分位点(原逻辑:基于数值阈值分箱,可能因边界重复导致每箱数量非严格等分)
|
| 424 |
+
# q20, q40, q60, q80 = np.quantile(steps_consider, [0.2, 0.4, 0.6, 0.8])
|
| 425 |
+
# bins = [-np.inf, q20, q40, q60, q80, np.inf]
|
| 426 |
+
# labels = np.digitize(steps_consider, bins) # 1..5 档
|
| 427 |
+
|
| 428 |
+
# 改为基于排序排名严格切分,保证每个 bin 精确包含 SAMPLE_PER_BIN 条(若可行)
|
| 429 |
+
sorted_idx = np.argsort(steps_consider, kind='stable')
|
| 430 |
+
# 计算每个 bin 的起止索引(尽量平均,优先保证前四个 bin 为 SAMPLE_PER_BIN 条)
|
| 431 |
+
bin_size = SAMPLE_PER_BIN
|
| 432 |
+
b1 = sorted_idx[0:bin_size]
|
| 433 |
+
b2 = sorted_idx[bin_size:bin_size*2]
|
| 434 |
+
b3 = sorted_idx[bin_size*2:bin_size*3]
|
| 435 |
+
b4 = sorted_idx[bin_size*3:bin_size*4]
|
| 436 |
+
b5 = sorted_idx[bin_size*4:CONSIDER_N]
|
| 437 |
+
|
| 438 |
+
# 直方图(含基于排名的阈值,log x)
|
| 439 |
+
def safe_idx(pos):
|
| 440 |
+
return min(max(pos, 0), CONSIDER_N - 1)
|
| 441 |
+
v20 = steps_consider[sorted_idx[safe_idx(bin_size - 1)]]
|
| 442 |
+
v40 = steps_consider[sorted_idx[safe_idx(bin_size * 2 - 1)]]
|
| 443 |
+
v60 = steps_consider[sorted_idx[safe_idx(bin_size * 3 - 1)]]
|
| 444 |
+
v80 = steps_consider[sorted_idx[safe_idx(bin_size * 4 - 1)]]
|
| 445 |
+
plot_histogram_with_vlines_log(
|
| 446 |
+
steps_consider,
|
| 447 |
+
[v20, v40, v60, v80],
|
| 448 |
+
f'{file_prefix} - Steps Histogram with Quintile Thresholds',
|
| 449 |
+
'Steps Count',
|
| 450 |
+
save_path=os.path.join(out_dir, f'{file_prefix}_steps_hist_with_quintiles.png')
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# 各档数据 & 叠加 CDF
|
| 454 |
+
# bin 可视化数据基于排名切分
|
| 455 |
+
bin_data = [
|
| 456 |
+
steps_consider[b1],
|
| 457 |
+
steps_consider[b2],
|
| 458 |
+
steps_consider[b3],
|
| 459 |
+
steps_consider[b4],
|
| 460 |
+
steps_consider[b5],
|
| 461 |
+
]
|
| 462 |
+
plot_cdf_multiple(
|
| 463 |
+
bin_data,
|
| 464 |
+
[f'Bin{k}' for k in [1, 2, 3, 4, 5]],
|
| 465 |
+
f'{file_prefix} - Steps CDF by Quintile Bins',
|
| 466 |
+
'Steps Count',
|
| 467 |
+
save_path=os.path.join(out_dir, f'{file_prefix}_steps_cdf_quintiles.png')
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# 读取原始题面字符串用于导出
|
| 471 |
+
puzzles_raw = load_sudoku_csv_strings(csv_file)
|
| 472 |
+
|
| 473 |
+
sampled_by_bin = {}
|
| 474 |
+
for target_bin in [1, 3, 5]:
|
| 475 |
+
# 基于排名切分得到每个 bin 的索引集合
|
| 476 |
+
if target_bin == 1:
|
| 477 |
+
idx_in_bin = b1
|
| 478 |
+
elif target_bin == 3:
|
| 479 |
+
idx_in_bin = b3
|
| 480 |
+
elif target_bin == 5:
|
| 481 |
+
idx_in_bin = b5
|
| 482 |
+
else:
|
| 483 |
+
idx_in_bin = np.array([], dtype=int)
|
| 484 |
+
|
| 485 |
+
if len(idx_in_bin) == 0:
|
| 486 |
+
print(f"Bin {target_bin} has no samples.")
|
| 487 |
+
sampled_by_bin[target_bin] = np.array([], dtype=int)
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
# 取前 SAMPLE_PER_BIN 条,确保导出数量一致(若该 bin 少于该数量,则全量导出)
|
| 491 |
+
take_n = min(len(idx_in_bin), SAMPLE_PER_BIN)
|
| 492 |
+
sampled_idx = idx_in_bin[:take_n]
|
| 493 |
+
|
| 494 |
+
sampled_by_bin[target_bin] = np.array(sampled_idx, dtype=int)
|
| 495 |
+
|
| 496 |
+
out_csv = f"{file_prefix}_bin{target_bin}_sample{SAMPLE_PER_BIN}.csv"
|
| 497 |
+
with open(out_csv, 'w', newline='') as wf:
|
| 498 |
+
writer = csv.writer(wf)
|
| 499 |
+
writer.writerow(['puzzle_index', 'steps', 'empty_count', 'puzzle'])
|
| 500 |
+
for i in sampled_idx:
|
| 501 |
+
writer.writerow([int(i), int(steps_consider[i]), int(empty_consider[i]), puzzles_raw[i]])
|
| 502 |
+
print(f"Saved bin {target_bin} sample to: {out_csv} (count={len(sampled_idx)})")
|
| 503 |
+
|
| 504 |
+
# 该部分的直方图与CDF
|
| 505 |
+
d_steps = steps_consider[sampled_by_bin[target_bin]]
|
| 506 |
+
plot_histogram_auto(
|
| 507 |
+
d_steps,
|
| 508 |
+
f'{file_prefix} - Bin{target_bin} Sample Steps Hist',
|
| 509 |
+
'Steps Count',
|
| 510 |
+
os.path.join(out_dir, f'{file_prefix}_bin{target_bin}_sample_steps_hist.png')
|
| 511 |
+
)
|
| 512 |
+
plot_cdf(
|
| 513 |
+
d_steps,
|
| 514 |
+
f'{file_prefix} - Bin{target_bin} Sample Steps CDF',
|
| 515 |
+
'Steps Count',
|
| 516 |
+
os.path.join(out_dir, f'{file_prefix}_bin{target_bin}_sample_steps_cdf.png')
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# 混合保存 1/3/5 档
|
| 520 |
+
def save_mix(mix_bins, out_name):
|
| 521 |
+
mix_idx = np.concatenate([sampled_by_bin.get(b, np.array([], dtype=int)) for b in mix_bins])
|
| 522 |
+
if mix_idx.size == 0:
|
| 523 |
+
print(f"Mix {out_name} has no samples; skipped.")
|
| 524 |
+
return
|
| 525 |
+
out_csv = f"{file_prefix}_{out_name}.csv"
|
| 526 |
+
with open(out_csv, 'w', newline='') as wf:
|
| 527 |
+
writer = csv.writer(wf)
|
| 528 |
+
writer.writerow(['puzzle_index', 'steps', 'empty_count', 'puzzle'])
|
| 529 |
+
for i in mix_idx:
|
| 530 |
+
writer.writerow([int(i), int(steps_consider[i]), int(empty_consider[i]), puzzles_raw[i]])
|
| 531 |
+
print(f"Saved mix {out_name} to: {out_csv} (count={mix_idx.size})")
|
| 532 |
+
|
| 533 |
+
# 混合 1+3+5:直方图与CDF
|
| 534 |
+
mix_name = f"mix_bin1_3_5_sample2000"
|
| 535 |
+
save_mix([1, 3, 5], mix_name)
|
| 536 |
+
mix_idx_all = np.concatenate([sampled_by_bin.get(b, np.array([], dtype=int)) for b in [1, 3, 5]])
|
| 537 |
+
if mix_idx_all.size > 0:
|
| 538 |
+
d_mix = steps_consider[mix_idx_all]
|
| 539 |
+
plot_histogram_auto(
|
| 540 |
+
d_mix,
|
| 541 |
+
f'{file_prefix} - Mix(1+3+5) Sample Steps Hist',
|
| 542 |
+
'Steps Count',
|
| 543 |
+
os.path.join(out_dir, f'{file_prefix}_{mix_name}_steps_hist.png')
|
| 544 |
+
)
|
| 545 |
+
plot_cdf(
|
| 546 |
+
d_mix,
|
| 547 |
+
f'{file_prefix} - Mix(1+3+5) Sample Steps CDF',
|
| 548 |
+
'Steps Count',
|
| 549 |
+
os.path.join(out_dir, f'{file_prefix}_{mix_name}_steps_cdf.png')
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
print("\nDone.")
|
test.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2fd52aea23d331d5b4ee723c856236e838a9fb9a70e66f4e0e0cf26c338c6a8
|
| 3 |
+
size 79360390
|
train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64b46674db0148e0d73a16346dadeb2b1c00824d3fca3f85b2ae7037f6b4b38e
|
| 3 |
+
size 718819925
|