CoMemNet / src /model /replay.py
mei2333's picture
Upload src/model/replay.py with huggingface_hub
ba569ff verified
Raw
History Blame Contribute Delete
7.66 kB
import sys
sys.path.append('src/')
import numpy as np
from scipy.stats import entropy as kldiv
from utils.dataloader import Cotinual_learning_DataLoader
import torch
from scipy.spatial import distance
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import wasserstein_distance
import matplotlib.pyplot as plt
import pickle
import matplotlib.pyplot as plt
import numpy as np
from fractions import Fraction
def contrastive_loss(embedding1, embedding2, temperature=0.1):
sim_matrix = F.cosine_similarity(embedding1.unsqueeze(1), embedding2.unsqueeze(0), dim=-1)
sim_matrix = sim_matrix / temperature
labels = torch.arange(sim_matrix.size(0)).to(embedding1.device)
loss = F.cross_entropy(sim_matrix, labels)
return loss
def get_feature(data, graph, args, model, adj):
node_size = data.shape[1]
data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3))
dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True)
for batch_idx, data in enumerate(dataloader.get_iterator()):
feature = model.target_branch(data,args.year)
return feature.cpu().detach().numpy()
def get_current(data, graph, args, model, adj):
node_size = data.shape[1]
data = np.reshape(data[-288*7-1:-1,:], (-1, args.x_len, node_size, 3))
dataloader = Cotinual_learning_DataLoader(data, batch_size=data.shape[0], shuffle=True,pad_with_last_sample=True)
for batch_idx, data in enumerate(dataloader.get_iterator()):
feature = model(data,args.year)
return feature.cpu().detach().numpy()
def get_adj(year, args):
adj = np.load(osp.join(args.graph_path, str(year)+"_adj.npz"))["x"]
adj = adj / (np.sum(adj, 1, keepdims=True) + 1e-6)
return torch.from_numpy(adj).to(torch.float).to(args.device)
def score_func(pre_data, cur_data, args):
node_size = pre_data.shape[1]
score = []
for node in range(node_size):
max_val = max(max(pre_data[:,node]), max(cur_data[:,node]))
min_val = min(min(pre_data[:,node]), min(cur_data[:,node]))
pre_prob, _ = np.histogram(pre_data[:,node], bins=10, range=(min_val, max_val))
pre_prob = pre_prob *1.0 / sum(pre_prob)
cur_prob, _ = np.histogram(cur_data[:,node], bins=10, range=(min_val, max_val))
cur_prob = cur_prob * 1.0 /sum(cur_prob)
score.append(kldiv(pre_prob, cur_prob))
return np.argpartition(np.asarray(score), -args.topm)[-args.topm:]
def visualize_distributions(save_dis, top_nodes, args):
years = getattr(args, 'years', [args.year])
for year in years:
plt.figure(figsize=(25, 10 * len(top_nodes)))
for idx, node in enumerate(top_nodes):
if (node, 0) in save_dis:
pre_prob, cur_prob = save_dis[(node, 0)]
bins = np.linspace(0, 1.0, 11)
bin_centers = bins[:-1]
pre_positions = bin_centers - 0.02
cur_positions = bin_centers + 0.02
gap = 0.05
plt.subplot(len(top_nodes), 1, idx + 1)
# plt.bar(pre_positions, pre_prob, width=0.04, alpha=0.5, label=f'Previous', color='blue')
# plt.bar(cur_positions, cur_prob, width=0.04, alpha=0.5, label=f'Current', color='orange')
# plt.bar(pre_positions, pre_prob, width=0.04, alpha=0.5, label=f'Previous', color='teal')
# plt.bar(cur_positions, cur_prob, width=0.04, alpha=0.5, label=f'Current', color='salmon')
plt.bar(pre_positions, pre_prob, width=0.04, alpha=1, label=f'Previous', color='peachpuff')
plt.bar(cur_positions, cur_prob, width=0.04, alpha=1, label=f'Current', color='lightskyblue')
plt.xlabel('Normalized Range', fontsize=30)
plt.ylabel('Density', fontsize=30)
plt.title(f'Distribution for Node {node} at {year} year', fontsize=30)
plt.legend(fontsize=30)
plt.grid(True, alpha=0.3)
fraction_labels = [f"{int(x * 10)}/10" for x in bins]
plt.xticks(bins, fraction_labels, rotation=0, fontsize=30)
plt.yticks(fontsize=30)
plt.tight_layout()
plt.savefig(f'/root/autodl-fs/CoMemNet/figure/PEMSD8/node_distributions_{year}.png')
plt.close()
print(f"Adjusted distribution plot for year {year} saved as 'node_distributions_{year}.png'")
def influence_node_selection(model, args, pre_data, cur_data, pre_graph, cur_graph):
save_dis = {}
if args.replay_strategy == 'original':
pre_data = pre_data[-288*7-1:-1,:]
cur_data = cur_data[-288*7-1:-1,:]
node_size = pre_data.shape[1]
score = []
for node in range(node_size):
max_val = max(np.max(pre_data[:,node,:]), np.max(cur_data[:,node,:]))
min_val = min(np.min(pre_data[:,node,:]), np.min(cur_data[:,node,:]))
pre_prob, _ = np.histogram(pre_data[:,node,:], bins=10, range=(min_val, max_val))
pre_prob = pre_prob *1.0 / sum(pre_prob)
cur_prob, _ = np.histogram(cur_data[:,node,:], bins=10, range=(min_val, max_val))
cur_prob = cur_prob * 1.0 /sum(cur_prob)
score.append(kldiv(pre_prob, cur_prob))
return np.argpartition(np.asarray(score), -args.topm)[-args.topm:]
elif args.replay_strategy == 'feature':
model.eval()
pre_adj = get_adj(args.year-1, args)
cur_adj = get_adj(args.year, args)
pre_data = get_feature(pre_data, pre_graph, args, model, pre_adj) # 形状: (batch_size, num_timesteps, num_nodes, 1)
cur_data = get_current(cur_data, cur_graph, args, model, cur_adj)
num_nodes = min(pre_data.shape[2], cur_data.shape[2])
pre_data = pre_data[:, :, :num_nodes, :]
cur_data = cur_data[:, :, :num_nodes, :]
print("Aligned num_nodes:", num_nodes)
score = []
save_dis = {}
for i in range(num_nodes):
score_ = 0.0
for j in range(pre_data.shape[1]):
try:
if np.max(pre_data[:, j, i, 0]) == np.min(pre_data[:, j, i, 0]):
print(f"Warning: Node {i}, timestep {j} has constant value")
continue
pre_data[:, j, i, 0] = (pre_data[:, j, i, 0] - np.min(pre_data[:, j, i, 0])) / (np.max(pre_data[:, j, i, 0]) - np.min(pre_data[:, j, i, 0]))
cur_data[:, j, i, 0] = (cur_data[:, j, i, 0] - np.min(cur_data[:, j, i, 0])) / (np.max(cur_data[:, j, i, 0]) - np.min(cur_data[:, j, i, 0]))
pre_prob, _ = np.histogram(pre_data[:, j, i, 0], bins=10, range=(0, 1), density=True)
cur_prob, _ = np.histogram(cur_data[:, j, i, 0], bins=10, range=(0, 1), density=True)
save_dis[(i, j)] = [pre_prob, cur_prob]
score_ += wasserstein_distance(pre_prob, cur_prob)
except Exception as e:
print(f"Error for node {i}, timestep {j}: {e}")
continue
score.append(score_)
with open('save_dis.pkl', 'wb') as f:
pickle.dump(save_dis, f)
args.topm = min(args.topm, len(score))
top_nodes = np.argsort(score)[-args.topm:].tolist()
visualize_distributions(save_dis, top_nodes, args)
return np.argsort(score)[-args.topm:].tolist()