from tqdm import tqdm import torch import dgl from dgl.data.utils import load_graphs from torch.utils.data import Dataset, DataLoader from datasets import util import random from torch import FloatTensor import json class BaseDataset(Dataset): def load_dataset(self, file_paths, root_dir, embedding): self.data = [] if embedding == 'comp2vec': parts_path = root_dir / 'embedding' / 'embedding_comp2vec.json' elif embedding == 'uv': parts_path = root_dir / 'embedding' / 'embedding_uv.json' else: raise ValueError(f"Unknown embedding type: {embedding}. Expected 'comp2vec' or 'uv'.") with open(parts_path, "r") as file: self.parts_dict = json.load(file) with open(file_paths, "r") as file: dataset = json.load(file) for data_ in tqdm(dataset, desc="Loading graphs"): partial_assembly = data_["partial assembly"] node_names = list(partial_assembly.keys()) node_name_to_idx = {name: idx for idx, name in enumerate(node_names)} edges_src = [] edges_dst = [] for src in partial_assembly: for dst in partial_assembly[src]: edges_src.append(node_name_to_idx[src]) edges_dst.append(node_name_to_idx[dst]) num_nodes = len(node_names) dgl_graph = dgl.graph((edges_src, edges_dst), num_nodes=num_nodes) node_numbers = [int(name.split('_')[1]) for name in node_names] dgl_graph.ndata['name'] = torch.tensor(node_numbers) if num_nodes == 1: dgl_graph.add_edges(0, 0) # Obtain embeddings for each node node_features = torch.stack([torch.tensor(self.parts_dict[name]) for name in node_names]) dgl_graph.ndata['embedding'] = node_features #Use part embeddings as node features # Replace label and negative with their corresponding part embeddings label_name = data_["label"] label_embedding = torch.tensor(self.parts_dict[label_name]) #Retrieve the corresponding part embeddings negative_embeddings = [torch.tensor(self.parts_dict[n]) for n in data_["negative"]] self.data.append({"partial assembly": dgl_graph, "label": label_embedding, "negative": negative_embeddings, "label_id": int(label_name.split("_")[1])}) def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return sample def _collate(self, batch): batched_assembly = dgl.batch([data["partial assembly"] for data in batch]) batched_labels = torch.stack([data["label"] for data in batch]) batched_negative = torch.stack([neg for data in batch for neg in data["negative"]]) neg_idx = [ i for i, data in enumerate(batch) for _ in data["negative"] ] batched_labels_id = torch.tensor([data["label_id"] for data in batch], dtype=torch.int64) return {"partial assemblies": batched_assembly, "batched_labels": batched_labels, "batched_negative": batched_negative, "neg_idx": torch.tensor(neg_idx, dtype=torch.long), "batched_labels_id": batched_labels_id} def get_dataloader(self, batch_size=64, shuffle=True, num_workers=0): return DataLoader( self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._collate, num_workers=num_workers, # Can be set to non-zero on Linux drop_last=True, )