0xZohar commited on
Commit
b90e24e
·
verified ·
1 Parent(s): 2fb76b6

Add code/cube3d/training/check_rotation_onehot.py

Browse files
code/cube3d/training/check_rotation_onehot.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import itertools
3
+
4
+ def signed_perm_mats_det_plus_1():
5
+ """生成所有 3x3 的 ±1 置换矩阵,且 det=+1(24 个)"""
6
+ mats = []
7
+ for perm in itertools.permutations(range(3)): # 6 种列置换
8
+ P = np.zeros((3,3), int)
9
+ for r, c in enumerate(perm):
10
+ P[r, c] = 1
11
+ for signs in itertools.product([1, -1], repeat=3): # 行符号
12
+ R = P * np.array(signs)[:, None] # 每行乘 ±1
13
+ if int(round(np.linalg.det(R))) == 1:
14
+ mats.append(R)
15
+ # 去重(以 tuple 做 key)
16
+ uniq = {}
17
+ for R in mats:
18
+ key = tuple(R.reshape(-1))
19
+ uniq[key] = R # 同一个 key 只保留一个
20
+ return list(uniq.values()) # 24 个
21
+
22
+ R24 = signed_perm_mats_det_plus_1() # len(R24) == 24
23
+
24
+ #import ipdb; ipdb.set_trace()
25
+
26
+ def rot_to_onehot24(R):
27
+ Rr = np.round(R).astype(int) # 去毛刺到 {-1,0,1}
28
+ # 直接精确匹配;如果有极少量数值误差,也可以用 L1/L2 最近邻
29
+ keys = [tuple(M.reshape(-1)) for M in R24]
30
+ kR = tuple(Rr.reshape(-1))
31
+ if kR in keys:
32
+ idx = keys.index(kR)
33
+ else:
34
+ # 兜底最近邻(谨慎使用)
35
+ diffs = [np.sum(np.abs(M - Rr)) for M in R24]
36
+ idx = int(np.argmin(diffs))
37
+ one_hot = np.zeros(24, dtype=int)
38
+ one_hot[idx] = 1
39
+ return one_hot, idx
40
+
41
+ def onehot24_to_rot(one_hot):
42
+ idx = int(np.asarray(one_hot).argmax())
43
+ #import ipdb; ipdb.set_trace()
44
+ return R24[idx]
45
+
46
+ # # 从索引出发:idx -> R -> onehot -> idx_back
47
+ # for i in range(24):
48
+ # R = R24[i]
49
+ # oh, j = rot_to_onehot24(R)
50
+ # assert j == i
51
+ # Rb = onehot24_to_rot(oh)
52
+ # assert np.array_equal(R, Rb)
53
+ # print("24类旋转:往返一致 ✅")
54
+
55
+ import os
56
+ import numpy as np
57
+ from itertools import product, permutations
58
+
59
+ def generate_24_rotations():
60
+ """生成 24 个 det=+1 的 ±1 置换矩阵(立方体的旋转群)。"""
61
+ rots = []
62
+ for p in permutations(range(3)): # 轴置换
63
+ P = np.zeros((3,3), dtype=int)
64
+ for i, j in enumerate(p):
65
+ P[i, j] = 1
66
+ for signs in product([-1, 1], repeat=3): # 各轴符号
67
+ S = np.diag(signs)
68
+ R = S @ P
69
+ if round(np.linalg.det(R)) == 1: # 只要 det=+1
70
+ rots.append(R.astype(int))
71
+ # 去重并固定顺序
72
+ unique = []
73
+ for R in rots:
74
+ if not any(np.array_equal(R, U) for U in unique):
75
+ unique.append(R)
76
+ return unique # 长度应为 24
77
+
78
+ ROTATIONS_24 = generate_24_rotations()
79
+
80
+ def is_signed_perm_det1(R_int):
81
+ """R_int 是否为 entries∈{-1,0,1} 的 3x3 矩阵,且每行/列唯一一个非零,det=+1。"""
82
+ if R_int.shape != (3,3): return False
83
+ if not np.all(np.isin(R_int, [-1, 0, 1])): return False
84
+ # 每行/列恰好一个非零
85
+ if not np.all(np.sum(R_int != 0, axis=1) == 1): return False
86
+ if not np.all(np.sum(R_int != 0, axis=0) == 1): return False
87
+ # det=+1
88
+ return round(np.linalg.det(R_int)) == 1
89
+
90
+ def snap_to_signed_perm(R, tol=1e-3):
91
+ """
92
+ 仅将非常接近 -1/0/+1 的元素吸附到 -1/0/+1;若存在明显偏离(例如 0.57、0.82),保持原值——最终判定会失败。
93
+ """
94
+ R = np.asarray(R, dtype=float)
95
+ R_snap = R.copy()
96
+
97
+ # 接近 1 -> 1;接近 -1 -> -1;接近 0 -> 0
98
+ R_snap[np.isclose(R, 1.0, atol=tol)] = 1.0
99
+ R_snap[np.isclose(R, -1.0, atol=tol)] = -1.0
100
+ R_snap[np.isclose(R, 0.0, atol=tol)] = 0.0
101
+
102
+ # 其余保持原值(不是 24 旋转就会在后续判定中失败)
103
+ return R_snap
104
+
105
+ def match_in_24(R, tol=1e-3):
106
+ """
107
+ 返回 (matched, index, R_int)
108
+ - matched: 是否匹配 24 旋转
109
+ - index: 若匹配,给出在 ROTATIONS_24 中的索引,否则为 None
110
+ - R_int: 吸附后的整数矩阵(或原矩阵的近似整数版,便于调试)
111
+ """
112
+ R_snap = snap_to_signed_perm(R, tol=tol)
113
+ # 尝试转成 int(吸附后如果仍有非 -1/0/1 的值,astype 后会引发误判,先检查)
114
+ if not np.all(np.isin(R_snap, [-1.0, 0.0, 1.0])):
115
+ return False, None, R_snap # 明显不在 24 旋转里
116
+
117
+ R_int = R_snap.astype(int)
118
+ if not is_signed_perm_det1(R_int):
119
+ return False, None, R_int
120
+
121
+ # 找索引
122
+ for idx, R0 in enumerate(ROTATIONS_24):
123
+ if np.array_equal(R_int, R0):
124
+ return True, idx, R_int
125
+
126
+ return False, None, R_int
127
+
128
+ def parse_rotation_from_parts(parts):
129
+ """
130
+ parts 是一行 LDR 的分割结果。旋转矩阵是 parts[5:14](9 个数,按行展开)。
131
+ 返回 3x3 numpy 数组。
132
+ """
133
+ vals = list(map(float, parts[5:14]))
134
+ R = np.array(vals, dtype=float).reshape(3,3)
135
+ return R
136
+
137
+ def check_ldr_file(ldr_path, tol=1e-3, verbose=True):
138
+ """
139
+ 逐行读取 LDR 文件,遇到以 '1 ' 开头且列数>=15 的零件行,抽取旋转矩阵并与 24 旋转比对。
140
+ """
141
+ total = 0
142
+ matched = 0
143
+ not_matched = 0
144
+ mismatches = [] # (lineno, R, R_int/snap)
145
+
146
+ with open(ldr_path, 'r', encoding='utf-8', errors='ignore') as f:
147
+ lines = f.readlines()
148
+ if len(lines)>310:
149
+ print(f"Skipping {ldr_path}: too many lines ({len(lines)}).")
150
+ return
151
+ for ln, line in enumerate(lines, start=1):
152
+ line_strip = line.strip()
153
+ if not line_strip.startswith('1 '):
154
+ continue
155
+ parts = line_strip.split()
156
+ if len(parts) < 15:
157
+ continue
158
+
159
+ total += 1
160
+ R = parse_rotation_from_parts(parts)
161
+ ok, idx, R_int = match_in_24(R, tol=tol)
162
+ if ok:
163
+ matched += 1
164
+ # if verbose:
165
+ # print(f"[OK] line {ln}: matched rot index={idx} matrix=\n{R_int}")
166
+ else:
167
+ not_matched += 1
168
+ mismatches.append((ln, R, R_int))
169
+ if verbose:
170
+ print(f"[NO] line {ln}: not in 24 rotations. snapped=\n{R_int}\norig=\n{R}")
171
+
172
+ print("\n===== SUMMARY =====")
173
+ print(f"file: {ldr_path}")
174
+ print(f"total part-lines: {total}")
175
+ print(f"matched in 24: {matched}")
176
+ print(f"not matched: {not_matched}")
177
+
178
+ if not_matched and verbose:
179
+ print("\nExamples of not-matched (up to 5):")
180
+ for ln, R, R_int in mismatches[:5]:
181
+ print(f"- line {ln}: snapped=\n{R_int}\n orig=\n{R}\n")
182
+
183
+ return {
184
+ "total": total,
185
+ "matched": matched,
186
+ "not_matched": not_matched,
187
+ "mismatches": mismatches,
188
+ }
189
+
190
+ if __name__ == "__main__":
191
+ folder_path = '/public/home/wangshuo/gap/assembly/data' # 替换为你的文件夹路径
192
+ for root, dirs, files in os.walk(folder_path):
193
+ for file in files:
194
+ if file.endswith('.ldr') and file.startswith('modified'): # 只处理ldr文件
195
+ file_path = os.path.join(root, file)
196
+
197
+ # 用法示例:替换为你的 LDR 文件路径
198
+ #ldr_file = "/public/home/wangshuo/gap/assembly/data/blue classic car/modified_blue classic car.ldr"
199
+ check_ldr_file(file_path, tol=1e-3, verbose=True)
200
+
201
+
202
+ # import numpy as np
203
+ # from itertools import product, permutations
204
+
205
+ # def generate_24_rotations():
206
+ # mats = []
207
+ # for p in permutations(range(3)):
208
+ # P = np.zeros((3,3), dtype=int)
209
+ # for i,j in enumerate(p): P[i,j] = 1
210
+ # for s in product([-1,1], repeat=3):
211
+ # S = np.diag(s)
212
+ # R = S @ P
213
+ # if round(np.linalg.det(R)) == 1:
214
+ # mats.append(R.astype(float))
215
+ # # 去重
216
+ # uniq = []
217
+ # for R in mats:
218
+ # if not any(np.array_equal(R, U) for U in uniq):
219
+ # uniq.append(R)
220
+ # assert len(uniq) == 24
221
+ # return uniq
222
+
223
+ # ROT24 = generate_24_rotations()
224
+
225
+ # def rotation_distance_deg(R1, R2):
226
+ # """
227
+ # 计算 R2 到 R1 的相对旋转角度(度数)。
228
+ # Δ = R2^T R1; angle = arccos((trace(Δ)-1)/2)
229
+ # 数值上做夹取,避免浮点误差。
230
+ # """
231
+ # Delta = R2.T @ R1
232
+ # c = (np.trace(Delta) - 1.0) / 2.0
233
+ # c = np.clip(c, -1.0, 1.0)
234
+ # return float(np.degrees(np.arccos(c)))
235
+
236
+ # def nearest_in_24(R):
237
+ # best_idx, best_R, best_deg = None, None, 1e9
238
+ # for i, Q in enumerate(ROT24):
239
+ # deg = rotation_distance_deg(R, Q)
240
+ # if deg < best_deg:
241
+ # best_idx, best_R, best_deg = i, Q, deg
242
+ # return best_idx, best_R, best_deg
243
+
244
+ # # —— 示例:对你给出的矩阵做最近邻 ——
245
+ # R1 = np.array([
246
+ # [ 0. , 0. , -1. ],
247
+ # [-0.707107, -0.707107, 0. ],
248
+ # [-0.707107, 0.707106, 0. ],
249
+ # ], dtype=float)
250
+
251
+ # idx, Q, deg = nearest_in_24(R1)
252
+ # print("nearest idx:", idx)
253
+ # print("nearest R:\n", Q)
254
+ # print("angle error (deg):", deg)