import sys sys.path.append('src/') import numpy as np from scipy.stats import entropy as kldiv from utils.dataloader import Cotinual_learning_DataLoader import torch from scipy.spatial import distance import os.path as osp import torch import torch.nn as nn import torch.nn.functional as F from scipy.stats import wasserstein_distance import matplotlib.pyplot as plt import pickle import matplotlib.pyplot as plt import numpy as np from fractions import Fraction def contrastive_loss(embedding1, embedding2, temperature=0.1): sim_matrix = F.cosine_similarity(embedding1.unsqueeze(1), embedding2.unsqueeze(0), dim=-1) sim_matrix = sim_matrix / temperature labels = torch.arange(sim_matrix.size(0)).to(embedding1.device) loss = F.cross_entropy(sim_matrix, labels) return loss def get_feature(data, graph, args, model, adj): node_size = data.shape[1] data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3)) dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True) for batch_idx, data in enumerate(dataloader.get_iterator()): feature = model.target_branch(data,args.year) return feature.cpu().detach().numpy() def get_current(data, graph, args, model, adj): node_size = data.shape[1] data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3)) dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True) for batch_idx, data in enumerate(dataloader.get_iterator()): feature = model(data,args.year) return feature.cpu().detach().numpy() def get_adj(year, args): adj = np.load(osp.join(args.graph_path, str(year)+"_adj.npz"))["x"] adj = adj / (np.sum(adj, 1, keepdims=True) + 1e-6) return torch.from_numpy(adj).to(torch.float).to(args.device) def score_func(pre_data, cur_data, args): node_size = pre_data.shape[1] score = [] for node in range(node_size): max_val = max(max(pre_data[:,node]), max(cur_data[:,node])) min_val = min(min(pre_data[:,node]), min(cur_data[:,node])) pre_prob, _ = np.histogram(pre_data[:,node], bins=10, range=(min_val, max_val)) pre_prob = pre_prob *1.0 / sum(pre_prob) cur_prob, _ = np.histogram(cur_data[:,node], bins=10, range=(min_val, max_val)) cur_prob = cur_prob * 1.0 /sum(cur_prob) score.append(kldiv(pre_prob, cur_prob)) return np.argpartition(np.asarray(score), -args.topm)[-args.topm:] def visualize_distributions(save_dis, top_nodes, args): years = getattr(args, 'years', [args.year]) for year in years: plt.figure(figsize=(25, 10 * len(top_nodes))) for idx, node in enumerate(top_nodes): if (node, 0) in save_dis: pre_prob, cur_prob = save_dis[(node, 0)] bins = np.linspace(0, 1.0, 11) bin_centers = bins[:-1] pre_positions = bin_centers - 0.02 cur_positions = bin_centers + 0.02 gap = 0.05 plt.subplot(len(top_nodes), 1, idx + 1) # plt.bar(pre_positions, pre_prob, width=0.04, alpha=0.5, label=f'Previous', color='blue') # plt.bar(cur_positions, cur_prob, width=0.04, alpha=0.5, label=f'Current', color='orange') # plt.bar(pre_positions, pre_prob, width=0.04, alpha=0.5, label=f'Previous', color='teal') # plt.bar(cur_positions, cur_prob, width=0.04, alpha=0.5, label=f'Current', color='salmon') plt.bar(pre_positions, pre_prob, width=0.04, alpha=1, label=f'Previous', color='peachpuff') plt.bar(cur_positions, cur_prob, width=0.04, alpha=1, label=f'Current', color='lightskyblue') plt.xlabel('Normalized Range', fontsize=30) plt.ylabel('Density', fontsize=30) plt.title(f'Distribution for Node {node} at {year} year', fontsize=30) plt.legend(fontsize=30) plt.grid(True, alpha=0.3) fraction_labels = [f"{int(x * 10)}/10" for x in bins] plt.xticks(bins, fraction_labels, rotation=0, fontsize=30) plt.yticks(fontsize=30) plt.tight_layout() plt.savefig(f'/root/autodl-fs/CoMemNet/figure/PEMSD8/node_distributions_{year}.png') plt.close() print(f"Adjusted distribution plot for year {year} saved as 'node_distributions_{year}.png'") def influence_node_selection(model, args, pre_data, cur_data, pre_graph, cur_graph): save_dis = {} if args.replay_strategy == 'original': pre_data = pre_data[-288*7-1:-1,:] cur_data = cur_data[-288*7-1:-1,:] node_size = pre_data.shape[1] score = [] for node in range(node_size): max_val = max(np.max(pre_data[:,node,:]), np.max(cur_data[:,node,:])) min_val = min(np.min(pre_data[:,node,:]), np.min(cur_data[:,node,:])) pre_prob, _ = np.histogram(pre_data[:,node,:], bins=10, range=(min_val, max_val)) pre_prob = pre_prob *1.0 / sum(pre_prob) cur_prob, _ = np.histogram(cur_data[:,node,:], bins=10, range=(min_val, max_val)) cur_prob = cur_prob * 1.0 /sum(cur_prob) score.append(kldiv(pre_prob, cur_prob)) return np.argpartition(np.asarray(score), -args.topm)[-args.topm:] elif args.replay_strategy == 'feature': model.eval() pre_adj = get_adj(args.year-1, args) cur_adj = get_adj(args.year, args) pre_data = get_feature(pre_data, pre_graph, args, model, pre_adj) # 形状: (batch_size, num_timesteps, num_nodes, 1) cur_data = get_current(cur_data, cur_graph, args, model, cur_adj) num_nodes = min(pre_data.shape[2], cur_data.shape[2]) pre_data = pre_data[:, :, :num_nodes, :] cur_data = cur_data[:, :, :num_nodes, :] print("Aligned num_nodes:", num_nodes) score = [] save_dis = {} for i in range(num_nodes): score_ = 0.0 for j in range(pre_data.shape[1]): try: if np.max(pre_data[:, j, i, 0]) == np.min(pre_data[:, j, i, 0]): print(f"Warning: Node {i}, timestep {j} has constant value") continue pre_data[:, j, i, 0] = (pre_data[:, j, i, 0] - np.min(pre_data[:, j, i, 0])) / (np.max(pre_data[:, j, i, 0]) - np.min(pre_data[:, j, i, 0])) cur_data[:, j, i, 0] = (cur_data[:, j, i, 0] - np.min(cur_data[:, j, i, 0])) / (np.max(cur_data[:, j, i, 0]) - np.min(cur_data[:, j, i, 0])) pre_prob, _ = np.histogram(pre_data[:, j, i, 0], bins=10, range=(0, 1), density=True) cur_prob, _ = np.histogram(cur_data[:, j, i, 0], bins=10, range=(0, 1), density=True) save_dis[(i, j)] = [pre_prob, cur_prob] score_ += wasserstein_distance(pre_prob, cur_prob) except Exception as e: print(f"Error for node {i}, timestep {j}: {e}") continue score.append(score_) with open('save_dis.pkl', 'wb') as f: pickle.dump(save_dis, f) args.topm = min(args.topm, len(score)) top_nodes = np.argsort(score)[-args.topm:].tolist() visualize_distributions(save_dis, top_nodes, args) return np.argsort(score)[-args.topm:].tolist()