| from tqdm import tqdm
|
| import torch
|
| import dgl
|
| from dgl.data.utils import load_graphs
|
| from torch.utils.data import Dataset, DataLoader
|
| from datasets import util
|
| import random
|
| from torch import FloatTensor
|
| import json
|
| class BaseDataset(Dataset):
|
|
|
| def load_dataset(self, file_paths, root_dir, embedding):
|
| self.data = []
|
| if embedding == 'comp2vec':
|
| parts_path = root_dir / 'embedding' / 'embedding_comp2vec.json'
|
| elif embedding == 'uv':
|
| parts_path = root_dir / 'embedding' / 'embedding_uv.json'
|
| else:
|
| raise ValueError(f"Unknown embedding type: {embedding}. Expected 'comp2vec' or 'uv'.")
|
| with open(parts_path, "r") as file:
|
| self.parts_dict = json.load(file)
|
| with open(file_paths, "r") as file:
|
| dataset = json.load(file)
|
| for data_ in tqdm(dataset, desc="Loading graphs"):
|
| partial_assembly = data_["partial assembly"]
|
| node_names = list(partial_assembly.keys())
|
| node_name_to_idx = {name: idx for idx, name in enumerate(node_names)}
|
| edges_src = []
|
| edges_dst = []
|
| for src in partial_assembly:
|
| for dst in partial_assembly[src]:
|
| edges_src.append(node_name_to_idx[src])
|
| edges_dst.append(node_name_to_idx[dst])
|
| num_nodes = len(node_names)
|
| dgl_graph = dgl.graph((edges_src, edges_dst), num_nodes=num_nodes)
|
| node_numbers = [int(name.split('_')[1]) for name in node_names]
|
| dgl_graph.ndata['name'] = torch.tensor(node_numbers)
|
| if num_nodes == 1:
|
| dgl_graph.add_edges(0, 0)
|
|
|
| node_features = torch.stack([torch.tensor(self.parts_dict[name]) for name in node_names])
|
| dgl_graph.ndata['embedding'] = node_features
|
|
|
| label_name = data_["label"]
|
| label_embedding = torch.tensor(self.parts_dict[label_name])
|
| negative_embeddings = [torch.tensor(self.parts_dict[n]) for n in data_["negative"]]
|
| self.data.append({"partial assembly": dgl_graph, "label": label_embedding, "negative": negative_embeddings, "label_id": int(label_name.split("_")[1])})
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| def __getitem__(self, idx):
|
| sample = self.data[idx]
|
| return sample
|
|
|
| def _collate(self, batch):
|
| batched_assembly = dgl.batch([data["partial assembly"] for data in batch])
|
| batched_labels = torch.stack([data["label"] for data in batch])
|
| batched_negative = torch.stack([neg for data in batch for neg in data["negative"]])
|
| neg_idx = [
|
| i for i, data in enumerate(batch) for _ in data["negative"]
|
| ]
|
| batched_labels_id = torch.tensor([data["label_id"] for data in batch], dtype=torch.int64)
|
| return {"partial assemblies": batched_assembly, "batched_labels": batched_labels, "batched_negative": batched_negative,
|
| "neg_idx": torch.tensor(neg_idx, dtype=torch.long), "batched_labels_id": batched_labels_id}
|
|
|
| def get_dataloader(self, batch_size=64, shuffle=True, num_workers=0):
|
| return DataLoader(
|
| self,
|
| batch_size=batch_size,
|
| shuffle=shuffle,
|
| collate_fn=self._collate,
|
| num_workers=num_workers,
|
| drop_last=True,
|
| )
|
|
|