GRM / datasets /base.py
Kang2691196427's picture
Upload base.py
0caca2a verified
from tqdm import tqdm
import torch
import dgl
from dgl.data.utils import load_graphs
from torch.utils.data import Dataset, DataLoader
from datasets import util
import random
from torch import FloatTensor
import json
class BaseDataset(Dataset):
def load_dataset(self, file_paths, root_dir, embedding):
self.data = []
if embedding == 'comp2vec':
parts_path = root_dir / 'embedding' / 'embedding_comp2vec.json'
elif embedding == 'uv':
parts_path = root_dir / 'embedding' / 'embedding_uv.json'
else:
raise ValueError(f"Unknown embedding type: {embedding}. Expected 'comp2vec' or 'uv'.")
with open(parts_path, "r") as file:
self.parts_dict = json.load(file)
with open(file_paths, "r") as file:
dataset = json.load(file)
for data_ in tqdm(dataset, desc="Loading graphs"):
partial_assembly = data_["partial assembly"]
node_names = list(partial_assembly.keys())
node_name_to_idx = {name: idx for idx, name in enumerate(node_names)}
edges_src = []
edges_dst = []
for src in partial_assembly:
for dst in partial_assembly[src]:
edges_src.append(node_name_to_idx[src])
edges_dst.append(node_name_to_idx[dst])
num_nodes = len(node_names)
dgl_graph = dgl.graph((edges_src, edges_dst), num_nodes=num_nodes)
node_numbers = [int(name.split('_')[1]) for name in node_names]
dgl_graph.ndata['name'] = torch.tensor(node_numbers)
if num_nodes == 1:
dgl_graph.add_edges(0, 0)
# Obtain embeddings for each node
node_features = torch.stack([torch.tensor(self.parts_dict[name]) for name in node_names])
dgl_graph.ndata['embedding'] = node_features #Use part embeddings as node features
# Replace label and negative with their corresponding part embeddings
label_name = data_["label"]
label_embedding = torch.tensor(self.parts_dict[label_name]) #Retrieve the corresponding part embeddings
negative_embeddings = [torch.tensor(self.parts_dict[n]) for n in data_["negative"]]
self.data.append({"partial assembly": dgl_graph, "label": label_embedding, "negative": negative_embeddings, "label_id": int(label_name.split("_")[1])})
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return sample
def _collate(self, batch):
batched_assembly = dgl.batch([data["partial assembly"] for data in batch])
batched_labels = torch.stack([data["label"] for data in batch])
batched_negative = torch.stack([neg for data in batch for neg in data["negative"]])
neg_idx = [
i for i, data in enumerate(batch) for _ in data["negative"]
]
batched_labels_id = torch.tensor([data["label_id"] for data in batch], dtype=torch.int64)
return {"partial assemblies": batched_assembly, "batched_labels": batched_labels, "batched_negative": batched_negative,
"neg_idx": torch.tensor(neg_idx, dtype=torch.long), "batched_labels_id": batched_labels_id}
def get_dataloader(self, batch_size=64, shuffle=True, num_workers=0):
return DataLoader(
self,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=self._collate,
num_workers=num_workers, # Can be set to non-zero on Linux
drop_last=True,
)