zeyuzy commited on
Commit
077b816
·
verified ·
1 Parent(s): 53319dd

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +2 -0
  3. hard_divide.py +64 -0
  4. simple_cell_acc.py +104 -0
  5. sudoku_cal_hardness3.py +552 -0
  6. test.csv +3 -0
  7. train.csv +3 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -66,3 +66,5 @@ cd5_train.jsonl filter=lfs diff=lfs merge=lfs -text
66
  hard_train.csv filter=lfs diff=lfs merge=lfs -text
67
  path_train.jsonl filter=lfs diff=lfs merge=lfs -text
68
  sudoku_train.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
66
  hard_train.csv filter=lfs diff=lfs merge=lfs -text
67
  path_train.jsonl filter=lfs diff=lfs merge=lfs -text
68
  sudoku_train.csv filter=lfs diff=lfs merge=lfs -text
69
+ test.csv filter=lfs diff=lfs merge=lfs -text
70
+ train.csv filter=lfs diff=lfs merge=lfs -text
hard_divide.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def process_csv(input_file, output_file, sample_size=50000):
4
+ """
5
+ 处理CSV文件:
6
+ 1. 读取数据并将question列中的所有'.'替换为'0'
7
+ 2. 重命名列
8
+ 3. 保留前指定数量的记录
9
+ 4. 保存处理后的数据
10
+
11
+ 参数:
12
+ input_file (str): 输入CSV文件路径
13
+ output_file (str): 输出CSV文件路径
14
+ sample_size (int): 要保留的记录数,默认为50,000
15
+ """
16
+ # 读取CSV文件
17
+ df = pd.read_csv(input_file)
18
+
19
+ # 检查列是否存在
20
+ required_columns = ['source', 'question', 'answer', 'rating']
21
+ for col in required_columns:
22
+ if col not in df.columns:
23
+ raise ValueError(f"CSV文件中缺少必需的列: {col}")
24
+
25
+ # 处理question列 - 将所有'.'替换为'0'(针对数独格式)
26
+ df['question'] = df['question'].str.replace('.', '0')
27
+
28
+ # 重命名列
29
+ df = df.rename(columns={
30
+ 'question': 'quizzes', # 使用question列作为quizzes
31
+ 'answer': 'solutions'
32
+ })
33
+
34
+ # 保留前N个记录
35
+ if len(df) > sample_size:
36
+ df = df.head(sample_size)
37
+ print(f"已从{len(df)}条记录中保留前{sample_size}条")
38
+ else:
39
+ print(f"警告:文件只有{len(df)}条记录,不足{sample_size}条,将保留全部记录")
40
+
41
+ # 只保留需要的列
42
+ df = df[['quizzes', 'solutions', 'rating']]
43
+
44
+ # 保存处理后的数据
45
+ df.to_csv(output_file, index=False)
46
+ print(f"处理完成,结果已保存到: {output_file}")
47
+ print(f"最终记录数: {len(df)}")
48
+ print("\n替换效果验证(前3个示例):")
49
+ print(df[['quizzes']].head(3).to_string(index=False))
50
+
51
+ if __name__ == "__main__":
52
+ # 设置输入输出文件路径
53
+ input_csv = "data/test.csv" # 替换为你的输入文件路径
54
+ output_csv = "data/hard_test.csv" # 替换为你想要的输出路径
55
+
56
+ # 执行处理
57
+ process_csv(input_csv, output_csv, sample_size=5000)
58
+
59
+
60
+ input_csv = "data/train.csv" # 替换为你的输入文件路径
61
+ output_csv = "data/hard_train.csv" # 替换为你想要的输出路径
62
+
63
+ # 执行处理
64
+ # process_csv(input_csv, output_csv, sample_size=100000)
simple_cell_acc.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ import os
4
+ import numpy as np
5
+ from collections import defaultdict
6
+
7
+ def calculate_accuracy(label, predict):
8
+ """计算单个样本的准确率"""
9
+
10
+ label = label.replace("[PAD]", "").replace("[EOS]", "0")
11
+ predict = predict.replace("[PAD]", "").replace("[EOS]", "0")
12
+ total_chars = len(label)
13
+ correct_chars = sum(1 for l, p in zip(label, predict) if l == p)
14
+ return correct_chars / total_chars
15
+ # return label == predict
16
+
17
+ def evaluate_jsonl(file_path):
18
+ total_accuracy = 0
19
+ sample_count = 0
20
+ individual_accuracies = []
21
+
22
+ with open(file_path, 'r') as f:
23
+ for line in f:
24
+ data = json.loads(line)
25
+ label = data['label']
26
+ predict = data['predict']
27
+
28
+ if len(label) != len(predict):
29
+ print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})")
30
+ continue
31
+
32
+ accuracy = calculate_accuracy(label, predict)
33
+ total_accuracy += accuracy
34
+ sample_count += 1
35
+
36
+ if sample_count == 0:
37
+ return 0, []
38
+
39
+ avg_accuracy = total_accuracy / sample_count
40
+ return avg_accuracy
41
+
42
+ def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"):
43
+ t_values = []
44
+ accuracies = []
45
+ missing_files = []
46
+ results = {}
47
+
48
+ for t in T_list:
49
+ file_path = os.path.join(exp_path, f"{dataset_name}_T{t}")
50
+ filename = os.path.join(file_path, "all_results.json")
51
+ if not os.path.exists(filename):
52
+ missing_files.append(filename)
53
+ continue
54
+ try:
55
+ with open(filename, 'r') as f:
56
+ data = json.load(f)
57
+ if acc_key not in data:
58
+ print(f"警告: {filename} 中未找到键 '{acc_key}'")
59
+ src_filename = os.path.join(file_path, "generated_predictions.jsonl")
60
+ data[acc_key] = evaluate_jsonl(src_filename)
61
+ with open(filename, 'w', encoding='utf-8') as f:
62
+ json.dump(data, f, ensure_ascii=False, indent=4)
63
+
64
+ acc = data[acc_key]
65
+ if isinstance(acc, str) and acc.endswith('%'):
66
+ acc = float(acc.strip('%')) / 100.0
67
+ t_values.append(t)
68
+ accuracies.append(acc)
69
+ results[t] = acc
70
+ except Exception as e:
71
+ print(f"处理文件 {filename} 时出错: {str(e)}")
72
+
73
+ return results, missing_files
74
+
75
+
76
+ # experiments = [
77
+ # "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250702-123750",
78
+ # "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-021344"
79
+ # ]
80
+ # dataset_names = ["hard_test", "sudoku_test"]
81
+
82
+ experiments = [
83
+ "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900",
84
+ "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910",
85
+ "output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232"
86
+ ]
87
+ dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"]
88
+
89
+
90
+ acc_key = "cell_acc"
91
+
92
+ for experiment in experiments:
93
+ for dataset_name in dataset_names:
94
+ file_path = os.path.join(experiment, dataset_name)
95
+ filename = os.path.join(file_path, "all_results.json")
96
+ with open(filename, 'r') as f:
97
+ data = json.load(f)
98
+ if acc_key not in data:
99
+ print(f"警告: {filename} 中未找到键 '{acc_key}'")
100
+ src_filename = os.path.join(file_path, "generated_predictions.jsonl")
101
+ data[acc_key] = evaluate_jsonl(src_filename)
102
+ with open(filename, 'w', encoding='utf-8') as f:
103
+ json.dump(data, f, ensure_ascii=False, indent=4)
104
+
sudoku_cal_hardness3.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Sudoku batch solver + difficulty analytics
4
+ - Bitmask + MRV DFS
5
+ - Multiprocessing
6
+ - Caching (.npz)
7
+ - Plots (hist, CDF, log-scale, quintiles)
8
+ """
9
+
10
+ import os
11
+ import re
12
+ import csv
13
+ import time
14
+ import numpy as np
15
+ import matplotlib
16
+ matplotlib.use('Agg') # 非交互模式(服务器/脚本环境)
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib.ticker import MaxNLocator
19
+ from tqdm import tqdm
20
+ from multiprocessing import Pool
21
+ import sys
22
+
23
+ # 增大递归深度,避免极难题爆栈(按需调整)
24
+ sys.setrecursionlimit(10000)
25
+
26
+ # ==================== 全局优化表 & 常量 ====================
27
+ POPCOUNT = [bin(x).count("1") for x in range(512)] # 0..511 的 popcount 查表
28
+ CELL_TO_BOX = [[3*(r//3)+(c//3) for c in range(9)] for r in range(9)] # 预计算 box 索引
29
+ RANDOM_SEED = 2025
30
+
31
+
32
+ # ==================== Bitmask Sudoku Solver ====================
33
+ def initialize_masks(board):
34
+ """初始化行/列/宫 bitmask"""
35
+ row_mask = [0] * 9
36
+ col_mask = [0] * 9
37
+ box_mask = [0] * 9
38
+ for r in range(9):
39
+ for c in range(9):
40
+ num = board[r][c]
41
+ if num != 0:
42
+ bit = 1 << (num - 1)
43
+ row_mask[r] |= bit
44
+ col_mask[c] |= bit
45
+ box_mask[CELL_TO_BOX[r][c]] |= bit
46
+ return row_mask, col_mask, box_mask
47
+
48
+
49
+ def select_mrv_bitmask(board, row_mask, col_mask, box_mask):
50
+ """MRV 选择下一个格子,返回(cell, possible_mask)"""
51
+ min_count = 10
52
+ best_cell = None
53
+ best_mask = None
54
+ for r in range(9):
55
+ for c in range(9):
56
+ if board[r][c] == 0:
57
+ mask = row_mask[r] | col_mask[c] | box_mask[CELL_TO_BOX[r][c]]
58
+ possible = (~mask) & 0x1FF
59
+ count = POPCOUNT[possible]
60
+ if count < min_count:
61
+ min_count = count
62
+ best_cell = (r, c)
63
+ best_mask = possible
64
+ if count == 1:
65
+ return best_cell, best_mask
66
+ return best_cell, best_mask # 若无空格,cell=None
67
+
68
+
69
+ def get_initial_min_remaining(board, row_mask, col_mask, box_mask):
70
+ """获取初始状态下的最小候选数(MRV基线)"""
71
+ min_count = 10
72
+ for r in range(9):
73
+ for c in range(9):
74
+ if board[r][c] == 0:
75
+ mask = row_mask[r] | col_mask[c] | box_mask[CELL_TO_BOX[r][c]]
76
+ possible = (~mask) & 0x1FF
77
+ count = POPCOUNT[possible]
78
+ if count < min_count:
79
+ min_count = count
80
+ if count == 1:
81
+ return min_count
82
+ return min_count if min_count != 10 else 0
83
+
84
+
85
+ def solve_sudoku_bitmask(board, row_mask, col_mask, box_mask, steps=None):
86
+ """数独求解(DFS + MRV + bitmask),steps 为可选计数器 list([0])"""
87
+ cell, mask = select_mrv_bitmask(board, row_mask, col_mask, box_mask)
88
+ if cell is None:
89
+ return True # 已完成
90
+ if mask == 0:
91
+ return False # 无解分支
92
+
93
+ r, c = cell
94
+ box_idx = CELL_TO_BOX[r][c]
95
+ while mask:
96
+ bit = mask & -mask # 取最低位的1
97
+ num = (bit.bit_length() - 1) + 1
98
+ mask -= bit
99
+
100
+ board[r][c] = num
101
+ row_mask[r] |= bit
102
+ col_mask[c] |= bit
103
+ box_mask[box_idx] |= bit
104
+
105
+ if steps is not None:
106
+ steps[0] += 1
107
+
108
+ if solve_sudoku_bitmask(board, row_mask, col_mask, box_mask, steps):
109
+ return True
110
+
111
+ # 回溯
112
+ board[r][c] = 0
113
+ row_mask[r] ^= bit
114
+ col_mask[c] ^= bit
115
+ box_mask[box_idx] ^= bit
116
+ return False
117
+
118
+
119
+ def evaluate_sudoku(board):
120
+ """
121
+ 返回 (empty_count, steps_used, initial_mrv_min)
122
+ 注:即便无解,也返回累计的 steps(便于统计)
123
+ """
124
+
125
+
126
+ empty_count = int(np.sum(np.array(board) == 0))
127
+ board_copy = [row[:] for row in board]
128
+ row_mask, col_mask, box_mask = initialize_masks(board_copy)
129
+
130
+ initial_min_remaining = get_initial_min_remaining(board_copy, row_mask, col_mask, box_mask)
131
+
132
+ steps = [0]
133
+ _ = solve_sudoku_bitmask(board_copy, row_mask, col_mask, box_mask, steps)
134
+ return empty_count, int(steps[0]), int(initial_min_remaining)
135
+
136
+
137
+ # ==================== 并行处理 ====================
138
+ def process_single_sudoku_optimized(args):
139
+ board, idx = args
140
+ empty_count, steps_count, initial_min_remaining = evaluate_sudoku(board)
141
+ return idx, empty_count, steps_count, initial_min_remaining
142
+
143
+
144
+ def parallel_solve_optimized(boards, n_workers=4):
145
+ empty_counts = [0] * len(boards)
146
+ steps_counts = [0] * len(boards)
147
+ initial_min_remainings = [0] * len(boards)
148
+ args_list = [(board, i) for i, board in enumerate(boards)]
149
+ with Pool(processes=n_workers) as pool:
150
+ with tqdm(total=len(boards), desc="Solving Sudoku", unit="puzzle") as pbar:
151
+ for idx, empty_count, steps_count, initial_min_remaining in pool.imap(
152
+ process_single_sudoku_optimized, args_list, chunksize=100):
153
+ empty_counts[idx] = empty_count
154
+ steps_counts[idx] = steps_count
155
+ initial_min_remainings[idx] = initial_min_remaining
156
+ pbar.update(1)
157
+ return empty_counts, steps_counts, initial_min_remainings
158
+
159
+
160
+ # ==================== 性能测试 ====================
161
+ def benchmark_solver(boards, sample_size=100):
162
+ print(f"Benchmarking solver with {sample_size} puzzles...")
163
+ start_time = time.time()
164
+ empty_counts, steps_counts, initial_min_remainings = parallel_solve_optimized(
165
+ boards[:sample_size], n_workers=4)
166
+ total_time = time.time() - start_time
167
+ print(f"Processed {sample_size} puzzles in {total_time:.2f} seconds")
168
+ print(f"Average per puzzle: {total_time / sample_size * 1000:.2f} ms")
169
+ print(f"Puzzles per second: {sample_size / total_time:.1f}")
170
+ return empty_counts, steps_counts, initial_min_remainings
171
+
172
+
173
+ # ==================== 绘图函数 ====================
174
+ def auto_bins(data):
175
+ """自动计算 bin 数(Freedman–Diaconis rule),带上下限"""
176
+ data = np.asarray(data)
177
+ data = data[np.isfinite(data)]
178
+ if data.size == 0:
179
+ return 30
180
+ q75, q25 = np.percentile(data, [75, 25])
181
+ iqr = q75 - q25
182
+ bin_width = 2 * iqr * (len(data) ** (-1/3)) if iqr > 0 else 0
183
+ if bin_width > 0:
184
+ bins = int((data.max() - data.min()) / bin_width)
185
+ return max(10, min(200, bins))
186
+ return 30
187
+
188
+
189
+ def plot_cdf(data, title, xlabel, save_path=None):
190
+ """累计分布函数 (CDF)"""
191
+ data_sorted = np.sort(data)
192
+ if len(data_sorted) == 0:
193
+ print(f"[WARN] plot_cdf: empty data for {title}")
194
+ return
195
+ cdf = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
196
+
197
+ fig, ax = plt.subplots(figsize=(8, 5))
198
+ ax.plot(data_sorted, cdf, linewidth=1) # 用默认配色循环
199
+
200
+ ax.set_title(title, fontsize=12)
201
+ ax.set_xlabel(xlabel, fontsize=10)
202
+ ax.set_ylabel('CDF', fontsize=10)
203
+ ax.grid(True, linestyle='--', alpha=0.3)
204
+
205
+ plt.tight_layout()
206
+ if save_path:
207
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
208
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
209
+ print(f"Plot saved to: {save_path}")
210
+ plt.close(fig)
211
+
212
+
213
+ def plot_histogram_auto(data, title, xlabel, save_path=None):
214
+ """直方图(自动 bins,x 轴尽量整数刻度)"""
215
+ if len(data) == 0:
216
+ print(f"[WARN] plot_histogram_auto: empty data for {title}")
217
+ return
218
+ fig, ax = plt.subplots(figsize=(8, 5))
219
+ bins = auto_bins(data)
220
+
221
+ ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
222
+
223
+ ax.xaxis.set_major_locator(MaxNLocator(nbins=15, integer=True))
224
+
225
+ ax.set_title(title, fontsize=12)
226
+ ax.set_xlabel(xlabel, fontsize=10)
227
+ ax.set_ylabel('Frequency', fontsize=10)
228
+ ax.grid(True, linestyle='--', alpha=0.3)
229
+
230
+ plt.tight_layout()
231
+ if save_path:
232
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
233
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
234
+ print(f"Plot saved to: {save_path}")
235
+ plt.close(fig)
236
+
237
+
238
+ def plot_histogram_log(data, title, xlabel, save_path=None):
239
+ """直方图 + 对数横轴"""
240
+ if len(data) == 0:
241
+ print(f"[WARN] plot_histogram_log: empty data for {title}")
242
+ return
243
+ fig, ax = plt.subplots(figsize=(8, 5))
244
+ bins = auto_bins(data)
245
+ ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
246
+ ax.set_xscale("log")
247
+
248
+ ax.set_title(title, fontsize=12)
249
+ ax.set_xlabel(xlabel + " (log scale)", fontsize=10)
250
+ ax.set_ylabel('Frequency', fontsize=10)
251
+ ax.grid(True, linestyle='--', alpha=0.3)
252
+
253
+ plt.tight_layout()
254
+ if save_path:
255
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
256
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
257
+ print(f"Plot saved to: {save_path}")
258
+ plt.close(fig)
259
+
260
+
261
+ def plot_histogram_with_vlines_log(data, vlines, title, xlabel, save_path=None):
262
+ """直方图(对数横轴)+ 竖线标注阈值"""
263
+ if len(data) == 0:
264
+ print(f"[WARN] plot_histogram_with_vlines_log: empty data for {title}")
265
+ return
266
+ fig, ax = plt.subplots(figsize=(8, 5))
267
+ bins = auto_bins(data)
268
+ ax.hist(data, bins=bins, alpha=0.7, edgecolor='white', linewidth=0.5)
269
+ ax.set_xscale("log")
270
+ for v in vlines:
271
+ ax.axvline(v, linestyle='--', linewidth=1)
272
+
273
+ ax.set_title(title, fontsize=12)
274
+ ax.set_xlabel(xlabel + " (log scale)", fontsize=10)
275
+ ax.set_ylabel('Frequency', fontsize=10)
276
+ ax.grid(True, linestyle='--', alpha=0.3)
277
+
278
+ plt.tight_layout()
279
+ if save_path:
280
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
281
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
282
+ print(f"Plot saved to: {save_path}")
283
+ plt.close(fig)
284
+
285
+
286
+ def plot_cdf_multiple(datasets, labels, title, xlabel, save_path=None):
287
+ """多组 CDF 叠加"""
288
+ fig, ax = plt.subplots(figsize=(8, 5))
289
+ for data, lab in zip(datasets, labels):
290
+ data = np.asarray(data)
291
+ if data.size == 0:
292
+ continue
293
+ data_sorted = np.sort(data)
294
+ cdf = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
295
+ ax.plot(data_sorted, cdf, linewidth=1, label=lab)
296
+ ax.set_title(title, fontsize=12)
297
+ ax.set_xlabel(xlabel, fontsize=10)
298
+ ax.set_ylabel('CDF', fontsize=10)
299
+ ax.grid(True, linestyle='--', alpha=0.3)
300
+ ax.legend()
301
+ plt.tight_layout()
302
+ if save_path:
303
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
304
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
305
+ print(f"Plot saved to: {save_path}")
306
+ plt.close(fig)
307
+
308
+
309
+ # ==================== 数据缓存 ====================
310
+ def save_data(empty_counts, steps_counts, initial_min_remainings, data_file):
311
+ np.savez(
312
+ data_file,
313
+ empty_counts=np.asarray(empty_counts, dtype=np.int32),
314
+ steps_counts=np.asarray(steps_counts, dtype=np.int32),
315
+ initial_min_remainings=np.asarray(initial_min_remainings, dtype=np.int32),
316
+ )
317
+ print(f"Data saved to: {data_file}")
318
+
319
+
320
+ def load_data(data_file):
321
+ if os.path.exists(data_file):
322
+ data = np.load(data_file)
323
+ print(f"Data loaded from: {data_file}")
324
+ # 兼容可能缺少 initial_min_remainings 的旧数据
325
+ initial_min_remainings = data.get('initial_min_remainings')
326
+ return data['empty_counts'], data['steps_counts'], initial_min_remainings
327
+ return None, None, None
328
+
329
+
330
+ # ==================== CSV Loader ====================
331
+ def load_sudoku_csv(file_path):
332
+ """读取 CSV(第一列是 81 长度题面,'.' 或 '0' 表空),返回 board list"""
333
+ quizzes = []
334
+ with open(file_path, newline='') as f:
335
+ reader = csv.reader(f)
336
+ header = next(reader, None) # 跳过表头(如果有)
337
+ for row in tqdm(reader, desc="Loading CSV"):
338
+ quiz_str = re.sub(r'[^0-9.]', '', row[0])
339
+ quiz = np.array([int(c) if c != '.' else 0 for c in quiz_str]).reshape(9, 9).tolist()
340
+ quizzes.append(quiz)
341
+ return quizzes
342
+
343
+
344
+ def load_sudoku_csv_strings(file_path):
345
+ """加载原始题面字符串(长度81),用'0'替代'.',便于保存回溯。"""
346
+ puzzles = []
347
+ with open(file_path, newline='') as f:
348
+ reader = csv.reader(f)
349
+ header = next(reader, None)
350
+ for row in reader:
351
+ quiz_str = re.sub(r'[^0-9.]', '', row[0])
352
+ quiz_str = ''.join(['0' if c == '.' else c for c in quiz_str])
353
+ puzzles.append(quiz_str)
354
+ return puzzles
355
+
356
+
357
+ # ==================== Main ====================
358
+ if __name__ == "__main__":
359
+ INPUT_CSV_FILES = ['data/hard_test.csv'] # 可替换为你的文件列表
360
+ N_WORKERS = 4
361
+ SAMPLE_SIZE = 5000 # 每个 CSV 最多处理多少条(用于快速试跑)
362
+ BENCHMARK = True # True 则对 solver 做基准测试
363
+ for csv_file in INPUT_CSV_FILES:
364
+ print(f"\nProcessing file: {csv_file}")
365
+ file_prefix = os.path.splitext(os.path.basename(csv_file))[0]
366
+ data_file = f"{file_prefix}_data.npz"
367
+
368
+ empty_counts, steps_counts, initial_min_remainings = load_data(data_file)
369
+ # if empty_counts is None or steps_counts is None or initial_min_remainings is None:
370
+ if True:
371
+ print("Data not found, processing...")
372
+ quizzes = load_sudoku_csv(csv_file)
373
+ print(f"Loaded {len(quizzes)} sudoku puzzles")
374
+ if BENCHMARK:
375
+ empty_counts, steps_counts, initial_min_remainings = benchmark_solver(quizzes, SAMPLE_SIZE)
376
+ else:
377
+ empty_counts, steps_counts, initial_min_remainings = parallel_solve_optimized(
378
+ quizzes[:SAMPLE_SIZE], n_workers=N_WORKERS)
379
+ save_data(empty_counts, steps_counts, initial_min_remainings, data_file)
380
+ else:
381
+ print("Using cached data")
382
+
383
+ # ============ 绘图输出目录 ============
384
+ out_dir = os.path.join('figures', file_prefix)
385
+ os.makedirs(out_dir, exist_ok=True)
386
+
387
+ # ============ 全量分布图 ============
388
+ plot_histogram_auto(empty_counts,
389
+ f'{file_prefix} - Empty Count Distribution',
390
+ 'Empty Count',
391
+ os.path.join(out_dir, f'{file_prefix}_empty_count_hist.png'))
392
+
393
+ plot_histogram_auto(steps_counts,
394
+ f'{file_prefix} - Steps Count Distribution',
395
+ 'Steps Count',
396
+ os.path.join(out_dir, f'{file_prefix}_steps_count_hist.png'))
397
+
398
+ plot_histogram_log(steps_counts,
399
+ f'{file_prefix} - Steps Count Distribution (Log X)',
400
+ 'Steps Count',
401
+ os.path.join(out_dir, f'{file_prefix}_steps_count_hist_log.png'))
402
+
403
+ plot_cdf(steps_counts,
404
+ f'{file_prefix} - Steps Count CDF',
405
+ 'Steps Count',
406
+ os.path.join(out_dir, f'{file_prefix}_steps_count_cdf.png'))
407
+
408
+ plot_cdf(empty_counts,
409
+ f'{file_prefix} - Empty Count CDF',
410
+ 'Empty Count',
411
+ os.path.join(out_dir, f'{file_prefix}_empty_count_cdf.png'))
412
+
413
+ # ============ 五档分位 + 抽样保存 ============
414
+ TOTAL_N = 5000
415
+ SAMPLE_PER_BIN = 1000
416
+ CONSIDER_N = int(min(TOTAL_N, len(steps_counts)))
417
+ consider_idx = np.arange(CONSIDER_N)
418
+
419
+ if CONSIDER_N > 0:
420
+ steps_consider = np.array(steps_counts)[:CONSIDER_N]
421
+ empty_consider = np.array(empty_counts)[:CONSIDER_N]
422
+
423
+ # 分位点(原逻辑:基于数值阈值分箱,可能因边界重复导致每箱数量非严格等分)
424
+ # q20, q40, q60, q80 = np.quantile(steps_consider, [0.2, 0.4, 0.6, 0.8])
425
+ # bins = [-np.inf, q20, q40, q60, q80, np.inf]
426
+ # labels = np.digitize(steps_consider, bins) # 1..5 档
427
+
428
+ # 改为基于排序排名严格切分,保证每个 bin 精确包含 SAMPLE_PER_BIN 条(若可行)
429
+ sorted_idx = np.argsort(steps_consider, kind='stable')
430
+ # 计算每个 bin 的起止索引(尽量平均,优先保证前四个 bin 为 SAMPLE_PER_BIN 条)
431
+ bin_size = SAMPLE_PER_BIN
432
+ b1 = sorted_idx[0:bin_size]
433
+ b2 = sorted_idx[bin_size:bin_size*2]
434
+ b3 = sorted_idx[bin_size*2:bin_size*3]
435
+ b4 = sorted_idx[bin_size*3:bin_size*4]
436
+ b5 = sorted_idx[bin_size*4:CONSIDER_N]
437
+
438
+ # 直方图(含基于排名的阈值,log x)
439
+ def safe_idx(pos):
440
+ return min(max(pos, 0), CONSIDER_N - 1)
441
+ v20 = steps_consider[sorted_idx[safe_idx(bin_size - 1)]]
442
+ v40 = steps_consider[sorted_idx[safe_idx(bin_size * 2 - 1)]]
443
+ v60 = steps_consider[sorted_idx[safe_idx(bin_size * 3 - 1)]]
444
+ v80 = steps_consider[sorted_idx[safe_idx(bin_size * 4 - 1)]]
445
+ plot_histogram_with_vlines_log(
446
+ steps_consider,
447
+ [v20, v40, v60, v80],
448
+ f'{file_prefix} - Steps Histogram with Quintile Thresholds',
449
+ 'Steps Count',
450
+ save_path=os.path.join(out_dir, f'{file_prefix}_steps_hist_with_quintiles.png')
451
+ )
452
+
453
+ # 各档数据 & 叠加 CDF
454
+ # bin 可视化数据基于排名切分
455
+ bin_data = [
456
+ steps_consider[b1],
457
+ steps_consider[b2],
458
+ steps_consider[b3],
459
+ steps_consider[b4],
460
+ steps_consider[b5],
461
+ ]
462
+ plot_cdf_multiple(
463
+ bin_data,
464
+ [f'Bin{k}' for k in [1, 2, 3, 4, 5]],
465
+ f'{file_prefix} - Steps CDF by Quintile Bins',
466
+ 'Steps Count',
467
+ save_path=os.path.join(out_dir, f'{file_prefix}_steps_cdf_quintiles.png')
468
+ )
469
+
470
+ # 读取原始题面字符串用于导出
471
+ puzzles_raw = load_sudoku_csv_strings(csv_file)
472
+
473
+ sampled_by_bin = {}
474
+ for target_bin in [1, 3, 5]:
475
+ # 基于排名切分得到每个 bin 的索引集合
476
+ if target_bin == 1:
477
+ idx_in_bin = b1
478
+ elif target_bin == 3:
479
+ idx_in_bin = b3
480
+ elif target_bin == 5:
481
+ idx_in_bin = b5
482
+ else:
483
+ idx_in_bin = np.array([], dtype=int)
484
+
485
+ if len(idx_in_bin) == 0:
486
+ print(f"Bin {target_bin} has no samples.")
487
+ sampled_by_bin[target_bin] = np.array([], dtype=int)
488
+ continue
489
+
490
+ # 取前 SAMPLE_PER_BIN 条,确保导出数量一致(若该 bin 少于该数量,则全量导出)
491
+ take_n = min(len(idx_in_bin), SAMPLE_PER_BIN)
492
+ sampled_idx = idx_in_bin[:take_n]
493
+
494
+ sampled_by_bin[target_bin] = np.array(sampled_idx, dtype=int)
495
+
496
+ out_csv = f"{file_prefix}_bin{target_bin}_sample{SAMPLE_PER_BIN}.csv"
497
+ with open(out_csv, 'w', newline='') as wf:
498
+ writer = csv.writer(wf)
499
+ writer.writerow(['puzzle_index', 'steps', 'empty_count', 'puzzle'])
500
+ for i in sampled_idx:
501
+ writer.writerow([int(i), int(steps_consider[i]), int(empty_consider[i]), puzzles_raw[i]])
502
+ print(f"Saved bin {target_bin} sample to: {out_csv} (count={len(sampled_idx)})")
503
+
504
+ # 该部分的直方图与CDF
505
+ d_steps = steps_consider[sampled_by_bin[target_bin]]
506
+ plot_histogram_auto(
507
+ d_steps,
508
+ f'{file_prefix} - Bin{target_bin} Sample Steps Hist',
509
+ 'Steps Count',
510
+ os.path.join(out_dir, f'{file_prefix}_bin{target_bin}_sample_steps_hist.png')
511
+ )
512
+ plot_cdf(
513
+ d_steps,
514
+ f'{file_prefix} - Bin{target_bin} Sample Steps CDF',
515
+ 'Steps Count',
516
+ os.path.join(out_dir, f'{file_prefix}_bin{target_bin}_sample_steps_cdf.png')
517
+ )
518
+
519
+ # 混合保存 1/3/5 档
520
+ def save_mix(mix_bins, out_name):
521
+ mix_idx = np.concatenate([sampled_by_bin.get(b, np.array([], dtype=int)) for b in mix_bins])
522
+ if mix_idx.size == 0:
523
+ print(f"Mix {out_name} has no samples; skipped.")
524
+ return
525
+ out_csv = f"{file_prefix}_{out_name}.csv"
526
+ with open(out_csv, 'w', newline='') as wf:
527
+ writer = csv.writer(wf)
528
+ writer.writerow(['puzzle_index', 'steps', 'empty_count', 'puzzle'])
529
+ for i in mix_idx:
530
+ writer.writerow([int(i), int(steps_consider[i]), int(empty_consider[i]), puzzles_raw[i]])
531
+ print(f"Saved mix {out_name} to: {out_csv} (count={mix_idx.size})")
532
+
533
+ # 混合 1+3+5:直方图与CDF
534
+ mix_name = f"mix_bin1_3_5_sample2000"
535
+ save_mix([1, 3, 5], mix_name)
536
+ mix_idx_all = np.concatenate([sampled_by_bin.get(b, np.array([], dtype=int)) for b in [1, 3, 5]])
537
+ if mix_idx_all.size > 0:
538
+ d_mix = steps_consider[mix_idx_all]
539
+ plot_histogram_auto(
540
+ d_mix,
541
+ f'{file_prefix} - Mix(1+3+5) Sample Steps Hist',
542
+ 'Steps Count',
543
+ os.path.join(out_dir, f'{file_prefix}_{mix_name}_steps_hist.png')
544
+ )
545
+ plot_cdf(
546
+ d_mix,
547
+ f'{file_prefix} - Mix(1+3+5) Sample Steps CDF',
548
+ 'Steps Count',
549
+ os.path.join(out_dir, f'{file_prefix}_{mix_name}_steps_cdf.png')
550
+ )
551
+
552
+ print("\nDone.")
test.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2fd52aea23d331d5b4ee723c856236e838a9fb9a70e66f4e0e0cf26c338c6a8
3
+ size 79360390
train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b46674db0148e0d73a16346dadeb2b1c00824d3fca3f85b2ae7037f6b4b38e
3
+ size 718819925