| import sys |
| sys.path.append('src/') |
| import numpy as np |
| from scipy.stats import entropy as kldiv |
| from utils.dataloader import Cotinual_learning_DataLoader |
| import torch |
| from scipy.spatial import distance |
| import os.path as osp |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from scipy.stats import wasserstein_distance |
| import matplotlib.pyplot as plt |
| import pickle |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from fractions import Fraction |
|
|
|
|
| def contrastive_loss(embedding1, embedding2, temperature=0.1): |
| sim_matrix = F.cosine_similarity(embedding1.unsqueeze(1), embedding2.unsqueeze(0), dim=-1) |
| sim_matrix = sim_matrix / temperature |
| labels = torch.arange(sim_matrix.size(0)).to(embedding1.device) |
| |
| loss = F.cross_entropy(sim_matrix, labels) |
| return loss |
|
|
| def get_feature(data, graph, args, model, adj): |
| node_size = data.shape[1] |
| data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3)) |
| dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True) |
| for batch_idx, data in enumerate(dataloader.get_iterator()): |
| feature = model.target_branch(data,args.year) |
| return feature.cpu().detach().numpy() |
| |
|
|
| def get_current(data, graph, args, model, adj): |
| node_size = data.shape[1] |
| data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3)) |
| dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True) |
| for batch_idx, data in enumerate(dataloader.get_iterator()): |
| feature = model(data,args.year) |
| return feature.cpu().detach().numpy() |
|
|
|
|
| def get_adj(year, args): |
| adj = np.load(osp.join(args.graph_path, str(year)+"_adj.npz"))["x"] |
| adj = adj / (np.sum(adj, 1, keepdims=True) + 1e-6) |
| return torch.from_numpy(adj).to(torch.float).to(args.device) |
| |
|
|
| def score_func(pre_data, cur_data, args): |
| node_size = pre_data.shape[1] |
| score = [] |
| for node in range(node_size): |
| max_val = max(max(pre_data[:,node]), max(cur_data[:,node])) |
| min_val = min(min(pre_data[:,node]), min(cur_data[:,node])) |
| pre_prob, _ = np.histogram(pre_data[:,node], bins=10, range=(min_val, max_val)) |
| pre_prob = pre_prob *1.0 / sum(pre_prob) |
| cur_prob, _ = np.histogram(cur_data[:,node], bins=10, range=(min_val, max_val)) |
| cur_prob = cur_prob * 1.0 /sum(cur_prob) |
| score.append(kldiv(pre_prob, cur_prob)) |
| return np.argpartition(np.asarray(score), -args.topm)[-args.topm:] |
|
|
|
|
|
|
| def visualize_distributions(save_dis, top_nodes, args): |
| years = getattr(args, 'years', [args.year]) |
| |
| for year in years: |
| plt.figure(figsize=(25, 10 * len(top_nodes))) |
| |
| for idx, node in enumerate(top_nodes): |
| if (node, 0) in save_dis: |
| pre_prob, cur_prob = save_dis[(node, 0)] |
| bins = np.linspace(0, 1.0, 11) |
| bin_centers = bins[:-1] |
| |
| pre_positions = bin_centers - 0.02 |
| cur_positions = bin_centers + 0.02 |
| gap = 0.05 |
| |
| plt.subplot(len(top_nodes), 1, idx + 1) |
| |
| |
| |
| |
| plt.bar(pre_positions, pre_prob, width=0.04, alpha=1, label=f'Previous', color='peachpuff') |
| plt.bar(cur_positions, cur_prob, width=0.04, alpha=1, label=f'Current', color='lightskyblue') |
| plt.xlabel('Normalized Range', fontsize=30) |
| plt.ylabel('Density', fontsize=30) |
| plt.title(f'Distribution for Node {node} at {year} year', fontsize=30) |
| plt.legend(fontsize=30) |
| plt.grid(True, alpha=0.3) |
| fraction_labels = [f"{int(x * 10)}/10" for x in bins] |
| plt.xticks(bins, fraction_labels, rotation=0, fontsize=30) |
| plt.yticks(fontsize=30) |
| |
| plt.tight_layout() |
| plt.savefig(f'/root/autodl-fs/CoMemNet/figure/PEMSD8/node_distributions_{year}.png') |
| plt.close() |
| print(f"Adjusted distribution plot for year {year} saved as 'node_distributions_{year}.png'") |
|
|
|
|
| def influence_node_selection(model, args, pre_data, cur_data, pre_graph, cur_graph): |
| save_dis = {} |
| if args.replay_strategy == 'original': |
| pre_data = pre_data[-288*7-1:-1,:] |
| cur_data = cur_data[-288*7-1:-1,:] |
| node_size = pre_data.shape[1] |
| score = [] |
| for node in range(node_size): |
| max_val = max(np.max(pre_data[:,node,:]), np.max(cur_data[:,node,:])) |
| min_val = min(np.min(pre_data[:,node,:]), np.min(cur_data[:,node,:])) |
| pre_prob, _ = np.histogram(pre_data[:,node,:], bins=10, range=(min_val, max_val)) |
| pre_prob = pre_prob *1.0 / sum(pre_prob) |
| cur_prob, _ = np.histogram(cur_data[:,node,:], bins=10, range=(min_val, max_val)) |
| cur_prob = cur_prob * 1.0 /sum(cur_prob) |
| score.append(kldiv(pre_prob, cur_prob)) |
| return np.argpartition(np.asarray(score), -args.topm)[-args.topm:] |
| |
| elif args.replay_strategy == 'feature': |
| model.eval() |
| pre_adj = get_adj(args.year-1, args) |
| cur_adj = get_adj(args.year, args) |
| |
| pre_data = get_feature(pre_data, pre_graph, args, model, pre_adj) |
| cur_data = get_current(cur_data, cur_graph, args, model, cur_adj) |
| |
| |
| num_nodes = min(pre_data.shape[2], cur_data.shape[2]) |
| pre_data = pre_data[:, :, :num_nodes, :] |
| cur_data = cur_data[:, :, :num_nodes, :] |
| print("Aligned num_nodes:", num_nodes) |
| |
| score = [] |
| save_dis = {} |
| |
| for i in range(num_nodes): |
| score_ = 0.0 |
| for j in range(pre_data.shape[1]): |
| try: |
| if np.max(pre_data[:, j, i, 0]) == np.min(pre_data[:, j, i, 0]): |
| print(f"Warning: Node {i}, timestep {j} has constant value") |
| continue |
| pre_data[:, j, i, 0] = (pre_data[:, j, i, 0] - np.min(pre_data[:, j, i, 0])) / (np.max(pre_data[:, j, i, 0]) - np.min(pre_data[:, j, i, 0])) |
| cur_data[:, j, i, 0] = (cur_data[:, j, i, 0] - np.min(cur_data[:, j, i, 0])) / (np.max(cur_data[:, j, i, 0]) - np.min(cur_data[:, j, i, 0])) |
| |
| pre_prob, _ = np.histogram(pre_data[:, j, i, 0], bins=10, range=(0, 1), density=True) |
| cur_prob, _ = np.histogram(cur_data[:, j, i, 0], bins=10, range=(0, 1), density=True) |
| |
| save_dis[(i, j)] = [pre_prob, cur_prob] |
| score_ += wasserstein_distance(pre_prob, cur_prob) |
| except Exception as e: |
| print(f"Error for node {i}, timestep {j}: {e}") |
| continue |
| score.append(score_) |
| |
| with open('save_dis.pkl', 'wb') as f: |
| pickle.dump(save_dis, f) |
| |
| args.topm = min(args.topm, len(score)) |
| top_nodes = np.argsort(score)[-args.topm:].tolist() |
|
|
| visualize_distributions(save_dis, top_nodes, args) |
| |
| return np.argsort(score)[-args.topm:].tolist() |
|
|