Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| # TODO: consider using vesin | |
| from matscipy.neighbours import neighbour_list | |
| from torch_geometric.data import Data | |
| from ase import Atoms | |
| from ase.calculators.singlepoint import SinglePointCalculator | |
| def get_neighbor( | |
| atoms: Atoms, cutoff: float, self_interaction: bool = False | |
| ): | |
| pbc = atoms.pbc | |
| cell = atoms.cell.array | |
| i, j, S = neighbour_list( | |
| quantities="ijS", | |
| pbc=pbc, | |
| cell=cell, | |
| positions=atoms.positions, | |
| cutoff=cutoff | |
| ) | |
| if not self_interaction: | |
| # Eliminate self-edges that don't cross periodic boundaries | |
| true_self_edge = i == j | |
| true_self_edge &= np.all(S == 0, axis=1) | |
| keep_edge = ~true_self_edge | |
| i = i[keep_edge] | |
| j = j[keep_edge] | |
| S = S[keep_edge] | |
| edge_index = np.stack((i, j)).astype(np.int64) | |
| edge_shift = np.dot(S, cell) | |
| return edge_index, edge_shift | |
| def collate_fn(batch: list[Atoms], cutoff: float) -> Data: | |
| """Collate a list of Atoms objects into a single batched Atoms object.""" | |
| # Offset the edge indices for each graph to ensure they remain disconnected | |
| offset = 0 | |
| node_batch = [] | |
| numbers_batch = [] | |
| positions_batch = [] | |
| # ec_batch = [] | |
| forces_batch = [] | |
| charges_batch = [] | |
| magmoms_batch = [] | |
| dipoles_batch = [] | |
| edge_index_batch = [] | |
| edge_shift_batch = [] | |
| cell_batch = [] | |
| natoms_batch = [] | |
| energy_batch = [] | |
| stress_batch = [] | |
| for i, atoms in enumerate(batch): | |
| edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False) | |
| edge_index[0] += offset | |
| edge_index[1] += offset | |
| edge_index_batch.append(torch.tensor(edge_index)) | |
| edge_shift_batch.append(torch.tensor(edge_shift)) | |
| natoms = len(atoms) | |
| offset += natoms | |
| node_batch.append(torch.ones(natoms, dtype=torch.long) * i) | |
| natoms_batch.append(natoms) | |
| cell_batch.append(torch.tensor(atoms.cell.array)) | |
| numbers_batch.append(torch.tensor(atoms.numbers)) | |
| positions_batch.append(torch.tensor(atoms.positions)) | |
| # ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers]) | |
| charges_batch.append( | |
| atoms.get_initial_charges() | |
| if atoms.get_initial_charges().any() | |
| else torch.full((natoms,), torch.nan) | |
| ) | |
| magmoms_batch.append( | |
| atoms.get_initial_magnetic_moments() | |
| if atoms.get_initial_magnetic_moments().any() | |
| else torch.full((natoms,), torch.nan) | |
| ) | |
| # Create the new 'arrays' data for the batch | |
| cell_batch = torch.stack(cell_batch, dim=0) | |
| node_batch = torch.cat(node_batch, dim=0) | |
| positions_batch = torch.cat(positions_batch, dim=0) | |
| numbers_batch = torch.cat(numbers_batch, dim=0) | |
| natoms_batch = torch.tensor(natoms_batch, dtype=torch.long) | |
| charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None | |
| magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None | |
| # ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch)) | |
| # ec_batch = torch.stack(ec_batch, dim=0) | |
| edge_index_batch = torch.cat(edge_index_batch, dim=1) | |
| edge_shift_batch = torch.cat(edge_shift_batch, dim=0) | |
| arrays_batch_concatenated = { | |
| "cell": cell_batch, | |
| "positions": positions_batch, | |
| "edge_index": edge_index_batch, | |
| "edge_shift": edge_shift_batch, | |
| "numbers": numbers_batch, | |
| "num_nodes": offset, | |
| "batch": node_batch, | |
| "charges": charges_batch, | |
| "magmoms": magmoms_batch, | |
| # "ec": ec_batch, | |
| "natoms": natoms_batch, | |
| "cutoff": torch.tensor(cutoff), | |
| } | |
| # TODO: custom fields | |
| # Create a new Data object with the concatenated arrays data | |
| batch_data = Data.from_dict(arrays_batch_concatenated) | |
| return batch_data | |
| def decollate_fn(batch_data: Data) -> list[Atoms]: | |
| """Decollate a batched Data object into a list of individual Atoms objects.""" | |
| # FIXME: this function is not working properly when the batch_data is on GPU. | |
| # TODO: create a new Cell class using torch tensor to handle device placement. | |
| # As a temporary fix, detach the batch_data from the GPU and move it to CPU. | |
| batch_data = batch_data.detach().cpu() | |
| # Initialize empty lists to store individual data entries | |
| individual_entries = [] | |
| # Split the 'batch' attribute to identify data entries | |
| unique_batches = batch_data.batch.unique(sorted=True) | |
| for i in unique_batches: | |
| # Identify the indices corresponding to the current data entry | |
| entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0] | |
| # Extract the attributes for the current data entry | |
| cell = batch_data.cell[i] | |
| numbers = batch_data.numbers[entry_indices] | |
| positions = batch_data.positions[entry_indices] | |
| # edge_index = batch_data.edge_index[:, entry_indices] | |
| # edge_shift = batch_data.edge_shift[entry_indices] | |
| # batch_data.ec[entry_indices] if batch_data.ec is not None else None | |
| # Optional fields | |
| energy = batch_data.energy[i] if "energy" in batch_data else None | |
| forces = batch_data.forces[entry_indices] if "forces" in batch_data else None | |
| stress = batch_data.stress[i] if "stress" in batch_data else None | |
| # charges = batch_data.charges[entry_indices] if "charges" in batch_data else None | |
| # magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None | |
| # dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None | |
| # TODO: cumstom fields | |
| # Create an 'Atoms' object for the current data entry | |
| atoms = Atoms( | |
| cell=cell, | |
| positions=positions, | |
| numbers=numbers, | |
| # forces=None if torch.any(torch.isnan(forces)) else forces, | |
| # charges=None if torch.any(torch.isnan(charges)) else charges, | |
| # magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms, | |
| # dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles, | |
| # energy=None if torch.isnan(energy) else energy, | |
| # stress=None if torch.any(torch.isnan(stress)) else stress, | |
| ) | |
| atoms.calc = SinglePointCalculator( | |
| energy=energy, | |
| forces=forces, | |
| stress=stress, | |
| # charges=charges, | |
| # magmoms=magmoms, | |
| ) # type: ignore | |
| # Append the individual data entry to the list | |
| individual_entries.append(atoms) | |
| return individual_entries | |