Spaces:
Sleeping
Sleeping
| import torch | |
| from src.obth_gnn import HGnn | |
| from src.dft_data_to_grphs import MaterialDS, MaterialMesh, MyTensor | |
| from src.utils import generate_heatmap | |
| from torch_geometric.loader import DataLoader | |
| import json | |
| device =torch.device('cpu') | |
| model_name_map = { | |
| "wizard_tb 0.3": "demo_model_1", | |
| "wizard 0.3": "demo_model_0" | |
| } | |
| data_map = { | |
| "aBN_01": 0, "aBN_02": 1, | |
| } | |
| def output_to_matrix(hii, hij, ij): | |
| mat_h = torch.zeros([len(hii), len(hii)]) | |
| mat_s = torch.zeros([len(hii), len(hii)]) | |
| for i, hi in enumerate(hii): | |
| mat_h[i][i] = hi[0] | |
| mat_s[i][i] = hi[1] | |
| for i, hx in enumerate(hij): | |
| mat_h[ij[0][i]][ij[1][i]] = hx[0] | |
| mat_s[ij[0][i]][ij[1][i]] = hx[1] | |
| return mat_h.detach().numpy(), mat_s.detach().numpy() | |
| def plot_mat(mat): | |
| fig = generate_heatmap(mat, "OutputFiles/img/mat.jpg", grid1_step=1, grid2_step=13) | |
| return fig | |
| def text_out(h_mat, s_mat,model_name, data_name ): | |
| file_content= {"h_mat":h_mat.tolist(), "s_mat":s_mat.tolist()} | |
| file_name=f"OutputFiles/{data_name}_{model_name}.json" | |
| with open(file_name, 'w') as json_file: | |
| json.dump(file_content, json_file, indent=4) | |
| return file_name | |
| def compute_mat(data_name, model_name): | |
| print("model_name",model_name) | |
| if model_name[0] == "wizard_tb 0.3": | |
| model = HGnn(edge_shape=51, | |
| node_shape=2, | |
| u_shape=10, | |
| embed_size=[20, 20, 10], | |
| ham_graph_emb=[7, 7, 7], | |
| n_blocks=3) | |
| model.load_state_dict(torch.load(f'Models/{model_name_map[model_name[0]]}.pt', map_location=device)) | |
| model.to(device) | |
| else: | |
| print("Model not in th th elist of available models") | |
| test_data = torch.load("DATA/demo-graph/train.pt") | |
| data_ = DataLoader(test_data, batch_size=1, shuffle=False, ) | |
| inputs = [k for k in data_] | |
| print("data_name",data_name) | |
| dn=data_map[data_name[0]] | |
| inputs=inputs[dn] | |
| targets = (inputs.onsite, inputs.hop) | |
| x = inputs.x.to(torch.float32) | |
| edge_index = inputs.edge_index.to(torch.int64) | |
| edge_attr = inputs.edge_attr.to(torch.float32) | |
| state = inputs.u.to(torch.float32) | |
| batch = inputs.batch | |
| bond_batch = inputs.bond_batch | |
| hii, hij, ij = model(x, edge_index, edge_attr, state, batch.to(device), | |
| bond_batch.to(device)) | |
| h_mat, s_mat = output_to_matrix(hii, hij, ij) | |
| h_plot = plot_mat(h_mat) | |
| s_plot = plot_mat(s_mat) | |
| file_rsp = text_out(h_mat, s_mat, model_name[0], data_name[0]) | |
| return h_plot, s_plot, file_rsp | |
| def upload_struct(): | |
| pass | |