Spaces:
Sleeping
Sleeping
| """ | |
| Converts the JSON to a graph | |
| """ | |
| import numpy as np | |
| import torch as tr | |
| import math | |
| from torch_geometric.data import Data | |
| from scipy.special import sph_harm | |
| from mendeleev import element | |
| from tqdm import tqdm | |
| from .utils import list_files_in_directory, create_directory_if_not_exists, read_dict_from_json, nan_checker | |
| ## Fundamental graph elements and transformations ## | |
| class MaterialMesh(Data): | |
| def __init__(self, x, edge_index, edge_attr, u, bond_batch, hop, onsite): | |
| super(MaterialMesh, self).__init__() | |
| self.x = x # Node features | |
| self.edge_index = edge_index # Edge indices | |
| self.edge_attr = edge_attr # Edge features | |
| self.u = u # Global features | |
| self.bond_batch = bond_batch # tels from witch batch is the edge | |
| self.onsite = onsite # target propriety | |
| self.hop = hop # target hopping | |
| def __cat_dim__(self, key, value, *args, **kwargs): | |
| """ | |
| Ad extra dim when batched u. | |
| It will make then to not concatenate | |
| :param key: | |
| :param value: | |
| :param args: | |
| :param kwargs: | |
| :return: | |
| """ | |
| if key == "u": | |
| return None | |
| return super().__cat_dim__(key, value, *args, **kwargs) | |
| class MyTensor(tr.Tensor): | |
| """ | |
| this class is needed to work with graphs without edges | |
| """ | |
| def max(self, *args, **kwargs): | |
| if tr.numel(self) == 0: | |
| return 0 | |
| else: | |
| return tr.max(self, *args, **kwargs) | |
| def f_cut(r, decay_rate=3, cutoff=50): | |
| """ | |
| Computes the cosine decay cutoff function. | |
| Parameters: | |
| r (float or numpy array): Distance value(s). | |
| decay_rate (float): Decay rate parameter. | |
| Returns: | |
| float or numpy array: Output value(s) of the cosine decay cutoff function. | |
| """ | |
| # return 0.5 * (1 + np.cos(np.pi * r)) * np.exp(-decay_rate * r) | |
| # Compute values of cutoff function | |
| cutoffs = 0.5 * (np.cos(r * math.pi / cutoff) + 1.0) | |
| # Remove contributions beyond the cutoff radius | |
| cutoffs *= (r < cutoff) | |
| return cutoffs | |
| def element_to_atomic_number(element_symbol): | |
| try: | |
| el = element(element_symbol) | |
| return el.atomic_number | |
| except KeyError: | |
| return None # Return None if the element is not found | |
| def bessel_distance(c1, c2, n=[1, 2, 3, 4, 5, 6], rc=3): | |
| # print(f"c1:{c1}, c2:{c2}") | |
| d = (c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2 + (c1[2] - c2[2]) ** 2 | |
| rij = np.sqrt(d * d) | |
| c = np.sqrt(2 / rc) | |
| fc = f_cut(rij, rc * 0.5) | |
| bes = [c * fc * (np.sin(n_ * math.pi * rij / rc)) / rij for n_ in n] | |
| return bes | |
| def spherical_harmonics(c1, c2, max_l=1): | |
| # muve to center | |
| rc = c1 - c2 | |
| r, theta, phi = cartesian_to_spherical(rc[0], rc[1], rc[2]) | |
| y = [] | |
| for l in range(max_l): | |
| # yl=[] | |
| for m in range(-l, l): | |
| ylm = real_spherical_harmonics(l, m, theta, phi) | |
| y.append(ylm) | |
| # y.append(yl) | |
| return y | |
| def cartesian_to_spherical(x, y, z): | |
| r = np.sqrt(x ** 2 + y ** 2 + z ** 2) | |
| theta = np.arccos(z / r) | |
| phi = np.arctan2(y, x) | |
| return r, theta, phi | |
| def real_spherical_harmonics(l, m, theta, phi): | |
| # Compute the complex spherical harmonics | |
| Y_lm_complex = sph_harm(m, l, phi, theta) | |
| # Compute real spherical harmonics based on m value | |
| if m > 0: | |
| return np.sqrt(2) * np.real(Y_lm_complex) | |
| elif m == 0: | |
| return np.real(Y_lm_complex) | |
| else: | |
| return np.sqrt(2) * (-1) ** m * np.imag(Y_lm_complex) | |
| def compute_distance_matrix_torch(points): | |
| """ | |
| Computes the distance matrix between points given their 3D coordinates using PyTorch. | |
| Parameters: | |
| points (array-like): An array-like object of shape (n_points, 3) where each row represents a point (x, y, z). | |
| Returns: | |
| torch.Tensor: A 2D tensor of shape (n_points, n_points) representing the distance matrix. | |
| """ | |
| # Convert the list of points to a torch tensor for efficient computation | |
| points_tensor = tr.tensor(points, dtype=tr.float32) | |
| # Compute the pairwise distance matrix | |
| # Expand the dimensions of the tensor to allow broadcasting for pairwise distance computation | |
| diff = points_tensor.unsqueeze(1) - points_tensor.unsqueeze(0) | |
| # Compute the Euclidean distance | |
| dist_matrix = tr.sqrt(tr.sum(diff ** 2, dim=-1)) | |
| return dist_matrix | |
| def find_indices_in_range(matrix, min_val, max_val): | |
| """ | |
| Finds the indices (i, j) where the values in the matrix fall within the specified range. | |
| Parameters: | |
| matrix (torch.Tensor): A 2D tensor representing the distance matrix. | |
| min_val (float): The minimum value of the range. | |
| max_val (float): The maximum value of the range. | |
| Returns: | |
| list: A list of tuples (i, j) where the values in the matrix are within the specified range. | |
| """ | |
| # Find the indices where the values are within the range | |
| indices = tr.nonzero((matrix >= min_val) & (matrix <= max_val), as_tuple=False) | |
| # Convert to a list of tuples | |
| indices_list = [(i.item(), j.item()) for i, j in indices] | |
| return indices_list | |
| # Build a dataset | |
| class MaterialDS(tr.utils.data.Dataset): | |
| def __init__(self, graph_list): | |
| """ | |
| Convert a list of graphs into a dataset. | |
| :param graph_list: [list of pytorch geometric graphs] | |
| """ | |
| # (g.onsite, g.hop) | |
| self.data_list = [(g) for g in graph_list] | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| return self.data_list[idx] | |
| ## End: Fundamental graph elements and transformations ## | |
| def get_nodes_from_structure(structure): | |
| # Construct the nodes | |
| node_features = [] | |
| node_target = [] | |
| col = 0 | |
| for atom in structure["structure"]["atoms"]: | |
| # atomic number | |
| for orbit in range(atom["nr_orbitals"]): | |
| nod = [] | |
| atomic_number = [element_to_atomic_number(atom["simbol"])] | |
| nod.extend(atomic_number) | |
| nod.extend([orbit]) | |
| # position-> kils equivariance | |
| # position = atom["position"] | |
| # nod_s.extend(position) | |
| # nod_px.extend(position) | |
| # nod_py.extend(position) | |
| # nod_pz.extend(position) | |
| # onsite | |
| onsite = [structure["hmat"][col][col] * 100, structure["smat"][col][col] * 100] | |
| col += 1 | |
| node_target.append(onsite) | |
| node_features.append(nod) | |
| node_features = tr.tensor(node_features, dtype=tr.float32) | |
| node_target = tr.tensor(node_target, dtype=tr.float32) | |
| return node_features, node_target | |
| def get_edges_from_structure(structure, max_r=10): | |
| # Construct edges: | |
| edge_index = [[], []] | |
| edge_props = [] | |
| edge_target = [] | |
| # Extend atoms to orbitals | |
| # TODO: This is snot efficient change it: | |
| ext_coordinates = [] | |
| ext_atom_type = [] | |
| ext_orbitals = [] | |
| for atom in structure["structure"]["atoms"]: | |
| for i in range(atom["nr_orbitals"]): | |
| ext_coordinates.append(atom["xyz"]) | |
| ext_atom_type.append(element_to_atomic_number(atom["simbol"])) | |
| ext_orbitals.append(i) | |
| distance_ = compute_distance_matrix_torch(ext_coordinates) | |
| edges = find_indices_in_range(distance_, min_val=0, max_val=max_r) | |
| # Maybe add some diference | |
| for edge in edges: | |
| if edge[0] != edge[1]: | |
| edge_prop = [] | |
| a = edge[0] | |
| b = edge[1] | |
| edge_index[0].append(a) | |
| edge_index[1].append(b) | |
| coord_a = tr.tensor(ext_coordinates[a]) | |
| coord_b = tr.tensor(ext_coordinates[b]) | |
| # print("ca",coord_a) | |
| distance = [distance_[a][b]] | |
| if distance[0]!=0: | |
| bassel_distance = bessel_distance(coord_a, coord_b, n=[i for i in range(1, 9)]) | |
| spherical = spherical_harmonics(coord_a, coord_b,max_l=7) | |
| else: | |
| bassel_distance=[0 for _ in range(8)] | |
| spherical = [0 for _ in range(42)] | |
| # print("distance:", distance) | |
| # print("bassel_distance:", len(bassel_distance)) | |
| # print("spherical",len(spherical)) | |
| # print("spherical", nan_checker(spherical)) | |
| # print("bassel", nan_checker(bassel_distance)) | |
| edge_prop.extend(distance) | |
| edge_prop.extend(bassel_distance) | |
| edge_prop.extend(spherical) | |
| # Add prop | |
| edge_props.append(edge_prop) | |
| # Target | |
| hopp = [structure["hmat"][a][b] * 100, structure["smat"][a][b] * 100] | |
| edge_target.append(hopp) | |
| # print(len(edge_props)) | |
| edge_props = tr.tensor(edge_props, dtype=tr.float32) | |
| # print(len(edge_index[0])) | |
| # print(len(edge_index[1])) | |
| edge_index = tr.tensor(edge_index, dtype=tr.float32) | |
| edge_target = tr.tensor(edge_target, dtype=tr.float32) | |
| return edge_index, edge_props, edge_target | |
| def get_global_from_structure(structure): | |
| # Global propriety: | |
| lattice_vectors = structure["structure"]['lattice vectors'] | |
| print("lat vectors:", lattice_vectors) | |
| atom_xyz = structure["structure"]["atoms"] | |
| global_prop = [len(atom_xyz), | |
| lattice_vectors[0][0], | |
| lattice_vectors[0][1], | |
| lattice_vectors[0][2], | |
| lattice_vectors[1][0], | |
| lattice_vectors[1][1], | |
| lattice_vectors[1][2], | |
| lattice_vectors[2][0], | |
| lattice_vectors[2][1], | |
| lattice_vectors[2][2]] | |
| global_prop = tr.tensor(global_prop) | |
| return global_prop | |
| def structure_to_graph(structure, radius=100): | |
| node_features, node_target = get_nodes_from_structure(structure) | |
| edge_index, edge_props, edge_target = get_edges_from_structure(structure, radius) | |
| global_prop = get_global_from_structure(structure) | |
| # Create custom graph | |
| graph = MaterialMesh(x=node_features, | |
| edge_index=edge_index, | |
| edge_attr=edge_props, | |
| u=global_prop, | |
| bond_batch=MyTensor(np.zeros(edge_index.shape[1])).long(), | |
| hop=edge_target, | |
| onsite=node_target) | |
| print("graph:", graph) | |
| return graph | |
| def main(files_path, test_ratio, saving_spot, radius): | |
| # Construct the saving spot | |
| create_directory_if_not_exists(saving_spot) | |
| # ge the files and shuffle them: | |
| files = list_files_in_directory(files_path) | |
| # shuffle | |
| # Extract structure and build the graph | |
| structures = [read_dict_from_json(f"{files_path}/{st}") for st in files] | |
| #structures = structures[:5] | |
| graphs = [structure_to_graph(structure, radius) for structure in tqdm(structures)] | |
| train_ds = MaterialDS(graphs[:int(1 - len(graphs) * test_ratio)]) | |
| tr.save(train_ds, f'{saving_spot}/train.pt') | |
| test_ds = MaterialDS(graphs[1 - int(len(graphs) * test_ratio):]) | |
| tr.save(test_ds, f'{saving_spot}/test.pt') | |
| return 0 | |
| if __name__ == "__main__": | |
| test_ratio = 0.2 | |
| files_path = "DATA/DFT/BN_DFT_JSON" | |
| saving_spot= "DATA/DFT/BN_DFT_GRAPH" | |
| radius = 50 | |
| main(files_path, test_ratio,saving_spot ,radius) | |