object-assembler / code /cube3d /training /check_rotation_onehot.py
0xZohar's picture
Add code/cube3d/training/check_rotation_onehot.py
b90e24e verified
import numpy as np
import itertools
def signed_perm_mats_det_plus_1():
"""生成所有 3x3 的 ±1 置换矩阵,且 det=+1(24 个)"""
mats = []
for perm in itertools.permutations(range(3)): # 6 种列置换
P = np.zeros((3,3), int)
for r, c in enumerate(perm):
P[r, c] = 1
for signs in itertools.product([1, -1], repeat=3): # 行符号
R = P * np.array(signs)[:, None] # 每行乘 ±1
if int(round(np.linalg.det(R))) == 1:
mats.append(R)
# 去重(以 tuple 做 key)
uniq = {}
for R in mats:
key = tuple(R.reshape(-1))
uniq[key] = R # 同一个 key 只保留一个
return list(uniq.values()) # 24 个
R24 = signed_perm_mats_det_plus_1() # len(R24) == 24
#import ipdb; ipdb.set_trace()
def rot_to_onehot24(R):
Rr = np.round(R).astype(int) # 去毛刺到 {-1,0,1}
# 直接精确匹配;如果有极少量数值误差,也可以用 L1/L2 最近邻
keys = [tuple(M.reshape(-1)) for M in R24]
kR = tuple(Rr.reshape(-1))
if kR in keys:
idx = keys.index(kR)
else:
# 兜底最近邻(谨慎使用)
diffs = [np.sum(np.abs(M - Rr)) for M in R24]
idx = int(np.argmin(diffs))
one_hot = np.zeros(24, dtype=int)
one_hot[idx] = 1
return one_hot, idx
def onehot24_to_rot(one_hot):
idx = int(np.asarray(one_hot).argmax())
#import ipdb; ipdb.set_trace()
return R24[idx]
# # 从索引出发:idx -> R -> onehot -> idx_back
# for i in range(24):
# R = R24[i]
# oh, j = rot_to_onehot24(R)
# assert j == i
# Rb = onehot24_to_rot(oh)
# assert np.array_equal(R, Rb)
# print("24类旋转:往返一致 ✅")
import os
import numpy as np
from itertools import product, permutations
def generate_24_rotations():
"""生成 24 个 det=+1 的 ±1 置换矩阵(立方体的旋转群)。"""
rots = []
for p in permutations(range(3)): # 轴置换
P = np.zeros((3,3), dtype=int)
for i, j in enumerate(p):
P[i, j] = 1
for signs in product([-1, 1], repeat=3): # 各轴符号
S = np.diag(signs)
R = S @ P
if round(np.linalg.det(R)) == 1: # 只要 det=+1
rots.append(R.astype(int))
# 去重并固定顺序
unique = []
for R in rots:
if not any(np.array_equal(R, U) for U in unique):
unique.append(R)
return unique # 长度应为 24
ROTATIONS_24 = generate_24_rotations()
def is_signed_perm_det1(R_int):
"""R_int 是否为 entries∈{-1,0,1} 的 3x3 矩阵,且每行/列唯一一个非零,det=+1。"""
if R_int.shape != (3,3): return False
if not np.all(np.isin(R_int, [-1, 0, 1])): return False
# 每行/列恰好一个非零
if not np.all(np.sum(R_int != 0, axis=1) == 1): return False
if not np.all(np.sum(R_int != 0, axis=0) == 1): return False
# det=+1
return round(np.linalg.det(R_int)) == 1
def snap_to_signed_perm(R, tol=1e-3):
"""
仅将非常接近 -1/0/+1 的元素吸附到 -1/0/+1;若存在明显偏离(例如 0.57、0.82),保持原值——最终判定会失败。
"""
R = np.asarray(R, dtype=float)
R_snap = R.copy()
# 接近 1 -> 1;接近 -1 -> -1;接近 0 -> 0
R_snap[np.isclose(R, 1.0, atol=tol)] = 1.0
R_snap[np.isclose(R, -1.0, atol=tol)] = -1.0
R_snap[np.isclose(R, 0.0, atol=tol)] = 0.0
# 其余保持原值(不是 24 旋转就会在后续判定中失败)
return R_snap
def match_in_24(R, tol=1e-3):
"""
返回 (matched, index, R_int)
- matched: 是否匹配 24 旋转
- index: 若匹配,给出在 ROTATIONS_24 中的索引,否则为 None
- R_int: 吸附后的整数矩阵(或原矩阵的近似整数版,便于调试)
"""
R_snap = snap_to_signed_perm(R, tol=tol)
# 尝试转成 int(吸附后如果仍有非 -1/0/1 的值,astype 后会引发误判,先检查)
if not np.all(np.isin(R_snap, [-1.0, 0.0, 1.0])):
return False, None, R_snap # 明显不在 24 旋转里
R_int = R_snap.astype(int)
if not is_signed_perm_det1(R_int):
return False, None, R_int
# 找索引
for idx, R0 in enumerate(ROTATIONS_24):
if np.array_equal(R_int, R0):
return True, idx, R_int
return False, None, R_int
def parse_rotation_from_parts(parts):
"""
parts 是一行 LDR 的分割结果。旋转矩阵是 parts[5:14](9 个数,按行展开)。
返回 3x3 numpy 数组。
"""
vals = list(map(float, parts[5:14]))
R = np.array(vals, dtype=float).reshape(3,3)
return R
def check_ldr_file(ldr_path, tol=1e-3, verbose=True):
"""
逐行读取 LDR 文件,遇到以 '1 ' 开头且列数>=15 的零件行,抽取旋转矩阵并与 24 旋转比对。
"""
total = 0
matched = 0
not_matched = 0
mismatches = [] # (lineno, R, R_int/snap)
with open(ldr_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
if len(lines)>310:
print(f"Skipping {ldr_path}: too many lines ({len(lines)}).")
return
for ln, line in enumerate(lines, start=1):
line_strip = line.strip()
if not line_strip.startswith('1 '):
continue
parts = line_strip.split()
if len(parts) < 15:
continue
total += 1
R = parse_rotation_from_parts(parts)
ok, idx, R_int = match_in_24(R, tol=tol)
if ok:
matched += 1
# if verbose:
# print(f"[OK] line {ln}: matched rot index={idx} matrix=\n{R_int}")
else:
not_matched += 1
mismatches.append((ln, R, R_int))
if verbose:
print(f"[NO] line {ln}: not in 24 rotations. snapped=\n{R_int}\norig=\n{R}")
print("\n===== SUMMARY =====")
print(f"file: {ldr_path}")
print(f"total part-lines: {total}")
print(f"matched in 24: {matched}")
print(f"not matched: {not_matched}")
if not_matched and verbose:
print("\nExamples of not-matched (up to 5):")
for ln, R, R_int in mismatches[:5]:
print(f"- line {ln}: snapped=\n{R_int}\n orig=\n{R}\n")
return {
"total": total,
"matched": matched,
"not_matched": not_matched,
"mismatches": mismatches,
}
if __name__ == "__main__":
folder_path = '/public/home/wangshuo/gap/assembly/data' # 替换为你的文件夹路径
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith('.ldr') and file.startswith('modified'): # 只处理ldr文件
file_path = os.path.join(root, file)
# 用法示例:替换为你的 LDR 文件路径
#ldr_file = "/public/home/wangshuo/gap/assembly/data/blue classic car/modified_blue classic car.ldr"
check_ldr_file(file_path, tol=1e-3, verbose=True)
# import numpy as np
# from itertools import product, permutations
# def generate_24_rotations():
# mats = []
# for p in permutations(range(3)):
# P = np.zeros((3,3), dtype=int)
# for i,j in enumerate(p): P[i,j] = 1
# for s in product([-1,1], repeat=3):
# S = np.diag(s)
# R = S @ P
# if round(np.linalg.det(R)) == 1:
# mats.append(R.astype(float))
# # 去重
# uniq = []
# for R in mats:
# if not any(np.array_equal(R, U) for U in uniq):
# uniq.append(R)
# assert len(uniq) == 24
# return uniq
# ROT24 = generate_24_rotations()
# def rotation_distance_deg(R1, R2):
# """
# 计算 R2 到 R1 的相对旋转角度(度数)。
# Δ = R2^T R1; angle = arccos((trace(Δ)-1)/2)
# 数值上做夹取,避免浮点误差。
# """
# Delta = R2.T @ R1
# c = (np.trace(Delta) - 1.0) / 2.0
# c = np.clip(c, -1.0, 1.0)
# return float(np.degrees(np.arccos(c)))
# def nearest_in_24(R):
# best_idx, best_R, best_deg = None, None, 1e9
# for i, Q in enumerate(ROT24):
# deg = rotation_distance_deg(R, Q)
# if deg < best_deg:
# best_idx, best_R, best_deg = i, Q, deg
# return best_idx, best_R, best_deg
# # —— 示例:对你给出的矩阵做最近邻 ——
# R1 = np.array([
# [ 0. , 0. , -1. ],
# [-0.707107, -0.707107, 0. ],
# [-0.707107, 0.707106, 0. ],
# ], dtype=float)
# idx, Q, deg = nearest_in_24(R1)
# print("nearest idx:", idx)
# print("nearest R:\n", Q)
# print("angle error (deg):", deg)