0xZohar commited on
Commit
e488d41
·
verified ·
1 Parent(s): 398283d

Upload code/cube3d/training/dataset.py

Browse files
Files changed (1) hide show
  1. code/cube3d/training/dataset.py +245 -0
code/cube3d/training/dataset.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import json
4
+
5
+ def read_ldr_file(file_path):
6
+ with open(file_path, 'r') as f:
7
+ lines = f.readlines()
8
+
9
+ return lines
10
+
11
+ def parse_ldr_lines(lines):
12
+ parts = []
13
+ for line in lines:
14
+ if line.startswith('1'): # LDR文件中的零件数据行通常以"1"开头
15
+ parts.append(line.strip()) # 处理零件信息
16
+ elif line.startswith('0'): # "0"行通常是注释或其他控制信息
17
+ pass
18
+ else:
19
+ pass
20
+
21
+ return parts
22
+
23
+ class SingLegoDataset:
24
+ def __init__(self, args, split_set="train"):
25
+ super().__init__()
26
+
27
+ self.split_set = split_set
28
+ data = np.load(os.path.join(args.data_dir, "Car Arcade_wrdhot" + ".npy"), allow_pickle=True)
29
+
30
+ self.data = [data]#[data[name] for name in data.files]
31
+
32
+ #self.prompts = json.load(open(os.path.join(args.data_dir, "text.json"), 'r'))['minecraft']
33
+ print(f"{split_set} dataset total data samples: {len(self.data)}")
34
+
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ data = self.data[idx]
41
+ prompt = self.prompts[idx]
42
+
43
+ #import ipdb; ipdb.set_trace()
44
+ data_dict = {}
45
+ data_dict['prompt'] = prompt
46
+ data_dict['latent'] = data
47
+
48
+
49
+ return data_dict
50
+
51
+
52
+ class LegosDataset:
53
+ def __init__(self, args, split_set="train"):
54
+ super().__init__()
55
+
56
+ self.max_num_tokens = 410
57
+ self.perm_num = -1
58
+ self.split_set = split_set
59
+ #data = np.load(os.path.join(args.data_dir, "all_ldr_data_lr30_train_sort.npz"), allow_pickle=True)['data']
60
+ data = np.load(os.path.join(args.data_dir, "train_1k.npz"), allow_pickle=True)['data']
61
+ #self.data = [self.padding(data[i], self.max_num_tokens) for i in range(len(data))]
62
+ #self.data = [data[i] for i in range(len(data))]
63
+ prompts = json.load(open(os.path.join(args.data_dir, "dense_captions", "dense_captions_rmthan300.json"), 'r'))['Car']
64
+ #latent = np.load(os.path.join(args.data_dir, "latents_train.npy"), allow_pickle=True)
65
+ bboxs = np.load(os.path.join(args.data_dir, "all_coordinates_train.npy"), allow_pickle=True)
66
+ self.data, self.prompts, self.bboxs = self.process_data(data, prompts, bboxs)
67
+
68
+ # self.latent = self.padding_latent(latent, self.max_num_tokens).astype(np.int64)
69
+ # self.data = [self.data[0]]
70
+ # self.prompts = [self.prompts[0]]
71
+ # self.bboxs = [self.bboxs[0]]
72
+ print(f"{split_set} dataset total data samples: {len(self.data)}")
73
+
74
+ def padding_latent(self, data, max_len=300):
75
+ # if data.shape[0] > max_len:
76
+ # print(data.shape[0])
77
+ pad_data = np.pad(data, ((0, max_len - data.shape[0]), (0, 0)), 'constant', constant_values=16386)
78
+ # pad_data[data.shape[0]-max_len:,-1] = 1 #flag label
79
+ # pad_data[data.shape[0]-max_len:,-2] = 0
80
+ return pad_data
81
+
82
+ def padding(self, data, max_len=300):
83
+ # if data.shape[0] > max_len:
84
+ # print(data.shape[0])
85
+ pad_data = np.pad(data, ((0, max_len - data.shape[0]), (0, 0)), 'constant', constant_values=-1)
86
+ pad_data[data.shape[0]-max_len:,-1] = 1 #flag label
87
+ pad_data[data.shape[0]-max_len:,-2] = 0
88
+ return pad_data
89
+
90
+ def permute(self, data, n_permutations=3):
91
+ return [data] + [data[np.random.permutation(len(data))] for _ in range(n_permutations-1)]
92
+
93
+ def process_data(self, data, prompts, bboxs):
94
+ processed_data, processed_prompts, processed_bboxs = [], [], []
95
+
96
+ for i in range(len(data)):
97
+ if self.perm_num > 0:
98
+ permuted_samples = self.permute(data[i], self.perm_num)
99
+ processed_data.extend([self.padding(p, self.max_num_tokens) for p in permuted_samples])
100
+ processed_prompts.extend([prompts[i]] * self.perm_num)
101
+ processed_bboxs.extend([bboxs[i]] * self.perm_num)
102
+ else:
103
+ processed_data.append(self.padding(data[i], self.max_num_tokens))
104
+ processed_prompts.append(prompts[i])
105
+ processed_bboxs.append(bboxs[i])
106
+
107
+ return processed_data, processed_prompts, np.array(processed_bboxs)
108
+ def __len__(self):
109
+ return len(self.data)
110
+
111
+ def __getitem__(self, idx):
112
+ data = self.data[idx]
113
+ prompt = self.prompts[idx]
114
+ bbox = self.bboxs[idx]
115
+ #latent = self.latent[idx]
116
+
117
+ #import ipdb; ipdb.set_trace()
118
+ data_dict = {}
119
+ data_dict['prompt'] = prompt
120
+ data_dict['target'] = data
121
+ data_dict['bbox'] = bbox
122
+ #data_dict['latent'] = latent
123
+
124
+
125
+ return data_dict
126
+
127
+ class LegosTestDataset:
128
+ def __init__(self, args, split_set="test"):
129
+ super().__init__()
130
+
131
+ self.max_num_tokens = 410
132
+ self.perm_num = -1
133
+ self.split_set = split_set
134
+ data = np.load(os.path.join(args.data_dir, "test_1k.npz"), allow_pickle=True)['data']
135
+
136
+ #self.data = [self.padding(data[i], self.max_num_tokens) for i in range(len(data))]
137
+ #self.data = [data[i] for i in range(len(data))]
138
+ prompts = json.load(open(os.path.join(args.data_dir, "dense_captions", "dense_captions_rmthan300.json"), 'r'))['Car']
139
+
140
+ bboxs = np.load(os.path.join(args.data_dir, "all_coordinates_test.npy"), allow_pickle=True)
141
+ self.data, self.prompts, self.bboxs = self.process_data(data, prompts, bboxs)
142
+ # latent = np.load(os.path.join(args.data_dir, "latents_test.npy"), allow_pickle=True)
143
+ # self.latent = self.padding_latent(latent, self.max_num_tokens).astype(np.int64)
144
+
145
+ #import ipdb; ipdb.set_trace()
146
+ # self.data = [self.data[1]]
147
+ # self.prompts = [self.prompts[0]]
148
+ # self.bboxs = [self.bboxs[1]]
149
+ print(f"{split_set} dataset total data samples: {len(self.data)}")
150
+
151
+ def padding_latent(self, data, max_len=300):
152
+ # if data.shape[0] > max_len:
153
+ # print(data.shape[0])
154
+ pad_data = np.pad(data, ((0, max_len - data.shape[0]), (0, 0)), 'constant', constant_values=16386)
155
+ # pad_data[data.shape[0]-max_len:,-1] = 1 #flag label
156
+ # pad_data[data.shape[0]-max_len:,-2] = 0
157
+ return pad_data
158
+
159
+ def padding(self, data, max_len=300):
160
+ # if data.shape[0] > max_len:
161
+ # print(data.shape[0])
162
+ pad_data = np.pad(data, ((0, max_len - data.shape[0]), (0, 0)), 'constant', constant_values=-1)
163
+ pad_data[data.shape[0]-max_len:,-1] = 1 #flag label
164
+ pad_data[data.shape[0]-max_len:,-2] = 0
165
+ return pad_data
166
+
167
+ def permute(self, data, n_permutations=3):
168
+ return [data] + [data[np.random.permutation(len(data))] for _ in range(n_permutations-1)]
169
+
170
+ def process_data(self, data, prompts, bboxs):
171
+ processed_data, processed_prompts, processed_bboxs = [], [], []
172
+
173
+ for i in range(len(data)):
174
+ if self.perm_num > 0:
175
+ permuted_samples = self.permute(data[i], self.perm_num)
176
+ processed_data.extend([self.padding(p, self.max_num_tokens) for p in permuted_samples])
177
+ processed_prompts.extend([prompts[i]] * self.perm_num)
178
+ processed_bboxs.extend([bboxs[i]] * self.perm_num)
179
+ else:
180
+ processed_data.append(self.padding(data[i], self.max_num_tokens))
181
+ processed_prompts.append(prompts[i])
182
+ processed_bboxs.append(bboxs[i])
183
+
184
+ return processed_data, processed_prompts, np.array(processed_bboxs)
185
+ def __len__(self):
186
+ return len(self.data)
187
+
188
+ def __getitem__(self, idx):
189
+ data = self.data[idx]
190
+ prompt = self.prompts[idx]
191
+ bbox = self.bboxs[idx]
192
+ #latent = self.latent[idx]
193
+
194
+ #import ipdb; ipdb.set_trace()
195
+ data_dict = {}
196
+ data_dict['prompt'] = prompt
197
+ data_dict['target'] = data
198
+ #data_dict['latent'] = latent
199
+ data_dict['bbox'] = bbox
200
+
201
+ return data_dict
202
+
203
+ class CubeDataset:
204
+ def __init__(self, args, split_set="train"):
205
+ super().__init__()
206
+ # self.num_tokens = args.n_discrete_size
207
+ # self.no_aug = args.no_aug
208
+
209
+ self.split_set = split_set
210
+ # if split_set == "test":
211
+ # self.no_aug = True
212
+
213
+ data = np.load(os.path.join(args.data_dir, split_set + ".npz"), allow_pickle=True)
214
+
215
+ self.data = [data[name] for name in data.files]
216
+ # if cur_data['faces_num'] <= self.max_triangles
217
+ # and cur_data['faces_num'] >= self.min_triangles]
218
+ self.prompts = json.load(open(os.path.join(args.data_dir, "text.json"), 'r'))['minecraft']
219
+ print(f"{split_set} dataset total data samples: {len(self.data)}")
220
+
221
+
222
+ def __len__(self):
223
+ return len(self.data)
224
+
225
+ def __getitem__(self, idx):
226
+ data = self.data[idx]
227
+ #prompt = self.prompts[idx]
228
+
229
+ #import ipdb; ipdb.set_trace()
230
+ data_dict = {}
231
+ #data_dict['prompt'] = prompt
232
+ data_dict['latent'] = data
233
+
234
+
235
+ return data_dict
236
+
237
+
238
+ if __name__ == "__main__":
239
+ file_path = '/public/home/wangshuo/gap/assembly/data/blue classic car/blue classic car.ldr'
240
+ ldr_lines = read_ldr_file(file_path)
241
+ parsed_parts = parse_ldr_lines(ldr_lines)
242
+
243
+ # import ipdb; ipdb.set_trace()
244
+ for part in parsed_parts:
245
+ print(part)