Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for managing files and plots. | |
| """ | |
| import os | |
| import json | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import TwoSlopeNorm | |
| import numpy as np | |
| import glob | |
| import math | |
| import torch | |
| ## Save spot ## | |
| def save_spot(exp_name, spot_nr, model, data): | |
| # Create directory | |
| create_directory_if_not_exists("EXPERIMENTS") | |
| create_directory_if_not_exists(f"EXPERIMENTS/{exp_name}") | |
| create_directory_if_not_exists(f"EXPERIMENTS/{exp_name}/spot{spot_nr}") | |
| path = f"EXPERIMENTS/{exp_name}/spot{spot_nr}" | |
| path_img = os.path.join(path, "img") | |
| create_directory_if_not_exists(path_img) | |
| path_txt = os.path.join(path, "txt") | |
| create_directory_if_not_exists(path_txt) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Save model | |
| torch.save(model.state_dict(), f"{path}/model.pt") | |
| # Save data | |
| for ko, inputs in enumerate(data): | |
| inputs = inputs.to(device) | |
| x = inputs.x.to(torch.float32) | |
| edge_index = inputs.edge_index.to(torch.int64) | |
| edge_attr = inputs.edge_attr.to(torch.float32) | |
| state = inputs.u.to(torch.float32) | |
| batch = inputs.batch | |
| bond_batch = inputs.bond_batch | |
| with torch.no_grad(): | |
| hii, hij, ij = model(x, edge_index, edge_attr, state, batch, bond_batch) | |
| # Move tensors to CPU for further processing and numpy conversion | |
| hii = hii.cpu() | |
| hij = hij.cpu() | |
| ij = ij.cpu() | |
| pred_mat_r = torch.zeros([len(hii), len(hii)]) | |
| pred_mat_i = torch.zeros([len(hii), len(hii)]) | |
| for i, hi in enumerate(hii): | |
| pred_mat_r[i][i] = hi[0] | |
| pred_mat_i[i][i] = hi[1] | |
| for i, hx in enumerate(hij): | |
| pred_mat_r[ij[0][i]][ij[1][i]] = hx[0] | |
| pred_mat_i[ij[0][i]][ij[1][i]] = hx[1] | |
| target_mat_r = torch.zeros([len(hii), len(hii)]) | |
| target_mat_i = torch.zeros([len(hii), len(hii)]) | |
| for i, hi in enumerate(inputs.onsite): | |
| target_mat_r[i][i] = hi[0] | |
| target_mat_i[i][i] = hi[1] | |
| for i, hx in enumerate(inputs.hop): | |
| target_mat_r[ij[0][i]][ij[1][i]] = hx[0] | |
| target_mat_i[ij[0][i]][ij[1][i]] = hx[1] | |
| dif_mat_i = target_mat_i - pred_mat_i | |
| dif_mat_r = target_mat_r - pred_mat_r | |
| target_mat_r = target_mat_r.detach().numpy() | |
| pred_mat_r = pred_mat_r.detach().numpy() | |
| dif_mat_r = dif_mat_r.detach().numpy() | |
| dif_mat_i = dif_mat_i.detach().numpy() | |
| pred_mat_i=pred_mat_i.detach().numpy() | |
| target_mat_i=target_mat_i.detach().numpy() | |
| generate_heatmap(target_mat_r, filename=f'{path_img}/{ko}_tar_hmat.png') | |
| generate_heatmap(pred_mat_r, filename=f'{path_img}/{ko}_pred_hmat.png') | |
| generate_heatmap(dif_mat_r, filename=f'{path_img}/{ko}_dif_hmat.png') | |
| generate_heatmap(dif_mat_i, filename=f'{path_img}/{ko}_dif_smat.png') | |
| generate_heatmap(pred_mat_i, filename=f'{path_img}/{ko}_pred_smat.png') | |
| generate_heatmap(target_mat_i, filename=f'{path_img}/{ko}_target_smat.png') | |
| print("Done") | |
| print("max:", dif_mat_r.max()) | |
| print("min:", dif_mat_r.min()) | |
| np.save(os.path.join(path_txt, f'{ko}_dif_mat_hmat.npy'), dif_mat_r) | |
| np.save(os.path.join(path_txt, f'{ko}_target_mat_hmat.npy'), target_mat_r) | |
| np.save(os.path.join(path_txt, f'{ko}_pred_mat_hmat.npy'), pred_mat_r) | |
| np.save(os.path.join(path_txt, f'{ko}_dif_mat_smat.npy'), dif_mat_i) | |
| np.save(os.path.join(path_txt, f'{ko}_target_mat_smat.npy'), target_mat_i) | |
| np.save(os.path.join(path_txt, f'{ko}_pred_mat_smat.npy'), pred_mat_i) | |
| print("Done") | |
| def nan_checker(lst): | |
| """ | |
| Check if there are any NaN values in the list. | |
| Parameters: | |
| lst (list): The list to check for NaN values. | |
| Returns: | |
| bool: True if there is at least one NaN value in the list, False otherwise. | |
| """ | |
| return any(math.isnan(x) for x in lst) | |
| def generate_heatmap(matrix, filename, grid1_step=1, grid2_step=13): | |
| """ | |
| Generate and save a heatmap from a given matrix. | |
| :param matrix: 2D array of data | |
| :param filename: The file name to save the heatmap | |
| :param grid1_step: Step for the first grid (default is 1) | |
| :param grid2_step: Step for the second grid (default is 13) | |
| """ | |
| plt.close() | |
| # Determine the min and max values of the matrix | |
| min_val = np.min(matrix) | |
| if min_val >=0: | |
| min_val=-0.1 | |
| max_val = np.max(matrix) | |
| if max_val <=0: | |
| max_val=+0.1 | |
| # Create the heatmap | |
| fig=plt.figure() | |
| norm = TwoSlopeNorm(vmin=min_val, vcenter=0, vmax=max_val) | |
| plt.imshow(matrix, cmap='seismic', norm=norm, interpolation='nearest') | |
| plt.colorbar() | |
| # Add grids | |
| ax = plt.gca() | |
| ax.set_xticks(np.arange(-0.5, matrix.shape[1], grid1_step), minor=True) | |
| ax.set_yticks(np.arange(-0.5, matrix.shape[0], grid1_step), minor=True) | |
| ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.05) | |
| if grid2_step > 0: | |
| ax.set_xticks(np.arange(-0.5, matrix.shape[1], grid2_step), minor=False) | |
| ax.set_yticks(np.arange(-0.5, matrix.shape[0], grid2_step), minor=False) | |
| ax.grid(which='major', color='gray', linestyle='-', linewidth=0.25) | |
| plt.grid(True) | |
| plt.savefig(filename) | |
| return fig | |
| def list_files_in_directory(directory_path): | |
| """ | |
| List all files in the specified directory. | |
| :param directory_path: Path to the directory | |
| :return: List of file names in the directory | |
| """ | |
| try: | |
| # Get the list of all files and directories | |
| entries = os.listdir(directory_path) | |
| # Filter out only the files | |
| files = [entry for entry in entries if os.path.isfile(os.path.join(directory_path, entry))] | |
| return files | |
| except FileNotFoundError: | |
| return f"The directory {directory_path} does not exist." | |
| except PermissionError: | |
| return f"Permission denied for accessing the directory {directory_path}." | |
| except Exception as e: | |
| return f"An error occurred: {e}" | |
| def list_subdirectories(directory_path): | |
| """ | |
| List all subdirectories in the specified directory. | |
| :param directory_path: Path to the directory | |
| :return: List of subdirectory names in the directory | |
| """ | |
| try: | |
| # Get the list of all files and directories | |
| entries = os.listdir(directory_path) | |
| # Filter out only the subdirectories | |
| subdirectories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))] | |
| return subdirectories | |
| except FileNotFoundError: | |
| return f"The directory {directory_path} does not exist." | |
| except PermissionError: | |
| return f"Permission denied for accessing the directory {directory_path}." | |
| except Exception as e: | |
| return f"An error occurred: {e}" | |
| def create_directory_if_not_exists(directory_path): | |
| """ | |
| Create a directory if it does not exist. | |
| :param directory_path: Path to the directory to be created | |
| """ | |
| try: | |
| # Check if the directory exists | |
| if not os.path.exists(directory_path): | |
| # Create the directory | |
| os.makedirs(directory_path) | |
| print(f"Directory '{directory_path}' created successfully.") | |
| else: | |
| print(f"Directory '{directory_path}' already exists.") | |
| except PermissionError: | |
| print(f"Permission denied for creating the directory '{directory_path}'.") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| def save_dict_to_json(dictionary, file_path): | |
| """ | |
| Save a dictionary to a JSON file. | |
| :param dictionary: Dictionary to save | |
| :param file_path: Path to the JSON file | |
| """ | |
| try: | |
| with open(file_path, 'w') as json_file: | |
| json.dump(dictionary, json_file, indent=4) | |
| print(f"Dictionary successfully saved to {file_path}") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| def erase_png_files(directory): | |
| """ | |
| Erases all .png files from the specified directory. | |
| Parameters: | |
| directory (str): The path to the directory where .png files should be erased. | |
| Returns: | |
| int: The number of .png files deleted. | |
| """ | |
| # Construct the path to all .png files in the directory | |
| png_files = glob.glob(os.path.join(directory, '*.png')) | |
| # Delete each .png file | |
| for file_path in png_files: | |
| try: | |
| os.remove(file_path) | |
| print(f"Deleted: {file_path}") | |
| except Exception as e: | |
| print(f"Error deleting {file_path}: {e}") | |
| return len(png_files) | |
| def read_dict_from_json(file_path): | |
| """ | |
| Reads a dictionary from a JSON file. | |
| Parameters: | |
| file_path (str): The path to the JSON file. | |
| Returns: | |
| dict: The dictionary read from the JSON file. | |
| """ | |
| try: | |
| with open(file_path, 'r') as file: | |
| data = json.load(file) | |
| return data | |
| except FileNotFoundError: | |
| print(f"Error: The file at {file_path} was not found.") | |
| return None | |
| except json.JSONDecodeError: | |
| print(f"Error: The file at {file_path} is not a valid JSON file.") | |
| return None | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return None |