yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import os
import os.path as osp
import numpy as np
import torch
from pytorch_lightning.utilities import rank_zero_warn
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_tar
from tqdm import tqdm
class rMD17(InMemoryDataset):
revised_url = ('https://archive.materialscloud.org/record/'
'file?filename=rmd17.tar.bz2&record_id=466')
molecule_files = dict(
aspirin='rmd17_aspirin.npz',
azobenzene='rmd17_azobenzene.npz',
benzene='rmd17_benzene.npz',
ethanol='rmd17_ethanol.npz',
malonaldehyde='rmd17_malonaldehyde.npz',
naphthalene='rmd17_naphthalene.npz',
paracetamol='rmd17_paracetamol.npz',
salicylic='rmd17_salicylic.npz',
toluene='rmd17_toluene.npz',
uracil='rmd17_uracil.npz',
)
available_molecules = list(molecule_files.keys())
def __init__(self, root, transform=None, pre_transform=None, dataset_arg=None):
assert dataset_arg is not None, (
"Please provide the desired comma separated molecule(s) through"
f"'dataset_arg'. Available molecules are {', '.join(rMD17.available_molecules)} "
"or 'all' to train on the combined dataset."
)
if dataset_arg == "all":
dataset_arg = ",".join(rMD17.available_molecules)
self.molecules = dataset_arg.split(",")
if len(self.molecules) > 1:
rank_zero_warn(
"MD17 molecules have different reference energies, "
"which is not accounted for during training."
)
super(rMD17, self).__init__(osp.join(root, dataset_arg), transform, pre_transform)
self.offsets = [0]
self.data_all, self.slices_all = [], []
for path in self.processed_paths:
data, slices = torch.load(path)
self.data_all.append(data)
self.slices_all.append(slices)
self.offsets.append(len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1])
def len(self):
return sum(len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all)
def get(self, idx):
data_idx = 0
while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]:
data_idx += 1
self.data = self.data_all[data_idx]
self.slices = self.slices_all[data_idx]
return super(rMD17, self).get(idx - self.offsets[data_idx])
@property
def raw_file_names(self):
return [osp.join('rmd17', 'npz_data', rMD17.molecule_files[mol]) for mol in self.molecules]
@property
def processed_file_names(self):
return [f"rmd17-{mol}.pt" for mol in self.molecules]
def download(self):
path = download_url(self.revised_url, self.raw_dir)
extract_tar(path, self.raw_dir, mode='r:bz2')
os.unlink(path)
def process(self):
for path, processed_path in zip(self.raw_paths, self.processed_paths):
data_npz = np.load(path)
z = torch.from_numpy(data_npz["nuclear_charges"]).long()
positions = torch.from_numpy(data_npz["coords"]).float()
energies = torch.from_numpy(data_npz["energies"]).float()
forces = torch.from_numpy(data_npz["forces"]).float()
energies.unsqueeze_(1)
samples = []
for pos, y, dy in tqdm(zip(positions, energies, forces), total=energies.size(0)):
data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy)
if self.pre_filter is not None:
data = self.pre_filter(data)
if self.pre_transform is not None:
data = self.pre_transform(data)
samples.append(data)
data, slices = self.collate(samples)
torch.save((data, slices), processed_path)