0xZohar commited on
Commit
d153774
·
verified ·
1 Parent(s): e3e6ebe

Add code/cube3d/training/batch_dat_mapping.py

Browse files
code/cube3d/training/batch_dat_mapping.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ # 检查文件编码
5
+ def checkEncoding(filepath):
6
+ with open(filepath, "rb") as encode_check:
7
+ encoding = encode_check.readline(3)
8
+ if encoding == b"\xfe\xff\x00":
9
+ return "utf_16_be"
10
+ elif encoding == b"\xff\xfe0":
11
+ return "utf_16_le"
12
+ else:
13
+ return "utf_8"
14
+
15
+ # 读取文本文件
16
+ def readTextFile(filepath):
17
+ if os.path.exists(filepath):
18
+ file_encoding = checkEncoding(filepath)
19
+ try:
20
+ with open(filepath, "rt", encoding=file_encoding) as f_in:
21
+ return f_in.readlines()
22
+ except:
23
+ with open(filepath, "rt", encoding="latin_1") as f_in:
24
+ return f_in.readlines()
25
+ return None
26
+
27
+ # 处理单个LDR文件数据
28
+ def process_ldr_data(lines, label_mapping, label_inverse_mapping, label_frequency, label_counter):
29
+ # 定位main_section范围
30
+ startLine = 0
31
+ endLine = 0
32
+ lineCount = 0
33
+ foundEnd = False
34
+ main_section_lines = []
35
+
36
+ for line in lines:
37
+ parameters = line.strip().split()
38
+ if len(parameters) > 2:
39
+ if parameters[0] == "0" and parameters[1] == "FILE":
40
+ if not foundEnd:
41
+ endLine = lineCount
42
+ if endLine > startLine:
43
+ main_section_lines.extend(lines[startLine:endLine])
44
+ foundEnd = True
45
+ break
46
+ startLine = lineCount
47
+ foundEnd = False
48
+
49
+ if parameters[0] == "0" and parameters[1] == "NOFILE":
50
+ endLine = lineCount
51
+ foundEnd = True
52
+ main_section_lines.extend(lines[startLine:endLine])
53
+ break
54
+ lineCount += 1
55
+
56
+ if not foundEnd:
57
+ endLine = len(lines)
58
+ if endLine > startLine:
59
+ main_section_lines.extend(lines[startLine:endLine])
60
+
61
+ # 处理main_section中1开头的行
62
+ for line in main_section_lines:
63
+ if line.startswith('1'):
64
+ parts = line.split()
65
+ if len(parts) >= 15:
66
+ part_filename = parts[14]
67
+ if ".DAT" in part_filename:
68
+ part_filename = part_filename.replace(".DAT", ".dat")
69
+
70
+ if part_filename not in label_mapping:
71
+ label_mapping[part_filename] = label_counter
72
+ label_inverse_mapping[label_counter] = part_filename
73
+ label_counter += 1
74
+
75
+ current_label = label_mapping[part_filename]
76
+ label_frequency[current_label] = label_frequency.get(current_label, 0) + 1
77
+
78
+ return label_mapping, label_inverse_mapping, label_frequency, label_counter
79
+
80
+ # 处理文件夹中所有LDR文件
81
+ def process_all_ldr_in_folder(folder_path):
82
+ overall_label_mapping = {}
83
+ overall_label_inverse_mapping = {}
84
+ overall_label_frequency = {}
85
+ label_counter = 0
86
+
87
+ for root, dirs, files in os.walk(folder_path):
88
+ for file in files:
89
+ if file.lower().endswith('.ldr'):
90
+ file_path = os.path.join(root, file)
91
+ print(f"正在处理: {file_path}")
92
+
93
+ lines = readTextFile(file_path)
94
+ if lines is None:
95
+ print(f"⚠️ 无法读取文件 {file_path},已跳过")
96
+ continue
97
+
98
+ overall_label_mapping, overall_label_inverse_mapping, overall_label_frequency, label_counter = process_ldr_data(
99
+ lines, overall_label_mapping, overall_label_inverse_mapping, overall_label_frequency, label_counter)
100
+
101
+ return overall_label_mapping, overall_label_inverse_mapping, overall_label_frequency
102
+
103
+ # 保存映射表和按频率排序的频率表
104
+ def save_results(label_mapping, label_inverse_mapping, label_frequency, output_dir):
105
+ os.makedirs(output_dir, exist_ok=True)
106
+
107
+ # 保存标签映射表
108
+ with open(os.path.join(output_dir, 'label_mapping.json'), 'w', encoding='utf-8') as f:
109
+ json.dump(label_mapping, f, indent=4, ensure_ascii=False)
110
+
111
+ # 保存反向标签映射表
112
+ with open(os.path.join(output_dir, 'label_inverse_mapping.json'), 'w', encoding='utf-8') as f:
113
+ json.dump(label_inverse_mapping, f, indent=4, ensure_ascii=False)
114
+
115
+ # 准备频率数据并按使用次数排序(从高到低)
116
+ frequency_list = []
117
+ for label_id, count in label_frequency.items():
118
+ frequency_list.append({
119
+ "label_id": label_id,
120
+ "part_name": label_inverse_mapping.get(label_id, "未知零件"),
121
+ "usage_count": count
122
+ })
123
+
124
+ # 按使用次数降序排序
125
+ frequency_list.sort(key=lambda x: x["usage_count"], reverse=True)
126
+
127
+ # 保存排序后的频率表
128
+ with open(os.path.join(output_dir, 'label_frequency.json'), 'w', encoding='utf-8') as f:
129
+ json.dump(frequency_list, f, indent=4, ensure_ascii=False)
130
+
131
+ # 主程序
132
+ if __name__ == "__main__":
133
+ INPUT_FOLDER = '/public/home/wangshuo/gap/assembly/data/car_1k/subset_self/ldr_l30_rotrans_expand_wom'
134
+ OUTPUT_FOLDER = '/public/home/wangshuo/gap/assembly/data/car_1k/subset_self'
135
+
136
+ label_mapping, label_inverse_mapping, label_frequency = process_all_ldr_in_folder(INPUT_FOLDER)
137
+
138
+ save_results(label_mapping, label_inverse_mapping, label_frequency, OUTPUT_FOLDER)
139
+ print(f"\n✅ 处理完成!结果已保存到: {OUTPUT_FOLDER}")
140
+ print(f"📊 统计摘要:")
141
+ print(f" - 总唯一标签数: {len(label_mapping)}")
142
+ print(f" - 总使用次数: {sum(label_frequency.values())}")
143
+ print(f" - label_frequency.json已按使用频率从高到低排序")